A kind of initial commit
This commit is contained in:
251
tests/test_recommendation_engine.py
Normal file
251
tests/test_recommendation_engine.py
Normal file
@@ -0,0 +1,251 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.clients.spotify import SpotifyApiError
|
||||
from app.services.recommendation import RecommendationEngine
|
||||
from app.types import TrackCandidate
|
||||
|
||||
|
||||
class DummyLastFm:
|
||||
enabled = False
|
||||
|
||||
|
||||
class RaisingLastFm:
|
||||
enabled = True
|
||||
|
||||
async def track_similar(self, *, artist: str, track: str, limit: int = 20) -> list[dict]:
|
||||
raise RuntimeError("Last.fm key invalid")
|
||||
|
||||
async def artist_similar(self, *, artist: str, limit: int = 15) -> list[dict]:
|
||||
raise RuntimeError("Last.fm key invalid")
|
||||
|
||||
|
||||
class StaticLastFm:
|
||||
enabled = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
track_similar_results: dict[tuple[str, str], list[dict]] | None = None,
|
||||
artist_similar_results: dict[str, list[dict]] | None = None,
|
||||
) -> None:
|
||||
self.track_similar_results = track_similar_results or {}
|
||||
self.artist_similar_results = artist_similar_results or {}
|
||||
|
||||
async def track_similar(self, *, artist: str, track: str, limit: int = 20) -> list[dict]:
|
||||
return list(self.track_similar_results.get((artist, track), []))
|
||||
|
||||
async def artist_similar(self, *, artist: str, limit: int = 15) -> list[dict]:
|
||||
return list(self.artist_similar_results.get(artist, []))
|
||||
|
||||
|
||||
class RecordingSpotifyStub:
|
||||
def __init__(self) -> None:
|
||||
self.recommendation_calls: list[tuple[list[str], list[str], str | None]] = []
|
||||
self.top_tracks_calls: list[tuple[str, str]] = []
|
||||
self.search_calls: list[tuple[str, str | None]] = []
|
||||
|
||||
self.raise_recommendations = False
|
||||
self.raise_top_tracks = False
|
||||
self.search_results_by_artist: dict[str, list[dict]] = {}
|
||||
|
||||
async def get_recommendations(
|
||||
self,
|
||||
access_token: str,
|
||||
*,
|
||||
seed_tracks: list[str],
|
||||
seed_artists: list[str],
|
||||
limit: int = 100,
|
||||
market: str | None = None,
|
||||
) -> list[dict]:
|
||||
self.recommendation_calls.append((list(seed_tracks), list(seed_artists), market))
|
||||
if self.raise_recommendations:
|
||||
raise SpotifyApiError("recommendations disabled", 404, "")
|
||||
return []
|
||||
|
||||
async def get_artist_top_tracks(self, access_token: str, artist_id: str, market: str) -> list[dict]:
|
||||
self.top_tracks_calls.append((artist_id, market))
|
||||
if self.raise_top_tracks:
|
||||
raise SpotifyApiError("top tracks forbidden", 403, "")
|
||||
return []
|
||||
|
||||
async def search_track(
|
||||
self,
|
||||
access_token: str,
|
||||
*,
|
||||
track_name: str,
|
||||
artist_name: str | None = None,
|
||||
market: str | None = None,
|
||||
) -> list[dict]:
|
||||
self.search_calls.append((track_name, artist_name))
|
||||
if not artist_name:
|
||||
return []
|
||||
return list(self.search_results_by_artist.get(artist_name, []))
|
||||
|
||||
|
||||
def make_engine(spotify_stub: RecordingSpotifyStub, lastfm=None) -> RecommendationEngine:
|
||||
settings = SimpleNamespace(
|
||||
recent_days_window=5,
|
||||
spotify_default_market="US",
|
||||
default_playlist_size=30,
|
||||
min_new_ratio=0.8,
|
||||
)
|
||||
return RecommendationEngine(settings, spotify_stub, lastfm or DummyLastFm())
|
||||
|
||||
|
||||
def fake_spotify_track(track_id: str, name: str, artist_id: str, artist_name: str, popularity: int = 50) -> dict:
|
||||
return {
|
||||
"id": track_id,
|
||||
"uri": f"spotify:track:{track_id}",
|
||||
"name": name,
|
||||
"artists": [{"id": artist_id, "name": artist_name}],
|
||||
"popularity": popularity,
|
||||
}
|
||||
|
||||
|
||||
class RecommendationEngineTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_collect_candidates_limits_recommendation_seeds_to_five(self) -> None:
|
||||
spotify = RecordingSpotifyStub()
|
||||
engine = make_engine(spotify)
|
||||
seed = {
|
||||
"seed_track_ids": [f"t{i}" for i in range(10)],
|
||||
"seed_artists": [f"a{i}" for i in range(20)],
|
||||
"seed_artist_names": [],
|
||||
"recent_track_meta": {},
|
||||
}
|
||||
|
||||
candidates = await engine._collect_candidates(access_token="token", seed=seed, market=None)
|
||||
|
||||
self.assertEqual(candidates, [])
|
||||
self.assertEqual(len(spotify.recommendation_calls), 4)
|
||||
self.assertTrue(
|
||||
all((len(seed_tracks) + len(seed_artists)) <= 5 for seed_tracks, seed_artists, _ in spotify.recommendation_calls)
|
||||
)
|
||||
|
||||
async def test_collect_candidates_uses_search_artist_fallback_when_other_sources_fail(self) -> None:
|
||||
spotify = RecordingSpotifyStub()
|
||||
spotify.raise_recommendations = True
|
||||
spotify.raise_top_tracks = True
|
||||
spotify.search_results_by_artist = {
|
||||
"Artist One": [fake_spotify_track("c1", "Song 1", "ax1", "Artist One")],
|
||||
"Artist Two": [fake_spotify_track("c2", "Song 2", "ax2", "Artist Two")],
|
||||
}
|
||||
engine = make_engine(spotify)
|
||||
seed = {
|
||||
"seed_track_ids": ["t1", "t2"],
|
||||
"seed_artists": ["a1", "a2"],
|
||||
"seed_artist_names": ["Artist One", "Artist Two"],
|
||||
"recent_track_meta": {},
|
||||
}
|
||||
|
||||
candidates = await engine._collect_candidates(access_token="token", seed=seed, market=None)
|
||||
|
||||
self.assertGreaterEqual(len(spotify.search_calls), 1)
|
||||
self.assertEqual({c.id for c in candidates}, {"c1", "c2"})
|
||||
self.assertTrue(all("spotify_search_artist" in c.source for c in candidates))
|
||||
|
||||
async def test_collect_candidates_tolerates_lastfm_errors(self) -> None:
|
||||
spotify = RecordingSpotifyStub()
|
||||
spotify.raise_recommendations = True
|
||||
spotify.raise_top_tracks = True
|
||||
spotify.search_results_by_artist = {
|
||||
"Seed Artist": [fake_spotify_track("c1", "Song 1", "ax1", "Seed Artist")],
|
||||
}
|
||||
engine = make_engine(spotify, lastfm=RaisingLastFm())
|
||||
seed = {
|
||||
"seed_track_ids": ["t1"],
|
||||
"seed_artists": ["a1"],
|
||||
"seed_artist_names": ["Seed Artist"],
|
||||
"recent_track_meta": {
|
||||
"t1": {
|
||||
"id": "t1",
|
||||
"name": "Seed Track",
|
||||
"artist_names": ["Seed Artist"],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
candidates = await engine._collect_candidates(access_token="token", seed=seed, market=None)
|
||||
|
||||
self.assertEqual({c.id for c in candidates}, {"c1"})
|
||||
self.assertTrue(any("spotify_search_artist" in c.source for c in candidates))
|
||||
|
||||
async def test_collect_candidates_uses_lastfm_artist_similar_search(self) -> None:
|
||||
spotify = RecordingSpotifyStub()
|
||||
spotify.raise_recommendations = True
|
||||
spotify.raise_top_tracks = True
|
||||
spotify.search_results_by_artist = {
|
||||
"Similar Artist": [fake_spotify_track("lf1", "LF Song", "lfa1", "Similar Artist")],
|
||||
}
|
||||
lastfm = StaticLastFm(
|
||||
artist_similar_results={
|
||||
"Seed Artist": [{"name": "Similar Artist"}],
|
||||
}
|
||||
)
|
||||
engine = make_engine(spotify, lastfm=lastfm)
|
||||
seed = {
|
||||
"seed_track_ids": [],
|
||||
"seed_artists": ["a1"],
|
||||
"seed_artist_names": ["Seed Artist"],
|
||||
"recent_track_meta": {},
|
||||
}
|
||||
|
||||
candidates = await engine._collect_candidates(access_token="token", seed=seed, market=None)
|
||||
|
||||
self.assertIn("lf1", {c.id for c in candidates})
|
||||
self.assertTrue(any("lastfm_artist_similar" in c.source for c in candidates))
|
||||
|
||||
def test_normalize_spotify_market(self) -> None:
|
||||
spotify = RecordingSpotifyStub()
|
||||
engine = make_engine(spotify)
|
||||
|
||||
self.assertIsNone(engine._normalize_spotify_market("EU"))
|
||||
self.assertIsNone(engine._normalize_spotify_market("global"))
|
||||
self.assertEqual(engine._normalize_spotify_market("de"), "DE")
|
||||
self.assertEqual(engine._normalize_spotify_market("US"), "US")
|
||||
self.assertIsNone(engine._normalize_spotify_market("USA"))
|
||||
self.assertIsNone(engine._normalize_spotify_market(""))
|
||||
|
||||
def test_rank_and_select_uses_liked_fallback_and_counts_as_reused(self) -> None:
|
||||
spotify = RecordingSpotifyStub()
|
||||
engine = make_engine(spotify)
|
||||
candidates = [
|
||||
TrackCandidate(
|
||||
id="c1",
|
||||
uri="spotify:track:c1",
|
||||
name="Song 1",
|
||||
artist_names=["Artist One"],
|
||||
artist_ids=["a1"],
|
||||
source="spotify_search_artist",
|
||||
score=0.7,
|
||||
),
|
||||
TrackCandidate(
|
||||
id="c2",
|
||||
uri="spotify:track:c2",
|
||||
name="Song 2",
|
||||
artist_names=["Artist Two"],
|
||||
artist_ids=["a2"],
|
||||
source="spotify_search_artist",
|
||||
score=0.6,
|
||||
),
|
||||
]
|
||||
|
||||
result = engine._rank_and_select(
|
||||
candidates=candidates,
|
||||
liked_ids={"c1", "c2"},
|
||||
history_ids=set(),
|
||||
target_size=2,
|
||||
min_new_ratio=0.8,
|
||||
)
|
||||
|
||||
self.assertEqual(len(result.tracks), 2)
|
||||
self.assertEqual(result.new_count, 0)
|
||||
self.assertEqual(result.reused_count, 2)
|
||||
self.assertIsNotNone(result.notes)
|
||||
self.assertIn("liked-track fallback", result.notes or "")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user