A kind of initial commit
This commit is contained in:
201
app/db/repositories.py
Normal file
201
app/db/repositories.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from datetime import date, datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import (
|
||||
AuthState,
|
||||
PlaylistRun,
|
||||
PlaylistRunTrack,
|
||||
RecommendationHistory,
|
||||
SavedTrack,
|
||||
User,
|
||||
)
|
||||
from app.utils.time import ensure_utc
|
||||
|
||||
|
||||
class UserRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def get_or_create_by_chat(self, chat_id: str, username: str | None = None) -> User:
|
||||
user = await self.get_by_chat_id(chat_id)
|
||||
if user:
|
||||
if username and user.telegram_username != username:
|
||||
user.telegram_username = username
|
||||
return user
|
||||
user = User(telegram_chat_id=chat_id, telegram_username=username)
|
||||
self.session.add(user)
|
||||
await self.session.flush()
|
||||
return user
|
||||
|
||||
async def get_by_chat_id(self, chat_id: str) -> User | None:
|
||||
result = await self.session.execute(select(User).where(User.telegram_chat_id == chat_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_id(self, user_id: int) -> User | None:
|
||||
result = await self.session.execute(select(User).where(User.id == user_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_active_connected_users(self) -> list[User]:
|
||||
result = await self.session.execute(
|
||||
select(User).where(User.is_active.is_(True), User.spotify_refresh_token.is_not(None))
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
class AuthStateRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def create(self, state: str, telegram_chat_id: str, expires_at: datetime) -> AuthState:
|
||||
row = AuthState(
|
||||
state=state,
|
||||
telegram_chat_id=telegram_chat_id,
|
||||
expires_at=expires_at,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
self.session.add(row)
|
||||
await self.session.flush()
|
||||
return row
|
||||
|
||||
async def pop_valid(self, state: str) -> AuthState | None:
|
||||
now = datetime.now(timezone.utc)
|
||||
result = await self.session.execute(select(AuthState).where(AuthState.state == state))
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
await self.session.delete(row)
|
||||
if ensure_utc(row.expires_at) < now:
|
||||
return None
|
||||
return row
|
||||
|
||||
async def delete_expired(self) -> int:
|
||||
result = await self.session.execute(delete(AuthState).where(AuthState.expires_at < datetime.now(timezone.utc)))
|
||||
return result.rowcount or 0
|
||||
|
||||
|
||||
class SavedTrackRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def replace_for_user(self, user_id: int, tracks: Iterable[dict]) -> None:
|
||||
await self.session.execute(delete(SavedTrack).where(SavedTrack.user_id == user_id))
|
||||
for item in tracks:
|
||||
self.session.add(
|
||||
SavedTrack(
|
||||
user_id=user_id,
|
||||
spotify_track_id=item["id"],
|
||||
name=item["name"],
|
||||
artist_names=", ".join(item["artist_names"]),
|
||||
artist_ids_csv=",".join(item["artist_ids"]),
|
||||
album_name=item.get("album_name"),
|
||||
added_at=item.get("added_at"),
|
||||
popularity=item.get("popularity"),
|
||||
)
|
||||
)
|
||||
await self.session.flush()
|
||||
|
||||
async def list_for_user(self, user_id: int) -> list[SavedTrack]:
|
||||
result = await self.session.execute(select(SavedTrack).where(SavedTrack.user_id == user_id))
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_for_user(self, user_id: int) -> int:
|
||||
result = await self.session.execute(select(func.count()).select_from(SavedTrack).where(SavedTrack.user_id == user_id))
|
||||
return int(result.scalar_one())
|
||||
|
||||
|
||||
class RecommendationHistoryRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def list_track_ids(self, user_id: int) -> set[str]:
|
||||
result = await self.session.execute(
|
||||
select(RecommendationHistory.spotify_track_id).where(RecommendationHistory.user_id == user_id)
|
||||
)
|
||||
return {row[0] for row in result.all()}
|
||||
|
||||
async def mark_tracks(self, user_id: int, track_ids: list[str]) -> None:
|
||||
if not track_ids:
|
||||
return
|
||||
now = datetime.now(timezone.utc)
|
||||
result = await self.session.execute(select(RecommendationHistory).where(RecommendationHistory.user_id == user_id))
|
||||
existing = {row.spotify_track_id: row for row in result.scalars().all()}
|
||||
for track_id in track_ids:
|
||||
if track_id in existing:
|
||||
row = existing[track_id]
|
||||
row.last_recommended_at = now
|
||||
row.times_recommended += 1
|
||||
else:
|
||||
self.session.add(
|
||||
RecommendationHistory(
|
||||
user_id=user_id,
|
||||
spotify_track_id=track_id,
|
||||
first_recommended_at=now,
|
||||
last_recommended_at=now,
|
||||
times_recommended=1,
|
||||
)
|
||||
)
|
||||
await self.session.flush()
|
||||
|
||||
|
||||
class PlaylistRunRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def create_run(self, user_id: int, run_date: date, notes: str | None = None) -> PlaylistRun:
|
||||
run = PlaylistRun(user_id=user_id, run_date=run_date, status="running", notes=notes)
|
||||
self.session.add(run)
|
||||
await self.session.flush()
|
||||
return run
|
||||
|
||||
async def add_tracks(self, run_id: int, tracks: list[dict]) -> None:
|
||||
for idx, track in enumerate(tracks, start=1):
|
||||
self.session.add(
|
||||
PlaylistRunTrack(
|
||||
run_id=run_id,
|
||||
spotify_track_id=track["id"],
|
||||
name=track["name"],
|
||||
artist_names=", ".join(track["artist_names"]),
|
||||
source=track["source"],
|
||||
position=idx,
|
||||
is_new_to_bot=track.get("is_new_to_bot", True),
|
||||
)
|
||||
)
|
||||
await self.session.flush()
|
||||
|
||||
async def mark_success(
|
||||
self,
|
||||
run: PlaylistRun,
|
||||
*,
|
||||
playlist_id: str,
|
||||
playlist_name: str,
|
||||
playlist_url: str | None,
|
||||
total_tracks: int,
|
||||
new_tracks: int,
|
||||
reused_tracks: int,
|
||||
notes: str | None = None,
|
||||
) -> None:
|
||||
run.status = "success"
|
||||
run.playlist_id = playlist_id
|
||||
run.playlist_name = playlist_name
|
||||
run.playlist_url = playlist_url
|
||||
run.total_tracks = total_tracks
|
||||
run.new_tracks = new_tracks
|
||||
run.reused_tracks = reused_tracks
|
||||
run.notes = notes
|
||||
await self.session.flush()
|
||||
|
||||
async def mark_failed(self, run: PlaylistRun, message: str) -> None:
|
||||
run.status = "failed"
|
||||
run.notes = message
|
||||
await self.session.flush()
|
||||
|
||||
async def latest_for_user(self, user_id: int) -> PlaylistRun | None:
|
||||
result = await self.session.execute(
|
||||
select(PlaylistRun).where(PlaylistRun.user_id == user_id).order_by(PlaylistRun.created_at.desc()).limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
Reference in New Issue
Block a user