16
loading...
This website collects cookies to deliver better user experience
docker-compose up
to launch everything. database.py
:# database.py
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
# We're using postgres but you could use
# any other engine supported by SQlAlchemy
SQLALCHEMY_DATABASE_URL = "postgresql://test-fastapi:password@db/db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
Base
Item
model, we simply create a file called models.py
and declare it using the Base
we've juste configured :# models.py
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import String
from .database import Base
class Item(Base):
__tablename__ = "items"
# Don't forget to set an index !
id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
description = Column(String, index=True)
# schemas.py
from typing import Optional
from pydantic import BaseModel
class ItemBase(BaseModel):
title: "str"
description: "Optional[str] = None"
class ItemCreate(ItemBase):
pass
class Item(ItemBase):
id: int
class Config:
orm_mode = True
Item
instances :# crud.py
from sqlalchemy import select
from sqlalchemy.orm import Session
from . import schemas
from .models import Item
def get_items(db: Session):
items = select(Item)
return db.execute(items).scalars().all()
def create_item(db: Session, item: schemas.ItemCreate):
db_item = Item(**item.dict())
db.add(db_item)
db.commit()
db.refresh(db_item)
return db_item
main.py
file and add the following lines to it :# main.py
from typing import List
from fastapi import Depends
from fastapi import FastAPI
from sqlalchemy.orm import Session
from . import crud
from . import models
from . import schemas
from .database import engine
from .database import SessionLocal
app = FastAPI()
# Here we create all the tables directly in the app
# in a real life situation this would be handle by a migratin tool
# Like alembic
models.Base.metadata.create_all(bind=engine)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@app.post("/items/", response_model=schemas.Item)
def create_item(item: schemas.ItemCreate, db: Session = Depends(get_db)):
return crud.create_item(db, item)
/docs
endpoint. Another awesome feature of FastAPI ! test_database.py
to write our tests, and add the following code to it :# test_database.py
SQLALCHEMY_DATABASE_URL = "postgresql://test-fastapi:password@db/test-fastapi-test"
engine = create_engine(SQLALCHEMY_DATABASE_URL)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
get_db
dependency with the override_get_db
which return a Session
connected to our test database. This def test_post_items():
# We grab another session to check
# if the items are created
db = override_get_db()
client = TestClient(app)
client.post("/items/", json={"title": "Item 1"})
client.post("/items/", json={"title": "Item 2"})
items = crud.get_items(db)
assert len(items) == 2
# The following doesn't work
# changes won't be rolled back !
def override_get_db():
try:
db = TestingSessionLocal()
db.begin()
yield db
finally:
db.rollback()
db.close()
app.dependency_overrides[get_db] = override_get_db
def override_get_db():
connection = engine.connect()
# begin a non-ORM transaction
transaction = connection.begin()
# bind an individual Session to the connection
db = Session(bind=connection)
# db = Session(engine)
yield db
db.rollback()
connection.close()
@app.get("/items/", response_model=List[schemas.Item])
def read_items(db: Session = Depends(get_db)):
return crud.get_items(db)
# conftest.py
@pytest.fixture(scope="session")
def db_engine():
engine = create_engine(SQLALCHEMY_DATABASE_URL)
if not database_exists:
create_database(engine.url)
Base.metadata.create_all(bind=engine)
yield engine
@pytest.fixture(scope="function")
def db(db_engine):
connection = db_engine.connect()
# begin a non-ORM transaction
transaction = connection.begin()
# bind an individual Session to the connection
db = Session(bind=connection)
# db = Session(db_engine)
yield db
db.rollback()
connection.close()
#
@pytest.fixture(scope="function")
def client(db):
app.dependency_overrides[get_db] = lambda: db
with TestClient(app) as c:
yield c
db
fixture rollsback the session after each test, and we can use it to seed the database. I've also put the dependency override in a fixture alongside the client. That way each @pytest.fixture
def items(db):
create_item(db, schemas.ItemCreate(title="item 1"))
create_item(db, schemas.ItemCreate(title="item 2"))
# test_database.py"
def test_list_items(items, client):
response = client.get("/items")
assert len(response.json()) == 2