from __future__ import annotations from collections.abc import Iterable from datetime import date, datetime, timezone from sqlalchemy import delete, func, select from sqlalchemy.ext.asyncio import AsyncSession from app.db.models import ( AuthState, PlaylistRun, PlaylistRunTrack, RecommendationHistory, SavedTrack, User, ) from app.utils.time import ensure_utc class UserRepository: def __init__(self, session: AsyncSession) -> None: self.session = session async def get_or_create_by_chat(self, chat_id: str, username: str | None = None) -> User: user = await self.get_by_chat_id(chat_id) if user: if username and user.telegram_username != username: user.telegram_username = username return user user = User(telegram_chat_id=chat_id, telegram_username=username) self.session.add(user) await self.session.flush() return user async def get_by_chat_id(self, chat_id: str) -> User | None: result = await self.session.execute(select(User).where(User.telegram_chat_id == chat_id)) return result.scalar_one_or_none() async def get_by_id(self, user_id: int) -> User | None: result = await self.session.execute(select(User).where(User.id == user_id)) return result.scalar_one_or_none() async def list_active_connected_users(self) -> list[User]: result = await self.session.execute( select(User).where(User.is_active.is_(True), User.spotify_refresh_token.is_not(None)) ) return list(result.scalars().all()) class AuthStateRepository: def __init__(self, session: AsyncSession) -> None: self.session = session async def create(self, state: str, telegram_chat_id: str, expires_at: datetime) -> AuthState: row = AuthState( state=state, telegram_chat_id=telegram_chat_id, expires_at=expires_at, created_at=datetime.now(timezone.utc), ) self.session.add(row) await self.session.flush() return row async def pop_valid(self, state: str) -> AuthState | None: now = datetime.now(timezone.utc) result = await self.session.execute(select(AuthState).where(AuthState.state == state)) row = result.scalar_one_or_none() if not row: return None await self.session.delete(row) if ensure_utc(row.expires_at) < now: return None return row async def delete_expired(self) -> int: result = await self.session.execute(delete(AuthState).where(AuthState.expires_at < datetime.now(timezone.utc))) return result.rowcount or 0 class SavedTrackRepository: def __init__(self, session: AsyncSession) -> None: self.session = session async def replace_for_user(self, user_id: int, tracks: Iterable[dict]) -> None: await self.session.execute(delete(SavedTrack).where(SavedTrack.user_id == user_id)) for item in tracks: self.session.add( SavedTrack( user_id=user_id, spotify_track_id=item["id"], name=item["name"], artist_names=", ".join(item["artist_names"]), artist_ids_csv=",".join(item["artist_ids"]), album_name=item.get("album_name"), added_at=item.get("added_at"), popularity=item.get("popularity"), ) ) await self.session.flush() async def list_for_user(self, user_id: int) -> list[SavedTrack]: result = await self.session.execute(select(SavedTrack).where(SavedTrack.user_id == user_id)) return list(result.scalars().all()) async def count_for_user(self, user_id: int) -> int: result = await self.session.execute(select(func.count()).select_from(SavedTrack).where(SavedTrack.user_id == user_id)) return int(result.scalar_one()) class RecommendationHistoryRepository: def __init__(self, session: AsyncSession) -> None: self.session = session async def list_track_ids(self, user_id: int) -> set[str]: result = await self.session.execute( select(RecommendationHistory.spotify_track_id).where(RecommendationHistory.user_id == user_id) ) return {row[0] for row in result.all()} async def mark_tracks(self, user_id: int, track_ids: list[str]) -> None: if not track_ids: return now = datetime.now(timezone.utc) result = await self.session.execute(select(RecommendationHistory).where(RecommendationHistory.user_id == user_id)) existing = {row.spotify_track_id: row for row in result.scalars().all()} for track_id in track_ids: if track_id in existing: row = existing[track_id] row.last_recommended_at = now row.times_recommended += 1 else: self.session.add( RecommendationHistory( user_id=user_id, spotify_track_id=track_id, first_recommended_at=now, last_recommended_at=now, times_recommended=1, ) ) await self.session.flush() class PlaylistRunRepository: def __init__(self, session: AsyncSession) -> None: self.session = session async def create_run(self, user_id: int, run_date: date, notes: str | None = None) -> PlaylistRun: run = PlaylistRun(user_id=user_id, run_date=run_date, status="running", notes=notes) self.session.add(run) await self.session.flush() return run async def add_tracks(self, run_id: int, tracks: list[dict]) -> None: for idx, track in enumerate(tracks, start=1): self.session.add( PlaylistRunTrack( run_id=run_id, spotify_track_id=track["id"], name=track["name"], artist_names=", ".join(track["artist_names"]), source=track["source"], position=idx, is_new_to_bot=track.get("is_new_to_bot", True), ) ) await self.session.flush() async def mark_success( self, run: PlaylistRun, *, playlist_id: str, playlist_name: str, playlist_url: str | None, total_tracks: int, new_tracks: int, reused_tracks: int, notes: str | None = None, ) -> None: run.status = "success" run.playlist_id = playlist_id run.playlist_name = playlist_name run.playlist_url = playlist_url run.total_tracks = total_tracks run.new_tracks = new_tracks run.reused_tracks = reused_tracks run.notes = notes await self.session.flush() async def mark_failed(self, run: PlaylistRun, message: str) -> None: run.status = "failed" run.notes = message await self.session.flush() async def latest_for_user(self, user_id: int) -> PlaylistRun | None: result = await self.session.execute( select(PlaylistRun).where(PlaylistRun.user_id == user_id).order_by(PlaylistRun.created_at.desc()).limit(1) ) return result.scalar_one_or_none()