Как написать контрольные приборы SQLAlchemy для приложений FastAPI
Я пишу приложение FastAPI, использующее базу данных SQLAlchemy. Я скопировал пример из документации FastAPI, упрощая схему базы данных для краткости. Полный исходный код находится внизу этого поста.
Это работает. Я могу запустить это сuvicorn sql_app.main:app
и взаимодействовать с базой данных через документацию Swagger. Когда он запускается, он создаетtest.db
в рабочем каталоге.
Теперь я хочу добавить модульный тест. Что-то вроде этого.
from fastapi import status
from fastapi.testclient import TestClient
from pytest import fixture
from main import app
@fixture
def client() -> TestClient:
return TestClient(app)
def test_fast_sql(client: TestClient):
response = client.get("/users/")
assert response.status_code == status.HTTP_200_OK
assert response.json() == []
Используя исходный код ниже, это требует test.db
в рабочем каталоге как база данных. Вместо этого я хочу создать новую базу данных для каждого модульного теста, который удаляется в конце теста.
Я мог бы поставить глобальный database.engine
а также database.SessionLocal
внутри объекта, который создается во время выполнения, например:
class UserDatabase:
def __init__(self, directory: Path):
directory.mkdir(exist_ok=True, parents=True)
sqlalchemy_database_url = f"sqlite:///{directory}/store.db"
self.engine = create_engine(
sqlalchemy_database_url, connect_args={"check_same_thread": False}
)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
models.Base.metadata.create_all(bind=self.engine)
но я не знаю, как заставить это работать с main.get_db
, поскольку Depends(get_db)
логика в конечном итоге предполагает database.engine
а также database.SessionLocal
доступны по всему миру.
Я привык работать с Flask, чьи средства модульного тестирования делают все это за вас. Сам не знаю, как это написать. Может ли кто-нибудь показать мне минимальные изменения, которые мне нужно было бы внести, чтобы создать новую базу данных для каждого модульного теста в этой структуре?
Полный исходный код упрощенного приложения FastAPI/SQLAlchemy выглядит следующим образом.
database.py
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
models.py
from sqlalchemy import Column, Integer, String
from database import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
name = Column(String)
age = Column(Integer)
schemas.py
from pydantic import BaseModel
class UserBase(BaseModel):
name: str
age: int
class UserCreate(UserBase):
pass
class User(UserBase):
id: int
class Config:
orm_mode = True
crud.py
from sqlalchemy.orm import Session
import schemas
import models
def get_user(db: Session, user_id: int):
return db.query(models.User).filter(models.User.id == user_id).first()
def get_users(db: Session, skip: int = 0, limit: int = 100):
return db.query(models.User).offset(skip).limit(limit).all()
def create_user(db: Session, user: schemas.UserCreate):
db_user = models.User(name=user.name, age=user.age)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
main.py
from typing import List
from fastapi import Depends, FastAPI, HTTPException
from sqlalchemy.orm import Session
import schemas
import models
import crud
from database import SessionLocal, engine
models.Base.metadata.create_all(bind=engine)
app = FastAPI()
# Dependency
def get_db():
try:
db = SessionLocal()
yield db
finally:
db.close()
@app.post("/users/", response_model=schemas.User)
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)):
return crud.create_user(db=db, user=user)
@app.get("/users/", response_model=List[schemas.User])
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
users = crud.get_users(db, skip=skip, limit=limit)
return users
@app.get("/users/{user_id}", response_model=schemas.User)
def read_user(user_id: int, db: Session = Depends(get_db)):
db_user = crud.get_user(db, user_id=user_id)
if db_user is None:
raise HTTPException(status_code=404, detail="User not found")
return db_user
1 ответ
Вам нужно переопределить ваш get_db
зависимости в ваших тестах, см. эти документы.
Что-то вроде этого для вашего приспособления:
@fixture
def db_fixture() -> Session:
raise NotImplementError() # Make this return you temporary session
@fixture
def client(db_fixture) -> TestClient:
def _get_db_override():
return db_fixture
app.dependency_overrides[get_db] = _get_db_override
return TestClient(app)