A kind of initial commit
This commit is contained in:
0
app/services/__init__.py
Normal file
0
app/services/__init__.py
Normal file
15
app/services/app_services.py
Normal file
15
app/services/app_services.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.services.playlist_job import PlaylistJobService
|
||||
from app.services.recommendation import RecommendationEngine
|
||||
from app.services.spotify_auth import SpotifyAuthService
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppServices:
|
||||
auth: SpotifyAuthService
|
||||
recommendation: RecommendationEngine
|
||||
jobs: PlaylistJobService
|
||||
|
||||
149
app/services/playlist_job.py
Normal file
149
app/services/playlist_job.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.config import Settings
|
||||
from app.db.repositories import PlaylistRunRepository, RecommendationHistoryRepository, UserRepository
|
||||
from app.services.recommendation import RecommendationEngine
|
||||
from app.services.spotify_auth import SpotifyAuthService
|
||||
from app.types import PlaylistBuildResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobOutcome:
|
||||
user_id: int
|
||||
ok: bool
|
||||
message: str
|
||||
playlist_url: str | None = None
|
||||
|
||||
|
||||
class PlaylistJobService:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
auth_service: SpotifyAuthService,
|
||||
recommendation_engine: RecommendationEngine,
|
||||
generate_lock: asyncio.Lock,
|
||||
) -> None:
|
||||
self.settings = settings
|
||||
self.session_factory = session_factory
|
||||
self.auth_service = auth_service
|
||||
self.recommendation_engine = recommendation_engine
|
||||
self.generate_lock = generate_lock
|
||||
self._notify = None
|
||||
|
||||
def set_notifier(self, notifier) -> None:
|
||||
self._notify = notifier
|
||||
|
||||
async def generate_for_user(self, user_id: int, *, force: bool = False, notify: bool = True) -> JobOutcome:
|
||||
async with self.generate_lock:
|
||||
async with self.session_factory() as session:
|
||||
users = UserRepository(session)
|
||||
runs = PlaylistRunRepository(session)
|
||||
history = RecommendationHistoryRepository(session)
|
||||
|
||||
user = await users.get_by_id(user_id)
|
||||
if not user:
|
||||
return JobOutcome(user_id=user_id, ok=False, message="User not found")
|
||||
if not user.spotify_refresh_token:
|
||||
return JobOutcome(user_id=user_id, ok=False, message="Spotify is not connected")
|
||||
if not force and user.last_generated_date == date.today():
|
||||
latest = await runs.latest_for_user(user.id)
|
||||
return JobOutcome(
|
||||
user_id=user.id,
|
||||
ok=True,
|
||||
message="Already generated today",
|
||||
playlist_url=latest.playlist_url if latest else user.latest_playlist_url,
|
||||
)
|
||||
|
||||
run = await runs.create_run(user_id=user.id, run_date=date.today())
|
||||
try:
|
||||
access_token = await self.auth_service.ensure_valid_access_token(session, user)
|
||||
# Re-sync likes each run so new likes affect next day's picks.
|
||||
await self.recommendation_engine.sync_saved_tracks(session, user, access_token)
|
||||
build = await self.recommendation_engine.build_daily_playlist(session, user, access_token)
|
||||
if not build.tracks:
|
||||
raise RuntimeError("No candidate tracks found. Try listening more or widen sources.")
|
||||
playlist_meta = await self._create_spotify_playlist(session, user, access_token, build)
|
||||
|
||||
serialized_tracks = []
|
||||
for c in build.tracks:
|
||||
serialized_tracks.append(
|
||||
{
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"artist_names": c.artist_names,
|
||||
"source": c.source,
|
||||
"is_new_to_bot": True, # fixed below
|
||||
}
|
||||
)
|
||||
history_ids = await history.list_track_ids(user.id)
|
||||
for item in serialized_tracks:
|
||||
item["is_new_to_bot"] = item["id"] not in history_ids
|
||||
|
||||
await runs.add_tracks(run.id, serialized_tracks)
|
||||
await history.mark_tracks(user.id, [c.id for c in build.tracks])
|
||||
await runs.mark_success(
|
||||
run,
|
||||
playlist_id=playlist_meta["id"],
|
||||
playlist_name=playlist_meta["name"],
|
||||
playlist_url=playlist_meta.get("url"),
|
||||
total_tracks=len(build.tracks),
|
||||
new_tracks=build.new_count,
|
||||
reused_tracks=build.reused_count,
|
||||
notes=build.notes,
|
||||
)
|
||||
user.last_generated_date = date.today()
|
||||
user.latest_playlist_id = playlist_meta["id"]
|
||||
user.latest_playlist_url = playlist_meta.get("url")
|
||||
await session.commit()
|
||||
message = (
|
||||
f"Playlist ready: {playlist_meta['name']} ({len(build.tracks)} tracks, "
|
||||
f"new {build.new_count}/{len(build.tracks)})"
|
||||
)
|
||||
if notify and self._notify:
|
||||
await self._notify(user.telegram_chat_id, f"{message}\n{playlist_meta.get('url', '')}".strip())
|
||||
return JobOutcome(user_id=user.id, ok=True, message=message, playlist_url=playlist_meta.get("url"))
|
||||
except Exception as exc:
|
||||
await runs.mark_failed(run, str(exc))
|
||||
await session.commit()
|
||||
if notify and self._notify:
|
||||
await self._notify(user.telegram_chat_id, f"Playlist generation failed: {exc}")
|
||||
return JobOutcome(user_id=user.id, ok=False, message=str(exc))
|
||||
|
||||
async def generate_for_all_connected_users(self) -> list[JobOutcome]:
|
||||
async with self.session_factory() as session:
|
||||
users_repo = UserRepository(session)
|
||||
users = await users_repo.list_active_connected_users()
|
||||
outcomes: list[JobOutcome] = []
|
||||
for user in users:
|
||||
outcomes.append(await self.generate_for_user(user.id, notify=True))
|
||||
return outcomes
|
||||
|
||||
async def _create_spotify_playlist(
|
||||
self, session: AsyncSession, user, access_token: str, build: PlaylistBuildResult
|
||||
) -> dict[str, str | None]:
|
||||
public = self.settings.playlist_visibility.lower() == "public"
|
||||
name = f"Daily Vibe {date.today().isoformat()}"
|
||||
desc = (
|
||||
"Auto-generated from your recent listening + liked tracks. "
|
||||
f"New-to-bot: {build.new_count}/{len(build.tracks)}."
|
||||
)
|
||||
playlist = await self.auth_service.spotify.create_playlist(
|
||||
access_token,
|
||||
user_id=user.spotify_user_id,
|
||||
name=name,
|
||||
description=desc,
|
||||
public=public,
|
||||
)
|
||||
await self.auth_service.spotify.add_playlist_items(access_token, playlist["id"], [c.uri for c in build.tracks])
|
||||
return {
|
||||
"id": playlist["id"],
|
||||
"name": playlist["name"],
|
||||
"url": ((playlist.get("external_urls") or {}).get("spotify")),
|
||||
}
|
||||
374
app/services/recommendation.py
Normal file
374
app/services/recommendation.py
Normal file
@@ -0,0 +1,374 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import random
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import date, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.clients.lastfm import LastFmClient
|
||||
from app.clients.spotify import SpotifyApiError, SpotifyClient
|
||||
from app.config import Settings
|
||||
from app.db.models import SavedTrack, User
|
||||
from app.db.repositories import RecommendationHistoryRepository, SavedTrackRepository
|
||||
from app.types import PlaylistBuildResult, TrackCandidate
|
||||
from app.utils.text import normalize_track_signature
|
||||
from app.utils.time import utcnow
|
||||
|
||||
|
||||
class RecommendationEngine:
|
||||
def __init__(self, settings: Settings, spotify: SpotifyClient, lastfm: LastFmClient) -> None:
|
||||
self.settings = settings
|
||||
self.spotify = spotify
|
||||
self.lastfm = lastfm
|
||||
|
||||
async def sync_saved_tracks(self, session: AsyncSession, user: User, access_token: str) -> list[SavedTrack]:
|
||||
saved_tracks_repo = SavedTrackRepository(session)
|
||||
raw = await self.spotify.get_saved_tracks_all(access_token)
|
||||
await saved_tracks_repo.replace_for_user(user.id, raw)
|
||||
return await saved_tracks_repo.list_for_user(user.id)
|
||||
|
||||
async def build_daily_playlist(self, session: AsyncSession, user: User, access_token: str) -> PlaylistBuildResult:
|
||||
saved_tracks_repo = SavedTrackRepository(session)
|
||||
history_repo = RecommendationHistoryRepository(session)
|
||||
|
||||
saved_rows = await saved_tracks_repo.list_for_user(user.id)
|
||||
if not saved_rows:
|
||||
saved_rows = await self.sync_saved_tracks(session, user, access_token)
|
||||
|
||||
recent_since = utcnow() - timedelta(days=self.settings.recent_days_window)
|
||||
recent_plays = await self.spotify.get_recently_played(access_token, since=recent_since, max_pages=12)
|
||||
|
||||
seed = self._build_seed_profile(saved_rows, recent_plays, user_id=user.id)
|
||||
history_ids = await history_repo.list_track_ids(user.id)
|
||||
liked_ids = {row.spotify_track_id for row in saved_rows}
|
||||
|
||||
market = self._normalize_spotify_market(self.settings.spotify_default_market)
|
||||
candidates = await self._collect_candidates(
|
||||
access_token=access_token,
|
||||
seed=seed,
|
||||
market=market,
|
||||
)
|
||||
result = self._rank_and_select(
|
||||
candidates=candidates,
|
||||
liked_ids=liked_ids,
|
||||
history_ids=history_ids,
|
||||
target_size=max(1, user.playlist_size or self.settings.default_playlist_size),
|
||||
min_new_ratio=user.min_new_ratio if user.min_new_ratio is not None else self.settings.min_new_ratio,
|
||||
)
|
||||
return result
|
||||
|
||||
def _build_seed_profile(self, saved_rows: list[SavedTrack], recent_plays: list[dict[str, Any]], *, user_id: int) -> dict[str, Any]:
|
||||
today = date.today()
|
||||
rng = random.Random(f"{user_id}-{today.isoformat()}")
|
||||
|
||||
recent_track_counts: Counter[str] = Counter()
|
||||
recent_track_meta: dict[str, dict[str, Any]] = {}
|
||||
artist_weights: Counter[str] = Counter()
|
||||
artist_names: dict[str, str] = {}
|
||||
|
||||
for idx, play in enumerate(sorted(recent_plays, key=lambda x: x.get("played_at") or utcnow(), reverse=True)):
|
||||
track_id = play["id"]
|
||||
weight = max(1.0, 3.0 - (idx * 0.04))
|
||||
recent_track_counts[track_id] += weight
|
||||
recent_track_meta[track_id] = play
|
||||
for artist_id, artist_name in zip(play.get("artist_ids", []), play.get("artist_names", [])):
|
||||
artist_weights[artist_id] += weight
|
||||
artist_names[artist_id] = artist_name
|
||||
|
||||
sorted_saved = sorted(saved_rows, key=lambda x: x.added_at or utcnow(), reverse=True)
|
||||
recent_likes = sorted_saved[:120]
|
||||
sampled_older = rng.sample(sorted_saved[120:], k=min(180, max(0, len(sorted_saved) - 120))) if len(sorted_saved) > 120 else []
|
||||
exploration_pool = recent_likes + sampled_older
|
||||
|
||||
for idx, row in enumerate(exploration_pool):
|
||||
base_weight = 1.2 if idx < 50 else 0.6
|
||||
artist_ids = [a for a in row.artist_ids_csv.split(",") if a]
|
||||
artist_list = [x.strip() for x in row.artist_names.split(",") if x.strip()]
|
||||
for artist_id, artist_name in zip(artist_ids, artist_list):
|
||||
artist_weights[artist_id] += base_weight
|
||||
artist_names[artist_id] = artist_name
|
||||
|
||||
seed_track_ids = [t for t, _ in recent_track_counts.most_common(10)]
|
||||
if len(seed_track_ids) < 10:
|
||||
for row in recent_likes[:20]:
|
||||
if row.spotify_track_id not in seed_track_ids:
|
||||
seed_track_ids.append(row.spotify_track_id)
|
||||
if len(seed_track_ids) >= 10:
|
||||
break
|
||||
|
||||
seed_artist_ids = [a for a, _ in artist_weights.most_common(20)]
|
||||
seed_artist_names = [artist_names[a] for a in seed_artist_ids if a in artist_names]
|
||||
|
||||
return {
|
||||
"seed_track_ids": seed_track_ids,
|
||||
"seed_artists": seed_artist_ids,
|
||||
"seed_artist_names": seed_artist_names,
|
||||
"recent_track_meta": recent_track_meta,
|
||||
}
|
||||
|
||||
async def _collect_candidates(self, *, access_token: str, seed: dict[str, Any], market: str | None) -> list[TrackCandidate]:
|
||||
by_id: dict[str, TrackCandidate] = {}
|
||||
sig_to_id: dict[str, str] = {}
|
||||
source_count: Counter[str] = Counter()
|
||||
recent_track_meta = seed["recent_track_meta"]
|
||||
|
||||
def upsert(candidate: TrackCandidate) -> None:
|
||||
sig = normalize_track_signature(candidate.name, candidate.artist_names)
|
||||
existing_id = sig_to_id.get(sig)
|
||||
if existing_id and existing_id != candidate.id and existing_id in by_id:
|
||||
if candidate.score <= by_id[existing_id].score:
|
||||
return
|
||||
del by_id[existing_id]
|
||||
existing = by_id.get(candidate.id)
|
||||
if existing:
|
||||
existing.score = max(existing.score, candidate.score)
|
||||
if candidate.source not in existing.source:
|
||||
existing.source = f"{existing.source}+{candidate.source}"
|
||||
for reason in candidate.seed_reasons:
|
||||
if reason not in existing.seed_reasons:
|
||||
existing.seed_reasons.append(reason)
|
||||
return
|
||||
by_id[candidate.id] = candidate
|
||||
sig_to_id[sig] = candidate.id
|
||||
source_count[candidate.source] += 1
|
||||
|
||||
seed_tracks = list(seed["seed_track_ids"])
|
||||
seed_artists = list(seed["seed_artists"])
|
||||
top_tracks_market = market or "US"
|
||||
|
||||
for batch_idx in range(4):
|
||||
# Spotify recommendations endpoint supports max 5 total seeds.
|
||||
track_start = batch_idx * 2
|
||||
artist_start = batch_idx * 3
|
||||
batch_seed_tracks = seed_tracks[track_start : track_start + 2]
|
||||
remaining_slots = max(0, 5 - len(batch_seed_tracks))
|
||||
batch_seed_artists = seed_artists[artist_start : artist_start + remaining_slots]
|
||||
if not batch_seed_tracks and not batch_seed_artists:
|
||||
continue
|
||||
try:
|
||||
rec_tracks = await self.spotify.get_recommendations(
|
||||
access_token,
|
||||
seed_tracks=batch_seed_tracks,
|
||||
seed_artists=batch_seed_artists,
|
||||
limit=100,
|
||||
market=market,
|
||||
)
|
||||
except SpotifyApiError:
|
||||
rec_tracks = []
|
||||
for raw in rec_tracks:
|
||||
cand = self._candidate_from_spotify_track(raw, source="spotify_recommendations", base_score=1.0)
|
||||
if not cand:
|
||||
continue
|
||||
if any(a in batch_seed_artists for a in cand.artist_ids):
|
||||
cand.score += 0.08
|
||||
upsert(cand)
|
||||
|
||||
for artist_id in seed_artists[:12]:
|
||||
try:
|
||||
top_tracks = await self.spotify.get_artist_top_tracks(
|
||||
access_token, artist_id=artist_id, market=top_tracks_market
|
||||
)
|
||||
except SpotifyApiError:
|
||||
continue
|
||||
for raw in top_tracks:
|
||||
cand = self._candidate_from_spotify_track(raw, source="artist_top_tracks", base_score=0.68)
|
||||
if not cand:
|
||||
continue
|
||||
if artist_id in cand.artist_ids:
|
||||
cand.score += 0.07
|
||||
upsert(cand)
|
||||
|
||||
# Fallback for apps/accounts where top-tracks and recommendations endpoints are restricted.
|
||||
if len(by_id) < 40:
|
||||
for artist_name in seed.get("seed_artist_names", [])[:12]:
|
||||
if not artist_name:
|
||||
continue
|
||||
try:
|
||||
search_hits = await self.spotify.search_track(
|
||||
access_token,
|
||||
track_name="",
|
||||
artist_name=artist_name,
|
||||
market=market,
|
||||
)
|
||||
except SpotifyApiError:
|
||||
continue
|
||||
for raw in search_hits[:3]:
|
||||
cand = self._candidate_from_spotify_track(raw, source="spotify_search_artist", base_score=0.55)
|
||||
if not cand:
|
||||
continue
|
||||
if artist_name.lower() in {a.lower() for a in cand.artist_names}:
|
||||
cand.score += 0.05
|
||||
upsert(cand)
|
||||
|
||||
if self.lastfm.enabled:
|
||||
for track_id in seed_tracks[:10]:
|
||||
meta = recent_track_meta.get(track_id)
|
||||
if not meta:
|
||||
continue
|
||||
artist_name = (meta.get("artist_names") or [None])[0]
|
||||
if not artist_name:
|
||||
continue
|
||||
try:
|
||||
similars = await self.lastfm.track_similar(artist=artist_name, track=meta["name"], limit=10)
|
||||
except Exception:
|
||||
similars = []
|
||||
for item in similars[:5]:
|
||||
lf_name = item.get("name")
|
||||
lf_artist = (item.get("artist") or {}).get("name") if isinstance(item.get("artist"), dict) else None
|
||||
if not lf_name:
|
||||
continue
|
||||
try:
|
||||
search_hits = await self.spotify.search_track(
|
||||
access_token, track_name=lf_name, artist_name=lf_artist, market=market
|
||||
)
|
||||
except SpotifyApiError:
|
||||
search_hits = []
|
||||
for raw in search_hits[:1]:
|
||||
cand = self._candidate_from_spotify_track(raw, source="lastfm_track_similar", base_score=0.9)
|
||||
if cand:
|
||||
upsert(cand)
|
||||
|
||||
for artist_name in seed.get("seed_artist_names", [])[:8]:
|
||||
try:
|
||||
similars = await self.lastfm.artist_similar(artist=artist_name, limit=8)
|
||||
except Exception:
|
||||
similars = []
|
||||
for item in similars[:4]:
|
||||
sim_artist = item.get("name")
|
||||
if not sim_artist:
|
||||
continue
|
||||
try:
|
||||
search_hits = await self.spotify.search_track(
|
||||
access_token, track_name="", artist_name=sim_artist, market=market
|
||||
)
|
||||
except SpotifyApiError:
|
||||
search_hits = []
|
||||
for raw in search_hits[:2]:
|
||||
cand = self._candidate_from_spotify_track(raw, source="lastfm_artist_similar", base_score=0.78)
|
||||
if cand:
|
||||
upsert(cand)
|
||||
|
||||
return list(by_id.values())
|
||||
|
||||
def _candidate_from_spotify_track(self, raw: dict[str, Any], *, source: str, base_score: float) -> TrackCandidate | None:
|
||||
track_id = raw.get("id")
|
||||
if not track_id:
|
||||
return None
|
||||
artists = raw.get("artists") or []
|
||||
artist_names = [a.get("name") or "Unknown" for a in artists]
|
||||
artist_ids = [a.get("id") for a in artists if a.get("id")]
|
||||
popularity = raw.get("popularity")
|
||||
|
||||
score = base_score
|
||||
if isinstance(popularity, int):
|
||||
# Prefer mid-popularity a bit to avoid obvious mainstream repeats and totally obscure misses.
|
||||
score += max(-0.12, 0.15 - abs(popularity - 55) / 250)
|
||||
|
||||
return TrackCandidate(
|
||||
id=track_id,
|
||||
uri=raw.get("uri") or f"spotify:track:{track_id}",
|
||||
name=raw.get("name") or "Unknown",
|
||||
artist_names=artist_names,
|
||||
artist_ids=artist_ids,
|
||||
popularity=popularity,
|
||||
source=source,
|
||||
score=score,
|
||||
)
|
||||
|
||||
def _rank_and_select(
|
||||
self,
|
||||
*,
|
||||
candidates: list[TrackCandidate],
|
||||
liked_ids: set[str],
|
||||
history_ids: set[str],
|
||||
target_size: int,
|
||||
min_new_ratio: float,
|
||||
) -> PlaylistBuildResult:
|
||||
min_new_required = math.ceil(target_size * min_new_ratio)
|
||||
|
||||
filtered = [c for c in candidates if c.id not in liked_ids]
|
||||
liked_fallback_used = False
|
||||
if not filtered and candidates:
|
||||
# If every discovered candidate is already liked, degrade gracefully instead of failing the run.
|
||||
filtered = list(candidates)
|
||||
liked_fallback_used = True
|
||||
for c in filtered:
|
||||
if c.id in liked_ids:
|
||||
c.score -= 0.35
|
||||
if c.id in history_ids:
|
||||
c.score -= 0.2
|
||||
if len(c.artist_ids) > 1:
|
||||
c.score += 0.03
|
||||
c.score += min(0.15, 0.01 * len(c.seed_reasons))
|
||||
|
||||
filtered.sort(key=lambda c: (c.score, c.popularity or 0), reverse=True)
|
||||
|
||||
novel = [c for c in filtered if c.id not in history_ids and c.id not in liked_ids]
|
||||
reused = [c for c in filtered if c.id in history_ids or c.id in liked_ids]
|
||||
|
||||
selected: list[TrackCandidate] = []
|
||||
artist_caps: defaultdict[str, int] = defaultdict(int)
|
||||
|
||||
def try_take(pool: list[TrackCandidate], count: int, hard_artist_cap: int) -> None:
|
||||
for c in pool:
|
||||
if len(selected) >= count:
|
||||
return
|
||||
if any(s.id == c.id for s in selected):
|
||||
continue
|
||||
main_artist = c.artist_ids[0] if c.artist_ids else f"name:{(c.artist_names or [''])[0].lower()}"
|
||||
if artist_caps[main_artist] >= hard_artist_cap:
|
||||
continue
|
||||
artist_caps[main_artist] += 1
|
||||
selected.append(c)
|
||||
|
||||
try_take(novel, min_new_required, hard_artist_cap=2)
|
||||
if len([c for c in selected if c.id not in history_ids and c.id not in liked_ids]) < min_new_required:
|
||||
try_take(novel, min_new_required, hard_artist_cap=4)
|
||||
|
||||
try_take(novel, target_size, hard_artist_cap=3)
|
||||
try_take(reused, target_size, hard_artist_cap=2)
|
||||
if len(selected) < target_size:
|
||||
try_take(reused, target_size, hard_artist_cap=4)
|
||||
|
||||
# Mark novelty for persistence.
|
||||
for c in selected:
|
||||
c.seed_reasons = c.seed_reasons or []
|
||||
|
||||
new_count = len([c for c in selected if c.id not in history_ids and c.id not in liked_ids])
|
||||
reused_count = len(selected) - new_count
|
||||
|
||||
notes_parts: list[str] = []
|
||||
if liked_fallback_used:
|
||||
notes_parts.append("All discovered candidates were already in Liked Songs; allowed liked-track fallback.")
|
||||
if new_count < min_new_required:
|
||||
notes_parts.append(
|
||||
f"Not enough completely new tracks to satisfy {int(min_new_ratio * 100)}% target "
|
||||
f"(got {new_count}/{target_size})."
|
||||
)
|
||||
notes = " ".join(notes_parts) if notes_parts else None
|
||||
|
||||
return PlaylistBuildResult(
|
||||
tracks=selected,
|
||||
target_size=target_size,
|
||||
new_count=new_count,
|
||||
reused_count=reused_count,
|
||||
min_new_required=min_new_required,
|
||||
notes=notes,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_spotify_market(value: str | None) -> str | None:
|
||||
if not value:
|
||||
return None
|
||||
market = value.strip().upper()
|
||||
if not market:
|
||||
return None
|
||||
# Common shorthand users put in .env, but Spotify APIs expect a country code.
|
||||
if market in {"EU", "GLOBAL", "WORLD", "ALL"}:
|
||||
return None
|
||||
if len(market) == 2 and market.isalpha():
|
||||
return market
|
||||
return None
|
||||
86
app/services/spotify_auth.py
Normal file
86
app/services/spotify_auth.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from datetime import timedelta
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.clients.spotify import SpotifyClient
|
||||
from app.config import Settings
|
||||
from app.db.repositories import AuthStateRepository, UserRepository
|
||||
from app.utils.time import ensure_utc, utcnow
|
||||
|
||||
|
||||
SPOTIFY_SCOPES = [
|
||||
"user-library-read",
|
||||
"user-read-recently-played",
|
||||
"playlist-modify-private",
|
||||
"playlist-modify-public",
|
||||
]
|
||||
|
||||
|
||||
class SpotifyAuthService:
|
||||
def __init__(self, settings: Settings, spotify: SpotifyClient, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self.settings = settings
|
||||
self.spotify = spotify
|
||||
self.session_factory = session_factory
|
||||
|
||||
async def create_connect_url(self, chat_id: str, username: str | None = None) -> str:
|
||||
async with self.session_factory() as session:
|
||||
users = UserRepository(session)
|
||||
states = AuthStateRepository(session)
|
||||
user = await users.get_or_create_by_chat(chat_id=chat_id, username=username)
|
||||
user.timezone = user.timezone or self.settings.app_timezone
|
||||
state = secrets.token_urlsafe(24)
|
||||
await states.delete_expired()
|
||||
await states.create(state=state, telegram_chat_id=user.telegram_chat_id, expires_at=utcnow() + timedelta(minutes=15))
|
||||
await session.commit()
|
||||
return self.spotify.build_authorize_url(state=state, scopes=SPOTIFY_SCOPES)
|
||||
|
||||
async def handle_callback(self, code: str, state: str) -> tuple[str, str]:
|
||||
async with self.session_factory() as session:
|
||||
users = UserRepository(session)
|
||||
states = AuthStateRepository(session)
|
||||
auth_state = await states.pop_valid(state)
|
||||
if not auth_state:
|
||||
await session.commit()
|
||||
raise ValueError("OAuth state is invalid or expired")
|
||||
|
||||
user = await users.get_by_chat_id(auth_state.telegram_chat_id)
|
||||
if not user:
|
||||
raise ValueError("User not found for auth state")
|
||||
|
||||
token_payload = await self.spotify.exchange_code(code)
|
||||
access_token = token_payload["access_token"]
|
||||
me = await self.spotify.get_current_user(access_token)
|
||||
|
||||
user.spotify_user_id = me.get("id")
|
||||
user.spotify_access_token = access_token
|
||||
user.spotify_refresh_token = token_payload.get("refresh_token") or user.spotify_refresh_token
|
||||
user.spotify_token_expires_at = self.spotify.token_expiry_from_response(token_payload)
|
||||
user.spotify_scopes = token_payload.get("scope")
|
||||
user.is_active = True
|
||||
if not user.timezone:
|
||||
user.timezone = self.settings.app_timezone
|
||||
|
||||
await session.commit()
|
||||
return user.telegram_chat_id, me.get("display_name") or me.get("id") or "Spotify user"
|
||||
|
||||
async def ensure_valid_access_token(self, session: AsyncSession, user) -> str:
|
||||
if (
|
||||
user.spotify_access_token
|
||||
and user.spotify_token_expires_at
|
||||
and ensure_utc(user.spotify_token_expires_at) > utcnow()
|
||||
):
|
||||
return user.spotify_access_token
|
||||
if not user.spotify_refresh_token:
|
||||
raise RuntimeError("User is not connected to Spotify")
|
||||
token_payload = await self.spotify.refresh_access_token(user.spotify_refresh_token)
|
||||
user.spotify_access_token = token_payload["access_token"]
|
||||
if token_payload.get("refresh_token"):
|
||||
user.spotify_refresh_token = token_payload["refresh_token"]
|
||||
user.spotify_token_expires_at = self.spotify.token_expiry_from_response(token_payload)
|
||||
if token_payload.get("scope"):
|
||||
user.spotify_scopes = token_payload["scope"]
|
||||
await session.flush()
|
||||
return user.spotify_access_token
|
||||
Reference in New Issue
Block a user