202 lines
7.3 KiB
Python
202 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
from datetime import date, datetime, timezone
|
|
|
|
from sqlalchemy import delete, func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models import (
|
|
AuthState,
|
|
PlaylistRun,
|
|
PlaylistRunTrack,
|
|
RecommendationHistory,
|
|
SavedTrack,
|
|
User,
|
|
)
|
|
from app.utils.time import ensure_utc
|
|
|
|
|
|
class UserRepository:
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self.session = session
|
|
|
|
async def get_or_create_by_chat(self, chat_id: str, username: str | None = None) -> User:
|
|
user = await self.get_by_chat_id(chat_id)
|
|
if user:
|
|
if username and user.telegram_username != username:
|
|
user.telegram_username = username
|
|
return user
|
|
user = User(telegram_chat_id=chat_id, telegram_username=username)
|
|
self.session.add(user)
|
|
await self.session.flush()
|
|
return user
|
|
|
|
async def get_by_chat_id(self, chat_id: str) -> User | None:
|
|
result = await self.session.execute(select(User).where(User.telegram_chat_id == chat_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_id(self, user_id: int) -> User | None:
|
|
result = await self.session.execute(select(User).where(User.id == user_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
async def list_active_connected_users(self) -> list[User]:
|
|
result = await self.session.execute(
|
|
select(User).where(User.is_active.is_(True), User.spotify_refresh_token.is_not(None))
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
class AuthStateRepository:
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self.session = session
|
|
|
|
async def create(self, state: str, telegram_chat_id: str, expires_at: datetime) -> AuthState:
|
|
row = AuthState(
|
|
state=state,
|
|
telegram_chat_id=telegram_chat_id,
|
|
expires_at=expires_at,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
self.session.add(row)
|
|
await self.session.flush()
|
|
return row
|
|
|
|
async def pop_valid(self, state: str) -> AuthState | None:
|
|
now = datetime.now(timezone.utc)
|
|
result = await self.session.execute(select(AuthState).where(AuthState.state == state))
|
|
row = result.scalar_one_or_none()
|
|
if not row:
|
|
return None
|
|
await self.session.delete(row)
|
|
if ensure_utc(row.expires_at) < now:
|
|
return None
|
|
return row
|
|
|
|
async def delete_expired(self) -> int:
|
|
result = await self.session.execute(delete(AuthState).where(AuthState.expires_at < datetime.now(timezone.utc)))
|
|
return result.rowcount or 0
|
|
|
|
|
|
class SavedTrackRepository:
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self.session = session
|
|
|
|
async def replace_for_user(self, user_id: int, tracks: Iterable[dict]) -> None:
|
|
await self.session.execute(delete(SavedTrack).where(SavedTrack.user_id == user_id))
|
|
for item in tracks:
|
|
self.session.add(
|
|
SavedTrack(
|
|
user_id=user_id,
|
|
spotify_track_id=item["id"],
|
|
name=item["name"],
|
|
artist_names=", ".join(item["artist_names"]),
|
|
artist_ids_csv=",".join(item["artist_ids"]),
|
|
album_name=item.get("album_name"),
|
|
added_at=item.get("added_at"),
|
|
popularity=item.get("popularity"),
|
|
)
|
|
)
|
|
await self.session.flush()
|
|
|
|
async def list_for_user(self, user_id: int) -> list[SavedTrack]:
|
|
result = await self.session.execute(select(SavedTrack).where(SavedTrack.user_id == user_id))
|
|
return list(result.scalars().all())
|
|
|
|
async def count_for_user(self, user_id: int) -> int:
|
|
result = await self.session.execute(select(func.count()).select_from(SavedTrack).where(SavedTrack.user_id == user_id))
|
|
return int(result.scalar_one())
|
|
|
|
|
|
class RecommendationHistoryRepository:
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self.session = session
|
|
|
|
async def list_track_ids(self, user_id: int) -> set[str]:
|
|
result = await self.session.execute(
|
|
select(RecommendationHistory.spotify_track_id).where(RecommendationHistory.user_id == user_id)
|
|
)
|
|
return {row[0] for row in result.all()}
|
|
|
|
async def mark_tracks(self, user_id: int, track_ids: list[str]) -> None:
|
|
if not track_ids:
|
|
return
|
|
now = datetime.now(timezone.utc)
|
|
result = await self.session.execute(select(RecommendationHistory).where(RecommendationHistory.user_id == user_id))
|
|
existing = {row.spotify_track_id: row for row in result.scalars().all()}
|
|
for track_id in track_ids:
|
|
if track_id in existing:
|
|
row = existing[track_id]
|
|
row.last_recommended_at = now
|
|
row.times_recommended += 1
|
|
else:
|
|
self.session.add(
|
|
RecommendationHistory(
|
|
user_id=user_id,
|
|
spotify_track_id=track_id,
|
|
first_recommended_at=now,
|
|
last_recommended_at=now,
|
|
times_recommended=1,
|
|
)
|
|
)
|
|
await self.session.flush()
|
|
|
|
|
|
class PlaylistRunRepository:
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self.session = session
|
|
|
|
async def create_run(self, user_id: int, run_date: date, notes: str | None = None) -> PlaylistRun:
|
|
run = PlaylistRun(user_id=user_id, run_date=run_date, status="running", notes=notes)
|
|
self.session.add(run)
|
|
await self.session.flush()
|
|
return run
|
|
|
|
async def add_tracks(self, run_id: int, tracks: list[dict]) -> None:
|
|
for idx, track in enumerate(tracks, start=1):
|
|
self.session.add(
|
|
PlaylistRunTrack(
|
|
run_id=run_id,
|
|
spotify_track_id=track["id"],
|
|
name=track["name"],
|
|
artist_names=", ".join(track["artist_names"]),
|
|
source=track["source"],
|
|
position=idx,
|
|
is_new_to_bot=track.get("is_new_to_bot", True),
|
|
)
|
|
)
|
|
await self.session.flush()
|
|
|
|
async def mark_success(
|
|
self,
|
|
run: PlaylistRun,
|
|
*,
|
|
playlist_id: str,
|
|
playlist_name: str,
|
|
playlist_url: str | None,
|
|
total_tracks: int,
|
|
new_tracks: int,
|
|
reused_tracks: int,
|
|
notes: str | None = None,
|
|
) -> None:
|
|
run.status = "success"
|
|
run.playlist_id = playlist_id
|
|
run.playlist_name = playlist_name
|
|
run.playlist_url = playlist_url
|
|
run.total_tracks = total_tracks
|
|
run.new_tracks = new_tracks
|
|
run.reused_tracks = reused_tracks
|
|
run.notes = notes
|
|
await self.session.flush()
|
|
|
|
async def mark_failed(self, run: PlaylistRun, message: str) -> None:
|
|
run.status = "failed"
|
|
run.notes = message
|
|
await self.session.flush()
|
|
|
|
async def latest_for_user(self, user_id: int) -> PlaylistRun | None:
|
|
result = await self.session.execute(
|
|
select(PlaylistRun).where(PlaylistRun.user_id == user_id).order_by(PlaylistRun.created_at.desc()).limit(1)
|
|
)
|
|
return result.scalar_one_or_none()
|