154 lines
4.8 KiB
Python
154 lines
4.8 KiB
Python
import contextvars
|
|
from contextlib import asynccontextmanager, contextmanager
|
|
from typing import AsyncGenerator, Generator, Generic, TypeVar
|
|
|
|
from pydantic_filters import BasePagination, BaseSort
|
|
from pydantic_filters.drivers.sqlalchemy import append_to_statement
|
|
from sqlalchemy import func
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
from sqlmodel import delete, insert, select, update
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from .exceptions import AlreadyExistsError, HaveNoSessionError
|
|
from .settings import RepositorySettings
|
|
from .tables import BaseSQLModel
|
|
|
|
|
|
DTOType = TypeVar("DTOType")
|
|
FilterType = TypeVar("FilterType")
|
|
|
|
|
|
class BaseRepository(Generic[DTOType, FilterType]):
|
|
def __init__(self, settings: RepositorySettings):
|
|
self._settings = settings
|
|
|
|
engine_parameters = {
|
|
"url": str(self._settings.dsn),
|
|
"pool_recycle": self._settings.pool_recycle,
|
|
}
|
|
if not str(self._settings.dsn).startswith("sqlite"):
|
|
engine_parameters["pool_timeout"] = self._settings.pool_timeout
|
|
engine_parameters["pool_size"] = self._settings.pool_size
|
|
|
|
self._engine = create_async_engine(**engine_parameters)
|
|
self._session_context = contextvars.ContextVar("session_context")
|
|
|
|
@property
|
|
def session(self) -> AsyncSession:
|
|
session = self._session_context.get(None)
|
|
if session is None:
|
|
raise HaveNoSessionError()
|
|
|
|
return session
|
|
|
|
@asynccontextmanager
|
|
async def transaction(self) -> AsyncGenerator[None, None]:
|
|
session_args = {"bind": self._engine, "expire_on_commit": False}
|
|
async with AsyncSession(**session_args) as session:
|
|
with self._set_session_to_session_context(session=session):
|
|
async with session.begin():
|
|
yield
|
|
|
|
@contextmanager
|
|
def _set_session_to_session_context(self, session: AsyncSession) -> Generator[None, None, None]:
|
|
token = self._session_context.set(session)
|
|
yield
|
|
self._session_context.reset(token)
|
|
|
|
def get_db_table(self) -> type[BaseSQLModel]:
|
|
raise NotImplementedError
|
|
|
|
async def get_items_count(
|
|
self,
|
|
filter_: FilterType | None = None,
|
|
) -> int:
|
|
table = self.get_db_table()
|
|
statement = select(func.count())
|
|
statement = append_to_statement(
|
|
statement=statement,
|
|
model=table,
|
|
filter_=filter_,
|
|
)
|
|
|
|
result = await self.session.exec(statement)
|
|
items_count = result.first()
|
|
|
|
return items_count
|
|
|
|
async def get_items(
|
|
self,
|
|
filter_: FilterType | None = None,
|
|
pagination: BasePagination | None = None,
|
|
sort: BaseSort | None = None,
|
|
options: list | None = None,
|
|
) -> AsyncGenerator[DTOType]:
|
|
table = self.get_db_table()
|
|
statement = select(table)
|
|
statement = append_to_statement(
|
|
statement=statement,
|
|
model=table,
|
|
filter_=filter_,
|
|
pagination=pagination,
|
|
sort=sort,
|
|
)
|
|
if options:
|
|
statement = statement.options(*options)
|
|
|
|
result = await self.session.exec(statement)
|
|
db_items = result.all()
|
|
|
|
for db_item in db_items:
|
|
yield db_item.to_item()
|
|
|
|
async def create_item(self, item: DTOType) -> DTOType:
|
|
table = self.get_db_table()
|
|
values = table.from_item(item=item).to_values()
|
|
statement = insert(table).values(values).returning(table)
|
|
|
|
try:
|
|
result = await self.session.exec(statement)
|
|
except IntegrityError:
|
|
raise AlreadyExistsError(model=table, values=values)
|
|
db_item = result.first()[0]
|
|
|
|
return db_item.to_item()
|
|
|
|
async def update_items(
|
|
self,
|
|
filter_: FilterType | None = None,
|
|
**values,
|
|
) -> AsyncGenerator[DTOType]:
|
|
table = self.get_db_table()
|
|
statement = update(table)
|
|
statement = append_to_statement(
|
|
statement=statement,
|
|
model=table,
|
|
filter_=filter_,
|
|
)
|
|
statement = statement.values(**values).returning(table)
|
|
|
|
result = await self.session.exec(statement)
|
|
db_items = result.all()
|
|
|
|
for db_item, *_ in db_items:
|
|
yield db_item.to_item()
|
|
|
|
async def delete_items(
|
|
self,
|
|
filter_: FilterType | None = None,
|
|
pagination: BasePagination | None = None,
|
|
sort: BaseSort | None = None,
|
|
):
|
|
table = self.get_db_table()
|
|
statement = delete(table)
|
|
statement = append_to_statement(
|
|
statement=statement,
|
|
model=table,
|
|
filter_=filter_,
|
|
pagination=pagination,
|
|
sort=sort,
|
|
)
|
|
|
|
await self.session.exec(statement)
|