From c33e8a474cff3eddeb6967827044b746c894f1b2 Mon Sep 17 00:00:00 2001 From: Marcel van der Veldt Date: Mon, 9 Nov 2020 00:21:28 +0100 Subject: [PATCH] changes to data model breaking change (reuires resync) but speed is improved a lot now --- .vscode/settings.json | 2 +- music_assistant.code-workspace | 2 +- music_assistant/constants.py | 2 +- music_assistant/helpers/cache.py | 108 +- music_assistant/helpers/util.py | 80 +- music_assistant/helpers/web.py | 73 +- music_assistant/managers/database.py | 1526 ++++++++--------- music_assistant/managers/metadata.py | 2 +- music_assistant/managers/music.py | 440 +++-- music_assistant/managers/players.py | 19 +- music_assistant/models/media_types.py | 119 +- music_assistant/models/player.py | 8 +- music_assistant/providers/file/__init__.py | 4 +- music_assistant/providers/qobuz/__init__.py | 260 +-- music_assistant/providers/spotify/__init__.py | 170 +- music_assistant/providers/tunein/__init__.py | 11 +- music_assistant/web/__init__.py | 3 +- music_assistant/web/endpoints/albums.py | 23 +- music_assistant/web/endpoints/artists.py | 19 +- music_assistant/web/endpoints/config.py | 14 +- music_assistant/web/endpoints/images.py | 2 +- music_assistant/web/endpoints/library.py | 44 +- music_assistant/web/endpoints/login.py | 2 +- music_assistant/web/endpoints/players.py | 35 +- music_assistant/web/endpoints/playlists.py | 17 +- music_assistant/web/endpoints/radios.py | 13 +- music_assistant/web/endpoints/search.py | 12 +- music_assistant/web/endpoints/streams.py | 2 +- music_assistant/web/endpoints/tracks.py | 13 +- music_assistant/web/endpoints/websocket.py | 20 +- requirements_dev.txt | 1 - 31 files changed, 1493 insertions(+), 1553 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 3b0a5ec1..e1a9b44c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,7 +2,7 @@ "python.linting.pylintEnabled": true, "python.linting.pylintArgs": ["--rcfile=${workspaceFolder}/setup.cfg"], "python.linting.enabled": true, - "python.pythonPath": "venv/bin/python3", + "python.pythonPath": "venv/bin/python", "python.linting.flake8Enabled": true, "python.linting.flake8Args": ["--config=${workspaceFolder}/setup.cfg"], "python.linting.mypyEnabled": false, diff --git a/music_assistant.code-workspace b/music_assistant.code-workspace index 92cb8f15..c6efe05f 100644 --- a/music_assistant.code-workspace +++ b/music_assistant.code-workspace @@ -5,6 +5,6 @@ } ], "settings": { - "python.pythonPath": "/usr/local/Cellar/python/3.7.7/Frameworks/Python.framework/Versions/3.7/bin/python3.7" + "python.pythonPath": "venv/bin/python" } } \ No newline at end of file diff --git a/music_assistant/constants.py b/music_assistant/constants.py index 17a48b99..c52ca59c 100755 --- a/music_assistant/constants.py +++ b/music_assistant/constants.py @@ -1,6 +1,6 @@ """All constants for Music Assistant.""" -__version__ = "0.0.62" +__version__ = "0.0.63" REQUIRED_PYTHON_VER = "3.8" # configuration keys/attributes diff --git a/music_assistant/helpers/cache.py b/music_assistant/helpers/cache.py index fdb67086..785fc700 100644 --- a/music_assistant/helpers/cache.py +++ b/music_assistant/helpers/cache.py @@ -1,11 +1,13 @@ """Provides a simple stateless caching system.""" +import asyncio import functools import logging import os import pickle import time from functools import reduce +from typing import Awaitable import aiosqlite from music_assistant.helpers.util import run_periodic @@ -21,7 +23,8 @@ class Cache: def __init__(self, mass): """Initialize our caching class.""" self.mass = mass - self._dbfile = os.path.join(mass.config.data_path, "cache.db") + self._dbfile = os.path.join(mass.config.data_path, ".cache.db") + self._mem_cache = {} async def async_setup(self): """Async initialize of cache module.""" @@ -35,7 +38,7 @@ class Cache: await db_conn.commit() self.mass.add_job(self.async_auto_cleanup()) - async def async_get(self, cache_key, checksum=""): + async def async_get(self, cache_key, checksum="", default=None): """ Get object from cache and return the results. @@ -43,40 +46,68 @@ class Cache: checkum: optional argument to check if the checksum in the cacheobject matches the checkum provided """ - result = None cur_time = int(time.time()) checksum = self._get_checksum(checksum) - sql_query = "SELECT expires, data, checksum FROM simplecache WHERE id = ?" + + # try memory cache first + cache_data = self._mem_cache.get(cache_key) + if ( + cache_data + and (not checksum or cache_data[1] == checksum) + and cache_data[2] >= cur_time + ): + return cache_data[0] + # fall back to db cache + sql_query = "SELECT data, checksum, expires FROM simplecache WHERE id = ?" async with aiosqlite.connect(self._dbfile, timeout=180) as db_conn: - db_conn.row_factory = aiosqlite.Row async with db_conn.execute(sql_query, (cache_key,)) as cursor: cache_data = await cursor.fetchone() - if not cache_data: - LOGGER.debug("no cache data for %s", cache_key) - elif cache_data["expires"] < cur_time: - LOGGER.debug("cache expired for %s", cache_key) - elif checksum and cache_data["checksum"] != checksum: - LOGGER.debug("cache checksum mismatch for %s", cache_key) - if cache_data and cache_data["expires"] > cur_time: - if checksum is None or cache_data["checksum"] == checksum: - LOGGER.debug("return cache data for %s", cache_key) - result = pickle.loads(cache_data[1]) - return result + if ( + cache_data + and (not checksum or cache_data[1] == checksum) + and cache_data[2] >= cur_time + ): + data = await asyncio.get_running_loop().run_in_executor( + None, pickle.loads, cache_data[0] + ) + # also store in memory cache for faster access + if cache_key not in self._mem_cache: + self._mem_cache[cache_key] = ( + data, + cache_data[1], + cache_data[2], + ) + return data + LOGGER.debug("no cache data for %s", cache_key) + return default async def async_set(self, cache_key, data, checksum="", expiration=(86400 * 30)): """Set data in cache.""" checksum = self._get_checksum(checksum) expires = int(time.time() + expiration) - data = pickle.dumps(data) + self._mem_cache[cache_key] = (data, checksum, expires) + data = await asyncio.get_running_loop().run_in_executor( + None, pickle.dumps, data + ) sql_query = """INSERT OR REPLACE INTO simplecache (id, expires, data, checksum) VALUES (?, ?, ?, ?)""" async with aiosqlite.connect(self._dbfile, timeout=180) as db_conn: await db_conn.execute(sql_query, (cache_key, expires, data, checksum)) await db_conn.commit() + async def async_delete(self, cache_key): + """Delete data from cache.""" + self._mem_cache.pop(cache_key, None) + sql_query = "DELETE FROM simplecache WHERE id = ?" + async with aiosqlite.connect(self._dbfile, timeout=180) as db_conn: + await db_conn.execute(sql_query, (cache_key,)) + await db_conn.commit() + @run_periodic(3600) async def async_auto_cleanup(self): """Sceduled auto cleanup task.""" + # for now we simply rest the memory cache + self._mem_cache = {} cur_timestamp = int(time.time()) LOGGER.debug("Running cleanup...") sql_query = "SELECT id, expires FROM simplecache" @@ -90,7 +121,6 @@ class Cache: if cache_data["expires"] < cur_timestamp: sql_query = "DELETE FROM simplecache WHERE id = ?" await db_conn.execute(sql_query, (cache_id,)) - LOGGER.debug("delete from db %s", cache_id) # compact db await db_conn.commit() LOGGER.debug("Auto cleanup done") @@ -104,34 +134,20 @@ class Cache: return reduce(lambda x, y: x + y, map(ord, stringinput)) -async def async_cached_generator( - cache, cache_key, coro_func, expires=(86400 * 30), checksum=None -): - """Return helper method to store results of a async generator in the cache.""" - cache_result = await cache.async_get(cache_key, checksum) - if cache_result is not None: - for item in cache_result: - yield item - else: - # nothing in cache, yield from generator and store in cache when complete - cache_result = [] - async for item in coro_func: - yield item - cache_result.append(item) - # store results in cache - await cache.async_set(cache_key, cache_result, checksum, expires) - - async def async_cached( - cache, cache_key, coro_func, expires=(86400 * 30), checksum=None + cache, + cache_key: str, + coro_func: Awaitable, + *args, + expires: int = (86400 * 30), + checksum=None ): """Return helper method to store results of a coroutine in the cache.""" cache_result = await cache.async_get(cache_key, checksum) - # normal async function if cache_result is not None: return cache_result - result = await coro_func - await cache.async_set(cache_key, cache_result, checksum, expires) + result = await coro_func(*args) + asyncio.create_task(cache.async_set(cache_key, result, checksum, expires)) return result @@ -150,11 +166,13 @@ def async_use_cache(cache_days=14, cache_checksum=None): if cachedata is not None: return cachedata result = await func(*args, **kwargs) - await method_class.cache.async_set( - cache_str, - result, - checksum=cache_checksum, - expiration=(86400 * cache_days), + asyncio.create_task( + method_class.cache.async_set( + cache_str, + result, + checksum=cache_checksum, + expiration=(86400 * cache_days), + ) ) return result diff --git a/music_assistant/helpers/util.py b/music_assistant/helpers/util.py index 3c5e9d72..b18fa8c7 100755 --- a/music_assistant/helpers/util.py +++ b/music_assistant/helpers/util.py @@ -8,8 +8,6 @@ import socket import struct import tempfile import urllib.request -from datetime import datetime -from enum import Enum from io import BytesIO from typing import Any, Callable, TypeVar @@ -90,7 +88,7 @@ def get_sort_name(name): for item in ["The ", "De ", "de ", "Les "]: if name.startswith(item): sort_name = "".join(name.split(item)[1:]) - return sort_name + return get_compare_string(sort_name) def try_parse_int(possible_int): @@ -236,28 +234,6 @@ def get_folder_size(folderpath): return total_size_gb -def serialize_values(obj): - """Recursively create serializable values for (custom) data types.""" - - def get_val(val): - if hasattr(val, "to_dict"): - return val.to_dict() - if isinstance(val, list): - return [get_val(x) for x in val] - if isinstance(val, datetime): - return val.isoformat() - if isinstance(val, dict): - return {key: get_val(value) for key, value in val.items()} - return val - - return get_val(obj) - - -def json_serializer(obj): - """Json serializer to recursively create serializable values for custom data types.""" - return ujson.dumps(serialize_values(obj)) - - def get_compare_string(input_str): """Return clean lowered string for compare actions.""" unaccented_string = unidecode.unidecode(input_str) @@ -272,14 +248,27 @@ def compare_strings(str1, str2, strict=False): return match -def merge_dict(base_dict: dict, new_dict: dict): +def merge_dict(base_dict: dict, new_dict: dict, allow_overwite=False): """Merge dict without overwriting existing values.""" + final_dict = base_dict.copy() for key, value in new_dict.items(): - if base_dict.get(key) and isinstance(value, dict): - base_dict[key] = merge_dict(base_dict[key], value) - elif not base_dict.get(key): - base_dict[key] = value - return base_dict + if final_dict.get(key) and isinstance(value, dict): + final_dict[key] = merge_dict(final_dict[key], value) + if final_dict.get(key) and isinstance(value, list): + final_dict[key] = merge_list(final_dict[key], value) + elif not final_dict.get(key) or allow_overwite: + final_dict[key] = value + return final_dict + + +def merge_list(base_list: list, new_list: list): + """Merge 2 lists.""" + final_list = [] + final_list += base_list + for item in new_list: + if item not in final_list: + final_list.append(item) + return final_list def try_load_json_file(jsonfile): @@ -310,35 +299,6 @@ async def async_yield_chunks(_obj, chunk_size): yield _obj[i : i + chunk_size] -class CustomIntEnum(int, Enum): - """Base for IntEnum with some helpers.""" - - # when serializing we prefer the string (name) representation - # internally (database) we use the int value - - def __int__(self): - """Return integer value.""" - return super().value - - def __str__(self): - """Return string value.""" - # pylint: disable=no-member - return self._name_.lower() - - @property - def value(self): - """Return the (json friendly) string name.""" - return self.__str__() - - @classmethod - def from_string(cls, string): - """Create IntEnum from it's string equivalent.""" - for key, value in cls.__dict__.items(): - if key.lower() == string or value == try_parse_int(string): - return value - return KeyError - - def create_wave_header(samplerate=44100, channels=2, bitspersample=16, duration=3600): """Generate a wave header from given params.""" file = BytesIO() diff --git a/music_assistant/helpers/web.py b/music_assistant/helpers/web.py index acdfff52..55ddcf70 100644 --- a/music_assistant/helpers/web.py +++ b/music_assistant/helpers/web.py @@ -1,39 +1,17 @@ """Various helpers for web requests.""" +import asyncio import ipaddress +from datetime import datetime from functools import wraps -from typing import AsyncGenerator +from typing import Any +import ujson from aiohttp import web from music_assistant.helpers.typing import MusicAssistantType -from music_assistant.helpers.util import json_serializer from music_assistant.models.media_types import MediaType -async def async_stream_json(request: web.Request, generator: AsyncGenerator): - """Stream items from async generator as json object.""" - resp = web.StreamResponse( - status=200, reason="OK", headers={"Content-Type": "application/json"} - ) - await resp.prepare(request) - # write json open tag - await resp.write(b'{ "items": [') - count = 0 - async for item in generator: - # write each item into the items object of the json - if count: - json_response = b"," + json_serializer(item).encode() - else: - json_response = json_serializer(item).encode() - await resp.write(json_response) - count += 1 - # write json close tag - msg = '], "count": %s }' % count - await resp.write(msg.encode()) - await resp.write_eof() - return resp - - async def async_media_items_from_body(mass: MusicAssistantType, data: dict): """Convert posted body data into media items.""" if not isinstance(data, list): @@ -43,7 +21,7 @@ async def async_media_items_from_body(mass: MusicAssistantType, data: dict): media_item = await mass.music.async_get_item( item["item_id"], item["provider"], - MediaType.from_string(item["media_type"]), + MediaType(item["media_type"]), lazy=True, ) media_items.append(media_item) @@ -71,3 +49,44 @@ def require_local_subnet(func): return await func(*args, **kwargs) return wrapped + + +def serialize_values(obj): + """Recursively create serializable values for (custom) data types.""" + + def get_val(val): + if hasattr(val, "to_dict"): + return val.to_dict() + if isinstance(val, list): + return [get_val(x) for x in val] + if isinstance(val, datetime): + return val.isoformat() + if isinstance(val, dict): + return {key: get_val(value) for key, value in val.items()} + return val + + return get_val(obj) + + +def json_serializer(obj): + """Json serializer to recursively create serializable values for custom data types.""" + return ujson.dumps(serialize_values(obj)) + + +def json_response(data: Any, status: int = 200): + """Return json in web request.""" + # return web.json_response(data, dumps=json_serializer) + return web.Response( + body=json_serializer(data), status=200, content_type="application/json" + ) + + +async def async_json_response(data: Any, status: int = 200): + """Return json in web request.""" + if isinstance(data, list): + # we could potentially receive a large list of objects to serialize + # which is blocking IO so run it in executor to be safe + return await asyncio.get_running_loop().run_in_executor( + None, json_response, data + ) + return json_response(data) diff --git a/music_assistant/managers/database.py b/music_assistant/managers/database.py index 7c5da398..0ccde1ad 100755 --- a/music_assistant/managers/database.py +++ b/music_assistant/managers/database.py @@ -2,25 +2,29 @@ # pylint: disable=too-many-lines import logging import os -import sqlite3 from functools import partial from typing import List import aiosqlite -from music_assistant.helpers.util import compare_strings, get_sort_name, try_parse_int +from music_assistant.helpers.util import ( + compare_strings, + merge_dict, + merge_list, + try_parse_int, +) +from music_assistant.helpers.web import json_serializer from music_assistant.models.media_types import ( Album, - AlbumType, + AlbumArtist, Artist, - ExternalId, - MediaItem, MediaItemProviderId, MediaType, Playlist, Radio, SearchResult, Track, - TrackQuality, + TrackAlbum, + TrackArtist, ) LOGGER = logging.getLogger("database") @@ -29,7 +33,7 @@ LOGGER = logging.getLogger("database") class DbConnect: """Helper to initialize the db connection or utilize an existing one.""" - def __init__(self, dbfile: str, db_conn: sqlite3.Connection = None): + def __init__(self, dbfile: str, db_conn: aiosqlite.Connection = None): """Initialize class.""" self._db_conn_provided = db_conn is not None self._db_conn = db_conn @@ -54,129 +58,223 @@ class DatabaseManager: def __init__(self, mass): """Initialize class.""" self.mass = mass - self._dbfile = os.path.join(mass.config.data_path, "database.db") + self._dbfile = os.path.join(mass.config.data_path, "mass.db") self.db_conn = partial(DbConnect, self._dbfile) + self.cache = {} async def async_setup(self): """Async initialization.""" async with DbConnect(self._dbfile) as db_conn: await db_conn.execute( - """CREATE TABLE IF NOT EXISTS library_items( - item_id INTEGER NOT NULL, provider TEXT NOT NULL, - media_type INTEGER NOT NULL, UNIQUE(item_id, provider, media_type) - );""" + """CREATE TABLE IF NOT EXISTS provider_mappings( + item_id INTEGER NOT NULL, + media_type TEXT NOT NULL, + prov_item_id TEXT NOT NULL, + provider TEXT NOT NULL, + quality INTEGER NOT NULL, + details TEXT NULL, + UNIQUE(item_id, media_type, prov_item_id, provider, quality) + );""" ) await db_conn.execute( """CREATE TABLE IF NOT EXISTS artists( - artist_id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, - sort_name TEXT, musicbrainz_id TEXT NOT NULL UNIQUE);""" + item_id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + sort_name TEXT, + musicbrainz_id TEXT NOT NULL UNIQUE, + in_library BOOLEAN DEFAULT 0, + metadata json, + provider_ids json + );""" ) await db_conn.execute( """CREATE TABLE IF NOT EXISTS albums( - album_id INTEGER PRIMARY KEY AUTOINCREMENT, artist_id INTEGER NOT NULL, - name TEXT NOT NULL, albumtype TEXT, year INTEGER, version TEXT, - UNIQUE(artist_id, name, version, year) + item_id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + sort_name TEXT, + album_type TEXT, + year INTEGER, + version TEXT, + in_library BOOLEAN DEFAULT 0, + upc TEXT, + artist json, + metadata json, + provider_ids json, + UNIQUE(item_id, name, version, year) );""" ) - await db_conn.execute( - """CREATE TABLE IF NOT EXISTS labels( - label_id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT UNIQUE);""" - ) - await db_conn.execute( - """CREATE TABLE IF NOT EXISTS album_labels( - album_id INTEGER, label_id INTEGER, UNIQUE(album_id, label_id));""" - ) - await db_conn.execute( """CREATE TABLE IF NOT EXISTS tracks( - track_id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, - album_id INTEGER, version TEXT, duration INTEGER, - UNIQUE(name, version, album_id, duration) + item_id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + sort_name TEXT, + version TEXT, + duration INTEGER, + in_library BOOLEAN DEFAULT 0, + isrc TEXT, + album json, + artists json, + metadata json, + provider_ids json, + UNIQUE(name, version, item_id, duration) );""" ) - await db_conn.execute( - """CREATE TABLE IF NOT EXISTS track_artists( - track_id INTEGER, artist_id INTEGER, UNIQUE(track_id, artist_id));""" - ) - - await db_conn.execute( - """CREATE TABLE IF NOT EXISTS tags( - tag_id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT UNIQUE);""" - ) - await db_conn.execute( - """CREATE TABLE IF NOT EXISTS media_tags( - item_id INTEGER, media_type INTEGER, tag_id, - UNIQUE(item_id, media_type, tag_id) - );""" - ) - - await db_conn.execute( - """CREATE TABLE IF NOT EXISTS provider_mappings( - item_id INTEGER NOT NULL, media_type INTEGER NOT NULL, - prov_item_id TEXT NOT NULL, - provider TEXT NOT NULL, quality INTEGER NOT NULL, details TEXT NULL, - UNIQUE(item_id, media_type, prov_item_id, provider, quality) - );""" - ) - - await db_conn.execute( - """CREATE TABLE IF NOT EXISTS metadata( - item_id INTEGER NOT NULL, media_type INTEGER NOT NULL, key TEXT NOT NULL, - value TEXT, UNIQUE(item_id, media_type, key));""" - ) - - await db_conn.execute( - """CREATE TABLE IF NOT EXISTS external_ids( - item_id INTEGER NOT NULL, media_type INTEGER NOT NULL, key TEXT NOT NULL, - value TEXT, UNIQUE(item_id, media_type, key, value));""" - ) await db_conn.execute( """CREATE TABLE IF NOT EXISTS playlists( - playlist_id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, - owner TEXT NOT NULL, is_editable BOOLEAN NOT NULL, checksum TEXT NOT NULL, + item_id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + sort_name TEXT, + owner TEXT NOT NULL, + is_editable BOOLEAN NOT NULL, + checksum TEXT NOT NULL, + in_library BOOLEAN DEFAULT 0, + metadata json, + provider_ids json, UNIQUE(name, owner) );""" ) await db_conn.execute( """CREATE TABLE IF NOT EXISTS radios( - radio_id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE);""" + item_id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + sort_name TEXT, + in_library BOOLEAN DEFAULT 0, + metadata json, + provider_ids json + );""" ) await db_conn.execute( """CREATE TABLE IF NOT EXISTS track_loudness( - provider_track_id INTEGER NOT NULL, provider TEXT NOT NULL, loudness REAL, - UNIQUE(provider_track_id, provider));""" + provider_item_id INTEGER NOT NULL, + provider TEXT NOT NULL, + loudness REAL, + UNIQUE(provider_item_id, provider));""" ) await db_conn.commit() await db_conn.execute("VACUUM;") await db_conn.commit() - async def async_get_database_id( + async def async_get_item_by_prov_id( self, provider_id: str, prov_item_id: str, media_type: MediaType, - db_conn: sqlite3.Connection = None, + db_conn: aiosqlite.Connection = None, ) -> int: - """Get the database id for the given prov_id.""" - async with DbConnect(self._dbfile, db_conn) as db_conn: - if provider_id == "database": - return prov_item_id - sql_query = """SELECT item_id FROM provider_mappings - WHERE prov_item_id = ? AND provider = ? AND media_type = ?;""" - async with db_conn.execute( - sql_query, (prov_item_id, provider_id, int(media_type)) - ) as cursor: - item_id = await cursor.fetchone() - if item_id: - return item_id[0] + """Get the database item for the given prov_id.""" + if media_type == MediaType.Artist: + return await self.async_get_artist_by_prov_id( + provider_id, prov_item_id, db_conn + ) + if media_type == MediaType.Album: + return await self.async_get_album_by_prov_id( + provider_id, prov_item_id, db_conn + ) + if media_type == MediaType.Track: + return await self.async_get_track_by_prov_id( + provider_id, prov_item_id, db_conn + ) + if media_type == MediaType.Playlist: + return await self.async_get_playlist_by_prov_id( + provider_id, prov_item_id, db_conn + ) + if media_type == MediaType.Radio: + return await self.async_get_radio_by_prov_id( + provider_id, prov_item_id, db_conn + ) + return None + + async def async_get_track_by_prov_id( + self, + provider_id: str, + prov_item_id: str, + db_conn: aiosqlite.Connection = None, + ) -> int: + """Get the database track for the given prov_id.""" + if provider_id == "database": + return await self.async_get_track(prov_item_id, db_conn=db_conn) + sql_query = f"""WHERE item_id in + (SELECT item_id FROM provider_mappings + WHERE prov_item_id = '{prov_item_id}' + AND provider = '{provider_id}' AND media_type = 'track')""" + for item in await self.async_get_tracks(sql_query, db_conn=db_conn): + return item + return None + + async def async_get_album_by_prov_id( + self, + provider_id: str, + prov_item_id: str, + db_conn: aiosqlite.Connection = None, + ) -> int: + """Get the database album for the given prov_id.""" + if provider_id == "database": + return await self.async_get_album(prov_item_id, db_conn=db_conn) + sql_query = f"""WHERE item_id in + (SELECT item_id FROM provider_mappings + WHERE prov_item_id = '{prov_item_id}' + AND provider = '{provider_id}' AND media_type = 'album')""" + for item in await self.async_get_albums(sql_query, db_conn=db_conn): + return item + return None + + async def async_get_artist_by_prov_id( + self, + provider_id: str, + prov_item_id: str, + db_conn: aiosqlite.Connection = None, + ) -> int: + """Get the database artist for the given prov_id.""" + if provider_id == "database": + return await self.async_get_artist(prov_item_id, db_conn=db_conn) + sql_query = f"""WHERE item_id in + (SELECT item_id FROM provider_mappings + WHERE prov_item_id = '{prov_item_id}' + AND provider = '{provider_id}' AND media_type = 'artist')""" + for item in await self.async_get_artists(sql_query, db_conn=db_conn): + return item + return None + + async def async_get_playlist_by_prov_id( + self, + provider_id: str, + prov_item_id: str, + db_conn: aiosqlite.Connection = None, + ) -> int: + """Get the database playlist for the given prov_id.""" + if provider_id == "database": + return await self.async_get_playlist(prov_item_id, db_conn=db_conn) + sql_query = f"""WHERE item_id in + (SELECT item_id FROM provider_mappings + WHERE prov_item_id = '{prov_item_id}' + AND provider = '{provider_id}' AND media_type = 'playlist')""" + for item in await self.async_get_playlists(sql_query, db_conn=db_conn): + return item + return None + + async def async_get_radio_by_prov_id( + self, + provider_id: str, + prov_item_id: str, + db_conn: aiosqlite.Connection = None, + ) -> int: + """Get the database radio for the given prov_id.""" + if provider_id == "database": + return await self.async_get_radio(prov_item_id, db_conn=db_conn) + sql_query = f"""WHERE item_id in + (SELECT item_id FROM provider_mappings + WHERE prov_item_id = '{prov_item_id}' + AND provider = '{provider_id}' AND media_type = 'radio')""" + for item in await self.async_get_radios(sql_query, db_conn=db_conn): + return item return None async def async_search( @@ -188,116 +286,59 @@ class DatabaseManager: searchquery = "%" + searchquery + "%" if media_types is None or MediaType.Artist in media_types: sql_query = ' WHERE name LIKE "%s"' % searchquery - result.artists = [ - item - async for item in self.async_get_artists(sql_query, db_conn=db_conn) - ] + result.artists = await self.async_get_artists( + sql_query, db_conn=db_conn + ) if media_types is None or MediaType.Album in media_types: sql_query = ' WHERE name LIKE "%s"' % searchquery - result.albums = [ - item - async for item in self.async_get_albums(sql_query, db_conn=db_conn) - ] + result.albums = await self.async_get_albums(sql_query, db_conn=db_conn) if media_types is None or MediaType.Track in media_types: sql_query = ' WHERE name LIKE "%s"' % searchquery - result.tracks = [ - item - async for item in self.async_get_tracks(sql_query, db_conn=db_conn) - ] + result.tracks = await self.async_get_tracks(sql_query, db_conn=db_conn) if media_types is None or MediaType.Playlist in media_types: sql_query = ' WHERE name LIKE "%s"' % searchquery - result.playlists = [ - item - async for item in self.async_get_playlists( - sql_query, db_conn=db_conn - ) - ] + result.playlists = await self.async_get_playlists( + sql_query, db_conn=db_conn + ) if media_types is None or MediaType.Radio in media_types: sql_query = ' WHERE name LIKE "%s"' % searchquery - result.radios = [ - item - async for item in self.async_get_radios(sql_query, db_conn=db_conn) - ] + result.radios = await self.async_get_radios(sql_query, db_conn=db_conn) return result - async def async_get_library_artists( - self, provider_id: str = None, orderby: str = "name" - ) -> List[Artist]: - """Get all library artists, optionally filtered by provider.""" - if provider_id is not None: - sql_query = f"""WHERE artist_id in (SELECT item_id FROM library_items WHERE - provider = "{provider_id}" AND media_type = {int(MediaType.Artist)})""" - else: - sql_query = f"""WHERE artist_id in - (SELECT item_id FROM library_items - WHERE media_type = {int(MediaType.Artist)})""" - async for item in self.async_get_artists( - sql_query, orderby=orderby, fulldata=True - ): - yield item + async def async_get_library_artists(self, orderby: str = "name") -> List[Artist]: + """Get all library artists.""" + sql_query = "WHERE in_library = 1" + return await self.async_get_artists(sql_query, orderby=orderby) - async def async_get_library_albums( - self, provider_id: str = None, orderby: str = "name" - ) -> List[Album]: - """Get all library albums, optionally filtered by provider.""" - if provider_id is not None: - sql_query = f"""WHERE album_id in (SELECT item_id FROM library_items - WHERE provider = "{provider_id}" AND media_type = {int(MediaType.Album)})""" - else: - sql_query = f"""WHERE album_id in - (SELECT item_id FROM library_items WHERE media_type = {int(MediaType.Album)})""" - async for item in self.async_get_albums( - sql_query, orderby=orderby, fulldata=True - ): - yield item + async def async_get_library_albums(self, orderby: str = "name") -> List[Album]: + """Get all library albums.""" + sql_query = "WHERE in_library = 1" + return await self.async_get_albums(sql_query, orderby=orderby) - async def async_get_library_tracks( - self, provider_id: str = None, orderby: str = "name" - ) -> List[Track]: - """Get all library tracks, optionally filtered by provider.""" - if provider_id is not None: - sql_query = f"""WHERE track_id in - (SELECT item_id FROM library_items WHERE provider = "{provider_id}" - AND media_type = {int(MediaType.Track)})""" - else: - sql_query = f"""WHERE track_id in - (SELECT item_id FROM library_items WHERE media_type = {int(MediaType.Track)})""" - async for item in self.async_get_tracks(sql_query, orderby=orderby): - yield item + async def async_get_library_tracks(self, orderby: str = "name") -> List[Track]: + """Get all library tracks.""" + sql_query = "WHERE in_library = 1" + return await self.async_get_tracks(sql_query, orderby=orderby) async def async_get_library_playlists( - self, provider_id: str = None, orderby: str = "name" + self, orderby: str = "name" ) -> List[Playlist]: """Fetch all playlist records from table.""" - if provider_id is not None: - sql_query = f"""WHERE playlist_id in - (SELECT item_id FROM library_items WHERE provider = "{provider_id}" - AND media_type = {int(MediaType.Playlist)})""" - else: - sql_query = f"""WHERE playlist_id in - (SELECT item_id FROM library_items WHERE media_type = {int(MediaType.Playlist)})""" - async for item in self.async_get_playlists(sql_query, orderby=orderby): - yield item + sql_query = "WHERE in_library = 1" + return await self.async_get_playlists(sql_query, orderby=orderby) async def async_get_library_radios( self, provider_id: str = None, orderby: str = "name" ) -> List[Radio]: """Fetch all radio records from table.""" - if provider_id is not None: - sql_query = f"""WHERE radio_id in - (SELECT item_id FROM library_items WHERE provider = "{provider_id}" - AND media_type = { int(MediaType.Radio)})""" - else: - sql_query = f"""WHERE radio_id in - (SELECT item_id FROM library_items WHERE media_type = {int(MediaType.Radio)})""" - async for item in self.async_get_radios(sql_query, orderby=orderby): - yield item + sql_query = "WHERE in_library = 1" + return await self.async_get_radios(sql_query, orderby=orderby) async def async_get_playlists( self, filter_query: str = None, orderby: str = "name", - db_conn: sqlite3.Connection = None, + db_conn: aiosqlite.Connection = None, ) -> List[Playlist]: """Get all playlists from database.""" async with DbConnect(self._dbfile, db_conn) as db_conn: @@ -306,41 +347,18 @@ class DatabaseManager: if filter_query: sql_query += " " + filter_query sql_query += " ORDER BY %s" % orderby - async with db_conn.execute(sql_query) as cursor: - db_rows = await cursor.fetchall() - for db_row in db_rows: - playlist = Playlist( - item_id=db_row["playlist_id"], - provider="database", - name=db_row["name"], - metadata=await self.__async_get_metadata( - db_row["playlist_id"], MediaType.Playlist, db_conn - ), - tags=await self.__async_get_tags( - db_row["playlist_id"], int(MediaType.Playlist), db_conn - ), - external_ids=await self.__async_get_external_ids( - db_row["playlist_id"], MediaType.Playlist, db_conn - ), - provider_ids=await self.__async_get_prov_ids( - db_row["playlist_id"], MediaType.Playlist, db_conn - ), - in_library=await self.__async_get_library_providers( - db_row["playlist_id"], MediaType.Playlist, db_conn - ), - is_lazy=False, - available=True, - owner=db_row["owner"], - checksum=db_row["checksum"], - is_editable=db_row["is_editable"], - ) - yield playlist - - async def async_get_playlist(self, playlist_id: int) -> Playlist: + return [ + Playlist.from_db_row(db_row) + for db_row in await db_conn.execute_fetchall(sql_query, ()) + ] + + async def async_get_playlist( + self, item_id: int, db_conn: aiosqlite.Connection = None + ) -> Playlist: """Get playlist record by id.""" - playlist_id = try_parse_int(playlist_id) - async for item in self.async_get_playlists( - f"WHERE playlist_id = {playlist_id}" + item_id = try_parse_int(item_id) + for item in await self.async_get_playlists( + f"WHERE item_id = {item_id}", db_conn=db_conn ): return item return None @@ -349,7 +367,7 @@ class DatabaseManager: self, filter_query: str = None, orderby: str = "name", - db_conn: sqlite3.Connection = None, + db_conn: aiosqlite.Connection = None, ) -> List[Radio]: """Fetch radio records from database.""" sql_query = "SELECT * FROM radios" @@ -358,37 +376,19 @@ class DatabaseManager: sql_query += " ORDER BY %s" % orderby async with DbConnect(self._dbfile, db_conn) as db_conn: db_conn.row_factory = aiosqlite.Row - async with db_conn.execute(sql_query) as cursor: - db_rows = await cursor.fetchall() - for db_row in db_rows: - radio = Radio( - item_id=db_row["radio_id"], - provider="database", - name=db_row["name"], - metadata=await self.__async_get_metadata( - db_row["radio_id"], MediaType.Radio, db_conn - ), - tags=await self.__async_get_tags( - db_row["radio_id"], MediaType.Radio, db_conn - ), - external_ids=await self.__async_get_external_ids( - db_row["radio_id"], MediaType.Radio, db_conn - ), - provider_ids=await self.__async_get_prov_ids( - db_row["radio_id"], MediaType.Radio, db_conn - ), - in_library=await self.__async_get_library_providers( - db_row["radio_id"], MediaType.Radio, db_conn - ), - is_lazy=False, - available=True, - ) - yield radio - - async def async_get_radio(self, radio_id: int) -> Playlist: + return [ + Radio.from_db_row(db_row) + for db_row in await db_conn.execute_fetchall(sql_query, ()) + ] + + async def async_get_radio( + self, item_id: int, db_conn: aiosqlite.Connection = None + ) -> Playlist: """Get radio record by id.""" - radio_id = try_parse_int(radio_id) - async for item in self.async_get_radios(f"WHERE radio_id = {radio_id}"): + item_id = try_parse_int(item_id) + for item in await self.async_get_radios( + f"WHERE item_id = {item_id}", db_conn=db_conn + ): return item return None @@ -396,85 +396,160 @@ class DatabaseManager: """Add a new playlist record to the database.""" assert playlist.name async with DbConnect(self._dbfile) as db_conn: - async with db_conn.execute( - "SELECT (playlist_id) FROM playlists WHERE name=? AND owner=?;", + db_conn.row_factory = aiosqlite.Row + cur_item = await self.__execute_fetchone( + db_conn, + "SELECT (item_id) FROM playlists WHERE name=? AND owner=?;", (playlist.name, playlist.owner), - ) as cursor: - result = await cursor.fetchone() - if result: - playlist_id = result[0] + ) + + if cur_item: # update existing - sql_query = "UPDATE playlists SET is_editable=?, checksum=? WHERE playlist_id=?;" - await db_conn.execute( - sql_query, (playlist.is_editable, playlist.checksum, playlist_id) - ) - else: - # insert playlist - sql_query = """INSERT INTO playlists (name, owner, is_editable, checksum) - VALUES(?,?,?,?);""" - async with db_conn.execute( - sql_query, - ( - playlist.name, - playlist.owner, - playlist.is_editable, - playlist.checksum, - ), - ) as cursor: - last_row_id = cursor.lastrowid - # get id from newly created item - sql_query = "SELECT (playlist_id) FROM playlists WHERE ROWID=?" - async with db_conn.execute(sql_query, (last_row_id,)) as cursor: - playlist_id = await cursor.fetchone() - playlist_id = playlist_id[0] - LOGGER.debug( - "added playlist %s to database: %s", playlist.name, playlist_id + return await self.async_update_playlist(cur_item[0], playlist) + # insert playlist + sql_query = """INSERT INTO playlists + (name, sort_name, owner, is_editable, checksum, metadata, provider_ids) + VALUES(?,?,?,?,?,?,?);""" + async with db_conn.execute( + sql_query, + ( + playlist.name, + playlist.sort_name, + playlist.owner, + playlist.is_editable, + playlist.checksum, + json_serializer(playlist.metadata), + json_serializer(playlist.provider_ids), + ), + ) as cursor: + last_row_id = cursor.lastrowid + new_item = await self.__execute_fetchone( + db_conn, + "SELECT (item_id) FROM playlists WHERE ROWID=?;", + (last_row_id,), ) - # add/update metadata await self.__async_add_prov_ids( - playlist_id, MediaType.Playlist, playlist.provider_ids, db_conn + new_item[0], MediaType.Playlist, playlist.provider_ids, db_conn ) - await self.__async_add_metadata( - playlist_id, MediaType.Playlist, playlist.metadata, db_conn + await db_conn.commit() + LOGGER.debug("added playlist %s to database", playlist.name) + # return created object + return await self.async_get_playlist(new_item[0]) + + async def async_update_playlist(self, item_id: int, playlist: Playlist): + """Update a playlist record in the database.""" + async with DbConnect(self._dbfile) as db_conn: + db_conn.row_factory = aiosqlite.Row + cur_item = Playlist.from_db_row( + await self.__execute_fetchone( + db_conn, "SELECT * FROM playlists WHERE item_id=?;", (item_id,) + ) ) - # save + metadata = merge_dict(cur_item.metadata, playlist.metadata) + provider_ids = merge_list(cur_item.provider_ids, playlist.provider_ids) + sql_query = """UPDATE playlists + SET name=?, + sort_name=?, + owner=?, + is_editable=?, + checksum=?, + metadata=?, + provider_ids=? + WHERE item_id=?;""" + await db_conn.execute( + sql_query, + ( + playlist.name, + playlist.sort_name, + playlist.owner, + playlist.is_editable, + playlist.checksum, + json_serializer(metadata), + json_serializer(provider_ids), + item_id, + ), + ) + await self.__async_add_prov_ids( + item_id, MediaType.Playlist, playlist.provider_ids, db_conn + ) + LOGGER.debug("updated playlist %s in database: %s", playlist.name, item_id) await db_conn.commit() - return playlist_id + # return updated object + return await self.async_get_playlist(item_id) async def async_add_radio(self, radio: Radio): """Add a new radio record to the database.""" assert radio.name async with DbConnect(self._dbfile) as db_conn: + db_conn.row_factory = aiosqlite.Row + cur_item = await self.__execute_fetchone( + db_conn, + "SELECT (item_id) FROM radios WHERE name=?;", + (radio.name,), + ) + if cur_item: + # update existing + return await self.async_update_radio(cur_item[0], radio) + # insert radio + sql_query = """INSERT INTO radios (name, sort_name, metadata, provider_ids) + VALUES(?,?,?);""" async with db_conn.execute( - "SELECT (radio_id) FROM radios WHERE name=?;", (radio.name,) + sql_query, + ( + radio.name, + radio.sort_name, + json_serializer(radio.metadata), + json_serializer(radio.provider_ids), + ), ) as cursor: - result = await cursor.fetchone() - if result: - radio_id = result[0] - else: - # insert radio - sql_query = "INSERT INTO radios (name) VALUES(?);" - async with db_conn.execute(sql_query, (radio.name,)) as cursor: - last_row_id = cursor.lastrowid - # await db_conn.commit() - # get id from newly created item - sql_query = "SELECT (radio_id) FROM radios WHERE ROWID=?" - async with db_conn.execute(sql_query, (last_row_id,)) as cursor: - radio_id = await cursor.fetchone() - radio_id = radio_id[0] - LOGGER.debug( - "added radio station %s to database: %s", radio.name, radio_id + last_row_id = cursor.lastrowid + new_item = await self.__execute_fetchone( + db_conn, + "SELECT (item_id) FROM radios WHERE ROWID=?;", + (last_row_id,), ) - # add/update metadata await self.__async_add_prov_ids( - radio_id, MediaType.Radio, radio.provider_ids, db_conn + new_item[0], MediaType.Radio, radio.provider_ids, db_conn ) - await self.__async_add_metadata( - radio_id, MediaType.Radio, radio.metadata, db_conn + await db_conn.commit() + LOGGER.debug("added radio %s to database", radio.name) + # return created object + return await self.async_get_radio(new_item[0]) + + async def async_update_radio(self, item_id: int, radio: Radio): + """Update a radio record in the database.""" + async with DbConnect(self._dbfile) as db_conn: + db_conn.row_factory = aiosqlite.Row + cur_item = Radio.from_db_row( + await self.__execute_fetchone( + db_conn, "SELECT * FROM radios WHERE item_id=?;", (item_id,) + ) ) - # save + metadata = merge_dict(cur_item.metadata, radio.metadata) + provider_ids = merge_list(cur_item.provider_ids, radio.provider_ids) + sql_query = """UPDATE radios + SET name=?, + sort_name=?, + metadata=?, + provider_ids=? + WHERE item_id=?;""" + await db_conn.execute( + sql_query, + ( + radio.name, + radio.sort_name, + json_serializer(metadata), + json_serializer(provider_ids), + item_id, + ), + ) + await self.__async_add_prov_ids( + item_id, MediaType.Radio, radio.provider_ids, db_conn + ) + LOGGER.debug("updated radio %s in database: %s", radio.name, item_id) await db_conn.commit() - return radio_id + # return updated object + return await self.async_get_radio(item_id) async def async_add_to_library( self, item_id: int, media_type: MediaType, provider: str @@ -482,9 +557,9 @@ class DatabaseManager: """Add an item to the library (item must already be present in the db!).""" async with DbConnect(self._dbfile) as db_conn: item_id = try_parse_int(item_id) - sql_query = """INSERT or REPLACE INTO library_items - (item_id, provider, media_type) VALUES(?,?,?);""" - await db_conn.execute(sql_query, (item_id, provider, int(media_type))) + db_name = media_type.value + "s" + sql_query = f"UPDATE {db_name} SET in_library=1 WHERE item_id=?;" + await db_conn.execute(sql_query, (item_id,)) await db_conn.commit() async def async_remove_from_library( @@ -493,120 +568,122 @@ class DatabaseManager: """Remove item from the library.""" async with DbConnect(self._dbfile) as db_conn: item_id = try_parse_int(item_id) - sql_query = "DELETE FROM library_items WHERE item_id=? AND provider=? AND media_type=?;" - await db_conn.execute(sql_query, (item_id, provider, int(media_type))) - if media_type == MediaType.Playlist: - sql_query = "DELETE FROM playlists WHERE playlist_id=?;" - await db_conn.execute(sql_query, (item_id,)) - sql_query = """DELETE FROM provider_mappings WHERE - item_id=? AND media_type=? AND provider=?;""" - await db_conn.execute(sql_query, (item_id, int(media_type), provider)) - await db_conn.commit() + db_name = media_type.value + "s" + sql_query = f"UPDATE {db_name} SET in_library=0 WHERE item_id=?;" + await db_conn.execute(sql_query, (item_id,)) + await db_conn.commit() async def async_get_artists( self, filter_query: str = None, orderby: str = "name", - fulldata=False, - db_conn: sqlite3.Connection = None, + db_conn: aiosqlite.Connection = None, ) -> List[Artist]: """Fetch artist records from database.""" + sql_query = "SELECT * FROM artists" + if filter_query: + sql_query += " " + filter_query + sql_query += " ORDER BY %s" % orderby async with DbConnect(self._dbfile, db_conn) as db_conn: db_conn.row_factory = aiosqlite.Row - sql_query = "SELECT * FROM artists" - if filter_query: - sql_query += " " + filter_query - sql_query += " ORDER BY %s" % orderby - for db_row in await db_conn.execute_fetchall(sql_query): - artist = Artist( - item_id=db_row["artist_id"], - provider="database", - name=db_row["name"], - sort_name=db_row["sort_name"], - ) - if fulldata: - artist.provider_ids = await self.__async_get_prov_ids( - db_row["artist_id"], MediaType.Artist, db_conn - ) - artist.in_library = await self.__async_get_library_providers( - db_row["artist_id"], MediaType.Artist, db_conn - ) - artist.external_ids = await self.__async_get_external_ids( - artist.item_id, MediaType.Artist, db_conn - ) - artist.metadata = await self.__async_get_metadata( - artist.item_id, MediaType.Artist, db_conn - ) - artist.tags = await self.__async_get_tags( - artist.item_id, MediaType.Artist, db_conn - ) - yield artist + return [ + Artist.from_db_row(db_row) + for db_row in await db_conn.execute_fetchall(sql_query, ()) + ] async def async_get_artist( - self, artist_id: int, fulldata=True, db_conn: sqlite3.Connection = None + self, item_id: int, db_conn: aiosqlite.Connection = None ) -> Artist: """Get artist record by id.""" - artist_id = try_parse_int(artist_id) - async for item in self.async_get_artists( - "WHERE artist_id = %d" % artist_id, fulldata=fulldata, db_conn=db_conn + item_id = try_parse_int(item_id) + for item in await self.async_get_artists( + "WHERE item_id = %d" % item_id, db_conn=db_conn ): return item return None - async def async_add_artist(self, artist: Artist) -> int: + async def async_add_artist(self, artist: Artist): """Add a new artist record to the database.""" - artist_id = None + assert artist.musicbrainz_id async with DbConnect(self._dbfile) as db_conn: - # always prefer to grab existing artist with external_id (=musicbrainz_id) - artist_id = await self.__async_get_item_by_external_id(artist, db_conn) - if not artist_id: - # insert artist - musicbrainz_id = artist.external_ids.get(ExternalId.MUSICBRAINZ) - assert musicbrainz_id # musicbrainz id is required - if not artist.sort_name: - artist.sort_name = get_sort_name(artist.name) - sql_query = "INSERT INTO artists (name, sort_name, musicbrainz_id) VALUES(?,?,?);" - async with db_conn.execute( - sql_query, (artist.name, artist.sort_name, musicbrainz_id) - ) as cursor: - last_row_id = cursor.lastrowid - await db_conn.commit() - # get id from (newly created) item - async with db_conn.execute( - "SELECT artist_id FROM artists WHERE ROWID=?;", (last_row_id,) - ) as cursor: - artist_id = await cursor.fetchone() - artist_id = artist_id[0] - # always add metadata and tags etc. because we might have received - # additional info or a match from other provider + db_conn.row_factory = aiosqlite.Row + cur_item = await self.__execute_fetchone( + db_conn, + "SELECT (item_id) FROM artists WHERE musicbrainz_id=?;", + (artist.musicbrainz_id,), + ) + if cur_item: + # update existing + return await self.async_update_artist(cur_item[0], artist) + # insert artist + sql_query = """INSERT INTO artists + (name, sort_name, musicbrainz_id, metadata, provider_ids) + VALUES(?,?,?,?,?);""" + async with db_conn.execute( + sql_query, + ( + artist.name, + artist.sort_name, + artist.musicbrainz_id, + json_serializer(artist.metadata), + json_serializer(artist.provider_ids), + ), + ) as cursor: + last_row_id = cursor.lastrowid + new_item = await self.__execute_fetchone( + db_conn, + "SELECT (item_id) FROM artists WHERE ROWID=?;", + (last_row_id,), + ) await self.__async_add_prov_ids( - artist_id, MediaType.Artist, artist.provider_ids, db_conn + new_item[0], MediaType.Artist, artist.provider_ids, db_conn ) - await self.__async_add_metadata( - artist_id, MediaType.Artist, artist.metadata, db_conn + await db_conn.commit() + LOGGER.debug("added artist %s to database", artist.name) + # return created object + return await self.async_get_artist(new_item[0]) + + async def async_update_artist(self, item_id: int, artist: Artist): + """Update a artist record in the database.""" + async with DbConnect(self._dbfile) as db_conn: + db_conn.row_factory = aiosqlite.Row + db_row = await self.__execute_fetchone( + db_conn, "SELECT * FROM artists WHERE item_id=?;", (item_id,) ) - await self.__async_add_tags( - artist_id, MediaType.Artist, artist.tags, db_conn + cur_item = Artist.from_db_row(db_row) + metadata = merge_dict(cur_item.metadata, artist.metadata) + provider_ids = merge_list(cur_item.provider_ids, artist.provider_ids) + sql_query = """UPDATE artists + SET name=?, + sort_name=?, + musicbrainz_id=?, + metadata=?, + provider_ids=? + WHERE item_id=?;""" + await db_conn.execute( + sql_query, + ( + artist.name, + artist.sort_name, + artist.musicbrainz_id, + json_serializer(metadata), + json_serializer(provider_ids), + item_id, + ), ) - await self.__async_add_external_ids( - artist_id, MediaType.Artist, artist.external_ids, db_conn + await self.__async_add_prov_ids( + item_id, MediaType.Artist, artist.provider_ids, db_conn ) - # save + LOGGER.debug("updated artist %s in database: %s", artist.name, item_id) await db_conn.commit() - LOGGER.debug( - "added artist %s (%s) to database: %s", - artist.name, - artist.provider_ids, - artist_id, - ) - return artist_id + # return updated object + return await self.async_get_artist(item_id) async def async_get_albums( self, filter_query: str = None, orderby: str = "name", - fulldata=False, - db_conn: sqlite3.Connection = None, + db_conn: aiosqlite.Connection = None, ) -> List[Album]: """Fetch all album records from the database.""" sql_query = "SELECT * FROM albums" @@ -615,464 +692,371 @@ class DatabaseManager: sql_query += " ORDER BY %s" % orderby async with DbConnect(self._dbfile, db_conn) as db_conn: db_conn.row_factory = aiosqlite.Row - for db_row in await db_conn.execute_fetchall(sql_query): - album = Album( - item_id=db_row["album_id"], - provider="database", - name=db_row["name"], - album_type=AlbumType(int(db_row["albumtype"])), - year=db_row["year"], - version=db_row["version"], - artist=await self.async_get_artist( - db_row["artist_id"], fulldata=fulldata, db_conn=db_conn - ), - ) - if fulldata: - album.provider_ids = await self.__async_get_prov_ids( - db_row["album_id"], MediaType.Album, db_conn - ) - album.in_library = await self.__async_get_library_providers( - db_row["album_id"], MediaType.Album, db_conn - ) - album.external_ids = await self.__async_get_external_ids( - album.item_id, MediaType.Album, db_conn - ) - album.metadata = await self.__async_get_metadata( - album.item_id, MediaType.Album, db_conn - ) - album.tags = await self.__async_get_tags( - album.item_id, MediaType.Album, db_conn - ) - album.labels = await self.__async_get_album_labels( - album.item_id, db_conn - ) - yield album + return [ + Album.from_db_row(db_row) + for db_row in await db_conn.execute_fetchall(sql_query, ()) + ] async def async_get_album( - self, album_id: int, fulldata=True, db_conn: sqlite3.Connection = None + self, item_id: int, db_conn: aiosqlite.Connection = None ) -> Album: """Get album record by id.""" - album_id = try_parse_int(album_id) - async for item in self.async_get_albums( - "WHERE album_id = %d" % album_id, fulldata=fulldata, db_conn=db_conn + item_id = try_parse_int(item_id) + # get from db + for item in await self.async_get_albums( + "WHERE item_id = %d" % item_id, db_conn=db_conn ): + item.artist = await self.async_get_artist(item.artist.item_id) return item return None - async def async_add_album(self, album: Album) -> int: + async def async_add_album(self, album: Album): """Add a new album record to the database.""" - assert album.name and album.artist - assert album.artist.provider == "database" - album_id = None async with DbConnect(self._dbfile) as db_conn: db_conn.row_factory = aiosqlite.Row - # always try to grab existing album with external_id - album_id = await self.__async_get_item_by_external_id(album, db_conn) - # fallback to matching on artist_id, name and version - if not album_id: - sql_query = """SELECT album_id FROM albums WHERE - artist_id=? AND name=? AND version=? AND year=? AND albumtype=?""" - async with db_conn.execute( - sql_query, + + # always try to grab existing item by external_id + cur_item = await self.__execute_fetchone( + db_conn, + "SELECT (item_id) FROM albums WHERE upc=?;", + (album.upc,), + ) + # fallback to matching on artist, name and version + if not cur_item: + cur_item = await self.__execute_fetchone( + db_conn, + """SELECT item_id FROM albums WHERE + json_extract("artist", '$.item_id') = ? + AND sort_name=? AND version=? AND year=? AND album_type=?""", ( album.artist.item_id, - album.name, + album.sort_name, album.version, int(album.year), - int(album.album_type), + album.album_type.value, ), - ) as cursor: - res = await cursor.fetchone() - if res: - album_id = res["album_id"] + ) # fallback to almost exact match - if not album_id: - sql_query = """SELECT album_id, year, version, albumtype FROM - albums WHERE artist_id=? AND name=?""" - async with db_conn.execute( - sql_query, (album.artist.item_id, album.name) - ) as cursor: - albums = await cursor.fetchall() - for result in albums: - if (not album.version and result["year"] == album.year) or ( - album.version and result["version"] == album.version + if not cur_item: + for item in await db_conn.execute_fetchall( + """SELECT * FROM albums WHERE + json_extract("artist", '$.item_id') = ? + AND sort_name = ?""", + (album.artist.item_id, album.sort_name), + ): + if (not album.version and item["year"] == album.year) or ( + album.version and item["version"] == album.version ): - album_id = result["album_id"] + cur_item = item break - # no match: insert album - if not album_id: - sql_query = """INSERT INTO albums (artist_id, name, albumtype, year, version) - VALUES(?,?,?,?,?);""" - query_params = ( - album.artist.item_id, + + if cur_item: + # update existing + return await self.async_update_album(cur_item[0], album) + # insert album + album_artist = AlbumArtist( + item_id=album.artist.item_id, + provider="database", + name=album.artist.name, + ) + sql_query = """INSERT INTO albums + (name, sort_name, album_type, year, version, upc, artist, metadata, provider_ids) + VALUES(?,?,?,?,?,?,?,?,?);""" + async with db_conn.execute( + sql_query, + ( album.name, - int(album.album_type), + album.sort_name, + album.album_type.value, album.year, album.version, + album.upc, + json_serializer(album_artist), + json_serializer(album.metadata), + json_serializer(album.provider_ids), + ), + ) as cursor: + last_row_id = cursor.lastrowid + new_item = await self.__execute_fetchone( + db_conn, + "SELECT (item_id) FROM albums WHERE ROWID=?;", + (last_row_id,), ) - async with db_conn.execute(sql_query, query_params) as cursor: - last_row_id = cursor.lastrowid - # get id from newly created item - sql_query = "SELECT (album_id) FROM albums WHERE ROWID=?" - async with db_conn.execute(sql_query, (last_row_id,)) as cursor: - album_id = await cursor.fetchone() - album_id = album_id[0] - await db_conn.commit() - # always add metadata and tags etc. because we might have received - # additional info or a match from other provider await self.__async_add_prov_ids( - album_id, MediaType.Album, album.provider_ids, db_conn + new_item[0], MediaType.Album, album.provider_ids, db_conn ) - await self.__async_add_metadata( - album_id, MediaType.Album, album.metadata, db_conn + await db_conn.commit() + LOGGER.debug("added album %s to database", album.name) + # return created object + return await self.async_get_album(new_item[0]) + + async def async_update_album(self, item_id: int, album: Album): + """Update a album record in the database.""" + async with DbConnect(self._dbfile) as db_conn: + db_conn.row_factory = aiosqlite.Row + cur_item = Album.from_db_row( + await self.__execute_fetchone( + db_conn, "SELECT * FROM albums WHERE item_id=?;", (item_id,) + ) ) - await self.__async_add_tags(album_id, MediaType.Album, album.tags, db_conn) - await self.__async_add_album_labels(album_id, album.labels, db_conn) - await self.__async_add_external_ids( - album_id, MediaType.Album, album.external_ids, db_conn + album_artist = AlbumArtist( + item_id=album.artist.item_id, + provider="database", + name=album.artist.name, ) - # save - await db_conn.commit() - LOGGER.debug( - "added album %s (%s) to database: %s", - album.name, - album.provider_ids, - album_id, + metadata = merge_dict(cur_item.metadata, album.metadata) + provider_ids = merge_list(cur_item.provider_ids, album.provider_ids) + sql_query = """UPDATE albums + SET name=?, + sort_name=?, + album_type=?, + year=?, + version=?, + upc=?, + artist=?, + metadata=?, + provider_ids=? + WHERE item_id=?;""" + await db_conn.execute( + sql_query, + ( + album.name, + album.sort_name, + album.album_type.value, + album.year, + album.version, + album.upc, + json_serializer(album_artist), + json_serializer(metadata), + json_serializer(provider_ids), + item_id, + ), + ) + await self.__async_add_prov_ids( + item_id, MediaType.Album, album.provider_ids, db_conn ) - return album_id + LOGGER.debug("updated album %s in database: %s", album.name, item_id) + await db_conn.commit() + # return updated object + return await self.async_get_album(item_id) async def async_get_tracks( self, filter_query: str = None, orderby: str = "name", - fulldata=False, - db_conn: sqlite3.Connection = None, + db_conn: aiosqlite.Connection = None, ) -> List[Track]: """Return all track records from the database.""" + sql_query = "SELECT * FROM tracks" + if filter_query: + sql_query += " " + filter_query + sql_query += " ORDER BY %s" % orderby async with DbConnect(self._dbfile, db_conn) as db_conn: db_conn.row_factory = aiosqlite.Row - sql_query = "SELECT * FROM tracks" - if filter_query: - sql_query += " " + filter_query - sql_query += " ORDER BY %s" % orderby - for db_row in await db_conn.execute_fetchall(sql_query, ()): - track = Track( - item_id=db_row["track_id"], - provider="database", - name=db_row["name"], - external_ids=await self.__async_get_external_ids( - db_row["track_id"], MediaType.Track, db_conn - ), - provider_ids=await self.__async_get_prov_ids( - db_row["track_id"], MediaType.Track, db_conn - ), - in_library=await self.__async_get_library_providers( - db_row["track_id"], MediaType.Track, db_conn - ), - duration=db_row["duration"], - version=db_row["version"], - album=await self.async_get_album( - db_row["album_id"], fulldata=fulldata, db_conn=db_conn - ), - artists=await self.__async_get_track_artists( - db_row["track_id"], db_conn=db_conn, fulldata=fulldata - ), - ) - if fulldata: - track.metadata = await self.__async_get_metadata( - db_row["track_id"], MediaType.Track, db_conn - ) - track.tags = await self.__async_get_tags( - db_row["track_id"], MediaType.Track, db_conn - ) - yield track + return [ + Track.from_db_row(db_row) + for db_row in await db_conn.execute_fetchall(sql_query, ()) + ] + + async def async_get_tracks_from_provider_ids( + self, + provider_id: str, + prov_item_ids: List[str], + ) -> dict: + """Get track records for the given prov_ids.""" + prov_item_id_str = ",".join([f'"{x}"' for x in prov_item_ids]) + sql_query = f"""WHERE item_id in + (SELECT item_id FROM provider_mappings + WHERE provider = '{provider_id}' AND media_type = 'track' + AND prov_item_id in ({prov_item_id_str}) + )""" + return await self.async_get_tracks(sql_query) async def async_get_track( - self, track_id: int, fulldata=True, db_conn: sqlite3.Connection = None + self, item_id: int, db_conn: aiosqlite.Connection = None ) -> Track: """Get track record by id.""" - track_id = try_parse_int(track_id) - async for item in self.async_get_tracks( - "WHERE track_id = %d" % track_id, fulldata=fulldata, db_conn=db_conn + item_id = try_parse_int(item_id) + for item in await self.async_get_tracks( + "WHERE item_id = %d" % item_id, db_conn=db_conn ): + item.album = await self.async_get_album(item.album.item_id) + artist_ids = [str(x.item_id) for x in item.artists] + query = "WHERE item_id in (%s)" % ",".join(artist_ids) + item.artists = await self.async_get_artists(query) return item return None - async def async_add_track(self, track: Track) -> int: + async def async_add_track(self, track: Track): """Add a new track record to the database.""" - assert track.name and track.album - assert track.album.provider == "database" - assert track.artists - for artist in track.artists: - assert artist.provider == "database" async with DbConnect(self._dbfile) as db_conn: db_conn.row_factory = aiosqlite.Row - # always try to grab existing track with external_id - track_id = await self.__async_get_item_by_external_id(track, db_conn) - # fallback to matching on album_id, name and version - if not track_id: - sql_query = "SELECT track_id, duration, version \ - FROM tracks WHERE album_id=? AND name=?" - async with db_conn.execute( - sql_query, (track.album.item_id, track.name) - ) as cursor: - results = await cursor.fetchall() - for result in results: + + # always try to grab existing item by external_id + cur_item = await self.__execute_fetchone( + db_conn, + "SELECT (item_id) FROM tracks WHERE isrc=?;", + (track.isrc,), + ) + # fallback to matching on item_id, name and version + if not cur_item: + for item in await db_conn.execute_fetchall( + """SELECT * FROM tracks WHERE + json_extract("album", '$.item_id') = ? + AND sort_name=?""", + ( + track.album.item_id, + track.sort_name, + ), + ): # we perform an additional safety check on the duration or version if ( track.version - and compare_strings(result["version"], track.version) + and compare_strings(item["version"], track.version) ) or ( ( not track.version - and not result["version"] - and abs(result["duration"] - track.duration) < 10 + and not item["version"] + and abs(item["duration"] - track.duration) < 10 ) ): - track_id = result["track_id"] + cur_item = item break - # no match found: insert track - if not track_id: - assert track.name and track.album.item_id - sql_query = "INSERT INTO tracks (name, album_id, duration, version) \ - VALUES(?,?,?,?);" - query_params = ( + + if cur_item: + # update existing + return await self.async_update_track(cur_item[0], track) + # insert track + sql_query = """INSERT INTO tracks + (name, sort_name, album, artists, duration, version, isrc, metadata, provider_ids) + VALUES(?,?,?,?,?,?,?,?,?);""" + # we store a simplified artist/album object in tracks + artists = [ + TrackArtist(item_id=x.item_id, provider="database", name=x.name) + for x in track.artists + ] + album = TrackAlbum( + item_id=track.album.item_id, provider="database", name=track.album.name + ) + async with db_conn.execute( + sql_query, + ( track.name, - track.album.item_id, + track.sort_name, + json_serializer(album), + json_serializer(artists), track.duration, track.version, + track.isrc, + json_serializer(track.metadata), + json_serializer(track.provider_ids), + ), + ) as cursor: + last_row_id = cursor.lastrowid + new_item = await self.__execute_fetchone( + db_conn, + "SELECT (item_id) FROM tracks WHERE ROWID=?;", + (last_row_id,), ) - async with db_conn.execute(sql_query, query_params) as cursor: - last_row_id = cursor.lastrowid - await db_conn.commit() - # get id from newly created item (the safe way) - async with db_conn.execute( - "SELECT track_id FROM tracks WHERE ROWID=?", (last_row_id,) - ) as cursor: - track_id = await cursor.fetchone() - track_id = track_id[0] - # always add metadata and tags etc. because we might have received - # additional info or a match from other provider - for artist in track.artists: - sql_query = "INSERT or IGNORE INTO track_artists (track_id, artist_id) VALUES(?,?);" - await db_conn.execute(sql_query, (track_id, artist.item_id)) await self.__async_add_prov_ids( - track_id, MediaType.Track, track.provider_ids, db_conn - ) - await self.__async_add_metadata( - track_id, MediaType.Track, track.metadata, db_conn + new_item[0], MediaType.Track, track.provider_ids, db_conn ) - await self.__async_add_tags(track_id, MediaType.Track, track.tags, db_conn) - await self.__async_add_external_ids( - track_id, MediaType.Track, track.external_ids, db_conn - ) - # save to db await db_conn.commit() - LOGGER.debug( - "added track %s (%s) to database: %s", - track.name, - track.provider_ids, - track_id, - ) - return track_id + LOGGER.debug("added track %s to database", track.name) + # return created object + return await self.async_get_track(new_item[0]) - async def async_update_playlist( - self, playlist_id: int, column_key: str, column_value: str - ): - """Update column of existing playlist.""" + async def async_update_track(self, item_id: int, track: Track): + """Update a track record in the database.""" async with DbConnect(self._dbfile) as db_conn: - sql_query = f"UPDATE playlists SET {column_key}=? WHERE playlist_id=?;" - await db_conn.execute(sql_query, (column_value, playlist_id)) + db_conn.row_factory = aiosqlite.Row + cur_item = Track.from_db_row( + await self.__execute_fetchone( + db_conn, "SELECT * FROM tracks WHERE item_id=?;", (item_id,) + ) + ) + metadata = merge_dict(cur_item.metadata, track.metadata) + provider_ids = merge_list(cur_item.provider_ids, track.provider_ids) + artists = [ + TrackArtist(item_id=x.item_id, provider="database", name=x.name) + for x in track.artists + ] + album = TrackAlbum( + item_id=track.album.item_id, provider="database", name=track.album.name + ) + sql_query = """UPDATE tracks + SET name=?, + sort_name=?, + album=?, + artists=?, + duration=?, + version=?, + isrc=?, + metadata=?, + provider_ids=? + WHERE item_id=?;""" + await db_conn.execute( + sql_query, + ( + track.name, + track.sort_name, + json_serializer(album), + json_serializer(artists), + track.duration, + track.version, + track.isrc, + json_serializer(metadata), + json_serializer(provider_ids), + item_id, + ), + ) + await self.__async_add_prov_ids( + item_id, MediaType.Track, track.provider_ids, db_conn + ) + LOGGER.debug("updated track %s in database: %s", track.name, item_id) await db_conn.commit() - - async def async_get_artist_tracks( - self, artist_id: int, orderby: str = "name" - ) -> List[Track]: - """Get all library tracks for the given artist.""" - artist_id = try_parse_int(artist_id) - sql_query = f"""WHERE track_id in - (SELECT track_id FROM track_artists WHERE artist_id = {artist_id})""" - async for item in self.async_get_tracks( - sql_query, orderby=orderby, fulldata=False - ): - yield item + # return updated object + return await self.async_get_track(item_id) async def async_get_artist_albums( - self, artist_id: int, orderby: str = "name" + self, item_id: int, orderby: str = "name" ) -> List[Album]: """Get all library albums for the given artist.""" - sql_query = " WHERE artist_id = %s" % artist_id - async for item in self.async_get_albums( - sql_query, orderby=orderby, fulldata=False - ): - yield item + # TODO: use json query type instead of text search + sql_query = f"WHERE json_extract(\"artist\", '$.item_id') = {item_id}" + return await self.async_get_albums(sql_query, orderby=orderby) async def async_set_track_loudness( - self, provider_track_id: str, provider: str, loudness: int + self, provider_item_id: str, provider: str, loudness: int ): """Set integrated loudness for a track in db.""" async with DbConnect(self._dbfile) as db_conn: sql_query = """INSERT or REPLACE INTO track_loudness - (provider_track_id, provider, loudness) VALUES(?,?,?);""" - await db_conn.execute(sql_query, (provider_track_id, provider, loudness)) + (provider_item_id, provider, loudness) VALUES(?,?,?);""" + await db_conn.execute(sql_query, (provider_item_id, provider, loudness)) await db_conn.commit() - async def async_get_track_loudness(self, provider_track_id, provider): + async def async_get_track_loudness(self, provider_item_id, provider): """Get integrated loudness for a track in db.""" async with DbConnect(self._dbfile) as db_conn: sql_query = """SELECT loudness FROM track_loudness WHERE - provider_track_id = ? AND provider = ?""" + provider_item_id = ? AND provider = ?""" async with db_conn.execute( - sql_query, (provider_track_id, provider) + sql_query, (provider_item_id, provider) ) as cursor: result = await cursor.fetchone() if result: return result[0] return None - async def __async_add_metadata( - self, - item_id: int, - media_type: MediaType, - metadata: dict, - db_conn: sqlite3.Connection, - ): - """Add or update metadata.""" - for key, value in metadata.items(): - if value: - sql_query = """INSERT or REPLACE INTO metadata - (item_id, media_type, key, value) VALUES(?,?,?,?);""" - await db_conn.execute(sql_query, (item_id, int(media_type), key, value)) - - async def __async_get_metadata( - self, - item_id: int, - media_type: MediaType, - db_conn: sqlite3.Connection, - filter_key: str = None, - ) -> dict: - """Get metadata for media item.""" - metadata = {} - sql_query = ( - "SELECT key, value FROM metadata WHERE item_id = ? AND media_type = ?" - ) - if filter_key: - sql_query += ' AND key = "%s"' % filter_key - async with db_conn.execute(sql_query, (item_id, int(media_type))) as cursor: - db_rows = await cursor.fetchall() - for db_row in db_rows: - key = db_row[0] - value = db_row[1] - metadata[key] = value - return metadata - - async def __async_add_tags( - self, - item_id: int, - media_type: MediaType, - tags: List[str], - db_conn: sqlite3.Connection, - ): - """Add tags to db.""" - for tag in tags: - sql_query = "INSERT or IGNORE INTO tags (name) VALUES(?);" - async with db_conn.execute(sql_query, (tag,)) as cursor: - tag_id = cursor.lastrowid - sql_query = """INSERT or IGNORE INTO media_tags - (item_id, media_type, tag_id) VALUES(?,?,?);""" - await db_conn.execute(sql_query, (item_id, int(media_type), tag_id)) - - async def __async_get_tags( - self, item_id: int, media_type: MediaType, db_conn: sqlite3.Connection - ) -> List[str]: - """Get tags for media item.""" - tags = [] - sql_query = """SELECT name FROM tags INNER JOIN media_tags ON - tags.tag_id = media_tags.tag_id WHERE item_id = ? AND media_type = ?""" - async with db_conn.execute(sql_query, (item_id, int(media_type))) as cursor: - db_rows = await cursor.fetchall() - for db_row in db_rows: - tags.append(db_row[0]) - return tags - - async def __async_add_album_labels( - self, album_id: int, labels: List[str], db_conn: sqlite3.Connection - ): - """Add labels to album in db.""" - for label in labels: - sql_query = "INSERT or IGNORE INTO labels (name) VALUES(?);" - async with db_conn.execute(sql_query, (label,)) as cursor: - label_id = cursor.lastrowid - sql_query = ( - "INSERT or IGNORE INTO album_labels (album_id, label_id) VALUES(?,?);" - ) - await db_conn.execute(sql_query, (album_id, label_id)) - - async def __async_get_album_labels( - self, album_id: int, db_conn: sqlite3.Connection - ) -> List[str]: - """Get labels for album item.""" - labels = [] - sql_query = """SELECT name FROM labels INNER JOIN album_labels - ON labels.label_id = album_labels.label_id WHERE album_id = ?""" - async with db_conn.execute(sql_query, (album_id,)) as cursor: - db_rows = await cursor.fetchall() - for db_row in db_rows: - labels.append(db_row[0]) - return labels - - async def __async_get_track_artists( - self, track_id: int, db_conn: sqlite3.Connection, fulldata: bool = False - ) -> List[Artist]: - """Get artists for track.""" - sql_query = ( - "WHERE artist_id in (SELECT artist_id FROM track_artists WHERE track_id = %s)" - % track_id - ) - return [ - item - async for item in self.async_get_artists( - sql_query, fulldata=fulldata, db_conn=db_conn - ) - ] - - async def __async_add_external_ids( - self, - item_id: int, - media_type: MediaType, - external_ids: dict, - db_conn: sqlite3.Connection, - ): - """Add or update external_ids.""" - for key, value in external_ids.items(): - sql_query = """INSERT or REPLACE INTO external_ids - (item_id, media_type, key, value) VALUES(?,?,?,?);""" - await db_conn.execute( - sql_query, (item_id, int(media_type), str(key), value) - ) - - async def __async_get_external_ids( - self, item_id: int, media_type: MediaType, db_conn: sqlite3.Connection - ) -> dict: - """Get external_ids for media item.""" - external_ids = {} - sql_query = ( - "SELECT key, value FROM external_ids WHERE item_id = ? AND media_type = ?" - ) - for db_row in await db_conn.execute_fetchall( - sql_query, (item_id, int(media_type)) - ): - external_ids[db_row[0]] = db_row[1] - return external_ids - async def __async_add_prov_ids( self, item_id: int, media_type: MediaType, provider_ids: List[MediaItemProviderId], - db_conn: sqlite3.Connection, + db_conn: aiosqlite.Connection, ): - """Add provider ids for media item to db_conn.""" + """Add provider ids for media item to database.""" for prov in provider_ids: sql_query = """INSERT OR REPLACE INTO provider_mappings @@ -1082,58 +1066,18 @@ class DatabaseManager: sql_query, ( item_id, - int(media_type), + media_type.value, prov.item_id, prov.provider, - int(prov.quality), + prov.quality, prov.details, ), ) - async def __async_get_prov_ids( - self, item_id: int, media_type: MediaType, db_conn: sqlite3.Connection - ) -> List[MediaItemProviderId]: - """Get all provider id's for media item.""" - provider_ids = [] - sql_query = "SELECT prov_item_id, provider, quality, details \ - FROM provider_mappings \ - WHERE item_id = ? AND media_type = ?" - for db_row in await db_conn.execute_fetchall( - sql_query, (item_id, int(media_type)) - ): - prov_mapping = MediaItemProviderId( - provider=db_row["provider"], - item_id=db_row["prov_item_id"], - quality=TrackQuality(db_row["quality"]), - details=db_row["details"], - ) - provider_ids.append(prov_mapping) - return provider_ids - - async def __async_get_library_providers( - self, db_item_id: int, media_type: MediaType, db_conn: sqlite3.Connection - ) -> List[str]: - """Get the providers that have this media_item added to the library.""" - providers = [] - sql_query = ( - "SELECT provider FROM library_items WHERE item_id = ? AND media_type = ?" - ) - for db_row in await db_conn.execute_fetchall( - sql_query, (db_item_id, int(media_type)) - ): - providers.append(db_row[0]) - return providers - - async def __async_get_item_by_external_id( - self, media_item: MediaItem, db_conn: sqlite3.Connection - ) -> int: - """Try to get existing item in db by matching the new item's external id's.""" - for key, value in media_item.external_ids.items(): - sql_query = "SELECT (item_id) FROM external_ids \ - WHERE media_type=? AND key=? AND value=?;" - for db_row in await db_conn.execute_fetchall( - sql_query, (int(media_item.media_type), str(key), value) - ): - if db_row: - return db_row[0] + async def __execute_fetchone( + self, db_conn: aiosqlite.Connection, query: str, query_params: tuple + ): + """Return first row of given query.""" + for item in await db_conn.execute_fetchall(query, query_params): + return item return None diff --git a/music_assistant/managers/metadata.py b/music_assistant/managers/metadata.py index aea627d5..a5c6c9c8 100755 --- a/music_assistant/managers/metadata.py +++ b/music_assistant/managers/metadata.py @@ -36,7 +36,7 @@ class MetaDataManager: break cache_key = f"{provider.id}.artist_metadata.{mb_artist_id}" res = await async_cached( - self.cache, cache_key, provider.async_get_artist_images(mb_artist_id) + self.cache, cache_key, provider.async_get_artist_images, mb_artist_id ) if res: merge_dict(metadata, res) diff --git a/music_assistant/managers/music.py b/music_assistant/managers/music.py index 42293939..e749c440 100755 --- a/music_assistant/managers/music.py +++ b/music_assistant/managers/music.py @@ -10,14 +10,13 @@ from typing import Any, List, Optional import aiohttp from music_assistant.constants import EVENT_MUSIC_SYNC_STATUS, EVENT_PROVIDER_REGISTERED -from music_assistant.helpers.cache import async_cached, async_cached_generator +from music_assistant.helpers.cache import async_cached from music_assistant.helpers.encryption import async_encrypt_string from music_assistant.helpers.musicbrainz import MusicBrainz from music_assistant.helpers.util import callback, compare_strings, run_periodic from music_assistant.models.media_types import ( Album, Artist, - ExternalId, MediaItem, MediaType, Playlist, @@ -119,17 +118,17 @@ class MusicManager: ) -> Artist: """Return artist details for the given provider artist id.""" assert item_id and provider_id - db_id = await self.mass.database.async_get_database_id( - provider_id, item_id, MediaType.Artist + db_item = await self.mass.database.async_get_artist_by_prov_id( + provider_id, item_id ) - if db_id is None: + if not db_item: # artist not yet in local database so fetch details provider = self.mass.get_provider(provider_id) if not provider.available: return None cache_key = f"{provider_id}.get_artist.{item_id}" artist = await async_cached( - self.cache, cache_key, provider.async_get_artist(item_id) + self.cache, cache_key, provider.async_get_artist, item_id ) if not artist: raise Exception( @@ -139,8 +138,8 @@ class MusicManager: self.mass.add_job(self.__async_add_artist(artist)) artist.is_lazy = True return artist - db_id = await self.__async_add_artist(artist) - return await self.mass.database.async_get_artist(db_id) + db_item = await self.__async_add_artist(artist) + return db_item async def async_get_album( self, @@ -151,10 +150,10 @@ class MusicManager: ) -> Album: """Return album details for the given provider album id.""" assert item_id and provider_id - db_id = await self.mass.database.async_get_database_id( - provider_id, item_id, MediaType.Album + db_item = await self.mass.database.async_get_album_by_prov_id( + provider_id, item_id ) - if db_id is None: + if not db_item: # album not yet in local database so fetch details if not album_details: provider = self.mass.get_provider(provider_id) @@ -162,7 +161,7 @@ class MusicManager: return None cache_key = f"{provider_id}.get_album.{item_id}" album_details = await async_cached( - self.cache, cache_key, provider.async_get_album(item_id) + self.cache, cache_key, provider.async_get_album, item_id ) if not album_details: raise Exception( @@ -172,8 +171,8 @@ class MusicManager: self.mass.add_job(self.__async_add_album(album_details)) album_details.is_lazy = True return album_details - db_id = await self.__async_add_album(album_details) - return await self.mass.database.async_get_album(db_id) + db_item = await self.__async_add_album(album_details) + return db_item async def async_get_track( self, @@ -185,19 +184,19 @@ class MusicManager: ) -> Track: """Return track details for the given provider track id.""" assert item_id and provider_id - db_id = await self.mass.database.async_get_database_id( - provider_id, item_id, MediaType.Track + db_item = await self.mass.database.async_get_track_by_prov_id( + provider_id, item_id ) - if db_id and refresh: + if db_item and refresh: # in some cases (e.g. at playback time or requesting full track info) # it's useful to have the track refreshed from the provider instead of # the database cache to make sure that the track is available and perhaps # another or a higher quality version is available. if lazy: - self.mass.add_job(self.__async_match_track(db_id)) + self.mass.add_job(self.__async_match_track(db_item)) else: - await self.__async_match_track(db_id) - if not db_id: + await self.__async_match_track(db_item) + if not db_item: # track not yet in local database so fetch details if not track_details: provider = self.mass.get_provider(provider_id) @@ -205,7 +204,7 @@ class MusicManager: return None cache_key = f"{provider_id}.get_track.{item_id}" track_details = await async_cached( - self.cache, cache_key, provider.async_get_track(item_id) + self.cache, cache_key, provider.async_get_track, item_id ) if not track_details: raise Exception( @@ -215,43 +214,43 @@ class MusicManager: self.mass.add_job(self.__async_add_track(track_details)) track_details.is_lazy = True return track_details - db_id = await self.__async_add_track(track_details) - return await self.mass.database.async_get_track(db_id, fulldata=True) + db_item = await self.__async_add_track(track_details) + return db_item async def async_get_playlist(self, item_id: str, provider_id: str) -> Playlist: """Return playlist details for the given provider playlist id.""" assert item_id and provider_id - db_id = await self.mass.database.async_get_database_id( - provider_id, item_id, MediaType.Playlist + db_item = await self.mass.database.async_get_playlist_by_prov_id( + provider_id, item_id ) - if db_id is None: + if not db_item: # item not yet in local database so fetch and store details provider = self.mass.get_provider(provider_id) if not provider.available: return None item_details = await provider.async_get_playlist(item_id) - db_id = await self.mass.database.async_add_playlist(item_details) - return await self.mass.database.async_get_playlist(db_id) + db_item = await self.mass.database.async_add_playlist(item_details) + return db_item async def async_get_radio(self, item_id: str, provider_id: str) -> Radio: """Return radio details for the given provider playlist id.""" assert item_id and provider_id - db_id = await self.mass.database.async_get_database_id( - provider_id, item_id, MediaType.Radio + db_item = await self.mass.database.async_get_radio_by_prov_id( + provider_id, item_id ) - if db_id is None: + if not db_item: # item not yet in local database so fetch and store details provider = self.mass.get_provider(provider_id) if not provider.available: return None item_details = await provider.async_get_radio(item_id) - db_id = await self.mass.database.async_add_radio(item_details) - return await self.mass.database.async_get_radio(db_id) + db_item = await self.mass.database.async_add_radio(item_details) + return db_item async def async_get_album_tracks( self, item_id: str, provider_id: str ) -> List[Track]: - """Return album tracks for the given provider album id. Generator.""" + """Return album tracks for the given provider album id.""" assert item_id and provider_id album = await self.async_get_album(item_id, provider_id) if album.provider == "database": @@ -260,54 +259,57 @@ class MusicManager: item_id = album.provider_ids[0].item_id provider = self.mass.get_provider(provider_id) cache_key = f"{provider_id}.album_tracks.{item_id}" + result = [] async with self.mass.database.db_conn() as db_conn: - async for item in async_cached_generator( - self.cache, cache_key, provider.async_get_album_tracks(item_id) + for item in await async_cached( + self.cache, cache_key, provider.async_get_album_tracks, item_id ): if not item: continue - db_id = await self.mass.database.async_get_database_id( - item.provider, item.item_id, MediaType.Track, db_conn + db_item = await self.mass.database.async_get_track_by_prov_id( + item.provider, item.item_id, db_conn ) - if db_id: + if db_item: # return database track instead if we have a match - track = await self.mass.database.async_get_track( - db_id, fulldata=False, db_conn=db_conn - ) + track = db_item track.disc_number = item.disc_number track.track_number = item.track_number else: track = item if not track.album: track.album = album - yield track + result.append(track) + return result async def async_get_album_versions( self, item_id: str, provider_id: str ) -> List[Album]: - """Return all versions of an album we can find on all providers. Generator.""" + """Return all versions of an album we can find on all providers.""" album = await self.async_get_album(item_id, provider_id) provider_ids = [ item.id for item in self.mass.get_providers(ProviderType.MUSIC_PROVIDER) ] search_query = f"{album.artist.name} - {album.name}" + result = [] for prov_id in provider_ids: provider_result = await self.async_search_provider( search_query, prov_id, [MediaType.Album], 25 ) for item in provider_result.albums: if compare_strings(item.artist.name, album.artist.name): - yield item + result.append(item) + return result async def async_get_track_versions( self, item_id: str, provider_id: str ) -> List[Track]: - """Return all versions of a track we can find on all providers. Generator.""" + """Return all versions of a track we can find on all providers.""" track = await self.async_get_track(item_id, provider_id) provider_ids = [ item.id for item in self.mass.get_providers(ProviderType.MUSIC_PROVIDER) ] search_query = f"{track.artists[0].name} - {track.name}" + result = [] for prov_id in provider_ids: provider_result = await self.async_search_provider( search_query, prov_id, [MediaType.Track], 25 @@ -318,56 +320,61 @@ class MusicManager: for artist in item.artists: # artist must match if compare_strings(artist.name, track.artists[0].name): - yield item + result.append(item) break + return result async def async_get_playlist_tracks( self, item_id: str, provider_id: str ) -> List[Track]: - """Return playlist tracks for the given provider playlist id. Generator.""" + """Return playlist tracks for the given provider playlist id.""" assert item_id and provider_id if provider_id == "database": # playlist tracks are not stored in db, we always fetch them (cached) from the provider. - db_item = await self.mass.database.async_get_playlist(item_id) - provider_id = db_item.provider_ids[0].provider - item_id = db_item.provider_ids[0].item_id - provider = self.mass.get_provider(provider_id) - playlist = await provider.async_get_playlist(item_id) + playlist = await self.mass.database.async_get_playlist(item_id) + provider_id = playlist.provider_ids[0].provider + item_id = playlist.provider_ids[0].item_id + provider = self.mass.get_provider(provider_id) + else: + provider = self.mass.get_provider(provider_id) + playlist = await provider.async_get_playlist(item_id) cache_checksum = playlist.checksum cache_key = f"{provider_id}.playlist_tracks.{item_id}" - pos = 0 - async with self.mass.database.db_conn() as db_conn: - async for item in async_cached_generator( - self.cache, - cache_key, - provider.async_get_playlist_tracks(item_id), - checksum=cache_checksum, - ): - if not item: - continue - assert item.item_id and item.provider - db_id = await self.mass.database.async_get_database_id( - item.provider, item.item_id, MediaType.Track, db_conn=db_conn - ) - if db_id: - # return database track instead if we have a match - item = await self.mass.database.async_get_track( - db_id, fulldata=False, db_conn=db_conn - ) - item.position = pos - pos += 1 - yield item + playlist_tracks = await async_cached( + self.cache, + cache_key, + provider.async_get_playlist_tracks, + item_id, + checksum=cache_checksum, + ) + db_tracks = await self.mass.database.async_get_tracks_from_provider_ids( + provider_id, [x.item_id for x in playlist_tracks] + ) + # combine provider tracks with db tracks + return [ + await self.__process_track_details(item, index, db_tracks) + for index, item in enumerate(playlist_tracks) + ] + + async def __process_track_details(self, item, position, db_tracks): + for db_track in db_tracks: + if item.item_id in [x.item_id for x in db_track.provider_ids]: + db_track.position = position + return db_track + item.position = position + return item async def async_get_artist_toptracks( self, artist_id: str, provider_id: str ) -> List[Track]: - """Return top tracks for an artist. Generator.""" + """Return top tracks for an artist.""" async with self.mass.database.db_conn() as db_conn: if provider_id == "database": # tracks from all providers item_ids = [] + result = [] artist = await self.mass.database.async_get_artist( - artist_id, True, db_conn=db_conn + artist_id, db_conn=db_conn ) for prov_id in artist.provider_ids: provider = self.mass.get_provider(prov_id.provider) @@ -376,47 +383,49 @@ class MusicManager: or MediaType.Track not in provider.supported_mediatypes ): continue - async for item in self.async_get_artist_toptracks( + for item in await self.async_get_artist_toptracks( prov_id.item_id, prov_id.provider ): if item.item_id not in item_ids: - yield item + result.append(item) item_ids.append(item.item_id) + return result else: # items from provider provider = self.mass.get_provider(provider_id) cache_key = f"{provider_id}.artist_toptracks.{artist_id}" - async for item in async_cached_generator( + result = [] + for item in await async_cached( self.cache, cache_key, - provider.async_get_artist_toptracks(artist_id), + provider.async_get_artist_toptracks, + artist_id, ): if item: assert item.item_id and item.provider - db_id = await self.mass.database.async_get_database_id( + db_item = await self.mass.database.async_get_track_by_prov_id( item.provider, item.item_id, - MediaType.Track, db_conn=db_conn, ) - if db_id: + if db_item: # return database track instead if we have a match - yield await self.mass.database.async_get_track( - db_id, fulldata=False, db_conn=db_conn - ) + result.append(db_item) else: - yield item + result.append(item) + return result async def async_get_artist_albums( self, artist_id: str, provider_id: str ) -> List[Album]: - """Return (all) albums for an artist. Generator.""" + """Return (all) albums for an artist.""" async with self.mass.database.db_conn() as db_conn: if provider_id == "database": # albums from all providers item_ids = [] + result = [] artist = await self.mass.database.async_get_artist( - artist_id, True, db_conn=db_conn + artist_id, db_conn=db_conn ) for prov_id in artist.provider_ids: provider = self.mass.get_provider(prov_id.provider) @@ -425,94 +434,70 @@ class MusicManager: or MediaType.Album not in provider.supported_mediatypes ): continue - async for item in self.async_get_artist_albums( + for item in await self.async_get_artist_albums( prov_id.item_id, prov_id.provider ): if item.item_id not in item_ids: - yield item + result.append(item) item_ids.append(item.item_id) + return result else: # items from provider provider = self.mass.get_provider(provider_id) cache_key = f"{provider_id}.artist_albums.{artist_id}" - async for item in async_cached_generator( - self.cache, cache_key, provider.async_get_artist_albums(artist_id) + result = [] + for item in await async_cached( + self.cache, cache_key, provider.async_get_artist_albums, artist_id ): assert item.item_id and item.provider - db_id = await self.mass.database.async_get_database_id( - item.provider, item.item_id, MediaType.Album, db_conn=db_conn + db_item = await self.mass.database.async_get_album_by_prov_id( + item.provider, item.item_id, db_conn=db_conn ) - if db_id: + if db_item: # return database album instead if we have a match - yield await self.mass.database.async_get_album( - db_id, db_conn=db_conn - ) + result.append(db_item) else: - yield item + result.append(item) + return result ################ GET MediaItems that are added in the library ################ - async def async_get_library_artists( - self, orderby: str = "name", provider_filter: str = None - ) -> List[Artist]: - """Return all library artists, optionally filtered by provider. Generator.""" - async for item in self.mass.database.async_get_library_artists( - provider_id=provider_filter, orderby=orderby - ): - yield item + async def async_get_library_artists(self, orderby: str = "name") -> List[Artist]: + """Return all library artists, optionally filtered by provider.""" + return await self.mass.database.async_get_library_artists(orderby=orderby) - async def async_get_library_albums( - self, orderby: str = "name", provider_filter: str = None - ) -> List[Album]: - """Return all library albums, optionally filtered by provider. Generator.""" - async for item in self.mass.database.async_get_library_albums( - provider_id=provider_filter, orderby=orderby - ): - yield item + async def async_get_library_albums(self, orderby: str = "name") -> List[Album]: + """Return all library albums, optionally filtered by provider.""" + return await self.mass.database.async_get_library_albums(orderby=orderby) - async def async_get_library_tracks( - self, orderby: str = "name", provider_filter: str = None - ) -> List[Track]: - """Return all library tracks, optionally filtered by provider. Generator.""" - async for item in self.mass.database.async_get_library_tracks( - provider_id=provider_filter, orderby=orderby - ): - yield item + async def async_get_library_tracks(self, orderby: str = "name") -> List[Track]: + """Return all library tracks, optionally filtered by provider.""" + return await self.mass.database.async_get_library_tracks(orderby=orderby) async def async_get_library_playlists( - self, orderby: str = "name", provider_filter: str = None + self, orderby: str = "name" ) -> List[Playlist]: - """Return all library playlists, optionally filtered by provider. Generator.""" - async for item in self.mass.database.async_get_library_playlists( - provider_id=provider_filter, orderby=orderby - ): - yield item + """Return all library playlists, optionally filtered by provider.""" + return await self.mass.database.async_get_library_playlists(orderby=orderby) - async def async_get_library_radios( - self, orderby: str = "name", provider_filter: str = None - ) -> List[Playlist]: - """Return all library radios, optionally filtered by provider. Generator.""" - async for item in self.mass.database.async_get_library_radios( - provider_id=provider_filter, orderby=orderby - ): - yield item + async def async_get_library_radios(self, orderby: str = "name") -> List[Playlist]: + """Return all library radios, optionally filtered by provider.""" + return await self.mass.database.async_get_library_radios(orderby=orderby) ################ ADD MediaItem(s) to database helpers ################ async def __async_add_artist(self, artist: Artist) -> int: """Add artist to local db and return the new database id.""" - musicbrainz_id = artist.external_ids.get(ExternalId.MUSICBRAINZ) - if not musicbrainz_id: - musicbrainz_id = await self.__async_get_artist_musicbrainz_id(artist) + if not artist.musicbrainz_id: + artist.musicbrainz_id = await self.__async_get_artist_musicbrainz_id(artist) # grab additional metadata - artist.external_ids[ExternalId.MUSICBRAINZ] = musicbrainz_id artist.metadata = await self.mass.metadata.async_get_artist_metadata( - musicbrainz_id, artist.metadata + artist.musicbrainz_id, artist.metadata ) - db_id = await self.mass.database.async_add_artist(artist) + db_item = await self.mass.database.async_add_artist(artist) # also fetch same artist on all providers - await self.__async_match_artist(db_id) - return db_id + await self.__async_match_artist(db_item) + return db_item async def __async_add_album(self, album: Album) -> int: """Add album to local db and return the new database id.""" @@ -520,13 +505,13 @@ class MusicManager: album.artist = await self.async_get_artist( album.artist.item_id, album.artist.provider, lazy=False ) - db_id = await self.mass.database.async_add_album(album) + db_item = await self.mass.database.async_add_album(album) # also fetch same album on all providers - await self.__async_match_album(db_id) - return db_id + await self.__async_match_album(db_item) + return db_item async def __async_add_track( - self, track: Track, album_id: Optional[str] = None + self, track: Track, album_id: Optional[int] = None ) -> int: """Add track to local db and return the new database id.""" track_artists = [] @@ -551,15 +536,15 @@ class MusicManager: track.album = await self.async_get_album( track.album.item_id, track.provider, lazy=False ) - db_id = await self.mass.database.async_add_track(track) + db_item = await self.mass.database.async_add_track(track) # also fetch same track on all providers (will also get other quality versions) - await self.__async_match_track(db_id) - return db_id + await self.__async_match_track(db_item) + return db_item async def __async_get_artist_musicbrainz_id(self, artist: Artist): """Fetch musicbrainz id by performing search using the artist name, albums and tracks.""" # try with album first - async for lookup_album in self.async_get_artist_albums( + for lookup_album in await self.async_get_artist_albums( artist.item_id, artist.provider ): if not lookup_album: @@ -567,12 +552,12 @@ class MusicManager: musicbrainz_id = await self.musicbrainz.async_get_mb_artist_id( artist.name, albumname=lookup_album.name, - album_upc=lookup_album.external_ids.get(ExternalId.UPC), + album_upc=lookup_album.upc, ) if musicbrainz_id: return musicbrainz_id # fallback to track - async for lookup_track in self.async_get_artist_toptracks( + for lookup_track in await self.async_get_artist_toptracks( artist.item_id, artist.provider ): if not lookup_track: @@ -580,7 +565,7 @@ class MusicManager: musicbrainz_id = await self.musicbrainz.async_get_mb_artist_id( artist.name, trackname=lookup_track.name, - track_isrc=lookup_track.external_ids.get(ExternalId.ISRC), + track_isrc=lookup_track.isrc, ) if musicbrainz_id: return musicbrainz_id @@ -588,18 +573,17 @@ class MusicManager: LOGGER.warning("Unable to get musicbrainz ID for artist %s !", artist.name) return artist.name - async def __async_match_artist(self, db_artist_id: int): + async def __async_match_artist(self, artist: Artist): """ Try to find matching artists on all providers for the provided (database) artist_id. This is used to link objects of different providers together. :attrib db_artist_id: Database artist_id. """ - match_job_id = f"artist.{db_artist_id}" + match_job_id = f"artist.{artist.item_id}" if match_job_id in self._match_jobs: return self._match_jobs.append(match_job_id) - artist = await self.mass.database.async_get_artist(db_artist_id) cur_providers = [item.provider for item in artist.provider_ids] for provider in self.mass.get_providers(ProviderType.MUSIC_PROVIDER): if provider.id in cur_providers: @@ -609,7 +593,7 @@ class MusicManager: ) match_found = False # try to get a match with some reference albums of this artist - async for ref_album in self.async_get_artist_albums( + for ref_album in await self.async_get_artist_albums( artist.item_id, artist.provider ): if match_found: @@ -645,7 +629,7 @@ class MusicManager: break # try to get a match with some reference tracks of this artist if not match_found: - async for search_track in self.async_get_artist_toptracks( + for search_track in await self.async_get_artist_toptracks( artist.item_id, artist.provider ): if match_found: @@ -695,18 +679,17 @@ class MusicManager: provider.name, ) - async def __async_match_album(self, db_album_id: int): + async def __async_match_album(self, album: Album): """ Try to find matching album on all providers for the provided (database) album_id. This is used to link objects of different providers/qualities together. :attrib db_album_id: Database album_id. """ - match_job_id = f"album.{db_album_id}" + match_job_id = f"album.{album.item_id}" if match_job_id in self._match_jobs: return self._match_jobs.append(match_job_id) - album = await self.mass.database.async_get_album(db_album_id) cur_providers = [item.provider for item in album.provider_ids] providers = self.mass.get_providers(ProviderType.MUSIC_PROVIDER) for provider in providers: @@ -755,18 +738,17 @@ class MusicManager: provider.name, ) - async def __async_match_track(self, db_track_id: int): + async def __async_match_track(self, track: Track): """ Try to find matching track on all providers for the provided (database) track_id. This is used to link objects of different providers/qualities together. :attrib db_track_id: Database track_id. """ - match_job_id = f"track.{db_track_id}" + match_job_id = f"track.{track.item_id}" if match_job_id in self._match_jobs: return self._match_jobs.append(match_job_id) - track = await self.mass.database.async_get_track(db_track_id, fulldata=False) for provider in self.mass.get_providers(ProviderType.MUSIC_PROVIDER): LOGGER.debug( "Trying to match track %s on provider %s", track.name, provider.name @@ -825,14 +807,14 @@ class MusicManager: async def async_get_library_playlist_by_name(self, name: str) -> Playlist: """Get in-library playlist by name.""" - async for playlist in self.async_get_library_playlists(): + for playlist in await self.async_get_library_playlists(): if playlist.name == name: return playlist return None async def async_get_radio_by_name(self, name: str) -> Radio: """Get in-library radio by name.""" - async for radio in self.async_get_library_radios(): + for radio in await self.async_get_library_radios(): if radio.name == name: return radio return None @@ -860,7 +842,10 @@ class MusicManager: return await async_cached( self.cache, cache_key, - provider.async_search(search_query, media_types, limit), + provider.async_search, + search_query, + media_types, + limit, ) async def async_global_search( @@ -952,7 +937,7 @@ class MusicManager: playlist_prov = playlist.provider_ids[0] # grab all existing track ids in the playlist so we can check for duplicates cur_playlist_track_ids = [] - async for item in self.async_get_playlist_tracks( + for item in await self.async_get_playlist_tracks( playlist_prov.item_id, playlist_prov.provider ): cur_playlist_track_ids.append(item.item_id) @@ -1034,12 +1019,10 @@ class MusicManager: return cache_file_sized # no file in cache so we should get it img_url = "" - # we only retrieve items that we already have in cache - item = None - if await self.mass.database.async_get_database_id( + # we only retrieve items that we already have in database + item = await self.mass.database.async_get_item_by_prov_id( provider_id, item_id, media_type - ): - item = await self.async_get_item(item_id, provider_id, media_type) + ) if not item: return "" if item and item.metadata.get("image"): @@ -1175,65 +1158,57 @@ class MusicManager: async def async_library_artists_sync(self, provider_id: str): """Sync library artists for given provider.""" music_provider = self.mass.get_provider(provider_id) - prev_db_ids = [ - item.item_id - async for item in self.async_get_library_artists( - provider_filter=provider_id - ) - ] + cache_key = f"library_artists_{provider_id}" + prev_db_ids = await self.mass.cache.async_get(cache_key, default=[]) cur_db_ids = [] - async for item in music_provider.async_get_library_artists(): + for item in await music_provider.async_get_library_artists(): db_item = await self.async_get_artist(item.item_id, provider_id, lazy=False) cur_db_ids.append(db_item.item_id) - if db_item.item_id not in prev_db_ids: - await self.mass.database.async_add_to_library( - db_item.item_id, MediaType.Artist, provider_id - ) + await self.mass.database.async_add_to_library( + db_item.item_id, MediaType.Artist, provider_id + ) # process deletions for db_id in prev_db_ids: if db_id not in cur_db_ids: await self.mass.database.async_remove_from_library( db_id, MediaType.Artist, provider_id ) + # store ids in cache for next sync + await self.mass.cache.async_set(cache_key, cur_db_ids) @sync_task("albums") async def async_library_albums_sync(self, provider_id: str): """Sync library albums for given provider.""" music_provider = self.mass.get_provider(provider_id) - prev_db_ids = [ - item.item_id - async for item in self.async_get_library_albums(provider_filter=provider_id) - ] + cache_key = f"library_albums_{provider_id}" + prev_db_ids = await self.mass.cache.async_get(cache_key, default=[]) cur_db_ids = [] - async for item in music_provider.async_get_library_albums(): + for item in await music_provider.async_get_library_albums(): db_album = await self.async_get_album( item.item_id, provider_id, album_details=item, lazy=False ) - if not db_album: - LOGGER.error("provider %s album: %s", provider_id, str(item)) cur_db_ids.append(db_album.item_id) - if db_album.item_id not in prev_db_ids: - await self.mass.database.async_add_to_library( - db_album.item_id, MediaType.Album, provider_id - ) + await self.mass.database.async_add_to_library( + db_album.item_id, MediaType.Album, provider_id + ) # process deletions for db_id in prev_db_ids: if db_id not in cur_db_ids: await self.mass.database.async_remove_from_library( db_id, MediaType.Album, provider_id ) + # store ids in cache for next sync + await self.mass.cache.async_set(cache_key, cur_db_ids) @sync_task("tracks") async def async_library_tracks_sync(self, provider_id: str): """Sync library tracks for given provider.""" music_provider = self.mass.get_provider(provider_id) - prev_db_ids = [ - item.item_id - async for item in self.async_get_library_tracks(provider_filter=provider_id) - ] + cache_key = f"library_tracks_{provider_id}" + prev_db_ids = await self.mass.cache.async_get(cache_key, default=[]) cur_db_ids = [] - async for item in music_provider.async_get_library_tracks(): + for item in await music_provider.async_get_library_tracks(): db_item = await self.async_get_track( item.item_id, provider_id=provider_id, lazy=False ) @@ -1248,61 +1223,52 @@ class MusicManager: await self.mass.database.async_remove_from_library( db_id, MediaType.Track, provider_id ) + # store ids in cache for next sync + await self.mass.cache.async_set(cache_key, cur_db_ids) @sync_task("playlists") async def async_library_playlists_sync(self, provider_id: str): """Sync library playlists for given provider.""" music_provider = self.mass.get_provider(provider_id) - prev_db_ids = [ - item.item_id - async for item in self.async_get_library_playlists( - provider_filter=provider_id - ) - ] + cache_key = f"library_playlists_{provider_id}" + prev_db_ids = await self.mass.cache.async_get(cache_key, default=[]) cur_db_ids = [] - async for playlist in music_provider.async_get_library_playlists(): - if playlist is None: - continue + for playlist in await music_provider.async_get_library_playlists(): # always add to db because playlist attributes could have changed - db_id = await self.mass.database.async_add_playlist(playlist) - cur_db_ids.append(db_id) - if db_id not in prev_db_ids: - await self.mass.database.async_add_to_library( - db_id, MediaType.Playlist, playlist.provider - ) - # We do not precache/store playlist tracks, these will be retrieved on request only + db_item = await self.mass.database.async_add_playlist(playlist) + cur_db_ids.append(db_item.item_id) + await self.mass.database.async_add_to_library( + db_item.item_id, MediaType.Playlist, playlist.provider + ) + # precache playlist tracks + await self.async_get_playlist_tracks(db_item.item_id, db_item.provider) # process playlist deletions for db_id in prev_db_ids: if db_id not in cur_db_ids: await self.mass.database.async_remove_from_library( db_id, MediaType.Playlist, provider_id ) + # store ids in cache for next sync + await self.mass.cache.async_set(cache_key, cur_db_ids) @sync_task("radios") async def async_library_radios_sync(self, provider_id: str): """Sync library radios for given provider.""" music_provider = self.mass.get_provider(provider_id) - prev_db_ids = [ - item.item_id - async for item in self.async_get_library_radios(provider_filter=provider_id) - ] + cache_key = f"library_radios_{provider_id}" + prev_db_ids = await self.mass.cache.async_get(cache_key, default=[]) cur_db_ids = [] - async for item in music_provider.async_get_radios(): - if not item: - continue - db_id = await self.mass.database.async_get_database_id( - item.provider, item.item_id, MediaType.Radio + for item in await music_provider.async_get_library_radios(): + db_radio = await self.async_get_radio(item.item_id, provider_id) + cur_db_ids.append(db_radio.item_id) + await self.mass.database.async_add_to_library( + db_radio.item_id, MediaType.Radio, provider_id ) - if not db_id: - db_id = await self.mass.database.async_add_radio(item) - cur_db_ids.append(db_id) - if db_id not in prev_db_ids: - await self.mass.database.async_add_to_library( - db_id, MediaType.Radio, provider_id - ) # process deletions for db_id in prev_db_ids: if db_id not in cur_db_ids: await self.mass.database.async_remove_from_library( db_id, MediaType.Radio, provider_id ) + # store ids in cache for next sync + await self.mass.cache.async_set(cache_key, cur_db_ids) diff --git a/music_assistant/managers/players.py b/music_assistant/managers/players.py index 48aeefd4..739776f0 100755 --- a/music_assistant/managers/players.py +++ b/music_assistant/managers/players.py @@ -14,12 +14,7 @@ from music_assistant.constants import ( EVENT_UNREGISTER_PLAYER_CONTROL, ) from music_assistant.helpers.typing import MusicAssistantType -from music_assistant.helpers.util import ( - async_iter_items, - callback, - run_periodic, - try_parse_int, -) +from music_assistant.helpers.util import callback, run_periodic, try_parse_int from music_assistant.models.media_types import MediaItem, MediaType, Track from music_assistant.models.player import ( PlaybackState, @@ -266,20 +261,22 @@ class PlayerManager: for media_item in media_items: # collect tracks to play if media_item.media_type == MediaType.Artist: - tracks = self.mass.music.async_get_artist_toptracks( + tracks = await self.mass.music.async_get_artist_toptracks( media_item.item_id, provider_id=media_item.provider ) elif media_item.media_type == MediaType.Album: - tracks = self.mass.music.async_get_album_tracks( + tracks = await self.mass.music.async_get_album_tracks( media_item.item_id, provider_id=media_item.provider ) elif media_item.media_type == MediaType.Playlist: - tracks = self.mass.music.async_get_playlist_tracks( + tracks = await self.mass.music.async_get_playlist_tracks( media_item.item_id, provider_id=media_item.provider ) else: - tracks = async_iter_items(media_item) # single track - async for track in tracks: + tracks = [media_item] # single track + for track in tracks: + if not track.available: + continue queue_item = QueueItem(track) # generate uri for this queue item queue_item.uri = "%s/stream/queue/%s/%s" % ( diff --git a/music_assistant/models/media_types.py b/music_assistant/models/media_types.py index 6738f4a0..1f60367a 100755 --- a/music_assistant/models/media_types.py +++ b/music_assistant/models/media_types.py @@ -1,46 +1,47 @@ """Models and helpers for media items.""" from dataclasses import dataclass, field -from enum import Enum -from typing import Any, List +from enum import Enum, IntEnum +from typing import Any, List, Mapping +import ujson from mashumaro import DataClassDictMixin -from music_assistant.helpers.util import CustomIntEnum +from music_assistant.helpers.util import get_sort_name -class MediaType(CustomIntEnum): +class MediaType(Enum): """Enum for MediaType.""" - Artist = 1 - Album = 2 - Track = 3 - Playlist = 4 - Radio = 5 + Artist = "artist" + Album = "album" + Track = "track" + Playlist = "playlist" + Radio = "radio" -class ContributorRole(CustomIntEnum): +class ContributorRole(Enum): """Enum for Contributor Role.""" - Artist = 1 - Writer = 2 - Producer = 3 + Artist = "artist" + Writer = "writer" + Producer = "producer" -class AlbumType(CustomIntEnum): +class AlbumType(Enum): """Enum for Album type.""" - Album = 1 - Single = 2 - Compilation = 3 + Album = "album" + Single = "single" + Compilation = "compilation" -class TrackQuality(CustomIntEnum): +class TrackQuality(IntEnum): """Enum for Track Quality.""" LOSSY_MP3 = 0 LOSSY_OGG = 1 LOSSY_AAC = 2 - FLAC_LOSSLESS = 6 # 44.1/48khz 16 bits HI-RES + FLAC_LOSSLESS = 6 # 44.1/48khz 16 bits FLAC_LOSSLESS_HI_RES_1 = 7 # 44.1/48khz 24 bits HI-RES FLAC_LOSSLESS_HI_RES_2 = 8 # 88.2/96khz 24 bits HI-RES FLAC_LOSSLESS_HI_RES_3 = 9 # 176/192khz 24 bits HI-RES @@ -56,14 +57,7 @@ class MediaItemProviderId(DataClassDictMixin): item_id: str quality: TrackQuality = TrackQuality.UNKNOWN details: str = None - - -class ExternalId(Enum): - """Enum with external id's.""" - - MUSICBRAINZ = "musicbrainz" - UPC = "upc" - ISRC = "isrc" + available: bool = True @dataclass @@ -74,12 +68,33 @@ class MediaItem(DataClassDictMixin): provider: str = "" name: str = "" metadata: Any = field(default_factory=dict) - tags: List[str] = field(default_factory=list) - external_ids: Any = field(default_factory=dict) provider_ids: List[MediaItemProviderId] = field(default_factory=list) - in_library: List[str] = field(default_factory=list) + in_library: bool = False is_lazy: bool = False - available: bool = True + + @classmethod + def from_db_row(cls, db_row: Mapping): + """Create MediaItem object from database row.""" + db_row = dict(db_row) + for key in ["artists", "artist", "album", "metadata", "provider_ids"]: + if key in db_row: + db_row[key] = ujson.loads(db_row[key]) + db_row["provider"] = "database" + if "in_library" in db_row: + db_row["in_library"] = bool(db_row["in_library"]) + return cls.from_dict(db_row) + + @property + def sort_name(self): + """Return sort name.""" + return get_sort_name(self.name) + + @property + def available(self): + """Return (calculated) availability.""" + for item in self.provider_ids: + if item.available: + return True @dataclass @@ -87,7 +102,17 @@ class Artist(MediaItem): """Model for an artist.""" media_type: MediaType = MediaType.Artist - sort_name: str = "" + musicbrainz_id: str = "" + + +@dataclass +class AlbumArtist(DataClassDictMixin): + """Representation of a minimized artist object.""" + + item_id: str = "" + provider: str = "" + name: str = "" + media_type: MediaType = MediaType.Artist @dataclass @@ -97,9 +122,29 @@ class Album(MediaItem): media_type: MediaType = MediaType.Album version: str = "" year: int = 0 - artist: Artist = None - labels: List[str] = field(default_factory=list) + artist: AlbumArtist = None album_type: AlbumType = AlbumType.Album + upc: str = "" + + +@dataclass +class TrackArtist(DataClassDictMixin): + """Representation of a minimized artist object.""" + + item_id: str = "" + provider: str = "" + name: str = "" + media_type: MediaType = MediaType.Artist + + +@dataclass +class TrackAlbum(DataClassDictMixin): + """Representation of a minimized album object.""" + + item_id: str = "" + provider: str = "" + name: str = "" + media_type: MediaType = MediaType.Album @dataclass @@ -109,10 +154,12 @@ class Track(MediaItem): media_type: MediaType = MediaType.Track duration: int = 0 version: str = "" - artists: List[Artist] = field(default_factory=list) - album: Album = None + artists: List[TrackArtist] = field(default_factory=list) + album: TrackAlbum = None disc_number: int = 1 track_number: int = 1 + position: int = 0 + isrc: str = "" @dataclass diff --git a/music_assistant/models/player.py b/music_assistant/models/player.py index 37450e76..3b7cf7b6 100755 --- a/music_assistant/models/player.py +++ b/music_assistant/models/player.py @@ -2,13 +2,13 @@ from abc import abstractmethod from dataclasses import dataclass -from enum import Enum +from enum import Enum, IntEnum from typing import Any, List, Optional from mashumaro import DataClassDictMixin from music_assistant.constants import EVENT_SET_PLAYER_CONTROL_STATE from music_assistant.helpers.typing import MusicAssistantType, QueueItems -from music_assistant.helpers.util import CustomIntEnum, callback +from music_assistant.helpers.util import callback from music_assistant.models.config_entry import ConfigEntry @@ -30,7 +30,7 @@ class DeviceInfo(DataClassDictMixin): manufacturer: str = "" -class PlayerFeature(CustomIntEnum): +class PlayerFeature(IntEnum): """Enum for player features.""" QUEUE = 0 @@ -270,7 +270,7 @@ class Player: self.mass.add_job(self.mass.players.async_update_player(self)) -class PlayerControlType(CustomIntEnum): +class PlayerControlType(Enum): """Enum with different player control types.""" POWER = 0 diff --git a/music_assistant/providers/file/__init__.py b/music_assistant/providers/file/__init__.py index 5e4c1798..45d96452 100644 --- a/music_assistant/providers/file/__init__.py +++ b/music_assistant/providers/file/__init__.py @@ -350,9 +350,9 @@ class FileProvider(MusicProvider): artists.append(artist) track.artists = artists if "GENRE" in song.tags: - track.tags = song.tags["GENRE"] + track.metadata["genres"] = song.tags["GENRE"] if "ISRC" in song.tags: - track.external_ids["isrc"] = song.tags["ISRC"][0] + track.isrc = song.tags["ISRC"][0] if "DISCNUMBER" in song.tags: track.disc_number = int(song.tags["DISCNUMBER"][0]) if "TRACKNUMBER" in song.tags: diff --git a/music_assistant/providers/qobuz/__init__.py b/music_assistant/providers/qobuz/__init__.py index 49e692e0..3f43b23c 100644 --- a/music_assistant/providers/qobuz/__init__.py +++ b/music_assistant/providers/qobuz/__init__.py @@ -124,130 +124,146 @@ class QobuzProvider(MusicProvider): searchresult = await self.__async_get_data("catalog/search", params) if searchresult: if "artists" in searchresult: - for item in searchresult["artists"]["items"]: - artist = await self.__async_parse_artist(item) - if artist: - result.artists.append(artist) + result.artists = [ + await self.__async_parse_artist(item) + for item in searchresult["artists"]["items"] + if (item and item["id"]) + ] if "albums" in searchresult: - for item in searchresult["albums"]["items"]: - album = await self.__async_parse_album(item) - if album: - result.albums.append(album) + result.albums = [ + await self.__async_parse_album(item) + for item in searchresult["albums"]["items"] + if (item and item["id"]) + ] if "tracks" in searchresult: - for item in searchresult["tracks"]["items"]: - track = await self.__async_parse_track(item) - if track: - result.tracks.append(track) + result.tracks = [ + await self.__async_parse_track(item) + for item in searchresult["tracks"]["items"] + if (item and item["id"]) + ] if "playlists" in searchresult: - for item in searchresult["playlists"]["items"]: - playlist = await self.__async_parse_playlist(item) - if playlist: - result.playlists.append(playlist) + result.playlists = [ + await self.__async_parse_playlist(item) + for item in searchresult["playlists"]["items"] + if (item and item["id"]) + ] return result async def async_get_library_artists(self) -> List[Artist]: """Retrieve all library artists from Qobuz.""" params = {"type": "artists"} endpoint = "favorite/getUserFavorites" - async for item in self.__async_get_all_items(endpoint, params, key="artists"): - artist = await self.__async_parse_artist(item) - if artist: - yield artist + return [ + await self.__async_parse_artist(item) + for item in await self.__async_get_all_items( + endpoint, params, key="artists" + ) + if (item and item["id"]) + ] async def async_get_library_albums(self) -> List[Album]: """Retrieve all library albums from Qobuz.""" params = {"type": "albums"} endpoint = "favorite/getUserFavorites" - async for item in self.__async_get_all_items(endpoint, params, key="albums"): - album = await self.__async_parse_album(item) - if album: - yield album + return [ + await self.__async_parse_album(item) + for item in await self.__async_get_all_items(endpoint, params, key="albums") + if (item and item["id"]) + ] async def async_get_library_tracks(self) -> List[Track]: """Retrieve library tracks from Qobuz.""" params = {"type": "tracks"} endpoint = "favorite/getUserFavorites" - async for item in self.__async_get_all_items(endpoint, params, key="tracks"): - track = await self.__async_parse_track(item) - if track: - yield track + return [ + await self.__async_parse_track(item) + for item in await self.__async_get_all_items(endpoint, params, key="tracks") + if (item and item["id"]) + ] async def async_get_library_playlists(self) -> List[Playlist]: """Retrieve all library playlists from the provider.""" endpoint = "playlist/getUserPlaylists" - async for item in self.__async_get_all_items(endpoint, key="playlists"): - playlist = await self.__async_parse_playlist(item) - if playlist: - yield playlist + return [ + await self.__async_parse_playlist(item) + for item in await self.__async_get_all_items(endpoint, key="playlists") + if (item and item["id"]) + ] async def async_get_radios(self) -> List[Radio]: """Retrieve library/subscribed radio stations from the provider.""" - yield None # TODO + return [] # TODO async def async_get_artist(self, prov_artist_id) -> Artist: """Get full artist details by id.""" params = {"artist_id": prov_artist_id} artist_obj = await self.__async_get_data("artist/get", params) - return await self.__async_parse_artist(artist_obj) + return ( + await self.__async_parse_artist(artist_obj) + if artist_obj and artist_obj["id"] + else None + ) async def async_get_album(self, prov_album_id) -> Album: """Get full album details by id.""" params = {"album_id": prov_album_id} album_obj = await self.__async_get_data("album/get", params) - return await self.__async_parse_album(album_obj) + return ( + await self.__async_parse_album(album_obj) + if album_obj and album_obj["id"] + else None + ) async def async_get_track(self, prov_track_id) -> Track: """Get full track details by id.""" params = {"track_id": prov_track_id} track_obj = await self.__async_get_data("track/get", params) - return await self.__async_parse_track(track_obj) + return ( + await self.__async_parse_track(track_obj) + if track_obj and track_obj["id"] + else None + ) async def async_get_playlist(self, prov_playlist_id) -> Playlist: """Get full playlist details by id.""" params = {"playlist_id": prov_playlist_id} playlist_obj = await self.__async_get_data("playlist/get", params) - return await self.__async_parse_playlist(playlist_obj) + return ( + await self.__async_parse_playlist(playlist_obj) + if playlist_obj and playlist_obj["id"] + else None + ) async def async_get_album_tracks(self, prov_album_id) -> List[Track]: """Get all album tracks for given album id.""" params = {"album_id": prov_album_id} - async for item in self.__async_get_all_items("album/get", params, key="tracks"): - track = await self.__async_parse_track(item) - if track: - yield track - else: - LOGGER.warning( - "Unavailable track found in album %s: %s", - prov_album_id, - item["title"], - ) + return [ + await self.__async_parse_track(item) + for item in await self.__async_get_all_items( + "album/get", params, key="tracks" + ) + if (item and item["id"]) + ] async def async_get_playlist_tracks(self, prov_playlist_id) -> List[Track]: """Get all playlist tracks for given playlist id.""" params = {"playlist_id": prov_playlist_id, "extra": "tracks"} endpoint = "playlist/get" - async for item in self.__async_get_all_items(endpoint, params, key="tracks"): - playlist_track = await self.__async_parse_track(item) - if playlist_track: - yield playlist_track - else: - LOGGER.warning( - "Unavailable track found in playlist %s: %s", - prov_playlist_id, - item["title"], - ) - # TODO: should we look for an alternative - # track version if the original is marked unavailable ? + return [ + await self.__async_parse_track(item) + for item in await self.__async_get_all_items(endpoint, params, key="tracks") + if (item and item["id"]) + ] async def async_get_artist_albums(self, prov_artist_id) -> List[Album]: """Get a list of albums for the given artist.""" params = {"artist_id": prov_artist_id, "extra": "albums"} endpoint = "artist/get" - async for item in self.__async_get_all_items(endpoint, params, key="albums"): - if str(item["artist"]["id"]) == str(prov_artist_id): - album = await self.__async_parse_album(item) - if album: - yield album + return [ + await self.__async_parse_album(item) + for item in await self.__async_get_all_items(endpoint, params, key="albums") + if (item and item["id"]) + ] async def async_get_artist_toptracks(self, prov_artist_id) -> List[Track]: """Get a list of most popular tracks for the given artist.""" @@ -256,13 +272,16 @@ class QobuzProvider(MusicProvider): artist = await self.async_get_artist(prov_artist_id) params = {"query": artist.name, "limit": 25, "type": "tracks"} searchresult = await self.__async_get_data("catalog/search", params) - for item in searchresult["tracks"]["items"]: - if "performer" in item and str(item["performer"]["id"]) == str( - prov_artist_id - ): - track = await self.__async_parse_track(item) - if track: - yield track + return [ + await self.__async_parse_track(item) + for item in searchresult["tracks"]["items"] + if ( + item + and item["id"] + and "performer" in item + and str(item["performer"]["id"]) == str(prov_artist_id) + ) + ] async def async_library_add(self, prov_item_id, media_type: MediaType): """Add item to library.""" @@ -415,15 +434,12 @@ class QobuzProvider(MusicProvider): async def __async_parse_artist(self, artist_obj): """Parse qobuz artist object to generic layout.""" - artist = Artist() - if not artist_obj or not artist_obj.get("id"): - return None - artist.item_id = str(artist_obj["id"]) - artist.provider = PROV_ID + artist = Artist( + item_id=str(artist_obj["id"]), provider=PROV_ID, name=artist_obj["name"] + ) artist.provider_ids.append( MediaItemProviderId(provider=PROV_ID, item_id=str(artist_obj["id"])) ) - artist.name = artist_obj["name"] if artist_obj.get("image"): for key in ["extralarge", "large", "medium", "small"]: if artist_obj["image"].get(key): @@ -439,19 +455,9 @@ class QobuzProvider(MusicProvider): artist.metadata["qobuz_url"] = artist_obj["url"] return artist - async def __async_parse_album(self, album_obj): + async def __async_parse_album(self, album_obj: dict, artist_obj: dict = None): """Parse qobuz album object to generic layout.""" - album = Album() - if ( - not album_obj - or not album_obj.get("id") - or not album_obj["streamable"] - or not album_obj["displayable"] - ): - # do not return unavailable items - return None - album.item_id = str(album_obj["id"]) - album.provider = PROV_ID + album = Album(item_id=str(album_obj["id"]), provider=PROV_ID) if album_obj["maximum_sampling_rate"] > 192: quality = TrackQuality.FLAC_LOSSLESS_HI_RES_4 elif album_obj["maximum_sampling_rate"] > 96: @@ -469,25 +475,28 @@ class QobuzProvider(MusicProvider): provider=PROV_ID, item_id=str(album_obj["id"]), quality=quality, - details=f'{album_obj["maximum_sampling_rate"]}kHz \ - {album_obj["maximum_bit_depth"]}bit', + details=f'{album_obj["maximum_sampling_rate"]}kHz {album_obj["maximum_bit_depth"]}bit', + available=album_obj["streamable"] and album_obj["displayable"], ) ) album.name, album.version = parse_title_and_version( album_obj["title"], album_obj.get("version") ) - album.artist = await self.__async_parse_artist(album_obj["artist"]) + if artist_obj: + album.artist = artist_obj + else: + album.artist = await self.__async_parse_artist(album_obj["artist"]) if album_obj.get("product_type", "") == "single": - album.albumtype = AlbumType.Single + album.album_type = AlbumType.Single elif ( album_obj.get("product_type", "") == "compilation" - or "Various" in album_obj["artist"]["name"] + or "Various" in album.artist.name ): - album.albumtype = AlbumType.Compilation + album.album_type = AlbumType.Compilation else: - album.albumtype = AlbumType.Album + album.album_type = AlbumType.Album if "genre" in album_obj: - album.tags = [album_obj["genre"]["name"]] + album.metadata["genre"] = album_obj["genre"]["name"] if album_obj.get("image"): for key in ["extralarge", "large", "medium", "small"]: if album_obj["image"].get(key): @@ -495,12 +504,12 @@ class QobuzProvider(MusicProvider): break if len(album_obj["upc"]) == 13: # qobuz writes ean as upc ?! - album.external_ids["ean"] = album_obj["upc"] - album.external_ids["upc"] = album_obj["upc"][1:] + album.metadata["ean"] = album_obj["upc"] + album.upc = album_obj["upc"][1:] else: - album.external_ids["upc"] = album_obj["upc"] + album.upc = album_obj["upc"] if "label" in album_obj: - album.labels = album_obj["label"]["name"].split("/") + album.metadata["label"] = album_obj["label"]["name"] if album_obj.get("released_at"): album.year = datetime.datetime.fromtimestamp(album_obj["released_at"]).year if album_obj.get("copyright"): @@ -515,17 +524,13 @@ class QobuzProvider(MusicProvider): async def __async_parse_track(self, track_obj): """Parse qobuz track object to generic layout.""" - track = Track() - if ( - not track_obj - or not track_obj.get("id") - or not track_obj["streamable"] - or not track_obj["displayable"] - ): - # do not return unavailable items - return None - track.item_id = str(track_obj["id"]) - track.provider = PROV_ID + track = Track( + item_id=str(track_obj["id"]), + provider=PROV_ID, + disc_number=track_obj["media_number"], + track_number=track_obj["track_number"], + duration=track_obj["duration"], + ) if track_obj.get("performer") and "Various " not in track_obj["performer"]: artist = await self.__async_parse_artist(track_obj["performer"]) if artist: @@ -554,19 +559,16 @@ class QobuzProvider(MusicProvider): track.name, track.version = parse_title_and_version( track_obj["title"], track_obj.get("version") ) - track.duration = track_obj["duration"] if "album" in track_obj: album = await self.__async_parse_album(track_obj["album"]) if album: track.album = album - track.disc_number = track_obj["media_number"] - track.track_number = track_obj["track_number"] if track_obj.get("hires"): track.metadata["hires"] = "true" if track_obj.get("url"): track.metadata["qobuz_url"] = track_obj["url"] if track_obj.get("isrc"): - track.external_ids["isrc"] = track_obj["isrc"] + track.isrc = track_obj["isrc"] if track_obj.get("performers"): track.metadata["performers"] = track_obj["performers"] if track_obj.get("copyright"): @@ -589,24 +591,23 @@ class QobuzProvider(MusicProvider): provider=PROV_ID, item_id=str(track_obj["id"]), quality=quality, - details=f'{track_obj["maximum_sampling_rate"]}kHz \ - {track_obj["maximum_bit_depth"]}bit', + details=f'{track_obj["maximum_sampling_rate"]}kHz {track_obj["maximum_bit_depth"]}bit', + available=track_obj["streamable"] and track_obj["displayable"], ) ) return track async def __async_parse_playlist(self, playlist_obj): """Parse qobuz playlist object to generic layout.""" - playlist = Playlist() - if not playlist_obj or not playlist_obj.get("id"): - return None - playlist.item_id = playlist_obj["id"] - playlist.provider = PROV_ID + playlist = Playlist( + item_id=playlist_obj["id"], + provider=PROV_ID, + name=playlist_obj["name"], + owner=playlist_obj["owner"]["name"], + ) playlist.provider_ids.append( MediaItemProviderId(provider=PROV_ID, item_id=str(playlist_obj["id"])) ) - playlist.name = playlist_obj["name"] - playlist.owner = playlist_obj["owner"]["name"] playlist.is_editable = ( playlist_obj["owner"]["id"] == self.__user_auth_info["user"]["id"] or playlist_obj["is_collaborative"] @@ -641,17 +642,20 @@ class QobuzProvider(MusicProvider): params = {} limit = 50 offset = 0 + all_items = [] while True: params["limit"] = limit params["offset"] = offset result = await self.__async_get_data(endpoint, params=params) offset += limit - if not result or key not in result or "items" not in result[key]: + if not result: + break + if not result.get(key) or not result[key].get("items"): break - for item in result[key]["items"]: - yield item + all_items += result[key]["items"] if len(result[key]["items"]) < limit: break + return all_items async def __async_get_data(self, endpoint, params=None, sign_request=False): """Get data from api.""" diff --git a/music_assistant/providers/spotify/__init__.py b/music_assistant/providers/spotify/__init__.py index e089db61..4da742ac 100644 --- a/music_assistant/providers/spotify/__init__.py +++ b/music_assistant/providers/spotify/__init__.py @@ -159,36 +159,39 @@ class SpotifyProvider(MusicProvider): spotify_artists = await self.__async_get_data( "me/following?type=artist&limit=50" ) - if spotify_artists: - # TODO: use cursor method to retrieve more than 50 artists - for artist_obj in spotify_artists["artists"]["items"]: - prov_artist = await self.__async_parse_artist(artist_obj) - yield prov_artist + return [ + await self.__async_parse_artist(item) + for item in spotify_artists["artists"]["items"] + if (item and item["id"]) + ] async def async_get_library_albums(self) -> List[Album]: """Retrieve library albums from the provider.""" - async for item in self.__async_get_all_items("me/albums"): - album = await self.__async_parse_album(item) - if album: - yield album + return [ + await self.__async_parse_album(item["album"]) + for item in await self.__async_get_all_items("me/albums") + if (item["album"] and item["album"]["id"]) + ] async def async_get_library_tracks(self) -> List[Track]: """Retrieve library tracks from the provider.""" - async for item in self.__async_get_all_items("me/tracks"): - track = await self.__async_parse_track(item) - if track: - yield track + return [ + await self.__async_parse_track(item["track"]) + for item in await self.__async_get_all_items("me/tracks") + if (item and item["track"]["id"]) + ] async def async_get_library_playlists(self) -> List[Playlist]: """Retrieve playlists from the provider.""" - async for item in self.__async_get_all_items("me/playlists"): - playlist = await self.__async_parse_playlist(item) - if playlist: - yield playlist + return [ + await self.__async_parse_playlist(item) + for item in await self.__async_get_all_items("me/playlists") + if (item and item["id"]) + ] async def async_get_radios(self) -> List[Radio]: """Retrieve library/subscribed radio stations from the provider.""" - yield None # TODO: Return spotify radio + return [] # TODO: Return spotify radio async def async_get_artist(self, prov_artist_id) -> Artist: """Get full artist details by id.""" @@ -212,45 +215,44 @@ class SpotifyProvider(MusicProvider): async def async_get_album_tracks(self, prov_album_id) -> List[Track]: """Get all album tracks for given album id.""" - endpoint = f"albums/{prov_album_id}/tracks" - async for track_obj in self.__async_get_all_items(endpoint): - track = await self.__async_parse_track(track_obj) - if track: - yield track + return [ + await self.__async_parse_track(item) + for item in await self.__async_get_all_items( + f"albums/{prov_album_id}/tracks" + ) + if (item and item["id"]) + ] async def async_get_playlist_tracks(self, prov_playlist_id) -> List[Track]: """Get all playlist tracks for given playlist id.""" - endpoint = f"playlists/{prov_playlist_id}/tracks" - async for track_obj in self.__async_get_all_items(endpoint): - playlist_track = await self.__async_parse_track(track_obj) - if playlist_track: - yield playlist_track - else: - LOGGER.warning( - "Unavailable track found in playlist %s: %s", - prov_playlist_id, - track_obj["track"]["name"], - ) + return [ + await self.__async_parse_track(item["track"]) + for item in await self.__async_get_all_items( + f"playlists/{prov_playlist_id}/tracks" + ) + if (item and item["track"]["id"]) + ] async def async_get_artist_albums(self, prov_artist_id) -> List[Album]: """Get a list of all albums for the given artist.""" - params = {"include_groups": "album,single,compilation"} - endpoint = f"artists/{prov_artist_id}/albums" - async for item in self.__async_get_all_items(endpoint, params): - album = await self.__async_parse_album(item) - if album: - yield album + return [ + await self.__async_parse_album(item) + for item in await self.__async_get_all_items( + f"artists/{prov_artist_id}/albums" + ) + if (item and item["id"]) + ] async def async_get_artist_toptracks(self, prov_artist_id) -> List[Track]: """Get a list of 10 most popular tracks for the given artist.""" artist = await self.async_get_artist(prov_artist_id) endpoint = f"artists/{prov_artist_id}/top-tracks" items = await self.__async_get_data(endpoint) - for item in items["tracks"]: - track = await self.__async_parse_track(item) - if track: - track.artists = [artist] - yield track + return [ + await self.__async_parse_track(item, artist=artist) + for item in items["tracks"] + if (item and item["id"]) + ] async def async_library_add(self, prov_item_id, media_type: MediaType): """Add item to library.""" @@ -335,17 +337,14 @@ class SpotifyProvider(MusicProvider): async def __async_parse_artist(self, artist_obj): """Parse spotify artist object to generic layout.""" - if not artist_obj: - return None - artist = Artist() - artist.item_id = artist_obj["id"] - artist.provider = self.id + artist = Artist( + item_id=artist_obj["id"], provider=self.id, name=artist_obj["name"] + ) artist.provider_ids.append( MediaItemProviderId(provider=PROV_ID, item_id=artist_obj["id"]) ) - artist.name = artist_obj["name"] if "genres" in artist_obj: - artist.tags = artist_obj["genres"] + artist.metadata["genres"] = artist_obj["genres"] if artist_obj.get("images"): for img in artist_obj["images"]: img_url = img["url"] @@ -358,34 +357,26 @@ class SpotifyProvider(MusicProvider): async def __async_parse_album(self, album_obj): """Parse spotify album object to generic layout.""" - if not album_obj: - return None - if "album" in album_obj: - album_obj = album_obj["album"] - if not album_obj["id"] or not album_obj.get("is_playable", True): - return None - album = Album() - album.item_id = album_obj["id"] - album.provider = self.id + album = Album(item_id=album_obj["id"], provider=self.id) album.name, album.version = parse_title_and_version(album_obj["name"]) for artist in album_obj["artists"]: album.artist = await self.__async_parse_artist(artist) if album.artist: break if album_obj["album_type"] == "single": - album.albumtype = AlbumType.Single + album.album_type = AlbumType.Single elif album_obj["album_type"] == "compilation": - album.albumtype = AlbumType.Compilation + album.album_type = AlbumType.Compilation else: - album.albumtype = AlbumType.Album + album.album_type = AlbumType.Album if "genres" in album_obj: - album.tags = album_obj["genres"] + album.metadata["genres"] = album_obj["genres"] if album_obj.get("images"): album.metadata["image"] = album_obj["images"][0]["url"] - if "external_ids" in album_obj: - album.external_ids = album_obj["external_ids"] + if "external_ids" in album_obj and album_obj["external_ids"].get("upc"): + album.upc = album_obj["external_ids"]["upc"] if "label" in album_obj: - album.labels = album_obj["label"].split("/") + album.metadata["label"] = album_obj["label"] if album_obj.get("release_date"): album.year = int(album_obj["release_date"].split("-")[0]) if album_obj.get("copyrights"): @@ -403,35 +394,31 @@ class SpotifyProvider(MusicProvider): ) return album - async def __async_parse_track(self, track_obj): + async def __async_parse_track(self, track_obj, artist=None): """Parse spotify track object to generic layout.""" - if not track_obj: - return None - if "track" in track_obj: - track_obj = track_obj["track"] - if track_obj["is_local"] or not track_obj["id"] or not track_obj["is_playable"]: - # do not return unavailable items - return None - track = Track() - track.item_id = track_obj["id"] - track.provider = self.id - for track_artist in track_obj["artists"]: + track = Track( + item_id=track_obj["id"], + provider=self.id, + duration=track_obj["duration_ms"] / 1000, + disc_number=track_obj["disc_number"], + track_number=track_obj["track_number"], + ) + if artist: + track.artists.append(artist) + for track_artist in track_obj.get("artists", []): artist = await self.__async_parse_artist(track_artist) if artist: track.artists.append(artist) track.name, track.version = parse_title_and_version(track_obj["name"]) - track.duration = track_obj["duration_ms"] / 1000 track.metadata["explicit"] = str(track_obj["explicit"]).lower() - if "external_ids" in track_obj: - track.external_ids = track_obj["external_ids"] + if "external_ids" in track_obj and "isrc" in track_obj["external_ids"]: + track.isrc = track_obj["external_ids"]["isrc"] if "album" in track_obj: track.album = await self.__async_parse_album(track_obj["album"]) if track_obj.get("copyright"): track.metadata["copyright"] = track_obj["copyright"] if track_obj.get("explicit"): track.metadata["explicit"] = True - track.disc_number = track_obj["disc_number"] - track.track_number = track_obj["track_number"] if track_obj.get("external_urls"): track.metadata["spotify_url"] = track_obj["external_urls"]["spotify"] track.provider_ids.append( @@ -439,18 +426,14 @@ class SpotifyProvider(MusicProvider): provider=PROV_ID, item_id=track_obj["id"], quality=TrackQuality.LOSSY_OGG, + available=not track_obj["is_local"] and track_obj["is_playable"], ) ) return track async def __async_parse_playlist(self, playlist_obj): """Parse spotify playlist object to generic layout.""" - - if not playlist_obj.get("id"): - return None - playlist = Playlist() - playlist.item_id = playlist_obj["id"] - playlist.provider = self.id + playlist = Playlist(item_id=playlist_obj["id"], provider=self.id) playlist.provider_ids.append( MediaItemProviderId(provider=PROV_ID, item_id=playlist_obj["id"]) ) @@ -548,6 +531,7 @@ class SpotifyProvider(MusicProvider): params = {} limit = 50 offset = 0 + all_items = [] while True: params["limit"] = limit params["offset"] = offset @@ -555,10 +539,10 @@ class SpotifyProvider(MusicProvider): offset += limit if not result or key not in result or not result[key]: break - for item in result[key]: - yield item + all_items += result[key] if len(result[key]) < limit: break + return all_items async def __async_get_data(self, endpoint, params=None): """Get data from api.""" diff --git a/music_assistant/providers/tunein/__init__.py b/music_assistant/providers/tunein/__init__.py index 3d16f2e3..dc170924 100644 --- a/music_assistant/providers/tunein/__init__.py +++ b/music_assistant/providers/tunein/__init__.py @@ -99,11 +99,12 @@ class TuneInProvider(MusicProvider): params = {"c": "presets"} result = await self.__async_get_data("Browse.ashx", params) if result and "body" in result: - for item in result["body"]: - # TODO: expand folders - if item["type"] == "audio": - radio = await self.__async_parse_radio(item) - yield radio + return [ + await self.__async_parse_radio(item) + for item in result["body"] + if item["type"] == "audio" + ] + return [] async def async_get_radio(self, prov_radio_id: str) -> Radio: """Get radio station details.""" diff --git a/music_assistant/web/__init__.py b/music_assistant/web/__init__.py index fb88beac..40215c98 100755 --- a/music_assistant/web/__init__.py +++ b/music_assistant/web/__init__.py @@ -8,7 +8,8 @@ from aiohttp import web from aiohttp_jwt import JWTMiddleware from music_assistant.constants import __version__ as MASS_VERSION from music_assistant.helpers.typing import MusicAssistantType -from music_assistant.helpers.util import get_hostname, get_ip, json_serializer +from music_assistant.helpers.util import get_hostname, get_ip +from music_assistant.helpers.web import json_serializer from .endpoints import ( albums, diff --git a/music_assistant/web/endpoints/albums.py b/music_assistant/web/endpoints/albums.py index 53d5665d..f0c68bd4 100644 --- a/music_assistant/web/endpoints/albums.py +++ b/music_assistant/web/endpoints/albums.py @@ -2,8 +2,7 @@ from aiohttp.web import Request, Response, RouteTableDef from aiohttp_jwt import login_required -from music_assistant.helpers.util import json_serializer -from music_assistant.helpers.web import async_stream_json +from music_assistant.helpers.web import async_json_response routes = RouteTableDef() @@ -12,8 +11,9 @@ routes = RouteTableDef() @login_required async def async_albums(request: Request): """Get all albums known in the database.""" - generator = request.app["mass"].database.async_get_albums() - return await async_stream_json(request, generator) + return await async_json_response( + await request.app["mass"].database.async_get_albums() + ) @routes.get("/api/albums/{item_id}") @@ -25,10 +25,9 @@ async def async_album(request: Request): lazy = request.rel_url.query.get("lazy", "true") != "false" if item_id is None or provider is None: return Response(text="invalid item or provider", status=501) - result = await request.app["mass"].music.async_get_album( - item_id, provider, lazy=lazy + return await async_json_response( + await request.app["mass"].music.async_get_album(item_id, provider, lazy=lazy) ) - return Response(body=json_serializer(result), content_type="application/json") @routes.get("/api/albums/{item_id}/tracks") @@ -39,8 +38,9 @@ async def async_album_tracks(request: Request): provider = request.rel_url.query.get("provider") if item_id is None or provider is None: return Response(text="invalid item_id or provider", status=501) - generator = request.app["mass"].music.async_get_album_tracks(item_id, provider) - return await async_stream_json(request, generator) + return await async_json_response( + await request.app["mass"].music.async_get_album_tracks(item_id, provider) + ) @routes.get("/api/albums/{item_id}/versions") @@ -51,5 +51,6 @@ async def async_album_versions(request): provider = request.rel_url.query.get("provider") if item_id is None or provider is None: return Response(text="invalid item_id or provider", status=501) - generator = request.app["mass"].music.async_get_album_versions(item_id, provider) - return await async_stream_json(request, generator) + return await async_json_response( + await request.app["mass"].music.async_get_album_versions(item_id, provider) + ) diff --git a/music_assistant/web/endpoints/artists.py b/music_assistant/web/endpoints/artists.py index 3a8d1be1..1847dff2 100644 --- a/music_assistant/web/endpoints/artists.py +++ b/music_assistant/web/endpoints/artists.py @@ -2,8 +2,7 @@ from aiohttp.web import Request, Response, RouteTableDef from aiohttp_jwt import login_required -from music_assistant.helpers.util import json_serializer -from music_assistant.helpers.web import async_stream_json +from music_assistant.helpers.web import async_json_response routes = RouteTableDef() @@ -12,8 +11,8 @@ routes = RouteTableDef() @login_required async def async_artists(request: Request): """Get all artists known in the database.""" - generator = request.app["mass"].database.async_get_artists() - return await async_stream_json(request, generator) + result = await request.app["mass"].database.async_get_artists() + return await async_json_response(result) @routes.get("/api/artists/{item_id}") @@ -28,7 +27,7 @@ async def async_artist(request: Request): result = await request.app["mass"].music.async_get_artist( item_id, provider, lazy=lazy ) - return Response(body=json_serializer(result), content_type="application/json") + return await async_json_response(result) @routes.get("/api/artists/{item_id}/toptracks") @@ -39,8 +38,10 @@ async def async_artist_toptracks(request: Request): provider = request.rel_url.query.get("provider") if item_id is None or provider is None: return Response(text="invalid item_id or provider", status=501) - generator = request.app["mass"].music.async_get_artist_toptracks(item_id, provider) - return await async_stream_json(request, generator) + result = await request.app["mass"].music.async_get_artist_toptracks( + item_id, provider + ) + return await async_json_response(result) @routes.get("/api/artists/{item_id}/albums") @@ -51,5 +52,5 @@ async def async_artist_albums(request: Request): provider = request.rel_url.query.get("provider") if item_id is None or provider is None: return Response(text="invalid item_id or provider", status=501) - generator = request.app["mass"].music.async_get_artist_albums(item_id, provider) - return await async_stream_json(request, generator) + result = await request.app["mass"].music.async_get_artist_albums(item_id, provider) + return await async_json_response(result) diff --git a/music_assistant/web/endpoints/config.py b/music_assistant/web/endpoints/config.py index 04df56d3..949dc0e7 100644 --- a/music_assistant/web/endpoints/config.py +++ b/music_assistant/web/endpoints/config.py @@ -2,7 +2,7 @@ from json.decoder import JSONDecodeError -from aiohttp.web import Request, Response, RouteTableDef, json_response +from aiohttp.web import Request, RouteTableDef from aiohttp_jwt import login_required from music_assistant.constants import ( CONF_KEY_BASE, @@ -12,7 +12,7 @@ from music_assistant.constants import ( CONF_KEY_PLAYER_SETTINGS, CONF_KEY_PLUGINS, ) -from music_assistant.helpers.util import json_serializer +from music_assistant.helpers.web import async_json_response routes = RouteTableDef() @@ -32,7 +32,7 @@ async def async_get_config(request: Request): CONF_KEY_PLAYER_SETTINGS, ] } - return Response(body=json_serializer(conf), content_type="application/json") + return await async_json_response(conf) @routes.get("/api/config/{base}") @@ -42,7 +42,7 @@ async def async_get_config_base_item(request: Request): language = request.rel_url.query.get("lang", "en") conf_base = request.match_info.get("base") conf = request.app["mass"].config[conf_base].all_items(language) - return Response(body=json_serializer(conf), content_type="application/json") + return await async_json_response(conf) @routes.get("/api/config/{base}/{item}") @@ -53,7 +53,7 @@ async def async_get_config_item(request: Request): conf_base = request.match_info.get("base") conf_item = request.match_info.get("item") conf = request.app["mass"].config[conf_base][conf_item].all_items(language) - return Response(body=json_serializer(conf), content_type="application/json") + return await async_json_response(conf) @routes.put("/api/config/{base}/{key}/{entry_key}") @@ -73,4 +73,6 @@ async def async_put_config(request: Request): .default_value ) request.app["mass"].config[conf_base][conf_key][entry_key] = new_value - return json_response(True) + return await async_json_response( + request.app["mass"].config[conf_base][conf_key][entry_key] + ) diff --git a/music_assistant/web/endpoints/images.py b/music_assistant/web/endpoints/images.py index 03ce3458..a388bae4 100644 --- a/music_assistant/web/endpoints/images.py +++ b/music_assistant/web/endpoints/images.py @@ -25,7 +25,7 @@ async def async_get_provider_icon(request: Request): async def async_get_image(request: Request): """Get (resized) thumb image.""" media_type_str = request.match_info.get("media_type") - media_type = MediaType.from_string(media_type_str) + media_type = MediaType(media_type_str) media_id = request.match_info.get("media_id") provider = request.rel_url.query.get("provider") if media_id is None or provider is None: diff --git a/music_assistant/web/endpoints/library.py b/music_assistant/web/endpoints/library.py index d6e87c1d..c0e217e1 100644 --- a/music_assistant/web/endpoints/library.py +++ b/music_assistant/web/endpoints/library.py @@ -1,9 +1,8 @@ """Library API endpoints.""" -from aiohttp.web import Request, Response, RouteTableDef +from aiohttp.web import Request, RouteTableDef from aiohttp_jwt import login_required -from music_assistant.helpers.util import json_serializer -from music_assistant.helpers.web import async_media_items_from_body, async_stream_json +from music_assistant.helpers.web import async_json_response, async_media_items_from_body routes = RouteTableDef() @@ -13,11 +12,10 @@ routes = RouteTableDef() async def async_library_artists(request: Request): """Get all library artists.""" orderby = request.query.get("orderby", "name") - provider_filter = request.rel_url.query.get("provider") - generator = request.app["mass"].music.async_get_library_artists( - orderby=orderby, provider_filter=provider_filter + + return await async_json_response( + await request.app["mass"].music.async_get_library_artists(orderby=orderby) ) - return await async_stream_json(request, generator) @routes.get("/api/library/albums") @@ -25,11 +23,10 @@ async def async_library_artists(request: Request): async def async_library_albums(request: Request): """Get all library albums.""" orderby = request.query.get("orderby", "name") - provider_filter = request.rel_url.query.get("provider") - generator = request.app["mass"].music.async_get_library_albums( - orderby=orderby, provider_filter=provider_filter + + return await async_json_response( + await request.app["mass"].music.async_get_library_albums(orderby=orderby) ) - return await async_stream_json(request, generator) @routes.get("/api/library/tracks") @@ -37,11 +34,10 @@ async def async_library_albums(request: Request): async def async_library_tracks(request: Request): """Get all library tracks.""" orderby = request.query.get("orderby", "name") - provider_filter = request.rel_url.query.get("provider") - generator = request.app["mass"].music.async_get_library_tracks( - orderby=orderby, provider_filter=provider_filter + + return await async_json_response( + await request.app["mass"].music.async_get_library_tracks(orderby=orderby) ) - return await async_stream_json(request, generator) @routes.get("/api/library/radios") @@ -49,11 +45,10 @@ async def async_library_tracks(request: Request): async def async_library_radios(request: Request): """Get all library radios.""" orderby = request.query.get("orderby", "name") - provider_filter = request.rel_url.query.get("provider") - generator = request.app["mass"].music.async_get_library_radios( - orderby=orderby, provider_filter=provider_filter + + return await async_json_response( + await request.app["mass"].music.async_get_library_radios(orderby=orderby) ) - return await async_stream_json(request, generator) @routes.get("/api/library/playlists") @@ -61,11 +56,10 @@ async def async_library_radios(request: Request): async def async_library_playlists(request: Request): """Get all library playlists.""" orderby = request.query.get("orderby", "name") - provider_filter = request.rel_url.query.get("provider") - generator = request.app["mass"].music.async_get_library_playlists( - orderby=orderby, provider_filter=provider_filter + + return await async_json_response( + await request.app["mass"].music.async_get_library_playlists(orderby=orderby) ) - return await async_stream_json(request, generator) @routes.put("/api/library") @@ -75,7 +69,7 @@ async def async_library_add(request: Request): body = await request.json() media_items = await async_media_items_from_body(request.app["mass"], body) result = await request.app["mass"].music.async_library_add(media_items) - return Response(body=json_serializer(result), content_type="application/json") + return await async_json_response(result) @routes.delete("/api/library") @@ -85,4 +79,4 @@ async def async_library_remove(request: Request): body = await request.json() media_items = await async_media_items_from_body(request.app["mass"], body) result = await request.app["mass"].music.async_library_remove(media_items) - return Response(body=json_serializer(result), content_type="application/json") + return await async_json_response(result) diff --git a/music_assistant/web/endpoints/login.py b/music_assistant/web/endpoints/login.py index 6dbe9ee6..23480509 100644 --- a/music_assistant/web/endpoints/login.py +++ b/music_assistant/web/endpoints/login.py @@ -5,7 +5,7 @@ import datetime import jwt from aiohttp.web import HTTPUnauthorized, Request, Response, RouteTableDef from music_assistant.helpers.typing import MusicAssistantType -from music_assistant.helpers.util import json_serializer +from music_assistant.helpers.web import json_serializer routes = RouteTableDef() diff --git a/music_assistant/web/endpoints/players.py b/music_assistant/web/endpoints/players.py index c40e018b..ebe0ba7e 100644 --- a/music_assistant/web/endpoints/players.py +++ b/music_assistant/web/endpoints/players.py @@ -2,10 +2,9 @@ from json.decoder import JSONDecodeError -from aiohttp.web import Request, Response, RouteTableDef, json_response +from aiohttp.web import Request, Response, RouteTableDef from aiohttp_jwt import login_required -from music_assistant.helpers.util import json_serializer -from music_assistant.helpers.web import async_media_items_from_body, async_stream_json +from music_assistant.helpers.web import async_json_response, async_media_items_from_body from music_assistant.models.player_queue import QueueOption routes = RouteTableDef() @@ -18,8 +17,9 @@ async def async_players(request: Request): """Get all playerstates.""" player_states = request.app["mass"].players.player_states player_states.sort(key=lambda x: str(x.name), reverse=False) - players = [player_state.to_dict() for player_state in player_states] - return Response(body=json_serializer(players), content_type="application/json") + return await async_json_response( + [player_state.to_dict() for player_state in player_states] + ) @routes.post("/api/players/{player_id}/cmd/{cmd}") @@ -43,7 +43,7 @@ async def async_player_command(request: Request): else: return Response(text="invalid command", status=501) result = {"success": success in [True, None]} - return Response(body=json_serializer(result), content_type="application/json") + return await async_json_response(result) @routes.post("/api/players/{player_id}/play_media/{queue_opt}") @@ -61,7 +61,7 @@ async def async_player_play_media(request: Request): player_id, media_items, queue_opt ) result = {"success": success in [True, None]} - return json_response(result) + return await async_json_response(result) @routes.get("/api/players/{player_id}/queue/items/{queue_item}") @@ -78,7 +78,7 @@ async def async_player_queue_item(request: Request): queue_item = player_queue.get_item(item_id) except ValueError: queue_item = player_queue.by_item_id(item_id) - return json_response(queue_item.to_dict()) + return await async_json_response(queue_item) @routes.get("/api/players/{player_id}/queue/items") @@ -89,12 +89,7 @@ async def async_player_queue_items(request: Request): player_queue = request.app["mass"].players.get_player_queue(player_id) if not player_queue: return Response(text="invalid player", status=404) - - async def async_queue_tracks_iter(): - for item in player_queue.items: - yield item - - return await async_stream_json(request, async_queue_tracks_iter()) + return await async_json_response(player_queue.items) @routes.get("/api/players/{player_id}/queue") @@ -105,9 +100,7 @@ async def async_player_queue(request: Request): player_queue = request.app["mass"].players.get_player_queue(player_id) if not player_queue: return Response(text="invalid player", status=404) - return Response( - body=json_serializer(player_queue.to_dict()), content_type="application/json" - ) + return await async_json_response(player_queue) @routes.put("/api/players/{player_id}/queue/{cmd}") @@ -135,9 +128,7 @@ async def async_player_queue_cmd(request: Request): await player_queue.async_move_item(cmd_args, 1) elif cmd == "next": await player_queue.async_move_item(cmd_args, 0) - return Response( - body=json_serializer(player_queue.to_dict()), content_type="application/json" - ) + return await async_json_response(player_queue) @routes.get("/api/players/{player_id}") @@ -148,6 +139,4 @@ async def async_player(request: Request): player_state = request.app["mass"].players.get_player_state(player_id) if not player_state: return Response(text="invalid player", status=404) - return Response( - body=json_serializer(player_state.to_dict()), content_type="application/json" - ) + return await async_json_response(player_state) diff --git a/music_assistant/web/endpoints/playlists.py b/music_assistant/web/endpoints/playlists.py index ac152b85..3873e0ad 100644 --- a/music_assistant/web/endpoints/playlists.py +++ b/music_assistant/web/endpoints/playlists.py @@ -1,10 +1,9 @@ """Playlists API endpoints.""" import ujson -from aiohttp.web import Request, Response, RouteTableDef, json_response +from aiohttp.web import Request, Response, RouteTableDef from aiohttp_jwt import login_required -from music_assistant.helpers.util import json_serializer -from music_assistant.helpers.web import async_media_items_from_body, async_stream_json +from music_assistant.helpers.web import async_json_response, async_media_items_from_body routes = RouteTableDef() @@ -18,7 +17,7 @@ async def async_playlist(request: Request): if item_id is None or provider is None: return Response(text="invalid item or provider", status=501) result = await request.app["mass"].music.async_get_playlist(item_id, provider) - return json_response(result, dumps=json_serializer) + return await async_json_response(result) @routes.get("/api/playlists/{item_id}/tracks") @@ -29,8 +28,10 @@ async def async_playlist_tracks(request: Request): provider = request.rel_url.query.get("provider") if item_id is None or provider is None: return Response(text="invalid item_id or provider", status=501) - generator = request.app["mass"].music.async_get_playlist_tracks(item_id, provider) - return await async_stream_json(request, generator) + result = await request.app["mass"].music.async_get_playlist_tracks( + item_id, provider + ) + return await async_json_response(result) @routes.put("/api/playlists/{item_id}/tracks") @@ -41,7 +42,7 @@ async def async_add_playlist_tracks(request: Request): body = await request.json(loads=ujson.loads) tracks = await async_media_items_from_body(request.app["mass"], body) result = await request.app["mass"].music.async_add_playlist_tracks(item_id, tracks) - return json_response(result) + return await async_json_response(result) @routes.delete("/api/playlists/{item_id}/tracks") @@ -54,4 +55,4 @@ async def async_remove_playlist_tracks(request: Request): result = await request.app["mass"].music.async_remove_playlist_tracks( item_id, tracks ) - return json_response(result) + return await async_json_response(result) diff --git a/music_assistant/web/endpoints/radios.py b/music_assistant/web/endpoints/radios.py index 5ba2f987..1db1a163 100644 --- a/music_assistant/web/endpoints/radios.py +++ b/music_assistant/web/endpoints/radios.py @@ -2,8 +2,7 @@ from aiohttp.web import Request, Response, RouteTableDef from aiohttp_jwt import login_required -from music_assistant.helpers.util import json_serializer -from music_assistant.helpers.web import async_stream_json +from music_assistant.helpers.web import async_json_response routes = RouteTableDef() @@ -12,8 +11,9 @@ routes = RouteTableDef() @login_required async def async_radios(request: Request): """Get all radios known in the database.""" - generator = request.app["mass"].database.async_get_radios() - return await async_stream_json(request, generator) + return await async_json_response( + await request.app["mass"].database.async_get_radios() + ) @routes.get("/api/radios/{item_id}") @@ -24,5 +24,6 @@ async def async_radio(request: Request): provider = request.rel_url.query.get("provider") if item_id is None or provider is None: return Response(text="invalid item_id or provider", status=501) - result = await request.app["mass"].music.async_get_radio(item_id, provider) - return Response(body=json_serializer(result), content_type="application/json") + return await async_json_response( + await request.app["mass"].music.async_get_radio(item_id, provider) + ) diff --git a/music_assistant/web/endpoints/search.py b/music_assistant/web/endpoints/search.py index 768c8425..2fd3d1dc 100644 --- a/music_assistant/web/endpoints/search.py +++ b/music_assistant/web/endpoints/search.py @@ -1,8 +1,8 @@ """Search API endpoints.""" -from aiohttp.web import Request, Response, RouteTableDef +from aiohttp.web import Request, RouteTableDef from aiohttp_jwt import login_required -from music_assistant.helpers.util import json_serializer +from music_assistant.helpers.web import async_json_response from music_assistant.models.media_types import MediaType routes = RouteTableDef() @@ -26,8 +26,8 @@ async def async_search(request: Request): media_types.append(MediaType.Playlist) if not media_types_query or "radios" in media_types_query: media_types.append(MediaType.Radio) - - result = await request.app["mass"].music.async_global_search( - searchquery, media_types, limit=limit + return await async_json_response( + await request.app["mass"].music.async_global_search( + searchquery, media_types, limit=limit + ) ) - return Response(body=json_serializer(result), content_type="application/json") diff --git a/music_assistant/web/endpoints/streams.py b/music_assistant/web/endpoints/streams.py index d62307ed..33bb301c 100644 --- a/music_assistant/web/endpoints/streams.py +++ b/music_assistant/web/endpoints/streams.py @@ -10,7 +10,7 @@ routes = RouteTableDef() @routes.get("/stream/media/{media_type}/{item_id}") async def stream_media(request: Request): """Stream a single audio track.""" - media_type = MediaType.from_string(request.match_info["media_type"]) + media_type = MediaType(request.match_info["media_type"]) if media_type not in [MediaType.Track, MediaType.Radio]: return Response(status=404, reason="Media item is not playable!") item_id = request.match_info["item_id"] diff --git a/music_assistant/web/endpoints/tracks.py b/music_assistant/web/endpoints/tracks.py index 3d990def..110c018a 100644 --- a/music_assistant/web/endpoints/tracks.py +++ b/music_assistant/web/endpoints/tracks.py @@ -2,8 +2,7 @@ from aiohttp.web import Request, Response, RouteTableDef from aiohttp_jwt import login_required -from music_assistant.helpers.util import json_serializer -from music_assistant.helpers.web import async_stream_json +from music_assistant.helpers.web import async_json_response routes = RouteTableDef() @@ -12,8 +11,8 @@ routes = RouteTableDef() @login_required async def async_tracks(request: Request): """Get all tracks known in the database.""" - generator = request.app["mass"].database.async_get_tracks() - return await async_stream_json(request, generator) + result = await request.app["mass"].database.async_get_tracks() + return await async_json_response(result) @routes.get("/api/tracks/{item_id}/versions") @@ -24,8 +23,8 @@ async def async_track_versions(request: Request): provider = request.rel_url.query.get("provider") if item_id is None or provider is None: return Response(text="invalid item_id or provider", status=501) - generator = request.app["mass"].music.async_get_track_versions(item_id, provider) - return await async_stream_json(request, generator) + result = await request.app["mass"].music.async_get_track_versions(item_id, provider) + return await async_json_response(result) @routes.get("/api/tracks/{item_id}") @@ -40,4 +39,4 @@ async def async_track(request: Request): result = await request.app["mass"].music.async_get_track( item_id, provider, lazy=lazy ) - return Response(body=json_serializer(result), content_type="application/json") + return await async_json_response(result) diff --git a/music_assistant/web/endpoints/websocket.py b/music_assistant/web/endpoints/websocket.py index b7e9aa00..62f51ea2 100644 --- a/music_assistant/web/endpoints/websocket.py +++ b/music_assistant/web/endpoints/websocket.py @@ -8,7 +8,9 @@ import ujson from aiohttp import WSMsgType from aiohttp.web import Request, RouteTableDef, WebSocketResponse from music_assistant.helpers.typing import MusicAssistantType -from music_assistant.helpers.util import json_serializer +from music_assistant.helpers.web import json_serializer + +from .login import async_get_token routes = RouteTableDef() ws_commands = dict() @@ -70,17 +72,28 @@ async def async_websocket_handler(request: Request): if not authenticated and not msg == "login": # make sure client is authenticated await async_send_message("error", "authentication required") - elif msg == "login": + elif msg == "login" and isinstance(msg_details, str): # handle login with token try: token_info = jwt.decode(msg_details, mass.web.device_id) await async_send_message("login", token_info) authenticated = True except jwt.InvalidTokenError as exc: - async_send_message( + await async_send_message( "error", "Invalid authorization token, " + str(exc) ) authenticated = False + elif msg == "login" and isinstance(msg_details, dict): + # handle login with username/password + token_info = await async_get_token( + mass, msg_details["username"], msg_details["password"] + ) + if token_info: + await async_send_message("login", token_info) + authenticated = True + else: + await async_send_message("error", "Invalid credentials") + authenticated = False elif msg in ws_commands: res = await ws_commands[msg](mass, msg_details) if res is not None: @@ -90,7 +103,6 @@ async def async_websocket_handler(request: Request): mass.add_event_listener(async_send_message, msg_details) ) await async_send_message("event listener subscribed", msg_details) - else: # simply echo the message on the eventbus request.app["mass"].signal_event(msg, msg_details) diff --git a/requirements_dev.txt b/requirements_dev.txt index 8a343542..b09b1992 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -2,6 +2,5 @@ -r requirements_lint.txt -r requirements_test.txt tox==3.20.1 -python-vlc -e . -- 2.34.1