from __future__ import annotations import math import random from collections import Counter, defaultdict from datetime import date, timedelta from typing import Any from sqlalchemy.ext.asyncio import AsyncSession from app.clients.lastfm import LastFmClient from app.clients.spotify import SpotifyApiError, SpotifyClient from app.config import Settings from app.db.models import SavedTrack, User from app.db.repositories import RecommendationHistoryRepository, SavedTrackRepository from app.types import PlaylistBuildResult, TrackCandidate from app.utils.text import normalize_track_signature from app.utils.time import utcnow class RecommendationEngine: def __init__(self, settings: Settings, spotify: SpotifyClient, lastfm: LastFmClient) -> None: self.settings = settings self.spotify = spotify self.lastfm = lastfm async def sync_saved_tracks(self, session: AsyncSession, user: User, access_token: str) -> list[SavedTrack]: saved_tracks_repo = SavedTrackRepository(session) raw = await self.spotify.get_saved_tracks_all(access_token) await saved_tracks_repo.replace_for_user(user.id, raw) return await saved_tracks_repo.list_for_user(user.id) async def build_daily_playlist(self, session: AsyncSession, user: User, access_token: str) -> PlaylistBuildResult: saved_tracks_repo = SavedTrackRepository(session) history_repo = RecommendationHistoryRepository(session) saved_rows = await saved_tracks_repo.list_for_user(user.id) if not saved_rows: saved_rows = await self.sync_saved_tracks(session, user, access_token) recent_since = utcnow() - timedelta(days=self.settings.recent_days_window) recent_plays = await self.spotify.get_recently_played(access_token, since=recent_since, max_pages=12) seed = self._build_seed_profile(saved_rows, recent_plays, user_id=user.id) history_ids = await history_repo.list_track_ids(user.id) liked_ids = {row.spotify_track_id for row in saved_rows} market = self._normalize_spotify_market(self.settings.spotify_default_market) candidates = await self._collect_candidates( access_token=access_token, seed=seed, market=market, ) result = self._rank_and_select( candidates=candidates, liked_ids=liked_ids, history_ids=history_ids, target_size=max(1, user.playlist_size or self.settings.default_playlist_size), min_new_ratio=user.min_new_ratio if user.min_new_ratio is not None else self.settings.min_new_ratio, ) return result def _build_seed_profile(self, saved_rows: list[SavedTrack], recent_plays: list[dict[str, Any]], *, user_id: int) -> dict[str, Any]: today = date.today() rng = random.Random(f"{user_id}-{today.isoformat()}") recent_track_counts: Counter[str] = Counter() recent_track_meta: dict[str, dict[str, Any]] = {} artist_weights: Counter[str] = Counter() artist_names: dict[str, str] = {} for idx, play in enumerate(sorted(recent_plays, key=lambda x: x.get("played_at") or utcnow(), reverse=True)): track_id = play["id"] weight = max(1.0, 3.0 - (idx * 0.04)) recent_track_counts[track_id] += weight recent_track_meta[track_id] = play for artist_id, artist_name in zip(play.get("artist_ids", []), play.get("artist_names", [])): artist_weights[artist_id] += weight artist_names[artist_id] = artist_name sorted_saved = sorted(saved_rows, key=lambda x: x.added_at or utcnow(), reverse=True) recent_likes = sorted_saved[:120] sampled_older = rng.sample(sorted_saved[120:], k=min(180, max(0, len(sorted_saved) - 120))) if len(sorted_saved) > 120 else [] exploration_pool = recent_likes + sampled_older for idx, row in enumerate(exploration_pool): base_weight = 1.2 if idx < 50 else 0.6 artist_ids = [a for a in row.artist_ids_csv.split(",") if a] artist_list = [x.strip() for x in row.artist_names.split(",") if x.strip()] for artist_id, artist_name in zip(artist_ids, artist_list): artist_weights[artist_id] += base_weight artist_names[artist_id] = artist_name seed_track_ids = [t for t, _ in recent_track_counts.most_common(10)] if len(seed_track_ids) < 10: for row in recent_likes[:20]: if row.spotify_track_id not in seed_track_ids: seed_track_ids.append(row.spotify_track_id) if len(seed_track_ids) >= 10: break seed_artist_ids = [a for a, _ in artist_weights.most_common(20)] seed_artist_names = [artist_names[a] for a in seed_artist_ids if a in artist_names] return { "seed_track_ids": seed_track_ids, "seed_artists": seed_artist_ids, "seed_artist_names": seed_artist_names, "recent_track_meta": recent_track_meta, } async def _collect_candidates(self, *, access_token: str, seed: dict[str, Any], market: str | None) -> list[TrackCandidate]: by_id: dict[str, TrackCandidate] = {} sig_to_id: dict[str, str] = {} source_count: Counter[str] = Counter() recent_track_meta = seed["recent_track_meta"] def upsert(candidate: TrackCandidate) -> None: sig = normalize_track_signature(candidate.name, candidate.artist_names) existing_id = sig_to_id.get(sig) if existing_id and existing_id != candidate.id and existing_id in by_id: if candidate.score <= by_id[existing_id].score: return del by_id[existing_id] existing = by_id.get(candidate.id) if existing: existing.score = max(existing.score, candidate.score) if candidate.source not in existing.source: existing.source = f"{existing.source}+{candidate.source}" for reason in candidate.seed_reasons: if reason not in existing.seed_reasons: existing.seed_reasons.append(reason) return by_id[candidate.id] = candidate sig_to_id[sig] = candidate.id source_count[candidate.source] += 1 seed_tracks = list(seed["seed_track_ids"]) seed_artists = list(seed["seed_artists"]) top_tracks_market = market or "US" for batch_idx in range(4): # Spotify recommendations endpoint supports max 5 total seeds. track_start = batch_idx * 2 artist_start = batch_idx * 3 batch_seed_tracks = seed_tracks[track_start : track_start + 2] remaining_slots = max(0, 5 - len(batch_seed_tracks)) batch_seed_artists = seed_artists[artist_start : artist_start + remaining_slots] if not batch_seed_tracks and not batch_seed_artists: continue try: rec_tracks = await self.spotify.get_recommendations( access_token, seed_tracks=batch_seed_tracks, seed_artists=batch_seed_artists, limit=100, market=market, ) except SpotifyApiError: rec_tracks = [] for raw in rec_tracks: cand = self._candidate_from_spotify_track(raw, source="spotify_recommendations", base_score=1.0) if not cand: continue if any(a in batch_seed_artists for a in cand.artist_ids): cand.score += 0.08 upsert(cand) for artist_id in seed_artists[:12]: try: top_tracks = await self.spotify.get_artist_top_tracks( access_token, artist_id=artist_id, market=top_tracks_market ) except SpotifyApiError: continue for raw in top_tracks: cand = self._candidate_from_spotify_track(raw, source="artist_top_tracks", base_score=0.68) if not cand: continue if artist_id in cand.artist_ids: cand.score += 0.07 upsert(cand) # Fallback for apps/accounts where top-tracks and recommendations endpoints are restricted. if len(by_id) < 40: for artist_name in seed.get("seed_artist_names", [])[:12]: if not artist_name: continue try: search_hits = await self.spotify.search_track( access_token, track_name="", artist_name=artist_name, market=market, ) except SpotifyApiError: continue for raw in search_hits[:3]: cand = self._candidate_from_spotify_track(raw, source="spotify_search_artist", base_score=0.55) if not cand: continue if artist_name.lower() in {a.lower() for a in cand.artist_names}: cand.score += 0.05 upsert(cand) if self.lastfm.enabled: for track_id in seed_tracks[:10]: meta = recent_track_meta.get(track_id) if not meta: continue artist_name = (meta.get("artist_names") or [None])[0] if not artist_name: continue try: similars = await self.lastfm.track_similar(artist=artist_name, track=meta["name"], limit=10) except Exception: similars = [] for item in similars[:5]: lf_name = item.get("name") lf_artist = (item.get("artist") or {}).get("name") if isinstance(item.get("artist"), dict) else None if not lf_name: continue try: search_hits = await self.spotify.search_track( access_token, track_name=lf_name, artist_name=lf_artist, market=market ) except SpotifyApiError: search_hits = [] for raw in search_hits[:1]: cand = self._candidate_from_spotify_track(raw, source="lastfm_track_similar", base_score=0.9) if cand: upsert(cand) for artist_name in seed.get("seed_artist_names", [])[:8]: try: similars = await self.lastfm.artist_similar(artist=artist_name, limit=8) except Exception: similars = [] for item in similars[:4]: sim_artist = item.get("name") if not sim_artist: continue try: search_hits = await self.spotify.search_track( access_token, track_name="", artist_name=sim_artist, market=market ) except SpotifyApiError: search_hits = [] for raw in search_hits[:2]: cand = self._candidate_from_spotify_track(raw, source="lastfm_artist_similar", base_score=0.78) if cand: upsert(cand) return list(by_id.values()) def _candidate_from_spotify_track(self, raw: dict[str, Any], *, source: str, base_score: float) -> TrackCandidate | None: track_id = raw.get("id") if not track_id: return None artists = raw.get("artists") or [] artist_names = [a.get("name") or "Unknown" for a in artists] artist_ids = [a.get("id") for a in artists if a.get("id")] popularity = raw.get("popularity") score = base_score if isinstance(popularity, int): # Prefer mid-popularity a bit to avoid obvious mainstream repeats and totally obscure misses. score += max(-0.12, 0.15 - abs(popularity - 55) / 250) return TrackCandidate( id=track_id, uri=raw.get("uri") or f"spotify:track:{track_id}", name=raw.get("name") or "Unknown", artist_names=artist_names, artist_ids=artist_ids, popularity=popularity, source=source, score=score, ) def _rank_and_select( self, *, candidates: list[TrackCandidate], liked_ids: set[str], history_ids: set[str], target_size: int, min_new_ratio: float, ) -> PlaylistBuildResult: min_new_required = math.ceil(target_size * min_new_ratio) filtered = [c for c in candidates if c.id not in liked_ids] liked_fallback_used = False if not filtered and candidates: # If every discovered candidate is already liked, degrade gracefully instead of failing the run. filtered = list(candidates) liked_fallback_used = True for c in filtered: if c.id in liked_ids: c.score -= 0.35 if c.id in history_ids: c.score -= 0.2 if len(c.artist_ids) > 1: c.score += 0.03 c.score += min(0.15, 0.01 * len(c.seed_reasons)) filtered.sort(key=lambda c: (c.score, c.popularity or 0), reverse=True) novel = [c for c in filtered if c.id not in history_ids and c.id not in liked_ids] reused = [c for c in filtered if c.id in history_ids or c.id in liked_ids] selected: list[TrackCandidate] = [] artist_caps: defaultdict[str, int] = defaultdict(int) def try_take(pool: list[TrackCandidate], count: int, hard_artist_cap: int) -> None: for c in pool: if len(selected) >= count: return if any(s.id == c.id for s in selected): continue main_artist = c.artist_ids[0] if c.artist_ids else f"name:{(c.artist_names or [''])[0].lower()}" if artist_caps[main_artist] >= hard_artist_cap: continue artist_caps[main_artist] += 1 selected.append(c) try_take(novel, min_new_required, hard_artist_cap=2) if len([c for c in selected if c.id not in history_ids and c.id not in liked_ids]) < min_new_required: try_take(novel, min_new_required, hard_artist_cap=4) try_take(novel, target_size, hard_artist_cap=3) try_take(reused, target_size, hard_artist_cap=2) if len(selected) < target_size: try_take(reused, target_size, hard_artist_cap=4) # Mark novelty for persistence. for c in selected: c.seed_reasons = c.seed_reasons or [] new_count = len([c for c in selected if c.id not in history_ids and c.id not in liked_ids]) reused_count = len(selected) - new_count notes_parts: list[str] = [] if liked_fallback_used: notes_parts.append("All discovered candidates were already in Liked Songs; allowed liked-track fallback.") if new_count < min_new_required: notes_parts.append( f"Not enough completely new tracks to satisfy {int(min_new_ratio * 100)}% target " f"(got {new_count}/{target_size})." ) notes = " ".join(notes_parts) if notes_parts else None return PlaylistBuildResult( tracks=selected, target_size=target_size, new_count=new_count, reused_count=reused_count, min_new_required=min_new_required, notes=notes, ) @staticmethod def _normalize_spotify_market(value: str | None) -> str | None: if not value: return None market = value.strip().upper() if not market: return None # Common shorthand users put in .env, but Spotify APIs expect a country code. if market in {"EU", "GLOBAL", "WORLD", "ALL"}: return None if len(market) == 2 and market.isalpha(): return market return None