Files
soul_diary/soul_diary/backend/database/service.py

149 lines
4.5 KiB
Python

import pathlib
import uuid
from contextlib import asynccontextmanager
from typing import Type
import bcrypt
from alembic import command as alembic_command
from alembic.config import Config as AlembicConfig
from facet import ServiceMixin
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from .models import Sense, Session, User
from .settings import DatabaseSettings
class DatabaseService(ServiceMixin):
def __init__(self, dsn: str):
self._dsn = dsn
self._engine = create_async_engine(self._dsn, pool_recycle=60)
self._sessionmaker = async_sessionmaker(self._engine, expire_on_commit=False)
def get_alembic_config(self) -> AlembicConfig:
migrations_path = pathlib.Path(__file__).parent / "migrations"
config = AlembicConfig()
config.set_main_option("script_location", str(migrations_path))
config.set_main_option("sqlalchemy.url", self._dsn.replace("%", "%%"))
return config
def get_models(self) -> list[Type[DeclarativeBase]]:
return [User, Sense]
@asynccontextmanager
async def transaction(self):
async with self._sessionmaker() as session:
async with session.begin():
yield session
def migrate(self):
alembic_command.upgrade(self.get_alembic_config(), "head")
def rollback(self, revision: str | None = None):
revision = revision or "-1"
alembic_command.downgrade(self.get_alembic_config(), revision)
def show_migrations(self):
alembic_command.history(self.get_alembic_config())
def create_migration(self, message: str | None = None):
alembic_command.revision(
self.get_alembic_config(), message=message, autogenerate=True,
)
async def create_user(self, session: AsyncSession, username: str, password: str) -> User:
hashed_password = bcrypt.hashpw(
password.encode("utf-8"),
bcrypt.gensalt(),
).decode("utf-8")
user = User(username=username, password=hashed_password)
user_session = Session(user=user)
user.sessions.append(user_session)
session.add_all([user, user_session])
return user
async def auth_user(
self,
session: AsyncSession,
username: str,
password: str,
) -> Session | None:
query = select(User).where(User.username == username)
result = await session.execute(query)
user = result.scalar_one_or_none()
if user is None:
return None
if not bcrypt.checkpw(password.encode("utf-8"), user.password.encode("utf-8")):
return None
user_session = Session(user=user)
session.add(user_session)
return user_session
async def logout_user(self, session: AsyncSession, user_session: Session):
await session.delete(user_session)
async def get_user_session(self, session: AsyncSession, token: str) -> Session | None:
query = select(Session).where(Session.token == token)
result = await session.execute(query)
user_session = result.scalar_one_or_none()
return user_session
async def get_sense_list(
self,
session: AsyncSession,
user: User,
page: int = 1,
limit: int = 10,
) -> list[Sense]:
query = (
select(Sense).where(Sense.user == user)
.order_by(Sense.created_at.desc())
.limit(limit).offset((page - 1) * limit)
)
result = await session.execute(query)
senses = result.scalars().all()
return list(senses)
async def create_sense(self, session: AsyncSession, user: User, data: str) -> Sense:
sense = Sense(user=user, data=data)
session.add(sense)
return sense
async def get_sense(self, session: AsyncSession, sense_id: uuid.UUID) -> Sense | None:
query = select(Sense).where(Sense.id == sense_id)
result = await session.execute(query)
sense = result.scalar_one_or_none()
return sense
async def update_sense(self, session: AsyncSession, sense: Sense, data: str) -> Sense:
sense.data = data
session.add(sense)
return sense
async def delete_sense(self, session: AsyncSession, sense: Sense):
await session.delete(sense)
def get_service() -> DatabaseService:
settings = DatabaseSettings()
return DatabaseService(dsn=str(settings.dsn))