252 lines
9.0 KiB
Python
252 lines
9.0 KiB
Python
from __future__ import annotations
|
|
|
|
import unittest
|
|
from types import SimpleNamespace
|
|
|
|
from app.clients.spotify import SpotifyApiError
|
|
from app.services.recommendation import RecommendationEngine
|
|
from app.types import TrackCandidate
|
|
|
|
|
|
class DummyLastFm:
|
|
enabled = False
|
|
|
|
|
|
class RaisingLastFm:
|
|
enabled = True
|
|
|
|
async def track_similar(self, *, artist: str, track: str, limit: int = 20) -> list[dict]:
|
|
raise RuntimeError("Last.fm key invalid")
|
|
|
|
async def artist_similar(self, *, artist: str, limit: int = 15) -> list[dict]:
|
|
raise RuntimeError("Last.fm key invalid")
|
|
|
|
|
|
class StaticLastFm:
|
|
enabled = True
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
track_similar_results: dict[tuple[str, str], list[dict]] | None = None,
|
|
artist_similar_results: dict[str, list[dict]] | None = None,
|
|
) -> None:
|
|
self.track_similar_results = track_similar_results or {}
|
|
self.artist_similar_results = artist_similar_results or {}
|
|
|
|
async def track_similar(self, *, artist: str, track: str, limit: int = 20) -> list[dict]:
|
|
return list(self.track_similar_results.get((artist, track), []))
|
|
|
|
async def artist_similar(self, *, artist: str, limit: int = 15) -> list[dict]:
|
|
return list(self.artist_similar_results.get(artist, []))
|
|
|
|
|
|
class RecordingSpotifyStub:
|
|
def __init__(self) -> None:
|
|
self.recommendation_calls: list[tuple[list[str], list[str], str | None]] = []
|
|
self.top_tracks_calls: list[tuple[str, str]] = []
|
|
self.search_calls: list[tuple[str, str | None]] = []
|
|
|
|
self.raise_recommendations = False
|
|
self.raise_top_tracks = False
|
|
self.search_results_by_artist: dict[str, list[dict]] = {}
|
|
|
|
async def get_recommendations(
|
|
self,
|
|
access_token: str,
|
|
*,
|
|
seed_tracks: list[str],
|
|
seed_artists: list[str],
|
|
limit: int = 100,
|
|
market: str | None = None,
|
|
) -> list[dict]:
|
|
self.recommendation_calls.append((list(seed_tracks), list(seed_artists), market))
|
|
if self.raise_recommendations:
|
|
raise SpotifyApiError("recommendations disabled", 404, "")
|
|
return []
|
|
|
|
async def get_artist_top_tracks(self, access_token: str, artist_id: str, market: str) -> list[dict]:
|
|
self.top_tracks_calls.append((artist_id, market))
|
|
if self.raise_top_tracks:
|
|
raise SpotifyApiError("top tracks forbidden", 403, "")
|
|
return []
|
|
|
|
async def search_track(
|
|
self,
|
|
access_token: str,
|
|
*,
|
|
track_name: str,
|
|
artist_name: str | None = None,
|
|
market: str | None = None,
|
|
) -> list[dict]:
|
|
self.search_calls.append((track_name, artist_name))
|
|
if not artist_name:
|
|
return []
|
|
return list(self.search_results_by_artist.get(artist_name, []))
|
|
|
|
|
|
def make_engine(spotify_stub: RecordingSpotifyStub, lastfm=None) -> RecommendationEngine:
|
|
settings = SimpleNamespace(
|
|
recent_days_window=5,
|
|
spotify_default_market="US",
|
|
default_playlist_size=30,
|
|
min_new_ratio=0.8,
|
|
)
|
|
return RecommendationEngine(settings, spotify_stub, lastfm or DummyLastFm())
|
|
|
|
|
|
def fake_spotify_track(track_id: str, name: str, artist_id: str, artist_name: str, popularity: int = 50) -> dict:
|
|
return {
|
|
"id": track_id,
|
|
"uri": f"spotify:track:{track_id}",
|
|
"name": name,
|
|
"artists": [{"id": artist_id, "name": artist_name}],
|
|
"popularity": popularity,
|
|
}
|
|
|
|
|
|
class RecommendationEngineTests(unittest.IsolatedAsyncioTestCase):
|
|
async def test_collect_candidates_limits_recommendation_seeds_to_five(self) -> None:
|
|
spotify = RecordingSpotifyStub()
|
|
engine = make_engine(spotify)
|
|
seed = {
|
|
"seed_track_ids": [f"t{i}" for i in range(10)],
|
|
"seed_artists": [f"a{i}" for i in range(20)],
|
|
"seed_artist_names": [],
|
|
"recent_track_meta": {},
|
|
}
|
|
|
|
candidates = await engine._collect_candidates(access_token="token", seed=seed, market=None)
|
|
|
|
self.assertEqual(candidates, [])
|
|
self.assertEqual(len(spotify.recommendation_calls), 4)
|
|
self.assertTrue(
|
|
all((len(seed_tracks) + len(seed_artists)) <= 5 for seed_tracks, seed_artists, _ in spotify.recommendation_calls)
|
|
)
|
|
|
|
async def test_collect_candidates_uses_search_artist_fallback_when_other_sources_fail(self) -> None:
|
|
spotify = RecordingSpotifyStub()
|
|
spotify.raise_recommendations = True
|
|
spotify.raise_top_tracks = True
|
|
spotify.search_results_by_artist = {
|
|
"Artist One": [fake_spotify_track("c1", "Song 1", "ax1", "Artist One")],
|
|
"Artist Two": [fake_spotify_track("c2", "Song 2", "ax2", "Artist Two")],
|
|
}
|
|
engine = make_engine(spotify)
|
|
seed = {
|
|
"seed_track_ids": ["t1", "t2"],
|
|
"seed_artists": ["a1", "a2"],
|
|
"seed_artist_names": ["Artist One", "Artist Two"],
|
|
"recent_track_meta": {},
|
|
}
|
|
|
|
candidates = await engine._collect_candidates(access_token="token", seed=seed, market=None)
|
|
|
|
self.assertGreaterEqual(len(spotify.search_calls), 1)
|
|
self.assertEqual({c.id for c in candidates}, {"c1", "c2"})
|
|
self.assertTrue(all("spotify_search_artist" in c.source for c in candidates))
|
|
|
|
async def test_collect_candidates_tolerates_lastfm_errors(self) -> None:
|
|
spotify = RecordingSpotifyStub()
|
|
spotify.raise_recommendations = True
|
|
spotify.raise_top_tracks = True
|
|
spotify.search_results_by_artist = {
|
|
"Seed Artist": [fake_spotify_track("c1", "Song 1", "ax1", "Seed Artist")],
|
|
}
|
|
engine = make_engine(spotify, lastfm=RaisingLastFm())
|
|
seed = {
|
|
"seed_track_ids": ["t1"],
|
|
"seed_artists": ["a1"],
|
|
"seed_artist_names": ["Seed Artist"],
|
|
"recent_track_meta": {
|
|
"t1": {
|
|
"id": "t1",
|
|
"name": "Seed Track",
|
|
"artist_names": ["Seed Artist"],
|
|
}
|
|
},
|
|
}
|
|
|
|
candidates = await engine._collect_candidates(access_token="token", seed=seed, market=None)
|
|
|
|
self.assertEqual({c.id for c in candidates}, {"c1"})
|
|
self.assertTrue(any("spotify_search_artist" in c.source for c in candidates))
|
|
|
|
async def test_collect_candidates_uses_lastfm_artist_similar_search(self) -> None:
|
|
spotify = RecordingSpotifyStub()
|
|
spotify.raise_recommendations = True
|
|
spotify.raise_top_tracks = True
|
|
spotify.search_results_by_artist = {
|
|
"Similar Artist": [fake_spotify_track("lf1", "LF Song", "lfa1", "Similar Artist")],
|
|
}
|
|
lastfm = StaticLastFm(
|
|
artist_similar_results={
|
|
"Seed Artist": [{"name": "Similar Artist"}],
|
|
}
|
|
)
|
|
engine = make_engine(spotify, lastfm=lastfm)
|
|
seed = {
|
|
"seed_track_ids": [],
|
|
"seed_artists": ["a1"],
|
|
"seed_artist_names": ["Seed Artist"],
|
|
"recent_track_meta": {},
|
|
}
|
|
|
|
candidates = await engine._collect_candidates(access_token="token", seed=seed, market=None)
|
|
|
|
self.assertIn("lf1", {c.id for c in candidates})
|
|
self.assertTrue(any("lastfm_artist_similar" in c.source for c in candidates))
|
|
|
|
def test_normalize_spotify_market(self) -> None:
|
|
spotify = RecordingSpotifyStub()
|
|
engine = make_engine(spotify)
|
|
|
|
self.assertIsNone(engine._normalize_spotify_market("EU"))
|
|
self.assertIsNone(engine._normalize_spotify_market("global"))
|
|
self.assertEqual(engine._normalize_spotify_market("de"), "DE")
|
|
self.assertEqual(engine._normalize_spotify_market("US"), "US")
|
|
self.assertIsNone(engine._normalize_spotify_market("USA"))
|
|
self.assertIsNone(engine._normalize_spotify_market(""))
|
|
|
|
def test_rank_and_select_uses_liked_fallback_and_counts_as_reused(self) -> None:
|
|
spotify = RecordingSpotifyStub()
|
|
engine = make_engine(spotify)
|
|
candidates = [
|
|
TrackCandidate(
|
|
id="c1",
|
|
uri="spotify:track:c1",
|
|
name="Song 1",
|
|
artist_names=["Artist One"],
|
|
artist_ids=["a1"],
|
|
source="spotify_search_artist",
|
|
score=0.7,
|
|
),
|
|
TrackCandidate(
|
|
id="c2",
|
|
uri="spotify:track:c2",
|
|
name="Song 2",
|
|
artist_names=["Artist Two"],
|
|
artist_ids=["a2"],
|
|
source="spotify_search_artist",
|
|
score=0.6,
|
|
),
|
|
]
|
|
|
|
result = engine._rank_and_select(
|
|
candidates=candidates,
|
|
liked_ids={"c1", "c2"},
|
|
history_ids=set(),
|
|
target_size=2,
|
|
min_new_ratio=0.8,
|
|
)
|
|
|
|
self.assertEqual(len(result.tracks), 2)
|
|
self.assertEqual(result.new_count, 0)
|
|
self.assertEqual(result.reused_count, 2)
|
|
self.assertIsNotNone(result.notes)
|
|
self.assertIn("liked-track fallback", result.notes or "")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|