375 lines
16 KiB
Python
375 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
import random
|
|
from collections import Counter, defaultdict
|
|
from datetime import date, timedelta
|
|
from typing import Any
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.clients.lastfm import LastFmClient
|
|
from app.clients.spotify import SpotifyApiError, SpotifyClient
|
|
from app.config import Settings
|
|
from app.db.models import SavedTrack, User
|
|
from app.db.repositories import RecommendationHistoryRepository, SavedTrackRepository
|
|
from app.types import PlaylistBuildResult, TrackCandidate
|
|
from app.utils.text import normalize_track_signature
|
|
from app.utils.time import utcnow
|
|
|
|
|
|
class RecommendationEngine:
|
|
def __init__(self, settings: Settings, spotify: SpotifyClient, lastfm: LastFmClient) -> None:
|
|
self.settings = settings
|
|
self.spotify = spotify
|
|
self.lastfm = lastfm
|
|
|
|
async def sync_saved_tracks(self, session: AsyncSession, user: User, access_token: str) -> list[SavedTrack]:
|
|
saved_tracks_repo = SavedTrackRepository(session)
|
|
raw = await self.spotify.get_saved_tracks_all(access_token)
|
|
await saved_tracks_repo.replace_for_user(user.id, raw)
|
|
return await saved_tracks_repo.list_for_user(user.id)
|
|
|
|
async def build_daily_playlist(self, session: AsyncSession, user: User, access_token: str) -> PlaylistBuildResult:
|
|
saved_tracks_repo = SavedTrackRepository(session)
|
|
history_repo = RecommendationHistoryRepository(session)
|
|
|
|
saved_rows = await saved_tracks_repo.list_for_user(user.id)
|
|
if not saved_rows:
|
|
saved_rows = await self.sync_saved_tracks(session, user, access_token)
|
|
|
|
recent_since = utcnow() - timedelta(days=self.settings.recent_days_window)
|
|
recent_plays = await self.spotify.get_recently_played(access_token, since=recent_since, max_pages=12)
|
|
|
|
seed = self._build_seed_profile(saved_rows, recent_plays, user_id=user.id)
|
|
history_ids = await history_repo.list_track_ids(user.id)
|
|
liked_ids = {row.spotify_track_id for row in saved_rows}
|
|
|
|
market = self._normalize_spotify_market(self.settings.spotify_default_market)
|
|
candidates = await self._collect_candidates(
|
|
access_token=access_token,
|
|
seed=seed,
|
|
market=market,
|
|
)
|
|
result = self._rank_and_select(
|
|
candidates=candidates,
|
|
liked_ids=liked_ids,
|
|
history_ids=history_ids,
|
|
target_size=max(1, user.playlist_size or self.settings.default_playlist_size),
|
|
min_new_ratio=user.min_new_ratio if user.min_new_ratio is not None else self.settings.min_new_ratio,
|
|
)
|
|
return result
|
|
|
|
def _build_seed_profile(self, saved_rows: list[SavedTrack], recent_plays: list[dict[str, Any]], *, user_id: int) -> dict[str, Any]:
|
|
today = date.today()
|
|
rng = random.Random(f"{user_id}-{today.isoformat()}")
|
|
|
|
recent_track_counts: Counter[str] = Counter()
|
|
recent_track_meta: dict[str, dict[str, Any]] = {}
|
|
artist_weights: Counter[str] = Counter()
|
|
artist_names: dict[str, str] = {}
|
|
|
|
for idx, play in enumerate(sorted(recent_plays, key=lambda x: x.get("played_at") or utcnow(), reverse=True)):
|
|
track_id = play["id"]
|
|
weight = max(1.0, 3.0 - (idx * 0.04))
|
|
recent_track_counts[track_id] += weight
|
|
recent_track_meta[track_id] = play
|
|
for artist_id, artist_name in zip(play.get("artist_ids", []), play.get("artist_names", [])):
|
|
artist_weights[artist_id] += weight
|
|
artist_names[artist_id] = artist_name
|
|
|
|
sorted_saved = sorted(saved_rows, key=lambda x: x.added_at or utcnow(), reverse=True)
|
|
recent_likes = sorted_saved[:120]
|
|
sampled_older = rng.sample(sorted_saved[120:], k=min(180, max(0, len(sorted_saved) - 120))) if len(sorted_saved) > 120 else []
|
|
exploration_pool = recent_likes + sampled_older
|
|
|
|
for idx, row in enumerate(exploration_pool):
|
|
base_weight = 1.2 if idx < 50 else 0.6
|
|
artist_ids = [a for a in row.artist_ids_csv.split(",") if a]
|
|
artist_list = [x.strip() for x in row.artist_names.split(",") if x.strip()]
|
|
for artist_id, artist_name in zip(artist_ids, artist_list):
|
|
artist_weights[artist_id] += base_weight
|
|
artist_names[artist_id] = artist_name
|
|
|
|
seed_track_ids = [t for t, _ in recent_track_counts.most_common(10)]
|
|
if len(seed_track_ids) < 10:
|
|
for row in recent_likes[:20]:
|
|
if row.spotify_track_id not in seed_track_ids:
|
|
seed_track_ids.append(row.spotify_track_id)
|
|
if len(seed_track_ids) >= 10:
|
|
break
|
|
|
|
seed_artist_ids = [a for a, _ in artist_weights.most_common(20)]
|
|
seed_artist_names = [artist_names[a] for a in seed_artist_ids if a in artist_names]
|
|
|
|
return {
|
|
"seed_track_ids": seed_track_ids,
|
|
"seed_artists": seed_artist_ids,
|
|
"seed_artist_names": seed_artist_names,
|
|
"recent_track_meta": recent_track_meta,
|
|
}
|
|
|
|
async def _collect_candidates(self, *, access_token: str, seed: dict[str, Any], market: str | None) -> list[TrackCandidate]:
|
|
by_id: dict[str, TrackCandidate] = {}
|
|
sig_to_id: dict[str, str] = {}
|
|
source_count: Counter[str] = Counter()
|
|
recent_track_meta = seed["recent_track_meta"]
|
|
|
|
def upsert(candidate: TrackCandidate) -> None:
|
|
sig = normalize_track_signature(candidate.name, candidate.artist_names)
|
|
existing_id = sig_to_id.get(sig)
|
|
if existing_id and existing_id != candidate.id and existing_id in by_id:
|
|
if candidate.score <= by_id[existing_id].score:
|
|
return
|
|
del by_id[existing_id]
|
|
existing = by_id.get(candidate.id)
|
|
if existing:
|
|
existing.score = max(existing.score, candidate.score)
|
|
if candidate.source not in existing.source:
|
|
existing.source = f"{existing.source}+{candidate.source}"
|
|
for reason in candidate.seed_reasons:
|
|
if reason not in existing.seed_reasons:
|
|
existing.seed_reasons.append(reason)
|
|
return
|
|
by_id[candidate.id] = candidate
|
|
sig_to_id[sig] = candidate.id
|
|
source_count[candidate.source] += 1
|
|
|
|
seed_tracks = list(seed["seed_track_ids"])
|
|
seed_artists = list(seed["seed_artists"])
|
|
top_tracks_market = market or "US"
|
|
|
|
for batch_idx in range(4):
|
|
# Spotify recommendations endpoint supports max 5 total seeds.
|
|
track_start = batch_idx * 2
|
|
artist_start = batch_idx * 3
|
|
batch_seed_tracks = seed_tracks[track_start : track_start + 2]
|
|
remaining_slots = max(0, 5 - len(batch_seed_tracks))
|
|
batch_seed_artists = seed_artists[artist_start : artist_start + remaining_slots]
|
|
if not batch_seed_tracks and not batch_seed_artists:
|
|
continue
|
|
try:
|
|
rec_tracks = await self.spotify.get_recommendations(
|
|
access_token,
|
|
seed_tracks=batch_seed_tracks,
|
|
seed_artists=batch_seed_artists,
|
|
limit=100,
|
|
market=market,
|
|
)
|
|
except SpotifyApiError:
|
|
rec_tracks = []
|
|
for raw in rec_tracks:
|
|
cand = self._candidate_from_spotify_track(raw, source="spotify_recommendations", base_score=1.0)
|
|
if not cand:
|
|
continue
|
|
if any(a in batch_seed_artists for a in cand.artist_ids):
|
|
cand.score += 0.08
|
|
upsert(cand)
|
|
|
|
for artist_id in seed_artists[:12]:
|
|
try:
|
|
top_tracks = await self.spotify.get_artist_top_tracks(
|
|
access_token, artist_id=artist_id, market=top_tracks_market
|
|
)
|
|
except SpotifyApiError:
|
|
continue
|
|
for raw in top_tracks:
|
|
cand = self._candidate_from_spotify_track(raw, source="artist_top_tracks", base_score=0.68)
|
|
if not cand:
|
|
continue
|
|
if artist_id in cand.artist_ids:
|
|
cand.score += 0.07
|
|
upsert(cand)
|
|
|
|
# Fallback for apps/accounts where top-tracks and recommendations endpoints are restricted.
|
|
if len(by_id) < 40:
|
|
for artist_name in seed.get("seed_artist_names", [])[:12]:
|
|
if not artist_name:
|
|
continue
|
|
try:
|
|
search_hits = await self.spotify.search_track(
|
|
access_token,
|
|
track_name="",
|
|
artist_name=artist_name,
|
|
market=market,
|
|
)
|
|
except SpotifyApiError:
|
|
continue
|
|
for raw in search_hits[:3]:
|
|
cand = self._candidate_from_spotify_track(raw, source="spotify_search_artist", base_score=0.55)
|
|
if not cand:
|
|
continue
|
|
if artist_name.lower() in {a.lower() for a in cand.artist_names}:
|
|
cand.score += 0.05
|
|
upsert(cand)
|
|
|
|
if self.lastfm.enabled:
|
|
for track_id in seed_tracks[:10]:
|
|
meta = recent_track_meta.get(track_id)
|
|
if not meta:
|
|
continue
|
|
artist_name = (meta.get("artist_names") or [None])[0]
|
|
if not artist_name:
|
|
continue
|
|
try:
|
|
similars = await self.lastfm.track_similar(artist=artist_name, track=meta["name"], limit=10)
|
|
except Exception:
|
|
similars = []
|
|
for item in similars[:5]:
|
|
lf_name = item.get("name")
|
|
lf_artist = (item.get("artist") or {}).get("name") if isinstance(item.get("artist"), dict) else None
|
|
if not lf_name:
|
|
continue
|
|
try:
|
|
search_hits = await self.spotify.search_track(
|
|
access_token, track_name=lf_name, artist_name=lf_artist, market=market
|
|
)
|
|
except SpotifyApiError:
|
|
search_hits = []
|
|
for raw in search_hits[:1]:
|
|
cand = self._candidate_from_spotify_track(raw, source="lastfm_track_similar", base_score=0.9)
|
|
if cand:
|
|
upsert(cand)
|
|
|
|
for artist_name in seed.get("seed_artist_names", [])[:8]:
|
|
try:
|
|
similars = await self.lastfm.artist_similar(artist=artist_name, limit=8)
|
|
except Exception:
|
|
similars = []
|
|
for item in similars[:4]:
|
|
sim_artist = item.get("name")
|
|
if not sim_artist:
|
|
continue
|
|
try:
|
|
search_hits = await self.spotify.search_track(
|
|
access_token, track_name="", artist_name=sim_artist, market=market
|
|
)
|
|
except SpotifyApiError:
|
|
search_hits = []
|
|
for raw in search_hits[:2]:
|
|
cand = self._candidate_from_spotify_track(raw, source="lastfm_artist_similar", base_score=0.78)
|
|
if cand:
|
|
upsert(cand)
|
|
|
|
return list(by_id.values())
|
|
|
|
def _candidate_from_spotify_track(self, raw: dict[str, Any], *, source: str, base_score: float) -> TrackCandidate | None:
|
|
track_id = raw.get("id")
|
|
if not track_id:
|
|
return None
|
|
artists = raw.get("artists") or []
|
|
artist_names = [a.get("name") or "Unknown" for a in artists]
|
|
artist_ids = [a.get("id") for a in artists if a.get("id")]
|
|
popularity = raw.get("popularity")
|
|
|
|
score = base_score
|
|
if isinstance(popularity, int):
|
|
# Prefer mid-popularity a bit to avoid obvious mainstream repeats and totally obscure misses.
|
|
score += max(-0.12, 0.15 - abs(popularity - 55) / 250)
|
|
|
|
return TrackCandidate(
|
|
id=track_id,
|
|
uri=raw.get("uri") or f"spotify:track:{track_id}",
|
|
name=raw.get("name") or "Unknown",
|
|
artist_names=artist_names,
|
|
artist_ids=artist_ids,
|
|
popularity=popularity,
|
|
source=source,
|
|
score=score,
|
|
)
|
|
|
|
def _rank_and_select(
|
|
self,
|
|
*,
|
|
candidates: list[TrackCandidate],
|
|
liked_ids: set[str],
|
|
history_ids: set[str],
|
|
target_size: int,
|
|
min_new_ratio: float,
|
|
) -> PlaylistBuildResult:
|
|
min_new_required = math.ceil(target_size * min_new_ratio)
|
|
|
|
filtered = [c for c in candidates if c.id not in liked_ids]
|
|
liked_fallback_used = False
|
|
if not filtered and candidates:
|
|
# If every discovered candidate is already liked, degrade gracefully instead of failing the run.
|
|
filtered = list(candidates)
|
|
liked_fallback_used = True
|
|
for c in filtered:
|
|
if c.id in liked_ids:
|
|
c.score -= 0.35
|
|
if c.id in history_ids:
|
|
c.score -= 0.2
|
|
if len(c.artist_ids) > 1:
|
|
c.score += 0.03
|
|
c.score += min(0.15, 0.01 * len(c.seed_reasons))
|
|
|
|
filtered.sort(key=lambda c: (c.score, c.popularity or 0), reverse=True)
|
|
|
|
novel = [c for c in filtered if c.id not in history_ids and c.id not in liked_ids]
|
|
reused = [c for c in filtered if c.id in history_ids or c.id in liked_ids]
|
|
|
|
selected: list[TrackCandidate] = []
|
|
artist_caps: defaultdict[str, int] = defaultdict(int)
|
|
|
|
def try_take(pool: list[TrackCandidate], count: int, hard_artist_cap: int) -> None:
|
|
for c in pool:
|
|
if len(selected) >= count:
|
|
return
|
|
if any(s.id == c.id for s in selected):
|
|
continue
|
|
main_artist = c.artist_ids[0] if c.artist_ids else f"name:{(c.artist_names or [''])[0].lower()}"
|
|
if artist_caps[main_artist] >= hard_artist_cap:
|
|
continue
|
|
artist_caps[main_artist] += 1
|
|
selected.append(c)
|
|
|
|
try_take(novel, min_new_required, hard_artist_cap=2)
|
|
if len([c for c in selected if c.id not in history_ids and c.id not in liked_ids]) < min_new_required:
|
|
try_take(novel, min_new_required, hard_artist_cap=4)
|
|
|
|
try_take(novel, target_size, hard_artist_cap=3)
|
|
try_take(reused, target_size, hard_artist_cap=2)
|
|
if len(selected) < target_size:
|
|
try_take(reused, target_size, hard_artist_cap=4)
|
|
|
|
# Mark novelty for persistence.
|
|
for c in selected:
|
|
c.seed_reasons = c.seed_reasons or []
|
|
|
|
new_count = len([c for c in selected if c.id not in history_ids and c.id not in liked_ids])
|
|
reused_count = len(selected) - new_count
|
|
|
|
notes_parts: list[str] = []
|
|
if liked_fallback_used:
|
|
notes_parts.append("All discovered candidates were already in Liked Songs; allowed liked-track fallback.")
|
|
if new_count < min_new_required:
|
|
notes_parts.append(
|
|
f"Not enough completely new tracks to satisfy {int(min_new_ratio * 100)}% target "
|
|
f"(got {new_count}/{target_size})."
|
|
)
|
|
notes = " ".join(notes_parts) if notes_parts else None
|
|
|
|
return PlaylistBuildResult(
|
|
tracks=selected,
|
|
target_size=target_size,
|
|
new_count=new_count,
|
|
reused_count=reused_count,
|
|
min_new_required=min_new_required,
|
|
notes=notes,
|
|
)
|
|
|
|
@staticmethod
|
|
def _normalize_spotify_market(value: str | None) -> str | None:
|
|
if not value:
|
|
return None
|
|
market = value.strip().upper()
|
|
if not market:
|
|
return None
|
|
# Common shorthand users put in .env, but Spotify APIs expect a country code.
|
|
if market in {"EU", "GLOBAL", "WORLD", "ALL"}:
|
|
return None
|
|
if len(market) == 2 and market.isalpha():
|
|
return market
|
|
return None
|