87 lines
3.8 KiB
Python
87 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
import secrets
|
|
from datetime import timedelta
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
|
|
from app.clients.spotify import SpotifyClient
|
|
from app.config import Settings
|
|
from app.db.repositories import AuthStateRepository, UserRepository
|
|
from app.utils.time import ensure_utc, utcnow
|
|
|
|
|
|
SPOTIFY_SCOPES = [
|
|
"user-library-read",
|
|
"user-read-recently-played",
|
|
"playlist-modify-private",
|
|
"playlist-modify-public",
|
|
]
|
|
|
|
|
|
class SpotifyAuthService:
|
|
def __init__(self, settings: Settings, spotify: SpotifyClient, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
|
self.settings = settings
|
|
self.spotify = spotify
|
|
self.session_factory = session_factory
|
|
|
|
async def create_connect_url(self, chat_id: str, username: str | None = None) -> str:
|
|
async with self.session_factory() as session:
|
|
users = UserRepository(session)
|
|
states = AuthStateRepository(session)
|
|
user = await users.get_or_create_by_chat(chat_id=chat_id, username=username)
|
|
user.timezone = user.timezone or self.settings.app_timezone
|
|
state = secrets.token_urlsafe(24)
|
|
await states.delete_expired()
|
|
await states.create(state=state, telegram_chat_id=user.telegram_chat_id, expires_at=utcnow() + timedelta(minutes=15))
|
|
await session.commit()
|
|
return self.spotify.build_authorize_url(state=state, scopes=SPOTIFY_SCOPES)
|
|
|
|
async def handle_callback(self, code: str, state: str) -> tuple[str, str]:
|
|
async with self.session_factory() as session:
|
|
users = UserRepository(session)
|
|
states = AuthStateRepository(session)
|
|
auth_state = await states.pop_valid(state)
|
|
if not auth_state:
|
|
await session.commit()
|
|
raise ValueError("OAuth state is invalid or expired")
|
|
|
|
user = await users.get_by_chat_id(auth_state.telegram_chat_id)
|
|
if not user:
|
|
raise ValueError("User not found for auth state")
|
|
|
|
token_payload = await self.spotify.exchange_code(code)
|
|
access_token = token_payload["access_token"]
|
|
me = await self.spotify.get_current_user(access_token)
|
|
|
|
user.spotify_user_id = me.get("id")
|
|
user.spotify_access_token = access_token
|
|
user.spotify_refresh_token = token_payload.get("refresh_token") or user.spotify_refresh_token
|
|
user.spotify_token_expires_at = self.spotify.token_expiry_from_response(token_payload)
|
|
user.spotify_scopes = token_payload.get("scope")
|
|
user.is_active = True
|
|
if not user.timezone:
|
|
user.timezone = self.settings.app_timezone
|
|
|
|
await session.commit()
|
|
return user.telegram_chat_id, me.get("display_name") or me.get("id") or "Spotify user"
|
|
|
|
async def ensure_valid_access_token(self, session: AsyncSession, user) -> str:
|
|
if (
|
|
user.spotify_access_token
|
|
and user.spotify_token_expires_at
|
|
and ensure_utc(user.spotify_token_expires_at) > utcnow()
|
|
):
|
|
return user.spotify_access_token
|
|
if not user.spotify_refresh_token:
|
|
raise RuntimeError("User is not connected to Spotify")
|
|
token_payload = await self.spotify.refresh_access_token(user.spotify_refresh_token)
|
|
user.spotify_access_token = token_payload["access_token"]
|
|
if token_payload.get("refresh_token"):
|
|
user.spotify_refresh_token = token_payload["refresh_token"]
|
|
user.spotify_token_expires_at = self.spotify.token_expiry_from_response(token_payload)
|
|
if token_payload.get("scope"):
|
|
user.spotify_scopes = token_payload["scope"]
|
|
await session.flush()
|
|
return user.spotify_access_token
|