-# This workflows will upload a Python Package using Twine when a release is created
-# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
-
-name: Upload Python Package
+name: Publish releases to PyPI
on:
release:
- types: [published]
+ types: [published, prereleased]
jobs:
- deploy:
-
+ build-and-publish:
+ name: Builds and publishes releases to PyPI
runs-on: ubuntu-latest
-
steps:
- - uses: actions/checkout@v3
- - name: Set up Python
- uses: actions/setup-python@v4.3.0
- with:
- python-version: '3.x'
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install setuptools wheel twine
- - name: Build and publish
- env:
- TWINE_USERNAME: __token__
- TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
- run: |
- python setup.py sdist bdist_wheel
- twine upload dist/*
+ - uses: actions/checkout@v3.3.0
+ - name: Set up Python 3.10
+ uses: actions/setup-python@v4.5.0
+ with:
+ python-version: "3.10"
+ - name: Install build
+ run: >-
+ pip install build
+ - name: Build
+ run: >-
+ python3 -m build
+ - name: Publish release to PyPI
+ uses: pypa/gh-action-pypi-publish@v1.6.4
+ with:
+ user: __token__
+ password: ${{ secrets.PYPI_TOKEN }}
# This workflow will install Python dependencies, run tests and lint
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
-name: Test with Pre-commit
+name: Test
on:
push:
- branches: [master]
+ branches: [main]
pull_request:
- branches: [master]
+ branches: [main]
jobs:
- build:
+ lint:
runs-on: ubuntu-latest
+ continue-on-error: true
+
+ steps:
+ - name: Check out code from GitHub
+ uses: actions/checkout@v3.3.0
+ - name: Set up Python
+ uses: actions/setup-python@v4.5.0
+ with:
+ python-version: "3.11"
+ - name: Install dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y ffmpeg
+ python -m pip install --upgrade pip
+ pip install -e .[server] -r requirements-test.txt
+ - name: Lint/test with pre-commit
+ run: pre-commit run --all-files
+ - name: Flake8
+ run: flake8 scripts/ music_assistant/
+ - name: Black
+ run: black --check scripts/ music_assistant/
+ - name: isort
+ run: isort --check scripts/ music_assistant/
+ - name: pylint
+ run: pylint music_assistant/
+ # - name: mypy
+ # run: mypy music_assistant/
+
+ test:
+ runs-on: ubuntu-latest
+ continue-on-error: true
strategy:
+ fail-fast: false
matrix:
- python-version: ['3.9', '3.10']
+ python-version:
+ - "3.11"
steps:
- - uses: actions/checkout@v3
+ - name: Check out code from GitHub
+ uses: actions/checkout@v3.3.0
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4.3.0
+ uses: actions/setup-python@v4.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt-get update
- sudo apt-get install ffmpeg
+ sudo apt-get install -y libgirepository1.0-dev
python -m pip install --upgrade pip
- pip install -r requirements_dev.txt
- pre-commit install-hooks
- - name: Lint/test with pre-commit
- run: pre-commit run --all-files
- - name: Run pylint on changed files
- run: |
- pylint -rn -sn --rcfile=pylintrc --fail-on=I $(git ls-files '*.py')
- - name: Run unit tests with pytest
- run: pytest tests/
+ pip install -e .[server] -r requirements-test.txt
+ - name: Pytest
+ run: pytest --durations 10 --cov-report term-missing --cov=music_assistant --cov-report=xml tests/server/
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.3.0
+ rev: v4.4.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
+ - id: no-commit-to-branch
+ args:
+ - --branch=main
+ - id: debug-statements
+ - repo: https://github.com/charliermarsh/ruff-pre-commit
+ rev: 'v0.0.254'
+ hooks:
+ - id: ruff
- repo: https://github.com/psf/black
- rev: 22.10.0
+ rev: 23.1.0
hooks:
- id: black
args:
- --safe
- --quiet
- - repo: https://gitlab.com/pycqa/flake8
- rev: 3.9.2
+ - repo: https://github.com/codespell-project/codespell
+ rev: v2.2.2
hooks:
- - id: flake8
- exclude: ^examples/
+ - id: codespell
+ args: []
+ exclude_types: [csv, json]
+ exclude: ^tests/fixtures/
additional_dependencies:
- - flake8-docstrings==1.3.1
- - pydocstyle==4.0.0
- - repo: https://github.com/pre-commit/mirrors-isort
- rev: v5.10.1
- hooks:
- - id: isort
- exclude: ^examples/
- - repo: https://github.com/pre-commit/mirrors-mypy
- rev: v0.990
- hooks:
- - id: mypy
- additional_dependencies: [types-all]
- exclude: ^examples/
- - repo: https://github.com/pycqa/pydocstyle
- rev: 6.1.1
- hooks:
- - id: pydocstyle
- exclude: ^examples/|^.venv/|^.vscode/|app_vars.py
- - repo: local
- hooks:
- - id: pylint
- name: pylint
- entry: pylint
- language: system
- types: [python]
- args: ["-rn", "-sn", "--rcfile=pylintrc", "--fail-on=I", "--disable=import-error"]
- exclude: ^.venv/|^.vscode/
+ - tomli
+
+ # - repo: https://github.com/pre-commit/mirrors-mypy
+ # rev: v0.990
+ # hooks:
+ # - id: mypy
+ # additional_dependencies: [types-all]
+ # exclude: ^examples/
--- /dev/null
+{
+ // Use IntelliSense to learn about possible attributes.
+ // Hover to view descriptions of existing attributes.
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "name": "Python: Module",
+ "type": "python",
+ "request": "launch",
+ "module": "music_assistant",
+ "justMyCode": false,
+ "args":[
+ "--log-level", "debug"
+ ]
+ }
+ ]
+}
\ No newline at end of file
{
- "python.linting.pylintArgs": ["--rcfile=${workspaceFolder}/setup.cfg"],
- "python.linting.enabled": true,
- "python.linting.flake8Enabled": true,
- "python.linting.flake8Args": ["--config=${workspaceFolder}/setup.cfg"],
- "python.linting.mypyEnabled": true,
- "python.testing.pytestArgs": [
- "tests"
- ]
+ "[python]": {
+ "editor.formatOnSave": true,
+ "editor.codeActionsOnSave": {
+ "source.fixAll": true,
+ "source.organizeImports": true
+ }
+ },
+ "python.formatting.provider": "black"
}
+++ /dev/null
-include *.txt
-include README.rst
-include LICENSE.md
-graft music_assistant
-recursive-exclude * *.py[co]
+++ /dev/null
-"""Extended example/script to run Music Assistant with all bells and whistles."""
-import argparse
-import asyncio
-import logging
-import os
-
-from os.path import abspath, dirname
-from sys import path
-
-path.insert(1, dirname(dirname(abspath(__file__))))
-
-# pylint: disable=wrong-import-position
-from music_assistant.mass import MusicAssistant
-from music_assistant.models.config import MassConfig, MusicProviderConfig
-from music_assistant.models.enums import (
- CrossFadeMode,
- ProviderType,
- RepeatMode,
- PlayerState,
-)
-from music_assistant.models.player import Player
-
-
-parser = argparse.ArgumentParser(description="MusicAssistant")
-parser.add_argument(
- "--spotify-username",
- required=False,
- help="Spotify username",
-)
-parser.add_argument(
- "--spotify-password",
- required=False,
- help="Spotify password.",
-)
-parser.add_argument(
- "--qobuz-username",
- required=False,
- help="Qobuz username",
-)
-parser.add_argument(
- "--qobuz-password",
- required=False,
- help="Qobuz password.",
-)
-parser.add_argument(
- "--tunein-username",
- required=False,
- help="Tunein username",
-)
-parser.add_argument(
- "--musicdir",
- required=False,
- help="Directory on disk for local music library",
-)
-parser.add_argument(
- "--ytmusic-username",
- required=False,
- help="YoutubeMusic username",
-)
-parser.add_argument(
- "--ytmusic-cookie",
- required=False,
- help="YoutubeMusic cookie",
-)
-parser.add_argument(
- "--smb-username",
- required=False,
- help="SMB username",
-)
-parser.add_argument(
- "--smb-password",
- required=False,
- help="SMB password",
-)
-parser.add_argument(
- "--smb-path",
- required=False,
- help="The NetBIOS machine name of the remote server + share (e.g. \\\\machine\\share).",
-)
-parser.add_argument(
- "--debug",
- action="store_true",
- help="Enable verbose debug logging",
-)
-args = parser.parse_args()
-
-
-# setup logger
-logging.basicConfig(
- level=logging.DEBUG if args.debug else logging.INFO,
- format="%(asctime)-15s %(levelname)-5s %(name)s -- %(message)s",
-)
-# silence some loggers
-logging.getLogger("aiorun").setLevel(logging.WARNING)
-logging.getLogger("asyncio").setLevel(logging.INFO)
-logging.getLogger("aiosqlite").setLevel(logging.WARNING)
-logging.getLogger("databases").setLevel(logging.INFO)
-logging.getLogger("SMB").setLevel(logging.INFO)
-
-
-# default database based on sqlite
-data_dir = os.getenv("APPDATA") if os.name == "nt" else os.path.expanduser("~")
-data_dir = os.path.join(data_dir, ".musicassistant")
-if not os.path.isdir(data_dir):
- os.makedirs(data_dir)
-db_file = os.path.join(data_dir, "music_assistant.db")
-
-
-mass_conf = MassConfig(
- database_url=f"sqlite:///{db_file}",
-)
-if args.spotify_username and args.spotify_password:
- mass_conf.providers.append(
- MusicProviderConfig(
- ProviderType.SPOTIFY,
- username=args.spotify_username,
- password=args.spotify_password,
- )
- )
-if args.qobuz_username and args.qobuz_password:
- mass_conf.providers.append(
- MusicProviderConfig(
- type=ProviderType.QOBUZ,
- username=args.qobuz_username,
- password=args.qobuz_password,
- )
- )
-if args.tunein_username:
- mass_conf.providers.append(
- MusicProviderConfig(
- type=ProviderType.TUNEIN,
- username=args.tunein_username,
- )
- )
-
-if args.ytmusic_username and args.ytmusic_cookie:
- mass_conf.providers.append(
- MusicProviderConfig(
- ProviderType.YTMUSIC,
- username=args.ytmusic_username,
- password=args.ytmusic_cookie,
- )
- )
-
-if args.musicdir:
- mass_conf.providers.append(
- MusicProviderConfig(type=ProviderType.FILESYSTEM_LOCAL, path=args.musicdir)
- )
-
-if args.smb_path:
- mass_conf.providers.append(
- MusicProviderConfig(
- ProviderType.FILESYSTEM_SMB,
- username=args.smb_username,
- password=args.smb_password,
- path=args.smb_path,
- )
- )
-
-
-class TestPlayer(Player):
- """Demonstatration player implementation."""
-
- def __init__(self, player_id: str):
- """Init."""
- self.player_id = player_id
- self._attr_name = player_id
- self._attr_powered = True
- self._attr_elapsed_time = 0
- self._attr_current_url = ""
- self._attr_state = PlayerState.IDLE
- self._attr_available = True
- self._attr_volume_level = 100
-
- async def play_url(self, url: str) -> None:
- """Play the specified url on the player."""
- print(f"stream url: {url}")
- self._attr_current_url = url
- self.update_state()
- # launch stream url with ffplay so we can hear it playing ;-)
- # normally this url is sent to the actual player implementation
- await asyncio.create_subprocess_shell(
- f'ffplay -hide_banner -loglevel quiet -i "{url}"'
- )
-
- async def stop(self) -> None:
- """Send STOP command to player."""
- print("stop called")
- self._attr_state = PlayerState.IDLE
- self._attr_current_url = None
- self._attr_elapsed_time = 0
- self.update_state()
-
- async def play(self) -> None:
- """Send PLAY/UNPAUSE command to player."""
- print("play called")
- self._attr_state = PlayerState.PLAYING
- self._attr_elapsed_time = 1
- self.update_state()
-
- async def pause(self) -> None:
- """Send PAUSE command to player."""
- print("pause called")
- self._attr_state = PlayerState.PAUSED
- self.update_state()
-
- async def power(self, powered: bool) -> None:
- """Send POWER command to player."""
- print(f"POWER CALLED - new power: {powered}")
- self._attr_powered = powered
- self._attr_current_url = None
- self.update_state()
-
- async def volume_set(self, volume_level: int) -> None:
- """Send volume level (0..100) command to player."""
- print(f"volume_set called - {volume_level}")
- self._attr_volume_level = volume_level
- self.update_state()
-
-
-async def main():
- """Handle main execution."""
-
- asyncio.get_event_loop().set_debug(args.debug)
-
- async with MusicAssistant(mass_conf) as mass:
-
- # run sync
- await mass.music.start_sync()
-
- # get some data
- artists = await mass.music.artists.db_items()
- artists_lib = await mass.music.artists.db_items(True)
- print(
- f"Got {artists_lib.total} artists in library (of {artists.total} total in db)"
- )
-
- albums = await mass.music.albums.db_items()
- albums_lib = await mass.music.albums.db_items(True)
- print(
- f"Got {albums_lib.total} albums in library (of {albums.total} total in db)"
- )
-
- tracks = await mass.music.tracks.db_items()
- tracks_lib = await mass.music.tracks.db_items(True)
- print(
- f"Got {tracks_lib.total} tracks in library (of {tracks.total} total in db)"
- )
-
- playlists = await mass.music.playlists.db_items()
- playlists_lib = await mass.music.playlists.db_items(True)
- print(
- f"Got {playlists_lib.total} playlists in library (of {playlists.total} total in db)"
- )
-
- # register a player
- test_player1 = TestPlayer("test1")
- test_player2 = TestPlayer("test2")
- await mass.players.register_player(test_player1)
- await mass.players.register_player(test_player2)
-
- # try to play some music
- test_player1.active_queue.settings.shuffle_enabled = True
- test_player1.active_queue.settings.repeat_mode = RepeatMode.ALL
- test_player1.active_queue.settings.crossfade_duration = 10
- test_player1.active_queue.settings.crossfade_mode = CrossFadeMode.SMART
-
- # we can send a MediaItem object (such as Artist, Album, Track, Playlist)
- # we can also send an uri, such as spotify://track/abcdfefgh
- # or database://playlist/1
- # or a list of items
- if playlists.count > 0:
- await test_player1.active_queue.play_media(playlists.items[0])
- elif tracks.count > 0:
- await test_player1.active_queue.play_media(tracks.items[0])
-
- await asyncio.sleep(3600)
-
-
-if __name__ == "__main__":
- try:
- asyncio.run(main())
- except KeyboardInterrupt:
- pass
+++ /dev/null
-"""Simple example/script to run Music Assistant with Spotify provider."""
-import argparse
-import asyncio
-import logging
-import os
-from os.path import abspath, dirname
-from sys import path
-
-path.insert(1, dirname(dirname(abspath(__file__))))
-
-# pylint: disable=wrong-import-position
-from music_assistant.mass import MusicAssistant
-from music_assistant.models.config import MassConfig, MusicProviderConfig
-from music_assistant.models.enums import ProviderType
-
-parser = argparse.ArgumentParser(description="MusicAssistant")
-parser.add_argument(
- "--username",
- required=True,
- help="Spotify username",
-)
-parser.add_argument(
- "--password",
- required=True,
- help="Spotify password.",
-)
-parser.add_argument(
- "--debug",
- action="store_true",
- help="Enable verbose debug logging",
-)
-args = parser.parse_args()
-
-
-# setup logger
-if args.debug:
- logging.basicConfig(
- level=logging.DEBUG,
- format="%(asctime)-15s %(levelname)-5s %(name)s -- %(message)s",
- )
- # silence some loggers
- logging.getLogger("aiorun").setLevel(logging.WARNING)
- logging.getLogger("asyncio").setLevel(logging.INFO)
- logging.getLogger("aiosqlite").setLevel(logging.WARNING)
- logging.getLogger("databases").setLevel(logging.WARNING)
-
-
-# default database based on sqlite
-data_dir = os.getenv("APPDATA") if os.name == "nt" else os.path.expanduser("~")
-data_dir = os.path.join(data_dir, ".musicassistant")
-if not os.path.isdir(data_dir):
- os.makedirs(data_dir)
-db_file = os.path.join(data_dir, "music_assistant.db")
-
-mass = MusicAssistant(
- MassConfig(
- database_url=MassConfig,
- providers=[
- MusicProviderConfig(
- ProviderType.SPOTIFY,
- username=args.spotify_username,
- password=args.spotify_password,
- )
- ],
- )
-)
-
-
-async def main():
- """Handle main execution."""
-
- asyncio.get_event_loop().set_debug(args.debug)
-
- # without contextmanager we need to call the async setup
- await mass.setup()
-
- # start sync
- await mass.music.start_sync(schedule=3)
-
- # get some data
- await mass.music.artists.db_items()
- await mass.music.tracks.db_items()
- await mass.music.radio.db_items()
-
- # run for an hour until someone hits CTRL+C
- await asyncio.sleep(3600)
-
- # without contextmanager we need to call the stop
- await mass.stop()
-
-
-if __name__ == "__main__":
- try:
- asyncio.run(main())
- except KeyboardInterrupt:
- pass
"""Music Assistant: The music library manager in python."""
-
-from .mass import MusicAssistant # noqa
--- /dev/null
+"""Run the Music Assistant Server."""
+import argparse
+import asyncio
+import logging
+import os
+from logging.handlers import TimedRotatingFileHandler
+
+import coloredlogs
+from aiorun import run
+
+from music_assistant.server import MusicAssistant
+
+
+def get_arguments():
+ """Arguments handling."""
+ parser = argparse.ArgumentParser(description="MusicAssistant")
+
+ default_data_dir = os.getenv("APPDATA") if os.name == "nt" else os.path.expanduser("~")
+ default_data_dir = os.path.join(default_data_dir, ".musicassistant")
+
+ parser.add_argument(
+ "-c",
+ "--config",
+ metavar="path_to_config_dir",
+ default=default_data_dir,
+ help="Directory that contains the MusicAssistant configuration",
+ )
+ parser.add_argument(
+ "--log-level",
+ type=str,
+ default="info",
+ help="Provide logging level. Example --log-level debug, "
+ "default=info, possible=(critical, error, warning, info, debug)",
+ )
+ arguments = parser.parse_args()
+ return arguments
+
+
+def setup_logger(data_path: str, level: str = "DEBUG"):
+ """Initialize logger."""
+ logs_dir = os.path.join(data_path, "logs")
+ if not os.path.isdir(logs_dir):
+ os.mkdir(logs_dir)
+ logger = logging.getLogger()
+ log_fmt = "%(asctime)-15s %(levelname)-5s %(name)s -- %(message)s"
+ log_formatter = logging.Formatter(log_fmt)
+ consolehandler = logging.StreamHandler()
+ consolehandler.setFormatter(log_formatter)
+ consolehandler.setLevel(logging.DEBUG)
+ logger.addHandler(consolehandler)
+ log_filename = os.path.join(logs_dir, "musicassistant.log")
+ file_handler = TimedRotatingFileHandler(
+ log_filename, when="midnight", interval=1, backupCount=10
+ )
+ file_handler.setLevel(logging.DEBUG)
+ file_handler.setFormatter(log_formatter)
+ logger.addHandler(file_handler)
+
+ # global level is debug by default unless overridden
+ logger.setLevel(level)
+
+ # silence some loggers
+ logging.getLogger("asyncio").setLevel(logging.WARNING)
+ logging.getLogger("aiosqlite").setLevel(logging.WARNING)
+ logging.getLogger("databases").setLevel(logging.WARNING)
+
+ # enable coloredlogs
+ coloredlogs.install(level=level, fmt=log_fmt)
+ return logger
+
+
+def main():
+ """Start MusicAssistant."""
+ # parse arguments
+ args = get_arguments()
+ data_dir = args.config
+ if not os.path.isdir(data_dir):
+ os.makedirs(data_dir)
+ # setup logger
+ log_level = args.log_level.upper()
+ logger = setup_logger(data_dir, log_level)
+ mass = MusicAssistant(data_dir)
+
+ def on_shutdown(loop):
+ logger.info("shutdown requested!")
+ loop.run_until_complete(mass.stop())
+
+ async def start_mass():
+ loop = asyncio.get_running_loop()
+ if log_level == "DEBUG":
+ loop.set_debug(True)
+ await mass.start()
+
+ run(
+ start_mass(),
+ use_uvloop=False,
+ shutdown_callback=on_shutdown,
+ executor_workers=64,
+ )
+
+
+if __name__ == "__main__":
+ main()
--- /dev/null
+"""Music Assistant: The music library manager in python."""
--- /dev/null
+"""Provide common/shared files for the Music Assistant Server and client."""
--- /dev/null
+"""Various utils/helpers."""
--- /dev/null
+"""Helpers for date and time."""
+from __future__ import annotations
+
+import datetime
+
+LOCAL_TIMEZONE = datetime.datetime.now(datetime.UTC).astimezone().tzinfo
+
+
+def utc() -> datetime.datetime:
+ """Get current UTC datetime."""
+ return datetime.datetime.now(datetime.UTC)
+
+
+def utc_timestamp() -> float:
+ """Return UTC timestamp in seconds as float."""
+ return utc().timestamp()
+
+
+def now() -> datetime.datetime:
+ """Get current datetime in local timezone."""
+ return datetime.datetime.now(LOCAL_TIMEZONE)
+
+
+def now_timestamp() -> float:
+ """Return current datetime as timestamp in local timezone."""
+ return now().timestamp()
+
+
+def future_timestamp(**kwargs) -> float:
+ """Return current timestamp + timedelta."""
+ return (now() + datetime.timedelta(**kwargs)).timestamp()
+
+
+def from_utc_timestamp(timestamp: float) -> datetime.datetime:
+ """Return datetime from UTC timestamp."""
+ return datetime.datetime.fromtimestamp(timestamp, datetime.UTC)
+
+
+def iso_from_utc_timestamp(timestamp: float) -> str:
+ """Return ISO 8601 datetime string from UTC timestamp."""
+ return from_utc_timestamp(timestamp).isoformat()
--- /dev/null
+"""Helpers to work with (de)serializing of json."""
+
+import base64
+from types import MethodType
+from typing import Any
+
+import aiofiles
+import orjson
+from _collections_abc import dict_keys, dict_values
+
+JSON_ENCODE_EXCEPTIONS = (TypeError, ValueError)
+JSON_DECODE_EXCEPTIONS = (orjson.JSONDecodeError,)
+
+
+def json_encoder_default(obj: Any) -> Any:
+ """Convert Special objects.
+
+ Hand other objects to the original method.
+ """
+ if getattr(obj, "do_not_serialize", None):
+ return None
+ if (
+ isinstance(obj, list | set | filter | tuple | dict_values | dict_keys | dict_values)
+ or obj.__class__ == "dict_valueiterator"
+ ):
+ return list(obj)
+ if hasattr(obj, "as_dict"):
+ return obj.as_dict()
+ if hasattr(obj, "to_dict"):
+ return obj.to_dict(omit_none=True)
+ if isinstance(obj, bytes):
+ return base64.b64encode(obj).decode("ascii")
+ if isinstance(obj, MethodType):
+ return None
+ raise TypeError
+
+
+def json_dumps(data: Any) -> str:
+ """Dump json string."""
+ return orjson.dumps(
+ data,
+ option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2,
+ default=json_encoder_default,
+ ).decode("utf-8")
+
+
+json_loads = orjson.loads
+
+
+async def load_json_file(path: str) -> dict:
+ """Load JSON from file."""
+ async with aiofiles.open(path, "r") as _file:
+ content = await _file.read()
+ return json_loads(content)
--- /dev/null
+"""Helpers for creating/parsing URI's."""
+
+import os
+
+from music_assistant.common.models.enums import MediaType
+from music_assistant.common.models.errors import MusicAssistantError
+
+
+def parse_uri(uri: str) -> tuple[MediaType, str, str]:
+ """Try to parse URI to Mass identifiers.
+
+ Returns Tuple: MediaType, provider_domain, item_id
+ """
+ try:
+ if uri.startswith("https://open."):
+ # public share URL (e.g. Spotify or Qobuz, not sure about others)
+ # https://open.spotify.com/playlist/5lH9NjOeJvctAO92ZrKQNB?si=04a63c8234ac413e
+ provider_domain = uri.split(".")[1]
+ media_type_str = uri.split("/")[3]
+ media_type = MediaType(media_type_str)
+ item_id = uri.split("/")[4].split("?")[0]
+ elif uri.startswith("http://") or uri.startswith("https://"):
+ # Translate a plain URL to the URL provider
+ provider_domain = "url"
+ media_type = MediaType.UNKNOWN
+ item_id = uri
+ elif "://" in uri:
+ # music assistant-style uri
+ # provider://media_type/item_id
+ provider_domain = uri.split("://")[0]
+ media_type_str = uri.split("/")[2]
+ media_type = MediaType(media_type_str)
+ item_id = uri.split(f"{media_type_str}/")[1]
+ elif ":" in uri:
+ # spotify new-style uri
+ provider_domain, media_type_str, item_id = uri.split(":")
+ media_type = MediaType(media_type_str)
+ elif os.path.isfile(uri):
+ # Translate a local file (which is not from file provider) to the URL provider
+ provider_domain = "url"
+ media_type = MediaType.TRACK
+ item_id = uri
+ else:
+ raise KeyError
+ except (TypeError, AttributeError, ValueError, KeyError) as err:
+ raise MusicAssistantError(f"Not a valid Music Assistant uri: {uri}") from err
+ return (media_type, provider_domain, item_id)
+
+
+def create_uri(media_type: MediaType, provider_domain: str, item_id: str) -> str:
+ """Create Music Assistant URI from MediaItem values."""
+ return f"{provider_domain}://{media_type.value}/{item_id}"
--- /dev/null
+"""Helper and utility functions."""
+from __future__ import annotations
+
+import asyncio
+import os
+import platform
+import re
+import socket
+import tempfile
+from collections.abc import Callable
+from typing import Any, TypeVar
+
+import memory_tempfile
+import unidecode
+
+# pylint: disable=invalid-name
+T = TypeVar("T")
+_UNDEF: dict = {}
+CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable)
+CALLBACK_TYPE = Callable[[], None]
+# pylint: enable=invalid-name
+
+
+def filename_from_string(string: str) -> str:
+ """Create filename from unsafe string."""
+ keepcharacters = (" ", ".", "_")
+ return "".join(c for c in string if c.isalnum() or c in keepcharacters).rstrip()
+
+
+def try_parse_int(possible_int: Any, default: int | None = 0) -> int | None:
+ """Try to parse an int."""
+ try:
+ return int(possible_int)
+ except (TypeError, ValueError):
+ return default
+
+
+def try_parse_float(possible_float: Any, default: float | None = 0.0) -> float | None:
+ """Try to parse a float."""
+ try:
+ return float(possible_float)
+ except (TypeError, ValueError):
+ return default
+
+
+def try_parse_bool(possible_bool: Any) -> str:
+ """Try to parse a bool."""
+ if isinstance(possible_bool, bool):
+ return possible_bool
+ return possible_bool in ["true", "True", "1", "on", "ON", 1]
+
+
+def create_safe_string(input_str: str) -> str:
+ """Return clean lowered string for compare actions."""
+ input_str = input_str.lower().strip()
+ unaccented_string = unidecode.unidecode(input_str)
+ return re.sub(r"[^a-zA-Z0-9]", "", unaccented_string)
+
+
+def create_sort_name(input_str: str) -> str:
+ """Create sort name/title from string."""
+ input_str = input_str.lower().strip()
+ for item in ["the ", "de ", "les "]:
+ if input_str.startswith(item):
+ input_str = input_str.replace(item, "")
+ return input_str.strip()
+
+
+def parse_title_and_version(title: str, track_version: str = None):
+ """Try to parse clean track title and version from the title."""
+ version = ""
+ for splitter in [" (", " [", " - ", " (", " [", "-"]:
+ if splitter in title:
+ title_parts = title.split(splitter)
+ for title_part in title_parts:
+ # look for the end splitter
+ for end_splitter in [")", "]"]:
+ if end_splitter in title_part:
+ title_part = title_part.split(end_splitter)[0] # noqa: PLW2901
+ for version_str in [
+ "version",
+ "live",
+ "edit",
+ "remix",
+ "mix",
+ "acoustic",
+ "instrumental",
+ "karaoke",
+ "remaster",
+ "versie",
+ "radio",
+ "unplugged",
+ "disco",
+ "akoestisch",
+ "deluxe",
+ ]:
+ if version_str in title_part.lower():
+ version = title_part
+ title = title.split(splitter + version)[0]
+ title = clean_title(title)
+ if not version and track_version:
+ version = track_version
+ version = get_version_substitute(version).title()
+ if version == title:
+ version = ""
+ return title, version
+
+
+def clean_title(title: str) -> str:
+ """Strip unwanted additional text from title."""
+ for splitter in [" (", " [", " - ", " (", " [", "-"]:
+ if splitter in title:
+ title_parts = title.split(splitter)
+ for title_part in title_parts:
+ # look for the end splitter
+ for end_splitter in [")", "]"]:
+ if end_splitter in title_part:
+ title_part = title_part.split(end_splitter)[0] # noqa: PLW2901
+ for ignore_str in ["feat.", "featuring", "ft.", "with ", "explicit"]:
+ if ignore_str in title_part.lower():
+ return title.split(splitter + title_part)[0].strip()
+ return title.strip()
+
+
+def get_version_substitute(version_str: str):
+ """Transform provider version str to universal version type."""
+ version_str = version_str.lower()
+ # substitute edit and edition with version
+ if "edition" in version_str or "edit" in version_str:
+ version_str = version_str.replace(" edition", " version")
+ version_str = version_str.replace(" edit ", " version")
+ if version_str.startswith("the "):
+ version_str = version_str.split("the ")[1]
+ if "radio mix" in version_str:
+ version_str = "radio version"
+ elif "video mix" in version_str:
+ version_str = "video version"
+ elif "spanglish" in version_str or "spanish" in version_str:
+ version_str = "spanish version"
+ elif "remaster" in version_str:
+ version_str = "remaster"
+ return version_str.strip()
+
+
+def get_ip():
+ """Get primary IP-address for this host."""
+ # pylint: disable=broad-except,no-member
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ try:
+ # doesn't even have to be reachable
+ sock.connect(("10.255.255.255", 1))
+ _ip = sock.getsockname()[0]
+ except Exception:
+ _ip = "127.0.0.1"
+ finally:
+ sock.close()
+ return _ip
+
+
+def is_port_in_use(port: int) -> bool:
+ """Check if port is in use."""
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as _sock:
+ try:
+ return _sock.connect_ex(("localhost", port)) == 0
+ except socket.gaierror:
+ return True
+
+
+async def select_free_port(range_start: int, range_end: int) -> int:
+ """Automatically find available port within range."""
+
+ def _select_free_port():
+ for port in range(range_start, range_end):
+ if not is_port_in_use(port):
+ return port
+ raise OSError("No free port available")
+
+ return await asyncio.to_thread(_select_free_port)
+
+
+async def get_ip_from_host(dns_name: str) -> str:
+ """Resolve (first) IP-address for given dns name."""
+
+ def _resolve():
+ try:
+ return socket.gethostbyname(dns_name)
+ except Exception: # pylint: disable=broad-except
+ # fail gracefully!
+ return dns_name
+
+ return await asyncio.to_thread(_resolve)
+
+
+def get_ip_pton():
+ """Return socket pton for local ip."""
+ # pylint:disable=no-member
+ try:
+ return socket.inet_pton(socket.AF_INET, get_ip())
+ except OSError:
+ return socket.inet_pton(socket.AF_INET6, get_ip())
+
+
+def get_folder_size(folderpath):
+ """Return folder size in gb."""
+ total_size = 0
+ # pylint: disable=unused-variable
+ for dirpath, dirnames, filenames in os.walk(folderpath):
+ for _file in filenames:
+ _fp = os.path.join(dirpath, _file)
+ total_size += os.path.getsize(_fp)
+ # pylint: enable=unused-variable
+ total_size_gb = total_size / float(1 << 30)
+ return total_size_gb
+
+
+def merge_dict(base_dict: dict, new_dict: dict, allow_overwite=False):
+ """Merge dict without overwriting existing values."""
+ final_dict = base_dict.copy()
+ for key, value in new_dict.items():
+ if final_dict.get(key) and isinstance(value, dict):
+ final_dict[key] = merge_dict(final_dict[key], value)
+ if final_dict.get(key) and isinstance(value, tuple):
+ final_dict[key] = merge_tuples(final_dict[key], value)
+ if final_dict.get(key) and isinstance(value, list):
+ final_dict[key] = merge_lists(final_dict[key], value)
+ elif not final_dict.get(key) or allow_overwite:
+ final_dict[key] = value
+ return final_dict
+
+
+def merge_tuples(base: tuple, new: tuple) -> tuple:
+ """Merge 2 tuples."""
+ return tuple(x for x in base if x not in new) + tuple(new)
+
+
+def merge_lists(base: list, new: list) -> list:
+ """Merge 2 lists."""
+ return list(x for x in base if x not in new) + list(new)
+
+
+def create_tempfile():
+ """Return a (named) temporary file."""
+ if platform.system() == "Linux":
+ return memory_tempfile.MemoryTempfile(fallback=True).NamedTemporaryFile(buffering=0)
+ return tempfile.NamedTemporaryFile(buffering=0)
+
+
+def get_changed_keys(
+ dict1: dict[str, Any],
+ dict2: dict[str, Any],
+ ignore_keys: list[str] | None = None,
+) -> set[str]:
+ """Compare 2 dicts and return set of changed keys."""
+ if not dict1:
+ return set(dict2.keys())
+ if not dict2:
+ return set(dict1.keys())
+ changed_keys = set()
+ for key, value in dict2.items():
+ if ignore_keys and key in ignore_keys:
+ continue
+ if key not in dict1:
+ changed_keys.add(key)
+ elif isinstance(value, dict):
+ changed_keys.update(get_changed_keys(dict1[key], value))
+ elif dict1[key] != value:
+ changed_keys.add(key)
+ return changed_keys
+
+
+def empty_queue(q: asyncio.Queue) -> None:
+ """Empty an asyncio Queue."""
+ for _ in range(q.qsize()):
+ try:
+ q.get_nowait()
+ q.task_done()
+ except (asyncio.QueueEmpty, ValueError):
+ pass
--- /dev/null
+"""Package with all common/shared (serializable) Models (dataclassses)."""
--- /dev/null
+"""Generic models used for the (websockets) API communication."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any
+
+from mashumaro import DataClassDictMixin
+
+from music_assistant.common.models.event import MassEvent
+
+
+@dataclass
+class CommandMessage(DataClassDictMixin):
+ """Model for a Message holding a command from server to client or client to server."""
+
+ message_id: str | int
+ command: str
+ args: dict[str, Any] | None = None
+
+
+@dataclass
+class ResultMessageBase(DataClassDictMixin):
+ """Base class for a result/response of a Command Message."""
+
+ message_id: str
+
+
+@dataclass
+class SuccessResultMessage(ResultMessageBase):
+ """Message sent when a Command has been successfully executed."""
+
+ result: Any
+
+
+@dataclass
+class ErrorResultMessage(ResultMessageBase):
+ """Message sent when a command did not execute successfully."""
+
+ error_code: str
+ details: str | None = None
+
+
+EventMessage = MassEvent
+
+
+@dataclass
+class ServerInfoMessage(DataClassDictMixin):
+ """Message sent by the server with it's info when a client connects."""
+
+ server_version: str
+ schema_version: int
+
+
+MessageType = (
+ CommandMessage
+ | EventMessage
+ | SuccessResultMessage
+ | ErrorResultMessage
+ | ServerInfoMessage
+)
--- /dev/null
+"""Model and helpers for Config entries."""
+from __future__ import annotations
+
+import logging
+from collections.abc import Callable, Iterable
+from dataclasses import dataclass
+from types import NoneType
+from typing import Any
+
+from mashumaro import DataClassDictMixin
+
+from music_assistant.common.models.enums import ProviderType
+from music_assistant.constants import (
+ CONF_EQ_BASS,
+ CONF_EQ_MID,
+ CONF_EQ_TREBLE,
+ CONF_FLOW_MODE,
+ CONF_OUTPUT_CHANNELS,
+ CONF_VOLUME_NORMALISATION,
+ CONF_VOLUME_NORMALISATION_TARGET,
+)
+
+from .enums import ConfigEntryType
+
+LOGGER = logging.getLogger(__name__)
+
+ConfigValueType = str | int | float | bool | None
+
+ConfigEntryTypeMap = {
+ ConfigEntryType.BOOLEAN: bool,
+ ConfigEntryType.STRING: str,
+ ConfigEntryType.PASSWORD: str,
+ ConfigEntryType.INTEGER: int,
+ ConfigEntryType.FLOAT: float,
+ ConfigEntryType.LABEL: str,
+}
+
+
+@dataclass
+class ConfigValueOption(DataClassDictMixin):
+ """Model for a value with separated name/value."""
+
+ title: str
+ value: ConfigValueType
+
+
+@dataclass
+class ConfigEntry(DataClassDictMixin):
+ """Model for a Config Entry.
+
+ The definition of something that can be configured for an object (e.g. provider or player)
+ within Music Assistant (without the value).
+ """
+
+ # key: used as identifier for the entry, also for localization
+ key: str
+ type: ConfigEntryType
+ # label: default label when no translation for the key is present
+ label: str
+ default_value: ConfigValueType = None
+ required: bool = True
+ # options [optional]: select from list of possible values/options
+ options: list[ConfigValueOption] | None = None
+ # range [optional]: select values within range
+ range: tuple[int, int] | None = None
+ # description [optional]: extended description of the setting.
+ description: str | None = None
+ # help_link [optional]: link to help article.
+ help_link: str | None = None
+ # multi_value [optional]: allow multiple values from the list
+ multi_value: bool = False
+ # depends_on [optional]: needs to be set before this setting shows up in frontend
+ depends_on: str | None = None
+ # hidden: hide from UI
+ hidden: bool = False
+ # advanced: this is an advanced setting (frontend hides it in some corner)
+ advanced: bool = False
+
+
+@dataclass
+class ConfigEntryValue(ConfigEntry):
+ """Config Entry with its value parsed."""
+
+ value: ConfigValueType = None
+
+ @classmethod
+ def parse(
+ cls,
+ entry: ConfigEntry,
+ value: ConfigValueType,
+ allow_none: bool = False,
+ ) -> ConfigEntryValue:
+ """Parse ConfigEntryValue from the config entry and plain value."""
+ result = ConfigEntryValue.from_dict(entry.to_dict())
+ result.value = value
+ expected_type = ConfigEntryTypeMap.get(result.type, NoneType)
+ if result.value is None:
+ result.value = entry.default_value
+ if result.value is None and not entry.required:
+ expected_type = NoneType
+ if entry.type == ConfigEntryType.LABEL:
+ result.value = result.label
+ if not isinstance(result.value, expected_type):
+ if result.value is None and allow_none:
+ # In some cases we allow this (e.g. create default config), hence the allow_none
+ return result
+ # handle common conversions/mistakes
+ if expected_type == float and isinstance(result.value, int):
+ result.value = float(result.value)
+ return result
+ if expected_type == int and isinstance(result.value, float):
+ result.value = int(result.value)
+ return result
+ if entry.default_value:
+ LOGGER.warning(
+ "%s has unexpected type: %s, fallback to default",
+ result.key,
+ type(result.value),
+ )
+ return result
+ raise ValueError(f"{result.key} has unexpected type: {type(result.value)}")
+ return result
+
+
+@dataclass
+class Config(DataClassDictMixin):
+ """Base Configuration object."""
+
+ values: dict[str, ConfigEntryValue]
+
+ def get_value(self, key: str) -> ConfigValueType:
+ """Return config value for given key."""
+ config_value = self.values[key]
+ if config_value.type == ConfigEntryType.PASSWORD: # noqa: SIM102
+ if decrypt_callback := self.get_decrypt_callback():
+ return decrypt_callback(config_value.value)
+ return config_value.value
+
+ @classmethod
+ def parse(
+ cls,
+ config_entries: Iterable[ConfigEntry],
+ raw: dict[str, Any],
+ allow_none: bool = False,
+ decrypt_callback: Callable[[str], str] | None = None,
+ ) -> Config:
+ """Parse Config from the raw values (as stored in persistent storage)."""
+ values = {
+ x.key: ConfigEntryValue.parse(x, raw.get("values", {}).get(x.key), allow_none).to_dict()
+ for x in config_entries
+ }
+ conf = cls.from_dict({**raw, "values": values})
+ if decrypt_callback:
+ conf.set_decrypt_callback(decrypt_callback)
+ return conf
+
+ def to_raw(self) -> dict[str, Any]:
+ """Return minimized/raw dict to store in persistent storage."""
+ return {
+ **self.to_dict(),
+ "values": {x.key: x.value for x in self.values.values() if x.value != x.default_value},
+ }
+
+ def set_decrypt_callback(self, callback: Callable[[str], str]) -> None:
+ """Register callback to decrypt (password) strings."""
+ setattr(self, "decrypt_callback", callback)
+
+ def get_decrypt_callback(self) -> Callable[[str], str] | None:
+ """Get optional callback to decrypt (password) strings."""
+ return getattr(self, "decrypt_callback", None)
+
+
+@dataclass
+class ProviderConfig(Config):
+ """Provider(instance) Configuration."""
+
+ type: ProviderType
+ domain: str
+ instance_id: str
+ # enabled: boolean to indicate if the provider is enabled
+ enabled: bool = True
+ # name: an (optional) custom name for this provider instance/config
+ name: str | None = None
+
+
+@dataclass
+class PlayerConfig(Config):
+ """Player Configuration."""
+
+ provider: str
+ player_id: str
+ # enabled: boolean to indicate if the player is enabled
+ enabled: bool = True
+ # name: an (optional) custom name for this player
+ name: str | None = None
+
+
+DEFAULT_PLAYER_CONFIG_ENTRIES = (
+ ConfigEntry(
+ key=CONF_VOLUME_NORMALISATION,
+ type=ConfigEntryType.BOOLEAN,
+ label="Enable volume normalization (EBU-R128 based)",
+ default_value=True,
+ description="Enable volume normalization based on the EBU-R128 "
+ "standard without affecting dynamic range",
+ ),
+ ConfigEntry(
+ key=CONF_FLOW_MODE,
+ type=ConfigEntryType.BOOLEAN,
+ label="Enable queue flow mode",
+ default_value=False,
+ description='Enable "flow" mode where all queue tracks are sent as a continuous '
+ "audio stream. Use for players that do not natively support gapless and/or "
+ "crossfading or if the player has trouble transitioning between tracks.",
+ advanced=True,
+ ),
+ ConfigEntry(
+ key=CONF_VOLUME_NORMALISATION_TARGET,
+ type=ConfigEntryType.INTEGER,
+ range=(-30, 0),
+ default_value=-14,
+ label="Target level for volume normalisation",
+ description="Adjust average (perceived) loudness to this target level, "
+ "default is -14 LUFS",
+ depends_on=CONF_VOLUME_NORMALISATION,
+ advanced=True,
+ ),
+ ConfigEntry(
+ key=CONF_EQ_BASS,
+ type=ConfigEntryType.INTEGER,
+ range=(-10, 10),
+ default_value=0,
+ label="Equalizer: bass",
+ description="Use the builtin basic equalizer to adjust the bass of audio.",
+ advanced=True,
+ ),
+ ConfigEntry(
+ key=CONF_EQ_MID,
+ type=ConfigEntryType.INTEGER,
+ range=(-10, 10),
+ default_value=0,
+ label="Equalizer: midrange",
+ description="Use the builtin basic equalizer to adjust the midrange of audio.",
+ advanced=True,
+ ),
+ ConfigEntry(
+ key=CONF_EQ_TREBLE,
+ type=ConfigEntryType.INTEGER,
+ range=(-10, 10),
+ default_value=0,
+ label="Equalizer: treble",
+ description="Use the builtin basic equalizer to adjust the treble of audio.",
+ advanced=True,
+ ),
+ ConfigEntry(
+ key=CONF_OUTPUT_CHANNELS,
+ type=ConfigEntryType.STRING,
+ options=[
+ ConfigValueOption("Stereo (both channels)", "stereo"),
+ ConfigValueOption("Left channel", "left"),
+ ConfigValueOption("Right channel", "right"),
+ ConfigValueOption("Mono (both channels)", "mono"),
+ ],
+ default_value="stereo",
+ label="Output Channel Mode",
+ description="You can configure this player to play only the left or right channel, "
+ "for example to a create a stereo pair with 2 players.",
+ advanced=True,
+ ),
+)
--- /dev/null
+"""All enums used by the Music Assistant models."""
+from __future__ import annotations
+
+from enum import Enum
+from typing import Any, Self, TypeVar
+
+# pylint:disable=ungrouped-imports
+try:
+ from enum import StrEnum
+except AttributeError:
+ # Python 3.10 compatibility for strenum
+ _StrEnumSelfT = TypeVar("_StrEnumSelfT", bound="StrEnum")
+
+ class StrEnum(str, Enum):
+ """Partial backport of Python 3.11's StrEnum for our basic use cases."""
+
+ def __new__(
+ cls: type[_StrEnumSelfT], value: str, *args: Any, **kwargs: Any
+ ) -> _StrEnumSelfT:
+ """Create a new StrEnum instance."""
+ if not isinstance(value, str):
+ raise TypeError(f"{value!r} is not a string")
+ return super().__new__(cls, value, *args, **kwargs)
+
+ def __str__(self) -> str:
+ """Return self."""
+ return str(self)
+
+ @staticmethod
+ def _generate_next_value_(
+ name: str, start: int, count: int, last_values: list[Any] # noqa
+ ) -> Any:
+ """Make `auto()` explicitly unsupported.
+
+ We may revisit this when it's very clear that Python 3.11's
+ `StrEnum.auto()` behavior will no longer change.
+ """
+ raise TypeError("auto() is not supported by this implementation")
+
+
+class MediaType(StrEnum):
+ """StrEnum for MediaType."""
+
+ ARTIST = "artist"
+ ALBUM = "album"
+ TRACK = "track"
+ PLAYLIST = "playlist"
+ RADIO = "radio"
+ FOLDER = "folder"
+ UNKNOWN = "unknown"
+
+ @classmethod
+ @property
+ def ALL(cls: Self) -> tuple[MediaType, ...]: # noqa: N802
+ """Return all (default) MediaTypes as tuple."""
+ return (
+ MediaType.ARTIST,
+ MediaType.ALBUM,
+ MediaType.TRACK,
+ MediaType.PLAYLIST,
+ MediaType.RADIO,
+ )
+
+
+class LinkType(StrEnum):
+ """StrEnum with link types."""
+
+ WEBSITE = "website"
+ FACEBOOK = "facebook"
+ TWITTER = "twitter"
+ LASTFM = "lastfm"
+ YOUTUBE = "youtube"
+ INSTAGRAM = "instagram"
+ SNAPCHAT = "snapchat"
+ TIKTOK = "tiktok"
+ DISCOGS = "discogs"
+ WIKIPEDIA = "wikipedia"
+ ALLMUSIC = "allmusic"
+
+
+class ImageType(StrEnum):
+ """StrEnum with image types."""
+
+ THUMB = "thumb"
+ LANDSCAPE = "landscape"
+ FANART = "fanart"
+ LOGO = "logo"
+ CLEARART = "clearart"
+ BANNER = "banner"
+ CUTOUT = "cutout"
+ BACK = "back"
+ DISCART = "discart"
+ OTHER = "other"
+
+
+class AlbumType(StrEnum):
+ """StrEnum for Album type."""
+
+ ALBUM = "album"
+ SINGLE = "single"
+ COMPILATION = "compilation"
+ EP = "ep"
+ UNKNOWN = "unknown"
+
+
+class ContentType(StrEnum):
+ """Enum with audio content/container types supported by ffmpeg."""
+
+ OGG = "ogg"
+ FLAC = "flac"
+ MP3 = "mp3"
+ AAC = "aac"
+ MPEG = "mpeg"
+ ALAC = "alac"
+ WAV = "wav"
+ AIFF = "aiff"
+ WMA = "wma"
+ M4A = "m4a"
+ DSF = "dsf"
+ WAVPACK = "wv"
+ PCM_S16LE = "s16le" # PCM signed 16-bit little-endian
+ PCM_S24LE = "s24le" # PCM signed 24-bit little-endian
+ PCM_S32LE = "s32le" # PCM signed 32-bit little-endian
+ PCM_F32LE = "f32le" # PCM 32-bit floating-point little-endian
+ PCM_F64LE = "f64le" # PCM 64-bit floating-point little-endian
+ PCM = "pcm" # PCM generic (details determined later)
+ MPEG_DASH = "dash"
+ UNKNOWN = "?"
+
+ @classmethod
+ def try_parse(cls: ContentType, string: str) -> ContentType:
+ """Try to parse ContentType from (url)string/extension."""
+ tempstr = string.lower()
+ if "audio/" in tempstr:
+ tempstr = tempstr.split("/")[1]
+ for splitter in (".", ","):
+ if splitter in tempstr:
+ for val in tempstr.split(splitter):
+ try:
+ return cls(val.strip())
+ except ValueError:
+ pass
+
+ tempstr = tempstr.split("?")[0]
+ tempstr = tempstr.split("&")[0]
+ tempstr = tempstr.split(";")[0]
+ tempstr = tempstr.replace("mp4", "m4a")
+ tempstr = tempstr.replace("mpd", "dash")
+ try:
+ return cls(tempstr)
+ except ValueError:
+ return cls.UNKNOWN
+
+ def is_pcm(self) -> bool:
+ """Return if contentype is PCM."""
+ return self.name.startswith("PCM")
+
+ def is_lossless(self) -> bool:
+ """Return if format is lossless."""
+ return self.is_pcm() or self in (
+ ContentType.DSF,
+ ContentType.FLAC,
+ ContentType.AIFF,
+ ContentType.WAV,
+ )
+
+ @classmethod
+ def from_bit_depth(cls, bit_depth: int, floating_point: bool = False) -> ContentType:
+ """Return (PCM) Contenttype from PCM bit depth."""
+ if floating_point and bit_depth > 32:
+ return cls.PCM_F64LE
+ if floating_point:
+ return cls.PCM_F32LE
+ if bit_depth == 16:
+ return cls.PCM_S16LE
+ if bit_depth == 24:
+ return cls.PCM_S24LE
+ return cls.PCM_S32LE
+
+
+class QueueOption(StrEnum):
+ """StrEnum representation of the queue (play) options.
+
+ - PLAY -> Insert new item(s) in queue at the current position and start playing.
+ - REPLACE -> Replace entire queue contents with the new items and start playing from index 0.
+ - NEXT -> Insert item(s) after current playing/buffered item.
+ - REPLACE_NEXT -> Replace item(s) after current playing/buffered item.
+ - ADD -> Add new item(s) to the queue (at the end if shuffle is not enabled).
+ """
+
+ PLAY = "play"
+ REPLACE = "replace"
+ NEXT = "next"
+ REPLACE_NEXT = "replace_next"
+ ADD = "add"
+
+
+class RepeatMode(StrEnum):
+ """Enum with repeat modes."""
+
+ OFF = "off" # no repeat at all
+ ONE = "one" # repeat one/single track
+ ALL = "all" # repeat entire queue
+
+
+class PlayerState(StrEnum):
+ """StrEnum for the (playback)state of a player."""
+
+ IDLE = "idle"
+ PAUSED = "paused"
+ PLAYING = "playing"
+ OFF = "off"
+
+
+class PlayerType(StrEnum):
+ """Enum with possible Player Types.
+
+ player: A regular player.
+ group: A (dedicated) group player or playergroup.
+ """
+
+ PLAYER = "player"
+ GROUP = "group"
+
+
+class PlayerFeature(StrEnum):
+ """Enum with possible Player features.
+
+ power: The player has a dedicated power control.
+ volume: The player supports adjusting the volume.
+ mute: The player supports muting the volume.
+ sync: The player supports syncing with other players (of the same platform).
+ accurate_time: The player provides millisecond accurate timing information.
+ seek: The player supports seeking to a specific.
+ set_members: The PlayerGroup supports adding/removing members.
+ queue: The player supports (en)queuing of media items.
+ """
+
+ POWER = "power"
+ VOLUME_SET = "volume_set"
+ VOLUME_MUTE = "volume_mute"
+ PAUSE = "pause"
+ SYNC = "sync"
+ ACCURATE_TIME = "accurate_time"
+ SEEK = "seek"
+ SET_MEMBERS = "set_members"
+ QUEUE = "queue"
+
+
+class EventType(StrEnum):
+ """Enum with possible Events."""
+
+ PLAYER_ADDED = "player_added"
+ PLAYER_UPDATED = "player_updated"
+ PLAYER_REMOVED = "player_removed"
+ PLAYER_SETTINGS_UPDATED = "player_settings_updated"
+ QUEUE_ADDED = "queue_added"
+ QUEUE_UPDATED = "queue_updated"
+ QUEUE_ITEMS_UPDATED = "queue_items_updated"
+ QUEUE_TIME_UPDATED = "queue_time_updated"
+ QUEUE_SETTINGS_UPDATED = "queue_settings_updated"
+ SHUTDOWN = "application_shutdown"
+ MEDIA_ITEM_ADDED = "media_item_added"
+ MEDIA_ITEM_UPDATED = "media_item_updated"
+ MEDIA_ITEM_DELETED = "media_item_deleted"
+ PROVIDERS_UPDATED = "providers_updated"
+ PLAYER_CONFIG_UPDATED = "player_config_updated"
+ SYNC_TASKS_UPDATED = "sync_tasks_updated"
+
+
+class ProviderFeature(StrEnum):
+ """Enum with features for a Provider."""
+
+ #
+ # MUSICPROVIDER FEATURES
+ #
+
+ # browse/explore/recommendations
+ BROWSE = "browse"
+ SEARCH = "search"
+ RECOMMENDATIONS = "recommendations"
+
+ # library feature per mediatype
+ LIBRARY_ARTISTS = "library_artists"
+ LIBRARY_ALBUMS = "library_albums"
+ LIBRARY_TRACKS = "library_tracks"
+ LIBRARY_PLAYLISTS = "library_playlists"
+ LIBRARY_RADIOS = "library_radios"
+
+ # additional library features
+ ARTIST_ALBUMS = "artist_albums"
+ ARTIST_TOPTRACKS = "artist_toptracks"
+
+ # library edit (=add/remove) feature per mediatype
+ LIBRARY_ARTISTS_EDIT = "library_artists_edit"
+ LIBRARY_ALBUMS_EDIT = "library_albums_edit"
+ LIBRARY_TRACKS_EDIT = "library_tracks_edit"
+ LIBRARY_PLAYLISTS_EDIT = "library_playlists_edit"
+ LIBRARY_RADIOS_EDIT = "library_radios_edit"
+
+ # if we can grab 'similar tracks' from the music provider
+ # used to generate dynamic playlists
+ SIMILAR_TRACKS = "similar_tracks"
+
+ # playlist-specific features
+ PLAYLIST_TRACKS_EDIT = "playlist_tracks_edit"
+ PLAYLIST_CREATE = "playlist_create"
+
+ #
+ # PLAYERPROVIDER FEATURES
+ #
+ CREATE_PLAYER_CONFIG = "create_player_config"
+
+ #
+ # METADATAPROVIDER FEATURES
+ #
+ ARTIST_METADATA = "artist_metadata"
+ ALBUM_METADATA = "album_metadata"
+ TRACK_METADATA = "track_metadata"
+ GET_ARTIST_MBID = "get_artist_mbid"
+
+ #
+ # PLUGIN FEATURES
+ #
+
+
+class ProviderType(StrEnum):
+ """Enum with supported provider types."""
+
+ MUSIC = "music"
+ PLAYER = "player"
+ METADATA = "metadata"
+ PLUGIN = "plugin"
+
+
+class ConfigEntryType(StrEnum):
+ """Enum for the type of a config entry."""
+
+ BOOLEAN = "boolean"
+ STRING = "string"
+ PASSWORD = "password"
+ INTEGER = "integer"
+ FLOAT = "float"
+ LABEL = "label"
--- /dev/null
+"""Custom errors and exceptions."""
+
+
+class MusicAssistantError(Exception):
+ """Custom Exception for all errors."""
+
+ error_code = 0
+
+
+class ProviderUnavailableError(MusicAssistantError):
+ """Error raised when trying to access mediaitem of unavailable provider."""
+
+ error_code = 1
+
+
+class MediaNotFoundError(MusicAssistantError):
+ """Error raised when trying to access non existing media item."""
+
+ error_code = 2
+
+
+class InvalidDataError(MusicAssistantError):
+ """Error raised when an object has invalid data."""
+
+ error_code = 3
+
+
+class AlreadyRegisteredError(MusicAssistantError):
+ """Error raised when a duplicate music provider or player is registered."""
+
+ error_code = 4
+
+
+class SetupFailedError(MusicAssistantError):
+ """Error raised when setup of a provider or player failed."""
+
+ error_code = 5
+
+
+class LoginFailed(MusicAssistantError):
+ """Error raised when a login failed."""
+
+ error_code = 6
+
+
+class AudioError(MusicAssistantError):
+ """Error raised when an issue arrised when processing audio."""
+
+ error_code = 7
+
+
+class QueueEmpty(MusicAssistantError):
+ """Error raised when trying to start queue stream while queue is empty."""
+
+ error_code = 8
+
+
+class UnsupportedFeaturedException(MusicAssistantError):
+ """Error raised when a feature is not supported."""
+
+ error_code = 9
+
+
+class PlayerUnavailableError(MusicAssistantError):
+ """Error raised when trying to access non-existing or unavailable player."""
+
+ error_code = 10
+
+
+class PlayerCommandFailed(MusicAssistantError):
+ """Error raised when a command to a player failed execution."""
+
+ error_code = 11
+
+
+class InvalidCommand(MusicAssistantError):
+ """Error raised when an unknown command is requested on the API."""
+
+ error_code = 12
+
+
+def error_code_to_exception(error_code: int) -> MusicAssistantError:
+ """Return MusicAssistant Error (exception) from error_code."""
+ match error_code:
+ case 1:
+ return ProviderUnavailableError
+ case 2:
+ return MediaNotFoundError
+ case 3:
+ return InvalidDataError
+ case 4:
+ return AlreadyRegisteredError
+ case 5:
+ return SetupFailedError
+ case 6:
+ return LoginFailed
+ case 7:
+ return AudioError
+ case 8:
+ return QueueEmpty
+ case 9:
+ return UnsupportedFeaturedException
+ case 10:
+ return PlayerUnavailableError
+ case 11:
+ return PlayerCommandFailed
+ case 12:
+ return InvalidCommand
+ case _:
+ return MusicAssistantError
--- /dev/null
+"""Model for Music Assistant Event."""
+
+from dataclasses import dataclass
+from typing import Any
+
+from mashumaro import DataClassDictMixin
+
+from music_assistant.common.models.enums import EventType
+
+
+@dataclass
+class MassEvent(DataClassDictMixin):
+ """Representation of an Event emitted in/by Music Assistant."""
+
+ event: EventType
+ object_id: str | None = None # player_id, queue_id or uri
+ data: Any = None # optional data (such as the object)
--- /dev/null
+"""Models and helpers for media items."""
+from __future__ import annotations
+
+from collections.abc import Mapping
+from dataclasses import dataclass, field, fields
+from time import time
+from typing import Any
+
+from mashumaro import DataClassDictMixin
+
+from music_assistant.common.helpers.json import json_dumps, json_loads
+from music_assistant.common.helpers.uri import create_uri
+from music_assistant.common.helpers.util import create_sort_name, merge_lists
+from music_assistant.common.models.enums import (
+ AlbumType,
+ ContentType,
+ ImageType,
+ LinkType,
+ MediaType,
+)
+
+MetadataTypes = int | bool | str | list[str]
+
+JSON_KEYS = ("artists", "artist", "albums", "metadata", "provider_mappings")
+
+
+@dataclass(frozen=True)
+class ProviderMapping(DataClassDictMixin):
+ """Model for a MediaItem's provider mapping details."""
+
+ item_id: str
+ provider_domain: str
+ provider_instance: str
+ available: bool = True
+ # quality details (streamable content only)
+ content_type: ContentType = ContentType.UNKNOWN
+ sample_rate: int = 44100
+ bit_depth: int = 16
+ bit_rate: int = 320
+ # optional details to store provider specific details
+ details: str | None = None
+ # url = link to provider details page if exists
+ url: str | None = None
+
+ @property
+ def quality(self) -> int:
+ """Calculate quality score."""
+ if self.content_type.is_lossless():
+ return int(self.sample_rate / 1000) + self.bit_depth
+ # lossy content, bit_rate is most important score
+ # but prefer some codecs over others
+ score = self.bit_rate / 100
+ if self.content_type in (ContentType.AAC, ContentType.OGG):
+ score += 1
+ return int(score)
+
+ def __hash__(self):
+ """Return custom hash."""
+ return hash((self.provider_domain, self.item_id))
+
+
+@dataclass(frozen=True)
+class MediaItemLink(DataClassDictMixin):
+ """Model for a link."""
+
+ type: LinkType
+ url: str
+
+ def __hash__(self):
+ """Return custom hash."""
+ return hash(self.type)
+
+
+@dataclass(frozen=True)
+class MediaItemImage(DataClassDictMixin):
+ """Model for a image."""
+
+ type: ImageType
+ url: str
+ is_file: bool = False # indicator that image is local filepath instead of url
+
+ def __hash__(self):
+ """Return custom hash."""
+ return hash(self.url)
+
+
+@dataclass
+class MediaItemMetadata(DataClassDictMixin):
+ """Model for a MediaItem's metadata."""
+
+ description: str | None = None
+ review: str | None = None
+ explicit: bool | None = None
+ images: list[MediaItemImage] | None = None
+ genres: set[str] | None = None
+ mood: str | None = None
+ style: str | None = None
+ copyright: str | None = None
+ lyrics: str | None = None
+ ean: str | None = None
+ label: str | None = None
+ links: set[MediaItemLink] | None = None
+ performers: set[str] | None = None
+ preview: str | None = None
+ replaygain: float | None = None
+ popularity: int | None = None
+ # last_refresh: timestamp the (full) metadata was last collected
+ last_refresh: int | None = None
+ # checksum: optional value to detect changes (e.g. playlists)
+ checksum: str | None = None
+
+ def update(
+ self,
+ new_values: MediaItemMetadata,
+ allow_overwrite: bool = False,
+ ) -> MediaItemMetadata:
+ """Update metadata (in-place) with new values."""
+ for fld in fields(self):
+ new_val = getattr(new_values, fld.name)
+ if new_val is None:
+ continue
+ cur_val = getattr(self, fld.name)
+ if isinstance(cur_val, list):
+ new_val = merge_lists(cur_val, new_val)
+ setattr(self, fld.name, new_val)
+ elif isinstance(cur_val, set):
+ new_val = cur_val.update(new_val)
+ setattr(self, fld.name, new_val)
+ elif cur_val is None or allow_overwrite: # noqa: SIM114
+ setattr(self, fld.name, new_val)
+ elif new_val and fld.name in ("checksum", "popularity", "last_refresh"):
+ # some fields are always allowed to be overwritten
+ # (such as checksum and last_refresh)
+ setattr(self, fld.name, new_val)
+ return self
+
+
+@dataclass
+class MediaItem(DataClassDictMixin):
+ """Base representation of a media item."""
+
+ item_id: str
+ provider: str
+ name: str
+ provider_mappings: set[ProviderMapping] = field(default_factory=set)
+
+ # optional fields below
+ metadata: MediaItemMetadata = field(default_factory=MediaItemMetadata)
+ in_library: bool = False
+ media_type: MediaType = MediaType.UNKNOWN
+ # sort_name and uri are auto generated, do not override unless really needed
+ sort_name: str | None = None
+ uri: str | None = None
+ # timestamp is used to determine when the item was added to the library
+ timestamp: int = 0
+
+ def __post_init__(self):
+ """Call after init."""
+ if not self.uri:
+ self.uri = create_uri(self.media_type, self.provider, self.item_id)
+ if not self.sort_name:
+ self.sort_name = create_sort_name(self.name)
+
+ @classmethod
+ def from_db_row(cls, db_row: Mapping):
+ """Create MediaItem object from database row."""
+ db_row = dict(db_row)
+ db_row["provider"] = "database"
+ for key in JSON_KEYS:
+ if key in db_row and db_row[key] is not None:
+ db_row[key] = json_loads(db_row[key])
+ if "in_library" in db_row:
+ db_row["in_library"] = bool(db_row["in_library"])
+ if db_row.get("albums"):
+ db_row["album"] = db_row["albums"][0]
+ db_row["disc_number"] = db_row["albums"][0]["disc_number"]
+ db_row["track_number"] = db_row["albums"][0]["track_number"]
+ db_row["item_id"] = str(db_row["item_id"])
+ return cls.from_dict(db_row)
+
+ def to_db_row(self) -> dict:
+ """Create dict from item suitable for db."""
+ return {
+ key: json_dumps(value) if key in JSON_KEYS else value
+ for key, value in self.to_dict().items()
+ if key
+ not in [
+ "item_id",
+ "provider",
+ "media_type",
+ "uri",
+ "album",
+ "position",
+ "track_number",
+ "disc_number",
+ ]
+ }
+
+ @property
+ def available(self):
+ """Return (calculated) availability."""
+ return any(x.available for x in self.provider_mappings)
+
+ @property
+ def image(self) -> MediaItemImage | None:
+ """Return (first/random) image/thumb from metadata (if any)."""
+ if self.metadata is None or self.metadata.images is None:
+ return None
+ return next((x for x in self.metadata.images if x.type == ImageType.THUMB), None)
+
+ def add_provider_mapping(self, prov_mapping: ProviderMapping) -> None:
+ """Add provider ID, overwrite existing entry."""
+ self.provider_mappings = {
+ x
+ for x in self.provider_mappings
+ if not (
+ x.item_id == prov_mapping.item_id
+ and x.provider_instance == prov_mapping.provider_instance
+ )
+ }
+ self.provider_mappings.add(prov_mapping)
+
+ @property
+ def last_refresh(self) -> int:
+ """Return timestamp the metadata was last refreshed (0 if full data never retrieved)."""
+ return self.metadata.last_refresh or 0
+
+ def __hash__(self):
+ """Return custom hash."""
+ return hash((self.media_type, self.provider, self.item_id))
+
+
+@dataclass(frozen=True)
+class ItemMapping(DataClassDictMixin):
+ """Representation of a minimized item object."""
+
+ media_type: MediaType
+ item_id: str
+ provider: str
+ name: str
+ sort_name: str
+ uri: str
+ version: str = ""
+
+ @classmethod
+ def from_item(cls, item: MediaItem):
+ """Create ItemMapping object from regular item."""
+ return cls.from_dict(item.to_dict())
+
+ def __hash__(self):
+ """Return custom hash."""
+ return hash((self.media_type, self.provider, self.item_id))
+
+
+@dataclass
+class Artist(MediaItem):
+ """Model for an artist."""
+
+ media_type: MediaType = MediaType.ARTIST
+ musicbrainz_id: str | None = None
+
+ def __hash__(self):
+ """Return custom hash."""
+ return hash((self.provider, self.item_id))
+
+
+@dataclass
+class Album(MediaItem):
+ """Model for an album."""
+
+ media_type: MediaType = MediaType.ALBUM
+ version: str = ""
+ year: int | None = None
+ artists: list[Artist | ItemMapping] = field(default_factory=list)
+ album_type: AlbumType = AlbumType.UNKNOWN
+ upc: str | None = None
+ musicbrainz_id: str | None = None # release group id
+
+ @property
+ def artist(self) -> Artist | ItemMapping | None:
+ """Return (first) artist of album."""
+ if self.artists:
+ return self.artists[0]
+ return None
+
+ @artist.setter
+ def artist(self, artist: Artist | ItemMapping) -> None:
+ """Set (first/only) artist of album."""
+ self.artists = [artist]
+
+ def __hash__(self):
+ """Return custom hash."""
+ return hash((self.provider, self.item_id))
+
+
+@dataclass(frozen=True)
+class TrackAlbumMapping(ItemMapping):
+ """Model for a track that is mapped to an album."""
+
+ disc_number: int | None = None
+ track_number: int | None = None
+
+
+@dataclass
+class Track(MediaItem):
+ """Model for a track."""
+
+ media_type: MediaType = MediaType.TRACK
+ duration: int = 0
+ version: str = ""
+ isrc: str | None = None
+ musicbrainz_id: str | None = None # Recording ID
+ artists: list[Artist | ItemMapping] = field(default_factory=list)
+ # album track only
+ album: Album | ItemMapping | None = None
+ albums: list[TrackAlbumMapping] = field(default_factory=list)
+ disc_number: int | None = None
+ track_number: int | None = None
+ # playlist track only
+ position: int | None = None
+
+ def __hash__(self):
+ """Return custom hash."""
+ return hash((self.provider, self.item_id))
+
+ @property
+ def image(self) -> MediaItemImage | None:
+ """Return (first/random) image/thumb from metadata (if any)."""
+ if image := super().image:
+ return image
+ # fallback to album image (use getattr to guard for ItemMapping)
+ if self.album:
+ return getattr(self.album, "image", None)
+ return None
+
+ @property
+ def isrcs(self) -> tuple[str]:
+ """Split multiple values in isrc field."""
+ # sometimes the isrc contains multiple values, split by semicolon
+ if not self.isrc:
+ return tuple()
+ return tuple(self.isrc.split(";"))
+
+ @property
+ def artist(self) -> Artist | ItemMapping | None:
+ """Return (first) artist of track."""
+ if self.artists:
+ return self.artists[0]
+ return None
+
+ @artist.setter
+ def artist(self, artist: Artist | ItemMapping) -> None:
+ """Set (first/only) artist of track."""
+ self.artists = [artist]
+
+
+@dataclass
+class Playlist(MediaItem):
+ """Model for a playlist."""
+
+ media_type: MediaType = MediaType.PLAYLIST
+ owner: str = ""
+ is_editable: bool = False
+
+ def __hash__(self):
+ """Return custom hash."""
+ return hash((self.provider, self.item_id))
+
+
+@dataclass
+class Radio(MediaItem):
+ """Model for a radio station."""
+
+ media_type: MediaType = MediaType.RADIO
+ duration: int = 172800
+
+ def to_db_row(self) -> dict:
+ """Create dict from item suitable for db."""
+ val = super().to_db_row()
+ val.pop("duration", None)
+ return val
+
+ def __hash__(self):
+ """Return custom hash."""
+ return hash((self.provider, self.item_id))
+
+
+@dataclass
+class BrowseFolder(MediaItem):
+ """Representation of a Folder used in Browse (which contains media items)."""
+
+ media_type: MediaType = MediaType.FOLDER
+ # path: the path (in uri style) to/for this browse folder
+ path: str = ""
+ # label: a labelid that needs to be translated by the frontend
+ label: str = ""
+ # subitems of this folder when expanding
+ items: list[MediaItemType | BrowseFolder] | None = None
+
+ def __post_init__(self):
+ """Call after init."""
+ super().__post_init__()
+ if not self.path:
+ self.path = f"{self.provider}://{self.item_id}"
+
+
+MediaItemType = Artist | Album | Track | Radio | Playlist | BrowseFolder
+
+
+@dataclass
+class PagedItems(DataClassDictMixin):
+ """Model for a paged listing."""
+
+ items: list[MediaItemType]
+ count: int
+ limit: int
+ offset: int
+ total: int | None = None
+
+
+def media_from_dict(media_item: dict) -> MediaItemType:
+ """Return MediaItem from dict."""
+ if media_item["media_type"] == "artist":
+ return Artist.from_dict(media_item)
+ if media_item["media_type"] == "album":
+ return Album.from_dict(media_item)
+ if media_item["media_type"] == "track":
+ return Track.from_dict(media_item)
+ if media_item["media_type"] == "playlist":
+ return Playlist.from_dict(media_item)
+ if media_item["media_type"] == "radio":
+ return Radio.from_dict(media_item)
+ return MediaItem.from_dict(media_item)
+
+
+@dataclass
+class StreamDetails(DataClassDictMixin):
+ """Model for streamdetails."""
+
+ # NOTE: the actual provider/itemid of the streamdetails may differ
+ # from the connected media_item due to track linking etc.
+ # the streamdetails are only used to provide details about the content
+ # that is going to be streamed.
+
+ # mandatory fields
+ provider: str
+ item_id: str
+ content_type: ContentType
+ media_type: MediaType = MediaType.TRACK
+ sample_rate: int = 44100
+ bit_depth: int = 16
+ channels: int = 2
+ # stream_title: radio streams can optionally set this field
+ stream_title: str | None = None
+ # duration of the item to stream, copied from media_item if omitted
+ duration: int | None = None
+ # total size in bytes of the item, calculated at eof when omitted
+ size: int | None = None
+ # expires: timestamp this streamdetails expire
+ expires: float = time() + 3600
+ # data: provider specific data (not exposed externally)
+ data: Any = None
+ # if the url/file is supported by ffmpeg directly, use direct stream
+ direct: str | None = None
+ # callback: optional callback function (or coroutine) to call when the stream completes.
+ # needed for streaming provivders to report what is playing
+ # receives the streamdetails as only argument from which to grab
+ # details such as seconds_streamed.
+ callback: Any = None
+
+ # the fields below will be set/controlled by the streamcontroller
+ queue_id: str | None = None
+ seconds_streamed: float | None = None
+ seconds_skipped: float | None = None
+ gain_correct: float | None = None
+ loudness: float | None = None
+
+ def __post_serialize__(self, d: dict[Any, Any]) -> dict[Any, Any]:
+ """Exclude internal fields from dict."""
+ d.pop("data")
+ d.pop("direct")
+ d.pop("expires")
+ d.pop("queue_id")
+ d.pop("callback")
+ return d
+
+ def __str__(self):
+ """Return pretty printable string of object."""
+ return self.uri
+
+ @property
+ def uri(self) -> str:
+ """Return uri representation of item."""
+ return f"{self.provider}://{self.media_type.value}/{self.item_id}"
--- /dev/null
+"""Model(s) for Player."""
+from __future__ import annotations
+
+import time
+from dataclasses import dataclass, field
+
+from mashumaro import DataClassDictMixin
+
+from .enums import PlayerFeature, PlayerState, PlayerType
+
+
+@dataclass(frozen=True)
+class DeviceInfo(DataClassDictMixin):
+ """Model for a player's deviceinfo."""
+
+ model: str = "unknown"
+ address: str = "unknown"
+ manufacturer: str = "unknown"
+
+
+@dataclass
+class Player(DataClassDictMixin):
+ """Representation of a Player within Music Assistant."""
+
+ player_id: str
+ provider: str
+ type: PlayerType
+ name: str
+ available: bool
+ powered: bool
+ device_info: DeviceInfo
+ supported_features: tuple[PlayerFeature, ...] = field(default=tuple())
+
+ elapsed_time: float = 0
+ elapsed_time_last_updated: float = time.time()
+ current_url: str | None = None
+ current_item_id: str | None = None
+ state: PlayerState = PlayerState.IDLE
+
+ volume_level: int = 100
+ volume_muted: bool = False
+
+ # group_childs: Return list of player group child id's or synced childs.
+ # - If this player is a dedicated group player,
+ # returns all child id's of the players in the group.
+ # - If this is a syncgroup of players from the same platform (e.g. sonos),
+ # this will return the id's of players synced to this player.
+ group_childs: list[str] = field(default_factory=list)
+
+ # active_queue: return player_id of the active queue for this player
+ # if the player is grouped and a group is active, this will be set to the group's player_id
+ # otherwise it will be set to the own player_id
+ active_queue: str = ""
+
+ # can_sync_with: return tuple of player_ids that can be synced to/with this player
+ # ususally this is just a list of all player_ids within the playerprovider
+ can_sync_with: tuple[str, ...] = field(default=tuple())
+
+ # synced_to: player_id of the player this player is currently synced to
+ # also referred to as "sync master"
+ synced_to: str | None = None
+
+ # max_sample_rate: maximum supported sample rate the player supports
+ max_sample_rate: int = 48000
+
+ # supports_24bit: bool if player supports 24bits (hi res) audio
+ supports_24bit: bool = True
+
+ # enabled: if the player is enabled
+ # will be set by the player manager based on config
+ # a disabled player is hidden in the UI and updates will not be processed
+ enabled: bool = True
+
+ # group_volume: if the player is a player group or syncgroup master,
+ # this will return the average volume of all child players
+ # if not a group player, this is just the player's volume
+ group_volume: int = 100
+
+ # display_name: return final/corrected name of the player
+ # always prefers any overridden name from settings
+ display_name: str = ""
+
+ @property
+ def corrected_elapsed_time(self) -> float:
+ """Return the corrected/realtime elapsed time."""
+ if self.state == PlayerState.PLAYING:
+ return self.elapsed_time + (time.time() - self.elapsed_time_last_updated)
+ return self.elapsed_time
--- /dev/null
+"""Model(s) for PlayerQueue."""
+from __future__ import annotations
+
+import time
+from dataclasses import dataclass, field
+
+from mashumaro import DataClassDictMixin
+
+from music_assistant.common.models.media_items import MediaItemType
+
+from .enums import PlayerState, RepeatMode
+from .queue_item import QueueItem
+
+
+@dataclass
+class PlayerQueue(DataClassDictMixin):
+ """Representation of a PlayerQueue within Music Assistant."""
+
+ queue_id: str
+ active: bool
+ display_name: str
+ available: bool
+ items: int
+
+ shuffle_enabled: bool = False
+ repeat_mode: RepeatMode = RepeatMode.OFF
+ crossfade_enabled: bool = True
+ # current_index: index that is active (e.g. being played) by the player
+ current_index: int | None = None
+ # index_in_buffer: index that has been preloaded/buffered by the player
+ index_in_buffer: int | None = None
+ elapsed_time: float = 0
+ elapsed_time_last_updated: float = time.time()
+ state: PlayerState = PlayerState.IDLE
+ current_item: QueueItem | None = None
+ next_item: QueueItem | None = None
+ radio_source: list[MediaItemType] = field(default_factory=list)
+ announcement_in_progress: bool = False
+ flow_mode: bool = False
+
+ @property
+ def corrected_elapsed_time(self) -> float:
+ """Return the corrected/realtime elapsed time."""
+ return self.elapsed_time + (time.time() - self.elapsed_time_last_updated)
--- /dev/null
+"""Models for providers and plugins in the MA ecosystem."""
+
+import asyncio
+from dataclasses import dataclass, field
+from typing import TypedDict
+
+from mashumaro import DataClassDictMixin
+
+from music_assistant.common.helpers.json import load_json_file
+
+from .config_entries import ConfigEntry
+from .enums import MediaType, ProviderFeature, ProviderType
+
+
+@dataclass
+class ProviderManifest(DataClassDictMixin):
+ """ProviderManifest, details of a provider."""
+
+ type: ProviderType
+ domain: str
+ name: str
+ description: str
+ codeowners: list[str]
+
+ # optional params
+ # config_entries: list of config entries required to configure/setup this provider
+ config_entries: list[ConfigEntry] = field(default_factory=list)
+ # requirements: list of (pip style) python packages required for this provider
+ requirements: list[str] = field(default_factory=list)
+ # documentation: link/url to documentation.
+ documentation: str | None = None
+ # init_class: class to initialize, within provider's package
+ # e.g. `SpotifyProvider`. (autodetect if None)
+ init_class: str | None = None
+ # multi_instance: whether multiple instances of the same provider are allowed/possible
+ multi_instance: bool = False
+ # builtin: whether this provider is a system/builtin and can not disabled/removed
+ builtin: bool = False
+ # load_by_default: load this provider by default (mostly used together with `builtin`)
+ load_by_default: bool = False
+ # depends_on: depends on another provider to function
+ depends_on: str | None = None
+
+ @classmethod
+ async def parse(cls: "ProviderManifest", manifest_file: str) -> "ProviderManifest":
+ """Parse ProviderManifest from file."""
+ manifest_dict = await load_json_file(manifest_file)
+ return cls.from_dict(manifest_dict)
+
+
+class ProviderInstance(TypedDict):
+ """Provider instance detailed dict when a provider is serialized over the api."""
+
+ type: ProviderType
+ domain: str
+ name: str
+ instance_id: str
+ supported_features: list[ProviderFeature]
+ available: bool
+ last_error: str | None
+
+
+@dataclass
+class SyncTask:
+ """Description of a Sync task/job of a musicprovider."""
+
+ provider_domain: str
+ provider_instance: str
+ media_types: tuple[MediaType]
+ task: asyncio.Task
+
+ def __post_init__(self):
+ """Execute action after initialization."""
+ # make sure that the task does not get serialized.
+ setattr(self.task, "do_not_serialize", True)
--- /dev/null
+"""Model a QueueItem."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any
+from uuid import uuid4
+
+from mashumaro import DataClassDictMixin
+
+from .enums import MediaType
+from .media_items import ItemMapping, MediaItemImage, Radio, StreamDetails, Track
+
+
+@dataclass
+class QueueItem(DataClassDictMixin):
+ """Representation of a queue item."""
+
+ queue_id: str
+ queue_item_id: str
+
+ name: str
+ duration: int | None
+ sort_index: int = 0
+ streamdetails: StreamDetails | None = None
+ media_item: Track | Radio | None = None
+ image: MediaItemImage | None = None
+
+ def __post_init__(self):
+ """Set default values."""
+ if self.streamdetails and self.streamdetails.stream_title:
+ self.name = self.streamdetails.stream_title
+ if not self.name:
+ self.name = self.uri
+
+ @classmethod
+ def __pre_deserialize__(cls, d: dict[Any, Any]) -> dict[Any, Any]:
+ """Run actions before deserialization."""
+ d.pop("streamdetails", None)
+ return d
+
+ @property
+ def uri(self) -> str:
+ """Return uri for this QueueItem (for logging purposes)."""
+ if self.media_item:
+ return self.media_item.uri
+ return self.queue_item_id
+
+ @property
+ def media_type(self) -> MediaType:
+ """Return MediaType for this QueueItem (for convenience purposes)."""
+ if self.media_item:
+ return self.media_item.media_type
+ return MediaType.UNKNOWN
+
+ @classmethod
+ def from_media_item(cls, queue_id: str, media_item: Track | Radio):
+ """Construct QueueItem from track/radio item."""
+ if media_item.media_type == MediaType.TRACK:
+ artists = "/".join(x.name for x in media_item.artists)
+ name = f"{artists} - {media_item.name}"
+ # save a lot of data/bandwidth by simplifying nested objects
+ media_item.artists = [ItemMapping.from_item(x) for x in media_item.artists]
+ if media_item.album:
+ media_item.album = ItemMapping.from_item(media_item.album)
+ media_item.albums = []
+ else:
+ name = media_item.name
+ return cls(
+ queue_id=queue_id,
+ queue_item_id=uuid4().hex,
+ name=name,
+ duration=media_item.duration,
+ media_item=media_item,
+ image=media_item.image,
+ )
"""All constants for Music Assistant."""
import pathlib
+from typing import Final
-ROOT_LOGGER_NAME = "music_assistant"
+__version__: Final[str] = "2.0.0"
-UNKNOWN_ARTIST = "Unknown Artist"
-VARIOUS_ARTISTS = "Various Artists"
-VARIOUS_ARTISTS_ID = "89ad4ac3-39f7-470e-963a-56509c546377"
+SCHEMA_VERSION: Final[int] = 19
+ROOT_LOGGER_NAME: Final[str] = "music_assistant"
-RESOURCES_DIR = pathlib.Path(__file__).parent.resolve().joinpath("helpers/resources")
+UNKNOWN_ARTIST: Final[str] = "Unknown Artist"
+VARIOUS_ARTISTS: Final[str] = "Various Artists"
+VARIOUS_ARTISTS_ID: Final[str] = "89ad4ac3-39f7-470e-963a-56509c546377"
-ANNOUNCE_ALERT_FILE = str(RESOURCES_DIR.joinpath("announce.mp3"))
-SILENCE_FILE = str(RESOURCES_DIR.joinpath("silence.mp3"))
-# if duration is None (e.g. radio stream) = 48 hours
-FALLBACK_DURATION = 172800
+RESOURCES_DIR: Final[pathlib.Path] = (
+ pathlib.Path(__file__).parent.resolve().joinpath("helpers/resources")
+)
+
+ANNOUNCE_ALERT_FILE: Final[str] = str(RESOURCES_DIR.joinpath("announce.mp3"))
+SILENCE_FILE: Final[str] = str(RESOURCES_DIR.joinpath("silence.mp3"))
+
+# if duration is None (e.g. radio stream):Final[str] = 48 hours
+FALLBACK_DURATION: Final[int] = 172800
# Name of the environment-variable to override base_url
-BASE_URL_OVERRIDE_ENVNAME = "MASS_BASE_URL"
+BASE_URL_OVERRIDE_ENVNAME: Final[str] = "MASS_BASE_URL"
+
+
+# config keys
+CONF_SERVER_ID: Final[str] = "server_id"
+CONF_WEB_IP: Final[str] = "webserver.ip"
+CONF_WEB_PORT: Final[str] = "webserver.port"
+CONF_DB_LIBRARY: Final[str] = "database.library"
+CONF_DB_CACHE: Final[str] = "database.cache"
+CONF_PROVIDERS: Final[str] = "providers"
+CONF_PLAYERS: Final[str] = "players"
+CONF_PATH: Final[str] = "path"
+CONF_USERNAME: Final[str] = "username"
+CONF_PASSWORD: Final[str] = "password"
+CONF_VOLUME_NORMALISATION: Final[str] = "volume_normalisation"
+CONF_VOLUME_NORMALISATION_TARGET: Final[str] = "volume_normalisation_target"
+CONF_MAX_SAMPLE_RATE: Final[str] = "max_sample_rate"
+CONF_EQ_BASS: Final[str] = "eq_bass"
+CONF_EQ_MID: Final[str] = "eq_mid"
+CONF_EQ_TREBLE: Final[str] = "eq_treble"
+CONF_OUTPUT_CHANNELS: Final[str] = "output_channels"
+CONF_FLOW_MODE: Final[str] = "flow_mode"
+
+# config default values
+DEFAULT_HOST: Final[str] = "0.0.0.0"
+DEFAULT_PORT: Final[int] = 8095
+DEFAULT_DB_LIBRARY: Final[str] = "sqlite:///[storage_path]/library.db"
+DEFAULT_DB_CACHE: Final[str] = "sqlite:///[storage_path]/cache.db"
+
+# common db tables
+DB_TABLE_TRACK_LOUDNESS: Final[str] = "track_loudness"
+DB_TABLE_PLAYLOG: Final[str] = "playlog"
+DB_TABLE_ARTISTS: Final[str] = "artists"
+DB_TABLE_ALBUMS: Final[str] = "albums"
+DB_TABLE_TRACKS: Final[str] = "tracks"
+DB_TABLE_PLAYLISTS: Final[str] = "playlists"
+DB_TABLE_RADIOS: Final[str] = "radios"
+DB_TABLE_CACHE: Final[str] = "cache"
+DB_TABLE_SETTINGS: Final[str] = "settings"
+DB_TABLE_THUMBS: Final[str] = "thumbnails"
+DB_TABLE_PROVIDER_MAPPINGS: Final[str] = "provider_mappings"
+
+# all other
+MASS_LOGO_ONLINE: Final[str] = (
+ "https://github.com/home-assistant/brands/" "raw/master/custom_integrations/mass/icon%402x.png"
+)
+++ /dev/null
-"""Package with controllers."""
+++ /dev/null
-"""Provides a simple stateless caching system."""
-from __future__ import annotations
-
-import asyncio
-import functools
-import json
-import time
-from collections import OrderedDict
-from collections.abc import MutableMapping
-from typing import TYPE_CHECKING, Any, Iterator, Optional
-
-from music_assistant.controllers.database import TABLE_CACHE
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-
-class CacheController:
- """Basic cache controller using both memory and database."""
-
- def __init__(self, mass: MusicAssistant) -> None:
- """Initialize our caching class."""
- self.mass = mass
- self.logger = mass.logger.getChild("cache")
- self._mem_cache = MemoryCache(500)
-
- async def setup(self) -> None:
- """Async initialize of cache module."""
- self.__schedule_cleanup_task()
-
- async def get(self, cache_key: str, checksum: Optional[str] = None, default=None):
- """
- Get object from cache and return the results.
-
- cache_key: the (unique) name of the cache object as reference
- checkum: optional argument to check if the checksum in the
- cacheobject matches the checkum provided
- """
- cur_time = int(time.time())
- if checksum is not None and not isinstance(checksum, str):
- checksum = str(checksum)
-
- # try memory cache first
- cache_data = self._mem_cache.get(cache_key)
- if (
- cache_data
- and (not checksum or cache_data[1] == checksum)
- and cache_data[2] >= cur_time
- ):
- return cache_data[0]
- # fall back to db cache
- if db_row := await self.mass.database.get_row(TABLE_CACHE, {"key": cache_key}):
- if (
- not checksum
- or db_row["checksum"] == checksum
- and db_row["expires"] >= cur_time
- ):
- try:
- data = await asyncio.get_running_loop().run_in_executor(
- None, json.loads, db_row["data"]
- )
- except Exception as exc: # pylint: disable=broad-except
- self.logger.exception(
- "Error parsing cache data for %s", cache_key, exc_info=exc
- )
- else:
- # also store in memory cache for faster access
- self._mem_cache[cache_key] = (
- data,
- db_row["checksum"],
- db_row["expires"],
- )
- return data
- return default
-
- async def set(self, cache_key, data, checksum="", expiration=(86400 * 30)):
- """Set data in cache."""
- if checksum is not None and not isinstance(checksum, str):
- checksum = str(checksum)
- expires = int(time.time() + expiration)
- self._mem_cache[cache_key] = (data, checksum, expires)
- if (expires - time.time()) < 3600 * 4:
- # do not cache items in db with short expiration
- return
- data = await asyncio.get_running_loop().run_in_executor(None, json.dumps, data)
- await self.mass.database.insert(
- TABLE_CACHE,
- {"key": cache_key, "expires": expires, "checksum": checksum, "data": data},
- allow_replace=True,
- )
-
- async def delete(self, cache_key):
- """Delete data from cache."""
- self._mem_cache.pop(cache_key, None)
- await self.mass.database.delete(TABLE_CACHE, {"key": cache_key})
-
- async def clear(self, key_filter: Optional[str] = None) -> None:
- """Clear all/partial items from cache."""
- self._mem_cache = {}
- query = f"key LIKE '%{key_filter}%'" if key_filter else None
- await self.mass.database.delete(TABLE_CACHE, query=query)
-
- async def auto_cleanup(self):
- """Sceduled auto cleanup task."""
- # for now we simply reset the memory cache
- self._mem_cache = {}
- cur_timestamp = int(time.time())
- for db_row in await self.mass.database.get_rows(TABLE_CACHE):
- # clean up db cache object only if expired
- if db_row["expires"] < cur_timestamp:
- await self.delete(db_row["key"])
-
- def __schedule_cleanup_task(self):
- """Schedule the cleanup task."""
- self.mass.add_job(self.auto_cleanup(), "Cleanup cache")
- # reschedule self
- self.mass.loop.call_later(3600, self.__schedule_cleanup_task)
-
-
-def use_cache(expiration=86400 * 30):
- """Return decorator that can be used to cache a method's result."""
-
- def wrapper(func):
- @functools.wraps(func)
- async def wrapped(*args, **kwargs):
- method_class = args[0]
- method_class_name = method_class.__class__.__name__
- cache_key_parts = [method_class_name, func.__name__]
- skip_cache = kwargs.pop("skip_cache", False)
- cache_checksum = kwargs.pop("cache_checksum", "")
- if len(args) > 1:
- cache_key_parts += args[1:]
- for key in sorted(kwargs.keys()):
- cache_key_parts.append(f"{key}{kwargs[key]}")
- cache_key = ".".join(cache_key_parts)
-
- cachedata = await method_class.cache.get(cache_key, checksum=cache_checksum)
-
- if not skip_cache and cachedata is not None:
- return cachedata
- result = await func(*args, **kwargs)
- asyncio.create_task(
- method_class.cache.set(
- cache_key, result, expiration=expiration, checksum=cache_checksum
- )
- )
- return result
-
- return wrapped
-
- return wrapper
-
-
-class MemoryCache(MutableMapping):
- """Simple limited in-memory cache implementation."""
-
- def __init__(self, maxlen: int):
- """Initialize."""
- self._maxlen = maxlen
- self.d = OrderedDict()
-
- @property
- def maxlen(self) -> int:
- """Return max length."""
- return self._maxlen
-
- def get(self, key: str, default: Any = None) -> Any:
- """Return item or default."""
- return self.d.get(key, default)
-
- def pop(self, key: str, default: Any = None) -> Any:
- """Pop item from collection."""
- return self.d.pop(key, default)
-
- def __getitem__(self, key: str) -> Any:
- """Get item."""
- self.d.move_to_end(key)
- return self.d[key]
-
- def __setitem__(self, key: str, value: Any) -> None:
- """Set item."""
- if key in self.d:
- self.d.move_to_end(key)
- elif len(self.d) == self.maxlen:
- self.d.popitem(last=False)
- self.d[key] = value
-
- def __delitem__(self, key) -> None:
- """Delete item."""
- del self.d[key]
-
- def __iter__(self) -> Iterator:
- """Iterate items."""
- return self.d.__iter__()
-
- def __len__(self) -> int:
- """Return length."""
- return len(self.d)
+++ /dev/null
-"""Database logic."""
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union
-
-from databases import Database as Db
-from sqlalchemy.sql import ClauseElement
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-
-SCHEMA_VERSION = 19
-
-TABLE_TRACK_LOUDNESS = "track_loudness"
-TABLE_PLAYLOG = "playlog"
-TABLE_ARTISTS = "artists"
-TABLE_ALBUMS = "albums"
-TABLE_TRACKS = "tracks"
-TABLE_PLAYLISTS = "playlists"
-TABLE_RADIOS = "radios"
-TABLE_CACHE = "cache"
-TABLE_SETTINGS = "settings"
-TABLE_THUMBS = "thumbnails"
-
-
-class DatabaseController:
- """Controller that holds the (connection to the) database."""
-
- def __init__(self, mass: MusicAssistant):
- """Initialize class."""
- self.url = mass.config.database_url
- self.mass = mass
- self.logger = mass.logger.getChild("db")
- # we maintain one global connection - otherwise we run into (dead)lock issues.
- # https://github.com/encode/databases/issues/456
- self._db = Db(self.url, timeout=360)
-
- async def setup(self) -> None:
- """Perform async initialization."""
- await self._db.connect()
- self.logger.info("Database connected.")
- await self._migrate()
-
- async def close(self) -> None:
- """Close db connection on exit."""
- self.logger.info("Database disconnected.")
- await self._db.disconnect()
-
- async def get_setting(self, key: str) -> str | None:
- """Get setting from settings table."""
- if db_row := await self.get_row(TABLE_SETTINGS, {"key": key}):
- return db_row["value"]
- return None
-
- async def set_setting(self, key: str, value: str) -> None:
- """Set setting in settings table."""
- if not isinstance(value, str):
- value = str(value)
- return await self.insert(
- TABLE_SETTINGS, {"key": key, "value": value}, allow_replace=True
- )
-
- async def get_rows(
- self,
- table: str,
- match: dict = None,
- order_by: str = None,
- limit: int = 500,
- offset: int = 0,
- ) -> List[Mapping]:
- """Get all rows for given table."""
- sql_query = f"SELECT * FROM {table}"
- if match is not None:
- sql_query += " WHERE " + " AND ".join((f"{x} = :{x}" for x in match))
- if order_by is not None:
- sql_query += f" ORDER BY {order_by}"
- sql_query += f" LIMIT {limit} OFFSET {offset}"
- return await self._db.fetch_all(sql_query, match)
-
- async def get_rows_from_query(
- self,
- query: str,
- params: Optional[dict] = None,
- limit: int = 500,
- offset: int = 0,
- ) -> List[Mapping]:
- """Get all rows for given custom query."""
- query = f"{query} LIMIT {limit} OFFSET {offset}"
- return await self._db.fetch_all(query, params)
-
- async def get_count_from_query(
- self,
- query: str,
- params: Optional[dict] = None,
- ) -> int:
- """Get row count for given custom query."""
- query = f"SELECT count() FROM ({query})"
- if result := await self._db.fetch_one(query, params):
- return result[0]
- return 0
-
- async def search(
- self, table: str, search: str, column: str = "name"
- ) -> List[Mapping]:
- """Search table by column."""
- sql_query = f"SELECT * FROM {table} WHERE {column} LIKE :search"
- params = {"search": f"%{search}%"}
- return await self._db.fetch_all(sql_query, params)
-
- async def get_row(self, table: str, match: Dict[str, Any]) -> Mapping | None:
- """Get single row for given table where column matches keys/values."""
- sql_query = f"SELECT * FROM {table} WHERE "
- sql_query += " AND ".join((f"{x} = :{x}" for x in match))
- return await self._db.fetch_one(sql_query, match)
-
- async def insert(
- self,
- table: str,
- values: Dict[str, Any],
- allow_replace: bool = False,
- ) -> Mapping:
- """Insert data in given table."""
- keys = tuple(values.keys())
- if allow_replace:
- sql_query = f'INSERT OR REPLACE INTO {table}({",".join(keys)})'
- else:
- sql_query = f'INSERT INTO {table}({",".join(keys)})'
- sql_query += f' VALUES ({",".join((f":{x}" for x in keys))})'
- await self.execute(sql_query, values)
- # return inserted/replaced item
- lookup_vals = {
- key: value
- for key, value in values.items()
- if value is not None and value != ""
- }
- return await self.get_row(table, lookup_vals)
-
- async def insert_or_replace(self, table: str, values: Dict[str, Any]) -> Mapping:
- """Insert or replace data in given table."""
- return await self.insert(table=table, values=values, allow_replace=True)
-
- async def update(
- self,
- table: str,
- match: Dict[str, Any],
- values: Dict[str, Any],
- ) -> Mapping:
- """Update record."""
- keys = tuple(values.keys())
- sql_query = f'UPDATE {table} SET {",".join((f"{x}=:{x}" for x in keys))} WHERE '
- sql_query += " AND ".join((f"{x} = :{x}" for x in match))
- await self.execute(sql_query, {**match, **values})
- # return updated item
- return await self.get_row(table, match)
-
- async def delete(
- self, table: str, match: Optional[dict] = None, query: Optional[str] = None
- ) -> None:
- """Delete data in given table."""
- assert not (query and "where" in query.lower())
- sql_query = f"DELETE FROM {table} "
- if match:
- sql_query += " WHERE " + " AND ".join((f"{x} = :{x}" for x in match))
- elif query and "query" not in query.lower():
- sql_query += "WHERE " + query
- elif query:
- sql_query += query
-
- await self.execute(sql_query, match)
-
- async def delete_where_query(self, table: str, query: Optional[str] = None) -> None:
- """Delete data in given table using given where clausule."""
- sql_query = f"DELETE FROM {table} WHERE {query}"
- await self.execute(sql_query)
-
- async def execute(
- self, query: Union[ClauseElement, str], values: dict = None
- ) -> Any:
- """Execute command on the database."""
- return await self._db.execute(query, values)
-
- async def _migrate(self):
- """Perform database migration actions if needed."""
- # always create db tables if they don't exist to prevent errors trying to access them later
- await self.__create_database_tables()
- try:
- if prev_version := await self.get_setting("version"):
- prev_version = int(prev_version)
- else:
- prev_version = 0
- except (KeyError, ValueError):
- prev_version = 0
-
- if prev_version not in (0, SCHEMA_VERSION):
- self.logger.info(
- "Performing database migration from %s to %s",
- prev_version,
- SCHEMA_VERSION,
- )
-
- if prev_version < 18:
- # too many changes, just recreate
- await self.execute(f"DROP TABLE IF EXISTS {TABLE_ARTISTS}")
- await self.execute(f"DROP TABLE IF EXISTS {TABLE_ALBUMS}")
- await self.execute(f"DROP TABLE IF EXISTS {TABLE_TRACKS}")
- await self.execute(f"DROP TABLE IF EXISTS {TABLE_PLAYLISTS}")
- await self.execute(f"DROP TABLE IF EXISTS {TABLE_RADIOS}")
- await self.execute(f"DROP TABLE IF EXISTS {TABLE_CACHE}")
- await self.execute(f"DROP TABLE IF EXISTS {TABLE_THUMBS}")
- # recreate missing tables
- await self.__create_database_tables()
-
- if prev_version == 18:
- # model for provider_mapping completely changed,
- # we just drop the old provider_ids column and add the new provider_mappings column
- # this will require a full resync of all providers including matching but at least
- # the additional metadata is not lost
- await self.execute(
- f"ALTER TABLE {TABLE_ARTISTS} ADD provider_mappings json DEFAULT '[]';"
- )
- await self.execute(
- f"ALTER TABLE {TABLE_ALBUMS} ADD provider_mappings json DEFAULT '[]';"
- )
- await self.execute(
- f"ALTER TABLE {TABLE_TRACKS} ADD provider_mappings json DEFAULT '[]';"
- )
- await self.execute(
- f"ALTER TABLE {TABLE_PLAYLISTS} ADD provider_mappings json DEFAULT '[]';"
- )
- await self.execute(
- f"ALTER TABLE {TABLE_RADIOS} ADD provider_mappings json DEFAULT '[]';"
- )
- await self.execute(
- f"ALTER TABLE {TABLE_ARTISTS} DROP column provider_ids;"
- )
- await self.execute(
- f"ALTER TABLE {TABLE_ALBUMS} DROP column provider_ids;"
- )
- await self.execute(
- f"ALTER TABLE {TABLE_TRACKS} DROP column provider_ids;"
- )
- await self.execute(
- f"ALTER TABLE {TABLE_PLAYLISTS} DROP column provider_ids;"
- )
- await self.execute(
- f"ALTER TABLE {TABLE_RADIOS} DROP column provider_ids;"
- )
- await self.execute(f"DROP TABLE IF EXISTS {TABLE_CACHE}")
- # recreate missing table(s)
- await self.__create_database_tables()
-
- # store current schema version
- await self.set_setting("version", str(SCHEMA_VERSION))
- # compact db
- await self.mass.database.execute("VACUUM")
-
- async def __create_database_tables(self) -> None:
- """Init database tables."""
- await self.execute(
- """CREATE TABLE IF NOT EXISTS settings(
- key TEXT PRIMARY KEY,
- value TEXT
- );"""
- )
- await self.execute(
- f"""CREATE TABLE IF NOT EXISTS {TABLE_TRACK_LOUDNESS}(
- item_id INTEGER NOT NULL,
- provider TEXT NOT NULL,
- loudness REAL,
- UNIQUE(item_id, provider));"""
- )
- await self.execute(
- f"""CREATE TABLE IF NOT EXISTS {TABLE_PLAYLOG}(
- item_id INTEGER NOT NULL,
- provider TEXT NOT NULL,
- timestamp INTEGER DEFAULT 0,
- UNIQUE(item_id, provider));"""
- )
- await self.execute(
- f"""CREATE TABLE IF NOT EXISTS {TABLE_ALBUMS}(
- item_id INTEGER PRIMARY KEY AUTOINCREMENT,
- name TEXT NOT NULL,
- sort_name TEXT NOT NULL,
- sort_artist TEXT,
- album_type TEXT,
- year INTEGER,
- version TEXT,
- in_library BOOLEAN DEFAULT 0,
- upc TEXT,
- musicbrainz_id TEXT,
- artists json,
- metadata json,
- provider_mappings json,
- timestamp INTEGER DEFAULT 0
- );"""
- )
- await self.execute(
- f"""CREATE TABLE IF NOT EXISTS {TABLE_ARTISTS}(
- item_id INTEGER PRIMARY KEY AUTOINCREMENT,
- name TEXT NOT NULL,
- sort_name TEXT NOT NULL,
- musicbrainz_id TEXT,
- in_library BOOLEAN DEFAULT 0,
- metadata json,
- provider_mappings json,
- timestamp INTEGER DEFAULT 0
- );"""
- )
- await self.execute(
- f"""CREATE TABLE IF NOT EXISTS {TABLE_TRACKS}(
- item_id INTEGER PRIMARY KEY AUTOINCREMENT,
- name TEXT NOT NULL,
- sort_name TEXT NOT NULL,
- sort_artist TEXT,
- sort_album TEXT,
- version TEXT,
- duration INTEGER,
- in_library BOOLEAN DEFAULT 0,
- isrc TEXT,
- musicbrainz_id TEXT,
- artists json,
- albums json,
- metadata json,
- provider_mappings json,
- timestamp INTEGER DEFAULT 0
- );"""
- )
- await self.execute(
- f"""CREATE TABLE IF NOT EXISTS {TABLE_PLAYLISTS}(
- item_id INTEGER PRIMARY KEY AUTOINCREMENT,
- name TEXT NOT NULL,
- sort_name TEXT NOT NULL,
- owner TEXT NOT NULL,
- is_editable BOOLEAN NOT NULL,
- in_library BOOLEAN DEFAULT 0,
- metadata json,
- provider_mappings json,
- timestamp INTEGER DEFAULT 0,
- UNIQUE(name, owner)
- );"""
- )
- await self.execute(
- f"""CREATE TABLE IF NOT EXISTS {TABLE_RADIOS}(
- item_id INTEGER PRIMARY KEY AUTOINCREMENT,
- name TEXT NOT NULL UNIQUE,
- sort_name TEXT NOT NULL,
- in_library BOOLEAN DEFAULT 0,
- metadata json,
- provider_mappings json,
- timestamp INTEGER DEFAULT 0
- );"""
- )
- await self.execute(
- f"""CREATE TABLE IF NOT EXISTS {TABLE_CACHE}(
- key TEXT UNIQUE NOT NULL, expires INTEGER NOT NULL, data TEXT, checksum TEXT NULL)"""
- )
- await self.execute(
- f"""CREATE TABLE IF NOT EXISTS {TABLE_THUMBS}(
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- path TEXT NOT NULL,
- size INTEGER DEFAULT 0,
- data BLOB,
- UNIQUE(path, size));"""
- )
- # create indexes
- # TODO: create indexes for the json columns ?
- await self.execute(
- "CREATE INDEX IF NOT EXISTS artists_in_library_idx on artists(in_library);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS albums_in_library_idx on albums(in_library);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS tracks_in_library_idx on tracks(in_library);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS playlists_in_library_idx on playlists(in_library);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS radios_in_library_idx on radios(in_library);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS artists_sort_name_idx on artists(sort_name);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS albums_sort_name_idx on albums(sort_name);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS tracks_sort_name_idx on tracks(sort_name);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS playlists_sort_name_idx on playlists(sort_name);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS radios_sort_name_idx on radios(sort_name);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS artists_musicbrainz_id_idx on artists(musicbrainz_id);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS albums_musicbrainz_id_idx on albums(musicbrainz_id);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS tracks_musicbrainz_id_idx on tracks(musicbrainz_id);"
- )
- await self.execute(
- "CREATE INDEX IF NOT EXISTS tracks_isrc_idx on tracks(isrc);"
- )
- await self.execute("CREATE INDEX IF NOT EXISTS albums_upc_idx on albums(upc);")
+++ /dev/null
-"""Package with Media controllers."""
+++ /dev/null
-"""Manage MediaItems of type Album."""
-from __future__ import annotations
-
-import asyncio
-from random import choice, random
-from typing import TYPE_CHECKING, List, Optional, Union
-
-from music_assistant.constants import VARIOUS_ARTISTS
-from music_assistant.controllers.database import TABLE_ALBUMS, TABLE_TRACKS
-from music_assistant.controllers.media.base import MediaControllerBase
-from music_assistant.helpers.compare import compare_album, loose_compare_strings
-from music_assistant.helpers.json import json_serializer
-from music_assistant.models.enums import EventType, MusicProviderFeature, ProviderType
-from music_assistant.models.errors import (
- MediaNotFoundError,
- UnsupportedFeaturedException,
-)
-from music_assistant.models.event import MassEvent
-from music_assistant.models.media_items import (
- Album,
- AlbumType,
- Artist,
- ItemMapping,
- MediaType,
- Track,
-)
-
-if TYPE_CHECKING:
- from music_assistant.models.music_provider import MusicProvider
-
-
-class AlbumsController(MediaControllerBase[Album]):
- """Controller managing MediaItems of type Album."""
-
- db_table = TABLE_ALBUMS
- media_type = MediaType.ALBUM
- item_cls = Album
-
- async def get(self, *args, **kwargs) -> Album:
- """Return (full) details for a single media item."""
- album = await super().get(*args, **kwargs)
- # append full artist details to full album item
- if album.artist:
- album.artist = await self.mass.music.artists.get(
- album.artist.item_id, album.artist.provider
- )
- return album
-
- async def tracks(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- ) -> List[Track]:
- """Return album tracks for the given provider album id."""
-
- if not (provider_type == ProviderType.DATABASE or provider_id == "database"):
- # return provider album tracks
- return await self._get_provider_album_tracks(
- item_id, provider_type or provider_id
- )
-
- # db_album requested: get results from first (non-file) provider
- return await self._get_db_album_tracks(item_id)
-
- async def versions(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- ) -> List[Album]:
- """Return all versions of an album we can find on all providers."""
- assert provider_type or provider_id, "Provider type or ID must be specified"
- album = await self.get(item_id, provider_type or provider_id)
- # perform a search on all provider(types) to collect all versions/variants
- provider_types = {item.type for item in self.mass.music.providers}
- search_query = f"{album.artist.name} - {album.name}"
- all_versions = {
- prov_item.item_id: prov_item
- for prov_items in await asyncio.gather(
- *[
- self.search(search_query, provider_type)
- for provider_type in provider_types
- ]
- )
- for prov_item in prov_items
- if loose_compare_strings(album.name, prov_item.name)
- }
- # make sure that the 'base' version is included
- for prov_version in album.provider_mappings:
- if prov_version.item_id in all_versions:
- continue
- album_copy = Album.from_dict(album.to_dict())
- album_copy.item_id = prov_version.item_id
- album_copy.provider = prov_version.provider_type
- album_copy.provider_mappings = {prov_version}
- all_versions[prov_version.item_id] = album_copy
-
- # return the aggregated result
- return all_versions.values()
-
- async def add(self, item: Album) -> Album:
- """Add album to local db and return the database item."""
- # grab additional metadata
- await self.mass.metadata.get_album_metadata(item)
- existing = await self.get_db_item_by_prov_id(item.item_id, item.provider)
- if existing:
- db_item = await self.update_db_item(existing.item_id, item)
- else:
- db_item = await self.add_db_item(item)
- # also fetch same album on all providers
- await self._match(db_item)
- # return final db_item after all match/metadata actions
- db_item = await self.get_db_item(db_item.item_id)
- # dump album tracks in db
- for prov_mapping in db_item.provider_mappings:
- for track in await self._get_provider_album_tracks(
- prov_mapping.item_id, prov_mapping.provider_id
- ):
- await self.mass.music.tracks.add_db_item(track)
- self.mass.signal_event(
- MassEvent(
- EventType.MEDIA_ITEM_UPDATED
- if existing
- else EventType.MEDIA_ITEM_ADDED,
- db_item.uri,
- db_item,
- )
- )
- return db_item
-
- async def add_db_item(self, item: Album, overwrite_existing: bool = False) -> Album:
- """Add a new record to the database."""
- assert item.provider_mappings, f"Album {item.name} is missing provider id(s)"
- assert item.artist, f"Album {item.name} is missing artist"
- async with self._db_add_lock:
- cur_item = None
- # always try to grab existing item by musicbrainz_id/upc
- if item.musicbrainz_id:
- match = {"musicbrainz_id": item.musicbrainz_id}
- cur_item = await self.mass.database.get_row(self.db_table, match)
- if not cur_item and item.upc:
- match = {"upc": item.upc}
- cur_item = await self.mass.database.get_row(self.db_table, match)
- if not cur_item:
- # fallback to search and match
- for row in await self.mass.database.search(self.db_table, item.name):
- row_album = Album.from_db_row(row)
- if compare_album(row_album, item):
- cur_item = row_album
- break
- if cur_item:
- # update existing
- return await self.update_db_item(
- cur_item.item_id, item, overwrite=overwrite_existing
- )
-
- # insert new item
- album_artists = await self._get_album_artists(item, cur_item)
- if album_artists:
- sort_artist = album_artists[0].sort_name
- else:
- sort_artist = ""
- new_item = await self.mass.database.insert(
- self.db_table,
- {
- **item.to_db_row(),
- "artists": json_serializer(album_artists) or None,
- "sort_artist": sort_artist,
- },
- )
- item_id = new_item["item_id"]
- self.logger.debug("added %s to database", item.name)
- # return created object
- return await self.get_db_item(item_id)
-
- async def update_db_item(
- self,
- item_id: int,
- item: Album,
- overwrite: bool = False,
- ) -> Album:
- """Update Album record in the database."""
- assert item.provider_mappings, f"Album {item.name} is missing provider id(s)"
- assert item.artist, f"Album {item.name} is missing artist"
- cur_item = await self.get_db_item(item_id)
-
- if overwrite:
- metadata = item.metadata
- metadata.last_refresh = None
- provider_mappings = item.provider_mappings
- album_artists = await self._get_album_artists(item, overwrite=True)
- else:
- metadata = cur_item.metadata.update(item.metadata, item.provider.is_file())
- provider_mappings = {*cur_item.provider_mappings, *item.provider_mappings}
- album_artists = await self._get_album_artists(item, cur_item)
-
- if item.album_type != AlbumType.UNKNOWN:
- album_type = item.album_type
- else:
- album_type = cur_item.album_type
-
- if album_artists:
- sort_artist = album_artists[0].sort_name
- else:
- sort_artist = ""
-
- await self.mass.database.update(
- self.db_table,
- {"item_id": item_id},
- {
- "name": item.name if overwrite else cur_item.name,
- "sort_name": item.sort_name if overwrite else cur_item.sort_name,
- "sort_artist": sort_artist,
- "version": item.version if overwrite else cur_item.version,
- "year": item.year or cur_item.year,
- "upc": item.upc or cur_item.upc,
- "album_type": album_type.value,
- "artists": json_serializer(album_artists) or None,
- "metadata": json_serializer(metadata),
- "provider_mappings": json_serializer(provider_mappings),
- "musicbrainz_id": item.musicbrainz_id or cur_item.musicbrainz_id,
- },
- )
- self.logger.debug("updated %s in database: %s", item.name, item_id)
- return await self.get_db_item(item_id)
-
- async def delete_db_item(self, item_id: int, recursive: bool = False) -> None:
- """Delete record from the database."""
- # check album tracks
- db_rows = await self.mass.database.get_rows_from_query(
- f"SELECT item_id FROM {TABLE_TRACKS} WHERE albums LIKE '%\"{item_id}\"%'",
- limit=5000,
- )
- assert not (db_rows and not recursive), "Tracks attached to album"
- for db_row in db_rows:
- try:
- await self.mass.music.albums.delete_db_item(
- db_row["item_id"], recursive
- )
- except MediaNotFoundError:
- pass
-
- # delete the album itself from db
- await super().delete_db_item(item_id)
-
- async def _get_provider_album_tracks(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- ) -> List[Track]:
- """Return album tracks for the given provider album id."""
- prov = self.mass.music.get_provider(provider_id or provider_type)
- if not prov:
- return []
- full_album = await self.get_provider_item(item_id, provider_id or provider_type)
- # prefer cache items (if any)
- cache_key = f"{prov.type.value}.albumtracks.{item_id}"
- cache_checksum = full_album.metadata.checksum
- if cache := await self.mass.cache.get(cache_key, checksum=cache_checksum):
- return [Track.from_dict(x) for x in cache]
- # no items in cache - get listing from provider
- items = []
- for track in await prov.get_album_tracks(item_id):
- # make sure that the (full) album is stored on the tracks
- track.album = full_album
- if full_album.metadata.images:
- track.metadata.images = full_album.metadata.images
- items.append(track)
- # store (serializable items) in cache
- self.mass.create_task(
- self.mass.cache.set(
- cache_key, [x.to_dict() for x in items], checksum=cache_checksum
- )
- )
- return items
-
- async def _get_provider_dynamic_tracks(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- limit: int = 25,
- ):
- """Generate a dynamic list of tracks based on the album content."""
- prov = self.mass.music.get_provider(provider_id or provider_type)
- if (
- not prov
- or MusicProviderFeature.SIMILAR_TRACKS not in prov.supported_features
- ):
- return []
- album_tracks = await self._get_provider_album_tracks(
- item_id=item_id, provider_type=provider_type, provider_id=provider_id
- )
- # Grab a random track from the album that we use to obtain similar tracks for
- track = choice(album_tracks)
- # Calculate no of songs to grab from each list at a 10/90 ratio
- total_no_of_tracks = limit + limit % 2
- no_of_album_tracks = int(total_no_of_tracks * 10 / 100)
- no_of_similar_tracks = int(total_no_of_tracks * 90 / 100)
- # Grab similar tracks from the music provider
- similar_tracks = await prov.get_similar_tracks(
- prov_track_id=track.item_id, limit=no_of_similar_tracks
- )
- # Merge album content with similar tracks
- dynamic_playlist = [
- *sorted(album_tracks, key=lambda n: random())[:no_of_album_tracks],
- *sorted(similar_tracks, key=lambda n: random())[:no_of_similar_tracks],
- ]
- return sorted(dynamic_playlist, key=lambda n: random())
-
- async def _get_dynamic_tracks(self, media_item: Album, limit=25) -> List[Track]:
- """Get dynamic list of tracks for given item, fallback/default implementation."""
- # TODO: query metadata provider(s) to get similar tracks (or tracks from similar artists)
- raise UnsupportedFeaturedException(
- "No Music Provider found that supports requesting similar tracks."
- )
-
- async def _get_db_album_tracks(
- self,
- item_id: str,
- ) -> List[Track]:
- """Return in-database album tracks for the given database album."""
- db_album = await self.get_db_item(item_id)
- # simply grab all tracks in the db that are linked to this album
- # TODO: adjust to json query instead of text search?
- query = f"SELECT * FROM tracks WHERE albums LIKE '%\"{item_id}\"%'"
- result = []
- for track in await self.mass.music.tracks.get_db_items_by_query(query):
- if album_mapping := next(
- (x for x in track.albums if x.item_id == db_album.item_id), None
- ):
- # make sure that the full album is set on the track and prefer the album's images
- track.album = db_album
- if db_album.metadata.images:
- track.metadata.images = db_album.metadata.images
- # apply the disc and track number from the mapping
- track.disc_number = album_mapping.disc_number
- track.track_number = album_mapping.track_number
- result.append(track)
- return sorted(result, key=lambda x: (x.disc_number or 0, x.track_number or 0))
-
- async def _match(self, db_album: Album) -> None:
- """
- Try to find matching album on all providers for the provided (database) album.
-
- This is used to link objects of different providers/qualities together.
- """
- if db_album.provider != ProviderType.DATABASE:
- return # Matching only supported for database items
-
- async def find_prov_match(provider_type: MusicProvider):
- self.logger.debug(
- "Trying to match album %s on provider %s", db_album.name, provider.name
- )
- match_found = False
- for search_str in (
- db_album.name,
- f"{db_album.artist.name} - {db_album.name}",
- f"{db_album.artist.name} {db_album.name}",
- ):
- if match_found:
- break
- search_result = await self.search(search_str, provider.id)
- for search_result_item in search_result:
- if not search_result_item.available:
- continue
- if not compare_album(search_result_item, db_album):
- continue
- # we must fetch the full album version, search results are simplified objects
- prov_album = await self.get_provider_item(
- search_result_item.item_id, search_result_item.provider
- )
- if compare_album(prov_album, db_album):
- # 100% match, we can simply update the db with additional provider ids
- await self.update_db_item(db_album.item_id, prov_album)
- match_found = True
- return match_found
-
- # try to find match on all providers
- cur_provider_types = {x.provider_type for x in db_album.provider_mappings}
- for provider in self.mass.music.providers:
- if provider.type in cur_provider_types:
- continue
- if MusicProviderFeature.SEARCH not in provider.supported_features:
- continue
- if await find_prov_match(provider):
- cur_provider_types.add(provider.type)
- else:
- self.logger.debug(
- "Could not find match for Album %s on provider %s",
- db_album.name,
- provider.name,
- )
-
- async def _get_album_artists(
- self,
- db_album: Album,
- updated_album: Optional[Album] = None,
- overwrite: bool = False,
- ) -> List[ItemMapping]:
- """Extract (database) album artist(s) as ItemMapping."""
- album_artists = set()
- for album in (updated_album, db_album):
- if not album:
- continue
- for artist in album.artists:
- album_artists.add(await self._get_artist_mapping(artist, overwrite))
- # use intermediate set to prevent duplicates
- # filter various artists if multiple artists
- if len(album_artists) > 1:
- album_artists = {x for x in album_artists if (x.name != VARIOUS_ARTISTS)}
- return list(album_artists)
-
- async def _get_artist_mapping(
- self, artist: Union[Artist, ItemMapping], overwrite: bool = False
- ) -> ItemMapping:
- """Extract (database) track artist as ItemMapping."""
- if overwrite:
- artist = await self.mass.music.artists.add_db_item(
- artist, overwrite_existing=True
- )
- if artist.provider == ProviderType.DATABASE:
- if isinstance(artist, ItemMapping):
- return artist
- return ItemMapping.from_item(artist)
-
- if db_artist := await self.mass.music.artists.get_db_item_by_prov_id(
- artist.item_id, provider_type=artist.provider
- ):
- return ItemMapping.from_item(db_artist)
-
- db_artist = await self.mass.music.artists.add_db_item(artist)
- return ItemMapping.from_item(db_artist)
+++ /dev/null
-"""Manage MediaItems of type Artist."""
-
-import asyncio
-import itertools
-from random import choice, random
-from time import time
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
-
-from music_assistant.constants import VARIOUS_ARTISTS, VARIOUS_ARTISTS_ID
-from music_assistant.controllers.database import (
- TABLE_ALBUMS,
- TABLE_ARTISTS,
- TABLE_TRACKS,
-)
-from music_assistant.controllers.media.base import MediaControllerBase
-from music_assistant.helpers.compare import compare_strings
-from music_assistant.helpers.json import json_serializer
-from music_assistant.models.enums import EventType, MusicProviderFeature, ProviderType
-from music_assistant.models.errors import (
- MediaNotFoundError,
- UnsupportedFeaturedException,
-)
-from music_assistant.models.event import MassEvent
-from music_assistant.models.media_items import (
- Album,
- AlbumType,
- Artist,
- ItemMapping,
- MediaType,
- PagedItems,
- Track,
-)
-
-if TYPE_CHECKING:
- from music_assistant.models.music_provider import MusicProvider
-
-
-class ArtistsController(MediaControllerBase[Artist]):
- """Controller managing MediaItems of type Artist."""
-
- db_table = TABLE_ARTISTS
- media_type = MediaType.ARTIST
- item_cls = Artist
-
- async def album_artists(
- self,
- in_library: Optional[bool] = None,
- search: Optional[str] = None,
- limit: int = 500,
- offset: int = 0,
- order_by: str = "sort_name",
- ) -> PagedItems:
- """Get in-database album artists."""
- return await self.db_items(
- in_library=in_library,
- search=search,
- limit=limit,
- offset=offset,
- order_by=order_by,
- query_parts=[
- "artists.sort_name in (select albums.sort_artist from albums)"
- ],
- )
-
- async def tracks(
- self,
- item_id: Optional[str] = None,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- artist: Optional[Artist] = None,
- ) -> List[Track]:
- """Return top tracks for an artist."""
- if not artist:
- artist = await self.get(item_id, provider_type, provider_id)
- # get results from all providers
- coros = [
- self.get_provider_artist_toptracks(
- prov_mapping.item_id,
- provider_type=prov_mapping.provider_type,
- provider_id=prov_mapping.provider_id,
- cache_checksum=artist.metadata.checksum,
- )
- for prov_mapping in artist.provider_mappings
- ]
- tracks = itertools.chain.from_iterable(await asyncio.gather(*coros))
- # merge duplicates using a dict
- final_items: Dict[str, Track] = {}
- for track in tracks:
- key = f".{track.name}.{track.version}"
- if key in final_items:
- final_items[key].provider_mappings.update(track.provider_mappings)
- else:
- final_items[key] = track
- return list(final_items.values())
-
- async def albums(
- self,
- item_id: Optional[str] = None,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- artist: Optional[Artist] = None,
- ) -> List[Album]:
- """Return (all/most popular) albums for an artist."""
- if not artist:
- artist = await self.get(item_id, provider_type or provider_id)
- # get results from all providers
- coros = [
- self.get_provider_artist_albums(
- item.item_id,
- item.provider_type,
- cache_checksum=artist.metadata.checksum,
- )
- for item in artist.provider_mappings
- ]
- albums = itertools.chain.from_iterable(await asyncio.gather(*coros))
- # merge duplicates using a dict
- final_items: Dict[str, Album] = {}
- for album in albums:
- key = f".{album.name}.{album.version}"
- if key in final_items:
- final_items[key].provider_mappings.update(album.provider_mappings)
- else:
- final_items[key] = album
- if album.in_library:
- final_items[key].in_library = True
- return list(final_items.values())
-
- async def add(self, item: Artist) -> Artist:
- """Add artist to local db and return the database item."""
- # grab musicbrainz id and additional metadata
- await self.mass.metadata.get_artist_metadata(item)
- existing = await self.get_db_item_by_prov_id(item.item_id, item.provider)
- if existing:
- db_item = await self.update_db_item(existing.item_id, item)
- else:
- db_item = await self.add_db_item(item)
- # also fetch same artist on all providers
- await self.match_artist(db_item)
- # return final db_item after all match/metadata actions
- db_item = await self.get_db_item(db_item.item_id)
- self.mass.signal_event(
- MassEvent(
- EventType.MEDIA_ITEM_UPDATED
- if existing
- else EventType.MEDIA_ITEM_ADDED,
- db_item.uri,
- db_item,
- )
- )
- return db_item
-
- async def match_artist(self, db_artist: Artist):
- """
- Try to find matching artists on all providers for the provided (database) item_id.
-
- This is used to link objects of different providers together.
- """
- assert (
- db_artist.provider == ProviderType.DATABASE
- ), "Matching only supported for database items!"
- cur_provider_types = {x.provider_type for x in db_artist.provider_mappings}
- for provider in self.mass.music.providers:
- if provider.type in cur_provider_types:
- continue
- if MusicProviderFeature.SEARCH not in provider.supported_features:
- continue
- if await self._match(db_artist, provider):
- cur_provider_types.add(provider.type)
- else:
- self.logger.debug(
- "Could not find match for Artist %s on provider %s",
- db_artist.name,
- provider.name,
- )
-
- async def get_provider_artist_toptracks(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- cache_checksum: Any = None,
- ) -> List[Track]:
- """Return top tracks for an artist on given provider."""
- prov = self.mass.music.get_provider(provider_id or provider_type)
- if not prov:
- return []
- # prefer cache items (if any)
- cache_key = f"{prov.type.value}.artist_toptracks.{item_id}"
- if cache := await self.mass.cache.get(cache_key, checksum=cache_checksum):
- return [Track.from_dict(x) for x in cache]
- # no items in cache - get listing from provider
- if MusicProviderFeature.ARTIST_TOPTRACKS in prov.supported_features:
- items = await prov.get_artist_toptracks(item_id)
- else:
- # fallback implementation using the db
- if db_artist := await self.mass.music.artists.get_db_item_by_prov_id(
- item_id, provider_type=provider_type, provider_id=provider_id
- ):
- prov_id = provider_id or provider_type.value
- # TODO: adjust to json query instead of text search?
- query = f"SELECT * FROM tracks WHERE artists LIKE '%\"{db_artist.item_id}\"%'"
- query += f" AND provider_mappings LIKE '%\"{prov_id}\"%'"
- items = await self.mass.music.tracks.get_db_items_by_query(query)
- # store (serializable items) in cache
- self.mass.create_task(
- self.mass.cache.set(
- cache_key, [x.to_dict() for x in items], checksum=cache_checksum
- )
- )
- return items
-
- async def get_provider_artist_albums(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- cache_checksum: Any = None,
- ) -> List[Album]:
- """Return albums for an artist on given provider."""
- prov = self.mass.music.get_provider(provider_id or provider_type)
- if not prov:
- return []
- # prefer cache items (if any)
- cache_key = f"{prov.type.value}.artist_albums.{item_id}"
- if cache := await self.mass.cache.get(cache_key, checksum=cache_checksum):
- return [Album.from_dict(x) for x in cache]
- # no items in cache - get listing from provider
- if MusicProviderFeature.ARTIST_ALBUMS in prov.supported_features:
- items = await prov.get_artist_albums(item_id)
- else:
- # fallback implementation using the db
- if db_artist := await self.mass.music.artists.get_db_item_by_prov_id(
- item_id, provider_type=provider_type, provider_id=provider_id
- ):
- prov_id = provider_id or provider_type.value
- # TODO: adjust to json query instead of text search?
- query = f"SELECT * FROM albums WHERE artists LIKE '%\"{db_artist.item_id}\"%'"
- query += f" AND provider_mappings LIKE '%\"{prov_id}\"%'"
- items = await self.mass.music.albums.get_db_items_by_query(query)
- else:
- # edge case
- items = []
- # store (serializable items) in cache
- self.mass.create_task(
- self.mass.cache.set(
- cache_key, [x.to_dict() for x in items], checksum=cache_checksum
- )
- )
- return items
-
- async def add_db_item(
- self, item: Artist, overwrite_existing: bool = False
- ) -> Artist:
- """Add a new item record to the database."""
- assert isinstance(item, Artist), "Not a full Artist object"
- assert item.provider_mappings, "Artist is missing provider id(s)"
- # enforce various artists name + id
- if compare_strings(item.name, VARIOUS_ARTISTS):
- item.musicbrainz_id = VARIOUS_ARTISTS_ID
- if item.musicbrainz_id == VARIOUS_ARTISTS_ID:
- item.name = VARIOUS_ARTISTS
-
- async with self._db_add_lock:
- # always try to grab existing item by musicbrainz_id
- cur_item = None
- if item.musicbrainz_id:
- match = {"musicbrainz_id": item.musicbrainz_id}
- cur_item = await self.mass.database.get_row(self.db_table, match)
- if not cur_item:
- # fallback to exact name match
- # NOTE: we match an artist by name which could theoretically lead to collisions
- # but the chance is so small it is not worth the additional overhead of grabbing
- # the musicbrainz id upfront
- match = {"sort_name": item.sort_name}
- for row in await self.mass.database.get_rows(self.db_table, match):
- row_artist = Artist.from_db_row(row)
- if row_artist.sort_name == item.sort_name:
- cur_item = row_artist
- break
- if cur_item:
- # update existing
- return await self.update_db_item(
- cur_item.item_id, item, overwrite=overwrite_existing
- )
-
- # insert item
- if item.in_library and not item.timestamp:
- item.timestamp = int(time())
- new_item = await self.mass.database.insert(self.db_table, item.to_db_row())
- item_id = new_item["item_id"]
- self.logger.debug("added %s to database", item.name)
- # return created object
- return await self.get_db_item(item_id)
-
- async def update_db_item(
- self,
- item_id: int,
- item: Artist,
- overwrite: bool = False,
- ) -> Artist:
- """Update Artist record in the database."""
- cur_item = await self.get_db_item(item_id)
- if overwrite:
- metadata = item.metadata
- provider_mappings = item.provider_mappings
- else:
- metadata = cur_item.metadata.update(item.metadata, item.provider.is_file())
- provider_mappings = {*cur_item.provider_mappings, *item.provider_mappings}
-
- # enforce various artists name + id
- if compare_strings(item.name, VARIOUS_ARTISTS):
- item.musicbrainz_id = VARIOUS_ARTISTS_ID
- if item.musicbrainz_id == VARIOUS_ARTISTS_ID:
- item.name = VARIOUS_ARTISTS
-
- await self.mass.database.update(
- self.db_table,
- {"item_id": item_id},
- {
- "name": item.name if overwrite else cur_item.name,
- "sort_name": item.sort_name if overwrite else cur_item.sort_name,
- "musicbrainz_id": item.musicbrainz_id or cur_item.musicbrainz_id,
- "metadata": json_serializer(metadata),
- "provider_mappings": json_serializer(provider_mappings),
- },
- )
- self.logger.debug("updated %s in database: %s", item.name, item_id)
- return await self.get_db_item(item_id)
-
- async def delete_db_item(self, item_id: int, recursive: bool = False) -> None:
- """Delete record from the database."""
- # check artist albums
- db_rows = await self.mass.database.get_rows_from_query(
- f"SELECT item_id FROM {TABLE_ALBUMS} WHERE artists LIKE '%\"{item_id}\"%'",
- limit=5000,
- )
- assert not (db_rows and not recursive), "Albums attached to artist"
- for db_row in db_rows:
- try:
- await self.mass.music.albums.delete_db_item(
- db_row["item_id"], recursive
- )
- except MediaNotFoundError:
- pass
-
- # check artist tracks
- db_rows = await self.mass.database.get_rows_from_query(
- f"SELECT item_id FROM {TABLE_TRACKS} WHERE artists LIKE '%\"{item_id}\"%'",
- limit=5000,
- )
- assert not (db_rows and not recursive), "Tracks attached to artist"
- for db_row in db_rows:
- try:
- await self.mass.music.albums.delete_db_item(
- db_row["item_id"], recursive
- )
- except MediaNotFoundError:
- pass
-
- # delete the artist itself from db
- await super().delete_db_item(item_id)
-
- async def _get_provider_dynamic_tracks(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- limit: int = 25,
- ):
- """Generate a dynamic list of tracks based on the artist's top tracks."""
- prov = self.mass.music.get_provider(provider_id or provider_type)
- if (
- not prov
- or MusicProviderFeature.SIMILAR_TRACKS not in prov.supported_features
- ):
- return []
- top_tracks = await self.get_provider_artist_toptracks(
- item_id=item_id, provider_type=provider_type, provider_id=provider_id
- )
- # Grab a random track from the album that we use to obtain similar tracks for
- track = choice(top_tracks)
- # Calculate no of songs to grab from each list at a 10/90 ratio
- total_no_of_tracks = limit + limit % 2
- no_of_artist_tracks = int(total_no_of_tracks * 10 / 100)
- no_of_similar_tracks = int(total_no_of_tracks * 90 / 100)
- # Grab similar tracks from the music provider
- similar_tracks = await prov.get_similar_tracks(
- prov_track_id=track.item_id, limit=no_of_similar_tracks
- )
- # Merge album content with similar tracks
- dynamic_playlist = [
- *sorted(top_tracks, key=lambda n: random())[:no_of_artist_tracks],
- *sorted(similar_tracks, key=lambda n: random())[:no_of_similar_tracks],
- ]
- return sorted(dynamic_playlist, key=lambda n: random())
-
- async def _get_dynamic_tracks(
- self, media_item: Artist, limit: int = 25
- ) -> List[Track]:
- """Get dynamic list of tracks for given item, fallback/default implementation."""
- # TODO: query metadata provider(s) to get similar tracks (or tracks from similar artists)
- raise UnsupportedFeaturedException(
- "No Music Provider found that supports requesting similar tracks."
- )
-
- async def _match(self, db_artist: Artist, provider: "MusicProvider") -> bool:
- """Try to find matching artists on given provider for the provided (database) artist."""
- self.logger.debug(
- "Trying to match artist %s on provider %s", db_artist.name, provider.name
- )
- # try to get a match with some reference tracks of this artist
- for ref_track in await self.tracks(
- db_artist.item_id, db_artist.provider, artist=db_artist
- ):
- # make sure we have a full track
- if isinstance(ref_track.album, ItemMapping):
- ref_track = await self.mass.music.tracks.get(
- ref_track.item_id, ref_track.provider
- )
- for search_str in (
- f"{db_artist.name} - {ref_track.name}",
- f"{db_artist.name} {ref_track.name}",
- ref_track.name,
- ):
- search_results = await self.mass.music.tracks.search(
- search_str, provider.type
- )
- for search_result_item in search_results:
- if search_result_item.sort_name != ref_track.sort_name:
- continue
- # get matching artist from track
- for search_item_artist in search_result_item.artists:
- if search_item_artist.sort_name != db_artist.sort_name:
- continue
- # 100% album match
- # get full artist details so we have all metadata
- prov_artist = await self.get_provider_item(
- search_item_artist.item_id, search_item_artist.provider
- )
- await self.update_db_item(db_artist.item_id, prov_artist)
- return True
- # try to get a match with some reference albums of this artist
- artist_albums = await self.albums(
- db_artist.item_id, db_artist.provider, artist=db_artist
- )
- for ref_album in artist_albums:
- if ref_album.album_type == AlbumType.COMPILATION:
- continue
- if ref_album.artist is None:
- continue
- for search_str in (
- ref_album.name,
- f"{db_artist.name} - {ref_album.name}",
- f"{db_artist.name} {ref_album.name}",
- ):
- search_result = await self.mass.music.albums.search(
- search_str, provider.type
- )
- for search_result_item in search_result:
- if search_result_item.artist is None:
- continue
- if search_result_item.sort_name != ref_album.sort_name:
- continue
- # artist must match 100%
- if (
- search_result_item.artist.sort_name
- != ref_album.artist.sort_name
- ):
- continue
- # 100% match
- # get full artist details so we have all metadata
- prov_artist = await self.get_provider_item(
- search_result_item.artist.item_id,
- search_result_item.artist.provider,
- )
- await self.update_db_item(db_artist.item_id, prov_artist)
- return True
- return False
+++ /dev/null
-"""Base (ABC) MediaType specific controller."""
-from __future__ import annotations
-
-import asyncio
-from abc import ABCMeta, abstractmethod
-from time import time
-from typing import (
- TYPE_CHECKING,
- AsyncGenerator,
- Generic,
- List,
- Optional,
- Tuple,
- TypeVar,
- Union,
-)
-
-from music_assistant.helpers.json import json_serializer
-from music_assistant.models.enums import (
- EventType,
- MediaType,
- MusicProviderFeature,
- ProviderType,
-)
-from music_assistant.models.errors import MediaNotFoundError
-from music_assistant.models.event import MassEvent
-from music_assistant.models.media_items import (
- MediaItemType,
- PagedItems,
- Track,
- media_from_dict,
-)
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-ItemCls = TypeVar("ItemCls", bound="MediaControllerBase")
-
-REFRESH_INTERVAL = 60 * 60 * 24 * 30
-
-
-class MediaControllerBase(Generic[ItemCls], metaclass=ABCMeta):
- """Base model for controller managing a MediaType."""
-
- media_type: MediaType
- item_cls: MediaItemType
- db_table: str
-
- def __init__(self, mass: MusicAssistant):
- """Initialize class."""
- self.mass = mass
- self.logger = mass.logger.getChild(f"music.{self.media_type.value}")
- self._db_add_lock = asyncio.Lock()
-
- @abstractmethod
- async def add(self, item: ItemCls) -> ItemCls:
- """Add item to local db and return the database item."""
- raise NotImplementedError
-
- @abstractmethod
- async def add_db_item(
- self, item: ItemCls, overwrite_existing: bool = False
- ) -> ItemCls:
- """Add a new record for this mediatype to the database."""
- raise NotImplementedError
-
- @abstractmethod
- async def update_db_item(
- self,
- item_id: int,
- item: ItemCls,
- overwrite: bool = False,
- ) -> ItemCls:
- """Update record in the database, merging data."""
- raise NotImplementedError
-
- async def db_items(
- self,
- in_library: Optional[bool] = None,
- search: Optional[str] = None,
- limit: int = 500,
- offset: int = 0,
- order_by: str = "sort_name",
- query_parts: Optional[List[str]] = None,
- ) -> PagedItems:
- """Get in-database items."""
- sql_query = f"SELECT * FROM {self.db_table}"
- params = {}
- query_parts = query_parts or []
- if search:
- params["search"] = f"%{search}%"
- if self.media_type in (MediaType.ALBUM, MediaType.TRACK):
- query_parts.append("(name LIKE :search or artists LIKE :search)")
- else:
- query_parts.append("name LIKE :search")
- if in_library is not None:
- query_parts.append("in_library = :in_library")
- params["in_library"] = in_library
- if query_parts:
- sql_query += " WHERE " + " AND ".join(query_parts)
- sql_query += f" ORDER BY {order_by}"
- items = await self.get_db_items_by_query(
- sql_query, params, limit=limit, offset=offset
- )
- count = len(items)
- if 0 < count < limit:
- total = offset + count
- else:
- total = await self.mass.database.get_count_from_query(sql_query, params)
- return PagedItems(items, count, limit, offset, total)
-
- async def iter_db_items(
- self,
- in_library: Optional[bool] = None,
- search: Optional[str] = None,
- order_by: str = "sort_name",
- ) -> AsyncGenerator[ItemCls, None]:
- """Iterate all in-database items."""
- limit: int = 500
- offset: int = 0
- while True:
- next_items = await self.db_items(
- in_library=in_library,
- search=search,
- limit=limit,
- offset=offset,
- order_by=order_by,
- )
- for item in next_items.items:
- yield item
- if next_items.count < limit:
- break
- offset += limit
-
- async def get(
- self,
- provider_item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- force_refresh: bool = False,
- lazy: bool = True,
- details: ItemCls = None,
- ) -> ItemCls:
- """Return (full) details for a single media item."""
- assert (
- provider_type or provider_id
- ), "provider_type or provider_id must be supplied"
- if isinstance(provider_type, str):
- provider_type = ProviderType(provider_type)
- db_item = await self.get_db_item_by_prov_id(
- provider_item_id=provider_item_id,
- provider_type=provider_type,
- provider_id=provider_id,
- )
- if db_item and (time() - db_item.last_refresh) > REFRESH_INTERVAL:
- # it's been too long since the full metadata was last retrieved (or never at all)
- force_refresh = True
- if db_item and force_refresh:
- # get (first) provider item id belonging to this db item
- provider_id, provider_item_id = await self.get_provider_mapping(db_item)
- elif db_item:
- # we have a db item and no refreshing is needed, return the results!
- return db_item
- if not details and provider_id:
- # no details provider nor in db, fetch them from the provider
- details = await self.get_provider_item(provider_item_id, provider_id)
- if not details and provider_type:
- # check providers for given provider type one by one
- for prov in self.mass.music.providers:
- if not prov.available:
- continue
- if prov.type == provider_type:
- try:
- details = await self.get_provider_item(
- provider_item_id, prov.id
- )
- except MediaNotFoundError:
- pass
- else:
- break
- if not details:
- # we couldn't get a match from any of the providers, raise error
- raise MediaNotFoundError(
- f"Item not found: {provider_type.value or id}/{provider_item_id}"
- )
- # create job to add the item to the db, including matching metadata etc. takes some time
- # in 99% of the cases we just return lazy because we want the details as fast as possible
- # only if we really need to wait for the result (e.g. to prevent race conditions), we
- # can set lazy to false and we await to job to complete.
- add_job = self.mass.add_job(
- self.add(details),
- f"Add {details.uri} to database",
- )
- if not lazy:
- await add_job.wait()
- return add_job.result
-
- return details
-
- async def search(
- self,
- search_query: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- limit: int = 25,
- ) -> List[ItemCls]:
- """Search database or provider with given query."""
- # create safe search string
- search_query = search_query.replace("/", " ").replace("'", "")
- if provider_type == ProviderType.DATABASE or provider_id == "database":
- return [
- self.item_cls.from_db_row(db_row)
- for db_row in await self.mass.database.search(
- self.db_table, search_query
- )
- ]
-
- prov = self.mass.music.get_provider(provider_id or provider_type)
- if not prov or MusicProviderFeature.SEARCH not in prov.supported_features:
- return []
- if not prov.library_supported(self.media_type):
- # assume library supported also means that this mediatype is supported
- return []
-
- # prefer cache items (if any)
- cache_key = (
- f"{prov.type.value}.search.{self.media_type.value}.{search_query}.{limit}"
- )
- if cache := await self.mass.cache.get(cache_key):
- return [media_from_dict(x) for x in cache]
- # no items in cache - get listing from provider
- items = await prov.search(
- search_query,
- [self.media_type],
- limit,
- )
- # store (serializable items) in cache
- if not prov.type.is_file(): # do not cache filesystem results
- self.mass.create_task(
- self.mass.cache.set(
- cache_key, [x.to_dict() for x in items], expiration=86400 * 7
- )
- )
- return items
-
- async def add_to_library(
- self,
- provider_item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- ) -> None:
- """Add an item to the library."""
- prov_item = await self.get_db_item_by_prov_id(
- provider_item_id, provider_type=provider_type, provider_id=provider_id
- )
- if prov_item is None:
- prov_item = await self.get_provider_item(
- provider_item_id, provider_id or provider_type
- )
- if prov_item.in_library is True:
- return
- # mark as favorite/library item on provider(s)
- for prov_mapping in prov_item.provider_mappings:
- if prov := self.mass.music.get_provider(prov_mapping.provider_id):
- if not prov.library_edit_supported(self.media_type):
- continue
- await prov.library_add(provider_id.item_id, self.media_type)
- # mark as library item in internal db if db item
- if prov_item.provider == ProviderType.DATABASE:
- if not prov_item.in_library:
- prov_item.in_library = True
- await self.set_db_library(prov_item.item_id, True)
-
- async def remove_from_library(
- self,
- provider_item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- ) -> None:
- """Remove item from the library."""
- prov_item = await self.get_db_item_by_prov_id(
- provider_item_id, provider_type=provider_type, provider_id=provider_id
- )
- if prov_item is None:
- prov_item = await self.get_provider_item(
- provider_item_id, provider_id or provider_type
- )
- if prov_item.in_library is False:
- return
- # unmark as favorite/library item on provider(s)
- for prov_mapping in prov_item.provider_mappings:
- if prov := self.mass.music.get_provider(prov_mapping.provider_id):
- if not prov.library_edit_supported(self.media_type):
- continue
- await prov.library_remove(prov_mapping.item_id, self.media_type)
- # unmark as library item in internal db if db item
- if prov_item.provider == ProviderType.DATABASE:
- prov_item.in_library = False
- await self.set_db_library(prov_item.item_id, False)
-
- async def get_provider_mapping(self, item: ItemCls) -> Tuple[str, str]:
- """Return (first) provider and item id."""
- if item.provider == ProviderType.DATABASE:
- # make sure we have a full object
- item = await self.get_db_item(item.item_id)
- for prefer_file in (True, False):
- for prov_mapping in item.provider_mappings:
- # returns the first provider that is available
- if not prov_mapping.available:
- continue
- if prefer_file and not prov_mapping.provider_type.is_file():
- continue
- if self.mass.music.get_provider(prov_mapping.provider_id):
- return (prov_mapping.provider_id, prov_mapping.item_id)
- return None, None
-
- async def get_db_items_by_query(
- self,
- custom_query: Optional[str] = None,
- query_params: Optional[dict] = None,
- limit: int = 500,
- offset: int = 0,
- ) -> List[ItemCls]:
- """Fetch MediaItem records from database given a custom query."""
- return [
- self.item_cls.from_db_row(db_row)
- for db_row in await self.mass.database.get_rows_from_query(
- custom_query, query_params, limit=limit, offset=offset
- )
- ]
-
- async def get_db_item(self, item_id: Union[int, str]) -> ItemCls:
- """Get record by id."""
- match = {"item_id": int(item_id)}
- if db_row := await self.mass.database.get_row(self.db_table, match):
- return self.item_cls.from_db_row(db_row)
- raise MediaNotFoundError(f"Album not found in database: {item_id}")
-
- async def get_db_item_by_prov_id(
- self,
- provider_item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- ) -> ItemCls | None:
- """Get the database item for the given provider_id."""
- assert (
- provider_type or provider_id
- ), "provider_type or provider_id must be supplied"
- if isinstance(provider_type, str):
- provider_type = ProviderType(provider_type)
- if provider_type == ProviderType.DATABASE or provider_id == "database":
- return await self.get_db_item(provider_item_id)
- for item in await self.get_db_items_by_prov_id(
- provider_type=provider_type,
- provider_id=provider_id,
- provider_item_ids=(provider_item_id,),
- ):
- return item
- return None
-
- async def get_db_items_by_prov_id(
- self,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- provider_item_ids: Optional[Tuple[str]] = None,
- limit: int = 500,
- offset: int = 0,
- ) -> List[ItemCls]:
- """Fetch all records from database for given provider."""
- assert (
- provider_type or provider_id
- ), "provider_type or provider_id must be supplied"
- if isinstance(provider_type, str):
- provider_type = ProviderType(provider_type)
- if provider_type == ProviderType.DATABASE or provider_id == "database":
- return await self.get_db_items_by_query(limit=limit, offset=offset)
-
- query = f"SELECT * FROM {self.db_table}, json_each(provider_mappings)"
- if provider_id is not None:
- query += f" WHERE json_extract(json_each.value, '$.provider_id') = '{provider_id}'"
- elif provider_type is not None:
- query += f" WHERE json_extract(json_each.value, '$.provider_type') = '{provider_type.value}'"
- if provider_item_ids is not None:
- prov_ids = str(tuple(provider_item_ids))
- if prov_ids.endswith(",)"):
- prov_ids = prov_ids.replace(",)", ")")
- query += f" AND json_extract(json_each.value, '$.item_id') in {prov_ids}"
-
- return await self.get_db_items_by_query(query, limit=limit, offset=offset)
-
- async def set_db_library(self, item_id: int, in_library: bool) -> None:
- """Set the in-library bool on a database item."""
- match = {"item_id": item_id}
- timestamp = int(time()) if in_library else 0
- await self.mass.database.update(
- self.db_table, match, {"in_library": in_library, "timestamp": timestamp}
- )
- db_item = await self.get_db_item(item_id)
- self.mass.signal_event(
- MassEvent(EventType.MEDIA_ITEM_UPDATED, db_item.uri, db_item)
- )
-
- async def get_provider_item(
- self,
- item_id: str,
- provider_id_or_type: Union[str, ProviderType],
- ) -> ItemCls:
- """Return item details for the given provider item id."""
- if provider_id_or_type in ("database", ProviderType.DATABASE):
- item = await self.get_db_item(item_id)
- else:
- provider = self.mass.music.get_provider(provider_id_or_type)
- item = await provider.get_item(self.media_type, item_id)
- if not item:
- raise MediaNotFoundError(
- f"{self.media_type.value}//{item_id} not found on provider {provider_id_or_type}"
- )
- return item
-
- async def remove_prov_mapping(self, item_id: int, provider_id: str) -> None:
- """Remove provider id(s) from item."""
- try:
- db_item = await self.get_db_item(item_id)
- except MediaNotFoundError:
- # edge case: already deleted / race condition
- return
-
- db_item.provider_mappings = {
- x for x in db_item.provider_mappings if x.provider_id != provider_id
- }
- if not db_item.provider_mappings:
- # item has no more provider_mappings left, it is completely deleted
- try:
- await self.delete_db_item(db_item.item_id)
- except AssertionError:
- self.logger.debug(
- "Could not delete %s: it has items attached", db_item.item_id
- )
- return
-
- # update the item in db (provider_mappings column only)
- match = {"item_id": item_id}
- await self.mass.database.update(
- self.db_table,
- match,
- {"provider_mappings": json_serializer(db_item.provider_mappings)},
- )
- self.mass.signal_event(
- MassEvent(EventType.MEDIA_ITEM_UPDATED, db_item.uri, db_item)
- )
-
- self.logger.debug("removed provider %s from item id %s", provider_id, item_id)
-
- async def delete_db_item(self, item_id: int, recursive: bool = False) -> None:
- """Delete record from the database."""
- db_item = await self.get_db_item(item_id)
- assert db_item, f"Item does not exist: {item_id}"
- # delete item
- await self.mass.database.delete(
- self.db_table,
- {"item_id": int(item_id)},
- )
- # NOTE: this does not delete any references to this item in other records,
- # this is handled/overridden in the mediatype specific controllers
- self.mass.signal_event(
- MassEvent(EventType.MEDIA_ITEM_DELETED, db_item.uri, db_item)
- )
- self.logger.debug("deleted item with id %s from database", item_id)
-
- async def dynamic_tracks(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- limit: int = 25,
- ) -> List[Track]:
- """Return a dynamic list of tracks based on the given item."""
- ref_item = await self.get(item_id, provider_type, provider_id)
- for prov_mapping in ref_item.provider_mappings:
- prov = self.mass.music.get_provider(prov_mapping.provider_id)
- if not prov.available:
- continue
- if MusicProviderFeature.SIMILAR_TRACKS not in prov.supported_features:
- continue
- return await self._get_provider_dynamic_tracks(
- item_id=prov_mapping.item_id,
- provider_type=prov_mapping.provider_type,
- provider_id=prov_mapping.provider_id,
- limit=limit,
- )
- # Fallback to the default implementation
- return await self._get_dynamic_tracks(ref_item)
-
- @abstractmethod
- async def _get_provider_dynamic_tracks(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- limit: int = 25,
- ) -> List[Track]:
- """Generate a dynamic list of tracks based on the item's content."""
-
- @abstractmethod
- async def _get_dynamic_tracks(
- self, media_item: ItemCls, limit: int = 25
- ) -> List[Track]:
- """Get dynamic list of tracks for given item, fallback/default implementation."""
+++ /dev/null
-"""Manage MediaItems of type Playlist."""
-from __future__ import annotations
-
-import random
-from ctypes import Union
-from time import time
-from typing import Any, List, Optional, Tuple
-
-from music_assistant.controllers.database import TABLE_PLAYLISTS
-from music_assistant.helpers.json import json_serializer
-from music_assistant.helpers.uri import create_uri
-from music_assistant.models.enums import (
- EventType,
- MediaType,
- MusicProviderFeature,
- ProviderType,
-)
-from music_assistant.models.errors import (
- InvalidDataError,
- MediaNotFoundError,
- ProviderUnavailableError,
- UnsupportedFeaturedException,
-)
-from music_assistant.models.event import MassEvent
-from music_assistant.models.media_items import Playlist, Track
-
-from .base import MediaControllerBase
-
-
-class PlaylistController(MediaControllerBase[Playlist]):
- """Controller managing MediaItems of type Playlist."""
-
- db_table = TABLE_PLAYLISTS
- media_type = MediaType.PLAYLIST
- item_cls = Playlist
-
- async def tracks(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- ) -> List[Track]:
- """Return playlist tracks for the given provider playlist id."""
- playlist = await self.get(item_id, provider_type, provider_id)
- prov = next(x for x in playlist.provider_mappings)
- return await self._get_provider_playlist_tracks(
- prov.item_id,
- provider_type=prov.provider_type,
- provider_id=prov.provider_id,
- cache_checksum=playlist.metadata.checksum,
- )
-
- async def add(self, item: Playlist) -> Playlist:
- """Add playlist to local db and return the new database item."""
- item.metadata.last_refresh = int(time())
- await self.mass.metadata.get_playlist_metadata(item)
- existing = await self.get_db_item_by_prov_id(item.item_id, item.provider)
- if existing:
- db_item = await self.update_db_item(existing.item_id, item)
- else:
- db_item = await self.add_db_item(item)
- self.mass.signal_event(
- MassEvent(
- EventType.MEDIA_ITEM_UPDATED
- if existing
- else EventType.MEDIA_ITEM_ADDED,
- db_item.uri,
- db_item,
- )
- )
- return db_item
-
- async def create(
- self, name: str, prov_type_or_id: Union[ProviderType, str, None] = None
- ) -> Playlist:
- """Create new playlist."""
- # if prov_type_or_id is omitted, prefer file
- if prov_type_or_id:
- provider = self.mass.music.get_provider(prov_type_or_id)
- else:
- try:
- provider = self.mass.music.get_provider(ProviderType.FILESYSTEM_LOCAL)
- except ProviderUnavailableError:
- provider = next(
- (
- x
- for x in self.mass.music.providers
- if MusicProviderFeature.PLAYLIST_CREATE in x.supported_features
- ),
- None,
- )
- if provider is None:
- raise ProviderUnavailableError(
- "No provider available which allows playlists creation."
- )
-
- return await provider.create_playlist(name)
-
- async def add_playlist_tracks(self, db_playlist_id: str, uris: List[str]) -> None:
- """Add multiple tracks to playlist. Creates background tasks to process the action."""
- playlist = await self.get_db_item(db_playlist_id)
- if not playlist:
- raise MediaNotFoundError(f"Playlist with id {db_playlist_id} not found")
- if not playlist.is_editable:
- raise InvalidDataError(f"Playlist {playlist.name} is not editable")
- for uri in uris:
- job_desc = f"Add track {uri} to playlist {playlist.name}"
- self.mass.add_job(self.add_playlist_track(db_playlist_id, uri), job_desc)
-
- async def add_playlist_track(self, db_playlist_id: str, track_uri: str) -> None:
- """Add track to playlist - make sure we dont add duplicates."""
- # we can only edit playlists that are in the database (marked as editable)
- playlist = await self.get_db_item(db_playlist_id)
- if not playlist:
- raise MediaNotFoundError(f"Playlist with id {db_playlist_id} not found")
- if not playlist.is_editable:
- raise InvalidDataError(f"Playlist {playlist.name} is not editable")
- # make sure we have recent full track details
- track = await self.mass.music.get_item_by_uri(track_uri, lazy=False)
- assert track.media_type == MediaType.TRACK
- # a playlist can only have one provider (for now)
- playlist_prov = next(iter(playlist.provider_mappings))
- # grab all existing track ids in the playlist so we can check for duplicates
- cur_playlist_track_ids = set()
- count = 0
- for item in await self.tracks(
- playlist_prov.item_id, playlist_prov.provider_type
- ):
- count += 1
- cur_playlist_track_ids.update(
- {
- i.item_id
- for i in item.provider_mappings
- if i.provider_id == playlist_prov.provider_id
- }
- )
- # check for duplicates
- for track_prov in track.provider_mappings:
- if (
- track_prov.provider_type == playlist_prov.provider_type
- and track_prov.item_id in cur_playlist_track_ids
- ):
- raise InvalidDataError(
- "Track already exists in playlist {playlist.name}"
- )
- # add track to playlist
- # we can only add a track to a provider playlist if track is available on that provider
- # a track can contain multiple versions on the same provider
- # simply sort by quality and just add the first one (assuming track is still available)
- track_id_to_add = None
- for track_version in sorted(
- track.provider_mappings, key=lambda x: x.quality, reverse=True
- ):
- if not track.available:
- continue
- if playlist_prov.provider_type.is_file():
- # the file provider can handle uri's from all providers so simply add the uri
- track_id_to_add = track_version.url or create_uri(
- MediaType.TRACK,
- track_version.provider_type,
- track_version.item_id,
- )
- break
- if track_version.provider_type == playlist_prov.provider_type:
- track_id_to_add = track_version.item_id
- break
- if not track_id_to_add:
- raise MediaNotFoundError(
- f"Track is not available on provider {playlist_prov.provider_type}"
- )
- # actually add the tracks to the playlist on the provider
- provider = self.mass.music.get_provider(playlist_prov.provider_id)
- await provider.add_playlist_tracks(playlist_prov.item_id, [track_id_to_add])
- # invalidate cache by updating the checksum
- await self.get(
- db_playlist_id, provider_type=ProviderType.DATABASE, force_refresh=True
- )
-
- async def remove_playlist_tracks(
- self, db_playlist_id: str, positions_to_remove: Tuple[int]
- ) -> None:
- """Remove multiple tracks from playlist."""
- playlist = await self.get_db_item(db_playlist_id)
- if not playlist:
- raise MediaNotFoundError(f"Playlist with id {db_playlist_id} not found")
- if not playlist.is_editable:
- raise InvalidDataError(f"Playlist {playlist.name} is not editable")
- for prov_mapping in playlist.provider_mappings:
- provider = self.mass.music.get_provider(prov_mapping.provider_id)
- if (
- MusicProviderFeature.PLAYLIST_TRACKS_EDIT
- not in provider.supported_features
- ):
- self.logger.warning(
- "Provider %s does not support editing playlists",
- prov_mapping.provider_type.value,
- )
- continue
- await provider.remove_playlist_tracks(
- prov_mapping.item_id, positions_to_remove
- )
- # invalidate cache by updating the checksum
- await self.get(db_playlist_id, ProviderType.DATABASE, force_refresh=True)
-
- async def add_db_item(
- self, item: Playlist, overwrite_existing: bool = False
- ) -> Playlist:
- """Add a new record to the database."""
- async with self._db_add_lock:
- match = {"name": item.name, "owner": item.owner}
- if cur_item := await self.mass.database.get_row(self.db_table, match):
- # update existing
- return await self.update_db_item(
- cur_item["item_id"], item, overwrite=overwrite_existing
- )
-
- # insert new item
- new_item = await self.mass.database.insert(self.db_table, item.to_db_row())
- item_id = new_item["item_id"]
- self.logger.debug("added %s to database", item.name)
- # return created object
- return await self.get_db_item(item_id)
-
- async def update_db_item(
- self,
- item_id: int,
- item: Playlist,
- overwrite: bool = False,
- ) -> Playlist:
- """Update Playlist record in the database."""
- cur_item = await self.get_db_item(item_id)
- if overwrite:
- metadata = item.metadata
- provider_mappings = item.provider_mappings
- else:
- metadata = cur_item.metadata.update(item.metadata)
- provider_mappings = {*cur_item.provider_mappings, *item.provider_mappings}
-
- await self.mass.database.update(
- self.db_table,
- {"item_id": item_id},
- {
- # always prefer name/owner from updated item here
- "name": item.name,
- "sort_name": item.sort_name,
- "owner": item.owner,
- "is_editable": item.is_editable,
- "metadata": json_serializer(metadata),
- "provider_mappings": json_serializer(provider_mappings),
- },
- )
- self.logger.debug("updated %s in database: %s", item.name, item_id)
- return await self.get_db_item(item_id)
-
- async def _get_provider_playlist_tracks(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- cache_checksum: Any = None,
- ) -> List[Track]:
- """Return album tracks for the given provider album id."""
- provider = self.mass.music.get_provider(provider_id or provider_type)
- if not provider:
- return []
- # prefer cache items (if any)
- cache_key = f"{provider.id}.playlist.{item_id}.tracks"
- if cache := await self.mass.cache.get(cache_key, checksum=cache_checksum):
- return [Track.from_dict(x) for x in cache]
- # no items in cache - get listing from provider
- items = await provider.get_playlist_tracks(item_id)
- # double check if position set
- if items:
- assert (
- items[0].position is not None
- ), "Playlist items require position to be set"
- # store (serializable items) in cache
- self.mass.create_task(
- self.mass.cache.set(
- cache_key, [x.to_dict() for x in items], checksum=cache_checksum
- )
- )
- return items
-
- async def _get_provider_dynamic_tracks(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- limit: int = 25,
- ):
- """Generate a dynamic list of tracks based on the playlist content."""
- provider = self.mass.music.get_provider(provider_id or provider_type)
- if (
- not provider
- or MusicProviderFeature.SIMILAR_TRACKS not in provider.supported_features
- ):
- return []
- playlist_tracks = await self._get_provider_playlist_tracks(
- item_id=item_id, provider_type=provider_type, provider_id=provider_id
- )
- # filter out unavailable tracks
- playlist_tracks = [x for x in playlist_tracks if x.available]
- limit = min(limit, len(playlist_tracks))
- # use set to prevent duplicates
- final_items = set()
- # to account for playlists with mixed content we grab suggestions from a few
- # random playlist tracks to prevent getting too many tracks of one of the
- # source playlist's genres.
- while len(final_items) < limit:
- # grab 5 random tracks from the playlist
- base_tracks = random.sample(playlist_tracks, 5)
- # add the source/base playlist tracks to the final list...
- final_items.update(base_tracks)
- # get 5 suggestions for one of the base tracks
- base_track = next(x for x in base_tracks if x.available)
- similar_tracks = await provider.get_similar_tracks(
- prov_track_id=base_track.item_id, limit=5
- )
- final_items.update(x for x in similar_tracks if x.available)
-
- # NOTE: In theory we can return a few more items than limit here
- # Shuffle the final items list
- return random.sample(final_items, len(final_items))
-
- async def _get_dynamic_tracks(
- self, media_item: Playlist, limit: int = 25
- ) -> List[Track]:
- """Get dynamic list of tracks for given item, fallback/default implementation."""
- # TODO: query metadata provider(s) to get similar tracks (or tracks from similar artists)
- raise UnsupportedFeaturedException(
- "No Music Provider found that supports requesting similar tracks."
- )
+++ /dev/null
-"""Manage MediaItems of type Radio."""
-from __future__ import annotations
-
-import asyncio
-from time import time
-from typing import List, Optional
-
-from music_assistant.controllers.database import TABLE_RADIOS
-from music_assistant.helpers.compare import loose_compare_strings
-from music_assistant.helpers.json import json_serializer
-from music_assistant.models.enums import EventType, MediaType, ProviderType
-from music_assistant.models.event import MassEvent
-from music_assistant.models.media_items import Radio, Track
-
-from .base import MediaControllerBase
-
-
-class RadioController(MediaControllerBase[Radio]):
- """Controller managing MediaItems of type Radio."""
-
- db_table = TABLE_RADIOS
- media_type = MediaType.RADIO
- item_cls = Radio
-
- async def versions(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- ) -> List[Radio]:
- """Return all versions of a radio station we can find on all providers."""
- assert provider_type or provider_id, "Provider type or ID must be specified"
- radio = await self.get(item_id, provider_type, provider_id)
- # perform a search on all provider(types) to collect all versions/variants
- provider_types = {item.type for item in self.mass.music.providers}
- all_versions = {
- prov_item.item_id: prov_item
- for prov_items in await asyncio.gather(
- *[
- self.search(radio.name, provider_type)
- for provider_type in provider_types
- ]
- )
- for prov_item in prov_items
- if loose_compare_strings(radio.name, prov_item.name)
- }
- # make sure that the 'base' version is included
- for prov_version in radio.provider_mappings:
- if prov_version.item_id in all_versions:
- continue
- radio_copy = Radio.from_dict(radio.to_dict())
- radio_copy.item_id = prov_version.item_id
- radio_copy.provider = prov_version.provider_type
- radio_copy.provider_mappings = {prov_version}
- all_versions[prov_version.item_id] = radio_copy
-
- # return the aggregated result
- return all_versions.values()
-
- async def add(self, item: Radio) -> Radio:
- """Add radio to local db and return the new database item."""
- item.metadata.last_refresh = int(time())
- await self.mass.metadata.get_radio_metadata(item)
- existing = await self.get_db_item_by_prov_id(item.item_id, item.provider)
- if existing:
- db_item = await self.update_db_item(existing.item_id, item)
- else:
- db_item = await self.add_db_item(item)
- self.mass.signal_event(
- MassEvent(
- EventType.MEDIA_ITEM_UPDATED
- if existing
- else EventType.MEDIA_ITEM_ADDED,
- db_item.uri,
- db_item,
- )
- )
- return db_item
-
- async def add_db_item(self, item: Radio, overwrite_existing: bool = False) -> Radio:
- """Add a new item record to the database."""
- assert item.provider_mappings
- async with self._db_add_lock:
- match = {"name": item.name}
- if cur_item := await self.mass.database.get_row(self.db_table, match):
- # update existing
- return await self.update_db_item(
- cur_item["item_id"], item, overwrite=overwrite_existing
- )
-
- # insert new item
- new_item = await self.mass.database.insert(self.db_table, item.to_db_row())
- item_id = new_item["item_id"]
- self.logger.debug("added %s to database", item.name)
- # return created object
- return await self.get_db_item(item_id)
-
- async def update_db_item(
- self,
- item_id: int,
- item: Radio,
- overwrite: bool = False,
- ) -> Radio:
- """Update Radio record in the database."""
- cur_item = await self.get_db_item(item_id)
- if overwrite:
- metadata = item.metadata
- provider_mappings = item.provider_mappings
- else:
- metadata = cur_item.metadata.update(item.metadata)
- provider_mappings = {*cur_item.provider_mappings, *item.provider_mappings}
-
- match = {"item_id": item_id}
- await self.mass.database.update(
- self.db_table,
- match,
- {
- # always prefer name from updated item here
- "name": item.name,
- "sort_name": item.sort_name,
- "metadata": json_serializer(metadata),
- "provider_mappings": json_serializer(provider_mappings),
- },
- )
- self.logger.debug("updated %s in database: %s", item.name, item_id)
- return await self.get_db_item(item_id)
-
- async def _get_provider_dynamic_tracks(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- limit: int = 25,
- ) -> List[Track]:
- """Generate a dynamic list of tracks based on the item's content."""
- raise NotImplementedError("Dynamic tracks not supported for Radio MediaItem")
-
- async def _get_dynamic_tracks(
- self, media_item: Radio, limit: int = 25
- ) -> List[Track]:
- """Get dynamic list of tracks for given item, fallback/default implementation."""
- raise NotImplementedError("Dynamic tracks not supported for Radio MediaItem")
+++ /dev/null
-"""Manage MediaItems of type Track."""
-from __future__ import annotations
-
-import asyncio
-from typing import List, Optional, Union
-
-from music_assistant.controllers.database import TABLE_TRACKS
-from music_assistant.helpers.compare import (
- compare_artists,
- compare_track,
- loose_compare_strings,
-)
-from music_assistant.helpers.json import json_serializer
-from music_assistant.models.enums import (
- EventType,
- MediaType,
- MusicProviderFeature,
- ProviderType,
-)
-from music_assistant.models.errors import (
- MediaNotFoundError,
- UnsupportedFeaturedException,
-)
-from music_assistant.models.event import MassEvent
-from music_assistant.models.media_items import (
- Album,
- Artist,
- ItemMapping,
- Track,
- TrackAlbumMapping,
-)
-
-from .base import MediaControllerBase
-
-
-class TracksController(MediaControllerBase[Track]):
- """Controller managing MediaItems of type Track."""
-
- db_table = TABLE_TRACKS
- media_type = MediaType.TRACK
- item_cls = Track
-
- async def get(self, *args, **kwargs) -> Track:
- """Return (full) details for a single media item."""
- track = await super().get(*args, **kwargs)
- # append full album details to full track item
- if track.album:
- try:
- track.album = await self.mass.music.albums.get(
- track.album.item_id, track.album.provider
- )
- except MediaNotFoundError:
- # edge case where playlist track has invalid albumdetails
- self.logger.warning("Unable to fetch album details %s", track.album.uri)
- # append full artist details to full track item
- full_artists = []
- for artist in track.artists:
- full_artists.append(
- await self.mass.music.artists.get(artist.item_id, artist.provider)
- )
- track.artists = full_artists
- return track
-
- async def add(self, item: Track) -> Track:
- """Add track to local db and return the new database item."""
- # make sure we have artists
- assert item.artists
- # grab additional metadata
- await self.mass.metadata.get_track_metadata(item)
- existing = await self.get_db_item_by_prov_id(item.item_id, item.provider)
- if existing:
- db_item = await self.update_db_item(existing.item_id, item)
- else:
- db_item = await self.add_db_item(item)
- # also fetch same track on all providers (will also get other quality versions)
- await self._match(db_item)
- # return final db_item after all match/metadata actions
- db_item = await self.get_db_item(db_item.item_id)
- self.mass.signal_event(
- MassEvent(
- EventType.MEDIA_ITEM_UPDATED
- if existing
- else EventType.MEDIA_ITEM_ADDED,
- db_item.uri,
- db_item,
- )
- )
- return db_item
-
- async def versions(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- ) -> List[Track]:
- """Return all versions of a track we can find on all providers."""
- assert provider_type or provider_id, "Provider type or ID must be specified"
- track = await self.get(item_id, provider_type or provider_id)
- # perform a search on all provider(types) to collect all versions/variants
- provider_types = {item.type for item in self.mass.music.providers}
- search_query = f"{track.artist.name} - {track.name}"
- all_versions = {
- prov_item.item_id: prov_item
- for prov_items in await asyncio.gather(
- *[
- self.search(search_query, provider_type)
- for provider_type in provider_types
- ]
- )
- for prov_item in prov_items
- if loose_compare_strings(track.name, prov_item.name)
- and compare_artists(prov_item.artists, track.artists, any_match=True)
- }
- # make sure that the 'base' version is included
- for prov_version in track.provider_mappings:
- if prov_version.item_id in all_versions:
- continue
- # grab full item here including album details etc
- prov_track = await self.get_provider_item(
- prov_version.item_id, prov_version.provider_id
- )
- all_versions[prov_version.item_id] = prov_track
-
- # return the aggregated result
- return all_versions.values()
-
- async def _match(self, db_track: Track) -> None:
- """
- Try to find matching track on all providers for the provided (database) track_id.
-
- This is used to link objects of different providers/qualities together.
- """
- if db_track.provider != ProviderType.DATABASE:
- return # Matching only supported for database items
- for provider in self.mass.music.providers:
- if MusicProviderFeature.SEARCH not in provider.supported_features:
- continue
- self.logger.debug(
- "Trying to match track %s on provider %s", db_track.name, provider.name
- )
- match_found = False
- for search_str in (
- db_track.name,
- f"{db_track.artists[0].name} - {db_track.name}",
- f"{db_track.artists[0].name} {db_track.name}",
- ):
- if match_found:
- break
- search_result = await self.search(search_str, provider.type)
- for search_result_item in search_result:
- if not search_result_item.available:
- continue
- if compare_track(search_result_item, db_track):
- # 100% match, we can simply update the db with additional provider ids
- match_found = True
- await self.update_db_item(db_track.item_id, search_result_item)
-
- if not match_found:
- self.logger.debug(
- "Could not find match for Track %s on provider %s",
- db_track.name,
- provider.name,
- )
-
- async def _get_provider_dynamic_tracks(
- self,
- item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- limit: int = 25,
- ):
- """Generate a dynamic list of tracks based on the track."""
- prov = self.mass.music.get_provider(provider_id or provider_type)
- if (
- not prov
- or MusicProviderFeature.SIMILAR_TRACKS not in prov.supported_features
- ):
- return []
- # Grab similar tracks from the music provider
- similar_tracks = await prov.get_similar_tracks(
- prov_track_id=item_id, limit=limit
- )
- return similar_tracks
-
- async def _get_dynamic_tracks(
- self, media_item: Track, limit: int = 25
- ) -> List[Track]:
- """Get dynamic list of tracks for given item, fallback/default implementation."""
- # TODO: query metadata provider(s) to get similar tracks (or tracks from similar artists)
- raise UnsupportedFeaturedException(
- "No Music Provider found that supports requesting similar tracks."
- )
-
- async def add_db_item(self, item: Track, overwrite_existing: bool = False) -> Track:
- """Add a new item record to the database."""
- assert isinstance(item, Track), "Not a full Track object"
- assert item.artists, "Track is missing artist(s)"
- assert item.provider_mappings, "Track is missing provider id(s)"
- async with self._db_add_lock:
- cur_item = None
-
- # always try to grab existing item by external_id
- if item.musicbrainz_id:
- match = {"musicbrainz_id": item.musicbrainz_id}
- cur_item = await self.mass.database.get_row(self.db_table, match)
- for isrc in item.isrcs:
- match = {"isrc": isrc}
- cur_item = await self.mass.database.get_row(self.db_table, match)
- if not cur_item:
- # fallback to matching
- match = {"sort_name": item.sort_name}
- for row in await self.mass.database.get_rows(self.db_table, match):
- row_track = Track.from_db_row(row)
- if compare_track(row_track, item):
- cur_item = row_track
- break
- if cur_item:
- # update existing
- return await self.update_db_item(
- cur_item.item_id, item, overwrite=overwrite_existing
- )
-
- # no existing match found: insert new item
- track_artists = await self._get_track_artists(item)
- track_albums = await self._get_track_albums(
- item, overwrite=overwrite_existing
- )
- if track_artists:
- sort_artist = track_artists[0].sort_name
- else:
- sort_artist = ""
- if track_albums:
- sort_album = track_albums[0].sort_name
- else:
- sort_album = ""
- new_item = await self.mass.database.insert(
- self.db_table,
- {
- **item.to_db_row(),
- "artists": json_serializer(track_artists),
- "albums": json_serializer(track_albums),
- "sort_artist": sort_artist,
- "sort_album": sort_album,
- },
- )
- item_id = new_item["item_id"]
- # return created object
- self.logger.debug("added %s to database: %s", item.name, item_id)
- return await self.get_db_item(item_id)
-
- async def update_db_item(
- self,
- item_id: int,
- item: Track,
- overwrite: bool = False,
- ) -> Track:
- """Update Track record in the database, merging data."""
- cur_item = await self.get_db_item(item_id)
-
- if overwrite:
- metadata = item.metadata
- provider_mappings = item.provider_mappings
- metadata.last_refresh = None
- # we store a mapping to artists/albums on the item for easier access/listings
- track_artists = await self._get_track_artists(item, overwrite=True)
- track_albums = await self._get_track_albums(item, overwrite=True)
- else:
- metadata = cur_item.metadata.update(item.metadata, item.provider.is_file())
- provider_mappings = {*cur_item.provider_mappings, *item.provider_mappings}
- track_artists = await self._get_track_artists(cur_item, item)
- track_albums = await self._get_track_albums(cur_item, item)
-
- await self.mass.database.update(
- self.db_table,
- {"item_id": item_id},
- {
- "name": item.name if overwrite else cur_item.name,
- "sort_name": item.sort_name if overwrite else cur_item.sort_name,
- "version": item.version if overwrite else cur_item.version,
- "duration": item.duration if overwrite else cur_item.duration,
- "artists": json_serializer(track_artists),
- "albums": json_serializer(track_albums),
- "metadata": json_serializer(metadata),
- "provider_mappings": json_serializer(provider_mappings),
- "isrc": item.isrc or cur_item.isrc,
- },
- )
- self.logger.debug("updated %s in database: %s", item.name, item_id)
- return await self.get_db_item(item_id)
-
- async def _get_track_artists(
- self,
- base_track: Track,
- upd_track: Optional[Track] = None,
- overwrite: bool = False,
- ) -> List[ItemMapping]:
- """Extract all (unique) artists of track as ItemMapping."""
- if upd_track and upd_track.artists:
- track_artists = upd_track.artists
- else:
- track_artists = base_track.artists
- # use intermediate set to clear out duplicates
- return list(
- {await self._get_artist_mapping(x, overwrite) for x in track_artists}
- )
-
- async def _get_track_albums(
- self,
- base_track: Track,
- upd_track: Optional[Track] = None,
- overwrite: bool = False,
- ) -> List[TrackAlbumMapping]:
- """Extract all (unique) albums of track as TrackAlbumMapping."""
- track_albums: List[TrackAlbumMapping] = []
- # existing TrackAlbumMappings are starting point
- if base_track.albums:
- track_albums = base_track.albums
- elif upd_track and upd_track.albums:
- track_albums = upd_track.albums
- # append update item album if needed
- if upd_track and upd_track.album:
- mapping = await self._get_album_mapping(
- upd_track.album, overwrite=overwrite
- )
- mapping = TrackAlbumMapping.from_dict(
- {
- **mapping.to_dict(),
- "disc_number": upd_track.disc_number,
- "track_number": upd_track.track_number,
- }
- )
- if mapping not in track_albums:
- track_albums.append(mapping)
- # append base item album if needed
- elif base_track and base_track.album:
- mapping = await self._get_album_mapping(
- base_track.album, overwrite=overwrite
- )
- mapping = TrackAlbumMapping.from_dict(
- {
- **mapping.to_dict(),
- "disc_number": base_track.disc_number,
- "track_number": base_track.track_number,
- }
- )
- if mapping not in track_albums:
- track_albums.append(mapping)
-
- return track_albums
-
- async def _get_album_mapping(
- self,
- album: Union[Album, ItemMapping],
- overwrite: bool = False,
- ) -> ItemMapping:
- """Extract (database) album as ItemMapping."""
-
- if album.provider == ProviderType.DATABASE:
- if isinstance(album, ItemMapping):
- return album
- return ItemMapping.from_item(album)
-
- if overwrite:
- db_album = await self.mass.music.albums.add_db_item(
- album, overwrite_existing=True
- )
-
- if db_album := await self.mass.music.albums.get_db_item_by_prov_id(
- album.item_id, provider_type=album.provider
- ):
- return ItemMapping.from_item(db_album)
-
- db_album = await self.mass.music.albums.add_db_item(
- album, overwrite_existing=overwrite
- )
- return ItemMapping.from_item(db_album)
-
- async def _get_artist_mapping(
- self, artist: Union[Artist, ItemMapping], overwrite: bool = False
- ) -> ItemMapping:
- """Extract (database) track artist as ItemMapping."""
-
- if artist.provider == ProviderType.DATABASE:
- if isinstance(artist, ItemMapping):
- return artist
- return ItemMapping.from_item(artist)
-
- if overwrite:
- artist = await self.mass.music.artists.add_db_item(
- artist, overwrite_existing=True
- )
-
- if db_artist := await self.mass.music.artists.get_db_item_by_prov_id(
- artist.item_id, provider_type=artist.provider
- ):
- return ItemMapping.from_item(db_artist)
-
- db_artist = await self.mass.music.artists.add_db_item(artist)
- return ItemMapping.from_item(db_artist)
+++ /dev/null
-"""Package with Metadata controller and providers."""
-
-from .metadata import MetaDataController # noqa
+++ /dev/null
-"""TheAudioDb Metadata provider."""
-from __future__ import annotations
-
-from json.decoder import JSONDecodeError
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
-
-import aiohttp
-from asyncio_throttle import Throttler
-
-from music_assistant.controllers.cache import use_cache
-from music_assistant.helpers.app_vars import ( # pylint: disable=no-name-in-module
- app_var,
-)
-from music_assistant.helpers.compare import compare_strings
-from music_assistant.models.media_items import (
- Album,
- AlbumType,
- Artist,
- ImageType,
- LinkType,
- MediaItemImage,
- MediaItemLink,
- MediaItemMetadata,
- Track,
-)
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-IMG_MAPPING = {
- "strArtistThumb": ImageType.THUMB,
- "strArtistLogo": ImageType.LOGO,
- "strArtistCutout": ImageType.CUTOUT,
- "strArtistClearart": ImageType.CLEARART,
- "strArtistWideThumb": ImageType.LANDSCAPE,
- "strArtistFanart": ImageType.FANART,
- "strArtistBanner": ImageType.BANNER,
- "strAlbumThumb": ImageType.THUMB,
- "strAlbumThumbHQ": ImageType.THUMB,
- "strAlbumCDart": ImageType.DISCART,
- "strAlbum3DCase": ImageType.OTHER,
- "strAlbum3DFlat": ImageType.OTHER,
- "strAlbum3DFace": ImageType.OTHER,
- "strAlbum3DThumb": ImageType.OTHER,
- "strTrackThumb": ImageType.THUMB,
- "strTrack3DCase": ImageType.OTHER,
-}
-
-LINK_MAPPING = {
- "strWebsite": LinkType.WEBSITE,
- "strFacebook": LinkType.FACEBOOK,
- "strTwitter": LinkType.TWITTER,
- "strLastFMChart": LinkType.LASTFM,
-}
-
-ALBUMTYPE_MAPPING = {
- "Single": AlbumType.SINGLE,
- "Compilation": AlbumType.COMPILATION,
- "Album": AlbumType.ALBUM,
-}
-
-
-class TheAudioDb:
- """TheAudioDb metadata provider."""
-
- def __init__(self, mass: MusicAssistant):
- """Initialize class."""
- self.mass = mass
- self.cache = mass.cache
- self.logger = mass.logger.getChild("audiodb")
- self.throttler = Throttler(rate_limit=2, period=1)
-
- async def get_artist_metadata(self, artist: Artist) -> MediaItemMetadata | None:
- """Retrieve metadata for artist on theaudiodb."""
- self.logger.debug("Fetching metadata for Artist %s on TheAudioDb", artist.name)
- if data := await self._get_data("artist-mb.php", i=artist.musicbrainz_id):
- if data.get("artists"):
- return self.__parse_artist(data["artists"][0])
- return None
-
- async def get_album_metadata(self, album: Album) -> MediaItemMetadata | None:
- """Retrieve metadata for album on theaudiodb."""
- adb_album = None
- if album.musicbrainz_id:
- result = await self._get_data("album-mb.php", i=album.musicbrainz_id)
- if result and result.get("album"):
- adb_album = result["album"][0]
- elif album.artist:
- # lookup by name
- result = await self._get_data(
- "searchalbum.php", s=album.artist.name, a=album.name
- )
- if result and result.get("album"):
- for item in result["album"]:
- if album.artist.musicbrainz_id:
- if (
- album.artist.musicbrainz_id
- != item["strMusicBrainzArtistID"]
- ):
- continue
- elif not compare_strings(
- album.artist.name, item["strArtistStripped"]
- ):
- continue
- if compare_strings(album.name, item["strAlbumStripped"]):
- adb_album = item
- break
- if adb_album:
- if not album.year:
- album.year = int(adb_album.get("intYearReleased", "0"))
- if not album.musicbrainz_id:
- album.musicbrainz_id = adb_album["strMusicBrainzID"]
- if album.artist and not album.artist.musicbrainz_id:
- album.artist.musicbrainz_id = adb_album["strMusicBrainzArtistID"]
- if album.album_type == AlbumType.UNKNOWN:
- album.album_type = ALBUMTYPE_MAPPING.get(
- adb_album.get("strReleaseFormat"), AlbumType.UNKNOWN
- )
- return self.__parse_album(adb_album)
- return None
-
- async def get_track_metadata(self, track: Track) -> MediaItemMetadata | None:
- """Retrieve metadata for track on theaudiodb."""
- adb_track = None
- if track.musicbrainz_id:
- result = await self._get_data("track-mb.php", i=track.musicbrainz_id)
- if result and result.get("track"):
- return self.__parse_track(result["track"][0])
-
- # lookup by name
- for track_artist in track.artists:
- result = await self._get_data(
- "searchtrack.php?", s=track_artist.name, t=track.name
- )
- if result and result.get("track"):
- for item in result["track"]:
- if track_artist.musicbrainz_id:
- if (
- track_artist.musicbrainz_id
- != item["strMusicBrainzArtistID"]
- ):
- continue
- elif not compare_strings(track_artist.name, item["strArtist"]):
- continue
- if compare_strings(track.name, item["strTrack"]):
- adb_track = item
- break
- if adb_track:
- if not track.musicbrainz_id:
- track.musicbrainz_id = adb_track["strMusicBrainzID"]
- if track.album and not track.album.musicbrainz_id:
- track.album.musicbrainz_id = adb_track["strMusicBrainzAlbumID"]
- if not track_artist.musicbrainz_id:
- track_artist.musicbrainz_id = adb_track["strMusicBrainzArtistID"]
-
- return self.__parse_track(adb_track)
- return None
-
- async def get_musicbrainz_id(
- self, artist: Artist, ref_albums: List[Album]
- ) -> str | None:
- """Try to discover MusicBrainz ID for an artist given some reference albums."""
- self.logger.debug(
- "Lookup MusicbrainzID for Artist %s on TheAudioDb", artist.name
- )
- musicbrainz_id = None
- if data := await self._get_data("searchalbum.php", s=artist.name):
- # NOTE: object is 'null' when no records found instead of empty array
- albums = data.get("album") or []
- for item in albums:
- if not compare_strings(item["strArtistStripped"], artist.name):
- continue
- for ref_album in ref_albums:
- if not compare_strings(item["strAlbumStripped"], ref_album.name):
- continue
- # found match - update album metadata too while we're here
- if not ref_album.musicbrainz_id:
- ref_album.metadata = self.__parse_album(item)
- await self.mass.music.albums.add_db_item(ref_album)
- musicbrainz_id = item["strMusicBrainzArtistID"]
- if musicbrainz_id:
- self.logger.debug(
- "Found MusicBrainzID for artist %s on TheAudioDb", artist.name
- )
- return musicbrainz_id
-
- def __parse_artist(self, artist_obj: Dict[str, Any]) -> MediaItemMetadata:
- """Parse audiodb artist object to MediaItemMetadata."""
- metadata = MediaItemMetadata()
- # generic data
- metadata.label = artist_obj.get("strLabel")
- metadata.style = artist_obj.get("strStyle")
- if genre := artist_obj.get("strGenre"):
- metadata.genres = {genre}
- metadata.mood = artist_obj.get("strMood")
- # links
- metadata.links = set()
- for key, link_type in LINK_MAPPING.items():
- if link := artist_obj.get(key):
- metadata.links.add(MediaItemLink(link_type, link))
- # description/biography
- if desc := artist_obj.get(
- f"strBiography{self.mass.metadata.preferred_language}"
- ):
- metadata.description = desc
- else:
- metadata.description = artist_obj.get("strBiographyEN")
- # images
- metadata.images = []
- for key, img_type in IMG_MAPPING.items():
- for postfix in ("", "2", "3", "4", "5", "6", "7", "8", "9", "10"):
- if img := artist_obj.get(f"{key}{postfix}"):
- metadata.images.append(MediaItemImage(img_type, img))
- else:
- break
- return metadata
-
- def __parse_album(self, album_obj: Dict[str, Any]) -> MediaItemMetadata:
- """Parse audiodb album object to MediaItemMetadata."""
- metadata = MediaItemMetadata()
- # generic data
- metadata.label = album_obj.get("strLabel")
- metadata.style = album_obj.get("strStyle")
- if genre := album_obj.get("strGenre"):
- metadata.genres = {genre}
- metadata.mood = album_obj.get("strMood")
- # links
- metadata.links = set()
- if link := album_obj.get("strWikipediaID"):
- metadata.links.add(
- MediaItemLink(LinkType.WIKIPEDIA, f"https://wikipedia.org/wiki/{link}")
- )
- if link := album_obj.get("strAllMusicID"):
- metadata.links.add(
- MediaItemLink(
- LinkType.ALLMUSIC, f"https://www.allmusic.com/album/{link}"
- )
- )
-
- # description
- if desc := album_obj.get(
- f"strDescription{self.mass.metadata.preferred_language}"
- ):
- metadata.description = desc
- else:
- metadata.description = album_obj.get("strDescriptionEN")
- metadata.review = album_obj.get("strReview")
- # images
- metadata.images = []
- for key, img_type in IMG_MAPPING.items():
- for postfix in ("", "2", "3", "4", "5", "6", "7", "8", "9", "10"):
- if img := album_obj.get(f"{key}{postfix}"):
- metadata.images.append(MediaItemImage(img_type, img))
- else:
- break
- return metadata
-
- def __parse_track(self, track_obj: Dict[str, Any]) -> MediaItemMetadata:
- """Parse audiodb track object to MediaItemMetadata."""
- metadata = MediaItemMetadata()
- # generic data
- metadata.lyrics = track_obj.get("strTrackLyrics")
- metadata.style = track_obj.get("strStyle")
- if genre := track_obj.get("strGenre"):
- metadata.genres = {genre}
- metadata.mood = track_obj.get("strMood")
- # description
- if desc := track_obj.get(
- f"strDescription{self.mass.metadata.preferred_language}"
- ):
- metadata.description = desc
- else:
- metadata.description = track_obj.get("strDescriptionEN")
- # images
- metadata.images = []
- for key, img_type in IMG_MAPPING.items():
- for postfix in ("", "2", "3", "4", "5", "6", "7", "8", "9", "10"):
- if img := track_obj.get(f"{key}{postfix}"):
- metadata.images.append(MediaItemImage(img_type, img))
- else:
- break
- return metadata
-
- @use_cache(86400 * 14)
- async def _get_data(self, endpoint, **kwargs) -> Optional[dict]:
- """Get data from api."""
- url = f"https://theaudiodb.com/api/v1/json/{app_var(3)}/{endpoint}"
- async with self.throttler:
- async with self.mass.http_session.get(
- url, params=kwargs, verify_ssl=False
- ) as response:
- try:
- result = await response.json()
- except (
- aiohttp.ContentTypeError,
- JSONDecodeError,
- ):
- self.logger.error("Failed to retrieve %s", endpoint)
- text_result = await response.text()
- self.logger.debug(text_result)
- return None
- except (
- aiohttp.ClientConnectorError,
- aiohttp.client_exceptions.ServerDisconnectedError,
- ):
- self.logger.warning("Failed to retrieve %s", endpoint)
- return None
- if "error" in result and "limit" in result["error"]:
- self.logger.warning(result["error"])
- return None
- return result
+++ /dev/null
-"""FanartTv Metadata provider."""
-from __future__ import annotations
-
-from json.decoder import JSONDecodeError
-from typing import TYPE_CHECKING, Optional
-
-import aiohttp
-from asyncio_throttle import Throttler
-
-from music_assistant.controllers.cache import use_cache
-from music_assistant.helpers.app_vars import ( # pylint: disable=no-name-in-module
- app_var,
-)
-from music_assistant.models.media_items import (
- Album,
- Artist,
- ImageType,
- MediaItemImage,
- MediaItemMetadata,
-)
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-# TODO: add support for personal api keys ?
-
-
-IMG_MAPPING = {
- "artistthumb": ImageType.THUMB,
- "hdmusiclogo": ImageType.LOGO,
- "musicbanner": ImageType.BANNER,
- "artistbackground": ImageType.FANART,
-}
-
-
-class FanartTv:
- """Fanart.tv metadata provider."""
-
- def __init__(self, mass: MusicAssistant):
- """Initialize class."""
- self.mass = mass
- self.cache = mass.cache
- self.logger = mass.logger.getChild("fanarttv")
- self.throttler = Throttler(rate_limit=2, period=1)
-
- async def get_artist_metadata(self, artist: Artist) -> MediaItemMetadata | None:
- """Retrieve metadata for artist on fanart.tv."""
- if not artist.musicbrainz_id:
- return
- self.logger.debug("Fetching metadata for Artist %s on Fanart.tv", artist.name)
- if data := await self._get_data(f"music/{artist.musicbrainz_id}"):
- metadata = MediaItemMetadata()
- metadata.images = []
- for key, img_type in IMG_MAPPING.items():
- items = data.get(key)
- if not items:
- continue
- for item in items:
- metadata.images.append(MediaItemImage(img_type, item["url"]))
- return metadata
- return None
-
- async def get_album_metadata(self, album: Album) -> MediaItemMetadata | None:
- """Retrieve metadata for album on fanart.tv."""
- if not album.musicbrainz_id:
- return
- self.logger.debug("Fetching metadata for Album %s on Fanart.tv", album.name)
- if data := await self._get_data(f"music/albums/{album.musicbrainz_id}"):
- if data and data.get("albums"):
- data = data["albums"][album.musicbrainz_id]
- metadata = MediaItemMetadata()
- metadata.images = []
- for key, img_type in IMG_MAPPING.items():
- items = data.get(key)
- if not items:
- continue
- for item in items:
- metadata.images.append(MediaItemImage(img_type, item["url"]))
- return metadata
- return None
-
- @use_cache(86400 * 14)
- async def _get_data(self, endpoint, **kwargs) -> Optional[dict]:
- """Get data from api."""
- url = f"http://webservice.fanart.tv/v3/{endpoint}"
- kwargs["api_key"] = app_var(4)
- async with self.throttler:
- async with self.mass.http_session.get(
- url, params=kwargs, verify_ssl=False
- ) as response:
- try:
- result = await response.json()
- except (
- aiohttp.ContentTypeError,
- JSONDecodeError,
- ):
- self.logger.error("Failed to retrieve %s", endpoint)
- text_result = await response.text()
- self.logger.debug(text_result)
- return None
- except (
- aiohttp.ClientConnectorError,
- aiohttp.client_exceptions.ServerDisconnectedError,
- ):
- self.logger.warning("Failed to retrieve %s", endpoint)
- return None
- if "error" in result and "limit" in result["error"]:
- self.logger.warning(result["error"])
- return None
- return result
+++ /dev/null
-"""All logic for metadata retrieval."""
-from __future__ import annotations
-
-from base64 import b64encode
-from time import time
-from typing import TYPE_CHECKING, Optional
-
-from music_assistant.controllers.database import TABLE_THUMBS
-from music_assistant.helpers.images import create_collage, create_thumbnail
-from music_assistant.models.enums import ImageType, MediaType
-from music_assistant.models.media_items import (
- Album,
- Artist,
- ItemMapping,
- MediaItemImage,
- MediaItemType,
- Playlist,
- Radio,
- Track,
-)
-
-from .audiodb import TheAudioDb
-from .fanarttv import FanartTv
-from .musicbrainz import MusicBrainz
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-
-class MetaDataController:
- """Several helpers to search and store metadata for mediaitems."""
-
- def __init__(self, mass: MusicAssistant) -> None:
- """Initialize class."""
- self.mass = mass
- self.cache = mass.cache
- self.logger = mass.logger.getChild("metadata")
- self.fanarttv = FanartTv(mass)
- self.musicbrainz = MusicBrainz(mass)
- self.audiodb = TheAudioDb(mass)
- self._pref_lang: Optional[str] = None
-
- @property
- def preferred_language(self) -> str:
- """
- Return preferred language for metadata as 2 letter country code (uppercase).
-
- Defaults to English (EN).
- """
- return self._pref_lang or "EN"
-
- @preferred_language.setter
- def preferred_language(self, lang: str) -> None:
- """
- Set preferred language to 2 letter country code.
-
- Can only be set once.
- """
- if self._pref_lang is None:
- self._pref_lang = lang.upper()
-
- async def setup(self):
- """Async initialize of module."""
-
- async def get_artist_metadata(self, artist: Artist) -> None:
- """Get/update rich metadata for an artist."""
- # set timestamp, used to determine when this function was last called
- artist.metadata.last_refresh = int(time())
-
- if not artist.musicbrainz_id:
- artist.musicbrainz_id = await self.get_artist_musicbrainz_id(artist)
-
- if artist.musicbrainz_id:
- if metadata := await self.fanarttv.get_artist_metadata(artist):
- artist.metadata.update(metadata)
- if metadata := await self.audiodb.get_artist_metadata(artist):
- artist.metadata.update(metadata)
-
- async def get_album_metadata(self, album: Album) -> None:
- """Get/update rich metadata for an album."""
- # set timestamp, used to determine when this function was last called
- album.metadata.last_refresh = int(time())
-
- if not (album.musicbrainz_id or album.artist):
- return
- if metadata := await self.audiodb.get_album_metadata(album):
- album.metadata.update(metadata)
- if metadata := await self.fanarttv.get_album_metadata(album):
- album.metadata.update(metadata)
-
- async def get_track_metadata(self, track: Track) -> None:
- """Get/update rich metadata for a track."""
- # set timestamp, used to determine when this function was last called
- track.metadata.last_refresh = int(time())
-
- if not (track.album and track.artists):
- return
- if metadata := await self.audiodb.get_track_metadata(track):
- track.metadata.update(metadata)
-
- async def get_playlist_metadata(self, playlist: Playlist) -> None:
- """Get/update rich metadata for a playlist."""
- # set timestamp, used to determine when this function was last called
- playlist.metadata.last_refresh = int(time())
- # retrieve genres from tracks
- # TODO: retrieve style/mood ?
- playlist.metadata.genres = set()
- image_urls = set()
- for track in await self.mass.music.playlists.tracks(
- playlist.item_id, playlist.provider
- ):
- if not playlist.image and track.image:
- image_urls.add(track.image.url)
- if track.media_type != MediaType.TRACK:
- # filter out radio items
- continue
- if track.metadata.genres:
- playlist.metadata.genres.update(track.metadata.genres)
- elif track.album and track.album.metadata.genres:
- playlist.metadata.genres.update(track.album.metadata.genres)
- # create collage thumb/fanart from playlist tracks
- if image_urls:
- fake_path = f"playlist_collage.{playlist.provider.value}.{playlist.item_id}"
- collage = await create_collage(self.mass, list(image_urls))
- match = {"path": fake_path, "size": 0}
- await self.mass.database.insert(
- TABLE_THUMBS, {**match, "data": collage}, allow_replace=True
- )
- playlist.metadata.images = [
- MediaItemImage(ImageType.THUMB, fake_path, True)
- ]
-
- async def get_radio_metadata(self, radio: Radio) -> None:
- """Get/update rich metadata for a radio station."""
- # NOTE: we do not have any metadata for radio so consider this future proofing ;-)
- radio.metadata.last_refresh = int(time())
-
- async def get_artist_musicbrainz_id(self, artist: Artist) -> str | None:
- """Fetch musicbrainz id by performing search using the artist name, albums and tracks."""
- ref_albums = await self.mass.music.artists.albums(artist=artist)
- # first try audiodb
- if musicbrainz_id := await self.audiodb.get_musicbrainz_id(artist, ref_albums):
- return musicbrainz_id
- # try again with musicbrainz with albums with upc
- for ref_album in ref_albums:
- if ref_album.upc:
- if musicbrainz_id := await self.musicbrainz.get_mb_artist_id(
- artist.name,
- album_upc=ref_album.upc,
- ):
- return musicbrainz_id
- if ref_album.musicbrainz_id:
- if musicbrainz_id := await self.musicbrainz.search_artist_by_album_mbid(
- artist.name, ref_album.musicbrainz_id
- ):
- return musicbrainz_id
-
- # try again with matching on track isrc
- ref_tracks = await self.mass.music.artists.tracks(artist=artist)
- for ref_track in ref_tracks:
- for isrc in ref_track.isrcs:
- if musicbrainz_id := await self.musicbrainz.get_mb_artist_id(
- artist.name,
- track_isrc=isrc,
- ):
- return musicbrainz_id
-
- # last restort: track matching by name
- for ref_track in ref_tracks:
- if musicbrainz_id := await self.musicbrainz.get_mb_artist_id(
- artist.name,
- trackname=ref_track.name,
- ):
- return musicbrainz_id
- # lookup failed
- ref_albums_str = "/".join(x.name for x in ref_albums) or "none"
- ref_tracks_str = "/".join(x.name for x in ref_tracks) or "none"
- self.logger.info(
- "Unable to get musicbrainz ID for artist %s\n"
- " - using lookup-album(s): %s\n"
- " - using lookup-track(s): %s\n",
- artist.name,
- ref_albums_str,
- ref_tracks_str,
- )
- return None
-
- async def get_image_data_for_item(
- self,
- media_item: MediaItemType,
- img_type: ImageType = ImageType.THUMB,
- size: int = 0,
- ) -> bytes | None:
- """Get image data for given MedaItem."""
- img_path = await self.get_image_url_for_item(
- media_item=media_item,
- img_type=img_type,
- allow_local=True,
- local_as_base64=False,
- )
- if not img_path:
- return None
- return await self.get_thumbnail(img_path, size)
-
- async def get_image_url_for_item(
- self,
- media_item: MediaItemType,
- img_type: ImageType = ImageType.THUMB,
- allow_local: bool = True,
- local_as_base64: bool = False,
- ) -> str | None:
- """Get url to image for given media media_item."""
- if not media_item:
- return None
- if isinstance(media_item, ItemMapping):
- media_item = await self.mass.music.get_item_by_uri(media_item.uri)
- if media_item and media_item.metadata.images:
- for img in media_item.metadata.images:
- if img.type != img_type:
- continue
- if img.is_file and not allow_local:
- continue
- if img.is_file and local_as_base64:
- # return base64 string of the image (compatible with browsers)
- return await self.get_thumbnail(img.url, base64=True)
- return img.url
-
- # retry with track's album
- if media_item.media_type == MediaType.TRACK and media_item.album:
- return await self.get_image_url_for_item(
- media_item.album, img_type, allow_local, local_as_base64
- )
-
- # try artist instead for albums
- if media_item.media_type == MediaType.ALBUM and media_item.artist:
- return await self.get_image_url_for_item(
- media_item.artist, img_type, allow_local, local_as_base64
- )
-
- # last resort: track artist(s)
- if media_item.media_type == MediaType.TRACK and media_item.artists:
- for artist in media_item.artists:
- return await self.get_image_url_for_item(
- artist, img_type, allow_local, local_as_base64
- )
-
- return None
-
- async def get_thumbnail(
- self, path: str, size: int = 0, base64: bool = False
- ) -> bytes | str:
- """Get/create thumbnail image for path (image url or local path)."""
- # check if we already have this cached in the db
- match_path = path.split("?")[0].split("&")[0]
- match = {"path": match_path, "size": size}
- if result := await self.mass.database.get_row(TABLE_THUMBS, match):
- thumbnail = result["data"]
- else:
- # create thumbnail if it doesn't exist
- thumbnail = await create_thumbnail(self.mass, path, size)
- await self.mass.database.insert(
- TABLE_THUMBS, {**match, "data": thumbnail}, allow_replace=True
- )
- if base64:
- enc_image = b64encode(thumbnail).decode()
- thumbnail = f"data:image/png;base64,{enc_image}"
- return thumbnail
+++ /dev/null
-"""Handle getting Id's from MusicBrainz."""
-from __future__ import annotations
-
-import re
-from json.decoder import JSONDecodeError
-from typing import TYPE_CHECKING
-
-import aiohttp
-from asyncio_throttle import Throttler
-
-from music_assistant.controllers.cache import use_cache
-from music_assistant.helpers.compare import compare_strings
-from music_assistant.helpers.util import create_sort_name
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-LUCENE_SPECIAL = r'([+\-&|!(){}\[\]\^"~*?:\\\/])'
-
-
-class MusicBrainz:
- """Handle getting Id's from MusicBrainz."""
-
- def __init__(self, mass: MusicAssistant):
- """Initialize class."""
- self.mass = mass
- self.cache = mass.cache
- self.logger = mass.logger.getChild("musicbrainz")
- self.throttler = Throttler(rate_limit=1, period=1)
-
- async def get_mb_artist_id(
- self,
- artistname,
- albumname=None,
- album_upc=None,
- trackname=None,
- track_isrc=None,
- ):
- """Retrieve musicbrainz artist id for the given details."""
- self.logger.debug(
- "searching musicbrainz for %s \
- (albumname: %s - album_upc: %s - trackname: %s - track_isrc: %s)",
- artistname,
- albumname,
- album_upc,
- trackname,
- track_isrc,
- )
-
- if album_upc:
- if mb_id := await self.search_artist_by_album(artistname, None, album_upc):
- self.logger.debug(
- "Got MusicbrainzArtistId for %s after search on upc %s --> %s",
- artistname,
- album_upc,
- mb_id,
- )
- return mb_id
- if track_isrc:
- if mb_id := await self.search_artist_by_track(artistname, None, track_isrc):
- self.logger.debug(
- "Got MusicbrainzArtistId for %s after search on isrc %s --> %s",
- artistname,
- track_isrc,
- mb_id,
- )
- return mb_id
- if albumname:
- if mb_id := await self.search_artist_by_album(artistname, albumname):
- self.logger.debug(
- "Got MusicbrainzArtistId for %s after search on albumname %s --> %s",
- artistname,
- albumname,
- mb_id,
- )
- return mb_id
- if trackname:
- if mb_id := await self.search_artist_by_track(artistname, trackname):
- self.logger.debug(
- "Got MusicbrainzArtistId for %s after search on trackname %s --> %s",
- artistname,
- trackname,
- mb_id,
- )
- return mb_id
- return None
-
- async def search_artist_by_album(self, artistname, albumname=None, album_upc=None):
- """Retrieve musicbrainz artist id by providing the artist name and albumname or upc."""
- for searchartist in (
- artistname,
- re.sub(LUCENE_SPECIAL, r"\\\1", create_sort_name(artistname)),
- ):
- if album_upc:
- # search by album UPC (barcode)
- query = f"barcode:{album_upc}"
- else:
- # search by name
- searchalbum = re.sub(LUCENE_SPECIAL, r"\\\1", albumname)
- query = f'artist:"{searchartist}" AND release:"{searchalbum}"'
- result = await self.get_data("release", query=query)
- if result and "releases" in result:
- for strict in (True, False):
- for item in result["releases"]:
- if not (
- album_upc
- or compare_strings(item["title"], albumname, strict)
- ):
- continue
- for artist in item["artist-credit"]:
- if compare_strings(
- artist["artist"]["name"], artistname, strict
- ):
- return artist["artist"]["id"]
- for alias in artist.get("aliases", []):
- if compare_strings(alias["name"], artistname, strict):
- return artist["id"]
- return ""
-
- async def search_artist_by_track(self, artistname, trackname=None, track_isrc=None):
- """Retrieve artist id by providing the artist name and trackname or track isrc."""
- searchartist = re.sub(LUCENE_SPECIAL, r"\\\1", artistname)
- if track_isrc:
- result = await self.get_data(f"isrc/{track_isrc}", inc="artist-credits")
- else:
- searchtrack = re.sub(LUCENE_SPECIAL, r"\\\1", trackname)
- result = await self.get_data(
- "recording", query=f'"{searchtrack}" AND artist:"{searchartist}"'
- )
- if result and "recordings" in result:
- for strict in (True, False):
- for item in result["recordings"]:
- if not (
- track_isrc or compare_strings(item["title"], trackname, strict)
- ):
- continue
- for artist in item["artist-credit"]:
- if compare_strings(
- artist["artist"]["name"], artistname, strict
- ):
- return artist["artist"]["id"]
- for alias in artist.get("aliases", []):
- if compare_strings(alias["name"], artistname, strict):
- return artist["id"]
- return ""
-
- async def search_artist_by_album_mbid(
- self, artistname, album_mbid: str
- ) -> str | None:
- """Retrieve musicbrainz artist id by providing the artist name and albumname or upc."""
- result = await self.get_data(f"release-group/{album_mbid}?inc=artist-credits")
- if result and "artist-credit" in result:
- for item in result["artist-credit"]:
- if artist := item.get("artist"):
- if compare_strings(artistname, artist["name"]):
- return artist["id"]
- return None
-
- @use_cache(86400 * 30)
- async def get_data(self, endpoint: str, **kwargs):
- """Get data from api."""
- url = f"http://musicbrainz.org/ws/2/{endpoint}"
- headers = {
- "User-Agent": "Music Assistant/1.0.0 https://github.com/music-assistant"
- }
- kwargs["fmt"] = "json"
- async with self.throttler:
- async with self.mass.http_session.get(
- url, headers=headers, params=kwargs, verify_ssl=False
- ) as response:
- try:
- result = await response.json()
- except (
- aiohttp.client_exceptions.ContentTypeError,
- JSONDecodeError,
- ) as exc:
- msg = await response.text()
- self.logger.warning("%s - %s", str(exc), msg)
- result = None
- return result
+++ /dev/null
-"""MusicController: Orchestrates all data from music providers and sync to internal database."""
-from __future__ import annotations
-
-import asyncio
-import itertools
-import statistics
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
-
-from music_assistant.controllers.database import TABLE_PLAYLOG, TABLE_TRACK_LOUDNESS
-from music_assistant.controllers.media.albums import AlbumsController
-from music_assistant.controllers.media.artists import ArtistsController
-from music_assistant.controllers.media.playlists import PlaylistController
-from music_assistant.controllers.media.radio import RadioController
-from music_assistant.controllers.media.tracks import TracksController
-from music_assistant.helpers.datetime import utc_timestamp
-from music_assistant.helpers.uri import parse_uri
-from music_assistant.models.config import MusicProviderConfig
-from music_assistant.models.enums import MediaType, MusicProviderFeature, ProviderType
-from music_assistant.models.errors import (
- MusicAssistantError,
- ProviderUnavailableError,
- SetupFailedError,
-)
-from music_assistant.models.media_items import (
- BrowseFolder,
- MediaItem,
- MediaItemType,
- media_from_dict,
-)
-from music_assistant.models.music_provider import MusicProvider
-from music_assistant.music_providers.filesystem import (
- LocalFileSystemProvider,
- SMBFileSystemProvider,
-)
-from music_assistant.music_providers.qobuz import QobuzProvider
-from music_assistant.music_providers.spotify import SpotifyProvider
-from music_assistant.music_providers.tunein import TuneInProvider
-from music_assistant.music_providers.url import URLProvider
-from music_assistant.music_providers.url.url import PROVIDER_CONFIG as URL_CONFIG
-from music_assistant.music_providers.ytmusic import YoutubeMusicProvider
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-PROV_MAP = {
- ProviderType.FILESYSTEM_LOCAL: LocalFileSystemProvider,
- ProviderType.FILESYSTEM_SMB: SMBFileSystemProvider,
- ProviderType.SPOTIFY: SpotifyProvider,
- ProviderType.QOBUZ: QobuzProvider,
- ProviderType.TUNEIN: TuneInProvider,
- ProviderType.YTMUSIC: YoutubeMusicProvider,
-}
-
-
-class MusicController:
- """Several helpers around the musicproviders."""
-
- def __init__(self, mass: MusicAssistant):
- """Initialize class."""
- self.logger = mass.logger.getChild("music")
- self.mass = mass
- self.artists = ArtistsController(mass)
- self.albums = AlbumsController(mass)
- self.tracks = TracksController(mass)
- self.radio = RadioController(mass)
- self.playlists = PlaylistController(mass)
- self._providers: Dict[str, MusicProvider] = {}
-
- async def setup(self):
- """Async initialize of module."""
- # register providers
- for prov_conf in self.mass.config.providers:
- prov_cls = PROV_MAP[prov_conf.type]
- await self._register_provider(prov_cls(self.mass, prov_conf), prov_conf)
- # always register url provider
- await self._register_provider(URLProvider(self.mass, URL_CONFIG), URL_CONFIG)
- # add job to cleanup old records from db
- self.mass.add_job(
- self._cleanup_library(),
- "Cleanup removed items from database",
- allow_duplicate=False,
- )
-
- async def start_sync(
- self,
- media_types: Optional[Tuple[MediaType]] = None,
- provider_types: Optional[Tuple[ProviderType]] = None,
- schedule: Optional[float] = None,
- ) -> None:
- """
- Start running the sync of all registred providers.
-
- media_types: only sync these media types. None for all.
- provider_types: only sync these provider types. None for all.
- schedule: schedule syncjob every X hours, set to None for just a manual sync run.
- """
-
- async def do_sync():
- while True:
- for prov in self.providers:
- if provider_types is not None and prov.type not in provider_types:
- continue
- self.mass.add_job(
- prov.sync_library(media_types),
- f"Library sync for provider {prov.name}",
- allow_duplicate=False,
- )
- if schedule is None:
- return
- await asyncio.sleep(3600 * schedule)
-
- self.mass.create_task(do_sync())
-
- @property
- def provider_count(self) -> int:
- """Return count of all registered music providers."""
- return len(self._providers)
-
- @property
- def providers(self) -> Tuple[MusicProvider]:
- """Return all (available) music providers."""
- return tuple(x for x in self._providers.values() if x.available)
-
- def get_provider(
- self, provider_id_or_type: Union[str, ProviderType]
- ) -> MusicProvider:
- """Return Music provider by id (or type)."""
- if prov := self._providers.get(provider_id_or_type):
- return prov
- for prov in self._providers.values():
- if provider_id_or_type in (prov.type, prov.id, prov.type.value):
- return prov
- raise ProviderUnavailableError(
- f"Provider {provider_id_or_type} is not available"
- )
-
- async def search(
- self,
- search_query,
- media_types: List[MediaType] = MediaType.ALL,
- limit: int = 10,
- ) -> List[MediaItemType]:
- """
- Perform global search for media items on all providers.
-
- :param search_query: Search query.
- :param media_types: A list of media_types to include.
- :param limit: number of items to return in the search (per type).
- """
- # include results from all music providers
- provider_ids = (item.id for item in self.providers)
- # TODO: sort by name and filter out duplicates ?
- return itertools.chain.from_iterable(
- await asyncio.gather(
- *[
- self.search_provider(
- search_query, media_types, provider_id=provider_id, limit=limit
- )
- for provider_id in provider_ids
- ]
- )
- )
-
- async def search_provider(
- self,
- search_query: str,
- media_types: List[MediaType] = MediaType.ALL,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- limit: int = 10,
- ) -> List[MediaItemType]:
- """
- Perform search on given provider.
-
- :param search_query: Search query
- :param provider_type: type of the provider to perform the search on.
- :param provider_id: id of the provider to perform the search on.
- :param media_types: A list of media_types to include. All types if None.
- :param limit: number of items to return in the search (per type).
- """
- assert provider_type or provider_id, "Provider needs to be supplied"
- prov = self.get_provider(provider_id or provider_type)
- if MusicProviderFeature.SEARCH not in prov.supported_features:
- return []
-
- # create safe search string
- search_query = search_query.replace("/", " ").replace("'", "")
-
- # prefer cache items (if any)
- cache_key = f"{prov.type.value}.search.{search_query}.{limit}"
- cache_key += "".join((x.value for x in media_types))
-
- if cache := await self.mass.cache.get(cache_key):
- return [media_from_dict(x) for x in cache]
- # no items in cache - get listing from provider
- items = await prov.search(
- search_query,
- media_types,
- limit,
- )
- # store (serializable items) in cache
- self.mass.create_task(
- self.mass.cache.set(
- cache_key, [x.to_dict() for x in items], expiration=86400 * 7
- )
- )
- return items
-
- async def browse(self, path: Optional[str] = None) -> BrowseFolder:
- """Browse Music providers."""
- # root level; folder per provider
- if not path or path == "root":
- return BrowseFolder(
- item_id="root",
- provider=ProviderType.DATABASE,
- path="root",
- label="browse",
- name="",
- items=[
- BrowseFolder(
- item_id="root",
- provider=prov.type,
- path=f"{prov.id}://",
- name=prov.name,
- )
- for prov in self.providers
- if MusicProviderFeature.BROWSE in prov.supported_features
- ],
- )
- # provider level
- provider_id = path.split("://", 1)[0]
- prov = self.get_provider(provider_id)
- return await prov.browse(path)
-
- async def get_item_by_uri(
- self, uri: str, force_refresh: bool = False, lazy: bool = True
- ) -> MediaItemType:
- """Fetch MediaItem by uri."""
- media_type, provider_type, item_id = parse_uri(uri)
- return await self.get_item(
- item_id=item_id,
- media_type=media_type,
- provider_type=provider_type,
- force_refresh=force_refresh,
- lazy=lazy,
- )
-
- async def get_item(
- self,
- item_id: str,
- media_type: MediaType,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- force_refresh: bool = False,
- lazy: bool = True,
- ) -> MediaItemType:
- """Get single music item by id and media type."""
- assert (
- provider_type or provider_id
- ), "provider_type or provider_id must be supplied"
- if provider_type == ProviderType.URL or provider_id == "url":
- # handle special case of 'URL' MusicProvider which allows us to play regular url's
- return await self.get_provider(ProviderType.URL).parse_item(item_id)
- ctrl = self.get_controller(media_type)
- return await ctrl.get(
- provider_item_id=item_id,
- provider_type=provider_type,
- provider_id=provider_id,
- force_refresh=force_refresh,
- lazy=lazy,
- )
-
- async def add_to_library(
- self,
- media_type: MediaType,
- provider_item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- ) -> None:
- """Add an item to the library."""
- ctrl = self.get_controller(media_type)
- await ctrl.add_to_library(
- provider_item_id, provider_type=provider_type, provider_id=provider_id
- )
-
- async def remove_from_library(
- self,
- media_type: MediaType,
- provider_item_id: str,
- provider_type: Optional[ProviderType] = None,
- provider_id: Optional[str] = None,
- ) -> None:
- """Remove item from the library."""
- ctrl = self.get_controller(media_type)
- await ctrl.remove_from_library(
- provider_item_id, provider_type=provider_type, provider_id=provider_id
- )
-
- async def delete_db_item(
- self, media_type: MediaType, db_item_id: str, recursive: bool = False
- ) -> None:
- """Remove item from the library."""
- ctrl = self.get_controller(media_type)
- await ctrl.delete_db_item(db_item_id, recursive)
-
- async def refresh_items(self, items: List[MediaItem]) -> None:
- """
- Refresh MediaItems to force retrieval of full info and matches.
-
- Creates background tasks to process the action.
- """
- for media_item in items:
- job_desc = f"Refresh metadata of {media_item.uri}"
- self.mass.add_job(self.refresh_item(media_item), job_desc)
-
- async def refresh_item(
- self,
- media_item: MediaItem,
- ):
- """Try to refresh a mediaitem by requesting it's full object or search for substitutes."""
- try:
- return await self.get_item(
- media_item.item_id,
- media_item.media_type,
- provider_type=media_item.provider,
- force_refresh=True,
- lazy=False,
- )
- except MusicAssistantError:
- pass
-
- for item in await self.search(media_item.name, [media_item.media_type], 20):
- if item.available:
- await self.get_item(
- item.item_id, item.media_type, item.provider, lazy=False
- )
-
- async def set_track_loudness(
- self, item_id: str, provider_type: ProviderType, loudness: int
- ):
- """List integrated loudness for a track in db."""
- await self.mass.database.insert(
- TABLE_TRACK_LOUDNESS,
- {"item_id": item_id, "provider": provider_type.value, "loudness": loudness},
- allow_replace=True,
- )
-
- async def get_track_loudness(
- self, provider_item_id: str, provider_type: ProviderType
- ) -> float | None:
- """Get integrated loudness for a track in db."""
- if result := await self.mass.database.get_row(
- TABLE_TRACK_LOUDNESS,
- {
- "item_id": provider_item_id,
- "provider": provider_type.value,
- },
- ):
- return result["loudness"]
- return None
-
- async def get_provider_loudness(self, provider_type: ProviderType) -> float | None:
- """Get average integrated loudness for tracks of given provider."""
- all_items = []
- if provider_type == ProviderType.URL:
- # this is not a very good idea for random urls
- return None
- for db_row in await self.mass.database.get_rows(
- TABLE_TRACK_LOUDNESS,
- {
- "provider": provider_type.value,
- },
- ):
- all_items.append(db_row["loudness"])
- if all_items:
- return statistics.fmean(all_items)
- return None
-
- async def mark_item_played(self, item_id: str, provider_type: ProviderType):
- """Mark item as played in playlog."""
- timestamp = utc_timestamp()
- await self.mass.database.insert(
- TABLE_PLAYLOG,
- {
- "item_id": item_id,
- "provider": provider_type.value,
- "timestamp": timestamp,
- },
- allow_replace=True,
- )
-
- async def library_add_items(self, items: List[MediaItem]) -> None:
- """
- Add media item(s) to the library.
-
- Creates background tasks to process the action.
- """
- for media_item in items:
- job_desc = f"Add {media_item.uri} to library"
- self.mass.add_job(
- self.add_to_library(
- media_item.media_type, media_item.item_id, media_item.provider
- ),
- job_desc,
- )
-
- async def library_remove_items(self, items: List[MediaItem]) -> None:
- """
- Remove media item(s) from the library.
-
- Creates background tasks to process the action.
- """
- for media_item in items:
- job_desc = f"Remove {media_item.uri} from library"
- self.mass.add_job(
- self.remove_from_library(
- media_item.media_type, media_item.item_id, media_item.provider
- ),
- job_desc,
- )
-
- def get_controller(
- self, media_type: MediaType
- ) -> ArtistsController | AlbumsController | TracksController | RadioController | PlaylistController:
- """Return controller for MediaType."""
- if media_type == MediaType.ARTIST:
- return self.artists
- if media_type == MediaType.ALBUM:
- return self.albums
- if media_type == MediaType.TRACK:
- return self.tracks
- if media_type == MediaType.RADIO:
- return self.radio
- if media_type == MediaType.PLAYLIST:
- return self.playlists
-
- async def _register_provider(
- self, provider: MusicProvider, conf: MusicProviderConfig
- ) -> None:
- """Register a music provider."""
- if provider.id in self._providers:
- raise SetupFailedError(
- f"Provider with id {provider.id} is already registered"
- )
- try:
- provider.config = conf
- provider.mass = self.mass
- provider.cache = self.mass.cache
- provider.logger = self.logger.getChild(provider.type.value)
- if await provider.setup():
- self._providers[provider.id] = provider
- except Exception as err: # pylint: disable=broad-except
- raise SetupFailedError(
- f"Setup failed of provider {provider.type.value}: {str(err)}"
- ) from err
-
- async def _cleanup_library(self) -> None:
- """Cleanup deleted items from library/database."""
- prev_providers = await self.mass.cache.get("prov_ids", default=[])
- cur_providers = list(self._providers.keys())
- removed_providers = {x for x in prev_providers if x not in cur_providers}
-
- for provider_id in removed_providers:
-
- # clean cache items from deleted provider(s)
- await self.mass.cache.clear(provider_id)
-
- # cleanup media items from db matched to deleted provider
- for ctrl in (
- # order is important here to recursively cleanup bottom up
- self.mass.music.radio,
- self.mass.music.playlists,
- self.mass.music.tracks,
- self.mass.music.albums,
- self.mass.music.artists,
- ):
- prov_items = await ctrl.get_db_items_by_prov_id(provider_id=provider_id)
- for item in prov_items:
- await ctrl.remove_prov_mapping(item.item_id, provider_id)
- await self.mass.cache.set("prov_ids", cur_providers)
+++ /dev/null
-"""Logic to play music from MusicProviders to supported players."""
-from __future__ import annotations
-
-import asyncio
-from typing import TYPE_CHECKING, Dict, Tuple
-
-from music_assistant.models.enums import EventType, PlayerState
-from music_assistant.models.errors import AlreadyRegisteredError
-from music_assistant.models.event import MassEvent
-from music_assistant.models.player import Player
-from music_assistant.models.player_queue import PlayerQueue
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-
-class PlayerController:
- """Controller holding all logic to play music from MusicProviders to supported players."""
-
- def __init__(self, mass: MusicAssistant) -> None:
- """Initialize class."""
- self.mass = mass
- self.logger = mass.logger.getChild("players")
- self._players: Dict[str, Player] = {}
- self._player_queues: Dict[str, PlayerQueue] = {}
-
- async def setup(self) -> None:
- """Async initialize of module."""
- self.mass.create_task(self._poll_players())
-
- async def cleanup(self) -> None:
- """Cleanup on exit."""
- for player_id in set(self._players.keys()):
- player = self._players.pop(player_id)
- player.on_remove()
- for queue_id in set(self._player_queues.keys()):
- self._player_queues.pop(queue_id)
-
- @property
- def players(self) -> Tuple[Player]:
- """Return all registered players."""
- return tuple(self._players.values())
-
- @property
- def player_queues(self) -> Tuple[PlayerQueue]:
- """Return all available PlayerQueue's."""
- return tuple(self._player_queues.values())
-
- def __iter__(self):
- """Iterate over (available) players."""
- return iter(self._players.values())
-
- def get_player(self, player_id: str) -> Player | None:
- """Return Player by player_id or None if not found."""
- return self._players.get(player_id)
-
- def get_player_queue(self, queue_id: str) -> PlayerQueue | None:
- """Return PlayerQueue by id or None if not found."""
- return self._player_queues.get(queue_id)
-
- def get_player_by_name(self, name: str) -> Player | None:
- """Return Player by name or None if no match is found."""
- return next((x for x in self._players.values() if x.name == name), None)
-
- async def register_player(self, player: Player) -> None:
- """Register a new player on the controller."""
- if self.mass.closed:
- return
- player_id = player.player_id
-
- if player_id in self._players:
- raise AlreadyRegisteredError(f"Player {player_id} is already registered")
-
- # make sure that the mass instance is set on the player
- player.mass = self.mass
- self._players[player_id] = player
-
- # create playerqueue for this player
- self._player_queues[player.player_id] = player_queue = PlayerQueue(
- self.mass, player_id
- )
- await player_queue.setup()
-
- self.logger.info(
- "Player registered: %s/%s",
- player_id,
- player.name,
- )
- self.mass.signal_event(
- MassEvent(EventType.PLAYER_ADDED, object_id=player.player_id, data=player)
- )
-
- async def _poll_players(self) -> None:
- """Poll players every X interval."""
- interval = 30
- cur_tick = 0
- while True:
- for player in self.players:
- if not player.available:
- continue
- if cur_tick == interval:
- self.mass.loop.call_soon(player.update_state)
- elif (
- player.active_queue.queue_id == player.player_id
- and player.active_queue.active
- and player.state == PlayerState.PLAYING
- ):
- self.mass.loop.call_soon(player.active_queue.on_player_update)
- if cur_tick == interval:
- cur_tick = 0
- else:
- cur_tick += 1
- await asyncio.sleep(1)
+++ /dev/null
-"""Controller to stream audio to players."""
-from __future__ import annotations
-
-import asyncio
-import gc
-import os
-import urllib.parse
-from time import time
-from types import CoroutineType
-from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, Tuple
-from uuid import uuid4
-
-from aiohttp import web
-
-from music_assistant.constants import (
- BASE_URL_OVERRIDE_ENVNAME,
- FALLBACK_DURATION,
- SILENCE_FILE,
-)
-from music_assistant.helpers.audio import (
- check_audio_support,
- crossfade_pcm_parts,
- get_chunksize,
- get_media_stream,
- get_preview_stream,
- get_stream_details,
- strip_silence,
-)
-from music_assistant.helpers.process import AsyncProcess
-from music_assistant.models.enums import (
- ContentType,
- CrossFadeMode,
- EventType,
- MediaType,
- MetadataMode,
- ProviderType,
-)
-from music_assistant.models.errors import MediaNotFoundError, QueueEmpty
-from music_assistant.models.event import MassEvent
-from music_assistant.models.player_queue import PlayerQueue
-from music_assistant.models.queue_item import QueueItem
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-ICY_CHUNKSIZE = 8192
-
-
-class StreamsController:
- """Controller to stream audio to players."""
-
- def __init__(self, mass: MusicAssistant):
- """Initialize instance."""
- self.mass = mass
- self.logger = mass.logger.getChild("stream")
- self._port = mass.config.stream_port
- self._ip = mass.config.stream_ip
- self.queue_streams: Dict[str, QueueStream] = {}
- self.announcements: Dict[str, Tuple[str]] = {}
-
- @property
- def base_url(self) -> str:
- """Return the base url for the stream engine."""
-
- if BASE_URL_OVERRIDE_ENVNAME in os.environ:
- # This is a purpously undocumented feature to override the automatic
- # generated base_url used by the streaming-devices.
- # If you need this, you know it, but you should probably try to not set it!
- # Also see https://github.com/music-assistant/hass-music-assistant/issues/802
- # and https://github.com/music-assistant/hass-music-assistant/discussions/794#discussioncomment-3331209
- return os.environ[BASE_URL_OVERRIDE_ENVNAME]
-
- return f"http://{self._ip}:{self._port}"
-
- def get_stream_url(
- self,
- stream_id: str,
- content_type: ContentType = ContentType.FLAC,
- ) -> str:
- """Generate unique stream url for the PlayerQueue Stream."""
- ext = content_type.value
- return f"{self.base_url}/{stream_id}.{ext}"
-
- def get_announcement_url(
- self,
- queue_id: str,
- urls: Tuple[str],
- content_type: ContentType,
- ) -> str:
- """Get url to announcement stream."""
- self.announcements[queue_id] = urls
- ext = content_type.value
- return f"{self.base_url}/announce/{queue_id}.{ext}"
-
- def get_control_url(self, player_id: str, cmd: str) -> str:
- """Get url to special control stream."""
- return f"{self.base_url}/control/{player_id}/{cmd}.mp3"
-
- async def get_preview_url(self, provider: ProviderType, track_id: str) -> str:
- """Return url to short preview sample."""
- track = await self.mass.music.tracks.get_provider_item(track_id, provider)
- if preview := track.metadata.preview:
- return preview
- enc_track_id = urllib.parse.quote(track_id)
- return (
- f"{self.base_url}/preview?provider={provider.value}&item_id={enc_track_id}"
- )
-
- async def setup(self) -> None:
- """Async initialize of module."""
- app = web.Application()
-
- app.router.add_get("/preview", self.serve_preview)
- app.router.add_get("/announce/{queue_id}.{fmt}", self.serve_announcement)
- app.router.add_get("/control/{player_id}/{cmd}.mp3", self.serve_control)
- app.router.add_get("/{stream_id}.{fmt}", self.serve_queue_stream)
-
- runner = web.AppRunner(app, access_log=None)
- await runner.setup()
- # set host to None to bind to all addresses on both IPv4 and IPv6
- http_site = web.TCPSite(runner, host=None, port=self._port)
- await http_site.start()
-
- async def on_shutdown_event(event: MassEvent):
- """Handle shutdown event."""
- await http_site.stop()
- await runner.cleanup()
- await app.shutdown()
- await app.cleanup()
- self.logger.debug("Streamserver exited.")
-
- self.mass.subscribe(on_shutdown_event, EventType.SHUTDOWN)
-
- ffmpeg_present, libsoxr_support = await check_audio_support(True)
- if not ffmpeg_present:
- self.logger.error(
- "FFmpeg binary not found on your system, playback will NOT work!."
- )
- elif not libsoxr_support:
- self.logger.warning(
- "FFmpeg version found without libsoxr support, "
- "highest quality audio not available. "
- )
-
- self.logger.info("Started stream server on port %s", self._port)
-
- async def serve_preview(self, request: web.Request):
- """Serve short preview sample."""
- provider_mapping = request.query["provider_mapping"]
- item_id = urllib.parse.unquote(request.query["item_id"])
- resp = web.StreamResponse(
- status=200, reason="OK", headers={"Content-Type": "audio/mp3"}
- )
- await resp.prepare(request)
- async for chunk in get_preview_stream(self.mass, provider_mapping, item_id):
- await resp.write(chunk)
- return resp
-
- async def serve_announcement(self, request: web.Request):
- """Serve announcement broadcast."""
- queue_id = request.match_info["queue_id"]
- fmt = ContentType.try_parse(request.match_info["fmt"])
- urls = self.announcements[queue_id]
-
- ffmpeg_args = ["ffmpeg", "-hide_banner", "-loglevel", "quiet"]
- for url in urls:
- ffmpeg_args += ["-i", url]
- if len(urls) > 1:
- ffmpeg_args += [
- "-filter_complex",
- f"[0:a][1:a]concat=n={len(urls)}:v=0:a=1",
- ]
- ffmpeg_args += ["-f", fmt.value, "-"]
-
- async with AsyncProcess(ffmpeg_args) as ffmpeg_proc:
- output, _ = await ffmpeg_proc.communicate()
-
- return web.Response(body=output, headers={"Content-Type": f"audio/{fmt.value}"})
-
- async def serve_control(self, request: web.Request):
- """Serve special control stream."""
- self.logger.debug(
- "Got %s request to %s from %s\nheaders: %s\n",
- request.method,
- request.path,
- request.remote,
- request.headers,
- )
- player_id = request.match_info["player_id"]
- cmd = request.match_info["cmd"]
-
- player = self.mass.players.get_player(player_id)
- if not player:
- return web.Response(status=404)
-
- queue = player.active_queue
-
- if queue and queue.stream:
- # handle next (ignore if signal_next active)
- if cmd == "next" and not queue.stream.signal_next:
- self.mass.create_task(queue.stream.queue.next())
- # handle previous
- elif cmd == "previous":
- self.mass.create_task(queue.stream.queue.previous())
-
- # always respond with silence just to prevent errors
- return web.FileResponse(SILENCE_FILE)
-
- async def serve_queue_stream(self, request: web.Request):
- """Serve queue audio stream to a single player."""
- self.logger.debug(
- "Got %s request to %s from %s\nheaders: %s\n",
- request.method,
- request.path,
- request.remote,
- request.headers,
- )
- client_id = request.query.get("player_id", request.remote)
- stream_id = request.match_info["stream_id"]
- queue_stream = self.queue_streams.get(stream_id)
-
- # try to recover from the situation where the player itself requests
- # a stream that is already done
- if queue_stream is None or queue_stream.done.is_set():
- self.logger.warning(
- "Got stream request for unknown or finished stream: %s",
- stream_id,
- )
- return web.Response(status=404)
-
- # handle a second connection for the same player
- # this probably means a client which does multiple GET requests (e.g. Kodi, Vlc)
- if client_id in queue_stream.connected_clients:
- self.logger.warning(
- "Simultanuous connections detected from %s, playback may be disturbed!",
- client_id,
- )
- client_id += uuid4().hex
- elif queue_stream.all_clients_connected.is_set():
- self.logger.warning(
- "Got stream request for already running stream: %s, playback may be disturbed!",
- stream_id,
- )
-
- # prepare request, add some DLNA/UPNP compatible headers
- headers = {
- "Content-Type": f"audio/{queue_stream.output_format.value}",
- "transferMode.dlna.org": "Streaming",
- "contentFeatures.dlna.org": "DLNA.ORG_OP=00;DLNA.ORG_CI=0;DLNA.ORG_FLAGS=0d500000000000000000000000000000",
- "Cache-Control": "no-cache",
- }
-
- # ICY-metadata headers depend on settings
- metadata_mode = queue_stream.queue.settings.metadata_mode
- if metadata_mode != MetadataMode.DISABLED:
- headers["icy-name"] = "Music Assistant"
- headers["icy-pub"] = "1"
- headers["icy-metaint"] = str(queue_stream.output_chunksize)
-
- resp = web.StreamResponse(headers=headers)
- try:
- await resp.prepare(request)
- except ConnectionResetError:
- return resp
-
- if request.method != "GET":
- # do not start stream on HEAD request
- return resp
-
- enable_icy = request.headers.get("Icy-MetaData", "") == "1"
-
- # regular streaming - each chunk is sent to the callback here
- # this chunk is already encoded to the requested audio format of choice.
- # optional ICY metadata can be sent to the client if it supports that
- async def audio_callback(chunk: bytes) -> None:
- """Call when a new audio chunk arrives."""
- # write audio
- await resp.write(chunk)
-
- # ICY metadata support
- if not enable_icy:
- return
-
- # if icy metadata is enabled, send the icy metadata after the chunk
- item_in_buf = queue_stream.queue.get_item(queue_stream.index_in_buffer)
- if item_in_buf and item_in_buf.name:
- title = item_in_buf.name
- if item_in_buf.image and not item_in_buf.image.is_file:
- image = item_in_buf.media_item.image.url
- else:
- image = ""
- else:
- title = "Music Assistant"
- image = ""
- metadata = f"StreamTitle='{title}';StreamUrl='&picture={image}';".encode()
- while len(metadata) % 16 != 0:
- metadata += b"\x00"
- length = len(metadata)
- length_b = chr(int(length / 16)).encode()
- await resp.write(length_b + metadata)
-
- await queue_stream.subscribe(client_id, audio_callback)
- await resp.write_eof()
-
- return resp
-
- async def start_queue_stream(
- self,
- queue: PlayerQueue,
- start_index: int,
- seek_position: int,
- fade_in: bool,
- output_format: ContentType,
- ) -> QueueStream:
- """Start running a queue stream."""
- # generate unique stream url
- stream_id = uuid4().hex
- # determine the pcm details based on the first track we need to stream
- try:
- first_item = queue.items[start_index]
- except (IndexError, TypeError) as err:
- raise QueueEmpty() from err
-
- streamdetails = await get_stream_details(self.mass, first_item, queue.queue_id)
-
- # work out pcm details
- if queue.settings.crossfade_mode == CrossFadeMode.ALWAYS:
- pcm_sample_rate = min(96000, queue.settings.max_sample_rate)
- pcm_bit_depth = 24
- pcm_channels = 2
- allow_resample = True
- elif streamdetails.sample_rate > queue.settings.max_sample_rate:
- pcm_sample_rate = queue.settings.max_sample_rate
- pcm_bit_depth = streamdetails.bit_depth
- pcm_channels = streamdetails.channels
- allow_resample = True
- else:
- pcm_sample_rate = streamdetails.sample_rate
- pcm_bit_depth = streamdetails.bit_depth
- pcm_channels = streamdetails.channels
- allow_resample = False
-
- self.queue_streams[stream_id] = stream = QueueStream(
- queue=queue,
- stream_id=stream_id,
- start_index=start_index,
- seek_position=seek_position,
- fade_in=fade_in,
- output_format=output_format,
- pcm_sample_rate=pcm_sample_rate,
- pcm_bit_depth=pcm_bit_depth,
- pcm_channels=pcm_channels,
- allow_resample=allow_resample,
- )
- # cleanup stale previous queue tasks
- asyncio.create_task(self.cleanup_stale())
- return stream
-
- async def cleanup_stale(self) -> None:
- """Cleanup stale/done stream tasks."""
- stale = set()
- for stream_id, stream in self.queue_streams.items():
- if stream.done.is_set() and not stream.connected_clients:
- stale.add(stream_id)
- for stream_id in stale:
- self.queue_streams.pop(stream_id, None)
-
-
-class QueueStream:
- """Representation of a (multisubscriber) Audio Queue stream."""
-
- def __init__(
- self,
- queue: PlayerQueue,
- stream_id: str,
- start_index: int,
- seek_position: int,
- fade_in: bool,
- output_format: ContentType,
- pcm_sample_rate: int,
- pcm_bit_depth: int,
- pcm_channels: int = 2,
- pcm_floating_point: bool = False,
- allow_resample: bool = False,
- ):
- """Init QueueStreamJob instance."""
- self.queue = queue
- self.stream_id = stream_id
- self.start_index = start_index
- self.seek_position = seek_position
- self.fade_in = fade_in
- self.output_format = output_format
- self.pcm_sample_rate = pcm_sample_rate
- self.pcm_bit_depth = pcm_bit_depth
- self.pcm_channels = pcm_channels
- self.pcm_floating_point = pcm_floating_point
- self.allow_resample = allow_resample
- self.url = queue.mass.streams.get_stream_url(stream_id, output_format)
-
- self.mass = queue.mass
- self.logger = self.queue.logger.getChild("stream")
- self.expected_clients = 1
- self.connected_clients: Dict[str, CoroutineType[bytes]] = {}
- self.total_seconds_streamed = 0
- self.streaming_started = 0
- self.done = asyncio.Event()
- self.all_clients_connected = asyncio.Event()
- self.index_in_buffer = start_index
- self.signal_next: Optional[int] = None
- self._runner_task: Optional[asyncio.Task] = None
- if queue.settings.metadata_mode == MetadataMode.LEGACY:
- # use the legacy/recommended metaint size of 8192 bytes
- self.output_chunksize = ICY_CHUNKSIZE
- else:
- self.output_chunksize = get_chunksize(
- output_format, pcm_sample_rate, pcm_bit_depth
- )
- self.sample_size_per_second = get_chunksize(
- ContentType.from_bit_depth(pcm_bit_depth, pcm_floating_point),
- pcm_sample_rate,
- pcm_bit_depth,
- pcm_channels,
- )
- self._runner_task = self.mass.create_task(self._queue_stream_runner())
-
- async def stop(self) -> None:
- """Stop running queue stream and cleanup."""
- self.done.set()
- if self._runner_task and not self._runner_task.done():
- self._runner_task.cancel()
- # allow some time to cleanup
- await asyncio.sleep(2)
-
- self._runner_task = None
- self.connected_clients = {}
-
- # run garbage collection manually due to the high number of
- # processed bytes blocks
- loop = asyncio.get_running_loop()
- await loop.run_in_executor(None, gc.collect)
- self.logger.debug("Stream job %s cleaned up", self.stream_id)
-
- async def subscribe(self, client_id: str, callback: CoroutineType[bytes]) -> None:
- """Subscribe callback and wait for completion."""
- assert not self.done.is_set(), "Stream task is already finished"
- # assert client_id not in self.connected_clients, "Client already connected"
-
- self.connected_clients[client_id] = callback
- self.logger.debug("client connected: %s", client_id)
- if len(self.connected_clients) == self.expected_clients:
- self.all_clients_connected.set()
- try:
- await self.done.wait()
- finally:
- self.logger.debug("client disconnected: %s", client_id)
- self.connected_clients.pop(client_id, None)
- await self._check_stop()
-
- async def _queue_stream_runner(self) -> None:
- """Distribute audio chunks over connected client(s)."""
- # collect ffmpeg args
- input_format = ContentType.from_bit_depth(
- self.pcm_bit_depth, self.pcm_floating_point
- )
- ffmpeg_args = [
- "ffmpeg",
- "-hide_banner",
- "-loglevel",
- "quiet",
- "-ignore_unknown",
- # pcm input args
- "-f",
- input_format.value,
- "-ac",
- str(self.pcm_channels),
- "-ar",
- str(self.pcm_sample_rate),
- "-i",
- "-",
- # add metadata
- "-metadata",
- "title=Streaming from Music Assistant",
- ]
- # fade-in if needed
- if self.fade_in:
- ffmpeg_args += ["-af", "afade=t=in:st=0:d=5"]
-
- ffmpeg_args += [
- # output args
- "-f",
- self.output_format.value,
- "-compression_level",
- "0",
- "-",
- ]
- # get the raw pcm bytes from the queue stream and on-the-fly encode to wanted format
- # send the compressed/encoded stream to the client(s).
- async with AsyncProcess(ffmpeg_args, True) as ffmpeg_proc:
-
- async def writer():
- """Task that sends the raw pcm audio to the ffmpeg process."""
- async for audio_chunk in self._get_queue_stream():
- await ffmpeg_proc.write(audio_chunk)
- self.total_seconds_streamed += (
- len(audio_chunk) / self.sample_size_per_second
- )
- # write eof when last packet is received
- ffmpeg_proc.write_eof()
-
- ffmpeg_proc.attach_task(writer())
-
- # wait max 10 seconds for all client(s) to connect
- try:
- await asyncio.wait_for(self.all_clients_connected.wait(), 10)
- except asyncio.exceptions.TimeoutError:
- self.logger.warning(
- "Abort: client(s) did not connect within 10 seconds."
- )
- self.done.set()
- return
- self.logger.debug("%s clients connected", len(self.connected_clients))
- self.streaming_started = time()
-
- # Read bytes from final output and send chunk to child callback.
- chunk_num = 0
- async for chunk in ffmpeg_proc.iter_chunked(self.output_chunksize):
- chunk_num += 1
-
- if len(self.connected_clients) == 0:
- # no more clients
- if await self._check_stop():
- return
- for client_id in set(self.connected_clients.keys()):
- try:
- callback = self.connected_clients[client_id]
- await callback(chunk)
- except (
- ConnectionResetError,
- KeyError,
- BrokenPipeError,
- ):
- self.connected_clients.pop(client_id, None)
-
- # all queue data has been streamed. Either because the queue is exhausted
- # or we need to restart the stream due to decoder/sample rate mismatch
- # set event that this stream task is finished
- # if the stream is restarted by the queue manager afterwards is controlled
- # by the `signal_next` bool above.
- self.done.set()
-
- async def _get_queue_stream(
- self,
- ) -> AsyncGenerator[None, bytes]:
- """Stream the PlayerQueue's tracks as constant feed of PCM raw audio."""
- last_fadeout_part = b""
- queue_index = None
- track_count = 0
- prev_track: Optional[QueueItem] = None
-
- pcm_fmt = ContentType.from_bit_depth(self.pcm_bit_depth)
- self.logger.debug(
- "Starting Queue audio stream for Queue %s (PCM format: %s - sample rate: %s)",
- self.queue.player.name,
- pcm_fmt.value,
- self.pcm_sample_rate,
- )
-
- # stream queue tracks one by one
- while True:
- # get the (next) track in queue
- track_count += 1
- if track_count == 1:
- queue_index = self.start_index
- seek_position = self.seek_position
- else:
- next_index = self.queue.get_next_index(queue_index)
- # break here if next index does not match (e.g. when repeat enabled)!
- if next_index <= queue_index:
- self.signal_next = next_index
- break
- queue_index = next_index
- seek_position = 0
- queue_track = self.queue.get_item(queue_index)
- if not queue_track:
- self.logger.debug(
- "Abort Queue stream %s: no (more) tracks in queue",
- self.queue.queue_id,
- )
- break
- # get streamdetails
- try:
- streamdetails = await get_stream_details(
- self.mass, queue_track, self.queue.queue_id
- )
- except MediaNotFoundError as err:
- self.logger.warning(
- "Skip track %s due to missing streamdetails",
- queue_track.name,
- exc_info=err,
- )
- continue
-
- # check the PCM samplerate/bitrate
- if not self.allow_resample and streamdetails.bit_depth > self.pcm_bit_depth:
- self.signal_next = queue_index
- self.logger.debug(
- "Abort queue stream %s due to bit depth mismatch",
- self.queue.player.name,
- )
- break
- if (
- not self.allow_resample
- and streamdetails.sample_rate > self.pcm_sample_rate
- and streamdetails.sample_rate <= self.queue.settings.max_sample_rate
- ):
- self.logger.debug(
- "Abort queue stream %s due to sample rate mismatch",
- self.queue.player.name,
- )
- self.signal_next = queue_index
- break
-
- # check crossfade ability
- use_crossfade = (
- self.queue.settings.crossfade_mode != CrossFadeMode.DISABLED
- and self.queue.settings.crossfade_duration > 0
- )
- # do not crossfade tracks of same album
- if (
- use_crossfade
- and self.queue.settings.crossfade_mode != CrossFadeMode.ALWAYS
- and prev_track
- and prev_track.media_type == MediaType.TRACK
- and queue_track.media_type == MediaType.TRACK
- ):
- if (
- prev_track.media_item.album is not None
- and queue_track.media_item.album is not None
- and prev_track.media_item.album == queue_track.media_item.album
- ):
- self.logger.debug("Skipping crossfade: Tracks are from same album")
- use_crossfade = False
- prev_track = queue_track
-
- self.logger.info(
- "Start Streaming queue track: %s (%s) for queue %s - crossfade: %s",
- streamdetails.uri,
- queue_track.name,
- self.queue.player.name,
- use_crossfade,
- )
-
- # set some basic vars
- crossfade_duration = self.queue.settings.crossfade_duration
- crossfade_size = int(self.sample_size_per_second * crossfade_duration)
- queue_track.streamdetails.seconds_skipped = seek_position
- # predict total size to expect for this track from duration
- stream_duration = (
- queue_track.duration or FALLBACK_DURATION
- ) - seek_position
- # buffer_duration has some overhead to account for padded silence
- buffer_duration = (crossfade_duration + 4) if use_crossfade else 4
- # send signal that we've loaded a new track into the buffer
- self.index_in_buffer = queue_index
- self.queue.signal_update()
-
- buffer = b""
- bytes_written = 0
- chunk_num = 0
- # handle incoming audio chunks
- track_time_start = time()
- async for chunk in get_media_stream(
- self.mass,
- streamdetails,
- pcm_fmt=pcm_fmt,
- sample_rate=self.pcm_sample_rate,
- channels=self.pcm_channels,
- seek_position=seek_position,
- chunk_size=self.sample_size_per_second,
- ):
-
- chunk_num += 1
- seconds_in_buffer = len(buffer) / self.sample_size_per_second
-
- #### HANDLE FIRST PART OF TRACK
-
- # buffer full for crossfade
- if last_fadeout_part and (seconds_in_buffer >= buffer_duration):
- # strip silence of start
- first_part = await strip_silence(
- buffer + chunk, pcm_fmt, self.pcm_sample_rate
- )
- # perform crossfade
- fadein_part = first_part[:crossfade_size]
- remaining_bytes = first_part[crossfade_size:]
- crossfade_part = await crossfade_pcm_parts(
- fadein_part,
- last_fadeout_part,
- self.pcm_bit_depth,
- self.pcm_sample_rate,
- )
- # send crossfade_part
- yield crossfade_part
- bytes_written += len(crossfade_part)
- # also write the leftover bytes from the strip action
- if remaining_bytes:
- yield remaining_bytes
- bytes_written += len(remaining_bytes)
-
- # clear vars
- last_fadeout_part = b""
- buffer = b""
- continue
-
- # first part of track and we need to crossfade: fill buffer
- if last_fadeout_part:
- buffer += chunk
- continue
-
- # last part of track: fill buffer
- if buffer or (chunk_num >= (stream_duration - buffer_duration)):
- buffer += chunk
- continue
-
- # all other: middle of track or no crossfade action, just yield the audio
- yield chunk
- bytes_written += len(chunk)
- continue
-
- #### HANDLE END OF TRACK
-
- if bytes_written == 0:
- # stream error: got empy first chunk ?!
- self.logger.warning("Stream error on %s", streamdetails.uri)
- queue_track.streamdetails.seconds_streamed = 0
- continue
-
- # try to make a rough assumption of how many seconds is buffered ahead by the player(s)
- player_buffered = (
- self.total_seconds_streamed - self.queue.player.elapsed_time or 0
- )
- seconds_in_buffer = len(buffer) / self.sample_size_per_second
- process_time = round(time() - track_time_start, 2)
- # log warning if received seconds are a lot less than expected
- if (stream_duration - chunk_num) > 20:
- self.logger.warning(
- "Unexpected number of chunks received for track %s: %s/%s - "
- "process_time: %s - player_buffered: %s",
- streamdetails.uri,
- chunk_num,
- stream_duration,
- process_time,
- player_buffered,
- )
- self.logger.debug(
- "end of track reached - chunk_num: %s - crossfade_buffer: %s - "
- "stream_duration: %s - player_buffer: %s - process_time: %s",
- chunk_num,
- seconds_in_buffer,
- stream_duration,
- player_buffered,
- process_time,
- )
-
- if buffer:
- # strip silence from end of audio
- last_part = await strip_silence(
- buffer, pcm_fmt, self.pcm_sample_rate, reverse=True
- )
- if use_crossfade:
- # if crossfade is enabled, save fadeout part to pickup for next track
- if len(last_part) < crossfade_size <= len(buffer):
- # the chunk length is too short after stripping silence, only use first part
- last_fadeout_part = buffer[:crossfade_size]
- elif use_crossfade and len(last_part) > crossfade_size:
- # yield remaining bytes from strip action,
- # we only need the crossfade_size part
- last_fadeout_part = last_part[-crossfade_size:]
- remaining_bytes = last_part[:-crossfade_size]
- yield remaining_bytes
- bytes_written += len(remaining_bytes)
- elif use_crossfade:
- last_fadeout_part = last_part
- else:
- # no crossfade enabled, just yield the stripped audio data
- yield last_part
- bytes_written += len(last_part)
-
- # end of the track reached - store accurate duration
- buffer = b""
- queue_track.streamdetails.seconds_streamed = (
- bytes_written / self.sample_size_per_second
- )
- self.logger.debug(
- "Finished Streaming queue track: %s (%s) on queue %s",
- queue_track.streamdetails.uri,
- queue_track.name,
- self.queue.player.name,
- )
- # end of queue reached, pass last fadeout bits to final output
- if last_fadeout_part:
- yield last_fadeout_part
- # END OF QUEUE STREAM
- self.logger.debug("Queue stream for Queue %s finished.", self.queue.player.name)
-
- async def _check_stop(self) -> bool:
- """Schedule stop of queue stream."""
- # Stop this queue stream when no clients (re)connected within 5 seconds
- for _ in range(0, 10):
- if len(self.connected_clients) > 0:
- return False
- await asyncio.sleep(0.5)
- asyncio.create_task(self.stop())
- return True
+++ /dev/null
-"""Various utils/helpers."""
+++ /dev/null
-# pylint: skip-file
-# fmt: off
-# flake8: noqa
-(lambda __g: [(lambda __mod: [[[None for __g['app_var'], app_var.__name__ in [(lambda index: (lambda __l: [[AV(aap(__l['var'].encode()).decode()) for __l['var'] in [(vars.split('acb2')[__l['index']][::(-1)])]][0] for __l['index'] in [(index)]][0])({}), 'app_var')]][0] for __g['vars'] in [('3YTNyUDOyQTOacb2=EmN5M2YjdzMhljYzYzYhlDMmFGNlVTOmNDZwMzNxYzNacb2=UDMzEGOyADO1QWO5kDNygTMlJGN5QzNzIWOmZTOiVmMacb2yMTNzITNacb2=UDZhJmMldTZ3QTY4IjZ3kTNxYjN0czNwI2YxkTM5MjN')]][0] for __g['aap'] in [(__mod.b64decode)]][0])(__import__('base64', __g, __g, ('b64decode',), 0)) for __g['AV'] in [((lambda b, d: d.get('__metaclass__', getattr(b[0], '__class__', type(b[0])))('AV', b, d))((str,), (lambda __l: [__l for __l['__repr__'], __l['__repr__'].__name__ in [(lambda self: (lambda __l: [__name__ for __l['self'] in [(self)]][0])({}), '__repr__')]][0])({'__module__': __name__})))]][0])(globals())
+++ /dev/null
-"""Various helpers for audio manipulation."""
-from __future__ import annotations
-
-import asyncio
-import logging
-import os
-import re
-import struct
-from io import BytesIO
-from time import time
-from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Tuple
-
-import aiofiles
-from aiohttp import ClientTimeout
-
-from music_assistant.helpers.process import AsyncProcess, check_output
-from music_assistant.helpers.util import create_tempfile
-from music_assistant.models.errors import (
- AudioError,
- MediaNotFoundError,
- MusicAssistantError,
-)
-from music_assistant.models.media_items import ContentType, MediaType, StreamDetails
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
- from music_assistant.models.player_queue import QueueItem
-
-LOGGER = logging.getLogger(__name__)
-
-# pylint:disable=consider-using-f-string
-
-
-async def crossfade_pcm_parts(
- fade_in_part: bytes,
- fade_out_part: bytes,
- bit_depth: int,
- sample_rate: int,
- channels: int = 2,
-) -> bytes:
- """Crossfade two chunks of pcm/raw audio using ffmpeg."""
- sample_size = int(sample_rate * (bit_depth / 8) * channels)
- fmt = ContentType.from_bit_depth(bit_depth)
- # calculate the fade_length from the smallest chunk
- fade_length = min(len(fade_in_part), len(fade_out_part)) / sample_size
- fadeoutfile = create_tempfile()
- async with aiofiles.open(fadeoutfile.name, "wb") as outfile:
- await outfile.write(fade_out_part)
- args = [
- # generic args
- "ffmpeg",
- "-hide_banner",
- "-loglevel",
- "quiet",
- # fadeout part (as file)
- "-acodec",
- fmt.name.lower(),
- "-f",
- fmt.value,
- "-ac",
- str(channels),
- "-ar",
- str(sample_rate),
- "-i",
- fadeoutfile.name,
- # fade_in part (stdin)
- "-acodec",
- fmt.name.lower(),
- "-f",
- fmt.value,
- "-ac",
- str(channels),
- "-ar",
- str(sample_rate),
- "-i",
- "-",
- # filter args
- "-filter_complex",
- f"[0][1]acrossfade=d={fade_length}",
- # output args
- "-f",
- fmt.value,
- "-",
- ]
- async with AsyncProcess(args, True) as proc:
- crossfade_data, _ = await proc.communicate(fade_in_part)
- if crossfade_data:
- LOGGER.debug(
- "crossfaded 2 pcm chunks. fade_in_part: %s - fade_out_part: %s - fade_length: %s seconds",
- len(fade_in_part),
- len(fade_out_part),
- fade_length,
- )
- return crossfade_data
- # no crossfade_data, return original data instead
- LOGGER.debug(
- "crossfade of pcm chunks failed: not enough data? fade_in_part: %s - fade_out_part: %s",
- len(fade_in_part),
- len(fade_out_part),
- )
- return fade_out_part + fade_in_part
-
-
-async def strip_silence(
- audio_data: bytes,
- fmt: ContentType,
- sample_rate: int,
- channels: int = 2,
- reverse=False,
-) -> bytes:
- """Strip silence from (a chunk of) pcm audio."""
- # input args
- args = ["ffmpeg", "-hide_banner", "-loglevel", "quiet"]
- args += [
- "-acodec",
- fmt.name.lower(),
- "-f",
- fmt.value,
- "-ac",
- str(channels),
- "-ar",
- str(sample_rate),
- "-i",
- "-",
- ]
- # filter args
- if reverse:
- args += [
- "-af",
- "areverse,atrim=start=0.2,silenceremove=start_periods=1:start_silence=0.1:start_threshold=0.02,areverse",
- ]
- else:
- args += [
- "-af",
- "atrim=start=0.2,silenceremove=start_periods=1:start_silence=0.1:start_threshold=0.02",
- ]
- # output args
- args += ["-f", fmt.value, "-"]
- async with AsyncProcess(args, True) as proc:
- stripped_data, _ = await proc.communicate(audio_data)
- LOGGER.debug(
- "stripped silence of pcm chunk. size before: %s - after: %s",
- len(audio_data),
- len(stripped_data),
- )
- return stripped_data
-
-
-async def analyze_audio(mass: MusicAssistant, streamdetails: StreamDetails) -> None:
- """Analyze track audio, for now we only calculate EBU R128 loudness."""
-
- if streamdetails.loudness is not None:
- # only when needed we do the analyze job
- return
-
- LOGGER.debug("Start analyzing track %s", streamdetails.uri)
- # calculate BS.1770 R128 integrated loudness with ffmpeg
- started = time()
- input_file = streamdetails.direct or "-"
- proc_args = [
- "ffmpeg",
- "-i",
- input_file,
- "-f",
- streamdetails.content_type.value,
- "-af",
- "ebur128=framelog=verbose",
- "-f",
- "null",
- "-",
- ]
- async with AsyncProcess(
- proc_args,
- enable_stdin=streamdetails.direct is None,
- enable_stdout=False,
- enable_stderr=True,
- ) as ffmpeg_proc:
-
- async def writer():
- """Task that grabs the source audio and feeds it to ffmpeg."""
- music_prov = mass.music.get_provider(streamdetails.provider)
- async for audio_chunk in music_prov.get_audio_stream(streamdetails):
- await ffmpeg_proc.write(audio_chunk)
- if (time() - started) > 300:
- # just in case of endless radio stream etc
- break
- ffmpeg_proc.write_eof()
-
- if streamdetails.direct is None:
- writer_task = ffmpeg_proc.attach_task(writer())
- # wait for the writer task to finish
- await writer_task
-
- _, stderr = await ffmpeg_proc.communicate()
- try:
- loudness_str = (
- stderr.decode()
- .split("Integrated loudness")[1]
- .split("I:")[1]
- .split("LUFS")[0]
- )
- loudness = float(loudness_str.strip())
- except (IndexError, ValueError, AttributeError):
- LOGGER.warning(
- "Could not determine integrated loudness of %s - %s",
- streamdetails.uri,
- stderr.decode() or "received empty value",
- )
- else:
- streamdetails.loudness = loudness
- await mass.music.set_track_loudness(
- streamdetails.item_id, streamdetails.provider, loudness
- )
- LOGGER.debug(
- "Integrated loudness of %s is: %s",
- streamdetails.uri,
- loudness,
- )
-
-
-async def get_stream_details(
- mass: MusicAssistant, queue_item: "QueueItem", queue_id: str = ""
-) -> StreamDetails:
- """
- Get streamdetails for the given QueueItem.
-
- This is called just-in-time when a PlayerQueue wants a MediaItem to be played.
- Do not try to request streamdetails in advance as this is expiring data.
- param media_item: The MediaItem (track/radio) for which to request the streamdetails for.
- param queue_id: Optionally provide the queue_id which will play this stream.
- """
- streamdetails = None
- if queue_item.streamdetails and (time() < (queue_item.streamdetails.expires - 60)):
- # we already have fresh streamdetails, use these
- queue_item.streamdetails.seconds_skipped = None
- queue_item.streamdetails.seconds_streamed = None
- streamdetails = queue_item.streamdetails
- else:
- # fetch streamdetails from provider
- # always request the full item as there might be other qualities available
- full_item = await mass.music.get_item_by_uri(queue_item.uri)
- # sort by quality and check track availability
- for prov_media in sorted(
- full_item.provider_mappings, key=lambda x: x.quality or 0, reverse=True
- ):
- if not prov_media.available:
- continue
- # get streamdetails from provider
- music_prov = mass.music.get_provider(prov_media.provider_id)
- if not music_prov or not music_prov.available:
- continue # provider temporary unavailable ?
- try:
- streamdetails: StreamDetails = await music_prov.get_stream_details(
- prov_media.item_id
- )
- streamdetails.content_type = ContentType(streamdetails.content_type)
- except MusicAssistantError as err:
- LOGGER.warning(str(err))
- else:
- break
-
- if not streamdetails:
- raise MediaNotFoundError(f"Unable to retrieve streamdetails for {queue_item}")
-
- # set queue_id on the streamdetails so we know what is being streamed
- streamdetails.queue_id = queue_id
- # get gain correct / replaygain
- if streamdetails.gain_correct is None:
- loudness, gain_correct = await get_gain_correct(mass, streamdetails)
- streamdetails.gain_correct = gain_correct
- streamdetails.loudness = loudness
- if not streamdetails.duration:
- streamdetails.duration = queue_item.duration
- # make sure that ffmpeg handles mpeg dash streams directly
- if (
- streamdetails.content_type == ContentType.MPEG_DASH
- and streamdetails.data
- and streamdetails.data.startswith("http")
- ):
- streamdetails.direct = streamdetails.data
- # set streamdetails as attribute on the media_item
- # this way the app knows what content is playing
- queue_item.streamdetails = streamdetails
- return streamdetails
-
-
-async def get_gain_correct(
- mass: MusicAssistant, streamdetails: StreamDetails
-) -> Tuple[Optional[float], Optional[float]]:
- """Get gain correction for given queue / track combination."""
- queue = mass.players.get_player_queue(streamdetails.queue_id)
- if not queue or not queue.settings.volume_normalization_enabled:
- return (None, None)
- if streamdetails.gain_correct is not None:
- return (streamdetails.loudness, streamdetails.gain_correct)
- target_gain = queue.settings.volume_normalization_target
- track_loudness = await mass.music.get_track_loudness(
- streamdetails.item_id, streamdetails.provider
- )
- if track_loudness is None:
- # fallback to provider average
- fallback_track_loudness = await mass.music.get_provider_loudness(
- streamdetails.provider
- )
- if fallback_track_loudness is None:
- # fallback to some (hopefully sane) average value for now
- fallback_track_loudness = -8.5
- gain_correct = target_gain - fallback_track_loudness
- else:
- gain_correct = target_gain - track_loudness
- gain_correct = round(gain_correct, 2)
- return (track_loudness, gain_correct)
-
-
-def create_wave_header(samplerate=44100, channels=2, bitspersample=16, duration=None):
- """Generate a wave header from given params."""
- # pylint: disable=no-member
- file = BytesIO()
-
- # Generate format chunk
- format_chunk_spec = b"<4sLHHLLHH"
- format_chunk = struct.pack(
- format_chunk_spec,
- b"fmt ", # Chunk id
- 16, # Size of this chunk (excluding chunk id and this field)
- 1, # Audio format, 1 for PCM
- channels, # Number of channels
- int(samplerate), # Samplerate, 44100, 48000, etc.
- int(samplerate * channels * (bitspersample / 8)), # Byterate
- int(channels * (bitspersample / 8)), # Blockalign
- bitspersample, # 16 bits for two byte samples, etc.
- )
- # Generate data chunk
- # duration = 3600*6.7
- data_chunk_spec = b"<4sL"
- if duration is None:
- # use max value possible
- datasize = 4254768000 # = 6,7 hours at 44100/16
- else:
- # calculate from duration
- numsamples = samplerate * duration
- datasize = int(numsamples * channels * (bitspersample / 8))
- data_chunk = struct.pack(
- data_chunk_spec,
- b"data", # Chunk id
- int(datasize), # Chunk size (excluding chunk id and this field)
- )
- sum_items = [
- # "WAVE" string following size field
- 4,
- # "fmt " + chunk size field + chunk size
- struct.calcsize(format_chunk_spec),
- # Size of data chunk spec + data size
- struct.calcsize(data_chunk_spec) + datasize,
- ]
- # Generate main header
- all_chunks_size = int(sum(sum_items))
- main_header_spec = b"<4sL4s"
- main_header = struct.pack(main_header_spec, b"RIFF", all_chunks_size, b"WAVE")
- # Write all the contents in
- file.write(main_header)
- file.write(format_chunk)
- file.write(data_chunk)
-
- # return file.getvalue(), all_chunks_size + 8
- return file.getvalue()
-
-
-async def get_media_stream(
- mass: MusicAssistant,
- streamdetails: StreamDetails,
- pcm_fmt: ContentType,
- sample_rate: int,
- channels: int = 2,
- seek_position: int = 0,
- chunk_size: int = 64000,
-) -> AsyncGenerator[bytes, None]:
- """Get the PCM audio stream for the given streamdetails."""
- assert pcm_fmt.is_pcm(), "Output format must be a PCM type"
- args = await _get_ffmpeg_args(
- streamdetails,
- pcm_fmt,
- pcm_sample_rate=sample_rate,
- pcm_channels=channels,
- seek_position=seek_position,
- )
- async with AsyncProcess(
- args, enable_stdin=streamdetails.direct is None
- ) as ffmpeg_proc:
-
- LOGGER.debug("start media stream for: %s", streamdetails.uri)
-
- async def writer():
- """Task that grabs the source audio and feeds it to ffmpeg."""
- LOGGER.debug("writer started for %s", streamdetails.uri)
- music_prov = mass.music.get_provider(streamdetails.provider)
- async for audio_chunk in music_prov.get_audio_stream(
- streamdetails, seek_position
- ):
- await ffmpeg_proc.write(audio_chunk)
- # write eof when last packet is received
- ffmpeg_proc.write_eof()
- LOGGER.debug("writer finished for %s", streamdetails.uri)
-
- if streamdetails.direct is None:
- ffmpeg_proc.attach_task(writer())
-
- # yield chunks from stdout
- try:
- async for chunk in ffmpeg_proc.iter_chunked(chunk_size):
- yield chunk
-
- except (asyncio.CancelledError, GeneratorExit) as err:
- LOGGER.debug("media stream aborted for: %s", streamdetails.uri)
- raise err
- else:
- LOGGER.debug("finished media stream for: %s", streamdetails.uri)
- await mass.music.mark_item_played(
- streamdetails.item_id, streamdetails.provider
- )
- finally:
- # report playback
- if streamdetails.callback:
- mass.create_task(streamdetails.callback, streamdetails)
- # send analyze job to background worker
- if streamdetails.loudness is None:
- mass.add_job(
- analyze_audio(mass, streamdetails),
- f"Analyze audio for {streamdetails.uri}",
- )
-
-
-async def get_radio_stream(
- mass: MusicAssistant, url: str, streamdetails: StreamDetails
-) -> AsyncGenerator[bytes, None]:
- """Get radio audio stream from HTTP, including metadata retrieval."""
- headers = {"Icy-MetaData": "1"}
- timeout = ClientTimeout(total=0, connect=30, sock_read=600)
- async with mass.http_session.get(url, headers=headers, timeout=timeout) as resp:
- headers = resp.headers
- meta_int = int(headers.get("icy-metaint", "0"))
- # stream with ICY Metadata
- if meta_int:
- while True:
- audio_chunk = await resp.content.readexactly(meta_int)
- yield audio_chunk
- meta_byte = await resp.content.readexactly(1)
- meta_length = ord(meta_byte) * 16
- meta_data = await resp.content.readexactly(meta_length)
- if not meta_data:
- continue
- meta_data = meta_data.rstrip(b"\0")
- stream_title = re.search(rb"StreamTitle='([^']*)';", meta_data)
- if not stream_title:
- continue
- stream_title = stream_title.group(1).decode()
- if stream_title != streamdetails.stream_title:
- streamdetails.stream_title = stream_title
- if queue := mass.players.get_player_queue(streamdetails.queue_id):
- queue.signal_update()
- # Regular HTTP stream
- else:
- async for chunk in resp.content.iter_any():
- yield chunk
-
-
-async def get_http_stream(
- mass: MusicAssistant,
- url: str,
- streamdetails: StreamDetails,
- seek_position: int = 0,
-) -> AsyncGenerator[bytes, None]:
- """Get audio stream from HTTP."""
- if seek_position:
- assert streamdetails.duration, "Duration required for seek requests"
- # try to get filesize with a head request
- if seek_position and not streamdetails.size:
- async with mass.http_session.head(url) as resp:
- if size := resp.headers.get("Content-Length"):
- streamdetails.size = int(size)
- # headers
- headers = {}
- skip_bytes = 0
- if seek_position and streamdetails.size:
- skip_bytes = int(streamdetails.size / streamdetails.duration * seek_position)
- headers["Range"] = f"bytes={skip_bytes}-"
-
- # start the streaming from http
- buffer = b""
- buffer_all = False
- bytes_received = 0
- timeout = ClientTimeout(total=0, connect=30, sock_read=600)
- async with mass.http_session.get(url, headers=headers, timeout=timeout) as resp:
- is_partial = resp.status == 206
- buffer_all = seek_position and not is_partial
- async for chunk in resp.content.iter_any():
- bytes_received += len(chunk)
- if buffer_all and not skip_bytes:
- buffer += chunk
- continue
- if not is_partial and skip_bytes and bytes_received < skip_bytes:
- continue
- yield chunk
-
- # store size on streamdetails for later use
- if not streamdetails.size:
- streamdetails.size = bytes_received
- if buffer_all:
- skip_bytes = streamdetails.size / streamdetails.duration * seek_position
- yield buffer[:skip_bytes]
-
-
-async def get_file_stream(
- mass: MusicAssistant,
- filename: str,
- streamdetails: StreamDetails,
- seek_position: int = 0,
-) -> AsyncGenerator[bytes, None]:
- """Get audio stream from local accessible file."""
- if seek_position:
- assert streamdetails.duration, "Duration required for seek requests"
- if not streamdetails.size:
- stat = await mass.loop.run_in_executor(None, os.stat, filename)
- streamdetails.size = stat.st_size
- chunk_size = get_chunksize(streamdetails.content_type)
- async with aiofiles.open(streamdetails.data, "rb") as _file:
- if seek_position:
- seek_pos = int(
- (streamdetails.size / streamdetails.duration) * seek_position
- )
- await _file.seek(seek_pos)
- # yield chunks of data from file
- while True:
- data = await _file.read(chunk_size)
- if not data:
- break
- yield data
-
-
-async def check_audio_support(try_install: bool = False) -> Tuple[bool, bool]:
- """Check if ffmpeg is present (with/without libsoxr support)."""
- cache_key = "audio_support_cache"
- if cache := globals().get(cache_key):
- return cache
-
- # check for FFmpeg presence
- returncode, output = await check_output("ffmpeg -version")
- ffmpeg_present = returncode == 0 and "FFmpeg" in output.decode()
- if not ffmpeg_present and try_install:
- # try a few common ways to install ffmpeg
- # this all assumes we have enough rights and running on a linux based platform (or docker)
- await check_output("apt-get update && apt-get install ffmpeg")
- await check_output("apk add ffmpeg")
- # test again
- returncode, output = await check_output("ffmpeg -version")
- ffmpeg_present = returncode == 0 and "FFmpeg" in output.decode()
-
- # use globals as in-memory cache
- libsoxr_support = "enable-libsoxr" in output.decode()
- result = (ffmpeg_present, libsoxr_support)
- globals()[cache_key] = result
- return result
-
-
-async def get_preview_stream(
- mass: MusicAssistant,
- provider_mapping: str,
- track_id: str,
-) -> AsyncGenerator[bytes, None]:
- """Create a 30 seconds preview audioclip for the given streamdetails."""
- music_prov = mass.music.get_provider(provider_mapping)
-
- streamdetails = await music_prov.get_stream_details(track_id)
-
- input_args = [
- "ffmpeg",
- "-hide_banner",
- "-loglevel",
- "quiet",
- "-ignore_unknown",
- ]
- if streamdetails.direct:
- input_args += ["-ss", "30", "-i", streamdetails.direct]
- else:
- # the input is received from pipe/stdin
- if streamdetails.content_type != ContentType.UNKNOWN:
- input_args += ["-f", streamdetails.content_type.value]
- input_args += ["-i", "-"]
-
- output_args = ["-to", "30", "-f", "mp3", "-"]
- args = input_args + output_args
- async with AsyncProcess(args, True) as ffmpeg_proc:
-
- async def writer():
- """Task that grabs the source audio and feeds it to ffmpeg."""
- music_prov = mass.music.get_provider(streamdetails.provider)
- async for audio_chunk in music_prov.get_audio_stream(streamdetails, 30):
- await ffmpeg_proc.write(audio_chunk)
- # write eof when last packet is received
- ffmpeg_proc.write_eof()
-
- if not streamdetails.direct:
- ffmpeg_proc.attach_task(writer())
-
- # yield chunks from stdout
- async for chunk in ffmpeg_proc.iter_any():
- yield chunk
-
-
-async def get_silence(
- duration: int,
- output_fmt: ContentType = ContentType.WAV,
- sample_rate: int = 44100,
- bit_depth: int = 16,
- channels: int = 2,
-) -> AsyncGenerator[bytes, None]:
- """Create stream of silence, encoded to format of choice."""
-
- # wav silence = just zero's
- if output_fmt == ContentType.WAV:
- yield create_wave_header(
- samplerate=sample_rate,
- channels=2,
- bitspersample=bit_depth,
- duration=duration,
- )
- for _ in range(0, duration):
- yield b"\0" * int(sample_rate * (bit_depth / 8) * channels)
- return
-
- # use ffmpeg for all other encodings
- args = [
- "ffmpeg",
- "-hide_banner",
- "-loglevel",
- "quiet",
- "-f",
- "lavfi",
- "-i",
- f"anullsrc=r={sample_rate}:cl={'stereo' if channels == 2 else 'mono'}",
- "-t",
- str(duration),
- "-f",
- output_fmt.value,
- "-",
- ]
- async with AsyncProcess(args) as ffmpeg_proc:
- async for chunk in ffmpeg_proc.iter_any():
- yield chunk
-
-
-def get_chunksize(
- content_type: ContentType,
- sample_rate: int = 44100,
- bit_depth: int = 16,
- channels: int = 2,
- seconds: int = 1,
-) -> int:
- """Get a default chunksize for given contenttype."""
- pcm_size = int(sample_rate * (bit_depth / 8) * channels * seconds)
- if content_type.is_pcm() or content_type == ContentType.WAV:
- return pcm_size
- if content_type in (ContentType.WAV, ContentType.AIFF, ContentType.DSF):
- return pcm_size
- if content_type in (ContentType.FLAC, ContentType.WAVPACK, ContentType.ALAC):
- return int(pcm_size * 0.6)
- if content_type in (ContentType.MP3, ContentType.OGG, ContentType.M4A):
- return int(640000 * seconds)
- return 32000 * seconds
-
-
-async def _get_ffmpeg_args(
- streamdetails: StreamDetails,
- pcm_output_format: ContentType,
- pcm_sample_rate: int,
- pcm_channels: int = 2,
- seek_position: int = 0,
-) -> List[str]:
- """Collect all args to send to the ffmpeg process."""
- input_format = streamdetails.content_type
- assert pcm_output_format.is_pcm(), "Output format needs to be PCM"
-
- ffmpeg_present, libsoxr_support = await check_audio_support()
-
- if not ffmpeg_present:
- raise AudioError(
- "FFmpeg binary is missing from system."
- "Please install ffmpeg on your OS to enable playback.",
- )
- # collect input args
- input_args = [
- "ffmpeg",
- "-hide_banner",
- "-loglevel",
- "quiet",
- "-ignore_unknown",
- ]
- if streamdetails.direct:
- # ffmpeg can access the inputfile (or url) directly
- if streamdetails.direct.startswith("http"):
- # append reconnect options for direct stream from http
- input_args += [
- "-reconnect",
- "1",
- "-reconnect_streamed",
- "1",
- "-reconnect_on_network_error",
- "1",
- "-reconnect_on_http_error",
- "5xx",
- "-reconnect_delay_max",
- "10",
- ]
- if seek_position:
- input_args += ["-ss", str(seek_position)]
- input_args += ["-i", streamdetails.direct]
- else:
- # the input is received from pipe/stdin
- if streamdetails.content_type != ContentType.UNKNOWN:
- input_args += ["-f", input_format.value]
- input_args += ["-i", "-"]
-
- # collect output args
- output_args = [
- "-acodec",
- pcm_output_format.name.lower(),
- "-f",
- pcm_output_format.value,
- "-ac",
- str(pcm_channels),
- "-ar",
- str(pcm_sample_rate),
- "-",
- ]
- # collect extra and filter args
- extra_args = []
- filter_params = []
- if streamdetails.gain_correct is not None:
- filter_params.append(f"volume={streamdetails.gain_correct}dB")
- if (
- streamdetails.sample_rate != pcm_sample_rate
- and libsoxr_support
- and streamdetails.media_type == MediaType.TRACK
- ):
- # prefer libsoxr high quality resampler (if present) for sample rate conversions
- filter_params.append("aresample=resampler=soxr")
- if filter_params:
- extra_args += ["-af", ",".join(filter_params)]
-
- return input_args + extra_args + output_args
+++ /dev/null
-"""Several helper/utils to compare objects."""
-from __future__ import annotations
-
-from typing import List, Union
-
-from music_assistant.helpers.util import create_safe_string, create_sort_name
-from music_assistant.models.enums import AlbumType
-from music_assistant.models.media_items import (
- Album,
- Artist,
- ItemMapping,
- MediaItem,
- MediaItemMetadata,
- Track,
-)
-
-
-def loose_compare_strings(base: str, alt: str) -> bool:
- """Compare strings and return True even on partial match."""
- # this is used to display 'versions' of the same track/album
- # where we account for other spelling or some additional wording in the title
- word_count = len(base.split(" "))
- if word_count == 1 and len(base) < 10:
- return compare_strings(base, alt, False)
- base_comp = create_safe_string(base)
- alt_comp = create_safe_string(alt)
- if base_comp in alt_comp:
- return True
- if alt_comp in base_comp:
- return True
- return False
-
-
-def compare_strings(str1: str, str2: str, strict: bool = True) -> bool:
- """Compare strings and return True if we have an (almost) perfect match."""
- if str1 is None or str2 is None:
- return False
- # return early if total length mismatch
- if abs(len(str1) - len(str2)) > 2:
- return False
- if not strict:
- return create_safe_string(str1) == create_safe_string(str2)
- return create_sort_name(str1) == create_sort_name(str2)
-
-
-def compare_version(left_version: str, right_version: str) -> bool:
- """Compare version string."""
- if not left_version and not right_version:
- return True
- if not left_version and right_version:
- return False
- if left_version and not right_version:
- return False
- if " " not in left_version:
- return compare_strings(left_version, right_version)
- # do this the hard way as sometimes the version string is in the wrong order
- left_versions = left_version.lower().split(" ").sort()
- right_versions = right_version.lower().split(" ").sort()
- return left_versions == right_versions
-
-
-def compare_explicit(left: MediaItemMetadata, right: MediaItemMetadata) -> bool:
- """Compare if explicit is same in metadata."""
- if left.explicit is None or right.explicit is None:
- # explicitness info is not always present in metadata
- # only strict compare them if both have the info set
- return True
- return left == right
-
-
-def compare_artist(
- left_artist: Union[Artist, ItemMapping],
- right_artist: Union[Artist, ItemMapping],
-) -> bool:
- """Compare two artist items and return True if they match."""
- if left_artist is None or right_artist is None:
- return False
- # return early on exact item_id match
- if compare_item_ids(left_artist, right_artist):
- return True
-
- # prefer match on musicbrainz_id
- if getattr(left_artist, "musicbrainz_id", None) and getattr(
- right_artist, "musicbrainz_id", None
- ):
- return left_artist.musicbrainz_id == right_artist.musicbrainz_id
-
- # fallback to comparing
- return compare_strings(left_artist.name, right_artist.name, False)
-
-
-def compare_artists(
- left_artists: List[Union[Artist, ItemMapping]],
- right_artists: List[Union[Artist, ItemMapping]],
- any_match: bool = False,
-) -> bool:
- """Compare two lists of artist and return True if both lists match (exactly)."""
- matches = 0
- for left_artist in left_artists:
- for right_artist in right_artists:
- if compare_artist(left_artist, right_artist):
- if any_match:
- return True
- matches += 1
- return len(left_artists) == matches
-
-
-def compare_item_ids(
- left_item: Union[MediaItem, ItemMapping], right_item: Union[MediaItem, ItemMapping]
-) -> bool:
- """Compare item_id(s) of two media items."""
- if (
- left_item.provider == right_item.provider
- and left_item.item_id == right_item.item_id
- ):
- return True
-
- left_prov_ids = getattr(left_item, "provider_mappings", None)
- right_prov_ids = getattr(right_item, "provider_mappings", None)
-
- if left_prov_ids is not None:
- for prov_l in left_item.provider_mappings:
- if (
- prov_l.provider_type == right_item.provider
- and prov_l.item_id == right_item.item_id
- ):
- return True
-
- if right_prov_ids is not None:
- for prov_r in right_item.provider_mappings:
- if (
- prov_r.provider_type == left_item.provider
- and prov_r.item_id == left_item.item_id
- ):
- return True
-
- if left_prov_ids is not None and right_prov_ids is not None:
- for prov_l in left_item.provider_mappings:
- for prov_r in right_item.provider_mappings:
- if prov_l.provider_type != prov_r.provider_type:
- continue
- if prov_l.item_id == prov_r.item_id:
- return True
- return False
-
-
-def compare_albums(
- left_albums: List[Union[Album, ItemMapping]],
- right_albums: List[Union[Album, ItemMapping]],
-):
- """Compare two lists of albums and return True if a match was found."""
- for left_album in left_albums:
- for right_album in right_albums:
- if compare_album(left_album, right_album):
- return True
- return False
-
-
-def compare_album(
- left_album: Union[Album, ItemMapping],
- right_album: Union[Album, ItemMapping],
-):
- """Compare two album items and return True if they match."""
- if left_album is None or right_album is None:
- return False
- # return early on exact item_id match
- if compare_item_ids(left_album, right_album):
- return True
-
- # prefer match on UPC
- if getattr(left_album, "upc", None) and getattr(right_album, "upc", None):
- if (left_album.upc in right_album.upc) or (right_album.upc in left_album.upc):
- return True
- # prefer match on musicbrainz_id
- # not present on ItemMapping
- if getattr(left_album, "musicbrainz_id", None) and getattr(
- right_album, "musicbrainz_id", None
- ):
- return left_album.musicbrainz_id == right_album.musicbrainz_id
-
- # fallback to comparing
- if not compare_strings(left_album.name, right_album.name, False):
- return False
- if not compare_version(left_album.version, right_album.version):
- return False
- # compare album artist
- # Note: Not present on ItemMapping
- if hasattr(left_album, "artist") and hasattr(right_album, "artist"):
- if not compare_artist(left_album.artist, right_album.artist):
- return False
- return left_album.sort_name == right_album.sort_name
-
-
-def compare_track(left_track: Track, right_track: Track):
- """Compare two track items and return True if they match."""
- if left_track is None or right_track is None:
- return False
- # return early on exact item_id match
- if compare_item_ids(left_track, right_track):
- return True
- for left_isrc in left_track.isrcs:
- for right_isrc in right_track.isrcs:
- # ISRC is always 100% accurate match
- if left_isrc == right_isrc:
- return True
- if left_track.musicbrainz_id and right_track.musicbrainz_id:
- if left_track.musicbrainz_id == right_track.musicbrainz_id:
- # musicbrainz_id is always 100% accurate match
- return True
- # album is required for track linking
- if left_track.album is None or right_track.album is None:
- return False
- # track name must match
- if not compare_strings(left_track.name, right_track.name, False):
- return False
- # exact albumtrack match = 100% match
- if (
- compare_album(left_track.album, right_track.album)
- and left_track.track_number
- and right_track.track_number
- and left_track.disc_number == right_track.disc_number
- and left_track.track_number == right_track.track_number
- ):
- return True
- # track version must match
- if not compare_version(left_track.version, right_track.version):
- return False
- # track artist(s) must match
- if not compare_artists(left_track.artists, right_track.artists):
- return False
- # track if both tracks are (not) explicit
- if not compare_explicit(left_track.metadata, right_track.metadata):
- return False
- # exact album match = 100% match
- if left_track.albums and right_track.albums:
- for left_album in left_track.albums:
- for right_album in right_track.albums:
- if compare_album(left_album, right_album):
- return True
- # fallback: both albums are compilations and (near-exact) track duration match
- if (
- abs(left_track.duration - right_track.duration) <= 2
- and left_track.album.album_type in (AlbumType.UNKNOWN, AlbumType.COMPILATION)
- and right_track.album.album_type in (AlbumType.UNKNOWN, AlbumType.COMPILATION)
- ):
- return True
- return False
+++ /dev/null
-"""Helpers for date and time."""
-from __future__ import annotations
-
-import datetime
-
-LOCAL_TIMEZONE = datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo
-
-
-def utc() -> datetime.datetime:
- """Get current UTC datetime."""
- return datetime.datetime.now(datetime.timezone.utc)
-
-
-def utc_timestamp() -> float:
- """Return UTC timestamp in seconds as float."""
- return utc().timestamp()
-
-
-def now() -> datetime.datetime:
- """Get current datetime in local timezone."""
- return datetime.datetime.now(LOCAL_TIMEZONE)
-
-
-def now_timestamp() -> float:
- """Return current datetime as timestamp in local timezone."""
- return now().timestamp()
-
-
-def future_timestamp(**kwargs) -> float:
- """Return current timestamp + timedelta."""
- return (now() + datetime.timedelta(**kwargs)).timestamp()
-
-
-def from_utc_timestamp(timestamp: float) -> datetime.datetime:
- """Return datetime from UTC timestamp."""
- return datetime.datetime.fromtimestamp(timestamp, datetime.timezone.utc)
-
-
-def iso_from_utc_timestamp(timestamp: float) -> str:
- """Return ISO 8601 datetime string from UTC timestamp."""
- return from_utc_timestamp(timestamp).isoformat()
+++ /dev/null
-"""Utilities for image manipulation and retrieval."""
-from __future__ import annotations
-
-import random
-from io import BytesIO
-from typing import TYPE_CHECKING, List, Optional
-
-from PIL import Image
-
-from music_assistant.controllers.database import TABLE_THUMBS
-from music_assistant.helpers.tags import get_embedded_image
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-
-async def get_image_data(mass: MusicAssistant, path: str) -> bytes:
- """Create thumbnail from image url."""
- # return from db if exists
- match = {"path": path, "size": 0}
- if result := await mass.database.get_row(TABLE_THUMBS, match):
- return result["data"]
- # always try ffmpeg first to get the image because it supports
- # both online and offline image files as well as embedded images in media files
- img_data = await get_embedded_image(path)
- if img_data:
- return img_data
- # assume file from file provider, we need to fetch it here...
- for prov in mass.music.providers:
- if not prov.type.is_file():
- continue
- if not await prov.exists(path):
- continue
- path = await prov.resolve(path)
- img_data = await get_embedded_image(path)
- if img_data:
- return img_data
- raise FileNotFoundError(f"Image not found: {path}")
-
-
-async def create_thumbnail(
- mass: MusicAssistant, path: str, size: Optional[int]
-) -> bytes:
- """Create thumbnail from image url."""
- img_data = await get_image_data(mass, path)
-
- def _create_image():
- data = BytesIO(img_data)
- img = Image.open(data)
- if size:
- img.thumbnail((size, size), Image.ANTIALIAS)
- img.convert("RGB").save(data, "PNG", optimize=True)
- return data.getvalue()
-
- return await mass.loop.run_in_executor(None, _create_image)
-
-
-async def create_collage(mass: MusicAssistant, images: List[str]) -> bytes:
- """Create a basic collage image from multiple image urls."""
-
- def _new_collage():
- return Image.new("RGBA", (1500, 1500), color=(255, 255, 255, 255))
-
- collage = await mass.loop.run_in_executor(None, _new_collage)
-
- def _add_to_collage(img_data: bytes, coord_x: int, coord_y: int):
- data = BytesIO(img_data)
- photo = Image.open(data).convert("RGBA")
- photo = photo.resize((500, 500))
- collage.paste(photo, (coord_x, coord_y))
-
- for x_co in range(0, 1500, 500):
- for y_co in range(0, 1500, 500):
- img_data = await get_image_data(mass, random.choice(images))
- await mass.loop.run_in_executor(None, _add_to_collage, img_data, x_co, y_co)
-
- def _save_collage():
- final_data = BytesIO()
- collage.convert("RGB").save(final_data, "PNG", optimize=True)
- return final_data.getvalue()
-
- return await mass.loop.run_in_executor(None, _save_collage)
+++ /dev/null
-"""Various helpers for web requests."""
-from __future__ import annotations
-
-import asyncio
-import json
-
-
-def serialize_values(obj):
- """Recursively create serializable values for (custom) data types."""
-
- def get_val(val):
- if (
- isinstance(val, (list, set, filter, tuple))
- or val.__class__ == "dict_valueiterator"
- ):
- return [get_val(x) for x in val] if val else []
- if isinstance(val, dict):
- return {key: get_val(value) for key, value in val.items()}
- try:
- return val.to_dict()
- except AttributeError:
- return val
- except Exception: # pylint: disable=broad-except
- return val
-
- return get_val(obj)
-
-
-def json_serializer(data):
- """Json serializer to recursively create serializable values for custom data types."""
- return json.dumps(serialize_values(data))
-
-
-async def async_json_serializer(data):
- """Run json serializer in executor for large data."""
- if isinstance(data, list) and len(data) > 100:
- return await asyncio.get_running_loop().run_in_executor(
- None, json_serializer, data
- )
- return json_serializer(data)
+++ /dev/null
-"""Helpers for parsing playlists."""
-from __future__ import annotations
-
-import asyncio
-import logging
-from typing import TYPE_CHECKING, List
-
-import aiohttp
-
-from music_assistant.models.errors import InvalidDataError
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-
-LOGGER = logging.getLogger(__name__)
-
-
-async def parse_m3u(m3u_data: str) -> List[str]:
- """Parse (only) filenames/urls from m3u playlist file."""
- m3u_lines = m3u_data.splitlines()
- lines = []
- for line in m3u_lines:
- line = line.strip()
- if line.startswith("#"):
- # ignore metadata
- continue
- if len(line) != 0:
- # Get uri/path from all other, non-blank lines
- lines.append(line)
-
- return lines
-
-
-async def parse_pls(pls_data: str) -> List[str]:
- """Parse (only) filenames/urls from pls playlist file."""
- pls_lines = pls_data.splitlines()
- lines = []
- for line in pls_lines:
- line = line.strip()
- if not line.startswith("File"):
- # ignore metadata lines
- continue
- if "=" in line:
- # Get uri/path from all other, non-blank lines
- lines.append(line.split("=")[1])
-
- return lines
-
-
-async def fetch_playlist(mass: MusicAssistant, url: str) -> List[str]:
- """Parse an online m3u or pls playlist."""
-
- try:
- async with mass.http_session.get(url, timeout=5) as resp:
- charset = resp.charset or "utf-8"
- try:
- playlist_data = (await resp.content.read(64 * 1024)).decode(charset)
- except ValueError as err:
- raise InvalidDataError(f"Could not decode playlist {url}") from err
- except asyncio.TimeoutError as err:
- raise InvalidDataError(f"Timeout while fetching playlist {url}") from err
- except aiohttp.client_exceptions.ClientError as err:
- raise InvalidDataError(f"Error while fetching playlist {url}") from err
-
- if url.endswith(".m3u") or url.endswith(".m3u8"):
- playlist = await parse_m3u(playlist_data)
- else:
- playlist = await parse_pls(playlist_data)
-
- if not playlist:
- raise InvalidDataError(f"Empty playlist {url}")
-
- return playlist
+++ /dev/null
-"""
-Implementation of a (truly) non blocking subprocess.
-
-The subprocess implementation in asyncio can (still) sometimes cause deadlocks,
-even when properly handling reading/writes from different tasks.
-"""
-from __future__ import annotations
-
-import asyncio
-import logging
-from typing import AsyncGenerator, Coroutine, List, Optional, Tuple, Union
-
-from async_timeout import timeout as _timeout
-
-LOGGER = logging.getLogger(__name__)
-
-DEFAULT_CHUNKSIZE = 128000
-DEFAULT_TIMEOUT = 600
-DEFAULT_LIMIT = 64 * 1024 * 1024
-
-# pylint: disable=invalid-name
-
-
-class AsyncProcess:
- """Implementation of a (truly) non blocking subprocess."""
-
- def __init__(
- self,
- args: Union[List, str],
- enable_stdin: bool = False,
- enable_stdout: bool = True,
- enable_stderr: bool = False,
- ):
- """Initialize."""
- self._proc = None
- self._args = args
- self._enable_stdin = enable_stdin
- self._enable_stdout = enable_stdout
- self._enable_stderr = enable_stderr
- self._attached_task: asyncio.Task = None
- self.closed = False
-
- async def __aenter__(self) -> "AsyncProcess":
- """Enter context manager."""
- if "|" in self._args:
- args = " ".join(self._args)
- else:
- args = self._args
- if isinstance(args, str):
- self._proc = await asyncio.create_subprocess_shell(
- args,
- stdin=asyncio.subprocess.PIPE if self._enable_stdin else None,
- stdout=asyncio.subprocess.PIPE if self._enable_stdout else None,
- stderr=asyncio.subprocess.PIPE if self._enable_stderr else None,
- limit=DEFAULT_LIMIT,
- close_fds=True,
- )
- else:
- self._proc = await asyncio.create_subprocess_exec(
- *args,
- stdin=asyncio.subprocess.PIPE if self._enable_stdin else None,
- stdout=asyncio.subprocess.PIPE if self._enable_stdout else None,
- stderr=asyncio.subprocess.PIPE if self._enable_stderr else None,
- limit=DEFAULT_LIMIT,
- close_fds=True,
- )
-
- # Fix BrokenPipeError due to a race condition
- # by attaching a default done callback
- def _done_cb(fut: asyncio.Future):
- fut.exception()
-
- self._proc._transport._protocol._stdin_closed.add_done_callback(_done_cb)
-
- return self
-
- async def __aexit__(self, exc_type, exc_value, traceback) -> bool:
- """Exit context manager."""
- self.closed = True
- if self._attached_task:
- # cancel the attached reader/writer task
- try:
- self._attached_task.cancel()
- await self._attached_task
- except asyncio.CancelledError:
- pass
- if self._proc.returncode is None:
- # prevent subprocess deadlocking, read remaining bytes
- await self._proc.communicate()
- if self._enable_stdout and not self._proc.stdout.at_eof():
- await self._proc.stdout.read()
- if self._enable_stderr and not self._proc.stderr.at_eof():
- await self._proc.stderr.read()
- if self._proc.returncode is None:
- # just in case?
- self._proc.kill()
-
- async def iter_chunked(
- self, n: int = DEFAULT_CHUNKSIZE
- ) -> AsyncGenerator[bytes, None]:
- """Yield chunks of n size from the process stdout."""
- while True:
- chunk = await self.readexactly(n)
- if chunk == b"":
- break
- yield chunk
- if len(chunk) < n:
- break
-
- async def iter_any(self, n: int = DEFAULT_CHUNKSIZE) -> AsyncGenerator[bytes, None]:
- """Yield chunks as they come in from process stdout."""
- while True:
- chunk = await self.read(n)
- if chunk == b"":
- break
- yield chunk
-
- async def readexactly(self, n: int, timeout: int = DEFAULT_TIMEOUT) -> bytes:
- """Read exactly n bytes from the process stdout (or less if eof)."""
- try:
- async with _timeout(timeout):
- return await self._proc.stdout.readexactly(n)
- except asyncio.IncompleteReadError as err:
- return err.partial
-
- async def read(self, n: int, timeout: int = DEFAULT_TIMEOUT) -> bytes:
- """
- Read up to n bytes from the stdout stream.
-
- If n is positive, this function try to read n bytes,
- and may return less or equal bytes than requested, but at least one byte.
- If EOF was received before any byte is read, this function returns empty byte object.
- """
- async with _timeout(timeout):
- return await self._proc.stdout.read(n)
-
- async def write(self, data: bytes) -> None:
- """Write data to process stdin."""
- if self.closed or self._proc.stdin.is_closing():
- raise asyncio.CancelledError()
- self._proc.stdin.write(data)
- await self._proc.stdin.drain()
-
- def write_eof(self) -> None:
- """Write end of file to to process stdin."""
- try:
- if self._proc.stdin.can_write_eof():
- self._proc.stdin.write_eof()
- except (
- AttributeError,
- AssertionError,
- BrokenPipeError,
- RuntimeError,
- ConnectionResetError,
- ):
- # already exited, race condition
- return
-
- async def communicate(
- self, input_data: Optional[bytes] = None
- ) -> Tuple[bytes, bytes]:
- """Write bytes to process and read back results."""
- return await self._proc.communicate(input_data)
-
- def attach_task(self, coro: Coroutine) -> asyncio.Task:
- """Attach given coro func as reader/writer task to properly cancel it when needed."""
- self._attached_task = task = asyncio.create_task(coro)
- return task
-
-
-async def check_output(shell_cmd: str) -> Tuple[int, bytes]:
- """Run shell subprocess and return output."""
- proc = await asyncio.create_subprocess_shell(
- shell_cmd,
- stderr=asyncio.subprocess.STDOUT,
- stdout=asyncio.subprocess.PIPE,
- )
- stdout, _ = await proc.communicate()
- return (proc.returncode, stdout)
+++ /dev/null
-"""Helpers/utilities to parse ID3 tags from audio files with ffmpeg."""
-from __future__ import annotations
-
-import json
-import os
-from dataclasses import dataclass
-from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union
-
-from requests import JSONDecodeError
-
-from music_assistant.constants import UNKNOWN_ARTIST
-from music_assistant.helpers.process import AsyncProcess
-from music_assistant.helpers.util import try_parse_int
-from music_assistant.models.errors import InvalidDataError
-
-# the only multi-item splitter we accept is the semicolon,
-# which is also the default in Musicbrainz Picard.
-# the slash is also a common splitter but causes colissions with
-# artists actually containing a slash in the name, such as ACDC
-TAG_SPLITTER = ";"
-
-
-def split_items(org_str: str) -> Tuple[str]:
- """Split up a tags string by common splitter."""
- if not org_str:
- return tuple()
- if isinstance(org_str, list):
- return org_str
- return tuple(x.strip() for x in org_str.split(TAG_SPLITTER))
-
-
-def split_artists(org_artists: Union[str, Tuple[str]]) -> Tuple[str]:
- """Parse all artists from a string."""
- final_artists = set()
- # when not using the multi artist tag, the artist string may contain
- # multiple artistsin freeform, even featuring artists may be included in this
- # string. Try to parse the featuring artists and seperate them.
- splitters = ("featuring", " feat. ", " feat ", "feat.")
- for item in split_items(org_artists):
- for splitter in splitters:
- for subitem in item.split(splitter):
- final_artists.add(subitem.strip())
- return tuple(final_artists)
-
-
-@dataclass
-class AudioTags:
- """Audio metadata parsed from an audio file."""
-
- raw: Dict[str, Any]
- sample_rate: int
- channels: int
- bits_per_sample: int
- format: str
- bit_rate: int
- duration: Optional[int]
- tags: Dict[str, str]
- has_cover_image: bool
- filename: str
-
- @property
- def title(self) -> str:
- """Return title tag (as-is)."""
- if tag := self.tags.get("title"):
- return tag
- # fallback to parsing from filename
- title = self.filename.rsplit(os.sep, 1)[-1].split(".")[0]
- if " - " in title:
- title_parts = title.split(" - ")
- if len(title_parts) >= 2:
- return title_parts[1].strip()
- return title
-
- @property
- def album(self) -> str:
- """Return album tag (as-is) if present."""
- return self.tags.get("album")
-
- @property
- def artists(self) -> Tuple[str]:
- """Return track artists."""
- # prefer multi-artist tag
- if tag := self.tags.get("artists"):
- return split_items(tag)
- # fallback to regular artist string
- if tag := self.tags.get("artist"):
- if ";" in tag:
- return split_items(tag)
- return split_artists(tag)
- # fallback to parsing from filename
- title = self.filename.rsplit(os.sep, 1)[-1].split(".")[0]
- if " - " in title:
- title_parts = title.split(" - ")
- if len(title_parts) >= 2:
- return split_artists(title_parts[0])
- return (UNKNOWN_ARTIST,)
-
- @property
- def album_artists(self) -> Tuple[str]:
- """Return (all) album artists (if any)."""
- # prefer multi-artist tag
- if tag := self.tags.get("albumartists"):
- return split_items(tag)
- # fallback to regular artist string
- if tag := self.tags.get("albumartist"):
- if ";" in tag:
- return split_items(tag)
- return split_artists(tag)
- return tuple()
-
- @property
- def genres(self) -> Tuple[str]:
- """Return (all) genres, if any."""
- return split_items(self.tags.get("genre"))
-
- @property
- def disc(self) -> int | None:
- """Return disc tag if present."""
- if tag := self.tags.get("disc"):
- return try_parse_int(tag.split("/")[0], None)
- return None
-
- @property
- def track(self) -> int | None:
- """Return track tag if present."""
- if tag := self.tags.get("track"):
- return try_parse_int(tag.split("/")[0], None)
- return None
-
- @property
- def year(self) -> int | None:
- """Return album's year if present, parsed from date."""
- if tag := self.tags.get("originalyear"):
- return try_parse_int(tag.split("-")[0], None)
- if tag := self.tags.get("originaldate"):
- return try_parse_int(tag.split("-")[0], None)
- if tag := self.tags.get("date"):
- return try_parse_int(tag.split("-")[0], None)
- return None
-
- @property
- def musicbrainz_artistids(self) -> Tuple[str]:
- """Return musicbrainz_artistid tag(s) if present."""
- return split_items(self.tags.get("musicbrainzartistid"))
-
- @property
- def musicbrainz_albumartistids(self) -> Tuple[str]:
- """Return musicbrainz_albumartistid tag if present."""
- return split_items(self.tags.get("musicbrainzalbumartistid"))
-
- @property
- def musicbrainz_releasegroupid(self) -> str | None:
- """Return musicbrainz_releasegroupid tag if present."""
- return self.tags.get("musicbrainzreleasegroupid")
-
- @property
- def musicbrainz_trackid(self) -> str | None:
- """Return musicbrainz_trackid tag if present."""
- if tag := self.tags.get("musicbrainztrackid"):
- return tag
- return self.tags.get("musicbrainzreleasetrackid")
-
- @property
- def album_type(self) -> str | None:
- """Return albumtype tag if present."""
- if tag := self.tags.get("musicbrainzalbumtype"):
- return tag
- return self.tags.get("releasetype")
-
- @classmethod
- def parse(cls, raw: dict) -> "AudioTags":
- """Parse instance from raw ffmpeg info output."""
- audio_stream = next(x for x in raw["streams"] if x["codec_type"] == "audio")
- has_cover_image = any(
- x for x in raw["streams"] if x["codec_name"] in ("mjpeg", "png")
- )
- # convert all tag-keys (gathered from all streams) to lowercase without spaces
- tags = {}
- for stream in raw["streams"] + [raw["format"]]:
- for key, value in stream.get("tags", {}).items():
- key = key.lower().replace(" ", "").replace("_", "")
- tags[key] = value
-
- return AudioTags(
- raw=raw,
- sample_rate=int(audio_stream.get("sample_rate", 44100)),
- channels=audio_stream.get("channels", 2),
- bits_per_sample=int(
- audio_stream.get(
- "bits_per_raw_sample", audio_stream.get("bits_per_sample")
- )
- or 16
- ),
- format=raw["format"]["format_name"],
- bit_rate=int(raw["format"].get("bit_rate", 320)),
- duration=int(float(raw["format"].get("duration", 0))) or None,
- tags=tags,
- has_cover_image=has_cover_image,
- filename=raw["format"]["filename"],
- )
-
- def get(self, key: str, default=None) -> Any:
- """Get tag by key."""
- return self.tags.get(key, default)
-
-
-async def parse_tags(input_file: Union[str, AsyncGenerator[bytes, None]]) -> AudioTags:
- """
- Parse tags from a media file.
-
- input_file may be a (local) filename/url accessible by ffmpeg or
- an AsyncGenerator which yields the file contents as bytes.
- """
- file_path = input_file if isinstance(input_file, str) else "-"
-
- args = (
- "ffprobe",
- "-hide_banner",
- "-loglevel",
- "fatal",
- "-show_error",
- "-show_format",
- "-show_streams",
- "-print_format",
- "json",
- "-i",
- file_path,
- )
-
- async with AsyncProcess(
- args, enable_stdin=file_path == "-", enable_stdout=True, enable_stderr=False
- ) as proc:
-
- if file_path == "-":
- # feed the file contents to the process
- async def chunk_feeder():
- # pylint: disable=protected-access
- async for chunk in input_file:
- try:
- await proc.write(chunk)
- except BrokenPipeError:
- break # race-condition: read enough data for tags
-
- proc.attach_task(chunk_feeder())
-
- try:
- res = await proc.read(-1)
- data = json.loads(res)
- if error := data.get("error"):
- raise InvalidDataError(error["string"])
- return AudioTags.parse(data)
- except (KeyError, ValueError, JSONDecodeError, InvalidDataError) as err:
- raise InvalidDataError(
- f"Unable to retrieve info for {file_path}: {str(err)}"
- ) from err
-
-
-async def get_embedded_image(
- input_file: Union[str, AsyncGenerator[bytes, None]]
-) -> bytes | None:
- """
- Return embedded image data.
-
- input_file may be a (local) filename/url accessible by ffmpeg or
- an AsyncGenerator which yields the file contents as bytes.
- """
- file_path = input_file if isinstance(input_file, str) else "-"
- args = (
- "ffmpeg",
- "-hide_banner",
- "-loglevel",
- "fatal",
- "-i",
- file_path,
- "-map",
- "0:v",
- "-c",
- "copy",
- "-f",
- "mjpeg",
- "-",
- )
-
- async with AsyncProcess(
- args, enable_stdin=file_path == "-", enable_stdout=True, enable_stderr=False
- ) as proc:
-
- if file_path == "-":
- # feed the file contents to the process
- async for chunk in input_file:
- await proc.write(chunk)
-
- if file_path == "-":
- # feed the file contents to the process
- async def chunk_feeder():
- async for chunk in input_file:
- await proc.write(chunk)
-
- proc.attach_task(chunk_feeder())
-
- return await proc.read(-1)
+++ /dev/null
-"""Helpers for creating/parsing URI's."""
-
-import os
-from typing import Tuple
-
-from music_assistant.models.enums import MediaType, ProviderType
-from music_assistant.models.errors import MusicAssistantError
-
-
-def parse_uri(uri: str) -> Tuple[MediaType, ProviderType, str]:
- """
- Try to parse URI to Mass identifiers.
-
- Returns Tuple: MediaType, provider, item_id
- """
- try:
- if uri.startswith("https://open."):
- # public share URL (e.g. Spotify or Qobuz, not sure about others)
- # https://open.spotify.com/playlist/5lH9NjOeJvctAO92ZrKQNB?si=04a63c8234ac413e
- provider = ProviderType.parse(uri.split(".")[1])
- media_type_str = uri.split("/")[3]
- media_type = MediaType(media_type_str)
- item_id = uri.split("/")[4].split("?")[0]
- elif uri.startswith("http://") or uri.startswith("https://"):
- # Translate a plain URL to the URL provider
- provider = ProviderType.URL
- media_type = MediaType.UNKNOWN
- item_id = uri
- elif "://" in uri:
- # music assistant-style uri
- # provider://media_type/item_id
- provider = ProviderType.parse(uri.split("://")[0])
- media_type_str = uri.split("/")[2]
- media_type = MediaType(media_type_str)
- item_id = uri.split(f"{media_type_str}/")[1]
- elif ":" in uri:
- # spotify new-style uri
- provider, media_type_str, item_id = uri.split(":")
- provider = ProviderType.parse(provider)
- media_type = MediaType(media_type_str)
- elif os.path.isfile(uri):
- # Translate a local file (which is not from file provider) to the URL provider
- provider = ProviderType.URL
- media_type = MediaType.TRACK
- item_id = uri
- else:
- raise KeyError
- except (TypeError, AttributeError, ValueError, KeyError) as err:
- raise MusicAssistantError(f"Not a valid Music Assistant uri: {uri}") from err
- return (media_type, provider, item_id)
-
-
-def create_uri(media_type: MediaType, provider: ProviderType, item_id: str) -> str:
- """Create Music Assistant URI from MediaItem values."""
- return f"{provider.value}://{media_type.value}/{item_id}"
+++ /dev/null
-"""Helper and utility functions."""
-from __future__ import annotations
-
-import asyncio
-import os
-import platform
-import re
-import socket
-import tempfile
-from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar
-
-import memory_tempfile
-import unidecode
-
-# pylint: disable=invalid-name
-T = TypeVar("T")
-_UNDEF: dict = {}
-CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable)
-CALLBACK_TYPE = Callable[[], None]
-# pylint: enable=invalid-name
-
-
-def filename_from_string(string: str) -> str:
- """Create filename from unsafe string."""
- keepcharacters = (" ", ".", "_")
- return "".join(c for c in string if c.isalnum() or c in keepcharacters).rstrip()
-
-
-def try_parse_int(possible_int: Any, default: Optional[int] = 0) -> Optional[int]:
- """Try to parse an int."""
- try:
- return int(possible_int)
- except (TypeError, ValueError):
- return default
-
-
-def try_parse_float(
- possible_float: Any, default: Optional[float] = 0.0
-) -> Optional[float]:
- """Try to parse a float."""
- try:
- return float(possible_float)
- except (TypeError, ValueError):
- return default
-
-
-def try_parse_bool(possible_bool: Any) -> str:
- """Try to parse a bool."""
- if isinstance(possible_bool, bool):
- return possible_bool
- return possible_bool in ["true", "True", "1", "on", "ON", 1]
-
-
-def create_safe_string(input_str: str) -> str:
- """Return clean lowered string for compare actions."""
- input_str = input_str.lower().strip()
- unaccented_string = unidecode.unidecode(input_str)
- return re.sub(r"[^a-zA-Z0-9]", "", unaccented_string)
-
-
-def create_sort_name(input_str: str) -> str:
- """Create sort name/title from string."""
- input_str = input_str.lower().strip()
- for item in ["the ", "de ", "les "]:
- if input_str.startswith(item):
- input_str = input_str.replace(item, "")
- return input_str.strip()
-
-
-def parse_title_and_version(title: str, track_version: str = None):
- """Try to parse clean track title and version from the title."""
- version = ""
- for splitter in [" (", " [", " - ", " (", " [", "-"]:
- if splitter in title:
- title_parts = title.split(splitter)
- for title_part in title_parts:
- # look for the end splitter
- for end_splitter in [")", "]"]:
- if end_splitter in title_part:
- title_part = title_part.split(end_splitter)[0]
- for version_str in [
- "version",
- "live",
- "edit",
- "remix",
- "mix",
- "acoustic",
- "instrumental",
- "karaoke",
- "remaster",
- "versie",
- "radio",
- "unplugged",
- "disco",
- "akoestisch",
- "deluxe",
- ]:
- if version_str in title_part.lower():
- version = title_part
- title = title.split(splitter + version)[0]
- title = clean_title(title)
- if not version and track_version:
- version = track_version
- version = get_version_substitute(version).title()
- if version == title:
- version = ""
- return title, version
-
-
-def clean_title(title: str) -> str:
- """Strip unwanted additional text from title."""
- for splitter in [" (", " [", " - ", " (", " [", "-"]:
- if splitter in title:
- title_parts = title.split(splitter)
- for title_part in title_parts:
- # look for the end splitter
- for end_splitter in [")", "]"]:
- if end_splitter in title_part:
- title_part = title_part.split(end_splitter)[0]
- for ignore_str in ["feat.", "featuring", "ft.", "with ", "explicit"]:
- if ignore_str in title_part.lower():
- return title.split(splitter + title_part)[0].strip()
- return title.strip()
-
-
-def get_version_substitute(version_str: str):
- """Transform provider version str to universal version type."""
- version_str = version_str.lower()
- # substitute edit and edition with version
- if "edition" in version_str or "edit" in version_str:
- version_str = version_str.replace(" edition", " version")
- version_str = version_str.replace(" edit ", " version")
- if version_str.startswith("the "):
- version_str = version_str.split("the ")[1]
- if "radio mix" in version_str:
- version_str = "radio version"
- elif "video mix" in version_str:
- version_str = "video version"
- elif "spanglish" in version_str or "spanish" in version_str:
- version_str = "spanish version"
- elif version_str.endswith("remaster"):
- version_str = "remaster"
- elif version_str.endswith("remastered"):
- version_str = "remaster"
- return version_str.strip()
-
-
-def get_ip():
- """Get primary IP-address for this host."""
- # pylint: disable=broad-except,no-member
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- try:
- # doesn't even have to be reachable
- sock.connect(("10.255.255.255", 1))
- _ip = sock.getsockname()[0]
- except Exception:
- _ip = "127.0.0.1"
- finally:
- sock.close()
- return _ip
-
-
-def is_port_in_use(port: int) -> bool:
- """Check if port is in use."""
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as _sock:
- try:
- return _sock.connect_ex(("localhost", port)) == 0
- except socket.gaierror:
- return True
-
-
-def select_stream_port() -> int:
- """Automatically find available stream port, prefer the default 8095."""
- for port in range(8095, 8195):
- if not is_port_in_use(port):
- return port
-
-
-async def get_ip_from_host(dns_name: str) -> str:
- """Resolve (first) IP-address for given dns name."""
-
- def _resolve():
- try:
- return socket.gethostbyname(dns_name)
- except Exception: # pylint: disable=broad-except
- # fail gracefully!
- return dns_name
-
- loop = asyncio.get_event_loop()
- return await loop.run_in_executor(None, _resolve)
-
-
-def get_folder_size(folderpath):
- """Return folder size in gb."""
- total_size = 0
- # pylint: disable=unused-variable
- for dirpath, dirnames, filenames in os.walk(folderpath):
- for _file in filenames:
- _fp = os.path.join(dirpath, _file)
- total_size += os.path.getsize(_fp)
- # pylint: enable=unused-variable
- total_size_gb = total_size / float(1 << 30)
- return total_size_gb
-
-
-def merge_dict(base_dict: dict, new_dict: dict, allow_overwite=False):
- """Merge dict without overwriting existing values."""
- final_dict = base_dict.copy()
- for key, value in new_dict.items():
- if final_dict.get(key) and isinstance(value, dict):
- final_dict[key] = merge_dict(final_dict[key], value)
- if final_dict.get(key) and isinstance(value, tuple):
- final_dict[key] = merge_tuples(final_dict[key], value)
- if final_dict.get(key) and isinstance(value, list):
- final_dict[key] = merge_lists(final_dict[key], value)
- elif not final_dict.get(key) or allow_overwite:
- final_dict[key] = value
- return final_dict
-
-
-def merge_tuples(base: tuple, new: tuple) -> Tuple:
- """Merge 2 tuples."""
- return tuple(x for x in base if x not in new) + tuple(new)
-
-
-def merge_lists(base: list, new: list) -> list:
- """Merge 2 lists."""
- return list(x for x in base if x not in new) + list(new)
-
-
-def create_tempfile():
- """Return a (named) temporary file."""
- if platform.system() == "Linux":
- return memory_tempfile.MemoryTempfile(fallback=True).NamedTemporaryFile(
- buffering=0
- )
- return tempfile.NamedTemporaryFile(buffering=0)
-
-
-def get_changed_keys(
- dict1: Dict[str, Any],
- dict2: Dict[str, Any],
- ignore_keys: Optional[List[str]] = None,
-) -> Set[str]:
- """Compare 2 dicts and return set of changed keys."""
- if not dict2:
- return set(dict1.keys())
- changed_keys = set()
- for key, value in dict2.items():
- if ignore_keys and key in ignore_keys:
- continue
- if key not in dict1:
- changed_keys.add(key)
- elif isinstance(value, dict):
- changed_keys.update(get_changed_keys(dict1[key], value))
- elif dict1[key] != value:
- changed_keys.add(key)
- return changed_keys
+++ /dev/null
-"""Main Music Assistant class."""
-from __future__ import annotations
-
-import asyncio
-import logging
-from collections import deque
-from functools import partial
-from time import time
-from types import TracebackType
-from typing import Any, Callable, Coroutine, Deque, List, Optional, Tuple, Type, Union
-from uuid import uuid4
-
-import aiohttp
-
-from music_assistant.constants import ROOT_LOGGER_NAME
-from music_assistant.controllers.cache import CacheController
-from music_assistant.controllers.database import DatabaseController
-from music_assistant.controllers.metadata.metadata import MetaDataController
-from music_assistant.controllers.music import MusicController
-from music_assistant.controllers.players import PlayerController
-from music_assistant.controllers.streams import StreamsController
-from music_assistant.models.background_job import BackgroundJob
-from music_assistant.models.config import MassConfig
-from music_assistant.models.enums import EventType, JobStatus
-from music_assistant.models.event import MassEvent
-
-EventCallBackType = Callable[[MassEvent], None]
-EventSubscriptionType = Tuple[
- EventCallBackType, Optional[Tuple[EventType]], Optional[Tuple[str]]
-]
-
-
-class MusicAssistant:
- """Main MusicAssistant object."""
-
- def __init__(
- self,
- config: MassConfig,
- session: Optional[aiohttp.ClientSession] = None,
- ) -> None:
- """
- Create an instance of MusicAssistant.
-
- config: Music Assistant runtimestartup Config
- session: Optionally provide an aiohttp clientsession
- """
-
- self.config = config
- self.loop: asyncio.AbstractEventLoop = None
- self.http_session: aiohttp.ClientSession = session
- self.http_session_provided = session is not None
- self.logger = logging.getLogger(ROOT_LOGGER_NAME)
-
- self._listeners = []
- self._jobs: Deque[BackgroundJob] = deque()
- self._jobs_event = asyncio.Event()
-
- # init core controllers
- self.database = DatabaseController(self)
- self.cache = CacheController(self)
- self.metadata = MetaDataController(self)
- self.music = MusicController(self)
- self.players = PlayerController(self)
- self.streams = StreamsController(self)
- self._tracked_tasks: List[asyncio.Task] = []
- self.closed = False
-
- async def setup(self) -> None:
- """Async setup of music assistant."""
- # initialize loop
- self.loop = asyncio.get_running_loop()
- # create shared aiohttp ClientSession
- if not self.http_session:
- self.http_session = aiohttp.ClientSession(
- loop=self.loop,
- connector=aiohttp.TCPConnector(ssl=False),
- )
- # setup core controllers
- await self.database.setup()
- await self.cache.setup()
- await self.music.setup()
- await self.metadata.setup()
- await self.players.setup()
- await self.streams.setup()
- self.create_task(self.__process_jobs())
-
- async def stop(self) -> None:
- """Stop running the music assistant server."""
- self.logger.info("Stop called, cleaning up...")
- await self.players.cleanup()
- # cancel all running tasks
- for task in self._tracked_tasks:
- task.cancel()
- self.signal_event(MassEvent(EventType.SHUTDOWN))
- await self.database.close()
- self.closed = True
- if self.http_session and not self.http_session_provided:
- await self.http_session.close()
-
- def signal_event(self, event: MassEvent) -> None:
- """Signal event to subscribers."""
- if self.closed:
- return
- if self.logger.isEnabledFor(logging.DEBUG):
- if event.type != EventType.QUEUE_TIME_UPDATED:
- # do not log queue time updated events because that is too chatty
- self.logger.getChild("event").debug(
- "%s %s", event.type.value, event.object_id or ""
- )
- for cb_func, event_filter, id_filter in self._listeners:
- if not (event_filter is None or event.type in event_filter):
- continue
- if not (id_filter is None or event.object_id in id_filter):
- continue
- if asyncio.iscoroutinefunction(cb_func):
- asyncio.run_coroutine_threadsafe(cb_func(event), self.loop)
- else:
- self.loop.call_soon_threadsafe(cb_func, event)
-
- def subscribe(
- self,
- cb_func: EventCallBackType,
- event_filter: Union[EventType, Tuple[EventType], None] = None,
- id_filter: Union[str, Tuple[str], None] = None,
- ) -> Callable:
- """
- Add callback to event listeners.
-
- Returns function to remove the listener.
- :param cb_func: callback function or coroutine
- :param event_filter: Optionally only listen for these events
- :param id_filter: Optionally only listen for these id's (player_id, queue_id, uri)
- """
- if isinstance(event_filter, EventType):
- event_filter = (event_filter,)
- if isinstance(id_filter, str):
- id_filter = (id_filter,)
- listener = (cb_func, event_filter, id_filter)
- self._listeners.append(listener)
-
- def remove_listener():
- self._listeners.remove(listener)
-
- return remove_listener
-
- def add_job(
- self, coro: Coroutine, name: Optional[str] = None, allow_duplicate=False
- ) -> BackgroundJob:
- """Add job to be (slowly) processed in the background."""
- if not allow_duplicate:
- if existing := next((x for x in self._jobs if x.name == name), None):
- self.logger.debug("Ignored duplicate job: %s", name)
- coro.close()
- return existing
- if not name:
- name = coro.__qualname__ or coro.__name__
- job = BackgroundJob(str(uuid4()), name=name, coro=coro)
- self._jobs.append(job)
- self._jobs_event.set()
- self.signal_event(
- MassEvent(EventType.BACKGROUND_JOB_UPDATED, job.name, data=job)
- )
- return job
-
- def create_task(
- self,
- target: Coroutine,
- *args: Any,
- **kwargs: Any,
- ) -> Union[asyncio.Task, asyncio.Future]:
- """
- Create Task on (main) event loop from Callable or awaitable.
-
- Tasks created by this helper will be properly cancelled on stop.
- """
- if self.closed:
- return
-
- if asyncio.iscoroutinefunction(target):
- task = self.loop.create_task(target(*args, **kwargs))
- else:
- task = self.loop.create_task(target)
-
- def task_done_callback(*args, **kwargs):
- self._tracked_tasks.remove(task)
-
- self._tracked_tasks.append(task)
- task.add_done_callback(task_done_callback)
- return task
-
- @property
- def jobs(self) -> List[BackgroundJob]:
- """Return the pending/running background jobs."""
- return list(self._jobs)
-
- async def __process_jobs(self):
- """Process jobs in the background."""
- while True:
- await self._jobs_event.wait()
- self._jobs_event.clear()
- # make sure we're not running more jobs than allowed
- running_jobs = tuple(x for x in self._jobs if x.status == JobStatus.RUNNING)
- slots_available = self.config.max_simultaneous_jobs - len(running_jobs)
- count = 0
- while count <= slots_available:
- count += 1
- next_job = next(
- (x for x in self._jobs if x.status == JobStatus.PENDING), None
- )
- if next_job is None:
- break
- # create task from coroutine and attach task_done callback
- next_job.timestamp = time()
- next_job.status = JobStatus.RUNNING
- task = self.create_task(next_job.coro)
- task.set_name(next_job.name)
- task.add_done_callback(partial(self.__job_done_cb, job=next_job))
- self.signal_event(
- MassEvent(
- EventType.BACKGROUND_JOB_UPDATED, next_job.name, data=next_job
- )
- )
-
- def __job_done_cb(self, task: asyncio.Task, job: BackgroundJob):
- """Call when background job finishes."""
- execution_time = round(time() - job.timestamp, 2)
- job.timestamp = execution_time
- if task.cancelled():
- job.status = JobStatus.CANCELLED
- elif err := task.exception():
- job.status = JobStatus.ERROR
- self.logger.error(
- "Job [%s] failed with error %s.",
- job.name,
- str(err),
- exc_info=err,
- )
- else:
- job.result = task.result()
- job.status = JobStatus.FINISHED
- self.logger.info(
- "Finished job [%s] in %s seconds.", job.name, execution_time
- )
- self._jobs.remove(job)
- self._jobs_event.set()
- # mark job as done
- job.done()
- self.signal_event(
- MassEvent(EventType.BACKGROUND_JOB_FINISHED, job.name, data=job)
- )
-
- async def __aenter__(self) -> "MusicAssistant":
- """Return Context manager."""
- await self.setup()
- return self
-
- async def __aexit__(
- self,
- exc_type: Type[BaseException],
- exc_val: BaseException,
- exc_tb: TracebackType,
- ) -> Optional[bool]:
- """Exit context manager."""
- await self.stop()
- if exc_val:
- raise exc_val
- return exc_type
+++ /dev/null
-"""Models package."""
+++ /dev/null
-"""Model for a Background Job."""
-import asyncio
-from dataclasses import dataclass, field
-from time import time
-from typing import Any, Coroutine
-
-from music_assistant.models.enums import JobStatus
-
-
-@dataclass
-class BackgroundJob:
- """Description of a background job/task."""
-
- id: str
- coro: Coroutine
- name: str
- timestamp: float = time()
- status: JobStatus = JobStatus.PENDING
- result: Any = None
- _evt: asyncio.Event = field(init=False, default_factory=asyncio.Event)
-
- def to_dict(self):
- """Return serializable dict from object."""
- return {
- "id": self.id,
- "name": self.name,
- "timestamp": self.status.value,
- "status": self.status.value,
- }
-
- async def wait(self) -> None:
- """Wait for the job to complete."""
- await self._evt.wait()
-
- def done(self) -> None:
- """Mark job as done."""
- self._evt.set()
+++ /dev/null
-"""Model for the Music Assisant runtime config."""
-
-from dataclasses import dataclass, field
-from typing import Dict, List, Optional
-
-from databases import DatabaseURL
-
-from music_assistant.helpers.util import get_ip, select_stream_port
-from music_assistant.models.enums import ProviderType
-
-
-@dataclass(frozen=True)
-class MusicProviderConfig:
- """Base Model for a MusicProvider config."""
-
- type: ProviderType
- enabled: bool = True
- username: Optional[str] = None
- password: Optional[str] = None
- path: Optional[str] = None
- options: Dict[str, str] = field(default_factory=dict)
- # if id is omitted, an id is generated/derived from the other params
- id: Optional[str] = None
-
- def __post_init__(self):
- """Call after init."""
- # create a default (hopefully unique enough) id from type + username/path
- if not self.id and (self.path or self.username):
- prov_id = f"{self.type.value}_"
- base_str = (self.path or self.username).lower()
- prov_id += (
- base_str.replace(".", "").replace("_", "").split("@")[0][1::2]
- ) + base_str[-1]
- super().__setattr__("id", prov_id)
- elif not self.id:
- super().__setattr__("id", self.type.value)
-
-
-@dataclass(frozen=True)
-class MassConfig:
- """Model for the Music Assisant runtime config."""
-
- database_url: DatabaseURL
-
- providers: List[MusicProviderConfig] = field(default_factory=list)
-
- # advanced settings
- max_simultaneous_jobs: int = 2
- stream_port: int = select_stream_port()
- stream_ip: str = get_ip()
+++ /dev/null
-"""All enums used by the Music Assistant models."""
-
-from enum import Enum
-from typing import List
-
-
-class MediaType(Enum):
- """Enum for MediaType."""
-
- ARTIST = "artist"
- ALBUM = "album"
- TRACK = "track"
- PLAYLIST = "playlist"
- RADIO = "radio"
- FOLDER = "folder"
- UNKNOWN = "unknown"
-
- @classmethod
- @property
- def ALL(cls) -> List["MediaType"]: # pylint: disable=invalid-name
- """Return all (default) MediaTypes as list."""
- return [
- MediaType.ARTIST,
- MediaType.ALBUM,
- MediaType.TRACK,
- MediaType.PLAYLIST,
- MediaType.RADIO,
- ]
-
-
-class LinkType(Enum):
- """Enum wth link types."""
-
- WEBSITE = "website"
- FACEBOOK = "facebook"
- TWITTER = "twitter"
- LASTFM = "lastfm"
- YOUTUBE = "youtube"
- INSTAGRAM = "instagram"
- SNAPCHAT = "snapchat"
- TIKTOK = "tiktok"
- DISCOGS = "discogs"
- WIKIPEDIA = "wikipedia"
- ALLMUSIC = "allmusic"
-
-
-class ImageType(Enum):
- """Enum wth image types."""
-
- THUMB = "thumb"
- LANDSCAPE = "landscape"
- FANART = "fanart"
- LOGO = "logo"
- CLEARART = "clearart"
- BANNER = "banner"
- CUTOUT = "cutout"
- BACK = "back"
- DISCART = "discart"
- OTHER = "other"
-
-
-class AlbumType(Enum):
- """Enum for Album type."""
-
- ALBUM = "album"
- SINGLE = "single"
- COMPILATION = "compilation"
- EP = "ep"
- UNKNOWN = "unknown"
-
-
-class ContentType(Enum):
- """Enum with audio content/container types supported by ffmpeg."""
-
- OGG = "ogg"
- FLAC = "flac"
- MP3 = "mp3"
- AAC = "aac"
- MPEG = "mpeg"
- ALAC = "alac"
- WAV = "wav"
- AIFF = "aiff"
- WMA = "wma"
- M4A = "m4a"
- DSF = "dsf"
- WAVPACK = "wv"
- PCM_S16LE = "s16le" # PCM signed 16-bit little-endian
- PCM_S24LE = "s24le" # PCM signed 24-bit little-endian
- PCM_S32LE = "s32le" # PCM signed 32-bit little-endian
- PCM_F32LE = "f32le" # PCM 32-bit floating-point little-endian
- PCM_F64LE = "f64le" # PCM 64-bit floating-point little-endian
- MPEG_DASH = "dash"
- UNKNOWN = "?"
-
- @classmethod
- def try_parse(cls: "ContentType", string: str) -> "ContentType":
- """Try to parse ContentType from (url)string/extension."""
- tempstr = string.lower()
- if "audio/" in tempstr:
- tempstr = tempstr.split("/")[1]
- for splitter in (".", ","):
- if splitter in tempstr:
- for val in tempstr.split(splitter):
- try:
- return cls(val.strip())
- except ValueError:
- pass
-
- tempstr = tempstr.split("?")[0]
- tempstr = tempstr.split("&")[0]
- tempstr = tempstr.split(";")[0]
- tempstr = tempstr.replace("mp4", "m4a")
- tempstr = tempstr.replace("mpd", "dash")
- try:
- return cls(tempstr)
- except ValueError:
- return cls.UNKNOWN
-
- def is_pcm(self) -> bool:
- """Return if contentype is PCM."""
- return self.name.startswith("PCM")
-
- def is_lossless(self) -> bool:
- """Return if format is lossless."""
- return self.is_pcm() or self in (
- ContentType.DSF,
- ContentType.FLAC,
- ContentType.AIFF,
- ContentType.WAV,
- )
-
- @classmethod
- def from_bit_depth(
- cls, bit_depth: int, floating_point: bool = False
- ) -> "ContentType":
- """Return (PCM) Contenttype from PCM bit depth."""
- if floating_point and bit_depth > 32:
- return cls.PCM_F64LE
- if floating_point:
- return cls.PCM_F32LE
- if bit_depth == 16:
- return cls.PCM_S16LE
- if bit_depth == 24:
- return cls.PCM_S24LE
- return cls.PCM_S32LE
-
-
-class QueueOption(Enum):
- """
- Enum representation of the queue (play) options.
-
- - PLAY -> Insert new item(s) in queue at the current position and start playing.
- - REPLACE -> Replace entire queue contents with the new items and start playing from index 0.
- - NEXT -> Insert item(s) after current playing/buffered item.
- - REPLACE_NEXT -> Replace item(s) after current playing/buffered item.
- - ADD -> Add new item(s) to the queue (at the end if shuffle is not enabled).
- """
-
- PLAY = "play"
- REPLACE = "replace"
- NEXT = "next"
- REPLACE_NEXT = "replace_next"
- ADD = "add"
-
-
-class CrossFadeMode(Enum):
- """
- Enum with crossfade modes.
-
- - DISABLED: no crossfading at all
- - STRICT: do not crossfade tracks of same album
- - SMART: crossfade if possible (do not crossfade different sample rates)
- - ALWAYS: all tracks - resample to fixed sample rate
- """
-
- DISABLED = "disabled"
- STRICT = "strict"
- SMART = "smart"
- ALWAYS = "always"
-
-
-class RepeatMode(Enum):
- """Enum with repeat modes."""
-
- OFF = "off" # no repeat at all
- ONE = "one" # repeat one/single track
- ALL = "all" # repeat entire queue
-
-
-class MetadataMode(Enum):
- """Enum with stream metadata modes."""
-
- DISABLED = "disabled" # do not notify icy support
- DEFAULT = "default" # enable icy if player requests it, default chunksize
- LEGACY = "legacy" # enable icy but with legacy 8kb chunksize, requires mp3
-
-
-class PlayerState(Enum):
- """Enum for the (playback)state of a player."""
-
- IDLE = "idle"
- PAUSED = "paused"
- PLAYING = "playing"
- OFF = "off"
-
-
-class EventType(Enum):
- """Enum with possible Events."""
-
- PLAYER_ADDED = "player_added"
- PLAYER_UPDATED = "player_updated"
- QUEUE_ADDED = "queue_added"
- QUEUE_UPDATED = "queue_updated"
- QUEUE_ITEMS_UPDATED = "queue_items_updated"
- QUEUE_TIME_UPDATED = "queue_time_updated"
- SHUTDOWN = "application_shutdown"
- BACKGROUND_JOB_UPDATED = "background_job_updated"
- BACKGROUND_JOB_FINISHED = "background_job_finished"
- MEDIA_ITEM_ADDED = "media_item_added"
- MEDIA_ITEM_UPDATED = "media_item_updated"
- MEDIA_ITEM_DELETED = "media_item_deleted"
-
-
-class JobStatus(Enum):
- """Enum with Job status."""
-
- PENDING = "pending"
- RUNNING = "running"
- CANCELLED = "cancelled"
- FINISHED = "success"
- ERROR = "error"
-
-
-class MusicProviderFeature(Enum):
- """Enum with features for a MusicProvider."""
-
- # browse/explore/recommendations
- BROWSE = "browse"
- SEARCH = "search"
- RECOMMENDATIONS = "recommendations"
- # library feature per mediatype
- LIBRARY_ARTISTS = "library_artists"
- LIBRARY_ALBUMS = "library_albums"
- LIBRARY_TRACKS = "library_tracks"
- LIBRARY_PLAYLISTS = "library_playlists"
- LIBRARY_RADIOS = "library_radios"
- # additional library features
- ARTIST_ALBUMS = "artist_albums"
- ARTIST_TOPTRACKS = "artist_toptracks"
- # library edit (=add/remove) feature per mediatype
- LIBRARY_ARTISTS_EDIT = "library_artists_edit"
- LIBRARY_ALBUMS_EDIT = "library_albums_edit"
- LIBRARY_TRACKS_EDIT = "library_tracks_edit"
- LIBRARY_PLAYLISTS_EDIT = "library_playlists_edit"
- LIBRARY_RADIOS_EDIT = "library_radios_edit"
- # if we can grab 'similar tracks' from the music provider
- # used to generate dynamic playlists
- SIMILAR_TRACKS = "similar_tracks"
- # playlist-specific features
- PLAYLIST_TRACKS_EDIT = "playlist_tracks_edit"
- PLAYLIST_CREATE = "playlist_create"
-
-
-class ProviderType(Enum):
- """Enum with supported music providers."""
-
- FILESYSTEM_LOCAL = "file"
- FILESYSTEM_SMB = "smb"
- FILESYSTEM_GOOGLE_DRIVE = "gdrive"
- FILESYSTEM_ONEDRIVE = "onedrive"
- SPOTIFY = "spotify"
- QOBUZ = "qobuz"
- TUNEIN = "tunein"
- YTMUSIC = "ytmusic"
- DATABASE = "database" # internal only
- URL = "url" # internal only
-
- def is_file(self) -> bool:
- """Return if type is one of the filesystem providers."""
- return self in (
- self.FILESYSTEM_LOCAL,
- self.FILESYSTEM_SMB,
- self.FILESYSTEM_GOOGLE_DRIVE,
- self.FILESYSTEM_ONEDRIVE,
- )
-
- @classmethod
- def parse(cls: "ProviderType", val: str) -> "ProviderType":
- """Try to parse ContentType from provider id."""
- if isinstance(val, ProviderType):
- return val
- for mem in ProviderType:
- if val.startswith(mem.value):
- return mem
- raise ValueError(f"Unable to parse ProviderType from {val}")
+++ /dev/null
-"""Custom errors and exceptions."""
-
-
-class MusicAssistantError(Exception):
- """Custom Exception for all errors."""
-
-
-class ProviderUnavailableError(MusicAssistantError):
- """Error raised when trying to access mediaitem of unavailable provider."""
-
-
-class MediaNotFoundError(MusicAssistantError):
- """Error raised when trying to access non existing media item."""
-
-
-class InvalidDataError(MusicAssistantError):
- """Error raised when an object has invalid data."""
-
-
-class AlreadyRegisteredError(MusicAssistantError):
- """Error raised when a duplicate music provider or player is registered."""
-
-
-class SetupFailedError(MusicAssistantError):
- """Error raised when setup of a provider or player failed."""
-
-
-class LoginFailed(MusicAssistantError):
- """Error raised when a login failed."""
-
-
-class AudioError(MusicAssistantError):
- """Error raised when an issue arrised when processing audio."""
-
-
-class QueueEmpty(MusicAssistantError):
- """Error raised when trying to start queue stream while queue is empty."""
-
-
-class UnsupportedFeaturedException(MusicAssistantError):
- """Error raised when a feature is not supported."""
+++ /dev/null
-"""Model for Music Assistant Event."""
-
-from dataclasses import dataclass
-from typing import Any, Optional
-
-from music_assistant.models.enums import EventType
-
-
-@dataclass
-class MassEvent:
- """Representation of an Event emitted in/by Music Assistant."""
-
- type: EventType
- object_id: Optional[str] = None # player_id, queue_id or uri
- data: Optional[Any] = None # optional data (such as the object)
+++ /dev/null
-"""Models and helpers for media items."""
-from __future__ import annotations
-
-from dataclasses import dataclass, field, fields
-from time import time
-from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union
-
-from mashumaro import DataClassDictMixin
-
-from music_assistant.helpers.json import json
-from music_assistant.helpers.uri import create_uri
-from music_assistant.helpers.util import create_sort_name, merge_lists
-from music_assistant.models.enums import (
- AlbumType,
- ContentType,
- ImageType,
- LinkType,
- MediaType,
- ProviderType,
-)
-
-MetadataTypes = Union[int, bool, str, List[str]]
-
-JSON_KEYS = ("artists", "artist", "albums", "metadata", "provider_mappings")
-
-
-@dataclass(frozen=True)
-class ProviderMapping(DataClassDictMixin):
- """Model for a MediaItem's provider mapping details."""
-
- item_id: str
- provider_type: ProviderType
- provider_id: str
- available: bool = True
- # quality details (streamable content only)
- content_type: ContentType = ContentType.UNKNOWN
- sample_rate: int = 44100
- bit_depth: int = 16
- bit_rate: int = 320
- # optional details to store provider specific details
- details: Optional[str] = None
- # url = link to provider details page if exists
- url: Optional[str] = None
-
- @property
- def quality(self) -> int:
- """Calculate quality score."""
- if self.content_type.is_lossless():
- return int(self.sample_rate / 1000) + self.bit_depth
- # lossy content, bit_rate is most important score
- # but prefer some codecs over others
- score = self.bit_rate / 100
- if self.content_type in (ContentType.AAC, ContentType.OGG):
- score += 1
- return int(score)
-
- def __hash__(self):
- """Return custom hash."""
- return hash((self.provider_type.value, self.item_id))
-
-
-@dataclass(frozen=True)
-class MediaItemLink(DataClassDictMixin):
- """Model for a link."""
-
- type: LinkType
- url: str
-
- def __hash__(self):
- """Return custom hash."""
- return hash((self.type.value))
-
-
-@dataclass(frozen=True)
-class MediaItemImage(DataClassDictMixin):
- """Model for a image."""
-
- type: ImageType
- url: str
- is_file: bool = False # indicator that image is local filepath instead of url
-
- def __hash__(self):
- """Return custom hash."""
- return hash((self.url))
-
-
-@dataclass
-class MediaItemMetadata(DataClassDictMixin):
- """Model for a MediaItem's metadata."""
-
- description: Optional[str] = None
- review: Optional[str] = None
- explicit: Optional[bool] = None
- images: Optional[List[MediaItemImage]] = None
- genres: Optional[Set[str]] = None
- mood: Optional[str] = None
- style: Optional[str] = None
- copyright: Optional[str] = None
- lyrics: Optional[str] = None
- ean: Optional[str] = None
- label: Optional[str] = None
- links: Optional[Set[MediaItemLink]] = None
- performers: Optional[Set[str]] = None
- preview: Optional[str] = None
- replaygain: Optional[float] = None
- popularity: Optional[int] = None
- # last_refresh: timestamp the (full) metadata was last collected
- last_refresh: Optional[int] = None
- # checksum: optional value to detect changes (e.g. playlists)
- checksum: Optional[str] = None
-
- def update(
- self,
- new_values: "MediaItemMetadata",
- allow_overwrite: bool = False,
- ) -> "MediaItemMetadata":
- """Update metadata (in-place) with new values."""
- for fld in fields(self):
- new_val = getattr(new_values, fld.name)
- if new_val is None:
- continue
- cur_val = getattr(self, fld.name)
- if isinstance(cur_val, list):
- new_val = merge_lists(cur_val, new_val)
- setattr(self, fld.name, new_val)
- elif isinstance(cur_val, set):
- new_val = cur_val.update(new_val)
- setattr(self, fld.name, new_val)
- elif cur_val is None or allow_overwrite:
- setattr(self, fld.name, new_val)
- elif new_val and fld.name in ("checksum", "popularity", "last_refresh"):
- # some fields are always allowed to be overwritten (such as checksum and last_refresh)
- setattr(self, fld.name, new_val)
- return self
-
-
-@dataclass
-class MediaItem(DataClassDictMixin):
- """Base representation of a media item."""
-
- item_id: str
- provider: ProviderType
- name: str
- provider_mappings: Set[ProviderMapping] = field(default_factory=set)
-
- # optional fields below
- metadata: MediaItemMetadata = field(default_factory=MediaItemMetadata)
- in_library: bool = False
- media_type: MediaType = MediaType.UNKNOWN
- # sort_name and uri are auto generated, do not override unless really needed
- sort_name: Optional[str] = None
- uri: Optional[str] = None
- # timestamp is used to determine when the item was added to the library
- timestamp: int = 0
-
- def __post_init__(self):
- """Call after init."""
- if not self.uri:
- self.uri = create_uri(self.media_type, self.provider, self.item_id)
- if not self.sort_name:
- self.sort_name = create_sort_name(self.name)
-
- @classmethod
- def from_db_row(cls, db_row: Mapping):
- """Create MediaItem object from database row."""
- db_row = dict(db_row)
- db_row["provider"] = "database"
- for key in JSON_KEYS:
- if key in db_row and db_row[key] is not None:
- db_row[key] = json.loads(db_row[key])
- if "in_library" in db_row:
- db_row["in_library"] = bool(db_row["in_library"])
- if db_row.get("albums"):
- db_row["album"] = db_row["albums"][0]
- db_row["disc_number"] = db_row["albums"][0]["disc_number"]
- db_row["track_number"] = db_row["albums"][0]["track_number"]
- db_row["item_id"] = str(db_row["item_id"])
- return cls.from_dict(db_row)
-
- def to_db_row(self) -> dict:
- """Create dict from item suitable for db."""
- return {
- key: json.dumps(value) if key in JSON_KEYS else value
- for key, value in self.to_dict().items()
- if key
- not in [
- "item_id",
- "provider",
- "media_type",
- "uri",
- "album",
- "position",
- "track_number",
- "disc_number",
- ]
- }
-
- @property
- def available(self):
- """Return (calculated) availability."""
- return any(x.available for x in self.provider_mappings)
-
- @property
- def image(self) -> MediaItemImage | None:
- """Return (first/random) image/thumb from metadata (if any)."""
- if self.metadata is None or self.metadata.images is None:
- return None
- return next(
- (x for x in self.metadata.images if x.type == ImageType.THUMB), None
- )
-
- def add_provider_mapping(self, prov_mapping: ProviderMapping) -> None:
- """Add provider ID, overwrite existing entry."""
- self.provider_mappings = {
- x
- for x in self.provider_mappings
- if not (
- x.item_id == prov_mapping.item_id
- and x.provider_id == prov_mapping.provider_id
- )
- }
- self.provider_mappings.add(prov_mapping)
-
- @property
- def last_refresh(self) -> int:
- """Return timestamp the metadata was last refreshed (0 if full data never retrieved)."""
- return self.metadata.last_refresh or 0
-
- def __hash__(self):
- """Return custom hash."""
- return hash((self.media_type, self.provider, self.item_id))
-
-
-@dataclass(frozen=True)
-class ItemMapping(DataClassDictMixin):
- """Representation of a minimized item object."""
-
- media_type: MediaType
- item_id: str
- provider: ProviderType
- name: str
- sort_name: str
- uri: str
- version: str = ""
-
- @classmethod
- def from_item(cls, item: "MediaItem"):
- """Create ItemMapping object from regular item."""
- return cls.from_dict(item.to_dict())
-
- def __hash__(self):
- """Return custom hash."""
- return hash((self.media_type, self.provider, self.item_id))
-
-
-@dataclass
-class Artist(MediaItem):
- """Model for an artist."""
-
- media_type: MediaType = MediaType.ARTIST
- musicbrainz_id: Optional[str] = None
-
- def __hash__(self):
- """Return custom hash."""
- return hash((self.provider, self.item_id))
-
-
-@dataclass
-class Album(MediaItem):
- """Model for an album."""
-
- media_type: MediaType = MediaType.ALBUM
- version: str = ""
- year: Optional[int] = None
- artists: List[Union[Artist, ItemMapping]] = field(default_factory=list)
- album_type: AlbumType = AlbumType.UNKNOWN
- upc: Optional[str] = None
- musicbrainz_id: Optional[str] = None # release group id
-
- @property
- def artist(self) -> Artist | ItemMapping | None:
- """Return (first) artist of album."""
- if self.artists:
- return self.artists[0]
- return None
-
- @artist.setter
- def artist(self, artist: Union[Artist, ItemMapping]) -> None:
- """Set (first/only) artist of album."""
- self.artists = [artist]
-
- def __hash__(self):
- """Return custom hash."""
- return hash((self.provider, self.item_id))
-
-
-@dataclass(frozen=True)
-class TrackAlbumMapping(ItemMapping):
- """Model for a track that is mapped to an album."""
-
- disc_number: Optional[int] = None
- track_number: Optional[int] = None
-
-
-@dataclass
-class Track(MediaItem):
- """Model for a track."""
-
- media_type: MediaType = MediaType.TRACK
- duration: int = 0
- version: str = ""
- isrc: Optional[str] = None
- musicbrainz_id: Optional[str] = None # Recording ID
- artists: List[Union[Artist, ItemMapping]] = field(default_factory=list)
- # album track only
- album: Union[Album, ItemMapping, None] = None
- albums: List[TrackAlbumMapping] = field(default_factory=list)
- disc_number: Optional[int] = None
- track_number: Optional[int] = None
- # playlist track only
- position: Optional[int] = None
-
- def __hash__(self):
- """Return custom hash."""
- return hash((self.provider, self.item_id))
-
- @property
- def image(self) -> MediaItemImage | None:
- """Return (first/random) image/thumb from metadata (if any)."""
- if image := super().image:
- return image
- # fallback to album image (use getattr to guard for ItemMapping)
- if self.album:
- return getattr(self.album, "image", None)
- return None
-
- @property
- def isrcs(self) -> Tuple[str]:
- """Split multiple values in isrc field."""
- # sometimes the isrc contains multiple values, splitted by semicolon
- if not self.isrc:
- return tuple()
- return tuple(self.isrc.split(";"))
-
- @property
- def artist(self) -> Artist | ItemMapping | None:
- """Return (first) artist of track."""
- if self.artists:
- return self.artists[0]
- return None
-
- @artist.setter
- def artist(self, artist: Union[Artist, ItemMapping]) -> None:
- """Set (first/only) artist of track."""
- self.artists = [artist]
-
-
-@dataclass
-class Playlist(MediaItem):
- """Model for a playlist."""
-
- media_type: MediaType = MediaType.PLAYLIST
- owner: str = ""
- is_editable: bool = False
-
- def __hash__(self):
- """Return custom hash."""
- return hash((self.provider, self.item_id))
-
-
-@dataclass
-class Radio(MediaItem):
- """Model for a radio station."""
-
- media_type: MediaType = MediaType.RADIO
- duration: int = 172800
-
- def to_db_row(self) -> dict:
- """Create dict from item suitable for db."""
- val = super().to_db_row()
- val.pop("duration", None)
- return val
-
- def __hash__(self):
- """Return custom hash."""
- return hash((self.provider, self.item_id))
-
-
-@dataclass
-class BrowseFolder(MediaItem):
- """Representation of a Folder used in Browse (which contains media items)."""
-
- media_type: MediaType = MediaType.FOLDER
- # path: the path (in uri style) to/for this browse folder
- path: str = ""
- # label: a labelid that needs to be translated by the frontend
- label: str = ""
- # subitems of this folder when expanding
- items: Optional[List[Union[MediaItemType, BrowseFolder]]] = None
-
- def __post_init__(self):
- """Call after init."""
- super().__post_init__()
- if not self.path:
- self.path = f"{self.provider}://{self.item_id}"
-
-
-MediaItemType = Union[Artist, Album, Track, Radio, Playlist, BrowseFolder]
-
-
-@dataclass
-class PagedItems(DataClassDictMixin):
- """Model for a paged listing."""
-
- items: List[MediaItemType]
- count: int
- limit: int
- offset: int
- total: Optional[int] = None
-
-
-def media_from_dict(media_item: dict) -> MediaItemType:
- """Return MediaItem from dict."""
- if media_item["media_type"] == "artist":
- return Artist.from_dict(media_item)
- if media_item["media_type"] == "album":
- return Album.from_dict(media_item)
- if media_item["media_type"] == "track":
- return Track.from_dict(media_item)
- if media_item["media_type"] == "playlist":
- return Playlist.from_dict(media_item)
- if media_item["media_type"] == "radio":
- return Radio.from_dict(media_item)
- return MediaItem.from_dict(media_item)
-
-
-@dataclass
-class StreamDetails(DataClassDictMixin):
- """Model for streamdetails."""
-
- # NOTE: the actual provider/itemid of the streamdetails may differ
- # from the connected media_item due to track linking etc.
- # the streamdetails are only used to provide details about the content
- # that is going to be streamed.
-
- # mandatory fields
- provider: ProviderType
- item_id: str
- content_type: ContentType
- media_type: MediaType = MediaType.TRACK
- sample_rate: int = 44100
- bit_depth: int = 16
- channels: int = 2
- # stream_title: radio streams can optionally set this field
- stream_title: Optional[str] = None
- # duration of the item to stream, copied from media_item if omitted
- duration: Optional[int] = None
- # total size in bytes of the item, calculated at eof when omitted
- size: Optional[int] = None
- # expires: timestamp this streamdetails expire
- expires: float = time() + 3600
- # data: provider specific data (not exposed externally)
- data: Optional[Any] = None
- # if the url/file is supported by ffmpeg directly, use direct stream
- direct: Optional[str] = None
- # callback: optional callback function (or coroutine) to call when the stream completes.
- # needed for streaming provivders to report what is playing
- # receives the streamdetails as only argument from which to grab
- # details such as seconds_streamed.
- callback: Any = None
-
- # the fields below will be set/controlled by the streamcontroller
- queue_id: Optional[str] = None
- seconds_streamed: Optional[float] = None
- seconds_skipped: Optional[float] = None
- gain_correct: Optional[float] = None
- loudness: Optional[float] = None
-
- def __post_serialize__(self, d: Dict[Any, Any]) -> Dict[Any, Any]:
- """Exclude internal fields from dict."""
- d.pop("data")
- d.pop("direct")
- d.pop("expires")
- d.pop("queue_id")
- d.pop("callback")
- return d
-
- def __str__(self):
- """Return pretty printable string of object."""
- return self.uri
-
- @property
- def uri(self) -> str:
- """Return uri representation of item."""
- return f"{self.provider.value}://{self.media_type.value}/{self.item_id}"
+++ /dev/null
-"""Model/base for a Music Provider implementation."""
-from __future__ import annotations
-
-from abc import abstractmethod
-from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Tuple
-
-from music_assistant.models.config import MusicProviderConfig
-from music_assistant.models.enums import MediaType, MusicProviderFeature, ProviderType
-from music_assistant.models.media_items import (
- Album,
- Artist,
- BrowseFolder,
- MediaItemType,
- Playlist,
- Radio,
- StreamDetails,
- Track,
-)
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
-
-class MusicProvider:
- """Model for a Music Provider."""
-
- _attr_name: str = None
- _attr_type: ProviderType = None
- _attr_available: bool = True
-
- def __init__(self, mass: MusicAssistant, config: MusicProviderConfig) -> None:
- """Initialize MusicProvider."""
- self.mass = mass
- self.config = config
- self.logger = mass.logger
- self.cache = mass.cache
-
- @property
- def supported_features(self) -> Tuple[MusicProviderFeature]:
- """Return the features supported by this MusicProvider."""
- return tuple()
-
- @abstractmethod
- async def setup(self) -> bool:
- """
- Handle async initialization of the provider.
-
- Called when provider is registered.
- """
-
- @property
- def type(self) -> ProviderType:
- """Return provider type for this provider."""
- return self._attr_type
-
- @property
- def name(self) -> str:
- """Return provider Name for this provider."""
- if sum(1 for x in self.mass.music.providers if x.type == self.type) > 1:
- append_str = self.config.path or self.config.username
- return f"{self._attr_name} ({append_str})"
- return self._attr_name
-
- @property
- def available(self) -> bool:
- """Return boolean if this provider is available/initialized."""
- return self._attr_available
-
- async def search(
- self, search_query: str, media_types=Optional[List[MediaType]], limit: int = 5
- ) -> List[MediaItemType]:
- """
- Perform search on musicprovider.
-
- :param search_query: Search query.
- :param media_types: A list of media_types to include. All types if None.
- :param limit: Number of items to return in the search (per type).
- """
- if MusicProviderFeature.SEARCH in self.supported_features:
- raise NotImplementedError
-
- async def get_library_artists(self) -> AsyncGenerator[Artist, None]:
- """Retrieve library artists from the provider."""
- if MusicProviderFeature.LIBRARY_ARTISTS in self.supported_features:
- raise NotImplementedError
-
- async def get_library_albums(self) -> AsyncGenerator[Album, None]:
- """Retrieve library albums from the provider."""
- if MusicProviderFeature.LIBRARY_ALBUMS in self.supported_features:
- raise NotImplementedError
-
- async def get_library_tracks(self) -> AsyncGenerator[Track, None]:
- """Retrieve library tracks from the provider."""
- if MusicProviderFeature.LIBRARY_TRACKS in self.supported_features:
- raise NotImplementedError
-
- async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]:
- """Retrieve library/subscribed playlists from the provider."""
- if MusicProviderFeature.LIBRARY_PLAYLISTS in self.supported_features:
- raise NotImplementedError
-
- async def get_library_radios(self) -> AsyncGenerator[Radio, None]:
- """Retrieve library/subscribed radio stations from the provider."""
- if MusicProviderFeature.LIBRARY_RADIOS in self.supported_features:
- raise NotImplementedError
-
- async def get_artist(self, prov_artist_id: str) -> Artist:
- """Get full artist details by id."""
- raise NotImplementedError
-
- async def get_artist_albums(self, prov_artist_id: str) -> List[Album]:
- """Get a list of all albums for the given artist."""
- if MusicProviderFeature.ARTIST_ALBUMS in self.supported_features:
- raise NotImplementedError
- return []
-
- async def get_artist_toptracks(self, prov_artist_id: str) -> List[Track]:
- """Get a list of most popular tracks for the given artist."""
- if MusicProviderFeature.ARTIST_TOPTRACKS in self.supported_features:
- raise NotImplementedError
- return []
-
- async def get_album(self, prov_album_id: str) -> Album:
- """Get full album details by id."""
- raise NotImplementedError
-
- async def get_track(self, prov_track_id: str) -> Track:
- """Get full track details by id."""
- raise NotImplementedError
-
- async def get_playlist(self, prov_playlist_id: str) -> Playlist:
- """Get full playlist details by id."""
- raise NotImplementedError
-
- async def get_radio(self, prov_radio_id: str) -> Radio:
- """Get full radio details by id."""
- raise NotImplementedError
-
- async def get_album_tracks(self, prov_album_id: str) -> List[Track]:
- """Get album tracks for given album id."""
- raise NotImplementedError
-
- async def get_playlist_tracks(self, prov_playlist_id: str) -> List[Track]:
- """Get all playlist tracks for given playlist id."""
- raise NotImplementedError
-
- async def library_add(self, prov_item_id: str, media_type: MediaType) -> bool:
- """Add item to provider's library. Return true on succes."""
- if (
- media_type == MediaType.ARTIST
- and MusicProviderFeature.LIBRARY_ARTISTS_EDIT in self.supported_features
- ):
- raise NotImplementedError
- if (
- media_type == MediaType.ALBUM
- and MusicProviderFeature.LIBRARY_ALBUMS_EDIT in self.supported_features
- ):
- raise NotImplementedError
- if (
- media_type == MediaType.TRACK
- and MusicProviderFeature.LIBRARY_TRACKS_EDIT in self.supported_features
- ):
- raise NotImplementedError
- if (
- media_type == MediaType.PLAYLIST
- and MusicProviderFeature.LIBRARY_PLAYLISTS_EDIT in self.supported_features
- ):
- raise NotImplementedError
- if (
- media_type == MediaType.RADIO
- and MusicProviderFeature.LIBRARY_RADIOS_EDIT in self.supported_features
- ):
- raise NotImplementedError
- self.logger.info(
- "Provider %s does not support library edit, "
- "the action will only be performed in the local database.",
- self.type.value,
- )
-
- async def library_remove(self, prov_item_id: str, media_type: MediaType) -> bool:
- """Remove item from provider's library. Return true on succes."""
- if (
- media_type == MediaType.ARTIST
- and MusicProviderFeature.LIBRARY_ARTISTS_EDIT in self.supported_features
- ):
- raise NotImplementedError
- if (
- media_type == MediaType.ALBUM
- and MusicProviderFeature.LIBRARY_ALBUMS_EDIT in self.supported_features
- ):
- raise NotImplementedError
- if (
- media_type == MediaType.TRACK
- and MusicProviderFeature.LIBRARY_TRACKS_EDIT in self.supported_features
- ):
- raise NotImplementedError
- if (
- media_type == MediaType.PLAYLIST
- and MusicProviderFeature.LIBRARY_PLAYLISTS_EDIT in self.supported_features
- ):
- raise NotImplementedError
- if (
- media_type == MediaType.RADIO
- and MusicProviderFeature.LIBRARY_RADIOS_EDIT in self.supported_features
- ):
- raise NotImplementedError
- self.logger.info(
- "Provider %s does not support library edit, "
- "the action will only be performed in the local database.",
- self.type.value,
- )
-
- async def add_playlist_tracks(
- self, prov_playlist_id: str, prov_track_ids: List[str]
- ) -> None:
- """Add track(s) to playlist."""
- if MusicProviderFeature.PLAYLIST_TRACKS_EDIT in self.supported_features:
- raise NotImplementedError
-
- async def remove_playlist_tracks(
- self, prov_playlist_id: str, positions_to_remove: Tuple[int]
- ) -> None:
- """Remove track(s) from playlist."""
- if MusicProviderFeature.PLAYLIST_TRACKS_EDIT in self.supported_features:
- raise NotImplementedError
-
- async def create_playlist(self, name: str) -> Playlist:
- """Create a new playlist on provider with given name."""
- raise NotImplementedError
-
- async def get_similar_tracks(self, prov_track_id, limit=25) -> List[Track]:
- """Retrieve a dynamic list of similar tracks based on the provided track."""
- raise NotImplementedError
-
- async def get_stream_details(self, item_id: str) -> StreamDetails | None:
- """Get streamdetails for a track/radio."""
- raise NotImplementedError
-
- async def get_audio_stream(
- self, streamdetails: StreamDetails, seek_position: int = 0
- ) -> AsyncGenerator[bytes, None]:
- """Return the audio stream for the provider item."""
- raise NotImplementedError
-
- async def get_item(self, media_type: MediaType, prov_item_id: str) -> MediaItemType:
- """Get single MediaItem from provider."""
- if media_type == MediaType.ARTIST:
- return await self.get_artist(prov_item_id)
- if media_type == MediaType.ALBUM:
- return await self.get_album(prov_item_id)
- if media_type == MediaType.PLAYLIST:
- return await self.get_playlist(prov_item_id)
- if media_type == MediaType.RADIO:
- return await self.get_radio(prov_item_id)
- return await self.get_track(prov_item_id)
-
- async def browse(self, path: str) -> BrowseFolder:
- """
- Browse this provider's items.
-
- :param path: The path to browse, (e.g. provid://artists).
- """
- if MusicProviderFeature.BROWSE not in self.supported_features:
- # we may NOT use the default implementation if the provider does not support browse
- raise NotImplementedError
-
- _, subpath = path.split("://")
-
- # this reference implementation can be overridden with provider specific approach
- if not subpath:
- # return main listing
- root_items = []
- if MusicProviderFeature.LIBRARY_ARTISTS in self.supported_features:
- root_items.append(
- BrowseFolder(
- item_id="artists",
- provider=self.type,
- path=path + "artists",
- name="",
- label="artists",
- )
- )
- if MusicProviderFeature.LIBRARY_ALBUMS in self.supported_features:
- root_items.append(
- BrowseFolder(
- item_id="albums",
- provider=self.type,
- path=path + "albums",
- name="",
- label="albums",
- )
- )
- if MusicProviderFeature.LIBRARY_TRACKS in self.supported_features:
- root_items.append(
- BrowseFolder(
- item_id="tracks",
- provider=self.type,
- path=path + "tracks",
- name="",
- label="tracks",
- )
- )
- if MusicProviderFeature.LIBRARY_PLAYLISTS in self.supported_features:
- root_items.append(
- BrowseFolder(
- item_id="playlists",
- provider=self.type,
- path=path + "playlists",
- name="",
- label="playlists",
- )
- )
- if MusicProviderFeature.LIBRARY_RADIOS in self.supported_features:
- root_items.append(
- BrowseFolder(
- item_id="radios",
- provider=self.type,
- path=path + "radios",
- name="",
- label="radios",
- )
- )
- return BrowseFolder(
- item_id="root",
- provider=self.type,
- path=path,
- name=self.name,
- items=root_items,
- )
- # sublevel
- if subpath == "artists":
- return BrowseFolder(
- item_id="artists",
- provider=self.type,
- path=path,
- name="",
- label="artists",
- items=[x async for x in self.get_library_artists()],
- )
- if subpath == "albums":
- return BrowseFolder(
- item_id="albums",
- provider=self.type,
- path=path,
- name="",
- label="albums",
- items=[x async for x in self.get_library_albums()],
- )
- if subpath == "tracks":
- return BrowseFolder(
- item_id="tracks",
- provider=self.type,
- path=path,
- name="",
- label="tracks",
- items=[x async for x in self.get_library_tracks()],
- )
- if subpath == "radios":
- return BrowseFolder(
- item_id="radios",
- provider=self.type,
- path=path,
- name="",
- label="radios",
- items=[x async for x in self.get_library_radios()],
- )
- if subpath == "playlists":
- return BrowseFolder(
- item_id="playlists",
- provider=self.type,
- path=path,
- name="",
- label="playlists",
- items=[x async for x in self.get_library_playlists()],
- )
-
- async def recommendations(self) -> List[BrowseFolder]:
- """
- Get this provider's recommendations.
-
- Returns a list of BrowseFolder items with (max 25) mediaitems in the items attribute.
- """
- if MusicProviderFeature.RECOMMENDATIONS in self.supported_features:
- raise NotImplementedError
-
- async def sync_library(
- self, media_types: Optional[Tuple[MediaType]] = None
- ) -> None:
- """Run library sync for this provider."""
- # this reference implementation can be overridden with provider specific approach
- # this logic is aimed at streaming/online providers,
- # which all have more or less the same structure.
- # filesystem implementation(s) just override this.
- if media_types is None:
- media_types = (x for x in MediaType)
- for media_type in media_types:
- if not self.library_supported(media_type):
- continue
- self.logger.debug("Start sync of %s items.", media_type.value)
- controller = self.mass.music.get_controller(media_type)
- cur_db_ids = set()
- async for prov_item in self._get_library_gen(media_type)():
- prov_item: MediaItemType = prov_item
-
- db_item: MediaItemType = await controller.get_db_item_by_prov_id(
- provider_item_id=prov_item.item_id,
- provider_type=prov_item.provider,
- )
- if not db_item:
- # dump the item in the db, rich metadata is lazy loaded later
- db_item = await controller.add_db_item(prov_item)
-
- elif (
- db_item.metadata.checksum and prov_item.metadata.checksum
- ) and db_item.metadata.checksum != prov_item.metadata.checksum:
- # item checksum changed
- db_item = await controller.update_db_item(
- db_item.item_id, prov_item
- )
- # preload album/playlist tracks
- if prov_item.media_type == (MediaType.ALBUM, MediaType.PLAYLIST):
- for track in controller.tracks(
- prov_item.item_id, prov_item.provider
- ):
- await self.mass.music.tracks.add_db_item(track)
- cur_db_ids.add(db_item.item_id)
- if not db_item.in_library:
- await controller.set_db_library(db_item.item_id, True)
-
- # process deletions (= no longer in library)
- async for db_item in controller.iter_db_items(True):
- if db_item.item_id in cur_db_ids:
- continue
- for prov_mapping in db_item.provider_mappings:
- provider_types = {
- x.provider_type for x in db_item.provider_mappings
- }
- if len(provider_types) > 1:
- continue
- if prov_mapping.provider_id != self.id:
- continue
- # only mark the item as not in library and leave the metadata in db
- await controller.set_db_library(db_item.item_id, False)
-
- # DO NOT OVERRIDE BELOW
-
- @property
- def id(self) -> str:
- """Return unique provider id to distinguish multiple instances of the same provider."""
- return self.config.id
-
- def to_dict(self) -> Dict[str, Any]:
- """Return (serializable) dict representation of MusicProvider."""
- return {
- "type": self.type.value,
- "name": self.name,
- "id": self.id,
- "supported_features": [x.value for x in self.supported_features],
- }
-
- def library_supported(self, media_type: MediaType) -> bool:
- """Return if Library is supported for given MediaType on this provider."""
- if media_type == MediaType.ARTIST:
- return MusicProviderFeature.LIBRARY_ARTISTS in self.supported_features
- if media_type == MediaType.ALBUM:
- return MusicProviderFeature.LIBRARY_ALBUMS in self.supported_features
- if media_type == MediaType.TRACK:
- return MusicProviderFeature.LIBRARY_TRACKS in self.supported_features
- if media_type == MediaType.PLAYLIST:
- return MusicProviderFeature.LIBRARY_PLAYLISTS in self.supported_features
- if media_type == MediaType.RADIO:
- return MusicProviderFeature.LIBRARY_RADIOS in self.supported_features
-
- def library_edit_supported(self, media_type: MediaType) -> bool:
- """Return if Library add/remove is supported for given MediaType on this provider."""
- if media_type == MediaType.ARTIST:
- return MusicProviderFeature.LIBRARY_ARTISTS_EDIT in self.supported_features
- if media_type == MediaType.ALBUM:
- return MusicProviderFeature.LIBRARY_ALBUMS_EDIT in self.supported_features
- if media_type == MediaType.TRACK:
- return MusicProviderFeature.LIBRARY_TRACKS_EDIT in self.supported_features
- if media_type == MediaType.PLAYLIST:
- return (
- MusicProviderFeature.LIBRARY_PLAYLISTS_EDIT in self.supported_features
- )
- if media_type == MediaType.RADIO:
- return MusicProviderFeature.LIBRARY_RADIOS_EDIT in self.supported_features
-
- def _get_library_gen(self, media_type: MediaType) -> AsyncGenerator[MediaItemType]:
- """Return library generator for given media_type."""
- if media_type == MediaType.ARTIST:
- return self.get_library_artists
- if media_type == MediaType.ALBUM:
- return self.get_library_albums
- if media_type == MediaType.TRACK:
- return self.get_library_tracks
- if media_type == MediaType.PLAYLIST:
- return self.get_library_playlists
- if media_type == MediaType.RADIO:
- return self.get_library_radios
+++ /dev/null
-"""Models and helpers for a player."""
-from __future__ import annotations
-
-import asyncio
-from abc import ABC
-from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Dict, List
-
-from mashumaro import DataClassDictMixin
-
-from music_assistant.helpers.util import get_changed_keys
-from music_assistant.models.enums import EventType, PlayerState
-from music_assistant.models.event import MassEvent
-from music_assistant.models.media_items import ContentType
-
-if TYPE_CHECKING:
- from music_assistant.mass import MusicAssistant
-
- from .player_queue import PlayerQueue
-
-
-@dataclass(frozen=True)
-class DeviceInfo(DataClassDictMixin):
- """Model for a player's deviceinfo."""
-
- model: str = "unknown"
- address: str = "unknown"
- manufacturer: str = "unknown"
-
-
-class Player(ABC):
- """Model for a music player."""
-
- player_id: str
- _attr_group_members: List[str] = []
- _attr_name: str = ""
- _attr_powered: bool = False
- _attr_elapsed_time: float = 0
- _attr_current_url: str = ""
- _attr_state: PlayerState = PlayerState.IDLE
- _attr_available: bool = True
- _attr_volume_level: int = 100
- _attr_volume_muted: bool = False
- _attr_device_info: DeviceInfo = DeviceInfo()
- _attr_max_sample_rate: int = 96000
- _attr_stream_type: ContentType = ContentType.FLAC
- # below objects will be set by playermanager at register/update
- mass: MusicAssistant = None # type: ignore[assignment]
- _prev_state: dict = {}
-
- @property
- def name(self) -> bool:
- """Return player name."""
- return self._attr_name or self.player_id
-
- @property
- def powered(self) -> bool:
- """Return current power state of player."""
- return self._attr_powered
-
- @property
- def elapsed_time(self) -> float:
- """Return elapsed time of current playing media in seconds."""
- # NOTE: Make sure to provide an accurate elapsed time otherwise the
- # queue reporting of playing tracks will be wrong.
- # this attribute will be checked every second when the queue is playing
- return self._attr_elapsed_time
-
- @property
- def current_url(self) -> str:
- """Return URL that is currently loaded in the player."""
- return self._attr_current_url
-
- @property
- def state(self) -> PlayerState:
- """Return current PlayerState of player."""
- if not self.powered:
- return PlayerState.OFF
- return self._attr_state
-
- @property
- def available(self) -> bool:
- """Return current availablity of player."""
- return self._attr_available
-
- @property
- def volume_level(self) -> int:
- """Return current volume level of player (scale 0..100)."""
- return self._attr_volume_level
-
- @property
- def volume_muted(self) -> bool:
- """Return current mute mode of player."""
- return self._attr_volume_muted
-
- @property
- def device_info(self) -> DeviceInfo:
- """Return basic device/provider info for this player."""
- return self._attr_device_info
-
- async def play_url(self, url: str) -> None:
- """Play the specified url on the player."""
- raise NotImplementedError
-
- async def stop(self) -> None:
- """Send STOP command to player."""
- raise NotImplementedError
-
- async def play(self) -> None:
- """Send PLAY/UNPAUSE command to player."""
- raise NotImplementedError
-
- async def pause(self) -> None:
- """Send PAUSE command to player."""
- raise NotImplementedError
-
- async def power(self, powered: bool) -> None:
- """Send POWER command to player."""
- raise NotImplementedError
-
- async def volume_set(self, volume_level: int) -> None:
- """Send volume level (0..100) command to player."""
- raise NotImplementedError
-
- # DEFAULT PLAYER SETTINGS
-
- @property
- def max_sample_rate(self) -> int:
- """Return the (default) max supported sample rate."""
- # if a player does not report/set its supported sample rates, we use a pretty safe default
- return self._attr_max_sample_rate
-
- @property
- def stream_type(self) -> ContentType:
- """Return the default/preferred content type to use for streaming."""
- return self._attr_stream_type
-
- # GROUP PLAYER ATTRIBUTES AND METHODS (may be overridden if needed)
- # a player can optionally be a group leader (e.g. Sonos)
- # or be a group player itself (e.g. Cast)
- # support both scenarios here
-
- @property
- def is_group(self) -> bool:
- """Return if this player represents a playergroup or is grouped with other players."""
- return len(self.group_members) > 1
-
- @property
- def group_members(self) -> List[str]:
- """
- Return list of grouped players.
-
- If this player is a dedicated group player (e.g. cast), returns the grouped child id's.
- If this is a player grouped with other players within the same platform (e.g. Sonos),
- this will return the players that are currently grouped together.
- The first child id should represent the group leader.
- """
- return self._attr_group_members
-
- @property
- def group_leader(self) -> str | None:
- """Return the leader's player_id of this playergroup."""
- if group_members := self.group_members:
- return group_members[0]
- return None
-
- @property
- def is_group_leader(self) -> bool:
- """Return if this player is the leader in a playergroup."""
- return self.is_group and self.group_leader == self.player_id
-
- @property
- def is_passive(self) -> bool:
- """
- Return if this player may not accept any playback related commands.
-
- Usually this means the player is part of a playergroup but not the leader.
- """
- if self.is_group and self.player_id not in self.group_members:
- return False
- return self.is_group and not self.is_group_leader
-
- @property
- def group_name(self) -> str:
- """Return name of this grouped player."""
- if not self.is_group:
- return self.name
- # default to name of groupleader and number of childs
- num_childs = len([x for x in self.group_members if x != self.player_id])
- return f"{self.name} +{num_childs}"
-
- @property
- def group_powered(self) -> bool:
- """Calculate a group power state from the grouped members."""
- if not self.is_group:
- return self.powered
- for _ in self.get_child_players(True):
- return True
- return False
-
- @property
- def group_volume_level(self) -> int:
- """Calculate a group volume from the grouped members."""
- if not self.is_group:
- return self.volume_level
- group_volume = 0
- active_players = 0
- for child_player in self.get_child_players(True):
- group_volume += child_player.volume_level
- active_players += 1
- if active_players:
- group_volume = group_volume / active_players
- return int(group_volume)
-
- async def set_group_volume(self, volume_level: int) -> None:
- """Send volume level (0..100) command to groupplayer's member(s)."""
- # handle group volume by only applying the volume to powered members
- cur_volume = self.group_volume_level
- new_volume = volume_level
- volume_dif = new_volume - cur_volume
- if cur_volume == 0:
- volume_dif_percent = 1 + (new_volume / 100)
- else:
- volume_dif_percent = volume_dif / cur_volume
- coros = []
- for child_player in self.get_child_players(True):
- cur_child_volume = child_player.volume_level
- new_child_volume = cur_child_volume + (
- cur_child_volume * volume_dif_percent
- )
- coros.append(child_player.volume_set(new_child_volume))
- await asyncio.gather(*coros)
-
- async def set_group_power(self, powered: bool) -> None:
- """Send power command to groupplayer's member(s)."""
- coros = [
- player.power(powered) for player in self.get_child_players(not powered)
- ]
- await asyncio.gather(*coros)
-
- # SOME CONVENIENCE METHODS (may be overridden if needed)
-
- async def volume_mute(self, muted: bool) -> None:
- """Send volume mute command to player."""
- # for players that do not support mute, we fake mute with volume
- self._attr_volume_muted = muted
- if muted:
- setattr(self, "prev_volume", self.volume_level)
- else:
- await self.volume_set(getattr(self, "prev_volume", 0))
-
- async def volume_up(self, step_size: int = 5) -> None:
- """Send volume UP command to player."""
- new_level = min(self.volume_level + step_size, 100)
- return await self.volume_set(new_level)
-
- async def volume_down(self, step_size: int = 5) -> None:
- """Send volume DOWN command to player."""
- new_level = max(self.volume_level - step_size, 0)
- return await self.volume_set(new_level)
-
- async def play_pause(self) -> None:
- """Toggle play/pause on player."""
- if self.state == PlayerState.PLAYING:
- await self.pause()
- else:
- await self.play()
-
- async def power_toggle(self) -> None:
- """Toggle power on player."""
- await self.power(not self.powered)
-
- def on_update(self) -> None:
- """Call when player is about to be updated in the player manager."""
-
- def on_child_update(self, player_id: str, changed_keys: set) -> None:
- """Call when one of the child players of a playergroup updates."""
- self.update_state(skip_forward=True)
-
- def on_parent_update(self, player_id: str, changed_keys: set) -> None:
- """Call when (one of) the parent player(s) of a grouped player updates."""
- self.update_state(skip_forward=True)
-
- def on_remove(self) -> None:
- """Call when player is about to be removed (cleaned up) from player manager."""
-
- # DO NOT OVERRIDE BELOW
-
- @property
- def active_queue(self) -> PlayerQueue:
- """Return the queue that is currently active on/for this player."""
- for queue in self.mass.players.player_queues:
- if queue.stream and queue.stream.url == self.current_url:
- return queue
- return self.mass.players.get_player_queue(self.player_id)
-
- def update_state(self, skip_forward: bool = False) -> None:
- """Update current player state in the player manager."""
- if self.mass is None or self.mass.closed:
- # guard
- return
- self.on_update()
- # basic throttle: do not send state changed events if player did not change
- cur_state = self.to_dict()
- changed_keys = get_changed_keys(
- self._prev_state, cur_state, ignore_keys=["elapsed_time"]
- )
-
- if len(changed_keys) == 0:
- return
-
- # update the playerqueue
- self.mass.players.get_player_queue(self.player_id).on_player_update()
-
- self._prev_state = cur_state
- self.mass.signal_event(
- MassEvent(EventType.PLAYER_UPDATED, object_id=self.player_id, data=self)
- )
-
- if skip_forward:
- return
- if self.is_group:
- # update group player members when parent updates
- for child_player_id in self.group_members:
- if child_player_id == self.player_id:
- continue
- if player := self.mass.players.get_player(child_player_id):
- self.mass.loop.call_soon_threadsafe(
- player.on_parent_update, self.player_id, changed_keys
- )
-
- # update group player(s) when child updates
- for group_player in self.get_group_parents():
- self.mass.loop.call_soon_threadsafe(
- group_player.on_child_update, self.player_id, changed_keys
- )
-
- def to_dict(self) -> Dict[str, Any]:
- """Export object to dict."""
- return {
- "player_id": self.player_id,
- "name": self.name,
- "powered": self.powered,
- "elapsed_time": int(self.elapsed_time),
- "state": self.state.value,
- "available": self.available,
- "volume_level": int(self.volume_level),
- "is_group": self.is_group,
- "group_members": self.group_members,
- "group_leader": self.group_leader,
- "is_group_leader": self.is_group_leader,
- "is_passive": self.is_passive,
- "group_name": self.group_name,
- "group_powered": self.group_powered,
- "group_volume_level": int(self.group_volume_level),
- "device_info": self.device_info.to_dict(),
- "active_queue": self.active_queue.queue_id
- if self.active_queue
- else self.player_id,
- }
-
- def get_child_players(
- self,
- only_powered: bool = False,
- only_playing: bool = False,
- ) -> List[Player]:
- """Get players attached to a grouped player."""
- if not self.mass:
- return []
- child_players = set()
- for child_id in self.group_members:
- if child_player := self.mass.players.get_player(child_id):
- if not (not only_powered or child_player.powered):
- continue
- if not (not only_playing or child_player.state == PlayerState.PLAYING):
- continue
- child_players.add(child_player)
- return list(child_players)
-
- def get_group_parents(
- self,
- only_powered: bool = False,
- only_playing: bool = False,
- ) -> List[Player]:
- """Get players which have this player as child in a group."""
- if not self.mass:
- return []
- parent_players = set()
- for player in self.mass.players:
- if player.player_id == self.player_id:
- continue
- if self.player_id not in player.group_members:
- continue
- if not (not only_powered or player.powered):
- continue
- if not (not only_playing or player.state == PlayerState.PLAYING):
- continue
- parent_players.add(player)
- return list(parent_players)
+++ /dev/null
-"""Model for a PlayerQueue."""
-from __future__ import annotations
-
-import asyncio
-import random
-from asyncio import TimerHandle
-from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
-
-from music_assistant.constants import ANNOUNCE_ALERT_FILE, FALLBACK_DURATION
-from music_assistant.helpers.tags import parse_tags
-from music_assistant.helpers.util import try_parse_int
-from music_assistant.models.enums import (
- EventType,
- MediaType,
- ProviderType,
- QueueOption,
- RepeatMode,
-)
-from music_assistant.models.errors import MediaNotFoundError, MusicAssistantError
-from music_assistant.models.event import MassEvent
-from music_assistant.models.media_items import MediaItemType, media_from_dict
-
-from .player import Player, PlayerState
-from .queue_item import QueueItem
-from .queue_settings import QueueSettings
-
-if TYPE_CHECKING:
- from music_assistant.controllers.streams import QueueStream
- from music_assistant.mass import MusicAssistant
-
-
-@dataclass
-class QueueSnapShot:
- """Represent a snapshot of the queue and its settings."""
-
- powered: bool
- state: PlayerState
- items: List[QueueItem]
- index: Optional[int]
- position: int
- settings: dict
- volume_level: int
- player_url: str
-
-
-class PlayerQueue:
- """Represents a PlayerQueue object."""
-
- def __init__(self, mass: MusicAssistant, player_id: str):
- """Instantiate a PlayerQueue instance."""
- self.mass = mass
- self.logger = mass.players.logger
- self.queue_id = player_id
- self._stream_id: str = ""
- self._settings = QueueSettings(self)
- self._current_index: Optional[int] = None
- self._current_item_elapsed_time: int = 0
- self._prev_item: Optional[QueueItem] = None
- self._last_player_state: Tuple[PlayerState, str] = (PlayerState.OFF, "")
- self._items: List[QueueItem] = []
- self._save_task: TimerHandle = None
- self._last_player_update: int = 0
- self._last_stream_id: str = ""
- self._snapshot: Optional[QueueSnapShot] = None
- self._radio_source: List[MediaItemType] = []
- self.announcement_in_progress: bool = False
-
- async def setup(self) -> None:
- """Handle async setup of instance."""
- await self._settings.restore()
- await self._restore_items()
- self.mass.signal_event(
- MassEvent(EventType.QUEUE_ADDED, object_id=self.queue_id, data=self)
- )
-
- @property
- def settings(self) -> QueueSettings:
- """Return settings/preferences for this PlayerQueue."""
- return self._settings
-
- @property
- def player(self) -> Player:
- """Return the player attached to this queue."""
- return self.mass.players.get_player(self.queue_id)
-
- @property
- def available(self) -> bool:
- """Return if player(queue) is available."""
- return self.player.available
-
- @property
- def stream(self) -> QueueStream | None:
- """Return the currently connected/active stream for this queue."""
- return self.mass.streams.queue_streams.get(self._stream_id)
-
- @property
- def index_in_buffer(self) -> int | None:
- """Return the item that is curently loaded into the buffer."""
- if not self._stream_id:
- return None
- if stream := self.mass.streams.queue_streams.get(self._stream_id):
- if not stream.done.is_set():
- return stream.index_in_buffer
- return self.current_index
-
- @property
- def active(self) -> bool:
- """Return if the queue is currenty active."""
- if stream := self.stream:
- if not self.stream.done.is_set():
- return True
- if not self.player.current_url:
- return False
- return stream.stream_id in self.player.current_url
- return False
-
- @property
- def elapsed_time(self) -> float:
- """Return elapsed time of current playing media in seconds."""
- if not self.active:
- return self.player.elapsed_time
- return self._current_item_elapsed_time
-
- @property
- def items(self) -> List[QueueItem]:
- """Return all items in this queue."""
- return self._items
-
- @property
- def current_index(self) -> int | None:
- """Return current index."""
- return self._current_index
-
- @property
- def current_item(self) -> QueueItem | None:
- """
- Return the current item in the queue.
-
- Returns None if queue is empty.
- """
- if self._current_index is None:
- return None
- if self._current_index >= len(self._items):
- return None
- return self._items[self._current_index]
-
- @property
- def next_item(self) -> QueueItem | None:
- """
- Return the next item in the queue.
-
- Returns None if queue is empty or no more items.
- """
- try:
- next_index = self.get_next_index(self._current_index)
- return self._items[next_index]
- except (IndexError, TypeError):
- return None
-
- def get_item(self, index: int) -> QueueItem | None:
- """Get queue item by index."""
- if index is not None and len(self._items) > index:
- return self._items[index]
- return None
-
- def item_by_id(self, queue_item_id: str) -> QueueItem | None:
- """Get item by queue_item_id from queue."""
- if not queue_item_id:
- return None
- return next((x for x in self.items if x.item_id == queue_item_id), None)
-
- def index_by_id(self, queue_item_id: str) -> Optional[int]:
- """Get index by queue_item_id."""
- for index, item in enumerate(self.items):
- if item.item_id == queue_item_id:
- return index
- return None
-
- async def play_media(
- self,
- media: str | List[str] | MediaItemType | List[MediaItemType],
- option: QueueOption = QueueOption.PLAY,
- radio_mode: bool = False,
- passive: bool = False,
- ) -> str:
- """
- Play media item(s) on the given queue.
-
- media: Media(s) that should be played (MediaItem(s) or uri's).
- queue_opt: Which enqueue mode to use.
- radio_mode: Enable radio mode for the given item(s).
- passive: If passive set to true the stream url will not be sent to the player.
- """
- if self.announcement_in_progress:
- self.logger.warning("Ignore queue command: An announcement is in progress")
- return
-
- # a single item or list of items may be provided
- if not isinstance(media, list):
- media = [media]
-
- # clear queue first if it was finished
- if self._current_index and self._current_index >= (len(self._items) - 1):
- self._current_index = None
- self._items = []
-
- # clear radio source items if needed
- if option not in (QueueOption.ADD, QueueOption.PLAY, QueueOption.NEXT):
- self._radio_source = []
-
- tracks: List[MediaItemType] = []
- for item in media:
- # parse provided uri into a MA MediaItem or Basic QueueItem from URL
- if isinstance(item, str):
- try:
- media_item = await self.mass.music.get_item_by_uri(item)
- except MusicAssistantError as err:
- # invalid MA uri or item not found error
- raise MediaNotFoundError(f"Invalid uri: {item}") from err
- elif isinstance(item, dict):
- media_item = media_from_dict(item)
- else:
- media_item = item
-
- # collect tracks to play
- ctrl = self.mass.music.get_controller(media_item.media_type)
- if radio_mode:
- self._radio_source.append(media_item)
- # if radio mode enabled, grab the first batch of tracks here
- tracks += await ctrl.dynamic_tracks(
- item_id=media_item.item_id, provider_type=media_item.provider
- )
- elif media_item.media_type in (
- MediaType.ARTIST,
- MediaType.ALBUM,
- MediaType.PLAYLIST,
- ):
- tracks += await ctrl.tracks(
- media_item.item_id, provider_type=media_item.provider
- )
- else:
- # single track or radio item
- tracks += [media_item]
-
- # only add valid/available items
- queue_items = [
- QueueItem.from_media_item(x) for x in tracks if x and x.available
- ]
-
- # load the items into the queue
- cur_index = self.index_in_buffer or self._current_index or 0
- shuffle = self.settings.shuffle_enabled and len(queue_items) >= 5
-
- # handle replace: clear all items and replace with the new items
- if option == QueueOption.REPLACE:
- await self.clear()
- await self.load(queue_items, shuffle=shuffle)
- if not passive:
- await self.play_index(0)
- # handle next: add item(s) in the index next to the playing/loaded/buffered index
- elif option == QueueOption.NEXT:
- await self.load(queue_items, insert_at_index=cur_index + 1, shuffle=shuffle)
- elif option == QueueOption.REPLACE_NEXT:
- await self.load(
- queue_items,
- insert_at_index=cur_index + 1,
- keep_remaining=False,
- shuffle=shuffle,
- )
- # handle play: replace current loaded/playing index with new item(s)
- elif option == QueueOption.PLAY:
- await self.load(queue_items, insert_at_index=cur_index, shuffle=shuffle)
- if not passive:
- await self.play_index(cur_index)
- # handle add: add/append item(s) to the remaining queue items
- elif option == QueueOption.ADD:
- shuffle = self.settings.shuffle_enabled
- if shuffle:
- # shuffle the new items with remaining queue items
- insert_at_index = cur_index + 1
- else:
- # just append at the end
- insert_at_index = len(self._items)
- await self.load(
- queue_items, insert_at_index=insert_at_index, shuffle=shuffle
- )
-
- async def _fill_radio_tracks(self) -> None:
- """Fill the Queue with (additional) Radio tracks."""
- assert self._radio_source, "No Radio item(s) loaded/active!"
- tracks: List[MediaItemType] = []
- # grab dynamic tracks for (all) source items
- # shuffle the source items, just in case
- for radio_item in random.sample(self._radio_source, len(self._radio_source)):
- ctrl = self.mass.music.get_controller(radio_item.media_type)
- tracks += await ctrl.dynamic_tracks(
- item_id=radio_item.item_id, provider_type=radio_item.provider
- )
- # make sure we do not grab too much items
- if len(tracks) >= 50:
- break
- # fill queue - filter out unavailable items
- queue_items = [QueueItem.from_media_item(x) for x in tracks if x.available]
- await self.load(
- queue_items,
- insert_at_index=len(self._items) - 1,
- )
-
- async def play_announcement(self, url: str, prepend_alert: bool = False) -> str:
- """
- Play given uri as Announcement on the queue.
-
- url: URL that should be played as announcement, can only be plain url.
- prepend_alert: Prepend the (TTS) announcement with an alert bell sound.
- """
- if self.announcement_in_progress:
- self.logger.warning(
- "Ignore queue command: An announcement is (already) in progress"
- )
- return
-
- try:
- # create snapshot
- await self.snapshot_create()
- wait_time = 2
- # stop player if needed
- if self.active and self.player.state == PlayerState.PLAYING:
- await self.stop()
- self.announcement_in_progress = True
- await asyncio.sleep(0.1)
-
- # adjust volume if needed
- if self._settings.announce_volume_increase:
- announce_volume = (
- self.player.volume_level + self._settings.announce_volume_increase
- )
- announce_volume = min(announce_volume, 100)
- announce_volume = max(announce_volume, 0)
- # turn on player if needed (might be needed before adjusting the volume)
- if not self.player.powered:
- await self.player.power(True)
- wait_time += 2
- await self.player.volume_set(announce_volume)
-
- # prepend alert sound if needed
- if prepend_alert:
- announce_urls = (ANNOUNCE_ALERT_FILE, url)
- wait_time += 2
- else:
- announce_urls = (url,)
-
- # send announcement stream to player
- announce_stream_url = self.mass.streams.get_announcement_url(
- self.queue_id, announce_urls, self._settings.stream_type
- )
- await self.player.play_url(announce_stream_url)
-
- # wait for the player to finish playing
- info = await parse_tags(url)
- wait_time += info.duration or 10
- await asyncio.sleep(wait_time)
-
- except Exception as err: # pylint: disable=broad-except
- self.logger.exception("Error while playing announcement", exc_info=err)
- finally:
- # restore queue
- self.announcement_in_progress = False
- await self.snapshot_restore()
-
- async def stop(self) -> None:
- """Stop command on queue player."""
- if self.announcement_in_progress:
- self.logger.warning("Ignore queue command: An announcement is in progress")
- return
- if stream := self.stream:
- stream.signal_next = None
- # redirect to underlying player
- await self.player.stop()
-
- async def play(self) -> None:
- """Play (unpause) command on queue player."""
- if self.announcement_in_progress:
- self.logger.warning("Ignore queue command: An announcement is in progress")
- return
- if self.player.state == PlayerState.PAUSED:
- await self.player.play()
- else:
- await self.resume()
-
- async def pause(self) -> None:
- """Pause command on queue player."""
- if self.announcement_in_progress:
- self.logger.warning("Ignore queue command: An announcement is in progress")
- return
- # redirect to underlying player
- await self.player.pause()
-
- async def play_pause(self) -> None:
- """Toggle play/pause on queue/player."""
- if self.player.state == PlayerState.PLAYING:
- await self.pause()
- return
- await self.play()
-
- async def next(self) -> None:
- """Play the next track in the queue."""
- next_index = self.get_next_index(self._current_index, True)
- if next_index is None:
- return None
- await self.play_index(next_index)
-
- async def previous(self) -> None:
- """Play the previous track in the queue."""
- if self._current_index is None:
- return
- await self.play_index(max(self._current_index - 1, 0))
-
- async def skip_ahead(self, seconds: int = 10) -> None:
- """Skip X seconds ahead in track."""
- await self.seek(self.elapsed_time + seconds)
-
- async def skip_back(self, seconds: int = 10) -> None:
- """Skip X seconds back in track."""
- await self.seek(self.elapsed_time - seconds)
-
- async def seek(self, position: int) -> None:
- """Seek to a specific position in the track (given in seconds)."""
- assert self.current_item, "No item loaded"
- assert self.current_item.media_item.media_type == MediaType.TRACK
- assert self.current_item.duration
- assert position < self.current_item.duration
- await self.play_index(self._current_index, position)
-
- async def resume(self) -> None:
- """Resume previous queue."""
- last_player_url = self._last_player_state[1]
- if last_player_url and self.mass.streams.base_url not in last_player_url:
- self.logger.info("Trying to resume non-MA content %s...", last_player_url)
- await self.player.play_url(last_player_url)
- return
- resume_item = self.current_item
- next_item = self.next_item
- resume_pos = self._current_item_elapsed_time
- if (
- resume_item
- and next_item
- and resume_item.duration
- and resume_pos > (resume_item.duration * 0.9)
- ):
- # track is already played for > 90% - skip to next
- resume_item = next_item
- resume_pos = 0
- elif self._current_index is None and len(self._items) > 0:
- # items available in queue but no previous track, start at 0
- resume_item = self.get_item(0)
- resume_pos = 0
-
- if resume_item is not None:
- resume_pos = resume_pos if resume_pos > 10 else 0
- fade_in = resume_pos > 0
- await self.play_index(resume_item.item_id, resume_pos, fade_in)
- else:
- self.logger.warning(
- "resume queue requested for %s but queue is empty", self.queue_id
- )
-
- async def snapshot_create(self) -> None:
- """Create snapshot of current Queue state."""
- self.logger.debug("Creating snapshot...")
- self._snapshot = QueueSnapShot(
- powered=self.player.powered,
- state=self.player.state,
- items=self._items,
- index=self._current_index,
- position=self._current_item_elapsed_time,
- settings=self._settings.to_dict(),
- volume_level=self.player.volume_level,
- player_url=self.player.current_url,
- )
-
- async def snapshot_restore(self) -> None:
- """Restore snapshot of Queue state."""
- assert self._snapshot, "Create snapshot before restoring it."
- try:
- # clear queue first
- await self.clear()
- # restore volume if needed
- if self._snapshot.volume_level != self.player.volume_level:
- await self.player.volume_set(self._snapshot.volume_level)
- # restore queue
- self._settings.from_dict(self._snapshot.settings)
- await self.update_items(self._snapshot.items)
- self._current_index = self._snapshot.index
- self._current_item_elapsed_time = self._snapshot.position
- self._last_player_state = (
- self._snapshot.state,
- self._snapshot.player_url,
- )
- if self._snapshot.state in (PlayerState.PLAYING, PlayerState.PAUSED):
- await self.resume()
- if self._snapshot.state == PlayerState.PAUSED:
- await self.pause()
- if not self._snapshot.powered:
- await self.player.power(False)
- # reset snapshot once restored
- self.logger.debug("Restored snapshot...")
- except Exception as err: # pylint: disable=broad-except
- self.logger.exception("Error while restoring snapshot", exc_info=err)
- finally:
- self._snapshot = None
-
- async def play_index(
- self,
- index: Union[int, str],
- seek_position: int = 0,
- fade_in: bool = False,
- passive: bool = False,
- ) -> None:
- """Play item at index (or item_id) X in queue."""
- if self.announcement_in_progress:
- self.logger.warning("Ignore queue command: An announcement is in progress")
- return
- if stream := self.stream:
- # make sure that the previous stream is not auto restarted (race condition)
- stream.signal_next = None
- if not isinstance(index, int):
- index = self.index_by_id(index)
- if index is None:
- raise FileNotFoundError(f"Unknown index/id: {index}")
- if not len(self.items) > index:
- return
- self._current_index = index
- # start the queue stream
- stream = await self.queue_stream_start(
- start_index=index,
- seek_position=int(seek_position),
- fade_in=fade_in,
- )
- # execute the play command on the player(s)
- if not passive:
- await self.player.play_url(stream.url)
-
- async def move_item(self, queue_item_id: str, pos_shift: int = 1) -> None:
- """
- Move queue item x up/down the queue.
-
- param pos_shift: move item x positions down if positive value
- move item x positions up if negative value
- move item to top of queue as next item if 0
- """
- items = self._items.copy()
- item_index = self.index_by_id(queue_item_id)
- if pos_shift == 0 and self.player.state == PlayerState.PLAYING:
- new_index = (self._current_index or 0) + 1
- elif pos_shift == 0:
- new_index = self._current_index or 0
- else:
- new_index = item_index + pos_shift
- if (new_index < (self._current_index or 0)) or (new_index > len(self.items)):
- return
- # move the item in the list
- # TODO: guard for position that is already played/buffered!
- items.insert(new_index, items.pop(item_index))
- await self.update_items(items)
-
- async def delete_item(self, queue_item_id: str) -> None:
- """Delete item (by id or index) from the queue."""
- item_index = self.index_by_id(queue_item_id)
- if self.stream and item_index <= self.index_in_buffer:
- # ignore request if track already loaded in the buffer
- # the frontend should guard so this is just in case
- self.logger.warning("delete requested for item already loaded in buffer")
- return
- self._items.pop(item_index)
- self.signal_update(True)
-
- async def load(
- self,
- queue_items: List[QueueItem],
- insert_at_index: int = 0,
- keep_remaining: bool = True,
- shuffle: bool = False,
- ) -> None:
- """
- Load new items at index.
-
- queue_items: a list of QueueItem
- insert_at_index: insert the item(s) at this index
- keep_remaining: keep the remaining items after the insert
- shuffle: (re)shuffle the items after insert index
- """
-
- # keep previous/played items, append the new ones
- prev_items = self._items[:insert_at_index]
- next_items = queue_items
-
- # if keep_remaining, append the old previous items
- if keep_remaining:
- next_items += self._items[insert_at_index:]
-
- # we set the original insert order as attribute so we can un-shuffle
- for index, item in enumerate(next_items):
- item.sort_index += insert_at_index + index
- # (re)shuffle the final batch if needed
- if shuffle:
- next_items = random.sample(next_items, len(next_items))
- await self.update_items(prev_items + next_items)
-
- async def clear(self) -> None:
- """Clear all items in the queue."""
- self._radio_source = []
- if self.player.state not in (PlayerState.IDLE, PlayerState.OFF):
- await self.stop()
- await self.update_items([])
-
- def on_player_update(self) -> None:
- """Call when player updates."""
- prev_state = self._last_player_state
- new_state = (self.player.state, self.player.current_url)
-
- # handle PlayerState changed
- if new_state[0] != prev_state[0]:
-
- # store previous state
- if self.announcement_in_progress:
- # while announcement in progress dont update the last url
- # to allow us to resume from 3rd party sources
- # https://github.com/music-assistant/hass-music-assistant/issues/697
- self._last_player_state = (new_state[0], prev_state[1])
- else:
- self._last_player_state = new_state
-
- # the queue stream was aborted on purpose and needs to restart
- if (
- prev_state[0] == PlayerState.PLAYING
- and new_state[0] == PlayerState.IDLE
- and self.stream
- and self.stream.signal_next is not None
- ):
- # the queue stream was aborted on purpose (e.g. because of sample rate mismatch)
- # we need to restart the stream with the next index
- self._current_item_elapsed_time = 0
- self.mass.create_task(self.play_index(self.stream.signal_next))
- return
-
- # queue exhausted or player turned off/stopped
- if self.stream and (
- new_state[0] in (PlayerState.IDLE, PlayerState.OFF)
- or not self.player.available
- ):
- self.stream.signal_next = None
- # handle last track of the queue, set the index to index that is out of range
- if (self._current_index or 0) >= (len(self._items) - 1):
- self._current_index += 1
-
- # always signal update if the PlayerState changed
- if new_state[0] != prev_state[0]:
- self.signal_update()
-
- # update queue details only if we're the active queue for the attached player
- if self.player.active_queue != self or not self.active:
- return
-
- track_time = self._current_item_elapsed_time
- new_item_loaded = False
- if self.player.state == PlayerState.PLAYING and self.player.elapsed_time > 0:
- new_index, track_time = self.__get_queue_stream_index()
-
- # process new index
- if self._current_index != new_index:
- # queue index updated
- self._current_index = new_index
- # watch dynamic radio items refill if needed
- fill_index = len(self._items) - 5
- if self._radio_source and (new_index >= fill_index):
- self.mass.create_task(self._fill_radio_tracks())
-
- # check if a new track is loaded, wait for the streamdetails
- if (
- self.current_item
- and self._prev_item != self.current_item
- and self.current_item.streamdetails
- ):
- # new active item in queue
- new_item_loaded = True
- self._prev_item = self.current_item
- # update vars and signal update on eventbus if needed
- prev_item_time = int(self._current_item_elapsed_time)
- self._current_item_elapsed_time = int(track_time)
-
- if new_item_loaded:
- self.signal_update()
- self.mass.create_task(self._fetch_full_details(self._current_index))
- if abs(prev_item_time - self._current_item_elapsed_time) >= 1:
- self.mass.signal_event(
- MassEvent(
- EventType.QUEUE_TIME_UPDATED,
- object_id=self.queue_id,
- data=int(self.elapsed_time),
- )
- )
-
- async def queue_stream_start(
- self,
- start_index: int,
- seek_position: int = 0,
- fade_in: bool = False,
- ) -> QueueStream:
- """Start the queue stream runner."""
- # start the queue stream background task
- stream = await self.mass.streams.start_queue_stream(
- queue=self,
- start_index=start_index,
- seek_position=seek_position,
- fade_in=fade_in,
- output_format=self._settings.stream_type,
- )
- self._stream_id = stream.stream_id
- self._current_item_elapsed_time = 0
- self._current_index = start_index
- return stream
-
- def get_next_index(self, cur_index: Optional[int], is_skip: bool = False) -> int:
- """Return the next index for the queue, accounting for repeat settings."""
- # handle repeat single track
- if self.settings.repeat_mode == RepeatMode.ONE and not is_skip:
- return cur_index
- # handle repeat all
- if (
- self.settings.repeat_mode == RepeatMode.ALL
- and self._items
- and cur_index == (len(self._items) - 1)
- ):
- return 0
- # simply return the next index. other logic is guarded to detect the index
- # being higher than the number of items to detect end of queue and/or handle repeat.
- if cur_index is None:
- return 0
- next_index = cur_index + 1
- return next_index
-
- def signal_update(self, items_changed: bool = False) -> None:
- """Signal state changed of this queue."""
- if items_changed:
- self.mass.signal_event(
- MassEvent(
- EventType.QUEUE_ITEMS_UPDATED, object_id=self.queue_id, data=self
- )
- )
- # save items
- self.mass.create_task(
- self.mass.cache.set(
- f"queue.items.{self.queue_id}",
- [x.to_dict() for x in self._items],
- )
- )
-
- # always send the base event
- self.mass.signal_event(
- MassEvent(EventType.QUEUE_UPDATED, object_id=self.queue_id, data=self)
- )
- # save state
- self.mass.create_task(
- self.mass.database.set_setting(
- f"queue.{self.queue_id}.current_index", self._current_index
- )
- )
- self.mass.create_task(
- self.mass.database.set_setting(
- f"queue.{self.queue_id}.current_item_elapsed_time",
- self._current_item_elapsed_time,
- )
- )
-
- def to_dict(self) -> Dict[str, Any]:
- """Export object to dict."""
- cur_item = self.current_item.to_dict() if self.current_item else None
- next_item = self.next_item.to_dict() if self.next_item else None
-
- return {
- "queue_id": self.queue_id,
- "player": self.player.player_id,
- "name": self.player.name,
- "active": self.active,
- "elapsed_time": int(self.elapsed_time),
- "state": self.player.state.value,
- "available": self.player.available,
- "current_index": self.current_index,
- "index_in_buffer": self.index_in_buffer,
- "current_item": cur_item,
- "next_item": next_item,
- "items": len(self._items),
- "settings": self.settings.to_dict(),
- "radio_source": [x.to_dict() for x in self._radio_source[:5]],
- }
-
- async def update_items(self, queue_items: List[QueueItem]) -> None:
- """Update the existing queue items, mostly caused by reordering."""
- self._items = queue_items
- self.signal_update(True)
-
- def __get_queue_stream_index(self) -> Tuple[int, int]:
- """Calculate current queue index and current track elapsed time."""
- # player is playing a constant stream so we need to do this the hard way
- queue_index = 0
- elapsed_time_queue = self.player.elapsed_time
- total_time = 0
- track_time = 0
- if self._items and self.stream and len(self._items) > self.stream.start_index:
- # start_index: holds the position from which the stream started
- queue_index = self.stream.start_index
- queue_track = None
- while len(self._items) > queue_index:
- # keep enumerating the queue tracks to find current track
- # starting from the start index
- queue_track = self._items[queue_index]
- if not queue_track.streamdetails:
- track_time = elapsed_time_queue - total_time
- break
- if queue_track.streamdetails.seconds_streamed is not None:
- track_duration = queue_track.streamdetails.seconds_streamed
- else:
- track_duration = queue_track.duration or FALLBACK_DURATION
- if elapsed_time_queue > (track_duration + total_time):
- # total elapsed time is more than (streamed) track duration
- # move index one up
- total_time += track_duration
- queue_index += 1
- else:
- # no more seconds left to divide, this is our track
- # account for any seeking by adding the skipped seconds
- track_sec_skipped = queue_track.streamdetails.seconds_skipped or 0
- track_time = elapsed_time_queue + track_sec_skipped - total_time
- break
- return queue_index, track_time
-
- async def _restore_items(self) -> None:
- """Try to load the saved state from cache."""
- if queue_cache := await self.mass.cache.get(f"queue.items.{self.queue_id}"):
- try:
- self._items = [QueueItem.from_dict(x) for x in queue_cache]
- except (KeyError, AttributeError, TypeError) as err:
- self.logger.warning(
- "Unable to restore queue state for queue %s",
- self.queue_id,
- exc_info=err,
- )
- else:
- # restore state too
- db_key = f"queue.{self.queue_id}.current_index"
- if db_value := await self.mass.database.get_setting(db_key):
- self._current_index = try_parse_int(db_value)
- db_key = f"queue.{self.queue_id}.current_item_elapsed_time"
- if db_value := await self.mass.database.get_setting(db_key):
- self._current_item_elapsed_time = try_parse_int(db_value)
-
- await self.settings.restore()
-
- async def _fetch_full_details(self, index: int) -> None:
- """Background task that fetches the full details of an item in the queue."""
- if not self._items or len(self._items) < (index + 1):
- return
-
- item_before = self._items[index]
-
- # check if the details are already fetched
- if item_before.media_item.provider == ProviderType.DATABASE:
- return
-
- # fetch full details here to prevent all clients do this on their own
- full_details = await self.mass.music.get_item_by_uri(
- self.current_item.media_item.uri, lazy=False
- )
- # convert to queueitem in between to minimize data
- temp_queue_item = QueueItem.from_media_item(full_details)
-
- # safe guard: check that item still matches
- # prevents race condition where items changes just while we were waiting for data
- if self._items[index].item_id != item_before.item_id:
- return
- self._items[index].media_item = temp_queue_item.media_item
- self.signal_update()
+++ /dev/null
-"""Model a QueueItem."""
-from __future__ import annotations
-
-from dataclasses import dataclass
-from typing import Any, Dict, Optional, Union
-from uuid import uuid4
-
-from mashumaro import DataClassDictMixin
-
-from music_assistant.models.enums import MediaType
-from music_assistant.models.media_items import (
- ItemMapping,
- MediaItemImage,
- Radio,
- StreamDetails,
- Track,
-)
-
-
-@dataclass
-class QueueItem(DataClassDictMixin):
- """Representation of a queue item."""
-
- name: str = ""
- duration: Optional[int] = None
- item_id: str = ""
- sort_index: int = 0
- streamdetails: Optional[StreamDetails] = None
- media_item: Union[Track, Radio, None] = None
- image: Optional[MediaItemImage] = None
-
- def __post_init__(self):
- """Set default values."""
- if not self.item_id:
- self.item_id = str(uuid4())
- if self.streamdetails and self.streamdetails.stream_title:
- self.name = self.streamdetails.stream_title
- if not self.name:
- self.name = self.uri
-
- @classmethod
- def __pre_deserialize__(cls, d: Dict[Any, Any]) -> Dict[Any, Any]:
- """Run actions before deserialization."""
- d.pop("streamdetails", None)
- return d
-
- @property
- def uri(self) -> str:
- """Return uri for this QueueItem (for logging purposes)."""
- if self.media_item:
- return self.media_item.uri
- return self.item_id
-
- @property
- def media_type(self) -> MediaType:
- """Return MediaType for this QueueItem (for convenience purposes)."""
- if self.media_item:
- return self.media_item.media_type
- return MediaType.UNKNOWN
-
- @classmethod
- def from_media_item(cls, media_item: Track | Radio):
- """Construct QueueItem from track/radio item."""
- if media_item.media_type == MediaType.TRACK:
- artists = "/".join((x.name for x in media_item.artists))
- name = f"{artists} - {media_item.name}"
- # save a lot of data/bandwidth by simplifying nested objects
- media_item.artists = [ItemMapping.from_item(x) for x in media_item.artists]
- if media_item.album:
- media_item.album = ItemMapping.from_item(media_item.album)
- media_item.albums = []
- else:
- name = media_item.name
- return cls(
- name=name,
- duration=media_item.duration,
- media_item=media_item,
- image=media_item.image,
- )
+++ /dev/null
-"""Model for a PlayerQueue's settings."""
-from __future__ import annotations
-
-import asyncio
-import random
-from typing import TYPE_CHECKING, Any, Dict, Optional
-
-from .enums import ContentType, CrossFadeMode, MetadataMode, RepeatMode
-
-if TYPE_CHECKING:
- from .player_queue import PlayerQueue
-
-
-class QueueSettings:
- """Representation of (user adjustable) PlayerQueue settings/preferences."""
-
- def __init__(self, queue: PlayerQueue) -> None:
- """Initialize."""
- self._queue = queue
- self.mass = queue.mass
- self._repeat_mode: RepeatMode = RepeatMode.OFF
- self._shuffle_enabled: bool = False
- self._crossfade_mode: CrossFadeMode = CrossFadeMode.DISABLED
- self._crossfade_duration: int = 6
- self._volume_normalization_enabled: bool = True
- self._volume_normalization_target: int = -14
- self._stream_type: ContentType = queue.player.stream_type
- self._max_sample_rate: int = queue.player.max_sample_rate
- self._announce_volume_increase: int = 15
- self._metadata_mode: MetadataMode = MetadataMode.DEFAULT
-
- @property
- def repeat_mode(self) -> RepeatMode:
- """Return repeat enabled setting."""
- return self._repeat_mode
-
- @repeat_mode.setter
- def repeat_mode(self, mode: RepeatMode) -> None:
- """Set repeat enabled setting."""
- if self._repeat_mode != mode:
- self._repeat_mode = mode
- self._on_update("repeat_mode")
-
- @property
- def shuffle_enabled(self) -> bool:
- """Return shuffle enabled setting."""
- return self._shuffle_enabled
-
- @shuffle_enabled.setter
- def shuffle_enabled(self, enabled: bool) -> None:
- """Set shuffle enabled setting."""
- if not self._shuffle_enabled and enabled:
- # shuffle requested
- self._shuffle_enabled = True
- cur_index = self._queue.index_in_buffer
- cur_item = self._queue.get_item(cur_index)
- if cur_item is not None:
- played_items = self._queue.items[:cur_index]
- next_items = self._queue.items[cur_index + 1 :]
- # for now we use default python random function
- # can be extended with some more magic based on last_played and stuff
- next_items = random.sample(next_items, len(next_items))
- items = played_items + [cur_item] + next_items
- asyncio.create_task(self._queue.update_items(items))
- self._on_update("shuffle_enabled")
- elif self._shuffle_enabled and not enabled:
- # unshuffle
- self._shuffle_enabled = False
- cur_index = self._queue.index_in_buffer
- cur_item = self._queue.get_item(cur_index)
- if cur_item is not None:
- played_items = self._queue.items[:cur_index]
- next_items = self._queue.items[cur_index + 1 :]
- next_items.sort(key=lambda x: x.sort_index, reverse=False)
- items = played_items + [cur_item] + next_items
- asyncio.create_task(self._queue.update_items(items))
- self._on_update("shuffle_enabled")
-
- @property
- def crossfade_mode(self) -> CrossFadeMode:
- """Return crossfade mode setting."""
- return self._crossfade_mode
-
- @crossfade_mode.setter
- def crossfade_mode(self, mode: CrossFadeMode) -> None:
- """Set crossfade enabled setting."""
- if self._crossfade_mode != mode:
- # TODO: restart the queue stream if its playing
- self._crossfade_mode = mode
- self._on_update("crossfade_mode")
-
- @property
- def crossfade_duration(self) -> int:
- """Return crossfade_duration setting."""
- return self._crossfade_duration
-
- @crossfade_duration.setter
- def crossfade_duration(self, duration: int) -> None:
- """Set crossfade_duration setting (1..10 seconds)."""
- duration = max(1, duration)
- duration = min(10, duration)
- if self._crossfade_duration != duration:
- self._crossfade_duration = duration
- self._on_update("crossfade_duration")
-
- @property
- def volume_normalization_enabled(self) -> bool:
- """Return volume_normalization_enabled setting."""
- return self._volume_normalization_enabled
-
- @volume_normalization_enabled.setter
- def volume_normalization_enabled(self, enabled: bool) -> None:
- """Set volume_normalization_enabled setting."""
- if self._volume_normalization_enabled != enabled:
- self._volume_normalization_enabled = enabled
- self._on_update("volume_normalization_enabled")
-
- @property
- def volume_normalization_target(self) -> float:
- """Return volume_normalization_target setting."""
- return self._volume_normalization_target
-
- @volume_normalization_target.setter
- def volume_normalization_target(self, target: float) -> None:
- """Set volume_normalization_target setting (-40..10 LUFS)."""
- target = max(-40, target)
- target = min(10, target)
- if self._volume_normalization_target != target:
- self._volume_normalization_target = target
- self._on_update("volume_normalization_target")
-
- @property
- def stream_type(self) -> ContentType:
- """Return supported/preferred stream type for this playerqueue."""
- return self._stream_type
-
- @stream_type.setter
- def stream_type(self, value: ContentType) -> None:
- """Set supported/preferred stream type for this playerqueue."""
- if self._stream_type != value:
- self._stream_type = value
- self._on_update("stream_type")
-
- @property
- def max_sample_rate(self) -> int:
- """Return max supported/needed sample rate(s) for this playerqueue."""
- return self._max_sample_rate
-
- @max_sample_rate.setter
- def max_sample_rate(self, value: ContentType) -> None:
- """Set supported/preferred sample rate(s) for this playerqueue."""
- if self._max_sample_rate != value:
- self._max_sample_rate = value
- self._on_update("max_sample_rate")
-
- @property
- def announce_volume_increase(self) -> int:
- """Return announce_volume_increase setting (percentage relative to current)."""
- return self._announce_volume_increase
-
- @announce_volume_increase.setter
- def announce_volume_increase(self, volume_increase: int) -> None:
- """Set announce_volume_increase setting."""
- if self._announce_volume_increase != volume_increase:
- self._announce_volume_increase = volume_increase
- self._on_update("announce_volume_increase")
-
- @property
- def metadata_mode(self) -> MetadataMode:
- """Return metadata mode setting."""
- return self._metadata_mode
-
- @metadata_mode.setter
- def metadata_mode(self, mode: MetadataMode) -> None:
- """Set metadata mode setting."""
- if self._metadata_mode != mode:
- self._metadata_mode = mode
- self._on_update("metadata_mode")
-
- def to_dict(self) -> Dict[str, Any]:
- """Return dict from settings."""
- return {
- "repeat_mode": self.repeat_mode.value,
- "shuffle_enabled": self.shuffle_enabled,
- "crossfade_mode": self.crossfade_mode.value,
- "crossfade_duration": self.crossfade_duration,
- "volume_normalization_enabled": self.volume_normalization_enabled,
- "volume_normalization_target": self.volume_normalization_target,
- "stream_type": self.stream_type.value,
- "max_sample_rate": self.max_sample_rate,
- "announce_volume_increase": self.announce_volume_increase,
- "metadata_mode": self.metadata_mode.value,
- }
-
- def from_dict(self, d: Dict[str, Any]) -> None:
- """Initialize settings from dict."""
- self._repeat_mode = RepeatMode(d.get("repeat_mode", self._repeat_mode.value))
- self._shuffle_enabled = bool(d.get("shuffle_enabled", self._shuffle_enabled))
- self._crossfade_mode = CrossFadeMode(
- d.get("crossfade_mode", self._crossfade_mode.value)
- )
- self._crossfade_duration = int(
- d.get("crossfade_duration", self._crossfade_duration)
- )
- self._volume_normalization_enabled = bool(
- d.get("volume_normalization_enabled", self._volume_normalization_enabled)
- )
- self._volume_normalization_target = float(
- d.get("volume_normalization_target", self._volume_normalization_target)
- )
- self._stream_type = ContentType(d.get("stream_type", self._stream_type.value))
- self._max_sample_rate = int(d.get("max_sample_rate", self._max_sample_rate))
- self._announce_volume_increase = int(
- d.get("announce_volume_increase", self._announce_volume_increase)
- )
- self._metadata_mode = MetadataMode(
- d.get("metadata_mode", self._metadata_mode.value)
- )
-
- async def restore(self) -> None:
- """Restore state from db."""
- values = {}
- for key in self.to_dict():
- db_key = f"{self._queue.queue_id}_{key}"
- if db_value := await self.mass.database.get_setting(db_key):
- values[key] = db_value
- self.from_dict(values)
-
- def _on_update(self, changed_key: Optional[str] = None) -> None:
- """Handle state changed."""
- self._queue.signal_update()
- self.mass.create_task(self.save(changed_key))
-
- async def save(self, changed_key: Optional[str] = None) -> None:
- """Save state in db."""
- for key, value in self.to_dict().items():
- if key == changed_key or changed_key is None:
- db_key = f"{self._queue.queue_id}_{key}"
- await self.mass.database.set_setting(db_key, value)
+++ /dev/null
-"""Package with Music Provider controllers."""
+++ /dev/null
-"""Package with FileSystem Music provider(s)."""
-
-from .base import FileSystemProviderBase # noqa
-from .local import LocalFileSystemProvider # noqa
-from .smb import SMBFileSystemProvider # noqa
+++ /dev/null
-"""Filesystem musicprovider support for MusicAssistant."""
-from __future__ import annotations
-
-import os
-from abc import abstractmethod
-from dataclasses import dataclass
-from time import time
-from typing import AsyncGenerator, List, Optional, Set, Tuple
-
-import xmltodict
-
-from music_assistant.constants import VARIOUS_ARTISTS, VARIOUS_ARTISTS_ID
-from music_assistant.controllers.database import SCHEMA_VERSION
-from music_assistant.helpers.compare import compare_strings
-from music_assistant.helpers.playlists import parse_m3u, parse_pls
-from music_assistant.helpers.tags import parse_tags, split_items
-from music_assistant.helpers.util import parse_title_and_version
-from music_assistant.models.enums import MusicProviderFeature
-from music_assistant.models.errors import MediaNotFoundError, MusicAssistantError
-from music_assistant.models.media_items import (
- Album,
- AlbumType,
- Artist,
- BrowseFolder,
- ContentType,
- ImageType,
- MediaItemImage,
- MediaItemType,
- MediaType,
- Playlist,
- ProviderMapping,
- Radio,
- StreamDetails,
- Track,
-)
-from music_assistant.models.music_provider import MusicProvider
-
-from .helpers import get_parentdir
-
-TRACK_EXTENSIONS = ("mp3", "m4a", "mp4", "flac", "wav", "ogg", "aiff", "wma", "dsf")
-PLAYLIST_EXTENSIONS = ("m3u", "pls")
-SUPPORTED_EXTENSIONS = TRACK_EXTENSIONS + PLAYLIST_EXTENSIONS
-IMAGE_EXTENSIONS = ("jpg", "jpeg", "JPG", "JPEG", "png", "PNG", "gif", "GIF")
-
-
-@dataclass
-class FileSystemItem:
- """
- Representation of an item (file or directory) on the filesystem.
-
- - name: Name (not path) of the file (or directory).
- - path: Relative path to the item on this filesystem provider.
- - absolute_path: Absolute (provider dependent) path to this item.
- - is_file: Boolean if item is file (not directory or symlink).
- - is_dir: Boolean if item is directory (not file).
- - checksum: Checksum for this path (usually last modified time).
- - file_size : File size in number of bytes or None if unknown (or not a file).
- - local_path: Optional local accessible path to this (file)item, supported by ffmpeg.
- """
-
- name: str
- path: str
- absolute_path: str
- is_file: bool
- is_dir: bool
- checksum: str
- file_size: Optional[int] = None
- local_path: Optional[str] = None
-
- @property
- def ext(self) -> str | None:
- """Return file extension."""
- try:
- return self.name.rsplit(".", 1)[1]
- except IndexError:
- return None
-
-
-class FileSystemProviderBase(MusicProvider):
- """
- Base Implementation of a musicprovider for files.
-
- Reads ID3 tags from file and falls back to parsing filename.
- Optionally reads metadata from nfo files and images in folder structure <artist>/<album>.
- Supports m3u files only for playlists.
- Supports having URI's from streaming providers within m3u playlist.
- """
-
- @abstractmethod
- async def setup(self) -> bool:
- """Handle async initialization of the provider."""
-
- @abstractmethod
- async def listdir(
- self, path: str, recursive: bool = False
- ) -> AsyncGenerator[FileSystemItem, None]:
- """
- List contents of a given provider directory/path.
-
- Parameters:
- - path: path of the directory (relative or absolute) to list contents of.
- Empty string for provider's root.
- - recursive: If True will recursively keep unwrapping subdirectories (scandir equivalent).
-
- Returns:
- AsyncGenerator yielding FileSystemItem objects.
-
- """
- yield
-
- @abstractmethod
- async def resolve(self, file_path: str) -> FileSystemItem:
- """Resolve (absolute or relative) path to FileSystemItem."""
-
- @abstractmethod
- async def exists(self, file_path: str) -> bool:
- """Return bool is this FileSystem musicprovider has given file/dir."""
-
- @abstractmethod
- async def read_file_content(
- self, file_path: str, seek: int = 0
- ) -> AsyncGenerator[bytes, None]:
- """Yield (binary) contents of file in chunks of bytes."""
- yield
-
- @abstractmethod
- async def write_file_content(self, file_path: str, data: bytes) -> None:
- """Write entire file content as bytes (e.g. for playlists)."""
-
- ##############################################
- # DEFAULT/GENERIC IMPLEMENTATION BELOW
- # should normally not be needed to override
-
- @property
- def supported_features(self) -> Tuple[MusicProviderFeature]:
- """Return the features supported by this MusicProvider."""
- return (
- MusicProviderFeature.LIBRARY_ARTISTS,
- MusicProviderFeature.LIBRARY_ALBUMS,
- MusicProviderFeature.LIBRARY_TRACKS,
- MusicProviderFeature.LIBRARY_PLAYLISTS,
- MusicProviderFeature.PLAYLIST_TRACKS_EDIT,
- MusicProviderFeature.PLAYLIST_CREATE,
- MusicProviderFeature.BROWSE,
- MusicProviderFeature.SEARCH,
- )
-
- async def search(
- self, search_query: str, media_types=Optional[List[MediaType]], limit: int = 5
- ) -> List[MediaItemType]:
- """Perform search on this file based musicprovider."""
- result = []
- # searching the filesystem is slow and unreliable,
- # instead we make some (slow) freaking queries to the db ;-)
- params = {"name": f"%{search_query}%", "provider_id": f"%{self.id}%"}
- if media_types is None or MediaType.TRACK in media_types:
- query = "SELECT * FROM tracks WHERE name LIKE :name AND provider_mappings LIKE :provider_id"
- tracks = await self.mass.music.tracks.get_db_items_by_query(query, params)
- result += tracks
- if media_types is None or MediaType.ALBUM in media_types:
- query = "SELECT * FROM albums WHERE name LIKE :name AND provider_mappings LIKE :provider_id"
- albums = await self.mass.music.albums.get_db_items_by_query(query, params)
- result += albums
- if media_types is None or MediaType.ARTIST in media_types:
- query = "SELECT * FROM artists WHERE name LIKE :name AND provider_mappings LIKE :provider_id"
- artists = await self.mass.music.artists.get_db_items_by_query(query, params)
- result += artists
- if media_types is None or MediaType.PLAYLIST in media_types:
- query = "SELECT * FROM playlists WHERE name LIKE :name AND provider_mappings LIKE :provider_id"
- playlists = await self.mass.music.playlists.get_db_items_by_query(
- query, params
- )
- result += playlists
- return result
-
- async def browse(self, path: str) -> BrowseFolder:
- """
- Browse this provider's items.
-
- :param path: The path to browse, (e.g. provid://artists).
- """
- _, item_path = path.split("://")
- if not item_path:
- item_path = ""
- subitems = []
- async for item in self.listdir(item_path, recursive=False):
- if item.is_dir:
- subitems.append(
- BrowseFolder(
- item_id=item.path,
- provider=self.type,
- path=f"{self.id}://{item.path}",
- name=item.name,
- )
- )
- continue
-
- if "." not in item.name or not item.ext:
- # skip system files and files without extension
- continue
-
- if item.ext in TRACK_EXTENSIONS:
- if db_item := await self.mass.music.tracks.get_db_item_by_prov_id(
- item.path, provider_id=self.id
- ):
- subitems.append(db_item)
- elif track := await self.get_track(item.path):
- # make sure that the item exists
- # https://github.com/music-assistant/hass-music-assistant/issues/707
- db_item = await self.mass.music.tracks.add_db_item(track)
- subitems.append(db_item)
- continue
- if item.ext in PLAYLIST_EXTENSIONS:
- if db_item := await self.mass.music.playlists.get_db_item_by_prov_id(
- item.path, provider_id=self.id
- ):
- subitems.append(db_item)
- elif playlist := await self.get_playlist(item.path):
- # make sure that the item exists
- # https://github.com/music-assistant/hass-music-assistant/issues/707
- db_item = await self.mass.music.playlists.add_db_item(playlist)
- subitems.append(db_item)
- continue
-
- return BrowseFolder(
- item_id=item_path,
- provider=self.type,
- path=path,
- name=item_path or self.name,
- # make sure to sort the resulting listing
- items=sorted(subitems, key=lambda x: (x.name.casefold(), x.name)),
- )
-
- async def sync_library(
- self, media_types: Optional[Tuple[MediaType]] = None
- ) -> None:
- """Run library sync for this provider."""
- cache_key = f"{self.id}.checksums"
- prev_checksums = await self.mass.cache.get(cache_key, SCHEMA_VERSION)
- save_checksum_interval = 0
- if prev_checksums is None:
- prev_checksums = {}
-
- # find all music files in the music directory and all subfolders
- # we work bottom up, as-in we derive all info from the tracks
- cur_checksums = {}
- async for item in self.listdir("", recursive=True):
-
- if "." not in item.name or not item.ext:
- # skip system files and files without extension
- continue
-
- if item.ext not in SUPPORTED_EXTENSIONS:
- # unsupported file extension
- continue
-
- try:
- cur_checksums[item.path] = item.checksum
- if item.checksum == prev_checksums.get(item.path):
- continue
-
- if item.ext in TRACK_EXTENSIONS:
- # add/update track to db
- track = await self.get_track(item.path)
- # if the track was edited on disk, always overwrite existing db details
- overwrite_existing = item.path in prev_checksums
- await self.mass.music.tracks.add_db_item(
- track, overwrite_existing=overwrite_existing
- )
- elif item.ext in PLAYLIST_EXTENSIONS:
- playlist = await self.get_playlist(item.path)
- # add/update] playlist to db
- playlist.metadata.checksum = item.checksum
- # playlist is always in-library
- playlist.in_library = True
- await self.mass.music.playlists.add_db_item(playlist)
- except Exception as err: # pylint: disable=broad-except
- # we don't want the whole sync to crash on one file so we catch all exceptions here
- self.logger.exception("Error processing %s - %s", item.path, str(err))
-
- # save checksums every 100 processed items
- # this allows us to pickup where we leftoff when initial scan gets interrupted
- if save_checksum_interval == 100:
- await self.mass.cache.set(cache_key, cur_checksums, SCHEMA_VERSION)
- save_checksum_interval = 0
- else:
- save_checksum_interval += 1
-
- # store (final) checksums in cache
- await self.mass.cache.set(cache_key, cur_checksums, SCHEMA_VERSION)
- # work out deletions
- deleted_files = set(prev_checksums.keys()) - set(cur_checksums.keys())
- await self._process_deletions(deleted_files)
-
- async def _process_deletions(self, deleted_files: Set[str]) -> None:
- """Process all deletions."""
- # process deleted tracks/playlists
- for file_path in deleted_files:
-
- _, ext = file_path.rsplit(".", 1)
- if ext not in SUPPORTED_EXTENSIONS:
- # unsupported file extension
- continue
-
- if ext in PLAYLIST_EXTENSIONS:
- controller = self.mass.music.get_controller(MediaType.PLAYLIST)
- else:
- controller = self.mass.music.get_controller(MediaType.TRACK)
-
- if db_item := await controller.get_db_item_by_prov_id(file_path, self.type):
- await controller.remove_prov_mapping(db_item.item_id, self.id)
-
- async def get_artist(self, prov_artist_id: str) -> Artist:
- """Get full artist details by id."""
- db_artist = await self.mass.music.artists.get_db_item_by_prov_id(
- provider_item_id=prov_artist_id, provider_id=self.id
- )
- if db_artist is None:
- raise MediaNotFoundError(f"Artist not found: {prov_artist_id}")
- if await self.exists(prov_artist_id):
- # if path exists on disk allow parsing full details to allow refresh of metadata
- return await self._parse_artist(db_artist.name, artist_path=prov_artist_id)
- return db_artist
-
- async def get_album(self, prov_album_id: str) -> Album:
- """Get full album details by id."""
- db_album = await self.mass.music.albums.get_db_item_by_prov_id(
- provider_item_id=prov_album_id, provider_id=self.id
- )
- if db_album is None:
- raise MediaNotFoundError(f"Album not found: {prov_album_id}")
- if await self.exists(prov_album_id):
- # if path exists on disk allow parsing full details to allow refresh of metadata
- return await self._parse_album(
- db_album.name, prov_album_id, db_album.artists
- )
- return db_album
-
- async def get_track(self, prov_track_id: str) -> Track:
- """Get full track details by id."""
- if not await self.exists(prov_track_id):
- raise MediaNotFoundError(f"Track path does not exist: {prov_track_id}")
-
- file_item = await self.resolve(prov_track_id)
-
- # parse tags
- input_file = file_item.local_path or self.read_file_content(
- file_item.absolute_path
- )
- tags = await parse_tags(input_file)
-
- name, version = parse_title_and_version(tags.title)
- track = Track(
- item_id=file_item.path,
- provider=self.type,
- name=name,
- version=version,
- )
-
- # album
- if tags.album:
- # work out if we have an album folder
- album_dir = get_parentdir(file_item.path, tags.album)
-
- # album artist(s)
- if tags.album_artists:
- album_artists = []
- for index, album_artist_str in enumerate(tags.album_artists):
- # work out if we have an artist folder
- artist_dir = get_parentdir(file_item.path, album_artist_str)
- artist = await self._parse_artist(
- album_artist_str, artist_path=artist_dir
- )
- if not artist.musicbrainz_id:
- try:
- artist.musicbrainz_id = tags.musicbrainz_albumartistids[
- index
- ]
- except IndexError:
- pass
- album_artists.append(artist)
- else:
- # always fallback to various artists as album artist if user did not tag album artist
- # ID3 tag properly because we must have an album artist
- self.logger.warning(
- "%s is missing ID3 tag [albumartist], using %s as fallback",
- file_item.path,
- VARIOUS_ARTISTS,
- )
- album_artists = [await self._parse_artist(name=VARIOUS_ARTISTS)]
-
- track.album = await self._parse_album(
- tags.album,
- album_dir,
- artists=album_artists,
- )
- else:
- self.logger.warning("%s is missing ID3 tag [album]", file_item.path)
-
- # track artist(s)
- for index, track_artist_str in enumerate(tags.artists):
- # re-use album artist details if possible
- if track.album:
- if artist := next(
- (x for x in track.album.artists if x.name == track_artist_str), None
- ):
- track.artists.append(artist)
- continue
- artist = await self._parse_artist(track_artist_str)
- if not artist.musicbrainz_id:
- try:
- artist.musicbrainz_id = tags.musicbrainz_artistids[index]
- except IndexError:
- pass
- track.artists.append(artist)
-
- # cover image - prefer album image, fallback to embedded
- if track.album and track.album.image:
- track.metadata.images = [track.album.image]
- elif tags.has_cover_image:
- # we do not actually embed the image in the metadata because that would consume too
- # much space and bandwidth. Instead we set the filename as value so the image can
- # be retrieved later in realtime.
- track.metadata.images = [
- MediaItemImage(ImageType.THUMB, file_item.path, True)
- ]
- if track.album:
- # set embedded cover on album
- track.album.metadata.images = track.metadata.images
-
- # parse other info
- track.duration = tags.duration or 0
- track.metadata.genres = tags.genres
- track.disc_number = tags.disc
- track.track_number = tags.track
- track.isrc = tags.get("isrc")
- track.metadata.copyright = tags.get("copyright")
- track.metadata.lyrics = tags.get("lyrics")
- track.musicbrainz_id = tags.musicbrainz_trackid
- if track.album:
- if not track.album.musicbrainz_id:
- track.album.musicbrainz_id = tags.musicbrainz_releasegroupid
- if not track.album.year:
- track.album.year = tags.year
- if not track.album.upc:
- track.album.upc = tags.get("barcode")
- # try to parse albumtype
- if track.album and track.album.album_type == AlbumType.UNKNOWN:
- album_type = tags.album_type
- if album_type and "compilation" in album_type:
- track.album.album_type = AlbumType.COMPILATION
- elif album_type and "single" in album_type:
- track.album.album_type = AlbumType.SINGLE
- elif album_type and "album" in album_type:
- track.album.album_type = AlbumType.ALBUM
- elif track.album.sort_name in track.sort_name:
- track.album.album_type = AlbumType.SINGLE
-
- # set checksum to invalidate any cached listings
- checksum_timestamp = str(int(time()))
- track.metadata.checksum = checksum_timestamp
- if track.album:
- track.album.metadata.checksum = checksum_timestamp
- for artist in track.album.artists:
- artist.metadata.checksum = checksum_timestamp
-
- track.add_provider_mapping(
- ProviderMapping(
- item_id=file_item.path,
- provider_type=self.type,
- provider_id=self.id,
- content_type=ContentType.try_parse(tags.format),
- sample_rate=tags.sample_rate,
- bit_depth=tags.bits_per_sample,
- bit_rate=tags.bit_rate,
- )
- )
- return track
-
- async def get_playlist(self, prov_playlist_id: str) -> Playlist:
- """Get full playlist details by id."""
- if not await self.exists(prov_playlist_id):
- raise MediaNotFoundError(
- f"Playlist path does not exist: {prov_playlist_id}"
- )
-
- file_item = await self.resolve(prov_playlist_id)
- playlist = Playlist(file_item.path, provider=self.type, name=file_item.name)
- playlist.is_editable = file_item.ext != "pls" # can only edit m3u playlists
-
- playlist.add_provider_mapping(
- ProviderMapping(
- item_id=file_item.path,
- provider_type=self.type,
- provider_id=self.id,
- )
- )
- playlist.owner = self._attr_name
- checksum = f"{SCHEMA_VERSION}.{file_item.checksum}"
- playlist.metadata.checksum = checksum
- return playlist
-
- async def get_album_tracks(self, prov_album_id: str) -> List[Track]:
- """Get album tracks for given album id."""
- # filesystem items are always stored in db so we can query the database
- db_album = await self.mass.music.albums.get_db_item_by_prov_id(
- prov_album_id, provider_id=self.id
- )
- if db_album is None:
- raise MediaNotFoundError(f"Album not found: {prov_album_id}")
- # TODO: adjust to json query instead of text search
- query = f"SELECT * FROM tracks WHERE albums LIKE '%\"{db_album.item_id}\"%'"
- query += f" AND provider_mappings LIKE '%\"{self.id}\"%'"
- result = []
- for track in await self.mass.music.tracks.get_db_items_by_query(query):
- track.album = db_album
- if album_mapping := next(
- (x for x in track.albums if x.item_id == db_album.item_id), None
- ):
- track.disc_number = album_mapping.disc_number
- track.track_number = album_mapping.track_number
- result.append(track)
- return sorted(result, key=lambda x: (x.disc_number or 0, x.track_number or 0))
-
- async def get_playlist_tracks(self, prov_playlist_id: str) -> List[Track]:
- """Get playlist tracks for given playlist id."""
- result = []
- if not await self.exists(prov_playlist_id):
- raise MediaNotFoundError(
- f"Playlist path does not exist: {prov_playlist_id}"
- )
-
- _, ext = prov_playlist_id.rsplit(".", 1)
- try:
- # get playlist file contents
- playlist_data = b""
- async for chunk in self.read_file_content(prov_playlist_id):
- playlist_data += chunk
- playlist_data = playlist_data.decode("utf-8")
-
- if ext in ("m3u", "m3u8"):
- playlist_lines = await parse_m3u(playlist_data)
- else:
- playlist_lines = await parse_pls(playlist_data)
-
- for line_no, playlist_line in enumerate(playlist_lines):
-
- if media_item := await self._parse_playlist_line(
- playlist_line, os.path.dirname(prov_playlist_id)
- ):
- # use the linenumber as position for easier deletions
- media_item.position = line_no
- result.append(media_item)
-
- except Exception as err: # pylint: disable=broad-except
- self.logger.warning(
- "Error while parsing playlist %s", prov_playlist_id, exc_info=err
- )
- return result
-
- async def _parse_playlist_line(
- self, line: str, playlist_path: str
- ) -> Track | Radio | None:
- """Try to parse a track from a playlist line."""
- try:
- # try to treat uri as (relative) filename
- if "://" not in line:
- for filename in (line, os.path.join(playlist_path, line)):
- if not await self.exists(filename):
- continue
- return await self.get_track(filename)
- # fallback to generic uri parsing
- return await self.mass.music.get_item_by_uri(line)
- except MusicAssistantError as err:
- self.logger.warning(
- "Could not parse uri/file %s to track: %s", line, str(err)
- )
- return None
-
- async def add_playlist_tracks(
- self, prov_playlist_id: str, prov_track_ids: List[str]
- ) -> None:
- """Add track(s) to playlist."""
- if not await self.exists(prov_playlist_id):
- raise MediaNotFoundError(
- f"Playlist path does not exist: {prov_playlist_id}"
- )
- playlist_data = b""
- async for chunk in self.read_file_content(prov_playlist_id):
- playlist_data += chunk
- playlist_data = playlist_data.decode("utf-8")
- for uri in prov_track_ids:
- playlist_data += f"\n{uri}"
-
- # write playlist file
- await self.write_file_content(prov_playlist_id, playlist_data.encode("utf-8"))
-
- async def remove_playlist_tracks(
- self, prov_playlist_id: str, positions_to_remove: Tuple[int]
- ) -> None:
- """Remove track(s) from playlist."""
- if not await self.exists(prov_playlist_id):
- raise MediaNotFoundError(
- f"Playlist path does not exist: {prov_playlist_id}"
- )
- cur_lines = []
- _, ext = prov_playlist_id.rsplit(".", 1)
-
- # get playlist file contents
- playlist_data = b""
- async for chunk in self.read_file_content(prov_playlist_id):
- playlist_data += chunk
- playlist_data.decode("utf-8")
-
- if ext in ("m3u", "m3u8"):
- playlist_lines = await parse_m3u(playlist_data)
- else:
- playlist_lines = await parse_pls(playlist_data)
-
- for line_no, playlist_line in enumerate(playlist_lines):
- if line_no not in positions_to_remove:
- cur_lines.append(playlist_line)
-
- new_playlist_data = "\n".join(cur_lines)
- # write playlist file
- await self.write_file_content(
- prov_playlist_id, new_playlist_data.encode("utf-8")
- )
-
- async def create_playlist(self, name: str) -> Playlist:
- """Create a new playlist on provider with given name."""
- # creating a new playlist on the filesystem is as easy
- # as creating a new (empty) file with the m3u extension...
- filename = await self.resolve(f"{name}.m3u")
- await self.write_file_content(filename, b"")
- playlist = await self.get_playlist(filename)
- db_playlist = await self.mass.music.playlists.add_db_item(playlist)
- return db_playlist
-
- async def get_stream_details(self, item_id: str) -> StreamDetails:
- """Return the content details for the given track when it will be streamed."""
- db_item = await self.mass.music.tracks.get_db_item_by_prov_id(
- provider_item_id=item_id, provider_id=self.id
- )
- if db_item is None:
- raise MediaNotFoundError(f"Item not found: {item_id}")
-
- prov_mapping = next(
- x for x in db_item.provider_mappings if x.item_id == item_id
- )
- file_item = await self.resolve(item_id)
-
- return StreamDetails(
- provider=self.type,
- item_id=item_id,
- content_type=prov_mapping.content_type,
- media_type=MediaType.TRACK,
- duration=db_item.duration,
- size=file_item.file_size,
- sample_rate=prov_mapping.sample_rate,
- bit_depth=prov_mapping.bit_depth,
- direct=file_item.local_path,
- )
-
- async def get_audio_stream(
- self, streamdetails: StreamDetails, seek_position: int = 0
- ) -> AsyncGenerator[bytes, None]:
- """Return the audio stream for the provider item."""
- if seek_position:
- assert streamdetails.duration, "Duration required for seek requests"
- assert streamdetails.size, "Filesize required for seek requests"
- seek_bytes = int(
- (streamdetails.size / streamdetails.duration) * seek_position
- )
- else:
- seek_bytes = 0
-
- async for chunk in self.read_file_content(streamdetails.item_id, seek_bytes):
- yield chunk
-
- async def _parse_artist(
- self,
- name: Optional[str] = None,
- artist_path: Optional[str] = None,
- ) -> Artist | None:
- """Lookup metadata in Artist folder."""
- assert name or artist_path
- if not artist_path:
- artist_path = name
-
- if not name:
- name = artist_path.split(os.sep)[-1]
-
- artist = Artist(
- artist_path,
- self.type,
- name,
- provider_mappings={
- ProviderMapping(artist_path, self.type, self.id, url=artist_path)
- },
- musicbrainz_id=VARIOUS_ARTISTS_ID
- if compare_strings(name, VARIOUS_ARTISTS)
- else None,
- )
-
- if not await self.exists(artist_path):
- # return basic object if there is no dedicated artist folder
- return artist
-
- nfo_file = os.path.join(artist_path, "artist.nfo")
- if await self.exists(nfo_file):
- # found NFO file with metadata
- # https://kodi.wiki/view/NFO_files/Artists
- data = b""
- async for chunk in self.read_file_content(nfo_file):
- data += chunk
- info = await self.mass.loop.run_in_executor(None, xmltodict.parse, data)
- info = info["artist"]
- artist.name = info.get("title", info.get("name", name))
- if sort_name := info.get("sortname"):
- artist.sort_name = sort_name
- if musicbrainz_id := info.get("musicbrainzartistid"):
- artist.musicbrainz_id = musicbrainz_id
- if descripton := info.get("biography"):
- artist.metadata.description = descripton
- if genre := info.get("genre"):
- artist.metadata.genres = set(split_items(genre))
- # find local images
- artist.metadata.images = await self._get_local_images(artist_path) or None
-
- return artist
-
- async def _parse_album(
- self, name: Optional[str], album_path: Optional[str], artists: List[Artist]
- ) -> Album | None:
- """Lookup metadata in Album folder."""
- assert (name or album_path) and artists
- if not album_path:
- # create fake path
- album_path = artists[0].name + os.sep + name
-
- if not name:
- name = album_path.split(os.sep)[-1]
-
- album = Album(
- album_path,
- self.type,
- name,
- artists=artists,
- provider_mappings={
- ProviderMapping(album_path, self.type, self.id, url=album_path)
- },
- )
-
- if not await self.exists(album_path):
- # return basic object if there is no dedicated album folder
- return album
-
- nfo_file = os.path.join(album_path, "album.nfo")
- if await self.exists(nfo_file):
- # found NFO file with metadata
- # https://kodi.wiki/view/NFO_files/Artists
- data = b""
- async for chunk in self.read_file_content(nfo_file):
- data += chunk
- info = await self.mass.loop.run_in_executor(None, xmltodict.parse, data)
- info = info["album"]
- album.name = info.get("title", info.get("name", name))
- if sort_name := info.get("sortname"):
- album.sort_name = sort_name
- if musicbrainz_id := info.get("musicbrainzreleasegroupid"):
- album.musicbrainz_id = musicbrainz_id
- if mb_artist_id := info.get("musicbrainzalbumartistid"):
- if album.artist and not album.artist.musicbrainz_id:
- album.artist.musicbrainz_id = mb_artist_id
- if description := info.get("review"):
- album.metadata.description = description
- if year := info.get("year"):
- album.year = int(year)
- if genre := info.get("genre"):
- album.metadata.genres = set(split_items(genre))
- # parse name/version
- album.name, album.version = parse_title_and_version(album.name)
-
- # find local images
- album.metadata.images = await self._get_local_images(album_path) or None
-
- return album
-
- async def _get_local_images(self, folder: str) -> List[MediaItemImage]:
- """Return local images found in a given folderpath."""
- images = []
- async for item in self.listdir(folder):
- if "." not in item.path or item.is_dir:
- continue
- for ext in IMAGE_EXTENSIONS:
- if item.ext != ext:
- continue
- try:
- images.append(MediaItemImage(ImageType(item.name), item.path, True))
- except ValueError:
- if "folder" in item.name:
- images.append(MediaItemImage(ImageType.THUMB, item.path, True))
- elif "AlbumArt" in item.name:
- images.append(MediaItemImage(ImageType.THUMB, item.path, True))
- elif "Artist" in item.name:
- images.append(MediaItemImage(ImageType.THUMB, item.path, True))
- return images
+++ /dev/null
-"""Some helpers for Filesystem based Musicproviders."""
-from __future__ import annotations
-
-import asyncio
-import os
-from io import BytesIO
-from typing import Any, AsyncGenerator, Dict
-
-from smb.base import SharedFile, SMBTimeout
-from smb.smb_structs import OperationFailure
-from smb.SMBConnection import SMBConnection
-
-from music_assistant.helpers.compare import compare_strings
-from music_assistant.models.errors import LoginFailed
-
-SERVICE_NAME = "music_assistant"
-
-
-def get_parentdir(base_path: str, name: str) -> str | None:
- """Look for folder name in path (to find dedicated artist or album folder)."""
- parentdir = os.path.dirname(base_path)
- for _ in range(3):
- dirname = parentdir.rsplit(os.sep)[-1]
- if compare_strings(name, dirname, False):
- return parentdir
- parentdir = os.path.dirname(parentdir)
- return None
-
-
-def get_relative_path(base_path: str, path: str) -> str:
- """Return the relative path string for a path."""
- if path.startswith(base_path):
- path = path.split(base_path)[1]
- for sep in ("/", "\\"):
- if path.startswith(sep):
- path = path[1:]
- return path
-
-
-def get_absolute_path(base_path: str, path: str) -> str:
- """Return the absolute path string for a path."""
- if path.startswith(base_path):
- return path
- return os.path.join(base_path, path)
-
-
-class AsyncSMB:
- """Async wrapped pysmb."""
-
- def __init__(
- self,
- remote_name: str,
- service_name: str,
- username: str,
- password: str,
- target_ip: str,
- options: Dict[str, Any],
- ) -> None:
- """Initialize instance."""
- self._service_name = service_name
- self._remote_name = remote_name
- self._target_ip = target_ip
- self._username = username
- self._password = password
- self._conn = SMBConnection(
- username=self._username,
- password=self._password,
- my_name=SERVICE_NAME,
- remote_name=self._remote_name,
- # choose sane default options but allow user to override them via the options dict
- domain=options.get("domain", ""),
- use_ntlm_v2=options.get("use_ntlm_v2", False),
- sign_options=options.get("sign_options", 2),
- is_direct_tcp=options.get("is_direct_tcp", False),
- )
-
- async def list_path(self, path: str) -> list[SharedFile]:
- """Retrieve a directory listing of files/folders at *path*."""
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(
- None, self._conn.listPath, self._service_name, path
- )
-
- async def get_attributes(self, path: str) -> SharedFile:
- """Retrieve information about the file at *path* on the *service_name*."""
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(
- None, self._conn.getAttributes, self._service_name, path
- )
-
- async def retrieve_file(
- self, path: str, offset: int = 0
- ) -> AsyncGenerator[bytes, None]:
- """Retrieve file contents."""
- loop = asyncio.get_running_loop()
-
- chunk_size = 256000
- while True:
- with BytesIO() as file_obj:
- await loop.run_in_executor(
- None,
- self._conn.retrieveFileFromOffset,
- self._service_name,
- path,
- file_obj,
- offset,
- chunk_size,
- )
- file_obj.seek(0)
- chunk = file_obj.read()
- yield chunk
- offset += len(chunk)
- if len(chunk) < chunk_size:
- break
-
- async def write_file(self, path: str, data: bytes) -> SharedFile:
- """Store the contents to the file at *path*."""
- loop = asyncio.get_running_loop()
- with BytesIO() as file_obj:
- file_obj.write(data)
- file_obj.seek(0)
- await loop.run_in_executor(
- None,
- self._conn.storeFile,
- self._service_name,
- path,
- file_obj,
- )
-
- async def path_exists(self, path: str) -> bool:
- """Return bool is this FileSystem musicprovider has given file/dir."""
- loop = asyncio.get_running_loop()
- try:
- await loop.run_in_executor(
- None, self._conn.getAttributes, self._service_name, path
- )
- except (OperationFailure, SMBTimeout):
- return False
- return True
-
- async def connect(self) -> None:
- """Connect to the SMB server."""
- loop = asyncio.get_running_loop()
- try:
- assert (
- await loop.run_in_executor(None, self._conn.connect, self._target_ip)
- is True
- )
- except Exception as exc:
- raise LoginFailed(f"SMB Connect failed to {self._remote_name}") from exc
-
- async def __aenter__(self) -> "AsyncSMB":
- """Enter context manager."""
- # connect
- await self.connect()
- return self
-
- async def __aexit__(self, exc_type, exc_value, traceback) -> bool:
- """Exit context manager."""
- self._conn.close()
+++ /dev/null
-"""Filesystem musicprovider support for MusicAssistant."""
-from __future__ import annotations
-
-import asyncio
-import os
-import os.path
-from typing import AsyncGenerator
-
-import aiofiles
-from aiofiles.os import wrap
-
-from music_assistant.models.enums import ProviderType
-from music_assistant.models.errors import SetupFailedError
-
-from .base import FileSystemItem, FileSystemProviderBase
-from .helpers import get_absolute_path, get_relative_path
-
-listdir = wrap(os.listdir)
-isdir = wrap(os.path.isdir)
-isfile = wrap(os.path.isfile)
-exists = wrap(os.path.exists)
-
-
-async def create_item(base_path: str, entry: os.DirEntry) -> FileSystemItem:
- """Create FileSystemItem from os.DirEntry."""
-
- def _create_item():
- absolute_path = get_absolute_path(base_path, entry.path)
- stat = entry.stat(follow_symlinks=False)
- return FileSystemItem(
- name=entry.name,
- path=get_relative_path(base_path, entry.path),
- absolute_path=absolute_path,
- is_file=entry.is_file(follow_symlinks=False),
- is_dir=entry.is_dir(follow_symlinks=False),
- checksum=str(int(stat.st_mtime)),
- file_size=stat.st_size,
- # local filesystem is always local resolvable
- local_path=absolute_path,
- )
-
- # run in executor because strictly taken this may be blocking IO
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _create_item)
-
-
-class LocalFileSystemProvider(FileSystemProviderBase):
- """Implementation of a musicprovider for local files."""
-
- _attr_name = "Filesystem"
- _attr_type = ProviderType.FILESYSTEM_LOCAL
-
- async def setup(self) -> bool:
- """Handle async initialization of the provider."""
-
- if not await isdir(self.config.path):
- raise SetupFailedError(f"Music Directory {self.config.path} does not exist")
-
- return True
-
- async def listdir(
- self, path: str, recursive: bool = False
- ) -> AsyncGenerator[FileSystemItem, None]:
- """
- List contents of a given provider directory/path.
-
- Parameters:
- - path: path of the directory (relative or absolute) to list contents of.
- Empty string for provider's root.
- - recursive: If True will recursively keep unwrapping subdirectories (scandir equivalent).
-
- Returns:
- AsyncGenerator yielding FileSystemItem objects.
-
- """
- abs_path = get_absolute_path(self.config.path, path)
- loop = asyncio.get_running_loop()
- for entry in await loop.run_in_executor(None, os.scandir, abs_path):
- if entry.name.startswith("."):
- # skip invalid/system files and dirs
- continue
- item = await create_item(self.config.path, entry)
- if recursive and item.is_dir:
- try:
- async for subitem in self.listdir(item.absolute_path, True):
- yield subitem
- except (OSError, PermissionError) as err:
- self.logger.warning("Skip folder %s: %s", item.path, str(err))
- else:
- yield item
-
- async def resolve(
- self, file_path: str, require_local: bool = False
- ) -> FileSystemItem:
- """
- Resolve (absolute or relative) path to FileSystemItem.
-
- If want_local is True, we prefer to have the `local_path` attribute filled
- (e.g. with a tempfile), if supported by the provider/item.
- """
- absolute_path = get_absolute_path(self.config.path, file_path)
-
- def _create_item():
- stat = os.stat(absolute_path, follow_symlinks=False)
- return FileSystemItem(
- name=os.path.basename(file_path),
- path=get_relative_path(self.config.path, file_path),
- absolute_path=absolute_path,
- is_dir=os.path.isdir(absolute_path),
- is_file=os.path.isfile(absolute_path),
- checksum=str(int(stat.st_mtime)),
- file_size=stat.st_size,
- # local filesystem is always local resolvable
- local_path=absolute_path,
- )
-
- # run in executor because strictly taken this may be blocking IO
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _create_item)
-
- async def exists(self, file_path: str) -> bool:
- """Return bool is this FileSystem musicprovider has given file/dir."""
- if not file_path:
- return False # guard
- abs_path = get_absolute_path(self.config.path, file_path)
- return await exists(abs_path)
-
- async def read_file_content(
- self, file_path: str, seek: int = 0
- ) -> AsyncGenerator[bytes, None]:
- """Yield (binary) contents of file in chunks of bytes."""
- abs_path = get_absolute_path(self.config.path, file_path)
- chunk_size = 512000
- async with aiofiles.open(abs_path, "rb") as _file:
- if seek:
- await _file.seek(seek)
- # yield chunks of data from file
- while True:
- data = await _file.read(chunk_size)
- if not data:
- break
- yield data
-
- async def write_file_content(self, file_path: str, data: bytes) -> None:
- """Write entire file content as bytes (e.g. for playlists)."""
- abs_path = get_absolute_path(self.config.path, file_path)
- async with aiofiles.open(abs_path, "wb") as _file:
- await _file.write(data)
+++ /dev/null
-"""SMB filesystem provider for Music Assistant."""
-
-import contextvars
-import os
-from contextlib import asynccontextmanager
-from typing import AsyncGenerator
-
-from smb.base import SharedFile
-
-from music_assistant.helpers.util import get_ip_from_host
-from music_assistant.models.enums import ProviderType
-
-from .base import FileSystemItem, FileSystemProviderBase
-from .helpers import AsyncSMB, get_absolute_path, get_relative_path
-
-
-async def create_item(
- file_path: str, entry: SharedFile, root_path: str
-) -> FileSystemItem:
- """Create FileSystemItem from smb.SharedFile."""
-
- rel_path = get_relative_path(root_path, file_path)
- abs_path = get_absolute_path(root_path, file_path)
- return FileSystemItem(
- name=entry.filename,
- path=rel_path,
- absolute_path=abs_path,
- is_file=not entry.isDirectory,
- is_dir=entry.isDirectory,
- checksum=str(int(entry.last_write_time)),
- file_size=entry.file_size,
- )
-
-
-smb_conn_ctx = contextvars.ContextVar("smb_conn_ctx", default=None)
-
-
-class SMBFileSystemProvider(FileSystemProviderBase):
- """Implementation of an SMB File System Provider."""
-
- _attr_name = "smb"
- _attr_type = ProviderType.FILESYSTEM_SMB
- _service_name = ""
- _root_path = "/"
- _remote_name = ""
- _target_ip = ""
-
- async def setup(self) -> bool:
- """Handle async initialization of the provider."""
- # extract params from path
- if self.config.path.startswith("\\\\"):
- path_parts = self.config.path[2:].split("\\", 2)
- elif self.config.path.startswith("smb://"):
- path_parts = self.config.path[6:].split("/", 2)
- else:
- path_parts = self.config.path.split(os.sep)
- self._remote_name = path_parts[0]
- self._service_name = path_parts[1]
- if len(path_parts) > 2:
- self._root_path = os.sep + path_parts[2]
-
- default_target_ip = await get_ip_from_host(self._remote_name)
- self._target_ip = self.config.options.get("target_ip", default_target_ip)
- async with self._get_smb_connection():
- return True
-
- async def listdir(
- self,
- path: str,
- recursive: bool = False,
- ) -> AsyncGenerator[FileSystemItem, None]:
- """
- List contents of a given provider directory/path.
-
- Parameters:
- - path: path of the directory (relative or absolute) to list contents of.
- Empty string for provider's root.
- - recursive: If True will recursively keep unwrapping subdirectories (scandir equivalent).
-
- Returns:
- AsyncGenerator yielding FileSystemItem objects.
-
- """
- abs_path = get_absolute_path(self._root_path, path)
- async with self._get_smb_connection() as smb_conn:
- path_result: list[SharedFile] = await smb_conn.list_path(abs_path)
- for entry in path_result:
- if entry.filename.startswith("."):
- # skip invalid/system files and dirs
- continue
- file_path = os.path.join(path, entry.filename)
- item = await create_item(file_path, entry, self._root_path)
- if recursive and item.is_dir:
- # yield sublevel recursively
- try:
- async for subitem in self.listdir(file_path, True):
- yield subitem
- except (OSError, PermissionError) as err:
- self.logger.warning("Skip folder %s: %s", item.path, str(err))
- elif item.is_file or item.is_dir:
- yield item
-
- async def resolve(self, file_path: str) -> FileSystemItem:
- """Resolve (absolute or relative) path to FileSystemItem."""
- abs_path = get_absolute_path(self._root_path, file_path)
- async with self._get_smb_connection() as smb_conn:
- entry: SharedFile = await smb_conn.get_attributes(abs_path)
- return FileSystemItem(
- name=file_path,
- path=get_relative_path(self._root_path, file_path),
- absolute_path=abs_path,
- is_file=not entry.isDirectory,
- is_dir=entry.isDirectory,
- checksum=str(int(entry.last_write_time)),
- file_size=entry.file_size,
- )
-
- async def exists(self, file_path: str) -> bool:
- """Return bool if this FileSystem musicprovider has given file/dir."""
- abs_path = get_absolute_path(self._root_path, file_path)
- async with self._get_smb_connection() as smb_conn:
- return await smb_conn.path_exists(abs_path)
-
- async def read_file_content(
- self, file_path: str, seek: int = 0
- ) -> AsyncGenerator[bytes, None]:
- """Yield (binary) contents of file in chunks of bytes."""
- abs_path = get_absolute_path(self._root_path, file_path)
-
- async with self._get_smb_connection() as smb_conn:
- async for chunk in smb_conn.retrieve_file(abs_path, seek):
- yield chunk
-
- async def write_file_content(self, file_path: str, data: bytes) -> None:
- """Write entire file content as bytes (e.g. for playlists)."""
- abs_path = get_absolute_path(self._root_path, file_path)
- async with self._get_smb_connection() as smb_conn:
- await smb_conn.write_file(abs_path, data)
-
- @asynccontextmanager
- async def _get_smb_connection(self) -> AsyncGenerator[AsyncSMB, None]:
- """Get instance of AsyncSMB."""
-
- # for a task that consists of multiple steps,
- # the smb connection may be reused (shared through a contextvar)
- if existing := smb_conn_ctx.get():
- yield existing
- return
-
- async with AsyncSMB(
- remote_name=self._remote_name,
- service_name=self._service_name,
- username=self.config.username,
- password=self.config.password,
- target_ip=self._target_ip,
- options=self.config.options,
- ) as smb_conn:
- token = smb_conn_ctx.set(smb_conn)
- yield smb_conn
- smb_conn_ctx.reset(token)
+++ /dev/null
-"""Package with Qobuz Music provider."""
-
-from .qobuz import QobuzProvider # noqa
+++ /dev/null
-"""Qobuz musicprovider support for MusicAssistant."""
-from __future__ import annotations
-
-import datetime
-import hashlib
-import time
-from json import JSONDecodeError
-from typing import AsyncGenerator, List, Optional, Tuple
-
-import aiohttp
-from asyncio_throttle import Throttler
-
-from music_assistant.helpers.app_vars import ( # pylint: disable=no-name-in-module
- app_var,
-)
-from music_assistant.helpers.util import parse_title_and_version, try_parse_int
-from music_assistant.models.enums import MusicProviderFeature, ProviderType
-from music_assistant.models.errors import LoginFailed, MediaNotFoundError
-from music_assistant.models.media_items import (
- Album,
- AlbumType,
- Artist,
- ContentType,
- ImageType,
- MediaItemImage,
- MediaItemType,
- MediaType,
- Playlist,
- ProviderMapping,
- StreamDetails,
- Track,
-)
-from music_assistant.models.music_provider import MusicProvider
-
-
-class QobuzProvider(MusicProvider):
- """Provider for the Qobux music service."""
-
- _attr_type = ProviderType.QOBUZ
- _attr_name = "Qobuz"
- _user_auth_info = None
- _throttler = Throttler(rate_limit=4, period=1)
-
- @property
- def supported_features(self) -> Tuple[MusicProviderFeature]:
- """Return the features supported by this MusicProvider."""
- return (
- MusicProviderFeature.LIBRARY_ARTISTS,
- MusicProviderFeature.LIBRARY_ALBUMS,
- MusicProviderFeature.LIBRARY_TRACKS,
- MusicProviderFeature.LIBRARY_PLAYLISTS,
- MusicProviderFeature.LIBRARY_ARTISTS_EDIT,
- MusicProviderFeature.LIBRARY_ALBUMS_EDIT,
- MusicProviderFeature.LIBRARY_PLAYLISTS_EDIT,
- MusicProviderFeature.LIBRARY_TRACKS_EDIT,
- MusicProviderFeature.PLAYLIST_TRACKS_EDIT,
- MusicProviderFeature.BROWSE,
- MusicProviderFeature.SEARCH,
- MusicProviderFeature.ARTIST_ALBUMS,
- MusicProviderFeature.ARTIST_TOPTRACKS,
- )
-
- async def setup(self) -> bool:
- """Handle async initialization of the provider."""
- if not self.config.enabled:
- return False
- if not self.config.username or not self.config.password:
- raise LoginFailed("Invalid login credentials")
- # try to get a token, raise if that fails
- token = await self._auth_token()
- if not token:
- raise LoginFailed(f"Login failed for user {self.config.username}")
- return True
-
- async def search(
- self, search_query: str, media_types=Optional[List[MediaType]], limit: int = 5
- ) -> List[MediaItemType]:
- """
- Perform search on musicprovider.
-
- :param search_query: Search query.
- :param media_types: A list of media_types to include. All types if None.
- :param limit: Number of items to return in the search (per type).
- """
- result = []
- params = {"query": search_query, "limit": limit}
- if len(media_types) == 1:
- # qobuz does not support multiple searchtypes, falls back to all if no type given
- if media_types[0] == MediaType.ARTIST:
- params["type"] = "artists"
- if media_types[0] == MediaType.ALBUM:
- params["type"] = "albums"
- if media_types[0] == MediaType.TRACK:
- params["type"] = "tracks"
- if media_types[0] == MediaType.PLAYLIST:
- params["type"] = "playlists"
- if searchresult := await self._get_data("catalog/search", **params):
- if "artists" in searchresult:
- result += [
- await self._parse_artist(item)
- for item in searchresult["artists"]["items"]
- if (item and item["id"])
- ]
- if "albums" in searchresult:
- result += [
- await self._parse_album(item)
- for item in searchresult["albums"]["items"]
- if (item and item["id"])
- ]
- if "tracks" in searchresult:
- result += [
- await self._parse_track(item)
- for item in searchresult["tracks"]["items"]
- if (item and item["id"])
- ]
- if "playlists" in searchresult:
- result += [
- await self._parse_playlist(item)
- for item in searchresult["playlists"]["items"]
- if (item and item["id"])
- ]
- return result
-
- async def get_library_artists(self) -> AsyncGenerator[Artist, None]:
- """Retrieve all library artists from Qobuz."""
- endpoint = "favorite/getUserFavorites"
- for item in await self._get_all_items(endpoint, key="artists", type="artists"):
- if item and item["id"]:
- yield await self._parse_artist(item)
-
- async def get_library_albums(self) -> AsyncGenerator[Album, None]:
- """Retrieve all library albums from Qobuz."""
- endpoint = "favorite/getUserFavorites"
- for item in await self._get_all_items(endpoint, key="albums", type="albums"):
- if item and item["id"]:
- yield await self._parse_album(item)
-
- async def get_library_tracks(self) -> AsyncGenerator[Track, None]:
- """Retrieve library tracks from Qobuz."""
- endpoint = "favorite/getUserFavorites"
- for item in await self._get_all_items(endpoint, key="tracks", type="tracks"):
- if item and item["id"]:
- yield await self._parse_track(item)
-
- async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]:
- """Retrieve all library playlists from the provider."""
- endpoint = "playlist/getUserPlaylists"
- for item in await self._get_all_items(endpoint, key="playlists"):
- if item and item["id"]:
- yield await self._parse_playlist(item)
-
- async def get_artist(self, prov_artist_id) -> Artist:
- """Get full artist details by id."""
- params = {"artist_id": prov_artist_id}
- artist_obj = await self._get_data("artist/get", **params)
- return (
- await self._parse_artist(artist_obj)
- if artist_obj and artist_obj["id"]
- else None
- )
-
- async def get_album(self, prov_album_id) -> Album:
- """Get full album details by id."""
- params = {"album_id": prov_album_id}
- album_obj = await self._get_data("album/get", **params)
- return (
- await self._parse_album(album_obj)
- if album_obj and album_obj["id"]
- else None
- )
-
- async def get_track(self, prov_track_id) -> Track:
- """Get full track details by id."""
- params = {"track_id": prov_track_id}
- track_obj = await self._get_data("track/get", **params)
- return (
- await self._parse_track(track_obj)
- if track_obj and track_obj["id"]
- else None
- )
-
- async def get_playlist(self, prov_playlist_id) -> Playlist:
- """Get full playlist details by id."""
- params = {"playlist_id": prov_playlist_id}
- playlist_obj = await self._get_data("playlist/get", **params)
- return (
- await self._parse_playlist(playlist_obj)
- if playlist_obj and playlist_obj["id"]
- else None
- )
-
- async def get_album_tracks(self, prov_album_id) -> List[Track]:
- """Get all album tracks for given album id."""
- params = {"album_id": prov_album_id}
- return [
- await self._parse_track(item)
- for item in await self._get_all_items("album/get", **params, key="tracks")
- if (item and item["id"])
- ]
-
- async def get_playlist_tracks(self, prov_playlist_id) -> List[Track]:
- """Get all playlist tracks for given playlist id."""
- count = 0
- result = []
- for item in await self._get_all_items(
- "playlist/get",
- key="tracks",
- playlist_id=prov_playlist_id,
- extra="tracks",
- ):
- if not (item and item["id"]):
- continue
- track = await self._parse_track(item)
- # use count as position
- track.position = count
- result.append(track)
- count += 1
- return result
-
- async def get_artist_albums(self, prov_artist_id) -> List[Album]:
- """Get a list of albums for the given artist."""
- endpoint = "artist/get"
- return [
- await self._parse_album(item)
- for item in await self._get_all_items(
- endpoint, key="albums", artist_id=prov_artist_id, extra="albums"
- )
- if (item and item["id"] and str(item["artist"]["id"]) == prov_artist_id)
- ]
-
- async def get_artist_toptracks(self, prov_artist_id) -> List[Track]:
- """Get a list of most popular tracks for the given artist."""
- result = await self._get_data(
- "artist/get",
- artist_id=prov_artist_id,
- extra="playlists",
- offset=0,
- limit=25,
- )
- if result and result["playlists"]:
- return [
- await self._parse_track(item)
- for item in result["playlists"][0]["tracks"]["items"]
- if (item and item["id"])
- ]
- # fallback to search
- artist = await self.get_artist(prov_artist_id)
- searchresult = await self._get_data(
- "catalog/search", query=artist.name, limit=25, type="tracks"
- )
- return [
- await self._parse_track(item)
- for item in searchresult["tracks"]["items"]
- if (
- item
- and item["id"]
- and "performer" in item
- and str(item["performer"]["id"]) == str(prov_artist_id)
- )
- ]
-
- async def get_similar_artists(self, prov_artist_id):
- """Get similar artists for given artist."""
- # https://www.qobuz.com/api.json/0.2/artist/getSimilarArtists?artist_id=220020&offset=0&limit=3
-
- async def library_add(self, prov_item_id, media_type: MediaType):
- """Add item to library."""
- result = None
- if media_type == MediaType.ARTIST:
- result = await self._get_data("favorite/create", artist_id=prov_item_id)
- elif media_type == MediaType.ALBUM:
- result = await self._get_data("favorite/create", album_ids=prov_item_id)
- elif media_type == MediaType.TRACK:
- result = await self._get_data("favorite/create", track_ids=prov_item_id)
- elif media_type == MediaType.PLAYLIST:
- result = await self._get_data(
- "playlist/subscribe", playlist_id=prov_item_id
- )
- return result
-
- async def library_remove(self, prov_item_id, media_type: MediaType):
- """Remove item from library."""
- result = None
- if media_type == MediaType.ARTIST:
- result = await self._get_data("favorite/delete", artist_ids=prov_item_id)
- elif media_type == MediaType.ALBUM:
- result = await self._get_data("favorite/delete", album_ids=prov_item_id)
- elif media_type == MediaType.TRACK:
- result = await self._get_data("favorite/delete", track_ids=prov_item_id)
- elif media_type == MediaType.PLAYLIST:
- playlist = await self.get_playlist(prov_item_id)
- if playlist.is_editable:
- result = await self._get_data(
- "playlist/delete", playlist_id=prov_item_id
- )
- else:
- result = await self._get_data(
- "playlist/unsubscribe", playlist_id=prov_item_id
- )
- return result
-
- async def add_playlist_tracks(
- self, prov_playlist_id: str, prov_track_ids: List[str]
- ) -> None:
- """Add track(s) to playlist."""
- return await self._get_data(
- "playlist/addTracks",
- playlist_id=prov_playlist_id,
- track_ids=",".join(prov_track_ids),
- playlist_track_ids=",".join(prov_track_ids),
- )
-
- async def remove_playlist_tracks(
- self, prov_playlist_id: str, positions_to_remove: Tuple[int]
- ) -> None:
- """Remove track(s) from playlist."""
- playlist_track_ids = set()
- for track in await self.get_playlist_tracks(prov_playlist_id):
- if track.position in positions_to_remove:
- playlist_track_ids.add(str(track["playlist_track_id"]))
- if len(playlist_track_ids) == positions_to_remove:
- break
- return await self._get_data(
- "playlist/deleteTracks",
- playlist_id=prov_playlist_id,
- playlist_track_ids=",".join(playlist_track_ids),
- )
-
- async def get_stream_details(self, item_id: str) -> StreamDetails:
- """Return the content details for the given track when it will be streamed."""
- streamdata = None
- for format_id in [27, 7, 6, 5]:
- # it seems that simply requesting for highest available quality does not work
- # from time to time the api response is empty for this request ?!
- result = await self._get_data(
- "track/getFileUrl",
- sign_request=True,
- format_id=format_id,
- track_id=item_id,
- intent="stream",
- )
- if result and result.get("url"):
- streamdata = result
- break
- if not streamdata:
- raise MediaNotFoundError(f"Unable to retrieve stream details for {item_id}")
- if streamdata["mime_type"] == "audio/mpeg":
- content_type = ContentType.MPEG
- elif streamdata["mime_type"] == "audio/flac":
- content_type = ContentType.FLAC
- else:
- raise MediaNotFoundError(f"Unsupported mime type for {item_id}")
- # report playback started as soon as the streamdetails are requested
- self.mass.create_task(self._report_playback_started(streamdata))
- return StreamDetails(
- item_id=str(item_id),
- provider=self.type,
- content_type=content_type,
- duration=streamdata["duration"],
- sample_rate=int(streamdata["sampling_rate"] * 1000),
- bit_depth=streamdata["bit_depth"],
- data=streamdata, # we need these details for reporting playback
- expires=time.time() + 3600, # not sure about the real allowed value
- direct=streamdata["url"],
- callback=self._report_playback_stopped,
- )
-
- async def _report_playback_started(self, streamdata: dict) -> None:
- """Report playback start to qobuz."""
- # TODO: need to figure out if the streamed track is purchased by user
- # https://www.qobuz.com/api.json/0.2/purchase/getUserPurchasesIds?limit=5000&user_id=xxxxxxx
- # {"albums":{"total":0,"items":[]},"tracks":{"total":0,"items":[]},"user":{"id":xxxx,"login":"xxxxx"}}
- device_id = self._user_auth_info["user"]["device"]["id"]
- credential_id = self._user_auth_info["user"]["credential"]["id"]
- user_id = self._user_auth_info["user"]["id"]
- format_id = streamdata["format_id"]
- timestamp = int(time.time())
- events = [
- {
- "online": True,
- "sample": False,
- "intent": "stream",
- "device_id": device_id,
- "track_id": streamdata["track_id"],
- "purchase": False,
- "date": timestamp,
- "credential_id": credential_id,
- "user_id": user_id,
- "local": False,
- "format_id": format_id,
- }
- ]
- await self._post_data("track/reportStreamingStart", data=events)
-
- async def _report_playback_stopped(self, streamdetails: StreamDetails) -> None:
- """Report playback stop to qobuz."""
- user_id = self._user_auth_info["user"]["id"]
- await self._get_data(
- "/track/reportStreamingEnd",
- user_id=user_id,
- track_id=str(streamdetails.item_id),
- duration=try_parse_int(streamdetails.seconds_streamed),
- )
-
- async def _parse_artist(self, artist_obj: dict):
- """Parse qobuz artist object to generic layout."""
- artist = Artist(
- item_id=str(artist_obj["id"]), provider=self.type, name=artist_obj["name"]
- )
- artist.add_provider_mapping(
- ProviderMapping(
- item_id=str(artist_obj["id"]),
- provider_type=self.type,
- provider_id=self.id,
- url=artist_obj.get(
- "url", f'https://open.qobuz.com/artist/{artist_obj["id"]}'
- ),
- )
- )
- if img := self.__get_image(artist_obj):
- artist.metadata.images = [MediaItemImage(ImageType.THUMB, img)]
- if artist_obj.get("biography"):
- artist.metadata.description = artist_obj["biography"].get("content")
- return artist
-
- async def _parse_album(self, album_obj: dict, artist_obj: dict = None):
- """Parse qobuz album object to generic layout."""
- if not artist_obj and "artist" not in album_obj:
- # artist missing in album info, return full abum instead
- return await self.get_album(album_obj["id"])
- name, version = parse_title_and_version(
- album_obj["title"], album_obj.get("version")
- )
- album = Album(
- item_id=str(album_obj["id"]), provider=self.type, name=name, version=version
- )
- album.add_provider_mapping(
- ProviderMapping(
- item_id=str(album_obj["id"]),
- provider_type=self.type,
- provider_id=self.id,
- available=album_obj["streamable"] and album_obj["displayable"],
- content_type=ContentType.FLAC,
- sample_rate=album_obj["maximum_sampling_rate"] * 1000,
- bit_depth=album_obj["maximum_bit_depth"],
- url=album_obj.get(
- "url", f'https://open.qobuz.com/album/{album_obj["id"]}'
- ),
- )
- )
-
- album.artist = await self._parse_artist(artist_obj or album_obj["artist"])
- if (
- album_obj.get("product_type", "") == "single"
- or album_obj.get("release_type", "") == "single"
- ):
- album.album_type = AlbumType.SINGLE
- elif (
- album_obj.get("product_type", "") == "compilation"
- or "Various" in album.artist.name
- ):
- album.album_type = AlbumType.COMPILATION
- elif (
- album_obj.get("product_type", "") == "album"
- or album_obj.get("release_type", "") == "album"
- ):
- album.album_type = AlbumType.ALBUM
- if "genre" in album_obj:
- album.metadata.genres = {album_obj["genre"]["name"]}
- if img := self.__get_image(album_obj):
- album.metadata.images = [MediaItemImage(ImageType.THUMB, img)]
- if len(album_obj["upc"]) == 13:
- # qobuz writes ean as upc ?!
- album.upc = album_obj["upc"][1:]
- else:
- album.upc = album_obj["upc"]
- if "label" in album_obj:
- album.metadata.label = album_obj["label"]["name"]
- if album_obj.get("released_at"):
- album.year = datetime.datetime.fromtimestamp(album_obj["released_at"]).year
- if album_obj.get("copyright"):
- album.metadata.copyright = album_obj["copyright"]
- if album_obj.get("description"):
- album.metadata.description = album_obj["description"]
- return album
-
- async def _parse_track(self, track_obj: dict):
- """Parse qobuz track object to generic layout."""
- name, version = parse_title_and_version(
- track_obj["title"], track_obj.get("version")
- )
- track = Track(
- item_id=str(track_obj["id"]),
- provider=self.type,
- name=name,
- version=version,
- disc_number=track_obj["media_number"],
- track_number=track_obj["track_number"],
- duration=track_obj["duration"],
- position=track_obj.get("position"),
- )
- if track_obj.get("performer") and "Various " not in track_obj["performer"]:
- artist = await self._parse_artist(track_obj["performer"])
- if artist:
- track.artists.append(artist)
- if not track.artists:
- # try to grab artist from album
- if (
- track_obj.get("album")
- and track_obj["album"].get("artist")
- and "Various " not in track_obj["album"]["artist"]
- ):
- artist = await self._parse_artist(track_obj["album"]["artist"])
- if artist:
- track.artists.append(artist)
- if not track.artists:
- # last resort: parse from performers string
- for performer_str in track_obj["performers"].split(" - "):
- role = performer_str.split(", ")[1]
- name = performer_str.split(", ")[0]
- if "artist" in role.lower():
- artist = Artist(name, self.type, name)
- track.artists.append(artist)
- # TODO: fix grabbing composer from details
-
- if "album" in track_obj:
- album = await self._parse_album(track_obj["album"])
- if album:
- track.album = album
- if track_obj.get("isrc"):
- track.isrc = track_obj["isrc"]
- if track_obj.get("performers"):
- track.metadata.performers = {
- x.strip() for x in track_obj["performers"].split("-")
- }
- if track_obj.get("copyright"):
- track.metadata.copyright = track_obj["copyright"]
- if track_obj.get("audio_info"):
- track.metadata.replaygain = track_obj["audio_info"]["replaygain_track_gain"]
- if track_obj.get("parental_warning"):
- track.metadata.explicit = True
- if img := self.__get_image(track_obj):
- track.metadata.images = [MediaItemImage(ImageType.THUMB, img)]
-
- track.add_provider_mapping(
- ProviderMapping(
- item_id=str(track_obj["id"]),
- provider_type=self.type,
- provider_id=self.id,
- available=track_obj["streamable"] and track_obj["displayable"],
- content_type=ContentType.FLAC,
- sample_rate=track_obj["maximum_sampling_rate"] * 1000,
- bit_depth=track_obj["maximum_bit_depth"],
- url=track_obj.get(
- "url", f'https://open.qobuz.com/track/{track_obj["id"]}'
- ),
- )
- )
- return track
-
- async def _parse_playlist(self, playlist_obj):
- """Parse qobuz playlist object to generic layout."""
- playlist = Playlist(
- item_id=str(playlist_obj["id"]),
- provider=self.type,
- name=playlist_obj["name"],
- owner=playlist_obj["owner"]["name"],
- )
- playlist.add_provider_mapping(
- ProviderMapping(
- item_id=str(playlist_obj["id"]),
- provider_type=self.type,
- provider_id=self.id,
- url=playlist_obj.get(
- "url", f'https://open.qobuz.com/playlist/{playlist_obj["id"]}'
- ),
- )
- )
- playlist.is_editable = (
- playlist_obj["owner"]["id"] == self._user_auth_info["user"]["id"]
- or playlist_obj["is_collaborative"]
- )
- if img := self.__get_image(playlist_obj):
- playlist.metadata.images = [MediaItemImage(ImageType.THUMB, img)]
- playlist.metadata.checksum = str(playlist_obj["updated_at"])
- return playlist
-
- async def _auth_token(self):
- """Login to qobuz and store the token."""
- if self._user_auth_info:
- return self._user_auth_info["user_auth_token"]
- params = {
- "username": self.config.username,
- "password": self.config.password,
- "device_manufacturer_id": "music_assistant",
- }
- details = await self._get_data("user/login", **params)
- if details and "user" in details:
- self._user_auth_info = details
- self.logger.info(
- "Succesfully logged in to Qobuz as %s", details["user"]["display_name"]
- )
- self.mass.metadata.preferred_language = details["user"]["country_code"]
- return details["user_auth_token"]
-
- async def _get_all_items(self, endpoint, key="tracks", **kwargs):
- """Get all items from a paged list."""
- limit = 50
- offset = 0
- all_items = []
- while True:
- kwargs["limit"] = limit
- kwargs["offset"] = offset
- result = await self._get_data(endpoint, **kwargs)
- offset += limit
- if not result:
- break
- if not result.get(key) or not result[key].get("items"):
- break
- for item in result[key]["items"]:
- item["position"] = len(all_items) + 1
- all_items.append(item)
- if len(result[key]["items"]) < limit:
- break
- return all_items
-
- async def _get_data(self, endpoint, sign_request=False, **kwargs):
- """Get data from api."""
- url = f"http://www.qobuz.com/api.json/0.2/{endpoint}"
- headers = {"X-App-Id": app_var(0)}
- if endpoint != "user/login":
- auth_token = await self._auth_token()
- if not auth_token:
- self.logger.debug("Not logged in")
- return None
- headers["X-User-Auth-Token"] = auth_token
- if sign_request:
- signing_data = "".join(endpoint.split("/"))
- keys = list(kwargs.keys())
- keys.sort()
- for key in keys:
- signing_data += f"{key}{kwargs[key]}"
- request_ts = str(time.time())
- request_sig = signing_data + request_ts + app_var(1)
- request_sig = str(hashlib.md5(request_sig.encode()).hexdigest())
- kwargs["request_ts"] = request_ts
- kwargs["request_sig"] = request_sig
- kwargs["app_id"] = app_var(0)
- kwargs["user_auth_token"] = await self._auth_token()
- async with self._throttler:
- async with self.mass.http_session.get(
- url, headers=headers, params=kwargs, verify_ssl=False
- ) as response:
- try:
- # make sure status is 200
- assert response.status == 200
- result = await response.json()
- # check for error in json
- if error := result.get("error"):
- raise ValueError(error)
- if result.get("status") and "error" in result["status"]:
- raise ValueError(result["status"])
- except (
- aiohttp.ContentTypeError,
- JSONDecodeError,
- AssertionError,
- ValueError,
- ) as err:
- text = await response.text()
- self.logger.exception(
- "Error while processing %s: %s", endpoint, text, exc_info=err
- )
- return None
- return result
-
- async def _post_data(self, endpoint, params=None, data=None):
- """Post data to api."""
- if not params:
- params = {}
- if not data:
- data = {}
- url = f"http://www.qobuz.com/api.json/0.2/{endpoint}"
- params["app_id"] = app_var(0)
- params["user_auth_token"] = await self._auth_token()
- async with self.mass.http_session.post(
- url, params=params, json=data, verify_ssl=False
- ) as response:
- try:
- result = await response.json()
- # check for error in json
- if error := result.get("error"):
- raise ValueError(error)
- if result.get("status") and "error" in result["status"]:
- raise ValueError(result["status"])
- except (
- aiohttp.ContentTypeError,
- JSONDecodeError,
- AssertionError,
- ValueError,
- ) as err:
- text = await response.text()
- self.logger.exception(
- "Error while processing %s: %s", endpoint, text, exc_info=err
- )
- return None
- return result
-
- def __get_image(self, obj: dict) -> Optional[str]:
- """Try to parse image from Qobuz media object."""
- if obj.get("image"):
- for key in ["extralarge", "large", "medium", "small"]:
- if obj["image"].get(key):
- if "2a96cbd8b46e442fc41c2b86b821562f" in obj["image"][key]:
- continue
- return obj["image"][key]
- if obj.get("images300"):
- # playlists seem to use this strange format
- return obj["images300"][0]
- if obj.get("album"):
- return self.__get_image(obj["album"])
- if obj.get("artist"):
- return self.__get_image(obj["artist"])
- return None
+++ /dev/null
-"""Package with Spotify Music provider."""
-
-from .spotify import SpotifyProvider # noqa
+++ /dev/null
-"""Spotify musicprovider support for MusicAssistant."""
-from __future__ import annotations
-
-import asyncio
-import json
-import os
-import platform
-import time
-from json.decoder import JSONDecodeError
-from tempfile import gettempdir
-from typing import AsyncGenerator, List, Optional, Tuple
-
-import aiohttp
-from asyncio_throttle import Throttler
-
-from music_assistant.helpers.app_vars import ( # noqa # pylint: disable=no-name-in-module
- app_var,
-)
-from music_assistant.helpers.process import AsyncProcess
-from music_assistant.helpers.util import parse_title_and_version
-from music_assistant.models.enums import MusicProviderFeature, ProviderType
-from music_assistant.models.errors import LoginFailed, MediaNotFoundError
-from music_assistant.models.media_items import (
- Album,
- AlbumType,
- Artist,
- ContentType,
- ImageType,
- MediaItemImage,
- MediaItemType,
- MediaType,
- Playlist,
- ProviderMapping,
- StreamDetails,
- Track,
-)
-from music_assistant.models.music_provider import MusicProvider
-
-CACHE_DIR = gettempdir()
-
-
-class SpotifyProvider(MusicProvider):
- """Implementation of a Spotify MusicProvider."""
-
- _attr_type = ProviderType.SPOTIFY
- _attr_name = "Spotify"
- _auth_token = None
- _sp_user = None
- _librespot_bin = None
- _throttler = Throttler(rate_limit=4, period=1)
- _cache_dir = CACHE_DIR
- _ap_workaround = False
-
- @property
- def supported_features(self) -> Tuple[MusicProviderFeature]:
- """Return the features supported by this MusicProvider."""
- return (
- MusicProviderFeature.LIBRARY_ARTISTS,
- MusicProviderFeature.LIBRARY_ALBUMS,
- MusicProviderFeature.LIBRARY_TRACKS,
- MusicProviderFeature.LIBRARY_PLAYLISTS,
- MusicProviderFeature.LIBRARY_ARTISTS_EDIT,
- MusicProviderFeature.LIBRARY_ALBUMS_EDIT,
- MusicProviderFeature.LIBRARY_PLAYLISTS_EDIT,
- MusicProviderFeature.LIBRARY_TRACKS_EDIT,
- MusicProviderFeature.PLAYLIST_TRACKS_EDIT,
- MusicProviderFeature.BROWSE,
- MusicProviderFeature.SEARCH,
- MusicProviderFeature.ARTIST_ALBUMS,
- MusicProviderFeature.ARTIST_TOPTRACKS,
- MusicProviderFeature.SIMILAR_TRACKS,
- )
-
- async def setup(self) -> bool:
- """Handle async initialization of the provider."""
- if not self.config.enabled:
- return False
- # try to get a token, raise if that fails
- self._cache_dir = os.path.join(CACHE_DIR, self.id)
- # try login which will raise if it fails
- await self.login()
- return True
-
- async def search(
- self, search_query: str, media_types=Optional[List[MediaType]], limit: int = 5
- ) -> List[MediaItemType]:
- """
- Perform search on musicprovider.
-
- :param search_query: Search query.
- :param media_types: A list of media_types to include. All types if None.
- :param limit: Number of items to return in the search (per type).
- """
- result = []
- searchtypes = []
- if MediaType.ARTIST in media_types:
- searchtypes.append("artist")
- if MediaType.ALBUM in media_types:
- searchtypes.append("album")
- if MediaType.TRACK in media_types:
- searchtypes.append("track")
- if MediaType.PLAYLIST in media_types:
- searchtypes.append("playlist")
- searchtype = ",".join(searchtypes)
- search_query = search_query.replace("'", "")
- if searchresult := await self._get_data(
- "search", q=search_query, type=searchtype, limit=limit
- ):
- if "artists" in searchresult:
- result += [
- await self._parse_artist(item)
- for item in searchresult["artists"]["items"]
- if (item and item["id"])
- ]
- if "albums" in searchresult:
- result += [
- await self._parse_album(item)
- for item in searchresult["albums"]["items"]
- if (item and item["id"])
- ]
- if "tracks" in searchresult:
- result += [
- await self._parse_track(item)
- for item in searchresult["tracks"]["items"]
- if (item and item["id"])
- ]
- if "playlists" in searchresult:
- result += [
- await self._parse_playlist(item)
- for item in searchresult["playlists"]["items"]
- if (item and item["id"])
- ]
- return result
-
- async def get_library_artists(self) -> AsyncGenerator[Artist, None]:
- """Retrieve library artists from spotify."""
- endpoint = "me/following"
- while True:
- spotify_artists = await self._get_data(
- endpoint,
- type="artist",
- limit=50,
- )
- for item in spotify_artists["artists"]["items"]:
- if item and item["id"]:
- yield await self._parse_artist(item)
- if spotify_artists["artists"]["next"]:
- endpoint = spotify_artists["artists"]["next"]
- endpoint = endpoint.replace("https://api.spotify.com/v1/", "")
- else:
- break
-
- async def get_library_albums(self) -> AsyncGenerator[Album, None]:
- """Retrieve library albums from the provider."""
- for item in await self._get_all_items("me/albums"):
- if item["album"] and item["album"]["id"]:
- yield await self._parse_album(item["album"])
-
- async def get_library_tracks(self) -> AsyncGenerator[Track, None]:
- """Retrieve library tracks from the provider."""
- for item in await self._get_all_items("me/tracks"):
- if item and item["track"]["id"]:
- yield await self._parse_track(item["track"])
-
- async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]:
- """Retrieve playlists from the provider."""
- for item in await self._get_all_items("me/playlists"):
- if item and item["id"]:
- yield await self._parse_playlist(item)
-
- async def get_artist(self, prov_artist_id) -> Artist:
- """Get full artist details by id."""
- artist_obj = await self._get_data(f"artists/{prov_artist_id}")
- return await self._parse_artist(artist_obj) if artist_obj else None
-
- async def get_album(self, prov_album_id) -> Album:
- """Get full album details by id."""
- album_obj = await self._get_data(f"albums/{prov_album_id}")
- return await self._parse_album(album_obj) if album_obj else None
-
- async def get_track(self, prov_track_id) -> Track:
- """Get full track details by id."""
- track_obj = await self._get_data(f"tracks/{prov_track_id}")
- return await self._parse_track(track_obj) if track_obj else None
-
- async def get_playlist(self, prov_playlist_id) -> Playlist:
- """Get full playlist details by id."""
- playlist_obj = await self._get_data(f"playlists/{prov_playlist_id}")
- return await self._parse_playlist(playlist_obj) if playlist_obj else None
-
- async def get_album_tracks(self, prov_album_id) -> List[Track]:
- """Get all album tracks for given album id."""
- return [
- await self._parse_track(item)
- for item in await self._get_all_items(f"albums/{prov_album_id}/tracks")
- if (item and item["id"])
- ]
-
- async def get_playlist_tracks(self, prov_playlist_id) -> List[Track]:
- """Get all playlist tracks for given playlist id."""
- count = 0
- result = []
- for item in await self._get_all_items(
- f"playlists/{prov_playlist_id}/tracks",
- ):
- if not (item and item["track"] and item["track"]["id"]):
- continue
- track = await self._parse_track(item["track"])
- # use count as position
- track.position = count
- result.append(track)
- count += 1
- return result
-
- async def get_artist_albums(self, prov_artist_id) -> List[Album]:
- """Get a list of all albums for the given artist."""
- return [
- await self._parse_album(item)
- for item in await self._get_all_items(
- f"artists/{prov_artist_id}/albums?include_groups=album,single,compilation"
- )
- if (item and item["id"])
- ]
-
- async def get_artist_toptracks(self, prov_artist_id) -> List[Track]:
- """Get a list of 10 most popular tracks for the given artist."""
- artist = await self.get_artist(prov_artist_id)
- endpoint = f"artists/{prov_artist_id}/top-tracks"
- items = await self._get_data(endpoint)
- return [
- await self._parse_track(item, artist=artist)
- for item in items["tracks"]
- if (item and item["id"])
- ]
-
- async def library_add(self, prov_item_id, media_type: MediaType):
- """Add item to library."""
- result = False
- if media_type == MediaType.ARTIST:
- result = await self._put_data(
- "me/following", {"ids": prov_item_id, "type": "artist"}
- )
- elif media_type == MediaType.ALBUM:
- result = await self._put_data("me/albums", {"ids": prov_item_id})
- elif media_type == MediaType.TRACK:
- result = await self._put_data("me/tracks", {"ids": prov_item_id})
- elif media_type == MediaType.PLAYLIST:
- result = await self._put_data(
- f"playlists/{prov_item_id}/followers", data={"public": False}
- )
- return result
-
- async def library_remove(self, prov_item_id, media_type: MediaType):
- """Remove item from library."""
- result = False
- if media_type == MediaType.ARTIST:
- result = await self._delete_data(
- "me/following", {"ids": prov_item_id, "type": "artist"}
- )
- elif media_type == MediaType.ALBUM:
- result = await self._delete_data("me/albums", {"ids": prov_item_id})
- elif media_type == MediaType.TRACK:
- result = await self._delete_data("me/tracks", {"ids": prov_item_id})
- elif media_type == MediaType.PLAYLIST:
- result = await self._delete_data(f"playlists/{prov_item_id}/followers")
- return result
-
- async def add_playlist_tracks(
- self, prov_playlist_id: str, prov_track_ids: List[str]
- ):
- """Add track(s) to playlist."""
- track_uris = []
- for track_id in prov_track_ids:
- track_uris.append(f"spotify:track:{track_id}")
- data = {"uris": track_uris}
- return await self._post_data(f"playlists/{prov_playlist_id}/tracks", data=data)
-
- async def remove_playlist_tracks(
- self, prov_playlist_id: str, positions_to_remove: Tuple[int]
- ) -> None:
- """Remove track(s) from playlist."""
- track_uris = []
- for track in await self.get_playlist_tracks(prov_playlist_id):
- if track.position in positions_to_remove:
- track_uris.append({"uri": f"spotify:track:{track.item_id}"})
- if len(track_uris) == positions_to_remove:
- break
- data = {"tracks": track_uris}
- return await self._delete_data(
- f"playlists/{prov_playlist_id}/tracks", data=data
- )
-
- async def get_similar_tracks(self, prov_track_id, limit=25) -> List[Track]:
- """Retrieve a dynamic list of tracks based on the provided item."""
- endpoint = "recommendations"
- items = await self._get_data(endpoint, seed_tracks=prov_track_id, limit=limit)
- return [
- await self._parse_track(item)
- for item in items["tracks"]
- if (item and item["id"])
- ]
-
- async def get_stream_details(self, item_id: str) -> StreamDetails:
- """Return the content details for the given track when it will be streamed."""
- # make sure a valid track is requested.
- track = await self.get_track(item_id)
- if not track:
- raise MediaNotFoundError(f"track {item_id} not found")
- # make sure that the token is still valid by just requesting it
- await self.login()
- return StreamDetails(
- item_id=track.item_id,
- provider=self.type,
- content_type=ContentType.OGG,
- duration=track.duration,
- )
-
- async def get_audio_stream(
- self, streamdetails: StreamDetails, seek_position: int = 0
- ) -> AsyncGenerator[bytes, None]:
- """Return the audio stream for the provider item."""
- # make sure that the token is still valid by just requesting it
- await self.login()
- librespot = await self.get_librespot_binary()
- args = [
- librespot,
- "-c",
- self._cache_dir,
- "--pass-through",
- "-b",
- "320",
- "--single-track",
- f"spotify://track:{streamdetails.item_id}",
- ]
- if seek_position:
- args += ["--start-position", str(int(seek_position))]
- if self._ap_workaround:
- args += ["--ap-port", "12345"]
- bytes_sent = 0
- async with AsyncProcess(args) as librespot_proc:
- async for chunk in librespot_proc.iter_any():
- yield chunk
- bytes_sent += len(chunk)
-
- if bytes_sent == 0 and not self._ap_workaround:
- # AP resolve failure
- # https://github.com/librespot-org/librespot/issues/972
- # retry with ap-port set to invalid value, which will force fallback
- args += ["--ap-port", "12345"]
- async with AsyncProcess(args) as librespot_proc:
- async for chunk in librespot_proc.iter_any(64000):
- yield chunk
- self._ap_workaround = True
-
- async def _parse_artist(self, artist_obj):
- """Parse spotify artist object to generic layout."""
- artist = Artist(
- item_id=artist_obj["id"], provider=self.type, name=artist_obj["name"]
- )
- artist.add_provider_mapping(
- ProviderMapping(
- item_id=artist_obj["id"],
- provider_type=self.type,
- provider_id=self.id,
- url=artist_obj["external_urls"]["spotify"],
- )
- )
- if "genres" in artist_obj:
- artist.metadata.genres = set(artist_obj["genres"])
- if artist_obj.get("images"):
- for img in artist_obj["images"]:
- img_url = img["url"]
- if "2a96cbd8b46e442fc41c2b86b821562f" not in img_url:
- artist.metadata.images = [MediaItemImage(ImageType.THUMB, img_url)]
- break
- return artist
-
- async def _parse_album(self, album_obj: dict):
- """Parse spotify album object to generic layout."""
- name, version = parse_title_and_version(album_obj["name"])
- album = Album(
- item_id=album_obj["id"], provider=self.type, name=name, version=version
- )
- for artist_obj in album_obj["artists"]:
- album.artists.append(await self._parse_artist(artist_obj))
- if album_obj["album_type"] == "single":
- album.album_type = AlbumType.SINGLE
- elif album_obj["album_type"] == "compilation":
- album.album_type = AlbumType.COMPILATION
- elif album_obj["album_type"] == "album":
- album.album_type = AlbumType.ALBUM
- if "genres" in album_obj:
- album.metadata.genre = set(album_obj["genres"])
- if album_obj.get("images"):
- album.metadata.images = [
- MediaItemImage(ImageType.THUMB, album_obj["images"][0]["url"])
- ]
- if "external_ids" in album_obj and album_obj["external_ids"].get("upc"):
- album.upc = album_obj["external_ids"]["upc"]
- if "label" in album_obj:
- album.metadata.label = album_obj["label"]
- if album_obj.get("release_date"):
- album.year = int(album_obj["release_date"].split("-")[0])
- if album_obj.get("copyrights"):
- album.metadata.copyright = album_obj["copyrights"][0]["text"]
- if album_obj.get("explicit"):
- album.metadata.explicit = album_obj["explicit"]
- album.add_provider_mapping(
- ProviderMapping(
- item_id=album_obj["id"],
- provider_type=self.type,
- provider_id=self.id,
- content_type=ContentType.OGG,
- bit_rate=320,
- url=album_obj["external_urls"]["spotify"],
- )
- )
- return album
-
- async def _parse_track(self, track_obj, artist=None):
- """Parse spotify track object to generic layout."""
- name, version = parse_title_and_version(track_obj["name"])
- track = Track(
- item_id=track_obj["id"],
- provider=self.type,
- name=name,
- version=version,
- duration=track_obj["duration_ms"] / 1000,
- disc_number=track_obj["disc_number"],
- track_number=track_obj["track_number"],
- position=track_obj.get("position"),
- )
- if artist:
- track.artists.append(artist)
- for track_artist in track_obj.get("artists", []):
- artist = await self._parse_artist(track_artist)
- if artist and artist.item_id not in {x.item_id for x in track.artists}:
- track.artists.append(artist)
-
- track.metadata.explicit = track_obj["explicit"]
- if "preview_url" in track_obj:
- track.metadata.preview = track_obj["preview_url"]
- if "external_ids" in track_obj and "isrc" in track_obj["external_ids"]:
- track.isrc = track_obj["external_ids"]["isrc"]
- if "album" in track_obj:
- track.album = await self._parse_album(track_obj["album"])
- if track_obj["album"].get("images"):
- track.metadata.images = [
- MediaItemImage(
- ImageType.THUMB, track_obj["album"]["images"][0]["url"]
- )
- ]
- if track_obj.get("copyright"):
- track.metadata.copyright = track_obj["copyright"]
- if track_obj.get("explicit"):
- track.metadata.explicit = True
- if track_obj.get("popularity"):
- track.metadata.popularity = track_obj["popularity"]
- track.add_provider_mapping(
- ProviderMapping(
- item_id=track_obj["id"],
- provider_type=self.type,
- provider_id=self.id,
- content_type=ContentType.OGG,
- bit_rate=320,
- url=track_obj["external_urls"]["spotify"],
- available=not track_obj["is_local"] and track_obj["is_playable"],
- )
- )
- return track
-
- async def _parse_playlist(self, playlist_obj):
- """Parse spotify playlist object to generic layout."""
- playlist = Playlist(
- item_id=playlist_obj["id"],
- provider=self.type,
- name=playlist_obj["name"],
- owner=playlist_obj["owner"]["display_name"],
- )
- playlist.add_provider_mapping(
- ProviderMapping(
- item_id=playlist_obj["id"],
- provider_type=self.type,
- provider_id=self.id,
- url=playlist_obj["external_urls"]["spotify"],
- )
- )
- playlist.is_editable = (
- playlist_obj["owner"]["id"] == self._sp_user["id"]
- or playlist_obj["collaborative"]
- )
- if playlist_obj.get("images"):
- playlist.metadata.images = [
- MediaItemImage(ImageType.THUMB, playlist_obj["images"][0]["url"])
- ]
- playlist.metadata.checksum = str(playlist_obj["snapshot_id"])
- return playlist
-
- async def login(self) -> dict:
- """Log-in Spotify and return tokeninfo."""
- # return existing token if we have one in memory
- if (
- self._auth_token
- and os.path.isdir(self._cache_dir)
- and (self._auth_token["expiresAt"] > int(time.time()) + 20)
- ):
- return self._auth_token
- tokeninfo, userinfo = None, self._sp_user
- if not self.config.username or not self.config.password:
- raise LoginFailed("Invalid login credentials")
- # retrieve token with librespot
- retries = 0
- while retries < 20:
- try:
- retries += 1
- if not tokeninfo:
- tokeninfo = await asyncio.wait_for(self._get_token(), 5)
- if tokeninfo and not userinfo:
- userinfo = await asyncio.wait_for(
- self._get_data("me", tokeninfo=tokeninfo), 5
- )
- if tokeninfo and userinfo:
- # we have all info we need!
- break
- if retries > 2:
- # switch to ap workaround after 2 retries
- self._ap_workaround = True
- except asyncio.exceptions.TimeoutError:
- await asyncio.sleep(2)
- if tokeninfo and userinfo:
- self._auth_token = tokeninfo
- self._sp_user = userinfo
- self.mass.metadata.preferred_language = userinfo["country"]
- self.logger.info("Succesfully logged in to Spotify as %s", userinfo["id"])
- self._auth_token = tokeninfo
- return tokeninfo
- if tokeninfo and not userinfo:
- raise LoginFailed(
- "Unable to retrieve userdetails from Spotify API - probably just a temporary error"
- )
- if self.config.username.isnumeric():
- # a spotify free/basic account can be recognized when
- # the username consists of numbers only - check that here
- # an integer can be parsed of the username, this is a free account
- raise LoginFailed("Only Spotify Premium accounts are supported")
- raise LoginFailed(f"Login failed for user {self.config.username}")
-
- async def _get_token(self):
- """Get spotify auth token with librespot bin."""
- time_start = time.time()
- # authorize with username and password (NOTE: this can also be Spotify Connect)
- args = [
- await self.get_librespot_binary(),
- "-O",
- "-c",
- self._cache_dir,
- "-a",
- "-u",
- self.config.username,
- "-p",
- self.config.password,
- ]
- librespot = await asyncio.create_subprocess_exec(*args)
- await librespot.wait()
- # get token with (authorized) librespot
- scopes = [
- "user-read-playback-state",
- "user-read-currently-playing",
- "user-modify-playback-state",
- "playlist-read-private",
- "playlist-read-collaborative",
- "playlist-modify-public",
- "playlist-modify-private",
- "user-follow-modify",
- "user-follow-read",
- "user-library-read",
- "user-library-modify",
- "user-read-private",
- "user-read-email",
- "user-read-birthdate",
- "user-top-read",
- ]
- scope = ",".join(scopes)
- args = [
- await self.get_librespot_binary(),
- "-O",
- "-t",
- "--client-id",
- app_var(2),
- "--scope",
- scope,
- "-c",
- self._cache_dir,
- ]
- if self._ap_workaround:
- args += ["--ap-port", "12345"]
- librespot = await asyncio.create_subprocess_exec(
- *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
- )
- stdout, _ = await librespot.communicate()
- duration = round(time.time() - time_start, 2)
- try:
- result = json.loads(stdout)
- except JSONDecodeError:
- self.logger.warning(
- "Error while retrieving Spotify token after %s seconds, details: %s",
- duration,
- stdout.decode(),
- )
- return None
- self.logger.debug(
- "Retrieved Spotify token using librespot in %s seconds",
- duration,
- )
- # transform token info to spotipy compatible format
- if result and "accessToken" in result:
- tokeninfo = result
- tokeninfo["expiresAt"] = tokeninfo["expiresIn"] + int(time.time())
- return tokeninfo
- return None
-
- async def _get_all_items(self, endpoint, key="items", **kwargs) -> List[dict]:
- """Get all items from a paged list."""
- limit = 50
- offset = 0
- all_items = []
- while True:
- kwargs["limit"] = limit
- kwargs["offset"] = offset
- result = await self._get_data(endpoint, **kwargs)
- offset += limit
- if not result or key not in result or not result[key]:
- break
- for item in result[key]:
- item["position"] = len(all_items) + 1
- all_items.append(item)
- if len(result[key]) < limit:
- break
- return all_items
-
- async def _get_data(self, endpoint, tokeninfo: Optional[dict] = None, **kwargs):
- """Get data from api."""
- url = f"https://api.spotify.com/v1/{endpoint}"
- kwargs["market"] = "from_token"
- kwargs["country"] = "from_token"
- if tokeninfo is None:
- tokeninfo = await self.login()
- headers = {"Authorization": f'Bearer {tokeninfo["accessToken"]}'}
- async with self._throttler:
- time_start = time.time()
- try:
- async with self.mass.http_session.get(
- url, headers=headers, params=kwargs, verify_ssl=False, timeout=120
- ) as response:
- result = await response.json()
- if "error" in result or (
- "status" in result and "error" in result["status"]
- ):
- self.logger.error("%s - %s", endpoint, result)
- return None
- except (
- aiohttp.ContentTypeError,
- JSONDecodeError,
- ) as err:
- self.logger.error("%s - %s", endpoint, str(err))
- return None
- finally:
- self.logger.debug(
- "Processing GET/%s took %s seconds",
- endpoint,
- round(time.time() - time_start, 2),
- )
- return result
-
- async def _delete_data(self, endpoint, data=None, **kwargs):
- """Delete data from api."""
- url = f"https://api.spotify.com/v1/{endpoint}"
- token = await self.login()
- if not token:
- return None
- headers = {"Authorization": f'Bearer {token["accessToken"]}'}
- async with self.mass.http_session.delete(
- url, headers=headers, params=kwargs, json=data, verify_ssl=False
- ) as response:
- return await response.text()
-
- async def _put_data(self, endpoint, data=None, **kwargs):
- """Put data on api."""
- url = f"https://api.spotify.com/v1/{endpoint}"
- token = await self.login()
- if not token:
- return None
- headers = {"Authorization": f'Bearer {token["accessToken"]}'}
- async with self.mass.http_session.put(
- url, headers=headers, params=kwargs, json=data, verify_ssl=False
- ) as response:
- return await response.text()
-
- async def _post_data(self, endpoint, data=None, **kwargs):
- """Post data on api."""
- url = f"https://api.spotify.com/v1/{endpoint}"
- token = await self.login()
- if not token:
- return None
- headers = {"Authorization": f'Bearer {token["accessToken"]}'}
- async with self.mass.http_session.post(
- url, headers=headers, params=kwargs, json=data, verify_ssl=False
- ) as response:
- return await response.text()
-
- async def get_librespot_binary(self):
- """Find the correct librespot binary belonging to the platform."""
- if self._librespot_bin is not None:
- return self._librespot_bin
-
- async def check_librespot(librespot_path: str) -> str | None:
- try:
- librespot = await asyncio.create_subprocess_exec(
- *[librespot_path, "--check"], stdout=asyncio.subprocess.PIPE
- )
- stdout, _ = await librespot.communicate()
- if (
- librespot.returncode == 0
- and b"ok spotty" in stdout
- and b"using librespot" in stdout
- ):
- self._librespot_bin = librespot_path
- return librespot_path
- except OSError:
- return None
-
- base_path = os.path.join(os.path.dirname(__file__), "librespot")
- if platform.system() == "Windows":
- if librespot := await check_librespot(
- os.path.join(base_path, "windows", "librespot.exe")
- ):
- return librespot
- if platform.system() == "Darwin":
- # macos binary is x86_64 intel
- if librespot := await check_librespot(
- os.path.join(base_path, "osx", "librespot")
- ):
- return librespot
-
- if platform.system() == "FreeBSD":
- # FreeBSD binary is x86_64 intel
- if librespot := await check_librespot(
- os.path.join(base_path, "freebsd", "librespot")
- ):
- return librespot
-
- if platform.system() == "Linux":
- architecture = platform.machine()
- if architecture in ["AMD64", "x86_64"]:
- # generic linux x86_64 binary
- if librespot := await check_librespot(
- os.path.join(
- base_path,
- "linux",
- "librespot-x86_64",
- )
- ):
- return librespot
-
- # arm architecture... try all options one by one...
- for arch in ["aarch64", "armv7", "armhf", "arm"]:
- if librespot := await check_librespot(
- os.path.join(
- base_path,
- "linux",
- f"librespot-{arch}",
- )
- ):
- return librespot
-
- raise RuntimeError(
- f"Unable to locate Libespot for {platform.system()} ({platform.machine()})"
- )
+++ /dev/null
-"""Package with Tune-In Music provider."""
-
-from .tunein import TuneInProvider # noqa
+++ /dev/null
-"""Tune-In musicprovider support for MusicAssistant."""
-from __future__ import annotations
-
-from time import time
-from typing import AsyncGenerator, List, Optional, Tuple
-
-from asyncio_throttle import Throttler
-
-from music_assistant.helpers.audio import get_radio_stream
-from music_assistant.helpers.playlists import fetch_playlist
-from music_assistant.helpers.tags import parse_tags
-from music_assistant.helpers.util import create_sort_name
-from music_assistant.models.enums import MusicProviderFeature, ProviderType
-from music_assistant.models.errors import LoginFailed, MediaNotFoundError
-from music_assistant.models.media_items import (
- ContentType,
- ImageType,
- MediaItemImage,
- MediaType,
- ProviderMapping,
- Radio,
- StreamDetails,
-)
-from music_assistant.models.music_provider import MusicProvider
-
-
-class TuneInProvider(MusicProvider):
- """Provider implementation for Tune In."""
-
- _attr_type = ProviderType.TUNEIN
- _attr_name = "Tune-in Radio"
- _throttler = Throttler(rate_limit=1, period=1)
-
- @property
- def supported_features(self) -> Tuple[MusicProviderFeature]:
- """Return the features supported by this MusicProvider."""
- return (
- MusicProviderFeature.LIBRARY_RADIOS,
- MusicProviderFeature.BROWSE,
- )
-
- async def setup(self) -> bool:
- """Handle async initialization of the provider."""
- if not self.config.enabled:
- return False
- if not self.config.username:
- raise LoginFailed("Username is invalid")
- if "@" in self.config.username:
- self.logger.warning(
- "Emailadress detected instead of username, "
- "it is advised to use the tunein username instead of email."
- )
- return True
-
- async def get_library_radios(self) -> AsyncGenerator[Radio, None]:
- """Retrieve library/subscribed radio stations from the provider."""
-
- async def parse_items(
- items: List[dict], folder: str = None
- ) -> AsyncGenerator[Radio, None]:
- for item in items:
- item_type = item.get("type", "")
- if item_type == "audio":
- if "preset_id" not in item:
- continue
- # each radio station can have multiple streams add each one as different quality
- stream_info = await self.__get_data(
- "Tune.ashx", id=item["preset_id"]
- )
- for stream in stream_info["body"]:
- yield await self._parse_radio(item, stream, folder)
- elif item_type == "link" and item.get("item") == "url":
- # custom url
- yield await self._parse_radio(item)
- elif item_type == "link":
- # stations are in sublevel (new style)
- if sublevel := await self.__get_data(item["URL"], render="json"):
- async for subitem in parse_items(
- sublevel["body"], item["text"]
- ):
- yield subitem
- elif item.get("children"):
- # stations are in sublevel (old style ?)
- async for subitem in parse_items(item["children"], item["text"]):
- yield subitem
-
- data = await self.__get_data("Browse.ashx", c="presets")
- if data and "body" in data:
- async for item in parse_items(data["body"]):
- yield item
-
- async def get_radio(self, prov_radio_id: str) -> Radio:
- """Get radio station details."""
- if not prov_radio_id.startswith("http"):
- prov_radio_id, media_type = prov_radio_id.split("--", 1)
- params = {"c": "composite", "detail": "listing", "id": prov_radio_id}
- result = await self.__get_data("Describe.ashx", **params)
- if result and result.get("body") and result["body"][0].get("children"):
- item = result["body"][0]["children"][0]
- stream_info = await self.__get_data("Tune.ashx", id=prov_radio_id)
- for stream in stream_info["body"]:
- if stream["media_type"] != media_type:
- continue
- return await self._parse_radio(item, stream)
- # fallback - e.g. for handle custom urls ...
- async for radio in self.get_library_radios():
- if radio.item_id == prov_radio_id:
- return radio
- return None
-
- async def _parse_radio(
- self, details: dict, stream: Optional[dict] = None, folder: Optional[str] = None
- ) -> Radio:
- """Parse Radio object from json obj returned from api."""
- if "name" in details:
- name = details["name"]
- else:
- # parse name from text attr
- name = details["text"]
- if " | " in name:
- name = name.split(" | ")[1]
- name = name.split(" (")[0]
-
- if stream is None:
- # custom url (no stream object present)
- url = details["URL"]
- item_id = url
- media_info = await parse_tags(url)
- content_type = ContentType.try_parse(media_info.format)
- bit_rate = media_info.bit_rate
- else:
- url = stream["url"]
- item_id = f'{details["preset_id"]}--{stream["media_type"]}'
- content_type = ContentType.try_parse(stream["media_type"])
- bit_rate = stream.get("bitrate", 128) # TODO !
-
- radio = Radio(item_id=item_id, provider=self.type, name=name)
- radio.add_provider_mapping(
- ProviderMapping(
- item_id=item_id,
- provider_type=self.type,
- provider_id=self.id,
- content_type=content_type,
- bit_rate=bit_rate,
- details=url,
- )
- )
- # preset number is used for sorting (not present at stream time)
- preset_number = details.get("preset_number")
- if preset_number and folder:
- radio.sort_name = f'{folder}-{details["preset_number"]}'
- elif preset_number:
- radio.sort_name = details["preset_number"]
- radio.sort_name += create_sort_name(name)
- if "text" in details:
- radio.metadata.description = details["text"]
- # images
- if img := details.get("image"):
- radio.metadata.images = [MediaItemImage(ImageType.THUMB, img)]
- if img := details.get("logo"):
- radio.metadata.images = [MediaItemImage(ImageType.LOGO, img)]
- return radio
-
- async def get_stream_details(self, item_id: str) -> StreamDetails:
- """Get streamdetails for a radio station."""
- if item_id.startswith("http"):
- # custom url
- return StreamDetails(
- provider=self.type,
- item_id=item_id,
- content_type=ContentType.UNKNOWN,
- media_type=MediaType.RADIO,
- data=item_id,
- )
- item_id, media_type = item_id.split("--", 1)
- stream_info = await self.__get_data("Tune.ashx", id=item_id)
- for stream in stream_info["body"]:
-
- if stream["media_type"] != media_type:
- continue
- # check if the radio stream is not a playlist
- url = stream["url"]
- if url.endswith("m3u8") or url.endswith("m3u") or url.endswith("pls"):
- playlist = await fetch_playlist(self.mass, url)
- url = playlist[0]
- return StreamDetails(
- provider=self.type,
- item_id=item_id,
- content_type=ContentType(stream["media_type"]),
- media_type=MediaType.RADIO,
- data=url,
- expires=time() + 24 * 3600,
- )
- raise MediaNotFoundError(f"Unable to retrieve stream details for {item_id}")
-
- async def get_audio_stream(
- self, streamdetails: StreamDetails, seek_position: int = 0
- ) -> AsyncGenerator[bytes, None]:
- """Return the audio stream for the provider item."""
- async for chunk in get_radio_stream(
- self.mass, streamdetails.data, streamdetails
- ):
- yield chunk
-
- async def __get_data(self, endpoint: str, **kwargs):
- """Get data from api."""
- if endpoint.startswith("http"):
- url = endpoint
- else:
- url = f"https://opml.radiotime.com/{endpoint}"
- kwargs["formats"] = "ogg,aac,wma,mp3"
- kwargs["username"] = self.config.username
- kwargs["partnerId"] = "1"
- kwargs["render"] = "json"
- async with self._throttler:
- async with self.mass.http_session.get(
- url, params=kwargs, verify_ssl=False
- ) as response:
- result = await response.json()
- if not result or "error" in result:
- self.logger.error(url)
- self.logger.error(kwargs)
- result = None
- return result
+++ /dev/null
-"""Package with URL Music provider."""
-
-from .url import URLProvider # noqa
+++ /dev/null
-"""Basic provider allowing for external URL's to be streamed."""
-from __future__ import annotations
-
-import os
-from typing import AsyncGenerator, Tuple
-
-from music_assistant.helpers.audio import (
- get_file_stream,
- get_http_stream,
- get_radio_stream,
-)
-from music_assistant.helpers.playlists import fetch_playlist
-from music_assistant.helpers.tags import AudioTags, parse_tags
-from music_assistant.models.config import MusicProviderConfig
-from music_assistant.models.enums import ContentType, ImageType, MediaType, ProviderType
-from music_assistant.models.media_items import (
- Artist,
- MediaItemImage,
- MediaItemType,
- ProviderMapping,
- Radio,
- StreamDetails,
- Track,
-)
-from music_assistant.models.music_provider import MusicProvider
-
-PROVIDER_CONFIG = MusicProviderConfig(ProviderType.URL)
-
-# pylint: disable=arguments-renamed
-
-
-class URLProvider(MusicProvider):
- """Music Provider for manual URL's/files added to the queue."""
-
- _attr_name: str = "URL"
- _attr_type: ProviderType = ProviderType.URL
- _attr_available: bool = True
- _full_url = {}
-
- async def setup(self) -> bool:
- """
- Handle async initialization of the provider.
-
- Called when provider is registered.
- """
- return True
-
- async def get_track(self, prov_track_id: str) -> Track:
- """Get full track details by id."""
- return await self.parse_item(prov_track_id)
-
- async def get_radio(self, prov_radio_id: str) -> Radio:
- """Get full radio details by id."""
- return await self.parse_item(prov_radio_id)
-
- async def get_artist(self, prov_artist_id: str) -> Track:
- """Get full artist details by id."""
- artist = prov_artist_id
- # this is here for compatibility reasons only
- return Artist(
- artist,
- self.type,
- artist,
- provider_mappings={
- ProviderMapping(artist, self.type, self.id, available=False)
- },
- )
-
- async def get_item(self, media_type: MediaType, prov_item_id: str) -> MediaItemType:
- """Get single MediaItem from provider."""
- if media_type == MediaType.ARTIST:
- return await self.get_artist(prov_item_id)
- if media_type == MediaType.TRACK:
- return await self.get_track(prov_item_id)
- if media_type == MediaType.RADIO:
- return await self.get_radio(prov_item_id)
- if media_type == MediaType.UNKNOWN:
- return await self.parse_item(prov_item_id)
- raise NotImplementedError
-
- async def parse_item(
- self, item_id_or_url: str, force_refresh: bool = False
- ) -> Track | Radio:
- """Parse plain URL to MediaItem of type Radio or Track."""
- item_id, url, media_info = await self._get_media_info(
- item_id_or_url, force_refresh
- )
- is_radio = media_info.get("icy-name") or not media_info.duration
- if is_radio:
- # treat as radio
- media_item = Radio(
- item_id=item_id,
- provider=self.type,
- name=media_info.get("icy-name") or media_info.title,
- )
- else:
- media_item = Track(
- item_id=item_id,
- provider=self.type,
- name=media_info.title,
- duration=int(media_info.duration or 0),
- artists=[
- await self.get_artist(artist) for artist in media_info.artists
- ],
- )
-
- media_item.provider_mappings = {
- ProviderMapping(
- item_id=item_id,
- provider_type=self.type,
- provider_id=self.id,
- content_type=ContentType.try_parse(media_info.format),
- sample_rate=media_info.sample_rate,
- bit_depth=media_info.bits_per_sample,
- bit_rate=media_info.bit_rate,
- )
- }
- if media_info.has_cover_image:
- media_item.metadata.images = [MediaItemImage(ImageType.THUMB, url, True)]
- return media_item
-
- async def _get_media_info(
- self, item_id_or_url: str, force_refresh: bool = False
- ) -> Tuple[str, str, AudioTags]:
- """Retrieve (cached) mediainfo for url."""
- # check if the radio stream is not a playlist
- if (
- item_id_or_url.endswith("m3u8")
- or item_id_or_url.endswith("m3u")
- or item_id_or_url.endswith("pls")
- ):
- playlist = await fetch_playlist(self.mass, item_id_or_url)
- url = playlist[0]
- item_id = item_id_or_url
- self._full_url[item_id] = url
- elif "?" in item_id_or_url or "&" in item_id_or_url:
- # store the 'real' full url to be picked up later
- # this makes sure that we're not storing any temporary data like auth keys etc
- # a request for an url mediaitem always passes here first before streamdetails
- url = item_id_or_url
- item_id = item_id_or_url.split("?")[0].split("&")[0]
- self._full_url[item_id] = url
- else:
- url = self._full_url.get(item_id_or_url, item_id_or_url)
- item_id = item_id_or_url
- cache_key = f"{self.type.value}.media_info.{item_id}"
- # do we have some cached info for this url ?
- cached_info = await self.mass.cache.get(cache_key)
- if cached_info and not force_refresh:
- media_info = AudioTags.parse(cached_info)
- else:
- # parse info with ffprobe (and store in cache)
- media_info = await parse_tags(url)
- if "authSig" in url:
- media_info.has_cover_image = False
- await self.mass.cache.set(cache_key, media_info.raw)
- return (item_id, url, media_info)
-
- async def get_stream_details(self, item_id: str) -> StreamDetails | None:
- """Get streamdetails for a track/radio."""
- item_id, url, media_info = await self._get_media_info(item_id)
- is_radio = media_info.get("icy-name") or not media_info.duration
- return StreamDetails(
- provider=self.type,
- item_id=item_id,
- content_type=ContentType.try_parse(media_info.format),
- media_type=MediaType.RADIO if is_radio else MediaType.TRACK,
- sample_rate=media_info.sample_rate,
- bit_depth=media_info.bits_per_sample,
- direct=None if is_radio else url,
- data=url,
- )
-
- async def get_audio_stream(
- self, streamdetails: StreamDetails, seek_position: int = 0
- ) -> AsyncGenerator[bytes, None]:
- """Return the audio stream for the provider item."""
- if streamdetails.media_type == MediaType.RADIO:
- # radio stream url
- async for chunk in get_radio_stream(
- self.mass, streamdetails.data, streamdetails
- ):
- yield chunk
- elif os.path.isfile(streamdetails.data):
- # local file
- async for chunk in get_file_stream(
- self.mass, streamdetails.data, streamdetails, seek_position
- ):
- yield chunk
- else:
- # regular stream url (without icy meta)
- async for chunk in get_http_stream(
- self.mass, streamdetails.data, streamdetails, seek_position
- ):
- yield chunk
+++ /dev/null
-"""Package with Youtube Music provider."""
-
-from .ytmusic import YoutubeMusicProvider # noqa
+++ /dev/null
-"""
-Helper module for parsing the Youtube Music API.
-
-This helpers file is an async wrapper around the excellent ytmusicapi package.
-While the ytmusicapi package does an excellent job at parsing the Youtube Music results,
-it is unfortunately not async, which is required for Music Assistant to run smoothly.
-This also nicely separates the parsing logic from the Youtube Music provider logic.
-"""
-
-import asyncio
-import json
-from time import time
-from typing import Dict, List
-
-import ytmusicapi
-
-
-async def get_artist(prov_artist_id: str) -> Dict[str, str]:
- """Async wrapper around the ytmusicapi get_artist function."""
-
- def _get_artist():
- ytm = ytmusicapi.YTMusic()
- try:
- artist = ytm.get_artist(channelId=prov_artist_id)
- # ChannelId can sometimes be different and original ID is not part of the response
- artist["channelId"] = prov_artist_id
- except KeyError:
- user = ytm.get_user(channelId=prov_artist_id)
- artist = {"channelId": prov_artist_id, "name": user["name"]}
- return artist
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _get_artist)
-
-
-async def get_album(prov_album_id: str) -> Dict[str, str]:
- """Async wrapper around the ytmusicapi get_album function."""
-
- def _get_album():
- ytm = ytmusicapi.YTMusic()
- return ytm.get_album(browseId=prov_album_id)
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _get_album)
-
-
-async def get_playlist(
- prov_playlist_id: str, headers: Dict[str, str], username: str
-) -> Dict[str, str]:
- """Async wrapper around the ytmusicapi get_playlist function."""
-
- def _get_playlist():
- user = username if is_brand_account(username) else None
- ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
- playlist = ytm.get_playlist(playlistId=prov_playlist_id)
- playlist["checksum"] = get_playlist_checksum(playlist)
- return playlist
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _get_playlist)
-
-
-async def get_track(prov_track_id: str) -> Dict[str, str]:
- """Async wrapper around the ytmusicapi get_playlist function."""
-
- def _get_song():
- ytm = ytmusicapi.YTMusic()
- track_obj = ytm.get_song(videoId=prov_track_id)
- track = {}
- track["videoId"] = track_obj["videoDetails"]["videoId"]
- track["title"] = track_obj["videoDetails"]["title"]
- track["artists"] = [
- {
- "channelId": track_obj["videoDetails"]["channelId"],
- "name": track_obj["videoDetails"]["author"],
- }
- ]
- track["duration"] = track_obj["videoDetails"]["lengthSeconds"]
- track["thumbnails"] = track_obj["microformat"]["microformatDataRenderer"][
- "thumbnail"
- ]["thumbnails"]
- track["isAvailable"] = track_obj["playabilityStatus"]["status"] == "OK"
- return track
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _get_song)
-
-
-async def get_library_artists(headers: Dict[str, str], username: str) -> Dict[str, str]:
- """Async wrapper around the ytmusicapi get_library_artists function."""
-
- def _get_library_artists():
- user = username if is_brand_account(username) else None
- ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
- artists = ytm.get_library_subscriptions(limit=9999)
- # Sync properties with uniformal artist object
- for artist in artists:
- artist["id"] = artist["browseId"]
- artist["name"] = artist["artist"]
- del artist["browseId"]
- del artist["artist"]
- return artists
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _get_library_artists)
-
-
-async def get_library_albums(headers: Dict[str, str], username: str) -> Dict[str, str]:
- """Async wrapper around the ytmusicapi get_library_albums function."""
-
- def _get_library_albums():
- user = username if is_brand_account(username) else None
- ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
- return ytm.get_library_albums(limit=9999)
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _get_library_albums)
-
-
-async def get_library_playlists(
- headers: Dict[str, str], username: str
-) -> Dict[str, str]:
- """Async wrapper around the ytmusicapi get_library_playlists function."""
-
- def _get_library_playlists():
- user = username if is_brand_account(username) else None
- ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
- playlists = ytm.get_library_playlists(limit=9999)
- # Sync properties with uniformal playlist object
- for playlist in playlists:
- playlist["id"] = playlist["playlistId"]
- del playlist["playlistId"]
- playlist["checksum"] = get_playlist_checksum(playlist)
- return playlists
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _get_library_playlists)
-
-
-async def get_library_tracks(headers: Dict[str, str], username: str) -> Dict[str, str]:
- """Async wrapper around the ytmusicapi get_library_tracks function."""
-
- def _get_library_tracks():
- user = username if is_brand_account(username) else None
- ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
- tracks = ytm.get_library_songs(limit=9999)
- return tracks
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _get_library_tracks)
-
-
-async def library_add_remove_artist(
- headers: Dict[str, str], prov_artist_id: str, add: bool = True, username: str = None
-) -> bool:
- """Add or remove an artist to the user's library."""
-
- def _library_add_remove_artist():
- user = username if is_brand_account(username) else None
- ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
- if add:
- return "actions" in ytm.subscribe_artists(channelIds=[prov_artist_id])
- if not add:
- return "actions" in ytm.unsubscribe_artists(channelIds=[prov_artist_id])
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _library_add_remove_artist)
-
-
-async def library_add_remove_album(
- headers: Dict[str, str], prov_item_id: str, add: bool = True, username: str = None
-) -> bool:
- """Add or remove an album or playlist to the user's library."""
- album = await get_album(prov_album_id=prov_item_id)
-
- def _library_add_remove_album():
- user = username if is_brand_account(username) else None
- ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
- playlist_id = album["audioPlaylistId"]
- if add:
- return ytm.rate_playlist(playlist_id, "LIKE")
- if not add:
- return ytm.rate_playlist(playlist_id, "INDIFFERENT")
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _library_add_remove_album)
-
-
-async def library_add_remove_playlist(
- headers: Dict[str, str], prov_item_id: str, add: bool = True, username: str = None
-) -> bool:
- """Add or remove an album or playlist to the user's library."""
-
- def _library_add_remove_playlist():
- user = username if is_brand_account(username) else None
- ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
- if add:
- return "actions" in ytm.rate_playlist(prov_item_id, "LIKE")
- if not add:
- return "actions" in ytm.rate_playlist(prov_item_id, "INDIFFERENT")
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _library_add_remove_playlist)
-
-
-async def add_remove_playlist_tracks(
- headers: Dict[str, str],
- prov_playlist_id: str,
- prov_track_ids: List[str],
- add: bool,
- username: str = None,
-) -> bool:
- """Async wrapper around adding/removing tracks to a playlist."""
-
- def _add_playlist_tracks():
- user = username if is_brand_account(username) else None
- ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
- if add:
- return ytm.add_playlist_items(
- playlistId=prov_playlist_id, videoIds=prov_track_ids
- )
- if not add:
- return ytm.remove_playlist_items(
- playlistId=prov_playlist_id, videos=prov_track_ids
- )
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _add_playlist_tracks)
-
-
-async def get_song_radio_tracks(
- headers: Dict[str, str], username: str, prov_item_id: str, limit=25
-) -> Dict[str, str]:
- """Async wrapper around the ytmusicapi radio function."""
- user = username if is_brand_account(username) else None
-
- def _get_song_radio_tracks():
- ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
- playlist_id = f"RDAMVM{prov_item_id}"
- result = ytm.get_watch_playlist(
- videoId=prov_item_id, playlistId=playlist_id, limit=limit
- )
- # Replace inconsistensies for easier parsing
- for track in result["tracks"]:
- if track.get("thumbnail"):
- track["thumbnails"] = track["thumbnail"]
- del track["thumbnail"]
- if track.get("length"):
- track["duration"] = get_sec(track["length"])
- return result
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _get_song_radio_tracks)
-
-
-async def search(query: str, ytm_filter: str = None, limit: int = 20) -> List[Dict]:
- """Async wrapper around the ytmusicapi search function."""
-
- def _search():
- ytm = ytmusicapi.YTMusic()
- results = ytm.search(query=query, filter=ytm_filter, limit=limit)
- # Sync result properties with uniformal objects
- for result in results:
- if result["resultType"] == "artist":
- result["id"] = result["browseId"]
- result["name"] = result["artist"]
- del result["browseId"]
- del result["artist"]
- elif result["resultType"] == "playlist":
- if "playlistId" in result:
- result["id"] = result["playlistId"]
- del result["playlistId"]
- elif "browseId" in result:
- result["id"] = result["browseId"]
- del result["browseId"]
- return results
-
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, _search)
-
-
-def get_playlist_checksum(playlist_obj: dict) -> str:
- """Try to calculate a checksum so we can detect changes in a playlist."""
- for key in ("duration_seconds", "trackCount"):
- if key in playlist_obj:
- return playlist_obj[key]
- return str(int(time()))
-
-
-def is_brand_account(username: str) -> bool:
- """Check if the provided username is a brand-account."""
- return len(username) == 21 and username.isdigit()
-
-
-def get_sec(time_str):
- """Get seconds from time."""
- parts = time_str.split(":")
- if len(parts) == 3:
- return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2])
- if len(parts) == 2:
- return int(parts[0]) * 60 + int(parts[1])
- return 0
+++ /dev/null
-"""Youtube Music support for MusicAssistant."""
-import re
-from operator import itemgetter
-from time import time
-from typing import AsyncGenerator, Dict, List, Optional, Tuple
-from urllib.parse import unquote
-
-import pytube
-import ytmusicapi
-
-from music_assistant.models.enums import MusicProviderFeature, ProviderType
-from music_assistant.models.errors import (
- InvalidDataError,
- LoginFailed,
- MediaNotFoundError,
-)
-from music_assistant.models.media_items import (
- Album,
- AlbumType,
- Artist,
- ContentType,
- ImageType,
- MediaItemImage,
- MediaItemType,
- MediaType,
- Playlist,
- ProviderMapping,
- StreamDetails,
- Track,
-)
-from music_assistant.models.music_provider import MusicProvider
-from music_assistant.music_providers.ytmusic.helpers import (
- add_remove_playlist_tracks,
- get_album,
- get_artist,
- get_library_albums,
- get_library_artists,
- get_library_playlists,
- get_library_tracks,
- get_playlist,
- get_song_radio_tracks,
- get_track,
- library_add_remove_album,
- library_add_remove_artist,
- library_add_remove_playlist,
- search,
-)
-
-YT_DOMAIN = "https://www.youtube.com"
-YTM_DOMAIN = "https://music.youtube.com"
-YTM_BASE_URL = f"{YTM_DOMAIN}/youtubei/v1/"
-
-
-class YoutubeMusicProvider(MusicProvider):
- """Provider for Youtube Music."""
-
- _attr_type = ProviderType.YTMUSIC
- _attr_name = "Youtube Music"
- _headers = None
- _context = None
- _cookies = None
- _signature_timestamp = 0
- _cipher = None
-
- @property
- def supported_features(self) -> Tuple[MusicProviderFeature]:
- """Return the features supported by this MusicProvider."""
- return (
- MusicProviderFeature.LIBRARY_ARTISTS,
- MusicProviderFeature.LIBRARY_ALBUMS,
- MusicProviderFeature.LIBRARY_TRACKS,
- MusicProviderFeature.LIBRARY_PLAYLISTS,
- MusicProviderFeature.BROWSE,
- MusicProviderFeature.SEARCH,
- MusicProviderFeature.ARTIST_ALBUMS,
- MusicProviderFeature.ARTIST_TOPTRACKS,
- MusicProviderFeature.SIMILAR_TRACKS,
- )
-
- async def setup(self) -> bool:
- """Set up the YTMusic provider."""
- if not self.config.enabled:
- return False
- if not self.config.username or not self.config.password:
- raise LoginFailed("Invalid login credentials")
- await self._initialize_headers(cookie=self.config.password)
- await self._initialize_context()
- self._cookies = {"CONSENT": "YES+1"}
- self._signature_timestamp = await self._get_signature_timestamp()
- return True
-
- async def search(
- self, search_query: str, media_types=Optional[List[MediaType]], limit: int = 5
- ) -> List[MediaItemType]:
- """
- Perform search on musicprovider.
-
- :param search_query: Search query.
- :param media_types: A list of media_types to include. All types if None.
- :param limit: Number of items to return in the search (per type).
- """
- ytm_filter = None
- if len(media_types) == 1:
- # YTM does not support multiple searchtypes, falls back to all if no type given
- if media_types[0] == MediaType.ARTIST:
- ytm_filter = "artists"
- if media_types[0] == MediaType.ALBUM:
- ytm_filter = "albums"
- if media_types[0] == MediaType.TRACK:
- ytm_filter = "songs"
- if media_types[0] == MediaType.PLAYLIST:
- ytm_filter = "playlists"
- results = await search(query=search_query, ytm_filter=ytm_filter, limit=limit)
- parsed_results = []
- for result in results:
- if result["resultType"] == "artist":
- parsed_results.append(await self._parse_artist(result))
- elif result["resultType"] == "album":
- parsed_results.append(await self._parse_album(result))
- elif result["resultType"] == "playlist":
- parsed_results.append(await self._parse_playlist(result))
- elif result["resultType"] == "song":
- if track := await self._parse_track(result):
- parsed_results.append(track)
- return parsed_results
-
- async def get_library_artists(self) -> AsyncGenerator[Artist, None]:
- """Retrieve all library artists from Youtube Music."""
- artists_obj = await get_library_artists(
- headers=self._headers, username=self.config.username
- )
- for artist in artists_obj:
- yield await self._parse_artist(artist)
-
- async def get_library_albums(self) -> AsyncGenerator[Album, None]:
- """Retrieve all library albums from Youtube Music."""
- albums_obj = await get_library_albums(
- headers=self._headers, username=self.config.username
- )
- for album in albums_obj:
- yield await self._parse_album(album, album["browseId"])
-
- async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]:
- """Retrieve all library playlists from the provider."""
- playlists_obj = await get_library_playlists(
- headers=self._headers, username=self.config.username
- )
- for playlist in playlists_obj:
- yield await self._parse_playlist(playlist)
-
- async def get_library_tracks(self) -> AsyncGenerator[Track, None]:
- """Retrieve library tracks from Youtube Music."""
- tracks_obj = await get_library_tracks(
- headers=self._headers, username=self.config.username
- )
- for track in tracks_obj:
- # Library tracks sometimes do not have a valid artist id
- # In that case, call the API for track details based on track id
- try:
- yield await self._parse_track(track)
- except InvalidDataError:
- track = await self.get_track(track["videoId"])
- yield track
-
- async def get_album(self, prov_album_id) -> Album:
- """Get full album details by id."""
- album_obj = await get_album(prov_album_id=prov_album_id)
- return (
- await self._parse_album(album_obj=album_obj, album_id=prov_album_id)
- if album_obj
- else None
- )
-
- async def get_album_tracks(self, prov_album_id: str) -> List[Track]:
- """Get album tracks for given album id."""
- album_obj = await get_album(prov_album_id=prov_album_id)
- if not album_obj.get("tracks"):
- return []
- tracks = []
- for idx, track_obj in enumerate(album_obj["tracks"], 1):
- track = await self._parse_track(track_obj=track_obj)
- track.disc_number = 0
- track.track_number = idx
- tracks.append(track)
- return tracks
-
- async def get_artist(self, prov_artist_id) -> Artist:
- """Get full artist details by id."""
- artist_obj = await get_artist(prov_artist_id=prov_artist_id)
- return await self._parse_artist(artist_obj=artist_obj) if artist_obj else None
-
- async def get_track(self, prov_track_id) -> Track:
- """Get full track details by id."""
- track_obj = await get_track(prov_track_id=prov_track_id)
- return await self._parse_track(track_obj)
-
- async def get_playlist(self, prov_playlist_id) -> Playlist:
- """Get full playlist details by id."""
- playlist_obj = await get_playlist(
- prov_playlist_id=prov_playlist_id,
- headers=self._headers,
- username=self.config.username,
- )
- return await self._parse_playlist(playlist_obj)
-
- async def get_playlist_tracks(self, prov_playlist_id) -> List[Track]:
- """Get all playlist tracks for given playlist id."""
- playlist_obj = await get_playlist(
- prov_playlist_id=prov_playlist_id,
- headers=self._headers,
- username=self.config.username,
- )
- if "tracks" not in playlist_obj:
- return []
- tracks = []
- for index, track in enumerate(playlist_obj["tracks"]):
- if track["isAvailable"]:
- # Playlist tracks sometimes do not have a valid artist id
- # In that case, call the API for track details based on track id
- try:
- track = await self._parse_track(track)
- if track:
- track.position = index
- tracks.append(track)
- except InvalidDataError:
- track = await self.get_track(track["videoId"])
- if track:
- track.position = index
- tracks.append(track)
- return tracks
-
- async def get_artist_albums(self, prov_artist_id) -> List[Album]:
- """Get a list of albums for the given artist."""
- artist_obj = await get_artist(prov_artist_id=prov_artist_id)
- if "albums" in artist_obj and "results" in artist_obj["albums"]:
- albums = []
- for album_obj in artist_obj["albums"]["results"]:
- if "artists" not in album_obj:
- album_obj["artists"] = [
- {"id": artist_obj["channelId"], "name": artist_obj["name"]}
- ]
- albums.append(await self._parse_album(album_obj, album_obj["browseId"]))
- return albums
- return []
-
- async def get_artist_toptracks(self, prov_artist_id) -> List[Track]:
- """Get a list of 25 most popular tracks for the given artist."""
- artist_obj = await get_artist(prov_artist_id=prov_artist_id)
- if artist_obj.get("songs") and artist_obj["songs"].get("browseId"):
- prov_playlist_id = artist_obj["songs"]["browseId"]
- playlist_tracks = await self.get_playlist_tracks(
- prov_playlist_id=prov_playlist_id
- )
- return playlist_tracks[:25]
- return []
-
- async def library_add(self, prov_item_id, media_type: MediaType) -> None:
- """Add an item to the library."""
- result = False
- if media_type == MediaType.ARTIST:
- result = await library_add_remove_artist(
- headers=self._headers,
- prov_artist_id=prov_item_id,
- add=True,
- username=self.config.username,
- )
- elif media_type == MediaType.ALBUM:
- result = await library_add_remove_album(
- headers=self._headers,
- prov_item_id=prov_item_id,
- add=True,
- username=self.config.username,
- )
- elif media_type == MediaType.PLAYLIST:
- result = await library_add_remove_playlist(
- headers=self._headers,
- prov_item_id=prov_item_id,
- add=True,
- username=self.config.username,
- )
- elif media_type == MediaType.TRACK:
- raise NotImplementedError
- return result
-
- async def library_remove(self, prov_item_id, media_type: MediaType):
- """Remove an item from the library."""
- result = False
- if media_type == MediaType.ARTIST:
- result = await library_add_remove_artist(
- headers=self._headers,
- prov_artist_id=prov_item_id,
- add=False,
- username=self.config.username,
- )
- elif media_type == MediaType.ALBUM:
- result = await library_add_remove_album(
- headers=self._headers,
- prov_item_id=prov_item_id,
- add=False,
- username=self.config.username,
- )
- elif media_type == MediaType.PLAYLIST:
- result = await library_add_remove_playlist(
- headers=self._headers,
- prov_item_id=prov_item_id,
- add=False,
- username=self.config.username,
- )
- elif media_type == MediaType.TRACK:
- raise NotImplementedError
- return result
-
- async def add_playlist_tracks(
- self, prov_playlist_id: str, prov_track_ids: List[str]
- ) -> None:
- """Add track(s) to playlist."""
- return await add_remove_playlist_tracks(
- headers=self._headers,
- prov_playlist_id=prov_playlist_id,
- prov_track_ids=prov_track_ids,
- add=True,
- username=self.config.username,
- )
-
- async def remove_playlist_tracks(
- self, prov_playlist_id: str, positions_to_remove: Tuple[int]
- ) -> None:
- """Remove track(s) from playlist."""
- playlist_obj = await get_playlist(
- prov_playlist_id=prov_playlist_id,
- headers=self._headers,
- username=self.config.username,
- )
- if "tracks" not in playlist_obj:
- return
- tracks_to_delete = []
- for index, track in enumerate(playlist_obj["tracks"]):
- if index in positions_to_remove:
- # YT needs both the videoId and the setVideoId in order to remove
- # the track. Thus, we need to obtain the playlist details and
- # grab the info from there.
- tracks_to_delete.append(
- {"videoId": track["videoId"], "setVideoId": track["setVideoId"]}
- )
-
- return await add_remove_playlist_tracks(
- headers=self._headers,
- prov_playlist_id=prov_playlist_id,
- prov_track_ids=tracks_to_delete,
- add=False,
- username=self.config.username,
- )
-
- async def get_similar_tracks(self, prov_track_id, limit=25) -> List[Track]:
- """Retrieve a dynamic list of tracks based on the provided item."""
- result = []
- result = await get_song_radio_tracks(
- headers=self._headers,
- username=self.config.username,
- prov_item_id=prov_track_id,
- limit=limit,
- )
- if "tracks" in result:
- tracks = []
- for track in result["tracks"]:
- # Playlist tracks sometimes do not have a valid artist id
- # In that case, call the API for track details based on track id
- try:
- track = await self._parse_track(track)
- if track:
- tracks.append(track)
- except InvalidDataError:
- track = await self.get_track(track["videoId"])
- if track:
- tracks.append(track)
- return tracks
- return []
-
- async def get_stream_details(self, item_id: str) -> StreamDetails:
- """Return the content details for the given track when it will be streamed."""
- data = {
- "playbackContext": {
- "contentPlaybackContext": {
- "signatureTimestamp": self._signature_timestamp
- }
- },
- "video_id": item_id,
- }
- track_obj = await self._post_data("player", data=data)
- stream_format = await self._parse_stream_format(track_obj)
- url = await self._parse_stream_url(stream_format=stream_format, item_id=item_id)
- stream_details = StreamDetails(
- provider=self.type,
- item_id=item_id,
- content_type=ContentType.try_parse(stream_format["mimeType"]),
- direct=url,
- )
- if (
- track_obj["streamingData"].get("expiresInSeconds")
- and track_obj["streamingData"].get("expiresInSeconds").isdigit()
- ):
- stream_details.expires = time() + int(
- track_obj["streamingData"].get("expiresInSeconds")
- )
- if (
- stream_format.get("audioChannels")
- and str(stream_format.get("audioChannels")).isdigit()
- ):
- stream_details.channels = int(stream_format.get("audioChannels"))
- if (
- stream_format.get("audioSampleRate")
- and stream_format.get("audioSampleRate").isdigit()
- ):
- stream_details.sample_rate = int(stream_format.get("audioSampleRate"))
- return stream_details
-
- async def _post_data(self, endpoint: str, data: Dict[str, str], **kwargs):
- url = f"{YTM_BASE_URL}{endpoint}"
- data.update(self._context)
- async with self.mass.http_session.post(
- url,
- headers=self._headers,
- json=data,
- verify_ssl=False,
- cookies=self._cookies,
- ) as response:
- return await response.json()
-
- async def _get_data(self, url: str, params: Dict = None):
- async with self.mass.http_session.get(
- url, headers=self._headers, params=params, cookies=self._cookies
- ) as response:
- return await response.text()
-
- async def _initialize_headers(self, cookie: str) -> Dict[str, str]:
- """Return headers to include in the requests."""
- headers = {
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:72.0) Gecko/20100101 Firefox/72.0",
- "Accept": "*/*",
- "Accept-Language": "en-US,en;q=0.5",
- "Content-Type": "application/json",
- "X-Goog-AuthUser": "0",
- "x-origin": "https://music.youtube.com",
- "Cookie": cookie,
- }
- sapisid = ytmusicapi.helpers.sapisid_from_cookie(cookie)
- origin = headers.get("origin", headers.get("x-origin"))
- headers["Authorization"] = ytmusicapi.helpers.get_authorization(
- sapisid + " " + origin
- )
- self._headers = headers
-
- async def _initialize_context(self) -> Dict[str, str]:
- """Return a dict to use as a context in requests."""
- self._context = {
- "context": {
- "client": {"clientName": "WEB_REMIX", "clientVersion": "0.1"},
- "user": {},
- }
- }
-
- async def _parse_album(self, album_obj: dict, album_id: str = None) -> Album:
- """Parse a YT Album response to an Album model object."""
- album_id = album_id or album_obj.get("id") or album_obj.get("browseId")
- if "title" in album_obj:
- name = album_obj["title"]
- elif "name" in album_obj:
- name = album_obj["name"]
- album = Album(
- item_id=album_id,
- name=name,
- provider=self.type,
- )
- if album_obj.get("year") and album_obj["year"].isdigit():
- album.year = album_obj["year"]
- if "thumbnails" in album_obj:
- album.metadata.images = await self._parse_thumbnails(
- album_obj["thumbnails"]
- )
- if "description" in album_obj:
- album.metadata.description = unquote(album_obj["description"])
- if "artists" in album_obj:
- album.artists = [
- await self._parse_artist(artist)
- for artist in album_obj["artists"]
- # artist object may be missing an id
- # in that case its either a performer (like the composer) OR this
- # is a Various artists compilation album...
- if (artist.get("id") or artist["name"] == "Various Artists")
- ]
- if "type" in album_obj:
- if album_obj["type"] == "Single":
- album_type = AlbumType.SINGLE
- elif album_obj["type"] == "EP":
- album_type = AlbumType.EP
- elif album_obj["type"] == "Album":
- album_type = AlbumType.ALBUM
- else:
- album_type = AlbumType.UNKNOWN
- album.album_type = album_type
- album.add_provider_mapping(
- ProviderMapping(
- item_id=str(album_id), provider_type=self.type, provider_id=self.id
- )
- )
- return album
-
- async def _parse_artist(self, artist_obj: dict) -> Artist:
- """Parse a YT Artist response to Artist model object."""
- artist_id = None
- if "channelId" in artist_obj:
- artist_id = artist_obj["channelId"]
- elif "id" in artist_obj and artist_obj["id"]:
- artist_id = artist_obj["id"]
- elif artist_obj["name"] == "Various Artists":
- artist_id = "UCUTXlgdcKU5vfzFqHOWIvkA"
- if not artist_id:
- raise InvalidDataError("Artist does not have a valid ID")
- artist = Artist(item_id=artist_id, name=artist_obj["name"], provider=self.type)
- if "description" in artist_obj:
- artist.metadata.description = artist_obj["description"]
- if "thumbnails" in artist_obj and artist_obj["thumbnails"]:
- artist.metadata.images = await self._parse_thumbnails(
- artist_obj["thumbnails"]
- )
- artist.add_provider_mapping(
- ProviderMapping(
- item_id=str(artist_id),
- provider_type=self.type,
- provider_id=self.id,
- url=f"https://music.youtube.com/channel/{artist_id}",
- )
- )
- return artist
-
- async def _parse_playlist(self, playlist_obj: dict) -> Playlist:
- """Parse a YT Playlist response to a Playlist object."""
- playlist = Playlist(
- item_id=playlist_obj["id"], provider=self.type, name=playlist_obj["title"]
- )
- if "description" in playlist_obj:
- playlist.metadata.description = playlist_obj["description"]
- if "thumbnails" in playlist_obj and playlist_obj["thumbnails"]:
- playlist.metadata.images = await self._parse_thumbnails(
- playlist_obj["thumbnails"]
- )
- is_editable = False
- if playlist_obj.get("privacy") and playlist_obj.get("privacy") == "PRIVATE":
- is_editable = True
- playlist.is_editable = is_editable
- playlist.add_provider_mapping(
- ProviderMapping(
- item_id=playlist_obj["id"], provider_type=self.type, provider_id=self.id
- )
- )
- playlist.metadata.checksum = playlist_obj.get("checksum")
- return playlist
-
- async def _parse_track(self, track_obj: dict) -> Track:
- """Parse a YT Track response to a Track model object."""
- track = Track(
- item_id=track_obj["videoId"], provider=self.type, name=track_obj["title"]
- )
- if "artists" in track_obj:
- track.artists = [
- await self._parse_artist(artist)
- for artist in track_obj["artists"]
- if artist.get("id")
- or artist.get("channelId")
- or artist.get("name") == "Various Artists"
- ]
- if "thumbnails" in track_obj and track_obj["thumbnails"]:
- track.metadata.images = await self._parse_thumbnails(
- track_obj["thumbnails"]
- )
- if (
- track_obj.get("album")
- and track_obj.get("artists")
- and isinstance(track_obj.get("album"), dict)
- and track_obj["album"].get("id")
- ):
- album = track_obj["album"]
- album["artists"] = track_obj["artists"]
- track.album = await self._parse_album(album, album["id"])
- if "isExplicit" in track_obj:
- track.metadata.explicit = track_obj["isExplicit"]
- if "duration" in track_obj and str(track_obj["duration"]).isdigit():
- track.duration = int(track_obj["duration"])
- elif (
- "duration_seconds" in track_obj
- and str(track_obj["duration_seconds"]).isdigit()
- ):
- track.duration = int(track_obj["duration_seconds"])
- available = True
- if "isAvailable" in track_obj:
- available = track_obj["isAvailable"]
- track.add_provider_mapping(
- ProviderMapping(
- item_id=str(track_obj["videoId"]),
- provider_type=self.type,
- provider_id=self.id,
- available=available,
- content_type=ContentType.M4A,
- )
- )
- return track
-
- async def _get_signature_timestamp(self):
- """Get a signature timestamp required to generate valid stream URLs."""
- response = await self._get_data(url=YTM_DOMAIN)
- match = re.search(r'jsUrl"\s*:\s*"([^"]+)"', response)
- if match is None:
- # retry with youtube domain
- response = await self._get_data(url=YT_DOMAIN)
- match = re.search(r'jsUrl"\s*:\s*"([^"]+)"', response)
- if match is None:
- raise Exception("Could not identify the URL for base.js player.")
- url = YTM_DOMAIN + match.group(1)
- response = await self._get_data(url=url)
- match = re.search(r"signatureTimestamp[:=](\d+)", response)
- if match is None:
- raise Exception("Unable to identify the signatureTimestamp.")
- return int(match.group(1))
-
- async def _parse_stream_url(self, stream_format: dict, item_id: str) -> str:
- """Figure out the stream URL to use based on the YT track object."""
- url = None
- if stream_format.get("signatureCipher"):
- # Secured URL
- cipher_parts = {}
- for part in stream_format["signatureCipher"].split("&"):
- key, val = part.split("=", maxsplit=1)
- cipher_parts[key] = unquote(val)
- signature = await self._decipher_signature(
- ciphered_signature=cipher_parts["s"], item_id=item_id
- )
- url = cipher_parts["url"] + "&sig=" + signature
- elif stream_format.get("url"):
- # Non secured URL
- url = stream_format.get("url")
- return url
-
- @classmethod
- async def _parse_thumbnails(cls, thumbnails_obj: dict) -> List[MediaItemImage]:
- """Parse and sort a list of thumbnails and return the highest quality."""
- thumb = sorted(thumbnails_obj, key=itemgetter("width"), reverse=True)[0]
- return [MediaItemImage(ImageType.THUMB, thumb["url"])]
-
- @classmethod
- async def _parse_stream_format(cls, track_obj: dict) -> dict:
- """Grab the highest available audio stream from available streams."""
- stream_format = {}
- quality_mapper = {
- "AUDIO_QUALITY_LOW": 1,
- "AUDIO_QUALITY_MEDIUM": 2,
- "AUDIO_QUALITY_HIGH": 3,
- }
- for adaptive_format in track_obj["streamingData"]["adaptiveFormats"]:
- if adaptive_format["mimeType"].startswith("audio") and (
- not stream_format
- or quality_mapper.get(adaptive_format["audioQuality"], 0)
- > quality_mapper.get(stream_format["audioQuality"], 0)
- ):
- stream_format = adaptive_format
- if stream_format is None:
- raise MediaNotFoundError("No stream found for this track")
- return stream_format
-
- async def _decipher_signature(self, ciphered_signature: str, item_id: str):
- """Decipher the signature, required to build the Stream URL."""
-
- def _decipher():
- embed_url = f"https://www.youtube.com/embed/{item_id}"
- embed_html = pytube.request.get(embed_url)
- js_url = pytube.extract.js_url(embed_html)
- ytm_js = pytube.request.get(js_url)
- cipher = pytube.cipher.Cipher(js=ytm_js)
- return cipher
-
- if not self._cipher:
- self._cipher = await self.mass.loop.run_in_executor(None, _decipher)
- return self._cipher.get_signature(ciphered_signature)
--- /dev/null
+"""Music Assistant: The music library manager in python."""
+
+from .server import MusicAssistant # noqa
--- /dev/null
+"""Package with controllers."""
--- /dev/null
+"""Provides a simple stateless caching system."""
+from __future__ import annotations
+
+import asyncio
+import functools
+import json
+import logging
+import time
+from collections import OrderedDict
+from collections.abc import Iterator, MutableMapping
+from typing import TYPE_CHECKING, Any
+
+from music_assistant.constants import (
+ CONF_DB_CACHE,
+ DB_TABLE_CACHE,
+ DB_TABLE_SETTINGS,
+ DEFAULT_DB_CACHE,
+ ROOT_LOGGER_NAME,
+)
+from music_assistant.server.helpers.database import DatabaseConnection
+
+if TYPE_CHECKING:
+ from music_assistant.server import MusicAssistant
+
+LOGGER = logging.getLogger(f"{ROOT_LOGGER_NAME}.cache")
+SCHEMA_VERSION = 1
+
+
+class CacheController:
+ """Basic cache controller using both memory and database."""
+
+ database: DatabaseConnection | None = None
+
+ def __init__(self, mass: MusicAssistant) -> None:
+ """Initialize our caching class."""
+ self.mass = mass
+ self._mem_cache = MemoryCache(500)
+
+ async def setup(self) -> None:
+ """Async initialize of cache module."""
+ await self._setup_database()
+ self.__schedule_cleanup_task()
+
+ async def close(self) -> None:
+ """Cleanup on exit."""
+
+ async def get(self, cache_key: str, checksum: str | None = None, default=None):
+ """Get object from cache and return the results.
+
+ cache_key: the (unique) name of the cache object as reference
+ checksum: optional argument to check if the checksum in the
+ cacheobject matches the checksum provided
+ """
+ if not cache_key:
+ return None
+ cur_time = int(time.time())
+ if checksum is not None and not isinstance(checksum, str):
+ checksum = str(checksum)
+
+ # try memory cache first
+ cache_data = self._mem_cache.get(cache_key)
+ if cache_data and (not checksum or cache_data[1] == checksum) and cache_data[2] >= cur_time:
+ return cache_data[0]
+ # fall back to db cache
+ if (db_row := await self.database.get_row(DB_TABLE_CACHE, {"key": cache_key})) and (
+ not checksum or db_row["checksum"] == checksum and db_row["expires"] >= cur_time
+ ):
+ try:
+ data = await asyncio.to_thread(json.loads, db_row["data"])
+ except Exception as exc: # pylint: disable=broad-except
+ LOGGER.exception("Error parsing cache data for %s", cache_key, exc_info=exc)
+ else:
+ # also store in memory cache for faster access
+ self._mem_cache[cache_key] = (
+ data,
+ db_row["checksum"],
+ db_row["expires"],
+ )
+ return data
+ return default
+
+ async def set(self, cache_key, data, checksum="", expiration=(86400 * 30)):
+ """Set data in cache."""
+ if not cache_key:
+ return
+ if checksum is not None and not isinstance(checksum, str):
+ checksum = str(checksum)
+ expires = int(time.time() + expiration)
+ self._mem_cache[cache_key] = (data, checksum, expires)
+ if (expires - time.time()) < 3600 * 4:
+ # do not cache items in db with short expiration
+ return
+ data = await asyncio.to_thread(json.dumps, data)
+ await self.database.insert(
+ DB_TABLE_CACHE,
+ {"key": cache_key, "expires": expires, "checksum": checksum, "data": data},
+ allow_replace=True,
+ )
+
+ async def delete(self, cache_key):
+ """Delete data from cache."""
+ self._mem_cache.pop(cache_key, None)
+ await self.database.delete(DB_TABLE_CACHE, {"key": cache_key})
+
+ async def clear(self, key_filter: str | None = None) -> None:
+ """Clear all/partial items from cache."""
+ self._mem_cache = {}
+ query = f"key LIKE '%{key_filter}%'" if key_filter else None
+ await self.database.delete(DB_TABLE_CACHE, query=query)
+
+ async def auto_cleanup(self):
+ """Sceduled auto cleanup task."""
+ # for now we simply reset the memory cache
+ self._mem_cache = {}
+ cur_timestamp = int(time.time())
+ for db_row in await self.database.get_rows(DB_TABLE_CACHE):
+ # clean up db cache object only if expired
+ if db_row["expires"] < cur_timestamp:
+ await self.delete(db_row["key"])
+
+ async def _setup_database(self):
+ """Initialize database."""
+ db_url: str = self.mass.config.get(CONF_DB_CACHE, DEFAULT_DB_CACHE)
+ db_url = db_url.replace("[storage_path]", self.mass.storage_path)
+ self.database = DatabaseConnection(db_url)
+
+ # always create db tables if they don't exist to prevent errors trying to access them later
+ await self.__create_database_tables()
+ try:
+ if db_row := await self.database.get_row(DB_TABLE_SETTINGS, {"key": "version"}):
+ prev_version = int(db_row["value"])
+ else:
+ prev_version = 0
+ except (KeyError, ValueError):
+ prev_version = 0
+
+ if prev_version not in (0, SCHEMA_VERSION):
+ LOGGER.info(
+ "Performing database migration from %s to %s",
+ prev_version,
+ SCHEMA_VERSION,
+ )
+
+ if prev_version < SCHEMA_VERSION:
+ # for now just keep it simple and just recreate the table(s)
+ await self.database.execute(f"DROP TABLE IF EXISTS {DB_TABLE_CACHE}")
+
+ # recreate missing table(s)
+ await self.__create_database_tables()
+
+ # store current schema version
+ await self.database.insert_or_replace(
+ DB_TABLE_SETTINGS,
+ {"key": "version", "value": str(SCHEMA_VERSION), "type": "str"},
+ )
+ # compact db
+ await self.database.execute("VACUUM")
+
+ async def __create_database_tables(self) -> None:
+ """Create database table(s)."""
+ await self.database.execute(
+ f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_SETTINGS}(
+ key TEXT PRIMARY KEY,
+ value TEXT,
+ type TEXT
+ );"""
+ )
+ await self.database.execute(
+ f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_CACHE}(
+ key TEXT UNIQUE NOT NULL, expires INTEGER NOT NULL,
+ data TEXT, checksum TEXT NULL)"""
+ )
+
+ # create indexes
+ await self.database.execute(
+ f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_key_idx on {DB_TABLE_CACHE}(key);"
+ )
+
+ def __schedule_cleanup_task(self):
+ """Schedule the cleanup task."""
+ self.mass.create_task(self.auto_cleanup())
+ # reschedule self
+ self.mass.loop.call_later(3600, self.__schedule_cleanup_task)
+
+
+def use_cache(expiration=86400 * 30):
+ """Return decorator that can be used to cache a method's result."""
+
+ def wrapper(func):
+ @functools.wraps(func)
+ async def wrapped(*args, **kwargs):
+ method_class = args[0]
+ method_class_name = method_class.__class__.__name__
+ cache_key_parts = [method_class_name, func.__name__]
+ skip_cache = kwargs.pop("skip_cache", False)
+ cache_checksum = kwargs.pop("cache_checksum", "")
+ if len(args) > 1:
+ cache_key_parts += args[1:]
+ for key in sorted(kwargs.keys()):
+ cache_key_parts.append(f"{key}{kwargs[key]}")
+ cache_key = ".".join(cache_key_parts)
+
+ cachedata = await method_class.cache.get(cache_key, checksum=cache_checksum)
+
+ if not skip_cache and cachedata is not None:
+ return cachedata
+ result = await func(*args, **kwargs)
+ asyncio.create_task(
+ method_class.cache.set(
+ cache_key, result, expiration=expiration, checksum=cache_checksum
+ )
+ )
+ return result
+
+ return wrapped
+
+ return wrapper
+
+
+class MemoryCache(MutableMapping):
+ """Simple limited in-memory cache implementation."""
+
+ def __init__(self, maxlen: int):
+ """Initialize."""
+ self._maxlen = maxlen
+ self.d = OrderedDict()
+
+ @property
+ def maxlen(self) -> int:
+ """Return max length."""
+ return self._maxlen
+
+ def get(self, key: str, default: Any = None) -> Any:
+ """Return item or default."""
+ return self.d.get(key, default)
+
+ def pop(self, key: str, default: Any = None) -> Any:
+ """Pop item from collection."""
+ return self.d.pop(key, default)
+
+ def __getitem__(self, key: str) -> Any:
+ """Get item."""
+ self.d.move_to_end(key)
+ return self.d[key]
+
+ def __setitem__(self, key: str, value: Any) -> None:
+ """Set item."""
+ if key in self.d:
+ self.d.move_to_end(key)
+ elif len(self.d) == self.maxlen:
+ self.d.popitem(last=False)
+ self.d[key] = value
+
+ def __delitem__(self, key) -> None:
+ """Delete item."""
+ del self.d[key]
+
+ def __iter__(self) -> Iterator:
+ """Iterate items."""
+ return self.d.__iter__()
+
+ def __len__(self) -> int:
+ """Return length."""
+ return len(self.d)
--- /dev/null
+"""Logic to handle storage of persistent (configuration) settings."""
+from __future__ import annotations
+
+import asyncio
+import base64
+import logging
+import os
+from typing import TYPE_CHECKING, Any
+from uuid import uuid4
+
+import aiofiles
+from aiofiles.os import wrap
+from cryptography.fernet import Fernet
+
+from music_assistant.common.helpers.json import JSON_DECODE_EXCEPTIONS, json_dumps, json_loads
+from music_assistant.common.models.config_entries import (
+ DEFAULT_PLAYER_CONFIG_ENTRIES,
+ ConfigEntryValue,
+ PlayerConfig,
+ ProviderConfig,
+)
+from music_assistant.common.models.enums import ConfigEntryType, EventType, ProviderType
+from music_assistant.common.models.errors import PlayerUnavailableError, ProviderUnavailableError
+from music_assistant.constants import CONF_PLAYERS, CONF_PROVIDERS, CONF_SERVER_ID
+from music_assistant.server.helpers.api import api_command
+
+if TYPE_CHECKING:
+ from music_assistant.server.models.player_provider import PlayerProvider
+ from music_assistant.server.server import MusicAssistant
+
+LOGGER = logging.getLogger(__name__)
+DEFAULT_SAVE_DELAY = 120
+
+isfile = wrap(os.path.isfile)
+remove = wrap(os.remove)
+rename = wrap(os.rename)
+
+
+class ConfigController:
+ """Controller that handles storage of persistent configuration settings."""
+
+ _fernet: Fernet | None = None
+
+ def __init__(self, mass: MusicAssistant) -> None:
+ """Initialize storage controller."""
+ self.mass = mass
+ self.initialized = False
+ self._data: dict[str, Any] = {}
+ self.filename = os.path.join(self.mass.storage_path, "settings.json")
+ self._timer_handle: asyncio.TimerHandle | None = None
+
+ async def setup(self) -> None:
+ """Async initialize of controller."""
+ await self._load()
+ self.initialized = True
+ # create default server ID if needed (also used for encrypting passwords)
+ server_id: str = self.get(CONF_SERVER_ID, uuid4().hex, True)
+ fernet_key = base64.urlsafe_b64encode(server_id.encode()[:32])
+ self._fernet = Fernet(fernet_key)
+ LOGGER.debug("Started.")
+
+ async def close(self) -> None:
+ """Handle logic on server stop."""
+ if not self._timer_handle:
+ # no point in forcing a save when there are no changes pending
+ return
+ await self.async_save()
+ LOGGER.debug("Stopped.")
+
+ def get(self, key: str, default: Any = None, setdefault: bool = False) -> Any:
+ """Get value(s) for a specific key/path in persistent storage."""
+ assert self.initialized, "Not yet (async) initialized"
+ # we support a multi level hierarchy by providing the key as path,
+ # with a slash (/) as splitter. Sort that out here.
+ parent = self._data
+ subkeys = key.split("/")
+ for index, subkey in enumerate(subkeys):
+ if index == (len(subkeys) - 1):
+ if setdefault:
+ parent.setdefault(subkey, default)
+ self.save()
+ value = parent.get(subkey, default)
+ if value is None:
+ # replace None with default
+ return default
+ return value
+ elif subkey not in parent:
+ # requesting subkey from a non existing parent
+ return default
+ else:
+ parent = parent[subkey]
+ return default
+
+ def set(self, key: str, value: Any) -> None:
+ """Set value(s) for a specific key/path in persistent storage."""
+ assert self.initialized, "Not yet (async) initialized"
+ # we support a multi level hierarchy by providing the key as path,
+ # with a slash (/) as splitter.
+ parent = self._data
+ subkeys = key.split("/")
+ for index, subkey in enumerate(subkeys):
+ if index == (len(subkeys) - 1):
+ cur_value = parent.get(subkey)
+ if cur_value == value:
+ # no need to save if value did not change
+ return
+ parent[subkey] = value
+ self.save()
+ else:
+ parent.setdefault(subkey, {})
+ parent = parent[subkey]
+
+ def remove(
+ self,
+ key: str,
+ ) -> None:
+ """Remove value(s) for a specific key/path in persistent storage."""
+ assert self.initialized, "Not yet (async) initialized"
+ parent = self._data
+ subkeys = key.split("/")
+ for index, subkey in enumerate(subkeys):
+ if subkey not in parent:
+ return
+ if index == (len(subkeys) - 1):
+ parent.pop(subkey)
+ else:
+ parent.setdefault(subkey, {})
+ parent = parent[subkey]
+
+ self.save()
+
+ @api_command("config/providers")
+ def get_provider_configs(
+ self,
+ provider_type: ProviderType | None = None,
+ provider_domain: str | None = None,
+ ) -> list[ProviderConfig]:
+ """Return all known provider configurations, optionally filtered by ProviderType."""
+ raw_values: dict[str, dict] = self.get(CONF_PROVIDERS, {})
+ prov_entries = {x.domain: x.config_entries for x in self.mass.get_available_providers()}
+ return [
+ ProviderConfig.parse(
+ prov_entries[prov_conf["domain"]],
+ prov_conf,
+ decrypt_callback=self.decrypt_password,
+ )
+ for prov_conf in raw_values.values()
+ if (provider_type is None or prov_conf["type"] == provider_type)
+ and (provider_domain is None or prov_conf["domain"] == provider_domain)
+ ]
+
+ @api_command("config/providers/get")
+ def get_provider_config(self, instance_id: str) -> ProviderConfig:
+ """Return configuration for a single provider."""
+ if raw_conf := self.get(f"{CONF_PROVIDERS}/{instance_id}", {}):
+ for prov in self.mass.get_available_providers():
+ if prov.domain != raw_conf["domain"]:
+ continue
+ return ProviderConfig.parse(
+ prov.config_entries,
+ raw_conf,
+ decrypt_callback=self.decrypt_password,
+ )
+ raise KeyError(f"No config found for provider id {instance_id}")
+
+ @api_command("config/providers/set")
+ def set_provider_config(self, config: ProviderConfig) -> None:
+ """Create or update ProviderConfig."""
+ # encrypt any password values
+ for val in config.values.values():
+ if val.type == ConfigEntryType.PASSWORD:
+ val.value = self.encrypt_password(val.value)
+
+ conf_key = f"{CONF_PROVIDERS}/{config.instance_id}"
+ existing = self.get(conf_key)
+ config_dict = config.to_raw()
+ if existing == config_dict:
+ # no changes
+ return
+ self.set(conf_key, config_dict)
+ # (re)load provider
+ updated_config = self.get_provider_config(config.instance_id)
+ self.mass.create_task(self.mass.load_provider(updated_config))
+
+ @api_command("config/providers/create")
+ def create_provider_config(self, provider_domain: str) -> ProviderConfig:
+ """Create default/empty ProviderConfig.
+
+ This is intended to be used as helper method to add a new provider,
+ and it performs some quick sanity checks as well as handling the
+ instance_id generation.
+ """
+ # lookup provider manifest
+ for prov in self.mass.get_available_providers():
+ if prov.domain == provider_domain:
+ manifest = prov
+ break
+ else:
+ raise KeyError(f"Unknown provider domain: {provider_domain}")
+
+ # determine instance id based on previous configs
+ existing = self.get_provider_configs(provider_domain=provider_domain)
+ if existing and not manifest.multi_instance:
+ raise ValueError(f"Provider {manifest.name} does not support multiple instances")
+
+ count = len(existing)
+ if count == 0:
+ instance_id = provider_domain
+ name = manifest.name
+ else:
+ instance_id = f"{provider_domain}{count+1}"
+ name = f"{manifest.name} {count+1}"
+
+ return ProviderConfig.parse(
+ prov.config_entries,
+ {
+ "type": manifest.type.value,
+ "domain": manifest.domain,
+ "instance_id": instance_id,
+ "name": name,
+ "values": dict(),
+ },
+ allow_none=True,
+ )
+
+ @api_command("config/providers/remove")
+ async def remove_provider_config(self, instance_id: str) -> None:
+ """Remove ProviderConfig."""
+ conf_key = f"{CONF_PROVIDERS}/{instance_id}"
+ existing = self.get(conf_key)
+ if not existing:
+ raise KeyError(f"Provider {instance_id} does not exist")
+ await self.mass.unload_provider(instance_id)
+ if existing["type"] == "music":
+ # cleanup entries in library
+ await self.mass.music.cleanup_provider(instance_id)
+ self.remove(conf_key)
+
+ @api_command("config/players")
+ def get_player_configs(self, provider: str | None = None) -> list[PlayerConfig]:
+ """Return all known player configurations, optionally filtered by provider domain."""
+ player_configs: dict[str, dict] = self.get(CONF_PLAYERS, {})
+ # we build a list of all playerids to cover both edge cases:
+ # - player does not yet have a config stored persistently
+ # - player is disabled in config and not available
+ all_player_ids = set(player_configs.keys())
+ for player in self.mass.players:
+ all_player_ids.add(player.player_id)
+ configs = [self.get_player_config(x) for x in all_player_ids]
+ if not provider:
+ return configs
+ return [x for x in configs if x.provider == provider]
+
+ @api_command("config/players/get")
+ def get_player_config(self, player_id: str) -> PlayerConfig:
+ """Return configuration for a single player."""
+ conf = self.get(f"{CONF_PLAYERS}/{player_id}")
+ if not conf:
+ player = self.mass.players.get(player_id)
+ if not player:
+ raise PlayerUnavailableError(f"Player {player_id} is not available")
+ conf = {"provider": player.provider, "player_id": player_id}
+ try:
+ prov = self.mass.get_provider(conf["provider"])
+ prov_entries = prov.get_player_config_entries(player_id)
+ except ProviderUnavailableError:
+ prov_entries = tuple()
+
+ entries = DEFAULT_PLAYER_CONFIG_ENTRIES + prov_entries
+ return PlayerConfig.parse(entries, conf)
+
+ @api_command("config/players/get_value")
+ def get_player_config_value(self, player_id: str, key: str) -> ConfigEntryValue:
+ """Return single configentry value for a player."""
+ conf = self.get(f"{CONF_PLAYERS}/{player_id}")
+ if not conf:
+ player = self.mass.players.get(player_id)
+ if not player:
+ raise PlayerUnavailableError(f"Player {player_id} is not available")
+ conf = {"provider": player.provider, "player_id": player_id, "values": {}}
+ prov = self.mass.get_provider(conf["provider"])
+ entries = DEFAULT_PLAYER_CONFIG_ENTRIES + prov.get_player_config_entries(player_id)
+ for entry in entries:
+ if entry.key == key:
+ return ConfigEntryValue.parse(entry, conf["values"].get(key))
+ raise KeyError(f"ConfigEntry {key} is invalid")
+
+ @api_command("config/players/set")
+ def set_player_config(self, config: PlayerConfig) -> None:
+ """Create or update PlayerConfig."""
+ conf_key = f"{CONF_PLAYERS}/{config.player_id}"
+ existing = self.get(conf_key)
+ config_dict = config.to_raw()
+ if existing == config_dict:
+ # no changes
+ return
+ self.set(conf_key, config_dict)
+ # send config updated event
+ self.mass.signal_event(
+ EventType.PLAYER_CONFIG_UPDATED,
+ object_id=config.player_id,
+ data=config,
+ )
+ # signal update to the player manager
+ if player := self.mass.players.get(config.player_id):
+ player.enabled = config.enabled
+ self.mass.players.update(config.player_id)
+ # signal player provider that the config changed
+ if provider := self.mass.get_provider(config.provider):
+ assert isinstance(provider, PlayerProvider)
+ provider.on_player_config_changed(config)
+
+ @api_command("config/players/create")
+ async def create_player_config(
+ self, provider_domain: str, config: PlayerConfig | None = None
+ ) -> PlayerConfig:
+ """Register a new Player(config) if the provider supports this."""
+ provider: PlayerProvider = self.mass.get_provider(provider_domain)
+ return await provider.create_player_config(config)
+
+ @api_command("config/players/remove")
+ async def remove_player_config(self, player_id: str) -> None:
+ """Remove PlayerConfig."""
+ conf_key = f"{CONF_PLAYERS}/{player_id}"
+ existing = self.get(conf_key)
+ if not existing:
+ raise KeyError(f"Player {player_id} does not exist")
+ self.remove(conf_key)
+ if provider := self.mass.get_provider(existing["provider"]):
+ assert isinstance(provider, PlayerProvider)
+ provider.on_player_config_removed(player_id)
+
+ async def _load(self) -> None:
+ """Load data from persistent storage."""
+ assert not self._data, "Already loaded"
+
+ for filename in self.filename, f"{self.filename}.backup":
+ try:
+ _filename = os.path.join(self.mass.storage_path, filename)
+ async with aiofiles.open(_filename, "r", encoding="utf-8") as _file:
+ self._data = json_loads(await _file.read())
+ return
+ except FileNotFoundError:
+ pass
+ except JSON_DECODE_EXCEPTIONS: # pylint: disable=catching-non-exception
+ LOGGER.error("Error while reading persistent storage file %s", filename)
+ else:
+ LOGGER.debug("Loaded persistent settings from %s", filename)
+ LOGGER.debug("Started with empty storage: No persistent storage file found.")
+
+ def save(self, immediate: bool = False) -> None:
+ """Schedule save of data to disk."""
+ if self._timer_handle is not None:
+ self._timer_handle.cancel()
+ self._timer_handle = None
+
+ if immediate:
+ self.mass.loop.create_task(self.async_save())
+ else:
+ # schedule the save for later
+ self._timer_handle = self.mass.loop.call_later(
+ DEFAULT_SAVE_DELAY, self.mass.create_task, self.async_save
+ )
+
+ async def async_save(self):
+ """Save persistent data to disk."""
+ filename_backup = f"{self.filename}.backup"
+ # make backup before we write a new file
+ if await isfile(self.filename):
+ if await isfile(filename_backup):
+ await remove(filename_backup)
+ await rename(self.filename, filename_backup)
+
+ async with aiofiles.open(self.filename, "w", encoding="utf-8") as _file:
+ await _file.write(json_dumps(self._data))
+ LOGGER.debug("Saved data to persistent storage")
+
+ def encrypt_password(self, str_value: str) -> str:
+ """Encrypt a (password)string with Fernet."""
+ return self._fernet.encrypt(str_value.encode()).decode()
+
+ def decrypt_password(self, encrypted_str: str) -> str:
+ """Decrypt a (password)string with Fernet."""
+ return self._fernet.decrypt(encrypted_str.encode()).decode()
--- /dev/null
+"""Package with Media controllers."""
--- /dev/null
+"""Manage MediaItems of type Album."""
+from __future__ import annotations
+
+import asyncio
+import contextlib
+from random import choice, random
+from typing import TYPE_CHECKING
+
+from music_assistant.common.helpers.json import json_dumps
+from music_assistant.common.models.enums import EventType, ProviderFeature
+from music_assistant.common.models.errors import MediaNotFoundError, UnsupportedFeaturedException
+from music_assistant.common.models.media_items import (
+ Album,
+ AlbumType,
+ Artist,
+ ItemMapping,
+ MediaType,
+ Track,
+)
+from music_assistant.constants import DB_TABLE_ALBUMS, DB_TABLE_TRACKS, VARIOUS_ARTISTS
+from music_assistant.server.controllers.media.base import MediaControllerBase
+from music_assistant.server.helpers.compare import compare_album, loose_compare_strings
+
+if TYPE_CHECKING:
+ from music_assistant.server.models.music_provider import MusicProvider
+
+
+class AlbumsController(MediaControllerBase[Album]):
+ """Controller managing MediaItems of type Album."""
+
+ db_table = DB_TABLE_ALBUMS
+ media_type = MediaType.ALBUM
+ item_cls = Album
+
+ def __init__(self, *args, **kwargs):
+ """Initialize class."""
+ super().__init__(*args, **kwargs)
+ # register api handlers
+ self.mass.register_api_command("music/albums", self.db_items)
+ self.mass.register_api_command("music/album", self.get)
+ self.mass.register_api_command("music/album/tracks", self.tracks)
+ self.mass.register_api_command("music/album/versions", self.versions)
+ self.mass.register_api_command("music/album/update", self.update_db_item)
+ self.mass.register_api_command("music/album/delete", self.delete_db_item)
+
+ async def get(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ force_refresh: bool = False,
+ lazy: bool = True,
+ details: Album = None,
+ force_provider_item: bool = False,
+ ) -> Album:
+ """Return (full) details for a single media item."""
+ album = await super().get(
+ item_id=item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ force_refresh=force_refresh,
+ lazy=lazy,
+ details=details,
+ force_provider_item=force_provider_item,
+ )
+ # append full artist details to full album item
+ if album.artist:
+ album.artist = await self.mass.music.artists.get(
+ album.artist.item_id,
+ album.artist.provider,
+ lazy=True,
+ details=album.artist,
+ )
+ return album
+
+ async def tracks(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ ) -> list[Track]:
+ """Return album tracks for the given provider album id."""
+ if "database" not in (provider_domain, provider_instance):
+ # return provider album tracks
+ return await self._get_provider_album_tracks(
+ item_id, provider_domain or provider_instance
+ )
+
+ # db_album requested: get results from first (non-file) provider
+ return await self._get_db_album_tracks(item_id)
+
+ async def versions(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ ) -> list[Album]:
+ """Return all versions of an album we can find on all providers."""
+ assert provider_domain or provider_instance, "Provider type or ID must be specified"
+ album = await self.get(item_id, provider_domain or provider_instance)
+ # perform a search on all provider(types) to collect all versions/variants
+ provider_domains = {item.domain for item in self.mass.music.providers}
+ search_query = f"{album.artist.name} - {album.name}"
+ all_versions = {
+ prov_item.item_id: prov_item
+ for prov_items in await asyncio.gather(
+ *[
+ self.search(search_query, provider_domain)
+ for provider_domain in provider_domains
+ ]
+ )
+ for prov_item in prov_items
+ if loose_compare_strings(album.name, prov_item.name)
+ }
+ # make sure that the 'base' version is included
+ for prov_version in album.provider_mappings:
+ if prov_version.item_id in all_versions:
+ continue
+ album_copy = Album.from_dict(album.to_dict())
+ album_copy.item_id = prov_version.item_id
+ album_copy.provider = prov_version.provider_domain
+ album_copy.provider_mappings = {prov_version}
+ all_versions[prov_version.item_id] = album_copy
+
+ # return the aggregated result
+ return all_versions.values()
+
+ async def add(self, item: Album) -> Album:
+ """Add album to local db and return the database item."""
+ # grab additional metadata
+ await self.mass.metadata.get_album_metadata(item)
+ existing = await self.get_db_item_by_prov_id(item.item_id, item.provider)
+ if existing:
+ db_item = await self.update_db_item(existing.item_id, item)
+ else:
+ db_item = await self.add_db_item(item)
+ # also fetch same album on all providers
+ await self._match(db_item)
+ # return final db_item after all match/metadata actions
+ db_item = await self.get_db_item(db_item.item_id)
+ # dump album tracks in db
+ for prov_mapping in db_item.provider_mappings:
+ for track in await self._get_provider_album_tracks(
+ prov_mapping.item_id, prov_mapping.provider_instance
+ ):
+ await self.mass.music.tracks.add_db_item(track)
+ self.mass.signal_event(
+ EventType.MEDIA_ITEM_UPDATED if existing else EventType.MEDIA_ITEM_ADDED,
+ db_item.uri,
+ db_item,
+ )
+ return db_item
+
+ async def add_db_item(self, item: Album, overwrite_existing: bool = False) -> Album:
+ """Add a new record to the database."""
+ assert item.provider_mappings, f"Album {item.name} is missing provider id(s)"
+ assert item.artist, f"Album {item.name} is missing artist"
+ async with self._db_add_lock:
+ cur_item = None
+ # always try to grab existing item by musicbrainz_id/upc
+ if item.musicbrainz_id:
+ match = {"musicbrainz_id": item.musicbrainz_id}
+ cur_item = await self.mass.music.database.get_row(self.db_table, match)
+ if not cur_item and item.upc:
+ match = {"upc": item.upc}
+ cur_item = await self.mass.music.database.get_row(self.db_table, match)
+ if not cur_item:
+ # fallback to search and match
+ for row in await self.mass.music.database.search(self.db_table, item.name):
+ row_album = Album.from_db_row(row)
+ if compare_album(row_album, item):
+ cur_item = row_album
+ break
+ if cur_item:
+ # update existing
+ return await self.update_db_item(
+ cur_item.item_id, item, overwrite=overwrite_existing
+ )
+
+ # insert new item
+ album_artists = await self._get_album_artists(item, cur_item)
+ sort_artist = album_artists[0].sort_name if album_artists else ""
+ new_item = await self.mass.music.database.insert(
+ self.db_table,
+ {
+ **item.to_db_row(),
+ "artists": json_dumps(album_artists) or None,
+ "sort_artist": sort_artist,
+ },
+ )
+ item_id = new_item["item_id"]
+ # update/set provider_mappings table
+ await self._set_provider_mappings(item_id, item.provider_mappings)
+ self.logger.debug("added %s to database", item.name)
+ # return created object
+ return await self.get_db_item(item_id)
+
+ async def update_db_item(
+ self,
+ item_id: int,
+ item: Album,
+ overwrite: bool = False,
+ ) -> Album:
+ """Update Album record in the database."""
+ assert item.provider_mappings, f"Album {item.name} is missing provider id(s)"
+ assert item.artist, f"Album {item.name} is missing artist"
+ cur_item = await self.get_db_item(item_id)
+
+ if overwrite:
+ metadata = item.metadata
+ metadata.last_refresh = None
+ provider_mappings = item.provider_mappings
+ album_artists = await self._get_album_artists(item, overwrite=True)
+ else:
+ is_file_provider = item.provider.startswith("filesystem")
+ metadata = cur_item.metadata.update(item.metadata, is_file_provider)
+ provider_mappings = {*cur_item.provider_mappings, *item.provider_mappings}
+ album_artists = await self._get_album_artists(item, cur_item)
+
+ if item.album_type != AlbumType.UNKNOWN:
+ album_type = item.album_type
+ else:
+ album_type = cur_item.album_type
+
+ sort_artist = album_artists[0].sort_name if album_artists else ""
+
+ await self.mass.music.database.update(
+ self.db_table,
+ {"item_id": item_id},
+ {
+ "name": item.name if overwrite else cur_item.name,
+ "sort_name": item.sort_name if overwrite else cur_item.sort_name,
+ "sort_artist": sort_artist,
+ "version": item.version if overwrite else cur_item.version,
+ "year": item.year or cur_item.year,
+ "upc": item.upc or cur_item.upc,
+ "album_type": album_type,
+ "artists": json_dumps(album_artists) or None,
+ "metadata": json_dumps(metadata),
+ "provider_mappings": json_dumps(provider_mappings),
+ "musicbrainz_id": item.musicbrainz_id or cur_item.musicbrainz_id,
+ },
+ )
+ # update/set provider_mappings table
+ await self._set_provider_mappings(item_id, provider_mappings)
+ self.logger.debug("updated %s in database: %s", item.name, item_id)
+ return await self.get_db_item(item_id)
+
+ async def delete_db_item(self, item_id: int, recursive: bool = False) -> None:
+ """Delete record from the database."""
+ # check album tracks
+ db_rows = await self.mass.music.database.get_rows_from_query(
+ f"SELECT item_id FROM {DB_TABLE_TRACKS} WHERE albums LIKE '%\"{item_id}\"%'",
+ limit=5000,
+ )
+ assert not (db_rows and not recursive), "Tracks attached to album"
+ for db_row in db_rows:
+ with contextlib.suppress(MediaNotFoundError):
+ await self.mass.music.albums.delete_db_item(db_row["item_id"], recursive)
+
+ # delete the album itself from db
+ await super().delete_db_item(item_id)
+
+ async def _get_provider_album_tracks(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ ) -> list[Track]:
+ """Return album tracks for the given provider album id."""
+ prov = self.mass.get_provider(provider_instance or provider_domain)
+ if not prov:
+ return []
+ full_album = await self.get_provider_item(item_id, provider_instance or provider_domain)
+ # prefer cache items (if any)
+ cache_key = f"{prov.instance_id}.albumtracks.{item_id}"
+ cache_checksum = full_album.metadata.checksum
+ if cache := await self.mass.cache.get(cache_key, checksum=cache_checksum):
+ return [Track.from_dict(x) for x in cache]
+ # no items in cache - get listing from provider
+ items = []
+ for track in await prov.get_album_tracks(item_id):
+ # make sure that the (full) album is stored on the tracks
+ track.album = full_album
+ if full_album.metadata.images:
+ track.metadata.images = full_album.metadata.images
+ items.append(track)
+ # store (serializable items) in cache
+ self.mass.create_task(
+ self.mass.cache.set(cache_key, [x.to_dict() for x in items], checksum=cache_checksum)
+ )
+ return items
+
+ async def _get_provider_dynamic_tracks(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ limit: int = 25,
+ ):
+ """Generate a dynamic list of tracks based on the album content."""
+ prov = self.mass.get_provider(provider_instance or provider_domain)
+ if not prov or ProviderFeature.SIMILAR_TRACKS not in prov.supported_features:
+ return []
+ album_tracks = await self._get_provider_album_tracks(
+ item_id=item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ )
+ # Grab a random track from the album that we use to obtain similar tracks for
+ track = choice(album_tracks)
+ # Calculate no of songs to grab from each list at a 10/90 ratio
+ total_no_of_tracks = limit + limit % 2
+ no_of_album_tracks = int(total_no_of_tracks * 10 / 100)
+ no_of_similar_tracks = int(total_no_of_tracks * 90 / 100)
+ # Grab similar tracks from the music provider
+ similar_tracks = await prov.get_similar_tracks(
+ prov_track_id=track.item_id, limit=no_of_similar_tracks
+ )
+ # Merge album content with similar tracks
+ # ruff: noqa: ARG005
+ dynamic_playlist = [
+ *sorted(album_tracks, key=lambda n: random())[:no_of_album_tracks],
+ *sorted(similar_tracks, key=lambda n: random())[:no_of_similar_tracks],
+ ]
+ return sorted(dynamic_playlist, key=lambda n: random())
+
+ async def _get_dynamic_tracks(
+ self, media_item: Album, limit: int = 25 # noqa: ARG002
+ ) -> list[Track]:
+ """Get dynamic list of tracks for given item, fallback/default implementation."""
+ # TODO: query metadata provider(s) to get similar tracks (or tracks from similar artists)
+ raise UnsupportedFeaturedException(
+ "No Music Provider found that supports requesting similar tracks."
+ )
+
+ async def _get_db_album_tracks(
+ self,
+ item_id: str,
+ ) -> list[Track]:
+ """Return in-database album tracks for the given database album."""
+ db_album = await self.get_db_item(item_id)
+ # simply grab all tracks in the db that are linked to this album
+ # TODO: adjust to json query instead of text search?
+ query = f"SELECT * FROM tracks WHERE albums LIKE '%\"{item_id}\"%'"
+ result = []
+ for track in await self.mass.music.tracks.get_db_items_by_query(query):
+ if album_mapping := next(
+ (x for x in track.albums if x.item_id == db_album.item_id), None
+ ):
+ # make sure that the full album is set on the track and prefer the album's images
+ track.album = db_album
+ if db_album.metadata.images:
+ track.metadata.images = db_album.metadata.images
+ # apply the disc and track number from the mapping
+ track.disc_number = album_mapping.disc_number
+ track.track_number = album_mapping.track_number
+ result.append(track)
+ return sorted(result, key=lambda x: (x.disc_number or 0, x.track_number or 0))
+
+ async def _match(self, db_album: Album) -> None:
+ """Try to find matching album on all providers for the provided (database) album.
+
+ This is used to link objects of different providers/qualities together.
+ """
+ if db_album.provider != "database":
+ return # Matching only supported for database items
+
+ async def find_prov_match(provider: MusicProvider):
+ self.logger.debug(
+ "Trying to match album %s on provider %s", db_album.name, provider.name
+ )
+ match_found = False
+ for search_str in (
+ db_album.name,
+ f"{db_album.artist.name} - {db_album.name}",
+ f"{db_album.artist.name} {db_album.name}",
+ ):
+ if match_found:
+ break
+ search_result = await self.search(search_str, provider.instance_id)
+ for search_result_item in search_result:
+ if not search_result_item.available:
+ continue
+ if not compare_album(search_result_item, db_album):
+ continue
+ # we must fetch the full album version, search results are simplified objects
+ prov_album = await self.get_provider_item(
+ search_result_item.item_id, search_result_item.provider
+ )
+ if compare_album(prov_album, db_album):
+ # 100% match, we can simply update the db with additional provider ids
+ await self.update_db_item(db_album.item_id, prov_album)
+ match_found = True
+ return match_found
+
+ # try to find match on all providers
+ cur_provider_domains = {x.provider_domain for x in db_album.provider_mappings}
+ for provider in self.mass.music.providers:
+ if provider.domain in cur_provider_domains:
+ continue
+ if ProviderFeature.SEARCH not in provider.supported_features:
+ continue
+ if await find_prov_match(provider):
+ cur_provider_domains.add(provider.domain)
+ else:
+ self.logger.debug(
+ "Could not find match for Album %s on provider %s",
+ db_album.name,
+ provider.name,
+ )
+
+ async def _get_album_artists(
+ self,
+ db_album: Album,
+ updated_album: Album | None = None,
+ overwrite: bool = False,
+ ) -> list[ItemMapping]:
+ """Extract (database) album artist(s) as ItemMapping."""
+ album_artists = set()
+ for album in (updated_album, db_album):
+ if not album:
+ continue
+ for artist in album.artists:
+ album_artists.add(await self._get_artist_mapping(artist, overwrite))
+ # use intermediate set to prevent duplicates
+ # filter various artists if multiple artists
+ if len(album_artists) > 1:
+ album_artists = {x for x in album_artists if (x.name != VARIOUS_ARTISTS)}
+ return list(album_artists)
+
+ async def _get_artist_mapping(
+ self, artist: Artist | ItemMapping, overwrite: bool = False
+ ) -> ItemMapping:
+ """Extract (database) track artist as ItemMapping."""
+ if overwrite:
+ artist = await self.mass.music.artists.add_db_item(artist, overwrite_existing=True)
+ if artist.provider == "database":
+ if isinstance(artist, ItemMapping):
+ return artist
+ return ItemMapping.from_item(artist)
+
+ if db_artist := await self.mass.music.artists.get_db_item_by_prov_id(
+ artist.item_id, provider_domain=artist.provider
+ ):
+ return ItemMapping.from_item(db_artist)
+
+ db_artist = await self.mass.music.artists.add_db_item(artist)
+ return ItemMapping.from_item(db_artist)
--- /dev/null
+"""Manage MediaItems of type Artist."""
+from __future__ import annotations
+
+import asyncio
+import contextlib
+import itertools
+from random import choice, random
+from time import time
+from typing import TYPE_CHECKING, Any
+
+from music_assistant.common.helpers.json import json_dumps
+from music_assistant.common.models.enums import EventType, ProviderFeature
+from music_assistant.common.models.errors import MediaNotFoundError, UnsupportedFeaturedException
+from music_assistant.common.models.media_items import (
+ Album,
+ AlbumType,
+ Artist,
+ ItemMapping,
+ MediaType,
+ PagedItems,
+ Track,
+)
+from music_assistant.constants import VARIOUS_ARTISTS, VARIOUS_ARTISTS_ID
+from music_assistant.server.controllers.media.base import MediaControllerBase
+from music_assistant.server.controllers.music import (
+ DB_TABLE_ALBUMS,
+ DB_TABLE_ARTISTS,
+ DB_TABLE_TRACKS,
+)
+from music_assistant.server.helpers.compare import compare_strings
+
+if TYPE_CHECKING:
+ from music_assistant.server.models.music_provider import MusicProvider
+
+
+class ArtistsController(MediaControllerBase[Artist]):
+ """Controller managing MediaItems of type Artist."""
+
+ db_table = DB_TABLE_ARTISTS
+ media_type = MediaType.ARTIST
+ item_cls = Artist
+
+ def __init__(self, *args, **kwargs):
+ """Initialize class."""
+ super().__init__(*args, **kwargs)
+ # register api handlers
+ self.mass.register_api_command("music/artists", self.db_items)
+ self.mass.register_api_command("music/albumartists", self.album_artists)
+ self.mass.register_api_command("music/artist", self.get)
+ self.mass.register_api_command("music/artist/albums", self.albums)
+ self.mass.register_api_command("music/artist/tracks", self.tracks)
+ self.mass.register_api_command("music/artist/update", self.update_db_item)
+ self.mass.register_api_command("music/artist/delete", self.delete_db_item)
+
+ async def album_artists(
+ self,
+ in_library: bool | None = None,
+ search: str | None = None,
+ limit: int = 500,
+ offset: int = 0,
+ order_by: str = "sort_name",
+ ) -> PagedItems:
+ """Get in-database album artists."""
+ return await self.db_items(
+ in_library=in_library,
+ search=search,
+ limit=limit,
+ offset=offset,
+ order_by=order_by,
+ query_parts=["artists.sort_name in (select albums.sort_artist from albums)"],
+ )
+
+ async def tracks(
+ self,
+ item_id: str | None = None,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ artist: Artist | None = None,
+ ) -> list[Track]:
+ """Return top tracks for an artist."""
+ if not artist:
+ artist = await self.get(item_id, provider_domain, provider_instance)
+ # get results from all providers
+ coros = [
+ self.get_provider_artist_toptracks(
+ prov_mapping.item_id,
+ provider_domain=prov_mapping.provider_domain,
+ provider_instance=prov_mapping.provider_instance,
+ cache_checksum=artist.metadata.checksum,
+ )
+ for prov_mapping in artist.provider_mappings
+ ]
+ tracks = itertools.chain.from_iterable(await asyncio.gather(*coros))
+ # merge duplicates using a dict
+ final_items: dict[str, Track] = {}
+ for track in tracks:
+ key = f".{track.name}.{track.version}"
+ if key in final_items:
+ final_items[key].provider_mappings.update(track.provider_mappings)
+ else:
+ final_items[key] = track
+ return list(final_items.values())
+
+ async def albums(
+ self,
+ item_id: str | None = None,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ artist: Artist | None = None,
+ ) -> list[Album]:
+ """Return (all/most popular) albums for an artist."""
+ if not artist:
+ artist = await self.get(item_id, provider_domain or provider_instance)
+ # get results from all providers
+ coros = [
+ self.get_provider_artist_albums(
+ item.item_id,
+ item.provider_domain,
+ cache_checksum=artist.metadata.checksum,
+ )
+ for item in artist.provider_mappings
+ ]
+ albums = itertools.chain.from_iterable(await asyncio.gather(*coros))
+ # merge duplicates using a dict
+ final_items: dict[str, Album] = {}
+ for album in albums:
+ key = f".{album.name}.{album.version}"
+ if key in final_items:
+ final_items[key].provider_mappings.update(album.provider_mappings)
+ else:
+ final_items[key] = album
+ if album.in_library:
+ final_items[key].in_library = True
+ return list(final_items.values())
+
+ async def add(self, item: Artist) -> Artist:
+ """Add artist to local db and return the database item."""
+ # grab musicbrainz id and additional metadata
+ await self.mass.metadata.get_artist_metadata(item)
+ existing = await self.get_db_item_by_prov_id(item.item_id, item.provider)
+ if existing:
+ db_item = await self.update_db_item(existing.item_id, item)
+ else:
+ db_item = await self.add_db_item(item)
+ # also fetch same artist on all providers
+ await self.match_artist(db_item)
+ # return final db_item after all match/metadata actions
+ db_item = await self.get_db_item(db_item.item_id)
+ self.mass.signal_event(
+ EventType.MEDIA_ITEM_UPDATED if existing else EventType.MEDIA_ITEM_ADDED,
+ db_item.uri,
+ db_item,
+ )
+ return db_item
+
+ async def match_artist(self, db_artist: Artist):
+ """Try to find matching artists on all providers for the provided (database) item_id.
+
+ This is used to link objects of different providers together.
+ """
+ assert db_artist.provider == "database", "Matching only supported for database items!"
+ cur_provider_domains = {x.provider_domain for x in db_artist.provider_mappings}
+ for provider in self.mass.music.providers:
+ if provider.domain in cur_provider_domains:
+ continue
+ if ProviderFeature.SEARCH not in provider.supported_features:
+ continue
+ if await self._match(db_artist, provider):
+ cur_provider_domains.add(provider.domain)
+ else:
+ self.logger.debug(
+ "Could not find match for Artist %s on provider %s",
+ db_artist.name,
+ provider.name,
+ )
+
+ async def get_provider_artist_toptracks(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ cache_checksum: Any = None,
+ ) -> list[Track]:
+ """Return top tracks for an artist on given provider."""
+ prov = self.mass.get_provider(provider_instance or provider_domain)
+ if not prov:
+ return []
+ # prefer cache items (if any)
+ cache_key = f"{prov.instance_id}.artist_toptracks.{item_id}"
+ if cache := await self.mass.cache.get(cache_key, checksum=cache_checksum):
+ return [Track.from_dict(x) for x in cache]
+ # no items in cache - get listing from provider
+ if ProviderFeature.ARTIST_TOPTRACKS in prov.supported_features:
+ items = await prov.get_artist_toptracks(item_id)
+ else:
+ # fallback implementation using the db
+ items = []
+ if db_artist := await self.mass.music.artists.get_db_item_by_prov_id(
+ item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ ):
+ prov_id = provider_instance or provider_domain
+ # TODO: adjust to json query instead of text search?
+ query = f"SELECT * FROM tracks WHERE artists LIKE '%\"{db_artist.item_id}\"%'"
+ query += f" AND provider_mappings LIKE '%\"{prov_id}\"%'"
+ items = await self.mass.music.tracks.get_db_items_by_query(query)
+ # store (serializable items) in cache
+ self.mass.create_task(
+ self.mass.cache.set(cache_key, [x.to_dict() for x in items], checksum=cache_checksum)
+ )
+ return items
+
+ async def get_provider_artist_albums(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ cache_checksum: Any = None,
+ ) -> list[Album]:
+ """Return albums for an artist on given provider."""
+ prov = self.mass.get_provider(provider_instance or provider_domain)
+ if not prov:
+ return []
+ # prefer cache items (if any)
+ cache_key = f"{prov.instance_id}.artist_albums.{item_id}"
+ if cache := await self.mass.cache.get(cache_key, checksum=cache_checksum):
+ return [Album.from_dict(x) for x in cache]
+ # no items in cache - get listing from provider
+ if ProviderFeature.ARTIST_ALBUMS in prov.supported_features:
+ items = await prov.get_artist_albums(item_id)
+ else:
+ # fallback implementation using the db
+ if db_artist := await self.mass.music.artists.get_db_item_by_prov_id( # noqa: PLR5501
+ item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ ):
+ prov_id = provider_instance or provider_domain
+ # TODO: adjust to json query instead of text search?
+ query = f"SELECT * FROM albums WHERE artists LIKE '%\"{db_artist.item_id}\"%'"
+ query += f" AND provider_mappings LIKE '%\"{prov_id}\"%'"
+ items = await self.mass.music.albums.get_db_items_by_query(query)
+ else:
+ # edge case
+ items = []
+ # store (serializable items) in cache
+ self.mass.create_task(
+ self.mass.cache.set(cache_key, [x.to_dict() for x in items], checksum=cache_checksum)
+ )
+ return items
+
+ async def add_db_item(self, item: Artist, overwrite_existing: bool = False) -> Artist:
+ """Add a new item record to the database."""
+ assert isinstance(item, Artist), "Not a full Artist object"
+ assert item.provider_mappings, "Artist is missing provider id(s)"
+ # enforce various artists name + id
+ if compare_strings(item.name, VARIOUS_ARTISTS):
+ item.musicbrainz_id = VARIOUS_ARTISTS_ID
+ if item.musicbrainz_id == VARIOUS_ARTISTS_ID:
+ item.name = VARIOUS_ARTISTS
+
+ async with self._db_add_lock:
+ # always try to grab existing item by musicbrainz_id
+ cur_item = None
+ if item.musicbrainz_id:
+ match = {"musicbrainz_id": item.musicbrainz_id}
+ cur_item = await self.mass.music.database.get_row(self.db_table, match)
+ if not cur_item:
+ # fallback to exact name match
+ # NOTE: we match an artist by name which could theoretically lead to collisions
+ # but the chance is so small it is not worth the additional overhead of grabbing
+ # the musicbrainz id upfront
+ match = {"sort_name": item.sort_name}
+ for row in await self.mass.music.database.get_rows(self.db_table, match):
+ row_artist = Artist.from_db_row(row)
+ if row_artist.sort_name == item.sort_name:
+ cur_item = row_artist
+ break
+ if cur_item:
+ # update existing
+ return await self.update_db_item(
+ cur_item.item_id, item, overwrite=overwrite_existing
+ )
+
+ # insert item
+ if item.in_library and not item.timestamp:
+ item.timestamp = int(time())
+ new_item = await self.mass.music.database.insert(self.db_table, item.to_db_row())
+ item_id = new_item["item_id"]
+ # update/set provider_mappings table
+ await self._set_provider_mappings(item_id, item.provider_mappings)
+ self.logger.debug("added %s to database", item.name)
+ # return created object
+ return await self.get_db_item(item_id)
+
+ async def update_db_item(
+ self,
+ item_id: int,
+ item: Artist,
+ overwrite: bool = False,
+ ) -> Artist:
+ """Update Artist record in the database."""
+ cur_item = await self.get_db_item(item_id)
+ if overwrite:
+ metadata = item.metadata
+ provider_mappings = item.provider_mappings
+ else:
+ is_file_provider = item.provider.startswith("filesystem")
+ metadata = cur_item.metadata.update(item.metadata, is_file_provider)
+ provider_mappings = {*cur_item.provider_mappings, *item.provider_mappings}
+
+ # enforce various artists name + id
+ if compare_strings(item.name, VARIOUS_ARTISTS):
+ item.musicbrainz_id = VARIOUS_ARTISTS_ID
+ if item.musicbrainz_id == VARIOUS_ARTISTS_ID:
+ item.name = VARIOUS_ARTISTS
+
+ await self.mass.music.database.update(
+ self.db_table,
+ {"item_id": item_id},
+ {
+ "name": item.name if overwrite else cur_item.name,
+ "sort_name": item.sort_name if overwrite else cur_item.sort_name,
+ "musicbrainz_id": item.musicbrainz_id or cur_item.musicbrainz_id,
+ "metadata": json_dumps(metadata),
+ "provider_mappings": json_dumps(provider_mappings),
+ },
+ )
+ # update/set provider_mappings table
+ await self._set_provider_mappings(item_id, provider_mappings)
+ self.logger.debug("updated %s in database: %s", item.name, item_id)
+ return await self.get_db_item(item_id)
+
+ async def delete_db_item(self, item_id: int, recursive: bool = False) -> None:
+ """Delete record from the database."""
+ # check artist albums
+ db_rows = await self.mass.music.database.get_rows_from_query(
+ f"SELECT item_id FROM {DB_TABLE_ALBUMS} WHERE artists LIKE '%\"{item_id}\"%'",
+ limit=5000,
+ )
+ assert not (db_rows and not recursive), "Albums attached to artist"
+ for db_row in db_rows:
+ with contextlib.suppress(MediaNotFoundError):
+ await self.mass.music.albums.delete_db_item(db_row["item_id"], recursive)
+
+ # check artist tracks
+ db_rows = await self.mass.music.database.get_rows_from_query(
+ f"SELECT item_id FROM {DB_TABLE_TRACKS} WHERE artists LIKE '%\"{item_id}\"%'",
+ limit=5000,
+ )
+ assert not (db_rows and not recursive), "Tracks attached to artist"
+ for db_row in db_rows:
+ with contextlib.suppress(MediaNotFoundError):
+ await self.mass.music.albums.delete_db_item(db_row["item_id"], recursive)
+
+ # delete the artist itself from db
+ await super().delete_db_item(item_id)
+
+ async def _get_provider_dynamic_tracks(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ limit: int = 25,
+ ):
+ """Generate a dynamic list of tracks based on the artist's top tracks."""
+ prov = self.mass.get_provider(provider_instance or provider_domain)
+ if not prov or ProviderFeature.SIMILAR_TRACKS not in prov.supported_features:
+ return []
+ top_tracks = await self.get_provider_artist_toptracks(
+ item_id=item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ )
+ # Grab a random track from the album that we use to obtain similar tracks for
+ track = choice(top_tracks)
+ # Calculate no of songs to grab from each list at a 10/90 ratio
+ total_no_of_tracks = limit + limit % 2
+ no_of_artist_tracks = int(total_no_of_tracks * 10 / 100)
+ no_of_similar_tracks = int(total_no_of_tracks * 90 / 100)
+ # Grab similar tracks from the music provider
+ similar_tracks = await prov.get_similar_tracks(
+ prov_track_id=track.item_id, limit=no_of_similar_tracks
+ )
+ # Merge album content with similar tracks
+ dynamic_playlist = [
+ *sorted(top_tracks, key=lambda n: random())[:no_of_artist_tracks], # noqa: ARG005
+ *sorted(similar_tracks, key=lambda n: random())[:no_of_similar_tracks], # noqa: ARG005
+ ]
+ return sorted(dynamic_playlist, key=lambda n: random()) # noqa: ARG005
+
+ async def _get_dynamic_tracks(
+ self, media_item: Artist, limit: int = 25 # noqa: ARG002
+ ) -> list[Track]:
+ """Get dynamic list of tracks for given item, fallback/default implementation."""
+ # TODO: query metadata provider(s) to get similar tracks (or tracks from similar artists)
+ raise UnsupportedFeaturedException(
+ "No Music Provider found that supports requesting similar tracks."
+ )
+
+ async def _match(self, db_artist: Artist, provider: MusicProvider) -> bool:
+ """Try to find matching artists on given provider for the provided (database) artist."""
+ self.logger.debug("Trying to match artist %s on provider %s", db_artist.name, provider.name)
+ # try to get a match with some reference tracks of this artist
+ for ref_track in await self.tracks(db_artist.item_id, db_artist.provider, artist=db_artist):
+ # make sure we have a full track
+ if isinstance(ref_track.album, ItemMapping):
+ ref_track = await self.mass.music.tracks.get( # noqa: PLW2901
+ ref_track.item_id, ref_track.provider
+ )
+ for search_str in (
+ f"{db_artist.name} - {ref_track.name}",
+ f"{db_artist.name} {ref_track.name}",
+ ref_track.name,
+ ):
+ search_results = await self.mass.music.tracks.search(search_str, provider.domain)
+ for search_result_item in search_results:
+ if search_result_item.sort_name != ref_track.sort_name:
+ continue
+ # get matching artist from track
+ for search_item_artist in search_result_item.artists:
+ if search_item_artist.sort_name != db_artist.sort_name:
+ continue
+ # 100% album match
+ # get full artist details so we have all metadata
+ prov_artist = await self.get_provider_item(
+ search_item_artist.item_id, search_item_artist.provider
+ )
+ await self.update_db_item(db_artist.item_id, prov_artist)
+ return True
+ # try to get a match with some reference albums of this artist
+ artist_albums = await self.albums(db_artist.item_id, db_artist.provider, artist=db_artist)
+ for ref_album in artist_albums:
+ if ref_album.album_type == AlbumType.COMPILATION:
+ continue
+ if ref_album.artist is None:
+ continue
+ for search_str in (
+ ref_album.name,
+ f"{db_artist.name} - {ref_album.name}",
+ f"{db_artist.name} {ref_album.name}",
+ ):
+ search_result = await self.mass.music.albums.search(search_str, provider.domain)
+ for search_result_item in search_result:
+ if search_result_item.artist is None:
+ continue
+ if search_result_item.sort_name != ref_album.sort_name:
+ continue
+ # artist must match 100%
+ if search_result_item.artist.sort_name != ref_album.artist.sort_name:
+ continue
+ # 100% match
+ # get full artist details so we have all metadata
+ prov_artist = await self.get_provider_item(
+ search_result_item.artist.item_id,
+ search_result_item.artist.provider,
+ )
+ await self.update_db_item(db_artist.item_id, prov_artist)
+ return True
+ return False
--- /dev/null
+"""Base (ABC) MediaType specific controller."""
+from __future__ import annotations
+
+import asyncio
+import logging
+from abc import ABCMeta, abstractmethod
+from collections.abc import AsyncGenerator
+from time import time
+from typing import TYPE_CHECKING, Generic, TypeVar
+
+from music_assistant.common.helpers.json import json_dumps
+from music_assistant.common.models.enums import EventType, MediaType, ProviderFeature
+from music_assistant.common.models.errors import MediaNotFoundError
+from music_assistant.common.models.media_items import (
+ MediaItemType,
+ PagedItems,
+ ProviderMapping,
+ Track,
+ media_from_dict,
+)
+from music_assistant.constants import DB_TABLE_PROVIDER_MAPPINGS, ROOT_LOGGER_NAME
+
+if TYPE_CHECKING:
+ from music_assistant.server import MusicAssistant
+
+ItemCls = TypeVar("ItemCls", bound="MediaItemType")
+
+REFRESH_INTERVAL = 60 * 60 * 24 * 30
+
+
+class MediaControllerBase(Generic[ItemCls], metaclass=ABCMeta):
+ """Base model for controller managing a MediaType."""
+
+ media_type: MediaType
+ item_cls: MediaItemType
+ db_table: str
+
+ def __init__(self, mass: MusicAssistant):
+ """Initialize class."""
+ self.mass = mass
+ self.logger = logging.getLogger(f"{ROOT_LOGGER_NAME}.music.{self.media_type.value}")
+ self._db_add_lock = asyncio.Lock()
+
+ @abstractmethod
+ async def add(self, item: ItemCls) -> ItemCls:
+ """Add item to local db and return the database item."""
+ raise NotImplementedError
+
+ @abstractmethod
+ async def add_db_item(self, item: ItemCls, overwrite_existing: bool = False) -> ItemCls:
+ """Add a new record for this mediatype to the database."""
+ raise NotImplementedError
+
+ @abstractmethod
+ async def update_db_item(
+ self,
+ item_id: int,
+ item: ItemCls,
+ overwrite: bool = False,
+ ) -> ItemCls:
+ """Update record in the database, merging data."""
+ raise NotImplementedError
+
+ async def db_items(
+ self,
+ in_library: bool | None = None,
+ search: str | None = None,
+ limit: int = 500,
+ offset: int = 0,
+ order_by: str = "sort_name",
+ query_parts: list[str] | None = None,
+ ) -> PagedItems:
+ """Get in-database items."""
+ sql_query = f"SELECT * FROM {self.db_table}"
+ params = {}
+ query_parts = query_parts or []
+ if search:
+ params["search"] = f"%{search}%"
+ if self.media_type in (MediaType.ALBUM, MediaType.TRACK):
+ query_parts.append("(name LIKE :search or artists LIKE :search)")
+ else:
+ query_parts.append("name LIKE :search")
+ if in_library is not None:
+ query_parts.append("in_library = :in_library")
+ params["in_library"] = in_library
+ if query_parts:
+ sql_query += " WHERE " + " AND ".join(query_parts)
+ sql_query += f" ORDER BY {order_by}"
+ items = await self.get_db_items_by_query(sql_query, params, limit=limit, offset=offset)
+ count = len(items)
+ if 0 < count < limit:
+ total = offset + count
+ else:
+ total = await self.mass.music.database.get_count_from_query(sql_query, params)
+ return PagedItems(items, count, limit, offset, total)
+
+ async def iter_db_items(
+ self,
+ in_library: bool | None = None,
+ search: str | None = None,
+ order_by: str = "sort_name",
+ ) -> AsyncGenerator[ItemCls, None]:
+ """Iterate all in-database items."""
+ limit: int = 500
+ offset: int = 0
+ while True:
+ next_items = await self.db_items(
+ in_library=in_library,
+ search=search,
+ limit=limit,
+ offset=offset,
+ order_by=order_by,
+ )
+ for item in next_items.items:
+ yield item
+ if next_items.count < limit:
+ break
+ offset += limit
+
+ async def get(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ force_refresh: bool = False,
+ lazy: bool = True,
+ details: ItemCls = None,
+ force_provider_item: bool = False,
+ ) -> ItemCls:
+ """Return (full) details for a single media item."""
+ assert (
+ provider_domain or provider_instance
+ ), "provider_domain or provider_instance must be supplied"
+ if force_provider_item:
+ return await self.get_provider_item(item_id, provider_instance)
+ db_item = await self.get_db_item_by_prov_id(
+ item_id=item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ )
+ if db_item and (time() - db_item.last_refresh) > REFRESH_INTERVAL:
+ # it's been too long since the full metadata was last retrieved (or never at all)
+ force_refresh = True
+ if db_item and force_refresh:
+ # get (first) provider item id belonging to this db item
+ provider_instance, item_id = await self.get_provider_mapping(db_item)
+ elif db_item:
+ # we have a db item and no refreshing is needed, return the results!
+ return db_item
+ if not details and provider_instance:
+ # no details provider nor in db, fetch them from the provider
+ details = await self.get_provider_item(item_id, provider_instance)
+ if not details and provider_domain:
+ # check providers for given provider domain one by one
+ for prov in self.mass.music.providers:
+ if not prov.available:
+ continue
+ if prov.domain == provider_domain:
+ try:
+ details = await self.get_provider_item(item_id, prov.domain)
+ except MediaNotFoundError:
+ pass
+ else:
+ break
+ if not details:
+ # we couldn't get a match from any of the providers, raise error
+ raise MediaNotFoundError(f"Item not found: {provider_domain or id}/{item_id}")
+ # create task to add the item to the db, including matching metadata etc. takes some time
+ # in 99% of the cases we just return lazy because we want the details as fast as possible
+ # only if we really need to wait for the result (e.g. to prevent race conditions), we
+ # can set lazy to false and we await to job to complete.
+ add_task = self.mass.create_task(self.add(details))
+ if not lazy:
+ await add_task
+ return add_task.result()
+
+ return details
+
+ async def search(
+ self,
+ search_query: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ limit: int = 25,
+ ) -> list[ItemCls]:
+ """Search database or provider with given query."""
+ # create safe search string
+ search_query = search_query.replace("/", " ").replace("'", "")
+ if "database" in (provider_domain, provider_instance):
+ return [
+ self.item_cls.from_db_row(db_row)
+ for db_row in await self.mass.music.database.search(self.db_table, search_query)
+ ]
+
+ prov = self.mass.get_provider(provider_instance or provider_domain)
+ if not prov or ProviderFeature.SEARCH not in prov.supported_features:
+ return []
+ if not prov.library_supported(self.media_type):
+ # assume library supported also means that this mediatype is supported
+ return []
+
+ # prefer cache items (if any)
+ cache_key = f"{prov.instance_id}.search.{self.media_type.value}.{search_query}.{limit}"
+ if cache := await self.mass.cache.get(cache_key):
+ return [media_from_dict(x) for x in cache]
+ # no items in cache - get listing from provider
+ items = await prov.search(
+ search_query,
+ [self.media_type],
+ limit,
+ )
+ # store (serializable items) in cache
+ if not prov.domain.startswith("filesystem"): # do not cache filesystem results
+ self.mass.create_task(
+ self.mass.cache.set(cache_key, [x.to_dict() for x in items], expiration=86400 * 7)
+ )
+ return items
+
+ async def add_to_library(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ ) -> None:
+ """Add an item to the library."""
+ prov_item = await self.get_db_item_by_prov_id(
+ item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ )
+ if prov_item is None:
+ prov_item = await self.get_provider_item(item_id, provider_instance or provider_domain)
+ if prov_item.in_library is True:
+ return
+ # mark as favorite/library item on provider(s)
+ for prov_mapping in prov_item.provider_mappings:
+ if prov := self.mass.get_provider(prov_mapping.provider_instance):
+ if not prov.library_edit_supported(self.media_type):
+ continue
+ await prov.library_add(provider_instance.item_id, self.media_type)
+ # mark as library item in internal db if db item
+ if prov_item.provider == "database" and not prov_item.in_library:
+ prov_item.in_library = True
+ await self.set_db_library(prov_item.item_id, True)
+
+ async def remove_from_library(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ ) -> None:
+ """Remove item from the library."""
+ prov_item = await self.get_db_item_by_prov_id(
+ item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ )
+ if prov_item is None:
+ prov_item = await self.get_provider_item(item_id, provider_instance or provider_domain)
+ if prov_item.in_library is False:
+ return
+ # unmark as favorite/library item on provider(s)
+ for prov_mapping in prov_item.provider_mappings:
+ if prov := self.mass.get_provider(prov_mapping.provider_instance):
+ if not prov.library_edit_supported(self.media_type):
+ continue
+ await prov.library_remove(prov_mapping.item_id, self.media_type)
+ # unmark as library item in internal db if db item
+ if prov_item.provider == "database":
+ prov_item.in_library = False
+ await self.set_db_library(prov_item.item_id, False)
+
+ async def get_provider_mapping(self, item: ItemCls) -> tuple[str, str]:
+ """Return (first) provider and item id."""
+ if item.provider == "database":
+ # make sure we have a full object
+ item = await self.get_db_item(item.item_id)
+ for prefer_file in (True, False):
+ for prov_mapping in item.provider_mappings:
+ # returns the first provider that is available
+ if not prov_mapping.available:
+ continue
+ if prefer_file and not prov_mapping.provider_domain.startswith("filesystem"):
+ continue
+ if self.mass.get_provider(prov_mapping.provider_instance):
+ return (prov_mapping.provider_instance, prov_mapping.item_id)
+ return None, None
+
+ async def get_db_items_by_query(
+ self,
+ custom_query: str | None = None,
+ query_params: dict | None = None,
+ limit: int = 500,
+ offset: int = 0,
+ ) -> list[ItemCls]:
+ """Fetch MediaItem records from database given a custom query."""
+ return [
+ self.item_cls.from_db_row(db_row)
+ for db_row in await self.mass.music.database.get_rows_from_query(
+ custom_query, query_params, limit=limit, offset=offset
+ )
+ ]
+
+ async def get_db_item(self, item_id: int | str) -> ItemCls:
+ """Get record by id."""
+ match = {"item_id": int(item_id)}
+ if db_row := await self.mass.music.database.get_row(self.db_table, match):
+ return self.item_cls.from_db_row(db_row)
+ raise MediaNotFoundError(f"Album not found in database: {item_id}")
+
+ async def get_db_item_by_prov_id(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ ) -> ItemCls | None:
+ """Get the database item for the given provider_instance."""
+ assert (
+ provider_domain or provider_instance
+ ), "provider_domain or provider_instance must be supplied"
+ if "database" in (provider_domain, provider_instance):
+ return await self.get_db_item(item_id)
+ for item in await self.get_db_items_by_prov_id(
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ provider_item_ids=(item_id,),
+ ):
+ return item
+ return None
+
+ async def get_db_items_by_prov_id(
+ self,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ provider_item_ids: tuple[str] | None = None,
+ limit: int = 500,
+ offset: int = 0,
+ ) -> list[ItemCls]:
+ """Fetch all records from database for given provider."""
+ assert (
+ provider_domain or provider_instance
+ ), "provider_domain or provider_instance must be supplied"
+ if "database" in (provider_domain, provider_instance):
+ return await self.get_db_items_by_query(limit=limit, offset=offset)
+
+ # we use the separate provider_mappings table to perform quick lookups
+ # from provider id's to database id's because this is faster
+ # (and more compatible) than querying the provider_mappings json column
+ subquery = f"SELECT item_id FROM {DB_TABLE_PROVIDER_MAPPINGS} "
+ if provider_instance is not None:
+ subquery += f"WHERE provider_instance = '{provider_instance}'"
+ elif provider_domain is not None:
+ subquery += f"WHERE provider_domain = '{provider_domain}'"
+ if provider_item_ids is not None:
+ prov_ids = str(tuple(provider_item_ids))
+ if prov_ids.endswith(",)"):
+ prov_ids = prov_ids.replace(",)", ")")
+ subquery += f" AND provider_item_id in {prov_ids}"
+ query = f"SELECT * FROM {self.db_table} WHERE item_id in ({subquery})"
+ return await self.get_db_items_by_query(query, limit=limit, offset=offset)
+
+ async def iter_db_items_by_prov_id(
+ self,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ provider_item_ids: tuple[str] | None = None,
+ limit: int = 500,
+ offset: int = 0,
+ ) -> AsyncGenerator[ItemCls, None]:
+ """Iterate all records from database for given provider."""
+ limit: int = 500
+ offset: int = 0
+ while True:
+ next_items = await self.get_db_items_by_prov_id(
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ provider_item_ids=provider_item_ids,
+ limit=limit,
+ offset=offset,
+ )
+ for item in next_items:
+ yield item
+ if len(next_items) < limit:
+ break
+ offset += limit
+
+ async def set_db_library(self, item_id: int, in_library: bool) -> None:
+ """Set the in-library bool on a database item."""
+ match = {"item_id": item_id}
+ timestamp = int(time()) if in_library else 0
+ await self.mass.music.database.update(
+ self.db_table, match, {"in_library": in_library, "timestamp": timestamp}
+ )
+ db_item = await self.get_db_item(item_id)
+ self.mass.signal_event(EventType.MEDIA_ITEM_UPDATED, db_item.uri, db_item)
+
+ async def get_provider_item(
+ self,
+ item_id: str,
+ provider_domain_or_instance_id: str,
+ ) -> ItemCls:
+ """Return item details for the given provider item id."""
+ if provider_domain_or_instance_id == "database":
+ item = await self.get_db_item(item_id)
+ else:
+ provider = self.mass.get_provider(provider_domain_or_instance_id)
+ item = await provider.get_item(self.media_type, item_id)
+ if not item:
+ raise MediaNotFoundError(
+ f"{self.media_type.value}//{item_id} not found on provider {provider_domain_or_instance_id}" # noqa: E501
+ )
+ return item
+
+ async def remove_prov_mapping(self, item_id: int, provider_instance: str) -> None:
+ """Remove provider id(s) from item."""
+ try:
+ db_item = await self.get_db_item(item_id)
+ except MediaNotFoundError:
+ # edge case: already deleted / race condition
+ return
+
+ # update provider_mappings table
+ await self.mass.music.database.delete(
+ DB_TABLE_PROVIDER_MAPPINGS,
+ {
+ "media_type": self.media_type.value,
+ "item_id": int(item_id),
+ "provider_instance": provider_instance,
+ },
+ )
+
+ # update the item in db (provider_mappings column only)
+ db_item.provider_mappings = {
+ x for x in db_item.provider_mappings if x.provider_instance != provider_instance
+ }
+ match = {"item_id": item_id}
+ await self.mass.music.database.update(
+ self.db_table,
+ match,
+ {"provider_mappings": json_dumps(db_item.provider_mappings)},
+ )
+ self.mass.signal_event(EventType.MEDIA_ITEM_UPDATED, db_item.uri, db_item)
+
+ # NOTE: If the item has no providers left we leave an orphan item in the db
+ # to easily reinstate when a new provider attaches to it.
+
+ self.logger.debug("removed provider %s from item id %s", provider_instance, item_id)
+
+ async def delete_db_item(self, item_id: int, recursive: bool = False) -> None: # noqa: ARG002
+ """Delete record from the database."""
+ db_item = await self.get_db_item(item_id)
+ assert db_item, f"Item does not exist: {item_id}"
+ # delete item
+ await self.mass.music.database.delete(
+ self.db_table,
+ {"item_id": int(item_id)},
+ )
+ # update provider_mappings table
+ await self.mass.music.database.delete(
+ DB_TABLE_PROVIDER_MAPPINGS,
+ {"media_type": self.media_type.value, "item_id": int(item_id)},
+ )
+ # NOTE: this does not delete any references to this item in other records,
+ # this is handled/overridden in the mediatype specific controllers
+ self.mass.signal_event(EventType.MEDIA_ITEM_DELETED, db_item.uri, db_item)
+ self.logger.debug("deleted item with id %s from database", item_id)
+
+ async def dynamic_tracks(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ limit: int = 25,
+ ) -> list[Track]:
+ """Return a dynamic list of tracks based on the given item."""
+ ref_item = await self.get(item_id, provider_domain, provider_instance)
+ for prov_mapping in ref_item.provider_mappings:
+ prov = self.mass.get_provider(prov_mapping.provider_instance)
+ if not prov.available:
+ continue
+ if ProviderFeature.SIMILAR_TRACKS not in prov.supported_features:
+ continue
+ return await self._get_provider_dynamic_tracks(
+ item_id=prov_mapping.item_id,
+ provider_domain=prov_mapping.provider_domain,
+ provider_instance=prov_mapping.provider_instance,
+ limit=limit,
+ )
+ # Fallback to the default implementation
+ return await self._get_dynamic_tracks(ref_item)
+
+ @abstractmethod
+ async def _get_provider_dynamic_tracks(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ limit: int = 25,
+ ) -> list[Track]:
+ """Generate a dynamic list of tracks based on the item's content."""
+
+ @abstractmethod
+ async def _get_dynamic_tracks(self, media_item: ItemCls, limit: int = 25) -> list[Track]:
+ """Get dynamic list of tracks for given item, fallback/default implementation."""
+
+ async def _set_provider_mappings(
+ self, item_id: int, provider_mappings: list[ProviderMapping]
+ ) -> None:
+ """Update the provider_items table for the media item."""
+ # clear all records first
+ await self.mass.music.database.delete(
+ DB_TABLE_PROVIDER_MAPPINGS,
+ {"media_type": self.media_type.value, "item_id": int(item_id)},
+ )
+ # add entries
+ for provider_mapping in provider_mappings:
+ await self.mass.music.database.insert_or_replace(
+ DB_TABLE_PROVIDER_MAPPINGS,
+ {
+ "media_type": self.media_type.value,
+ "item_id": item_id,
+ "provider_domain": provider_mapping.provider_domain,
+ "provider_instance": provider_mapping.provider_instance,
+ "provider_item_id": provider_mapping.item_id,
+ },
+ )
--- /dev/null
+"""Manage MediaItems of type Playlist."""
+from __future__ import annotations
+
+import random
+from time import time
+from typing import Any
+
+from music_assistant.common.helpers.json import json_dumps
+from music_assistant.common.helpers.uri import create_uri
+from music_assistant.common.models.enums import EventType, MediaType, ProviderFeature
+from music_assistant.common.models.errors import (
+ InvalidDataError,
+ MediaNotFoundError,
+ ProviderUnavailableError,
+ UnsupportedFeaturedException,
+)
+from music_assistant.common.models.media_items import Playlist, Track
+from music_assistant.constants import DB_TABLE_PLAYLISTS
+
+from .base import MediaControllerBase
+
+
+class PlaylistController(MediaControllerBase[Playlist]):
+ """Controller managing MediaItems of type Playlist."""
+
+ db_table = DB_TABLE_PLAYLISTS
+ media_type = MediaType.PLAYLIST
+ item_cls = Playlist
+
+ def __init__(self, *args, **kwargs):
+ """Initialize class."""
+ super().__init__(*args, **kwargs)
+ # register api handlers
+ self.mass.register_api_command("music/playlists", self.db_items)
+ self.mass.register_api_command("music/playlist", self.get)
+ self.mass.register_api_command("music/playlist/tracks", self.tracks)
+ self.mass.register_api_command("music/playlist/tracks/add", self.add_playlist_tracks)
+ self.mass.register_api_command("music/playlist/tracks/remove", self.remove_playlist_tracks)
+ self.mass.register_api_command("music/playlist/update", self.update_db_item)
+ self.mass.register_api_command("music/playlist/delete", self.delete_db_item)
+ self.mass.register_api_command("music/playlist/create", self.create)
+
+ async def tracks(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ ) -> list[Track]:
+ """Return playlist tracks for the given provider playlist id."""
+ playlist = await self.get(item_id, provider_domain, provider_instance)
+ prov = next(x for x in playlist.provider_mappings)
+ return await self._get_provider_playlist_tracks(
+ prov.item_id,
+ provider_domain=prov.provider_domain,
+ provider_instance=prov.provider_instance,
+ cache_checksum=playlist.metadata.checksum,
+ )
+
+ async def add(self, item: Playlist) -> Playlist:
+ """Add playlist to local db and return the new database item."""
+ item.metadata.last_refresh = int(time())
+ await self.mass.metadata.get_playlist_metadata(item)
+ existing = await self.get_db_item_by_prov_id(item.item_id, item.provider)
+ if existing:
+ db_item = await self.update_db_item(existing.item_id, item)
+ else:
+ db_item = await self.add_db_item(item)
+ self.mass.signal_event(
+ EventType.MEDIA_ITEM_UPDATED if existing else EventType.MEDIA_ITEM_ADDED,
+ db_item.uri,
+ db_item,
+ )
+ return db_item
+
+ async def create(self, name: str, provider: str | None = None) -> Playlist:
+ """Create new playlist."""
+ # if provider is omitted, just pick first provider
+ if provider:
+ provider = self.mass.get_provider(provider)
+ else:
+ provider = next(
+ (
+ x
+ for x in self.mass.music.providers
+ if ProviderFeature.PLAYLIST_CREATE in x.supported_features
+ ),
+ None,
+ )
+ if provider is None:
+ raise ProviderUnavailableError(
+ "No provider available which allows playlists creation."
+ )
+
+ return await provider.create_playlist(name)
+
+ async def add_playlist_tracks(self, db_playlist_id: str, uris: list[str]) -> None:
+ """Add multiple tracks to playlist. Creates background tasks to process the action."""
+ playlist = await self.get_db_item(db_playlist_id)
+ if not playlist:
+ raise MediaNotFoundError(f"Playlist with id {db_playlist_id} not found")
+ if not playlist.is_editable:
+ raise InvalidDataError(f"Playlist {playlist.name} is not editable")
+ for uri in uris:
+ self.mass.create_task(self.add_playlist_track(db_playlist_id, uri))
+
+ async def add_playlist_track(self, db_playlist_id: str, track_uri: str) -> None:
+ """Add track to playlist - make sure we dont add duplicates."""
+ # we can only edit playlists that are in the database (marked as editable)
+ playlist = await self.get_db_item(db_playlist_id)
+ if not playlist:
+ raise MediaNotFoundError(f"Playlist with id {db_playlist_id} not found")
+ if not playlist.is_editable:
+ raise InvalidDataError(f"Playlist {playlist.name} is not editable")
+ # make sure we have recent full track details
+ track = await self.mass.music.get_item_by_uri(track_uri, lazy=False)
+ assert track.media_type == MediaType.TRACK
+ # a playlist can only have one provider (for now)
+ playlist_prov = next(iter(playlist.provider_mappings))
+ # grab all existing track ids in the playlist so we can check for duplicates
+ cur_playlist_track_ids = set()
+ count = 0
+ for item in await self.tracks(playlist_prov.item_id, playlist_prov.provider_domain):
+ count += 1
+ cur_playlist_track_ids.update(
+ {
+ i.item_id
+ for i in item.provider_mappings
+ if i.provider_instance == playlist_prov.provider_instance
+ }
+ )
+ # check for duplicates
+ for track_prov in track.provider_mappings:
+ if (
+ track_prov.provider_domain == playlist_prov.provider_domain
+ and track_prov.item_id in cur_playlist_track_ids
+ ):
+ raise InvalidDataError("Track already exists in playlist {playlist.name}")
+ # add track to playlist
+ # we can only add a track to a provider playlist if track is available on that provider
+ # a track can contain multiple versions on the same provider
+ # simply sort by quality and just add the first one (assuming track is still available)
+ track_id_to_add = None
+ for track_version in sorted(track.provider_mappings, key=lambda x: x.quality, reverse=True):
+ if not track.available:
+ continue
+ if playlist_prov.provider_domain.startswith("filesystem"):
+ # the file provider can handle uri's from all providers so simply add the uri
+ track_id_to_add = track_version.url or create_uri(
+ MediaType.TRACK,
+ track_version.provider_domain,
+ track_version.item_id,
+ )
+ break
+ if track_version.provider_domain == playlist_prov.provider_domain:
+ track_id_to_add = track_version.item_id
+ break
+ if not track_id_to_add:
+ raise MediaNotFoundError(
+ f"Track is not available on provider {playlist_prov.provider_domain}"
+ )
+ # actually add the tracks to the playlist on the provider
+ provider = self.mass.get_provider(playlist_prov.provider_instance)
+ await provider.add_playlist_tracks(playlist_prov.item_id, [track_id_to_add])
+ # invalidate cache by updating the checksum
+ await self.get(db_playlist_id, provider_domain="database", force_refresh=True)
+
+ async def remove_playlist_tracks(
+ self, db_playlist_id: str, positions_to_remove: tuple[int]
+ ) -> None:
+ """Remove multiple tracks from playlist."""
+ playlist = await self.get_db_item(db_playlist_id)
+ if not playlist:
+ raise MediaNotFoundError(f"Playlist with id {db_playlist_id} not found")
+ if not playlist.is_editable:
+ raise InvalidDataError(f"Playlist {playlist.name} is not editable")
+ for prov_mapping in playlist.provider_mappings:
+ provider = self.mass.get_provider(prov_mapping.provider_instance)
+ if ProviderFeature.PLAYLIST_TRACKS_EDIT not in provider.supported_features:
+ self.logger.warning(
+ "Provider %s does not support editing playlists",
+ prov_mapping.provider_domain,
+ )
+ continue
+ await provider.remove_playlist_tracks(prov_mapping.item_id, positions_to_remove)
+ # invalidate cache by updating the checksum
+ await self.get(db_playlist_id, "database", force_refresh=True)
+
+ async def add_db_item(self, item: Playlist, overwrite_existing: bool = False) -> Playlist:
+ """Add a new record to the database."""
+ async with self._db_add_lock:
+ match = {"name": item.name, "owner": item.owner}
+ if cur_item := await self.mass.music.database.get_row(self.db_table, match):
+ # update existing
+ return await self.update_db_item(
+ cur_item["item_id"], item, overwrite=overwrite_existing
+ )
+
+ # insert new item
+ new_item = await self.mass.music.database.insert(self.db_table, item.to_db_row())
+ item_id = new_item["item_id"]
+ # update/set provider_mappings table
+ await self._set_provider_mappings(item_id, item.provider_mappings)
+ self.logger.debug("added %s to database", item.name)
+ # return created object
+ return await self.get_db_item(item_id)
+
+ async def update_db_item(
+ self,
+ item_id: int,
+ item: Playlist,
+ overwrite: bool = False,
+ ) -> Playlist:
+ """Update Playlist record in the database."""
+ cur_item = await self.get_db_item(item_id)
+ if overwrite:
+ metadata = item.metadata
+ provider_mappings = item.provider_mappings
+ else:
+ metadata = cur_item.metadata.update(item.metadata)
+ provider_mappings = {*cur_item.provider_mappings, *item.provider_mappings}
+
+ await self.mass.music.database.update(
+ self.db_table,
+ {"item_id": item_id},
+ {
+ # always prefer name/owner from updated item here
+ "name": item.name,
+ "sort_name": item.sort_name,
+ "owner": item.owner,
+ "is_editable": item.is_editable,
+ "metadata": json_dumps(metadata),
+ "provider_mappings": json_dumps(provider_mappings),
+ },
+ )
+ # update/set provider_mappings table
+ await self._set_provider_mappings(item_id, provider_mappings)
+ self.logger.debug("updated %s in database: %s", item.name, item_id)
+ return await self.get_db_item(item_id)
+
+ async def _get_provider_playlist_tracks(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ cache_checksum: Any = None,
+ ) -> list[Track]:
+ """Return album tracks for the given provider album id."""
+ provider = self.mass.get_provider(provider_instance or provider_domain)
+ if not provider:
+ return []
+ # prefer cache items (if any)
+ cache_key = f"{provider.instance_id}.playlist.{item_id}.tracks"
+ if cache := await self.mass.cache.get(cache_key, checksum=cache_checksum):
+ return [Track.from_dict(x) for x in cache]
+ # no items in cache - get listing from provider
+ items = await provider.get_playlist_tracks(item_id)
+ # double check if position set
+ if items:
+ assert items[0].position is not None, "Playlist items require position to be set"
+ # store (serializable items) in cache
+ self.mass.create_task(
+ self.mass.cache.set(cache_key, [x.to_dict() for x in items], checksum=cache_checksum)
+ )
+ return items
+
+ async def _get_provider_dynamic_tracks(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ limit: int = 25,
+ ):
+ """Generate a dynamic list of tracks based on the playlist content."""
+ provider = self.mass.get_provider(provider_instance or provider_domain)
+ if not provider or ProviderFeature.SIMILAR_TRACKS not in provider.supported_features:
+ return []
+ playlist_tracks = await self._get_provider_playlist_tracks(
+ item_id=item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ )
+ # filter out unavailable tracks
+ playlist_tracks = [x for x in playlist_tracks if x.available]
+ limit = min(limit, len(playlist_tracks))
+ # use set to prevent duplicates
+ final_items = set()
+ # to account for playlists with mixed content we grab suggestions from a few
+ # random playlist tracks to prevent getting too many tracks of one of the
+ # source playlist's genres.
+ while len(final_items) < limit:
+ # grab 5 random tracks from the playlist
+ base_tracks = random.sample(playlist_tracks, 5)
+ # add the source/base playlist tracks to the final list...
+ final_items.update(base_tracks)
+ # get 5 suggestions for one of the base tracks
+ base_track = next(x for x in base_tracks if x.available)
+ similar_tracks = await provider.get_similar_tracks(
+ prov_track_id=base_track.item_id, limit=5
+ )
+ final_items.update(x for x in similar_tracks if x.available)
+
+ # NOTE: In theory we can return a few more items than limit here
+ # Shuffle the final items list
+ return random.sample(final_items, len(final_items))
+
+ async def _get_dynamic_tracks(
+ self, media_item: Playlist, limit: int = 25 # noqa: ARG002
+ ) -> list[Track]:
+ """Get dynamic list of tracks for given item, fallback/default implementation."""
+ # TODO: query metadata provider(s) to get similar tracks (or tracks from similar artists)
+ raise UnsupportedFeaturedException(
+ "No Music Provider found that supports requesting similar tracks."
+ )
--- /dev/null
+"""Manage MediaItems of type Radio."""
+from __future__ import annotations
+
+import asyncio
+from time import time
+
+from music_assistant.common.helpers.json import json_dumps
+from music_assistant.common.models.enums import EventType, MediaType
+from music_assistant.common.models.media_items import Radio, Track
+from music_assistant.constants import DB_TABLE_RADIOS
+from music_assistant.server.helpers.compare import loose_compare_strings
+
+from .base import MediaControllerBase
+
+
+class RadioController(MediaControllerBase[Radio]):
+ """Controller managing MediaItems of type Radio."""
+
+ db_table = DB_TABLE_RADIOS
+ media_type = MediaType.RADIO
+ item_cls = Radio
+
+ def __init__(self, *args, **kwargs):
+ """Initialize class."""
+ super().__init__(*args, **kwargs)
+ # register api handlers
+ self.mass.register_api_command("music/radios", self.db_items)
+ self.mass.register_api_command("music/radio", self.get)
+ self.mass.register_api_command("music/radio/versions", self.versions)
+ self.mass.register_api_command("music/radio/update", self.update_db_item)
+ self.mass.register_api_command("music/radio/delete", self.delete_db_item)
+
+ async def versions(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ ) -> list[Radio]:
+ """Return all versions of a radio station we can find on all providers."""
+ assert provider_domain or provider_instance, "Provider type or ID must be specified"
+ radio = await self.get(item_id, provider_domain, provider_instance)
+ # perform a search on all provider(types) to collect all versions/variants
+ provider_domains = {prov.domain for prov in self.mass.music.providers}
+ all_versions = {
+ prov_item.item_id: prov_item
+ for prov_items in await asyncio.gather(
+ *[self.search(radio.name, provider_domain) for provider_domain in provider_domains]
+ )
+ for prov_item in prov_items
+ if loose_compare_strings(radio.name, prov_item.name)
+ }
+ # make sure that the 'base' version is included
+ for prov_version in radio.provider_mappings:
+ if prov_version.item_id in all_versions:
+ continue
+ radio_copy = Radio.from_dict(radio.to_dict())
+ radio_copy.item_id = prov_version.item_id
+ radio_copy.provider = prov_version.provider_domain
+ radio_copy.provider_mappings = {prov_version}
+ all_versions[prov_version.item_id] = radio_copy
+
+ # return the aggregated result
+ return all_versions.values()
+
+ async def add(self, item: Radio) -> Radio:
+ """Add radio to local db and return the new database item."""
+ item.metadata.last_refresh = int(time())
+ await self.mass.metadata.get_radio_metadata(item)
+ existing = await self.get_db_item_by_prov_id(item.item_id, item.provider)
+ if existing:
+ db_item = await self.update_db_item(existing.item_id, item)
+ else:
+ db_item = await self.add_db_item(item)
+ self.mass.signal_event(
+ EventType.MEDIA_ITEM_UPDATED if existing else EventType.MEDIA_ITEM_ADDED,
+ db_item.uri,
+ db_item,
+ )
+ return db_item
+
+ async def add_db_item(self, item: Radio, overwrite_existing: bool = False) -> Radio:
+ """Add a new item record to the database."""
+ assert item.provider_mappings
+ async with self._db_add_lock:
+ match = {"name": item.name}
+ if cur_item := await self.mass.music.database.get_row(self.db_table, match):
+ # update existing
+ return await self.update_db_item(
+ cur_item["item_id"], item, overwrite=overwrite_existing
+ )
+
+ # insert new item
+ new_item = await self.mass.music.database.insert(self.db_table, item.to_db_row())
+ item_id = new_item["item_id"]
+ # update/set provider_mappings table
+ await self._set_provider_mappings(item_id, item.provider_mappings)
+ self.logger.debug("added %s to database", item.name)
+ # return created object
+ return await self.get_db_item(item_id)
+
+ async def update_db_item(
+ self,
+ item_id: int,
+ item: Radio,
+ overwrite: bool = False,
+ ) -> Radio:
+ """Update Radio record in the database."""
+ cur_item = await self.get_db_item(item_id)
+ if overwrite:
+ metadata = item.metadata
+ provider_mappings = item.provider_mappings
+ else:
+ metadata = cur_item.metadata.update(item.metadata)
+ provider_mappings = {*cur_item.provider_mappings, *item.provider_mappings}
+
+ match = {"item_id": item_id}
+ await self.mass.music.database.update(
+ self.db_table,
+ match,
+ {
+ # always prefer name from updated item here
+ "name": item.name,
+ "sort_name": item.sort_name,
+ "metadata": json_dumps(metadata),
+ "provider_mappings": json_dumps(provider_mappings),
+ },
+ )
+ # update/set provider_mappings table
+ await self._set_provider_mappings(item_id, provider_mappings)
+ self.logger.debug("updated %s in database: %s", item.name, item_id)
+ return await self.get_db_item(item_id)
+
+ async def _get_provider_dynamic_tracks(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ limit: int = 25,
+ ) -> list[Track]:
+ """Generate a dynamic list of tracks based on the item's content."""
+ raise NotImplementedError("Dynamic tracks not supported for Radio MediaItem")
+
+ async def _get_dynamic_tracks(self, media_item: Radio, limit: int = 25) -> list[Track]:
+ """Get dynamic list of tracks for given item, fallback/default implementation."""
+ raise NotImplementedError("Dynamic tracks not supported for Radio MediaItem")
--- /dev/null
+"""Manage MediaItems of type Track."""
+from __future__ import annotations
+
+import asyncio
+
+from music_assistant.common.helpers.json import json_dumps
+from music_assistant.common.models.enums import EventType, MediaType, ProviderFeature
+from music_assistant.common.models.errors import MediaNotFoundError, UnsupportedFeaturedException
+from music_assistant.common.models.media_items import (
+ Album,
+ Artist,
+ ItemMapping,
+ Track,
+ TrackAlbumMapping,
+)
+from music_assistant.constants import DB_TABLE_TRACKS
+from music_assistant.server.helpers.compare import (
+ compare_artists,
+ compare_track,
+ loose_compare_strings,
+)
+
+from .base import MediaControllerBase
+
+
+class TracksController(MediaControllerBase[Track]):
+ """Controller managing MediaItems of type Track."""
+
+ db_table = DB_TABLE_TRACKS
+ media_type = MediaType.TRACK
+ item_cls = Track
+
+ def __init__(self, *args, **kwargs):
+ """Initialize class."""
+ super().__init__(*args, **kwargs)
+ # register api handlers
+ self.mass.register_api_command("music/tracks", self.db_items)
+ self.mass.register_api_command("music/track", self.get)
+ self.mass.register_api_command("music/track/versions", self.versions)
+ self.mass.register_api_command("music/track/update", self.update_db_item)
+ self.mass.register_api_command("music/track/delete", self.delete_db_item)
+
+ async def get(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ force_refresh: bool = False,
+ lazy: bool = True,
+ details: Track = None,
+ force_provider_item: bool = False,
+ ) -> Track:
+ """Return (full) details for a single media item."""
+ track = await super().get(
+ item_id=item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ force_refresh=force_refresh,
+ lazy=lazy,
+ details=details,
+ force_provider_item=force_provider_item,
+ )
+ # append full album details to full track item
+ if track.album:
+ try:
+ track.album = await self.mass.music.albums.get(
+ track.album.item_id,
+ track.album.provider,
+ lazy=True,
+ details=track.album,
+ )
+ except MediaNotFoundError:
+ # edge case where playlist track has invalid albumdetails
+ self.logger.warning("Unable to fetch album details %s", track.album.uri)
+ # append full artist details to full track item
+ full_artists = []
+ for artist in track.artists:
+ full_artists.append(
+ await self.mass.music.artists.get(
+ artist.item_id, artist.provider, lazy=True, details=artist
+ )
+ )
+ track.artists = full_artists
+ return track
+
+ async def add(self, item: Track) -> Track:
+ """Add track to local db and return the new database item."""
+ # make sure we have artists
+ assert item.artists
+ # grab additional metadata
+ await self.mass.metadata.get_track_metadata(item)
+ existing = await self.get_db_item_by_prov_id(item.item_id, item.provider)
+ if existing:
+ db_item = await self.update_db_item(existing.item_id, item)
+ else:
+ db_item = await self.add_db_item(item)
+ # also fetch same track on all providers (will also get other quality versions)
+ await self._match(db_item)
+ # return final db_item after all match/metadata actions
+ db_item = await self.get_db_item(db_item.item_id)
+ self.mass.signal_event(
+ EventType.MEDIA_ITEM_UPDATED if existing else EventType.MEDIA_ITEM_ADDED,
+ db_item.uri,
+ db_item,
+ )
+ return db_item
+
+ async def versions(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ ) -> list[Track]:
+ """Return all versions of a track we can find on all providers."""
+ assert provider_domain or provider_instance, "Provider type or ID must be specified"
+ track = await self.get(item_id, provider_domain or provider_instance)
+ # perform a search on all provider(types) to collect all versions/variants
+ provider_domains = {prov.domain for prov in self.mass.music.providers}
+ search_query = f"{track.artist.name} - {track.name}"
+ all_versions = {
+ prov_item.item_id: prov_item
+ for prov_items in await asyncio.gather(
+ *[
+ self.search(search_query, provider_domain)
+ for provider_domain in provider_domains
+ ]
+ )
+ for prov_item in prov_items
+ if loose_compare_strings(track.name, prov_item.name)
+ and compare_artists(prov_item.artists, track.artists, any_match=True)
+ }
+ # make sure that the 'base' version is included
+ for prov_version in track.provider_mappings:
+ if prov_version.item_id in all_versions:
+ continue
+ # grab full item here including album details etc
+ prov_track = await self.get_provider_item(
+ prov_version.item_id, prov_version.provider_instance
+ )
+ all_versions[prov_version.item_id] = prov_track
+
+ # return the aggregated result
+ return all_versions.values()
+
+ async def get_preview_url(self, provider_domain: str, item_id: str) -> str:
+ """Return url to short preview sample."""
+ track = await self.get_provider_item(item_id, provider_domain)
+ # prefer provider-provided preview
+ if preview := track.metadata.preview:
+ return preview
+ # fallback to a preview/sample hosted by our own webserver
+ return self.mass.streams.get_preview_url(provider_domain, item_id)
+
+ async def _match(self, db_track: Track) -> None:
+ """Try to find matching track on all providers for the provided (database) track_id.
+
+ This is used to link objects of different providers/qualities together.
+ """
+ if db_track.provider != "database":
+ return # Matching only supported for database items
+ for provider in self.mass.music.providers:
+ if ProviderFeature.SEARCH not in provider.supported_features:
+ continue
+ self.logger.debug(
+ "Trying to match track %s on provider %s", db_track.name, provider.name
+ )
+ match_found = False
+ for search_str in (
+ db_track.name,
+ f"{db_track.artists[0].name} - {db_track.name}",
+ f"{db_track.artists[0].name} {db_track.name}",
+ ):
+ if match_found:
+ break
+ search_result = await self.search(search_str, provider.domain)
+ for search_result_item in search_result:
+ if not search_result_item.available:
+ continue
+ if compare_track(search_result_item, db_track):
+ # 100% match, we can simply update the db with additional provider ids
+ match_found = True
+ await self.update_db_item(db_track.item_id, search_result_item)
+
+ if not match_found:
+ self.logger.debug(
+ "Could not find match for Track %s on provider %s",
+ db_track.name,
+ provider.name,
+ )
+
+ async def _get_provider_dynamic_tracks(
+ self,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ limit: int = 25,
+ ):
+ """Generate a dynamic list of tracks based on the track."""
+ prov = self.mass.get_provider(provider_instance or provider_domain)
+ if not prov or ProviderFeature.SIMILAR_TRACKS not in prov.supported_features:
+ return []
+ # Grab similar tracks from the music provider
+ similar_tracks = await prov.get_similar_tracks(prov_track_id=item_id, limit=limit)
+ return similar_tracks
+
+ async def _get_dynamic_tracks(
+ self, media_item: Track, limit: int = 25 # noqa: ARG002
+ ) -> list[Track]:
+ """Get dynamic list of tracks for given item, fallback/default implementation."""
+ # TODO: query metadata provider(s) to get similar tracks (or tracks from similar artists)
+ raise UnsupportedFeaturedException(
+ "No Music Provider found that supports requesting similar tracks."
+ )
+
+ async def add_db_item(self, item: Track, overwrite_existing: bool = False) -> Track:
+ """Add a new item record to the database."""
+ assert isinstance(item, Track), "Not a full Track object"
+ assert item.artists, "Track is missing artist(s)"
+ assert item.provider_mappings, "Track is missing provider id(s)"
+ async with self._db_add_lock:
+ cur_item = None
+
+ # always try to grab existing item by external_id
+ if item.musicbrainz_id:
+ match = {"musicbrainz_id": item.musicbrainz_id}
+ cur_item = await self.mass.music.database.get_row(self.db_table, match)
+ for isrc in item.isrcs:
+ match = {"isrc": isrc}
+ cur_item = await self.mass.music.database.get_row(self.db_table, match)
+ if not cur_item:
+ # fallback to matching
+ match = {"sort_name": item.sort_name}
+ for row in await self.mass.music.database.get_rows(self.db_table, match):
+ row_track = Track.from_db_row(row)
+ if compare_track(row_track, item):
+ cur_item = row_track
+ break
+ if cur_item:
+ # update existing
+ return await self.update_db_item(
+ cur_item.item_id, item, overwrite=overwrite_existing
+ )
+
+ # no existing match found: insert new item
+ track_artists = await self._get_track_artists(item)
+ track_albums = await self._get_track_albums(item, overwrite=overwrite_existing)
+ sort_artist = track_artists[0].sort_name if track_artists else ""
+ sort_album = track_albums[0].sort_name if track_albums else ""
+ new_item = await self.mass.music.database.insert(
+ self.db_table,
+ {
+ **item.to_db_row(),
+ "artists": json_dumps(track_artists),
+ "albums": json_dumps(track_albums),
+ "sort_artist": sort_artist,
+ "sort_album": sort_album,
+ },
+ )
+ item_id = new_item["item_id"]
+ # update/set provider_mappings table
+ await self._set_provider_mappings(item_id, item.provider_mappings)
+ # return created object
+ self.logger.debug("added %s to database: %s", item.name, item_id)
+ return await self.get_db_item(item_id)
+
+ async def update_db_item(
+ self,
+ item_id: int,
+ item: Track,
+ overwrite: bool = False,
+ ) -> Track:
+ """Update Track record in the database, merging data."""
+ cur_item = await self.get_db_item(item_id)
+
+ if overwrite:
+ metadata = item.metadata
+ provider_mappings = item.provider_mappings
+ metadata.last_refresh = None
+ # we store a mapping to artists/albums on the item for easier access/listings
+ track_artists = await self._get_track_artists(item, overwrite=True)
+ track_albums = await self._get_track_albums(item, overwrite=True)
+ else:
+ metadata = cur_item.metadata.update(item.metadata, "file" in item.provider)
+ provider_mappings = {*cur_item.provider_mappings, *item.provider_mappings}
+ track_artists = await self._get_track_artists(cur_item, item)
+ track_albums = await self._get_track_albums(cur_item, item)
+
+ await self.mass.music.database.update(
+ self.db_table,
+ {"item_id": item_id},
+ {
+ "name": item.name if overwrite else cur_item.name,
+ "sort_name": item.sort_name if overwrite else cur_item.sort_name,
+ "version": item.version if overwrite else cur_item.version,
+ "duration": item.duration if overwrite else cur_item.duration,
+ "artists": json_dumps(track_artists),
+ "albums": json_dumps(track_albums),
+ "metadata": json_dumps(metadata),
+ "provider_mappings": json_dumps(provider_mappings),
+ "isrc": item.isrc or cur_item.isrc,
+ },
+ )
+ # update/set provider_mappings table
+ await self._set_provider_mappings(item_id, provider_mappings)
+ self.logger.debug("updated %s in database: %s", item.name, item_id)
+ return await self.get_db_item(item_id)
+
+ async def _get_track_artists(
+ self,
+ base_track: Track,
+ upd_track: Track | None = None,
+ overwrite: bool = False,
+ ) -> list[ItemMapping]:
+ """Extract all (unique) artists of track as ItemMapping."""
+ track_artists = upd_track.artists if upd_track and upd_track.artists else base_track.artists
+ # use intermediate set to clear out duplicates
+ return list({await self._get_artist_mapping(x, overwrite) for x in track_artists})
+
+ async def _get_track_albums(
+ self,
+ base_track: Track,
+ upd_track: Track | None = None,
+ overwrite: bool = False,
+ ) -> list[TrackAlbumMapping]:
+ """Extract all (unique) albums of track as TrackAlbumMapping."""
+ track_albums: list[TrackAlbumMapping] = []
+ # existing TrackAlbumMappings are starting point
+ if base_track.albums:
+ track_albums = base_track.albums
+ elif upd_track and upd_track.albums:
+ track_albums = upd_track.albums
+ # append update item album if needed
+ if upd_track and upd_track.album:
+ mapping = await self._get_album_mapping(upd_track.album, overwrite=overwrite)
+ mapping = TrackAlbumMapping.from_dict(
+ {
+ **mapping.to_dict(),
+ "disc_number": upd_track.disc_number,
+ "track_number": upd_track.track_number,
+ }
+ )
+ if mapping not in track_albums:
+ track_albums.append(mapping)
+ # append base item album if needed
+ elif base_track and base_track.album:
+ mapping = await self._get_album_mapping(base_track.album, overwrite=overwrite)
+ mapping = TrackAlbumMapping.from_dict(
+ {
+ **mapping.to_dict(),
+ "disc_number": base_track.disc_number,
+ "track_number": base_track.track_number,
+ }
+ )
+ if mapping not in track_albums:
+ track_albums.append(mapping)
+
+ return track_albums
+
+ async def _get_album_mapping(
+ self,
+ album: Album | ItemMapping,
+ overwrite: bool = False,
+ ) -> ItemMapping:
+ """Extract (database) album as ItemMapping."""
+ if album.provider == "database":
+ if isinstance(album, ItemMapping):
+ return album
+ return ItemMapping.from_item(album)
+
+ if overwrite:
+ db_album = await self.mass.music.albums.add_db_item(album, overwrite_existing=True)
+
+ if db_album := await self.mass.music.albums.get_db_item_by_prov_id(
+ album.item_id, provider_domain=album.provider
+ ):
+ return ItemMapping.from_item(db_album)
+
+ db_album = await self.mass.music.albums.add_db_item(album, overwrite_existing=overwrite)
+ return ItemMapping.from_item(db_album)
+
+ async def _get_artist_mapping(
+ self, artist: Artist | ItemMapping, overwrite: bool = False
+ ) -> ItemMapping:
+ """Extract (database) track artist as ItemMapping."""
+ if artist.provider == "database":
+ if isinstance(artist, ItemMapping):
+ return artist
+ return ItemMapping.from_item(artist)
+
+ if overwrite:
+ artist = await self.mass.music.artists.add_db_item(artist, overwrite_existing=True)
+
+ if db_artist := await self.mass.music.artists.get_db_item_by_prov_id(
+ artist.item_id, provider_domain=artist.provider
+ ):
+ return ItemMapping.from_item(db_artist)
+
+ db_artist = await self.mass.music.artists.add_db_item(artist)
+ return ItemMapping.from_item(db_artist)
--- /dev/null
+"""All logic for metadata retrieval."""
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+from base64 import b64encode
+from random import shuffle
+from time import time
+from typing import TYPE_CHECKING
+
+import aiofiles
+from aiohttp import web
+
+from music_assistant.common.models.enums import (
+ ImageType,
+ MediaType,
+ ProviderFeature,
+ ProviderType,
+)
+from music_assistant.common.models.media_items import (
+ Album,
+ Artist,
+ ItemMapping,
+ MediaItemImage,
+ MediaItemType,
+ Playlist,
+ Radio,
+ Track,
+)
+from music_assistant.constants import ROOT_LOGGER_NAME
+from music_assistant.server.helpers.images import create_collage, get_image_thumb
+
+if TYPE_CHECKING:
+ from music_assistant.server import MusicAssistant
+ from music_assistant.server.models.metadata_provider import MetadataProvider
+
+LOGGER = logging.getLogger(f"{ROOT_LOGGER_NAME}.metadata")
+
+
+class MetaDataController:
+ """Several helpers to search and store metadata for mediaitems."""
+
+ def __init__(self, mass: MusicAssistant) -> None:
+ """Initialize class."""
+ self.mass = mass
+ self.cache = mass.cache
+ self._pref_lang: str | None = None
+ self.scan_busy: bool = False
+
+ async def setup(self) -> None:
+ """Async initialize of module."""
+ self.mass.webapp.router.add_get("/imageproxy", self._handle_imageproxy)
+
+ async def close(self) -> None:
+ """Handle logic on server stop."""
+
+ @property
+ def providers(self) -> list[MetadataProvider]:
+ """Return all loaded/running MetadataProviders."""
+ return self.mass.get_providers(ProviderType.METADATA) # type: ignore[return-value]
+
+ @property
+ def preferred_language(self) -> str:
+ """Return preferred language for metadata as 2 letter country code (uppercase).
+
+ Defaults to English (EN).
+ """
+ return self._pref_lang or "EN"
+
+ @preferred_language.setter
+ def preferred_language(self, lang: str) -> None:
+ """Set preferred language to 2 letter country code.
+
+ Can only be set once.
+ """
+ if self._pref_lang is None:
+ self._pref_lang = lang.upper()
+
+ def start_scan(self) -> None:
+ """Start background scan for missing Artist metadata."""
+
+ async def scan_artist_metadata():
+ """Background task that scans for artists missing metadata on filesystem providers."""
+ if self.scan_busy:
+ return
+
+ LOGGER.info("Start scan for missing artist metadata")
+ self.scan_busy = True
+ for prov in self.mass.music.providers:
+ if not prov.is_file():
+ continue
+ async for artist in self.mass.music.artists.iter_db_items_by_prov_id(
+ provider_instance=prov.instance_id
+ ):
+ if artist.metadata.last_refresh is not None:
+ continue
+ # simply grabbing the full artist will trigger a full fetch
+ await self.mass.music.artists.get(artist.item_id, artist.provider, lazy=False)
+ # this is slow on purpose to not cause stress on the metadata providers
+ await asyncio.sleep(30)
+ self.scan_busy = False
+ LOGGER.info("Finished scan for missing artist metadata")
+
+ self.mass.create_task(scan_artist_metadata)
+
+ async def get_artist_metadata(self, artist: Artist) -> None:
+ """Get/update rich metadata for an artist."""
+ # set timestamp, used to determine when this function was last called
+ artist.metadata.last_refresh = int(time())
+
+ if not artist.musicbrainz_id:
+ artist.musicbrainz_id = await self.get_artist_musicbrainz_id(artist)
+
+ if not artist.musicbrainz_id:
+ return
+
+ # collect metadata from all providers
+ for provider in self.providers:
+ if ProviderFeature.ARTIST_METADATA not in provider.supported_features:
+ continue
+ if metadata := await provider.get_artist_metadata(artist):
+ artist.metadata.update(metadata)
+ LOGGER.debug(
+ "Fetched metadata for Artist %s on provider %s",
+ artist.name,
+ provider.name,
+ )
+
+ async def get_album_metadata(self, album: Album) -> None:
+ """Get/update rich metadata for an album."""
+ # set timestamp, used to determine when this function was last called
+ album.metadata.last_refresh = int(time())
+ # ensure the album has a musicbrainz id or artist
+ if not (album.musicbrainz_id or album.artist):
+ return
+ # collect metadata from all providers
+ for provider in self.providers:
+ if ProviderFeature.ALBUM_METADATA not in provider.supported_features:
+ continue
+ if metadata := await provider.get_album_metadata(album):
+ album.metadata.update(metadata)
+ LOGGER.debug(
+ "Fetched metadata for Album %s on provider %s",
+ album.name,
+ provider.name,
+ )
+
+ async def get_track_metadata(self, track: Track) -> None:
+ """Get/update rich metadata for a track."""
+ # set timestamp, used to determine when this function was last called
+ track.metadata.last_refresh = int(time())
+
+ if not (track.album and track.artists):
+ return
+ # collect metadata from all providers
+ for provider in self.providers:
+ if ProviderFeature.TRACK_METADATA not in provider.supported_features:
+ continue
+ if metadata := await provider.get_track_metadata(track):
+ track.metadata.update(metadata)
+ LOGGER.debug(
+ "Fetched metadata for Track %s on provider %s",
+ track.name,
+ provider.name,
+ )
+
+ async def get_playlist_metadata(self, playlist: Playlist) -> None:
+ """Get/update rich metadata for a playlist."""
+ # set timestamp, used to determine when this function was last called
+ playlist.metadata.last_refresh = int(time())
+ # retrieve genres from tracks
+ # TODO: retrieve style/mood ?
+ playlist.metadata.genres = set()
+ image_urls = set()
+ for track in await self.mass.music.playlists.tracks(playlist.item_id, playlist.provider):
+ if not playlist.image and track.image:
+ image_urls.add(track.image.url)
+ if track.media_type != MediaType.TRACK:
+ # filter out radio items
+ continue
+ assert isinstance(track, Track)
+ assert isinstance(track.album, Album)
+ if track.metadata.genres:
+ playlist.metadata.genres.update(track.metadata.genres)
+ elif track.album and track.album.metadata.genres:
+ playlist.metadata.genres.update(track.album.metadata.genres)
+ # create collage thumb/fanart from playlist tracks
+ if image_urls:
+ img_path = f"playlist.{playlist.provider}.{playlist.item_id}.png"
+ img_path = os.path.join(self.mass.storage_path, img_path)
+ img_data = await create_collage(self.mass, list(image_urls))
+ async with aiofiles.open(img_path, "wb") as _file:
+ await _file.write(img_data)
+ playlist.metadata.images = [MediaItemImage(ImageType.THUMB, img_path, True)]
+
+ async def get_radio_metadata(self, radio: Radio) -> None:
+ """Get/update rich metadata for a radio station."""
+ # NOTE: we do not have any metadata for radio so consider this future proofing ;-)
+ radio.metadata.last_refresh = int(time())
+
+ async def get_artist_musicbrainz_id(self, artist: Artist) -> str | None:
+ """Fetch musicbrainz id by performing search using the artist name, albums and tracks."""
+ ref_albums = await self.mass.music.artists.albums(artist=artist)
+ ref_tracks = await self.mass.music.artists.tracks(artist=artist)
+
+ # randomize providers so average the load
+ providers = self.providers
+ shuffle(providers)
+
+ # try all providers one by one until we have a match
+ for provider in providers:
+ if ProviderFeature.GET_ARTIST_MBID not in provider.supported_features:
+ continue
+ if musicbrainz_id := await provider.get_musicbrainz_artist_id(
+ artist, ref_albums=ref_albums, ref_tracks=ref_tracks
+ ):
+ LOGGER.debug(
+ "Fetched MusicBrainz ID for Artist %s on provider %s",
+ artist.name,
+ provider.name,
+ )
+ return musicbrainz_id
+
+ # lookup failed
+ ref_albums_str = "/".join(x.name for x in ref_albums) or "none"
+ ref_tracks_str = "/".join(x.name for x in ref_tracks) or "none"
+ LOGGER.info(
+ "Unable to get musicbrainz ID for artist %s\n"
+ " - using lookup-album(s): %s\n"
+ " - using lookup-track(s): %s\n",
+ artist.name,
+ ref_albums_str,
+ ref_tracks_str,
+ )
+ return None
+
+ async def get_image_data_for_item(
+ self,
+ media_item: MediaItemType,
+ img_type: ImageType = ImageType.THUMB,
+ size: int = 0,
+ ) -> bytes | None:
+ """Get image data for given MedaItem."""
+ img_path = await self.get_image_url_for_item(
+ media_item=media_item,
+ img_type=img_type,
+ allow_local=True,
+ local_as_base64=False,
+ )
+ if not img_path:
+ return None
+ return await self.get_thumbnail(img_path, size)
+
+ async def get_image_url_for_item(
+ self,
+ media_item: MediaItemType,
+ img_type: ImageType = ImageType.THUMB,
+ allow_local: bool = True,
+ local_as_base64: bool = False,
+ ) -> str | None:
+ """Get url to image for given media media_item."""
+ if not media_item:
+ return None
+ if isinstance(media_item, ItemMapping):
+ media_item = await self.mass.music.get_item_by_uri(media_item.uri)
+ if media_item and media_item.metadata.images:
+ for img in media_item.metadata.images:
+ if img.type != img_type:
+ continue
+ if img.is_file and not allow_local:
+ continue
+ if img.is_file and local_as_base64:
+ # return base64 string of the image (compatible with browsers)
+ return await self.get_thumbnail(img.url, base64=True)
+ return img.url
+
+ # retry with track's album
+ if media_item.media_type == MediaType.TRACK and media_item.album:
+ return await self.get_image_url_for_item(
+ media_item.album, img_type, allow_local, local_as_base64
+ )
+
+ # try artist instead for albums
+ if media_item.media_type == MediaType.ALBUM and media_item.artist:
+ return await self.get_image_url_for_item(
+ media_item.artist, img_type, allow_local, local_as_base64
+ )
+
+ # last resort: track artist(s)
+ if media_item.media_type == MediaType.TRACK and media_item.artists:
+ for artist in media_item.artists:
+ return await self.get_image_url_for_item(
+ artist, img_type, allow_local, local_as_base64
+ )
+
+ return None
+
+ async def get_thumbnail(
+ self, path: str, size: int | None = None, base64: bool = False
+ ) -> bytes | str:
+ """Get/create thumbnail image for path (image url or local path)."""
+ thumbnail = await get_image_thumb(self.mass, path, size)
+ if base64:
+ enc_image = b64encode(thumbnail).decode()
+ thumbnail = f"data:image/png;base64,{enc_image}"
+ return thumbnail
+
+ async def _handle_imageproxy(self, request: web.Request) -> web.Response:
+ """Handle request for image proxy."""
+ path = request.query["path"]
+ size = int(request.query.get("size", "0"))
+ image_data = await self.get_thumbnail(path, size)
+ # we set the cache header to 1 year (forever)
+ # the client can use the checksum value to refresh when content changes
+ return web.Response(
+ body=image_data,
+ headers={"Cache-Control": "max-age=31536000"},
+ content_type="image/png",
+ )
--- /dev/null
+"""MusicController: Orchestrates all data from music providers and sync to internal database."""
+from __future__ import annotations
+
+import asyncio
+import itertools
+import logging
+import statistics
+from typing import TYPE_CHECKING
+
+from music_assistant.common.helpers.datetime import utc_timestamp
+from music_assistant.common.helpers.uri import parse_uri
+from music_assistant.common.models.enums import EventType, MediaType, ProviderFeature, ProviderType
+from music_assistant.common.models.errors import MusicAssistantError
+from music_assistant.common.models.media_items import (
+ BrowseFolder,
+ MediaItem,
+ MediaItemType,
+ media_from_dict,
+)
+from music_assistant.common.models.provider import SyncTask
+from music_assistant.constants import (
+ CONF_DB_LIBRARY,
+ DB_TABLE_ALBUMS,
+ DB_TABLE_ARTISTS,
+ DB_TABLE_PLAYLISTS,
+ DB_TABLE_PLAYLOG,
+ DB_TABLE_PROVIDER_MAPPINGS,
+ DB_TABLE_RADIOS,
+ DB_TABLE_SETTINGS,
+ DB_TABLE_TRACK_LOUDNESS,
+ DB_TABLE_TRACKS,
+ DEFAULT_DB_LIBRARY,
+ ROOT_LOGGER_NAME,
+ SCHEMA_VERSION,
+)
+from music_assistant.server.helpers.api import api_command
+from music_assistant.server.helpers.database import DatabaseConnection
+from music_assistant.server.models.music_provider import MusicProvider
+
+from .media.albums import AlbumsController
+from .media.artists import ArtistsController
+from .media.playlists import PlaylistController
+from .media.radio import RadioController
+from .media.tracks import TracksController
+
+if TYPE_CHECKING:
+ from music_assistant.server import MusicAssistant
+
+LOGGER = logging.getLogger(f"{ROOT_LOGGER_NAME}.music")
+
+
+class MusicController:
+ """Several helpers around the musicproviders."""
+
+ database: DatabaseConnection | None = None
+
+ def __init__(self, mass: MusicAssistant):
+ """Initialize class."""
+ self.mass = mass
+ self.artists = ArtistsController(mass)
+ self.albums = AlbumsController(mass)
+ self.tracks = TracksController(mass)
+ self.radio = RadioController(mass)
+ self.playlists = PlaylistController(mass)
+ self.in_progress_syncs: list[SyncTask] = []
+
+ async def setup(self):
+ """Async initialize of module."""
+ # setup library database
+ await self._setup_database()
+
+ async def close(self) -> None:
+ """Cleanup on exit."""
+
+ @property
+ def providers(self) -> list[MusicProvider]:
+ """Return all loaded/running MusicProviders (instances)."""
+ return self.mass.get_providers(ProviderType.MUSIC)
+
+ @api_command("music/sync")
+ async def start_sync(
+ self,
+ media_types: list[MediaType] | None = None,
+ providers: list[str] | None = None,
+ ) -> None:
+ """Start running the sync of (all or selected) musicproviders.
+
+ media_types: only sync these media types. None for all.
+ providers: only sync these provider instances. None for all.
+ """
+ if media_types is None:
+ media_types = MediaType.ALL
+ if providers is None:
+ providers = [x.instance_id for x in self.providers]
+
+ for provider in self.providers:
+ if provider.instance_id not in providers:
+ continue
+ self._start_provider_sync(provider.instance_id, media_types)
+ # trgger metadata scan after provider sync completed
+ self.mass.metadata.start_scan()
+
+ @api_command("music/synctasks")
+ def get_running_sync_tasks(self) -> list[SyncTask]:
+ """Return list with providers that are currently syncing."""
+ return self.in_progress_syncs
+
+ @api_command("music/search")
+ async def search(
+ self,
+ search_query: str,
+ media_types: list[MediaType] = MediaType.ALL,
+ limit: int = 10,
+ ) -> list[MediaItemType]:
+ """Perform global search for media items on all providers.
+
+ :param search_query: Search query.
+ :param media_types: A list of media_types to include.
+ :param limit: number of items to return in the search (per type).
+ """
+ # include results from all music providers
+ provider_instances = (item.instance_id for item in self.providers)
+ # TODO: sort by name and filter out duplicates ?
+ return list(
+ itertools.chain.from_iterable(
+ await asyncio.gather(
+ *[
+ self.search_provider(
+ search_query,
+ media_types,
+ provider_instance=provider_instance,
+ limit=limit,
+ )
+ for provider_instance in provider_instances
+ ]
+ )
+ )
+ )
+
+ async def search_provider(
+ self,
+ search_query: str,
+ media_types: list[MediaType] = MediaType.ALL,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ limit: int = 10,
+ ) -> list[MediaItemType]:
+ """Perform search on given provider.
+
+ :param search_query: Search query
+ :param provider_domain: domain of the provider to perform the search on.
+ :param provider_instance: instance id of the provider to perform the search on.
+ :param media_types: A list of media_types to include. All types if None.
+ :param limit: number of items to return in the search (per type).
+ """
+ assert provider_domain or provider_instance, "Provider needs to be supplied"
+ prov = self.mass.get_provider(provider_instance or provider_domain)
+ if ProviderFeature.SEARCH not in prov.supported_features:
+ return []
+
+ # create safe search string
+ search_query = search_query.replace("/", " ").replace("'", "")
+
+ # prefer cache items (if any)
+ cache_key = f"{prov.instance_id}.search.{search_query}.{limit}"
+ cache_key += "".join(x for x in media_types)
+
+ if cache := await self.mass.cache.get(cache_key):
+ return [media_from_dict(x) for x in cache]
+ # no items in cache - get listing from provider
+ items = await prov.search(
+ search_query,
+ media_types,
+ limit,
+ )
+ # store (serializable items) in cache
+ self.mass.create_task(
+ self.mass.cache.set(cache_key, [x.to_dict() for x in items], expiration=86400 * 7)
+ )
+ return items
+
+ @api_command("music/browse")
+ async def browse(self, path: str | None = None) -> BrowseFolder:
+ """Browse Music providers."""
+ # root level; folder per provider
+ if not path or path == "root":
+ return BrowseFolder(
+ item_id="root",
+ provider="database",
+ path="root",
+ label="browse",
+ name="",
+ items=[
+ BrowseFolder(
+ item_id="root",
+ provider=prov.domain,
+ path=f"{prov.instance_id}://",
+ name=prov.name,
+ )
+ for prov in self.providers
+ if ProviderFeature.BROWSE in prov.supported_features
+ ],
+ )
+ # provider level
+ provider_instance = path.split("://", 1)[0]
+ prov = self.mass.get_provider(provider_instance)
+ return await prov.browse(path)
+
+ @api_command("music/item_by_uri")
+ async def get_item_by_uri(
+ self, uri: str, force_refresh: bool = False, lazy: bool = True
+ ) -> MediaItemType:
+ """Fetch MediaItem by uri."""
+ media_type, provider_domain, item_id = parse_uri(uri)
+ return await self.get_item(
+ media_type=media_type,
+ item_id=item_id,
+ provider_domain=provider_domain,
+ force_refresh=force_refresh,
+ lazy=lazy,
+ )
+
+ @api_command("music/item")
+ async def get_item(
+ self,
+ media_type: MediaType,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ force_refresh: bool = False,
+ lazy: bool = True,
+ ) -> MediaItemType:
+ """Get single music item by id and media type."""
+ assert (
+ provider_domain or provider_instance
+ ), "provider_domain or provider_instance must be supplied"
+ if "url" in (provider_domain, provider_instance):
+ # handle special case of 'URL' MusicProvider which allows us to play regular url's
+ return await self.mass.get_provider("url").parse_item(item_id)
+ ctrl = self.get_controller(media_type)
+ return await ctrl.get(
+ item_id=item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ force_refresh=force_refresh,
+ lazy=lazy,
+ )
+
+ @api_command("music/library/add")
+ async def add_to_library(
+ self,
+ media_type: MediaType,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ ) -> None:
+ """Add an item to the library."""
+ ctrl = self.get_controller(media_type)
+ await ctrl.add_to_library(
+ item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ )
+
+ @api_command("music/library/add_items")
+ async def add_items_to_library(self, items: list[str | MediaItemType]) -> None:
+ """Add multiple items to the library (provide uri or MediaItem)."""
+ tasks = []
+ for item in items:
+ if isinstance(item, str):
+ item = await self.get_item_by_uri(item) # noqa: PLW2901
+ tasks.append(
+ self.mass.create_task(
+ self.add_to_library(
+ media_type=item.media_type,
+ item_id=item.item_id,
+ provider_domain=item.provider,
+ )
+ )
+ )
+ await asyncio.gather(*tasks)
+
+ @api_command("music/library/remove")
+ async def remove_from_library(
+ self,
+ media_type: MediaType,
+ item_id: str,
+ provider_domain: str | None = None,
+ provider_instance: str | None = None,
+ ) -> None:
+ """Remove item from the library."""
+ ctrl = self.get_controller(media_type)
+ await ctrl.remove_from_library(
+ item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ )
+
+ @api_command("music/library/remove_items")
+ async def remove_items_from_library(self, items: list[str | MediaItemType]) -> None:
+ """Remove multiple items from the library (provide uri or MediaItem)."""
+ tasks = []
+ for item in items:
+ if isinstance(item, str):
+ item = await self.get_item_by_uri(item) # noqa: PLW2901
+ tasks.append(
+ self.mass.create_task(
+ self.remove_from_library(
+ media_type=item.media_type,
+ item_id=item.item_id,
+ provider_domain=item.provider,
+ )
+ )
+ )
+ await asyncio.gather(*tasks)
+
+ @api_command("music/delete_db_item")
+ async def delete_db_item(
+ self, media_type: MediaType, db_item_id: str | int, recursive: bool = False
+ ) -> None:
+ """Remove item from the database."""
+ ctrl = self.get_controller(media_type)
+ await ctrl.delete_db_item(db_item_id, recursive)
+
+ async def refresh_items(self, items: list[MediaItem]) -> None:
+ """Refresh MediaItems to force retrieval of full info and matches.
+
+ Creates background tasks to process the action.
+ """
+ for media_item in items:
+ self.mass.create_task(self.refresh_item(media_item))
+
+ async def refresh_item(
+ self,
+ media_item: MediaItem,
+ ):
+ """Try to refresh a mediaitem by requesting it's full object or search for substitutes."""
+ try:
+ return await self.get_item(
+ media_item.media_type,
+ media_item.item_id,
+ provider_domain=media_item.provider,
+ force_refresh=True,
+ lazy=False,
+ )
+ except MusicAssistantError:
+ pass
+
+ for item in await self.search(media_item.name, [media_item.media_type], 20):
+ if item.available:
+ await self.get_item(item.media_type, item.item_id, item.provider, lazy=False)
+ return None
+
+ async def set_track_loudness(self, item_id: str, provider_domain: str, loudness: int):
+ """List integrated loudness for a track in db."""
+ await self.database.insert(
+ DB_TABLE_TRACK_LOUDNESS,
+ {"item_id": item_id, "provider": provider_domain, "loudness": loudness},
+ allow_replace=True,
+ )
+
+ async def get_track_loudness(self, item_id: str, provider_domain: str) -> float | None:
+ """Get integrated loudness for a track in db."""
+ if result := await self.database.get_row(
+ DB_TABLE_TRACK_LOUDNESS,
+ {
+ "item_id": item_id,
+ "provider": provider_domain,
+ },
+ ):
+ return result["loudness"]
+ return None
+
+ async def get_provider_loudness(self, provider_domain: str) -> float | None:
+ """Get average integrated loudness for tracks of given provider."""
+ all_items = []
+ if provider_domain == "url":
+ # this is not a very good idea for random urls
+ return None
+ for db_row in await self.database.get_rows(
+ DB_TABLE_TRACK_LOUDNESS,
+ {
+ "provider": provider_domain,
+ },
+ ):
+ all_items.append(db_row["loudness"])
+ if all_items:
+ return statistics.fmean(all_items)
+ return None
+
+ async def mark_item_played(self, item_id: str, provider_domain: str):
+ """Mark item as played in playlog."""
+ timestamp = utc_timestamp()
+ await self.database.insert(
+ DB_TABLE_PLAYLOG,
+ {
+ "item_id": item_id,
+ "provider": provider_domain,
+ "timestamp": timestamp,
+ },
+ allow_replace=True,
+ )
+
+ async def library_add_items(self, items: list[MediaItem]) -> None:
+ """Add media item(s) to the library.
+
+ Creates background tasks to process the action.
+ """
+ for media_item in items:
+ self.mass.create_task(
+ self.add_to_library(media_item.media_type, media_item.item_id, media_item.provider)
+ )
+
+ async def library_remove_items(self, items: list[MediaItem]) -> None:
+ """Remove media item(s) from the library.
+
+ Creates background tasks to process the action.
+ """
+ for media_item in items:
+ self.mass.create_task(
+ self.remove_from_library(
+ media_item.media_type, media_item.item_id, media_item.provider
+ )
+ )
+
+ def get_controller(
+ self, media_type: MediaType
+ ) -> (
+ ArtistsController
+ | AlbumsController
+ | TracksController
+ | RadioController
+ | PlaylistController
+ ): # noqa: E501
+ """Return controller for MediaType."""
+ if media_type == MediaType.ARTIST:
+ return self.artists
+ if media_type == MediaType.ALBUM:
+ return self.albums
+ if media_type == MediaType.TRACK:
+ return self.tracks
+ if media_type == MediaType.RADIO:
+ return self.radio
+ if media_type == MediaType.PLAYLIST:
+ return self.playlists
+ return None
+
+ def _start_provider_sync(self, provider_instance: str, media_types: tuple[MediaType, ...]):
+ """Start sync task on provider and track progress."""
+ # check if we're not already running a sync task for this provider/mediatype
+ for sync_task in self.in_progress_syncs:
+ if sync_task.provider_instance != provider_instance:
+ continue
+ for media_type in media_types:
+ if media_type in sync_task.media_types:
+ LOGGER.debug(
+ "Skip sync task for %s because another task is already in progress",
+ provider_instance,
+ )
+ return
+
+ # we keep track of running sync tasks
+ provider = self.mass.get_provider(provider_instance)
+ task = self.mass.create_task(provider.sync_library(media_types))
+ sync_spec = SyncTask(
+ provider_domain=provider.domain,
+ provider_instance=provider.instance_id,
+ media_types=media_types,
+ task=task,
+ )
+ self.in_progress_syncs.append(sync_spec)
+
+ self.mass.signal_event(EventType.SYNC_TASKS_UPDATED, data=self.in_progress_syncs)
+
+ def on_sync_task_done(task: asyncio.Task): # noqa: ARG001
+ self.in_progress_syncs.remove(sync_spec)
+ self.mass.signal_event(EventType.SYNC_TASKS_UPDATED, data=self.in_progress_syncs)
+
+ task.add_done_callback(on_sync_task_done)
+
+ async def cleanup_provider(self, provider_instance: str) -> None:
+ """Cleanup provider records from the database."""
+ # clean cache items from deleted provider(s)
+ await self.mass.cache.clear(provider_instance)
+
+ # cleanup media items from db matched to deleted provider
+ for ctrl in (
+ # order is important here to recursively cleanup bottom up
+ self.mass.music.radio,
+ self.mass.music.playlists,
+ self.mass.music.tracks,
+ self.mass.music.albums,
+ self.mass.music.artists,
+ ):
+ prov_items = await ctrl.get_db_items_by_prov_id(provider_instance=provider_instance)
+ for item in prov_items:
+ await ctrl.remove_prov_mapping(item.item_id, provider_instance)
+
+ async def _setup_database(self):
+ """Initialize database."""
+ db_url: str = self.mass.config.get(CONF_DB_LIBRARY, DEFAULT_DB_LIBRARY)
+ db_url = db_url.replace("[storage_path]", self.mass.storage_path)
+ self.database = DatabaseConnection(db_url)
+
+ # always create db tables if they don't exist to prevent errors trying to access them later
+ await self.__create_database_tables()
+ try:
+ if db_row := await self.database.get_row(DB_TABLE_SETTINGS, {"key": "version"}):
+ prev_version = int(db_row["value"])
+ else:
+ prev_version = 0
+ except (KeyError, ValueError):
+ prev_version = 0
+
+ if prev_version not in (0, SCHEMA_VERSION):
+ LOGGER.info(
+ "Performing database migration from %s to %s",
+ prev_version,
+ SCHEMA_VERSION,
+ )
+
+ if prev_version < SCHEMA_VERSION:
+ # for now just keep it simple and just recreate the tables
+ await self.database.execute(f"DROP TABLE IF EXISTS {DB_TABLE_ARTISTS}")
+ await self.database.execute(f"DROP TABLE IF EXISTS {DB_TABLE_ALBUMS}")
+ await self.database.execute(f"DROP TABLE IF EXISTS {DB_TABLE_TRACKS}")
+ await self.database.execute(f"DROP TABLE IF EXISTS {DB_TABLE_PLAYLISTS}")
+ await self.database.execute(f"DROP TABLE IF EXISTS {DB_TABLE_RADIOS}")
+
+ # recreate missing tables
+ await self.__create_database_tables()
+
+ # store current schema version
+ await self.database.insert_or_replace(
+ DB_TABLE_SETTINGS,
+ {"key": "version", "value": str(SCHEMA_VERSION), "type": "str"},
+ )
+ # compact db
+ await self.database.execute("VACUUM")
+
+ async def __create_database_tables(self) -> None:
+ """Create database tables."""
+ await self.database.execute(
+ f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_SETTINGS}(
+ key TEXT PRIMARY KEY,
+ value TEXT,
+ type TEXT
+ );"""
+ )
+ await self.database.execute(
+ f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_TRACK_LOUDNESS}(
+ item_id INTEGER NOT NULL,
+ provider TEXT NOT NULL,
+ loudness REAL,
+ UNIQUE(item_id, provider));"""
+ )
+ await self.database.execute(
+ f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_PLAYLOG}(
+ item_id INTEGER NOT NULL,
+ provider TEXT NOT NULL,
+ timestamp INTEGER DEFAULT 0,
+ UNIQUE(item_id, provider));"""
+ )
+ await self.database.execute(
+ f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_ALBUMS}(
+ item_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT NOT NULL,
+ sort_name TEXT NOT NULL,
+ sort_artist TEXT,
+ album_type TEXT,
+ year INTEGER,
+ version TEXT,
+ in_library BOOLEAN DEFAULT 0,
+ upc TEXT,
+ musicbrainz_id TEXT,
+ artists json,
+ metadata json,
+ provider_mappings json,
+ timestamp INTEGER DEFAULT 0
+ );"""
+ )
+ await self.database.execute(
+ f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_ARTISTS}(
+ item_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT NOT NULL,
+ sort_name TEXT NOT NULL,
+ musicbrainz_id TEXT,
+ in_library BOOLEAN DEFAULT 0,
+ metadata json,
+ provider_mappings json,
+ timestamp INTEGER DEFAULT 0
+ );"""
+ )
+ await self.database.execute(
+ f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_TRACKS}(
+ item_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT NOT NULL,
+ sort_name TEXT NOT NULL,
+ sort_artist TEXT,
+ sort_album TEXT,
+ version TEXT,
+ duration INTEGER,
+ in_library BOOLEAN DEFAULT 0,
+ isrc TEXT,
+ musicbrainz_id TEXT,
+ artists json,
+ albums json,
+ metadata json,
+ provider_mappings json,
+ timestamp INTEGER DEFAULT 0
+ );"""
+ )
+ await self.database.execute(
+ f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_PLAYLISTS}(
+ item_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT NOT NULL,
+ sort_name TEXT NOT NULL,
+ owner TEXT NOT NULL,
+ is_editable BOOLEAN NOT NULL,
+ in_library BOOLEAN DEFAULT 0,
+ metadata json,
+ provider_mappings json,
+ timestamp INTEGER DEFAULT 0,
+ UNIQUE(name, owner)
+ );"""
+ )
+ await self.database.execute(
+ f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_RADIOS}(
+ item_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT NOT NULL UNIQUE,
+ sort_name TEXT NOT NULL,
+ in_library BOOLEAN DEFAULT 0,
+ metadata json,
+ provider_mappings json,
+ timestamp INTEGER DEFAULT 0
+ );"""
+ )
+ await self.database.execute(
+ f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_PROVIDER_MAPPINGS}(
+ media_type TEXT NOT NULL,
+ item_id INTEGER NOT NULL,
+ provider_domain TEXT NOT NULL,
+ provider_instance TEXT NOT NULL,
+ provider_item_id TEXT NOT NULL,
+ UNIQUE(media_type, item_id, provider_instance,
+ provider_item_id, provider_item_id)
+ );"""
+ )
+
+ # create indexes
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS artists_in_library_idx on artists(in_library);"
+ )
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS albums_in_library_idx on albums(in_library);"
+ )
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS tracks_in_library_idx on tracks(in_library);"
+ )
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS playlists_in_library_idx on playlists(in_library);"
+ )
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS radios_in_library_idx on radios(in_library);"
+ )
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS artists_sort_name_idx on artists(sort_name);"
+ )
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS albums_sort_name_idx on albums(sort_name);"
+ )
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS tracks_sort_name_idx on tracks(sort_name);"
+ )
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS playlists_sort_name_idx on playlists(sort_name);"
+ )
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS radios_sort_name_idx on radios(sort_name);"
+ )
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS artists_musicbrainz_id_idx on artists(musicbrainz_id);"
+ )
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS albums_musicbrainz_id_idx on albums(musicbrainz_id);"
+ )
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS tracks_musicbrainz_id_idx on tracks(musicbrainz_id);"
+ )
+ await self.database.execute("CREATE INDEX IF NOT EXISTS tracks_isrc_idx on tracks(isrc);")
+ await self.database.execute("CREATE INDEX IF NOT EXISTS albums_upc_idx on albums(upc);")
--- /dev/null
+"""Logic to play music from MusicProviders to supported players."""
+from __future__ import annotations
+
+import logging
+import random
+import time
+from typing import TYPE_CHECKING
+
+from music_assistant.common.helpers.util import get_changed_keys
+from music_assistant.common.models.enums import (
+ EventType,
+ MediaType,
+ PlayerFeature,
+ PlayerState,
+ QueueOption,
+ RepeatMode,
+)
+from music_assistant.common.models.errors import (
+ MediaNotFoundError,
+ MusicAssistantError,
+ QueueEmpty,
+)
+from music_assistant.common.models.media_items import MediaItemType, media_from_dict
+from music_assistant.common.models.player_queue import PlayerQueue
+from music_assistant.common.models.queue_item import QueueItem
+from music_assistant.constants import (
+ CONF_FLOW_MODE,
+ FALLBACK_DURATION,
+ ROOT_LOGGER_NAME,
+)
+from music_assistant.server.helpers.api import api_command
+
+if TYPE_CHECKING:
+ from collections.abc import Iterator
+
+ from music_assistant.common.models.player import Player
+
+ from .players import PlayerController
+
+LOGGER = logging.getLogger(f"{ROOT_LOGGER_NAME}.players.queue")
+
+
+class PlayerQueuesController:
+ """Controller holding all logic to enqueue music for players."""
+
+ def __init__(self, players: PlayerController) -> None:
+ """Initialize class."""
+ self.players = players
+ self.mass = players.mass
+ self._queues: dict[str, PlayerQueue] = {}
+ self._queue_items: dict[str, list[QueueItem]] = {}
+ self._prev_states: dict[str, dict] = {}
+
+ async def close(self) -> None:
+ """Cleanup on exit."""
+ # stop all playback
+ for queue in self.all():
+ if queue.state not in (PlayerState.PLAYING, PlayerState.PAUSED):
+ continue
+ await self.stop(queue.queue_id)
+
+ def __iter__(self) -> Iterator[PlayerQueue]:
+ """Iterate over (available) players."""
+ return iter(self._queues.values())
+
+ @api_command("players/queue/all")
+ def all(self) -> tuple[PlayerQueue]:
+ """Return all registered PlayerQueues."""
+ return tuple(self._queues.values())
+
+ @api_command("players/queue/get")
+ def get(self, queue_id: str) -> PlayerQueue | None:
+ """Return PlayerQueue by queue_id or None if not found."""
+ return self._queues.get(queue_id)
+
+ @api_command("players/queue/items")
+ def items(self, queue_id: str) -> list[QueueItem]:
+ """Return all QueueItems for given PlayerQueue."""
+ return self._queue_items.get(queue_id, [])
+
+ @api_command("players/queue/get_active_queue")
+ def get_active_queue(self, player_id: str) -> PlayerQueue:
+ """Return the current active/synced queue for a player."""
+ player = self.mass.players.get(player_id)
+ return self.get(player.active_queue)
+
+ # Queue commands
+
+ @api_command("players/queue/shuffle")
+ def set_shuffle(self, queue_id: str, shuffle_enabled: bool) -> None:
+ """Configure shuffle setting on the the queue."""
+ queue = self._queues[queue_id]
+ if queue.shuffle_enabled == shuffle_enabled:
+ return # no change
+
+ queue.shuffle_enabled = shuffle_enabled
+ queue_items = self._queue_items[queue_id]
+ cur_index = queue.index_in_buffer
+ if cur_index is not None:
+ next_index = cur_index + 1
+ next_items = queue_items[next_index:]
+ else:
+ next_items = []
+ next_index = 0
+
+ if not shuffle_enabled:
+ # shuffle disabled, try to restore original sort order of the remaining items
+ next_items.sort(key=lambda x: x.sort_index, reverse=False)
+ self.load(
+ queue_id=queue_id,
+ queue_items=next_items,
+ insert_at_index=next_index,
+ keep_remaining=False,
+ shuffle=shuffle_enabled,
+ )
+
+ @api_command("players/queue/repeat")
+ def set_repeat(self, queue_id: str, repeat_mode: RepeatMode) -> None:
+ """Configure repeat setting on the the queue."""
+ queue = self._queues[queue_id]
+ if queue.repeat_mode == repeat_mode:
+ return # no change
+ queue.repeat_mode = repeat_mode
+ self.signal_update(queue_id)
+
+ @api_command("players/queue/crossfade")
+ def set_crossfade(self, queue_id: str, crossfade_enabled: bool) -> None:
+ """Configure crossfade setting on the the queue."""
+ queue = self._queues[queue_id]
+ if queue.crossfade_enabled == crossfade_enabled:
+ return # no change
+ queue.crossfade_enabled = crossfade_enabled
+ self.signal_update(queue_id)
+
+ @api_command("players/queue/play_media")
+ async def play_media(
+ self,
+ queue_id: str,
+ media: MediaItemType | list[MediaItemType] | str | list[str],
+ option: QueueOption = QueueOption.PLAY,
+ radio_mode: bool = False,
+ ) -> None:
+ """Play media item(s) on the given queue.
+
+ - media: Media that should be played (MediaItem(s) or uri's).
+ - queue_opt: Which enqueue mode to use.
+ - radio_mode: Enable radio mode for the given item(s).
+ """
+ # pylint: disable=too-many-branches
+ queue = self._queues[queue_id]
+ if queue.announcement_in_progress:
+ LOGGER.warning("Ignore queue command: An announcement is in progress")
+ return
+
+ # a single item or list of items may be provided
+ if not isinstance(media, list):
+ media = [media]
+
+ # clear queue first if it was finished
+ if queue.current_index and queue.current_index >= (len(self._queue_items[queue_id]) - 1):
+ queue.current_index = None
+ self._queue_items[queue_id] = []
+
+ # clear radio source items if needed
+ if option not in (QueueOption.ADD, QueueOption.PLAY, QueueOption.NEXT):
+ queue.radio_source = []
+
+ tracks: list[MediaItemType] = []
+ for item in media:
+ # parse provided uri into a MA MediaItem or Basic QueueItem from URL
+ if isinstance(item, str):
+ try:
+ media_item = await self.mass.music.get_item_by_uri(item)
+ except MusicAssistantError as err:
+ # invalid MA uri or item not found error
+ raise MediaNotFoundError(f"Invalid uri: {item}") from err
+ elif isinstance(item, dict):
+ media_item = media_from_dict(item)
+ else:
+ media_item = item
+
+ # collect tracks to play
+ ctrl = self.mass.music.get_controller(media_item.media_type)
+ if radio_mode:
+ queue.radio_source.append(media_item)
+ # if radio mode enabled, grab the first batch of tracks here
+ tracks += await ctrl.dynamic_tracks(
+ item_id=media_item.item_id, provider_domain=media_item.provider
+ )
+ elif media_item.media_type in (
+ MediaType.ARTIST,
+ MediaType.ALBUM,
+ MediaType.PLAYLIST,
+ ):
+ tracks += await ctrl.tracks(media_item.item_id, provider_domain=media_item.provider)
+ else:
+ # single track or radio item
+ tracks += [media_item]
+
+ # only add valid/available items
+ queue_items = [QueueItem.from_media_item(queue_id, x) for x in tracks if x and x.available]
+
+ # load the items into the queue
+ cur_index = queue.index_in_buffer or 0
+ shuffle = queue.shuffle_enabled and len(queue_items) >= 5
+
+ # handle replace: clear all items and replace with the new items
+ if option == QueueOption.REPLACE:
+ self.clear(queue_id)
+ self.load(queue_id, queue_items=queue_items, shuffle=shuffle)
+ await self.play_index(queue_id, 0)
+ # handle next: add item(s) in the index next to the playing/loaded/buffered index
+ elif option == QueueOption.NEXT:
+ await self.load(
+ queue_id,
+ queue_items=queue_items,
+ insert_at_index=cur_index + 1,
+ shuffle=shuffle,
+ )
+ elif option == QueueOption.REPLACE_NEXT:
+ self.load(
+ queue_id,
+ queue_items=queue_items,
+ insert_at_index=cur_index + 1,
+ keep_remaining=False,
+ shuffle=shuffle,
+ )
+ # handle play: replace current loaded/playing index with new item(s)
+ elif option == QueueOption.PLAY:
+ self.load(
+ queue_id,
+ queue_items=queue_items,
+ insert_at_index=cur_index,
+ shuffle=shuffle,
+ )
+ await self.play_index(queue_id, cur_index)
+ # handle add: add/append item(s) to the remaining queue items
+ elif option == QueueOption.ADD:
+ if queue.shuffle_enabled:
+ # shuffle the new items with remaining queue items
+ insert_at_index = cur_index + 1
+ else:
+ # just append at the end
+ insert_at_index = len(self._queue_items[queue_id])
+ self.load(
+ queue_id=queue_id,
+ queue_items=queue_items,
+ insert_at_index=insert_at_index,
+ shuffle=queue.shuffle_enabled,
+ )
+
+ @api_command("players/queue/move_item")
+ def move_item(self, queue_id: str, queue_item_id: str, pos_shift: int = 1) -> None:
+ """
+ Move queue item x up/down the queue.
+
+ - queue_id: id of the queue to process this request.
+ - queue_item_id: the item_id of the queueitem that needs to be moved.
+ - pos_shift: move item x positions down if positive value
+ - pos_shift: move item x positions up if negative value
+ - pos_shift: move item to top of queue as next item if 0.
+ """
+ queue = self._queues[queue_id]
+ item_index = self.index_by_id(queue_id, queue_item_id)
+ if item_index <= queue.index_in_buffer:
+ raise IndexError(f"{item_index} is already played/buffered")
+
+ queue_items = self._queue_items[queue_id]
+ queue_items = queue_items.copy()
+
+ if pos_shift == 0 and queue.state == PlayerState.PLAYING:
+ new_index = (queue.current_index or 0) + 1
+ elif pos_shift == 0:
+ new_index = queue.current_index or 0
+ else:
+ new_index = item_index + pos_shift
+ if (new_index < (queue.current_index or 0)) or (new_index > len(queue_items)):
+ return
+ # move the item in the list
+ queue_items.insert(new_index, queue_items.pop(item_index))
+ self.update_items(queue_id, queue_items)
+
+ @api_command("players/queue/delete_item")
+ def delete_item(self, queue_id: str, item_id_or_index: int | str) -> None:
+ """Delete item (by id or index) from the queue."""
+ if isinstance(item_id_or_index, str):
+ item_index = self.index_by_id(queue_id, item_id_or_index)
+ else:
+ item_index = item_id_or_index
+ queue = self._queues[queue_id]
+ if item_index <= queue.index_in_buffer:
+ # ignore request if track already loaded in the buffer
+ # the frontend should guard so this is just in case
+ LOGGER.warning("delete requested for item already loaded in buffer")
+ return
+ queue_items = self._queue_items[queue_id]
+ queue_items.pop(item_index)
+ self.update_items(queue_id, queue_items)
+
+ @api_command("players/queue/clear")
+ def clear(self, queue_id: str) -> None:
+ """Clear all items in the queue."""
+ queue = self._queues[queue_id]
+ queue.radio_source = []
+ if queue.state not in (PlayerState.IDLE, PlayerState.OFF):
+ self.mass.create_task(self.stop(queue_id))
+ queue.current_index = None
+ queue.index_in_buffer = None
+ self.update_items(queue_id, [])
+
+ @api_command("players/queue/stop")
+ async def stop(self, queue_id: str) -> None:
+ """
+ Handle STOP command for given queue.
+
+ - queue_id: queue_id of the playerqueue to handle the command.
+ """
+ if self._queues[queue_id].announcement_in_progress:
+ LOGGER.warning("Ignore queue command for %s because an announcement is in progress.")
+ return
+ # simply forward the command to underlying player
+ await self.players.cmd_stop(queue_id)
+
+ @api_command("players/queue/play")
+ async def play(self, queue_id: str) -> None:
+ """
+ Handle PLAY command for given queue.
+
+ - queue_id: queue_id of the playerqueue to handle the command.
+ """
+ if self._queues[queue_id].announcement_in_progress:
+ LOGGER.warning("Ignore queue command for %s because an announcement is in progress.")
+ return
+ if self._queues[queue_id].state == PlayerState.PAUSED:
+ # simply forward the command to underlying player
+ await self.players.cmd_play(queue_id)
+ else:
+ await self.resume(queue_id)
+
+ @api_command("players/queue/pause")
+ async def pause(self, queue_id: str) -> None:
+ """Handle PAUSE command for given queue.
+
+ - queue_id: queue_id of the playerqueue to handle the command.
+ """
+ if self._queues[queue_id].announcement_in_progress:
+ LOGGER.warning("Ignore queue command for %s because an announcement is in progress.")
+ return
+ # simply forward the command to underlying player
+ await self.players.cmd_pause(queue_id)
+
+ @api_command("players/queue/play_pause")
+ async def play_pause(self, queue_id: str) -> None:
+ """Toggle play/pause on given playerqueue.
+
+ - queue_id: queue_id of the queue to handle the command.
+ """
+ if self._queues[queue_id].state == PlayerState.PLAYING:
+ await self.pause(queue_id)
+ return
+ await self.play(queue_id)
+
+ @api_command("players/queue/next")
+ async def next(self, queue_id: str) -> None:
+ """Handle NEXT TRACK command for given queue.
+
+ - queue_id: queue_id of the queue to handle the command.
+ """
+ current_index = self._queues[queue_id].current_index
+ next_index = self.get_next_index(queue_id, current_index, True)
+ if next_index is None:
+ return
+ await self.play_index(queue_id, next_index)
+
+ @api_command("players/queue/previous")
+ async def previous(self, queue_id: str) -> None:
+ """Handle PREVIOUS TRACK command for given queue.
+
+ - queue_id: queue_id of the queue to handle the command.
+ """
+ current_index = self._queues[queue_id].current_index
+ if current_index is None:
+ return
+ await self.play_index(queue_id, max(current_index - 1, 0))
+
+ @api_command("players/queue/skip")
+ async def skip(self, queue_id: str, seconds: int = 10) -> None:
+ """Handle SKIP command for given queue.
+
+ - queue_id: queue_id of the queue to handle the command.
+ - seconds: number of seconds to skip in track. Use negative value to skip back.
+ """
+ await self.seek(queue_id, self._queues[queue_id].elapsed_time + seconds)
+
+ @api_command("players/queue/seek")
+ async def seek(self, queue_id: str, position: int = 10) -> None:
+ """Handle SEEK command for given queue.
+
+ - queue_id: queue_id of the queue to handle the command.
+ - position: position in seconds to seek to in the current playing item.
+ """
+ queue = self._queues[queue_id]
+ assert queue.current_item, "No item loaded"
+ assert queue.current_item.media_item.media_type == MediaType.TRACK
+ assert queue.current_item.duration
+ assert position < queue.current_item.duration
+ player = self.mass.players.get(queue_id)
+ if PlayerFeature.SEEK in player.supported_features:
+ player_prov = self.mass.players.get_player_provider(queue_id)
+ await player_prov.cmd_seek(player.player_id, position)
+ return
+ await self.play_index(queue_id, queue.current_index, position)
+
+ @api_command("players/queue/resume")
+ async def resume(self, queue_id: str) -> None:
+ """Handle RESUME command for given queue.
+
+ - queue_id: queue_id of the queue to handle the command.
+ """
+ queue = self._queues[queue_id]
+ queue_items = self._queue_items[queue_id]
+ resume_item = queue.current_item
+ next_item = queue.next_item
+ resume_pos = queue.elapsed_time
+ if (
+ resume_item
+ and next_item
+ and resume_item.duration
+ and resume_pos > (resume_item.duration * 0.9)
+ ):
+ # track is already played for > 90% - skip to next
+ resume_item = next_item
+ resume_pos = 0
+ elif queue.current_index is None and len(queue_items) > 0:
+ # items available in queue but no previous track, start at 0
+ resume_item = self.get_item(queue_id, 0)
+ resume_pos = 0
+
+ if resume_item is not None:
+ resume_pos = resume_pos if resume_pos > 10 else 0
+ fade_in = resume_pos > 0
+ await self.play_index(queue_id, resume_item.queue_item_id, resume_pos, fade_in)
+ else:
+ raise QueueEmpty(f"Resume queue requested but queue {queue_id} is empty")
+
+ @api_command("players/queue/play_index")
+ async def play_index(
+ self,
+ queue_id: str,
+ index: int | str,
+ seek_position: int = 0,
+ fade_in: bool = False,
+ ) -> None:
+ """Play item at index (or item_id) X in queue."""
+ queue = self._queues[queue_id]
+ if queue.announcement_in_progress:
+ LOGGER.warning("Ignore queue command for %s because an announcement is in progress.")
+ return
+ if isinstance(index, str):
+ index = self.index_by_id(queue_id, index)
+ queue_item = self.get_item(queue_id, index)
+ if queue_item is None:
+ raise FileNotFoundError(f"Unknown index/id: {index}")
+ queue.current_index = index
+ queue.index_in_buffer = index
+ # power on player if needed
+ await self.mass.players.cmd_power(queue_id, True)
+ # execute the play_media command on the player(s)
+ player_prov = self.mass.players.get_player_provider(queue_id)
+ flow_mode = self.mass.config.get_player_config_value(queue.queue_id, CONF_FLOW_MODE)
+ queue.flow_mode = flow_mode.value
+ await player_prov.cmd_play_media(
+ queue_id,
+ queue_item=queue_item,
+ seek_position=seek_position,
+ fade_in=fade_in,
+ flow_mode=flow_mode.value,
+ )
+
+ # Interaction with player
+
+ async def on_player_register(self, player: Player) -> None:
+ """Register PlayerQueue for given player/queue id."""
+ queue_id = player.player_id
+ # try to restore previous state
+ if prev_state := await self.mass.cache.get(f"queue.state.{queue_id}"):
+ queue = PlayerQueue.from_dict(prev_state)
+ prev_items = await self.mass.cache.get(f"queue.items.{queue_id}", default=[])
+ queue_items = [QueueItem.from_dict(x) for x in prev_items]
+ else:
+ queue = PlayerQueue(
+ queue_id=queue_id,
+ active=False,
+ display_name=player.display_name,
+ available=player.available,
+ items=0,
+ )
+ queue_items = []
+
+ self._queues[queue_id] = queue
+ self._queue_items[queue_id] = queue_items
+ # always call update to calculate state etc
+ self.on_player_update(player, {})
+
+ def on_player_update(self, player: Player, changed_keys: set[str]) -> None:
+ """Call when a PlayerQueue needs to be updated (e.g. when player updates)."""
+ if player.player_id not in self._queues:
+ self.mass.create_task(self.on_player_register(player))
+ return
+ queue_id = player.player_id
+ player = self.players.get(queue_id)
+ queue = self._queues[queue_id]
+
+ # copy most properties from the player
+ queue.display_name = player.display_name
+ queue.available = player.available
+ queue.items = len(self._queue_items[queue_id])
+ queue.state = player.state
+ queue.elapsed_time = int(player.corrected_elapsed_time)
+ queue.elapsed_time_last_updated = time.time()
+
+ # determine if this queue is currently active for this player
+ queue.active = player.active_queue == queue.queue_id
+ if queue.active:
+ # update current item from player report
+ player_item_index = self.index_by_id(queue_id, player.current_item_id)
+ if player_item_index is not None:
+ if queue.flow_mode:
+ # flow mode active, calculate current item
+ (
+ queue.current_index,
+ queue.elapsed_time,
+ ) = self.__get_queue_stream_index(queue, player, player_item_index)
+ else:
+ queue.current_index = player_item_index
+
+ queue.current_item = self.get_item(queue_id, queue.current_index)
+ queue.next_item = self.get_next_item(queue_id)
+
+ # correct elapsed time when seeking
+ if (
+ queue.current_item
+ and queue.current_item.streamdetails
+ and queue.current_item.streamdetails.seconds_skipped
+ and not queue.flow_mode
+ ):
+ queue.elapsed_time += queue.current_item.streamdetails.seconds_skipped
+
+ # basic throttle: do not send state changed events if queue did not actually change
+ prev_state = self._prev_states.get(queue_id, {})
+ new_state = self._queues[queue_id].to_dict()
+ changed_keys = get_changed_keys(prev_state, new_state)
+ self._prev_states[queue_id] = new_state
+
+ if len(changed_keys) == 0:
+ return
+
+ if "elapsed_time" in changed_keys:
+ self.mass.signal_event(
+ EventType.QUEUE_TIME_UPDATED,
+ object_id=queue_id,
+ data=queue.elapsed_time,
+ )
+ # do not send full updates if only time was updated
+ if changed_keys in (
+ {"elapsed_time_last_updated"},
+ {
+ "elapsed_time",
+ "elapsed_time_last_updated",
+ },
+ ):
+ # ignore
+ return
+
+ # only signal queue updated event if other properties than elapsed_time updated
+ self.signal_update(queue_id)
+ # watch dynamic radio items refill if needed
+ if "current_index" in changed_keys:
+ fill_index = len(self._queue_items[queue_id]) - 5
+ if queue.radio_source and (queue.current_index >= fill_index):
+ self.mass.create_task(self._fill_radio_tracks(queue_id))
+
+ def on_player_remove(self, player_id: str) -> None:
+ """Call when a player is removed from the registry."""
+ self.mass.create_task(self.mass.cache.delete(f"queue.state.{player_id}"))
+ self.mass.create_task(self.mass.cache.delete(f"queue.items.{player_id}"))
+ self._queues.pop(player_id, None)
+ self._queue_items.pop(player_id, None)
+
+ def player_ready_for_next_track(
+ self, queue_or_player_id: str, current_item_id: str | None = None
+ ) -> tuple[QueueItem, bool]:
+ """Call when a player is ready to load the next track into the buffer.
+
+ The result is a tuple of the next QueueItem to Play,
+ and a bool if the player should crossfade (if supported).
+ Raises QueueEmpty if there are no more tracks left.
+
+ NOTE: The player(s) should resolve the stream URL for the QueueItem,
+ just like with the play_media call.
+ """
+ queue = self.get_active_queue(queue_or_player_id)
+ if current_item_id is None:
+ cur_index = queue.current_index
+ else:
+ cur_index = self.index_by_id(queue.queue_id, current_item_id)
+ cur_item = self.get_item(queue.queue_id, cur_index)
+ next_index = self.get_next_index(queue.queue_id, cur_index)
+ next_item = self.get_item(queue.queue_id, next_index)
+ if not next_item:
+ raise QueueEmpty("No more tracks left in the queue.")
+ queue.index_in_buffer = next_index
+ # work out crossfade
+ crossfade = queue.crossfade_enabled
+ if (
+ cur_item.media_type == MediaType.TRACK
+ and next_item.media_type == MediaType.TRACK
+ and cur_item.media_item.album == next_item.media_item.album
+ ):
+ # disable crossfade if playing tracks from same album
+ # TODO: make this a bit more intelligent.
+ crossfade = False
+ return (next_item, crossfade)
+
+ # Main queue manipulation methods
+
+ def load(
+ self,
+ queue_id: str,
+ queue_items: list[QueueItem],
+ insert_at_index: int = 0,
+ keep_remaining: bool = True,
+ shuffle: bool = False,
+ ) -> None:
+ """Load new items at index.
+
+ - queue_id: id of the queue to process this request.
+ - queue_items: a list of QueueItems
+ - insert_at_index: insert the item(s) at this index
+ - keep_remaining: keep the remaining items after the insert
+ - shuffle: (re)shuffle the items after insert index
+ """
+ # keep previous/played items, append the new ones
+ prev_items = self._queue_items[queue_id][:insert_at_index]
+ next_items = queue_items
+
+ # if keep_remaining, append the old previous items
+ if keep_remaining:
+ next_items += prev_items[insert_at_index:]
+
+ # we set the original insert order as attribute so we can un-shuffle
+ for index, item in enumerate(next_items):
+ item.sort_index += insert_at_index + index
+ # (re)shuffle the final batch if needed
+ if shuffle:
+ next_items = random.sample(next_items, len(next_items))
+ self.update_items(queue_id, prev_items + next_items)
+
+ def update_items(self, queue_id: str, queue_items: list[QueueItem]) -> None:
+ """Update the existing queue items, mostly caused by reordering."""
+ self._queue_items[queue_id] = queue_items
+ self.signal_update(queue_id, True)
+
+ # Helper methods
+
+ def get_item(self, queue_id: str, item_id_or_index: int | str | None) -> QueueItem | None:
+ """Get queue item by index or item_id."""
+ if item_id_or_index is None:
+ return None
+ queue_items = self._queue_items[queue_id]
+ if isinstance(item_id_or_index, int) and len(queue_items) > item_id_or_index:
+ return queue_items[item_id_or_index]
+ if isinstance(item_id_or_index, str):
+ return next((x for x in queue_items if x.queue_item_id == item_id_or_index), None)
+ return None
+
+ def index_by_id(self, queue_id: str, queue_item_id: str) -> int | None:
+ """Get index by queue_item_id."""
+ queue_items = self._queue_items[queue_id]
+ for index, item in enumerate(queue_items):
+ if item.queue_item_id == queue_item_id:
+ return index
+ return None
+
+ def get_next_index(self, queue_id: str, cur_index: int | None, is_skip: bool = False) -> int:
+ """Return the next index for the queue, accounting for repeat settings."""
+ queue = self._queues[queue_id]
+ queue_items = self._queue_items[queue_id]
+ # handle repeat single track
+ if queue.repeat_mode == RepeatMode.ONE and not is_skip:
+ return cur_index
+ # handle repeat all
+ if (
+ queue.repeat_mode == RepeatMode.ALL
+ and queue_items
+ and cur_index == (len(queue_items) - 1)
+ ):
+ return 0
+ # simply return the next index. other logic is guarded to detect the index
+ # being higher than the number of items to detect end of queue and/or handle repeat.
+ if cur_index is None:
+ return 0
+ next_index = cur_index + 1
+ return next_index
+
+ def get_next_item(self, queue_id: str, cur_index: int | None = None) -> QueueItem | None:
+ """Return next QueueItem for given queue."""
+ queue = self._queues[queue_id]
+ if cur_index is None:
+ cur_index = queue.current_index
+ next_index = self.get_next_index(queue_id, queue.current_index)
+ return self.get_item(queue_id, next_index)
+
+ def signal_update(self, queue_id: str, items_changed: bool = False) -> None:
+ """Signal state changed of given queue."""
+ queue = self._queues[queue_id]
+ if items_changed:
+ self.mass.signal_event(EventType.QUEUE_ITEMS_UPDATED, object_id=queue_id, data=queue)
+ # save items in cache
+ self.mass.create_task(
+ self.mass.cache.set(
+ f"queue.items.{queue_id}",
+ [x.to_dict() for x in self._queue_items[queue_id]],
+ )
+ )
+
+ # always send the base event
+ self.mass.signal_event(EventType.QUEUE_UPDATED, object_id=queue_id, data=queue)
+ # save state
+ self.mass.create_task(
+ self.mass.cache.set(
+ f"queue.state.{queue_id}",
+ queue.to_dict(),
+ )
+ )
+
+ async def _fill_radio_tracks(self, queue_id: str) -> None:
+ """Fill a Queue with (additional) Radio tracks."""
+ queue = self._queues[queue_id]
+ assert queue.radio_source, "No Radio item(s) loaded/active!"
+ tracks: list[MediaItemType] = []
+ # grab dynamic tracks for (all) source items
+ # shuffle the source items, just in case
+ for radio_item in random.sample(queue.radio_source, len(queue.radio_source)):
+ ctrl = self.mass.music.get_controller(radio_item.media_type)
+ tracks += await ctrl.dynamic_tracks(
+ item_id=radio_item.item_id, provider_domain=radio_item.provider
+ )
+ # make sure we do not grab too much items
+ if len(tracks) >= 50:
+ break
+ # fill queue - filter out unavailable items
+ queue_items = [QueueItem.from_media_item(queue_id, x) for x in tracks if x.available]
+ self.load(
+ queue_id,
+ queue_items,
+ insert_at_index=len(self._queue_items[queue_id]) - 1,
+ )
+
+ def __get_queue_stream_index(
+ self, queue: PlayerQueue, player: Player, start_index: int
+ ) -> tuple[int, int]:
+ """Calculate current queue index and current track elapsed time."""
+ # player is playing a constant stream so we need to do this the hard way
+ queue_index = 0
+ elapsed_time_queue = player.corrected_elapsed_time
+ total_time = 0
+ track_time = 0
+ queue_items = self._queue_items[queue.queue_id]
+ if queue_items and len(queue_items) > start_index:
+ # start_index: holds the position from which the flow stream started
+ queue_index = start_index
+ queue_track = None
+ while len(queue_items) > queue_index:
+ # keep enumerating the queue tracks to find current track
+ # starting from the start index
+ queue_track = queue_items[queue_index]
+ if not queue_track.streamdetails:
+ track_time = elapsed_time_queue - total_time
+ break
+ if queue_track.streamdetails.seconds_streamed is not None:
+ track_duration = queue_track.streamdetails.seconds_streamed
+ else:
+ track_duration = queue_track.duration or FALLBACK_DURATION
+ if elapsed_time_queue > (track_duration + total_time):
+ # total elapsed time is more than (streamed) track duration
+ # move index one up
+ total_time += track_duration
+ queue_index += 1
+ else:
+ # no more seconds left to divide, this is our track
+ # account for any seeking by adding the skipped seconds
+ track_sec_skipped = queue_track.streamdetails.seconds_skipped or 0
+ track_time = elapsed_time_queue + track_sec_skipped - total_time
+ break
+ return queue_index, track_time
--- /dev/null
+"""Logic to play music from MusicProviders to supported players."""
+from __future__ import annotations
+
+import asyncio
+import logging
+from collections.abc import Iterator
+from typing import TYPE_CHECKING, cast
+
+from music_assistant.common.helpers.util import get_changed_keys
+from music_assistant.common.models.enums import (
+ EventType,
+ PlayerFeature,
+ PlayerState,
+ PlayerType,
+ ProviderType,
+)
+from music_assistant.common.models.errors import (
+ AlreadyRegisteredError,
+ PlayerCommandFailed,
+ PlayerUnavailableError,
+ UnsupportedFeaturedException,
+)
+from music_assistant.common.models.player import Player
+from music_assistant.constants import CONF_PLAYERS, ROOT_LOGGER_NAME
+from music_assistant.server.helpers.api import api_command
+from music_assistant.server.models.player_provider import PlayerProvider
+
+from .player_queues import PlayerQueuesController
+
+if TYPE_CHECKING:
+ from music_assistant.server import MusicAssistant
+
+LOGGER = logging.getLogger(f"{ROOT_LOGGER_NAME}.players")
+
+
+class PlayerController:
+ """Controller holding all logic to control registered players."""
+
+ def __init__(self, mass: MusicAssistant) -> None:
+ """Initialize class."""
+ self.mass = mass
+ self._players: dict[str, Player] = {}
+ self._prev_states: dict[str, dict] = {}
+ self.queues = PlayerQueuesController(self)
+
+ async def setup(self) -> None:
+ """Async initialize of module."""
+ self.mass.create_task(self._poll_players())
+
+ async def close(self) -> None:
+ """Cleanup on exit."""
+ await self.queues.close()
+
+ @property
+ def providers(self) -> list[PlayerProvider]:
+ """Return all loaded/running MusicProviders."""
+ return self.mass.get_providers(ProviderType.MUSIC) # type: ignore=return-value
+
+ def __iter__(self) -> Iterator[Player]:
+ """Iterate over (available) players."""
+ return iter(self._players.values())
+
+ @api_command("players/all")
+ def all(self) -> tuple[Player]:
+ """Return all registered players."""
+ return tuple(self._players.values())
+
+ @api_command("players/get")
+ def get(
+ self,
+ player_id: str,
+ raise_unavailable: bool = False,
+ ) -> Player:
+ """Return Player by player_id."""
+ if player := self._players.get(player_id):
+ if not player.available and raise_unavailable:
+ raise PlayerUnavailableError(f"Player {player_id} is not available")
+ return player
+ raise PlayerUnavailableError(f"Player {player_id} does not exist")
+
+ @api_command("players/get_by_name")
+ def get_by_name(self, name: str) -> Player | None:
+ """Return Player by name or None if no match is found."""
+ return next((x for x in self._players.values() if x.name == name), None)
+
+ @api_command("players/set")
+ def set(self, player: Player) -> None:
+ """Set/Update player details on the controller."""
+ if player.player_id not in self._players:
+ # new player
+ self.register(player)
+ return
+ self._players[player.player_id] = player
+ self.update(player.player_id)
+
+ @api_command("players/register")
+ def register(self, player: Player) -> None:
+ """Register a new player on the controller."""
+ if self.mass.closing:
+ return
+ player_id = player.player_id
+
+ if player_id in self._players:
+ raise AlreadyRegisteredError(f"Player {player_id} is already registered")
+
+ # register playerqueue for this player
+ self.mass.create_task(self.queues.on_player_register(player))
+
+ self._players[player_id] = player
+
+ LOGGER.info(
+ "Player registered: %s/%s",
+ player_id,
+ player.name,
+ )
+ self.mass.signal_event(EventType.PLAYER_ADDED, object_id=player.player_id, data=player)
+
+ @api_command("players/register_or_update")
+ def register_or_update(self, player: Player) -> None:
+ """Register a new player on the controller or update existing one."""
+ if self.mass.closing:
+ return
+
+ if player.player_id in self._players:
+ self.update(player.player_id)
+ return
+
+ self.register(player)
+
+ @api_command("players/remove")
+ def remove(self, player_id: str) -> None:
+ """Remove a player from the registry."""
+ player = self._players.pop(player_id, None)
+ if player is None:
+ return
+ LOGGER.info("Player removed: %s", player.name)
+ self.queues.on_player_remove(player_id)
+ self.mass.config.remove(f"players/{player_id}")
+ self._prev_states.pop(player_id, None)
+ self.mass.signal_event(EventType.PLAYER_REMOVED, player_id)
+
+ @api_command("players/update")
+ def update(self, player_id: str, skip_forward: bool = False) -> None:
+ """Update player state."""
+ if player_id not in self._players:
+ return
+ player = self._players[player_id]
+ # calculate active_queue
+ player.active_queue = self._get_active_queue(player)
+ # calculate group volume
+ player.group_volume = self._get_group_volume_level(player)
+ # prefer any overridden name from config
+ player.display_name = (
+ self.mass.config.get(f"{CONF_PLAYERS}/{player_id}/name")
+ or player.name
+ or player.player_id
+ )
+ # set player state to off if player is not powered
+ if player.powered and player.state == PlayerState.OFF:
+ player.state = PlayerState.IDLE
+ elif not player.powered:
+ player.state = PlayerState.OFF
+ # basic throttle: do not send state changed events if player did not actually change
+ prev_state = self._prev_states.get(player_id, {})
+ new_state = self._players[player_id].to_dict()
+ changed_keys = get_changed_keys(
+ prev_state,
+ new_state,
+ ignore_keys=["elapsed_time", "elapsed_time_last_updated"],
+ )
+ self._prev_states[player_id] = new_state
+
+ if not player.enabled and "enabled" not in changed_keys:
+ # ignore updates for disabled players
+ return
+
+ # always signal update to the playerqueue
+ self.queues.on_player_update(player, changed_keys)
+
+ if len(changed_keys) == 0:
+ return
+
+ self.mass.signal_event(EventType.PLAYER_UPDATED, object_id=player_id, data=player)
+
+ if skip_forward:
+ return
+ if player.type == PlayerType.GROUP:
+ # update group player child's when parent updates
+ for child_player_id in player.group_childs:
+ if child_player_id == player_id:
+ continue
+ self.update(child_player_id, skip_forward=True)
+
+ # update group player(s) when child updates
+ for group_player in self._get_player_groups(player_id):
+ self.update(group_player.player_id, skip_forward=True)
+
+ def get_player_provider(self, player_id: str) -> PlayerProvider:
+ """Return PlayerProvider for given player."""
+ player = self._players[player_id]
+ player_provider = self.mass.get_provider(player.provider)
+ return cast(PlayerProvider, player_provider)
+
+ # Player commands
+
+ @api_command("players/cmd/stop")
+ async def cmd_stop(self, player_id: str) -> None:
+ """Send STOP command to given player.
+
+ - player_id: player_id of the player to handle the command.
+ """
+ player_id = self._check_redirect(player_id)
+ player_provider = self.get_player_provider(player_id)
+ await player_provider.cmd_stop(player_id)
+
+ @api_command("players/cmd/play")
+ async def cmd_play(self, player_id: str) -> None:
+ """Send PLAY (unpause) command to given player.
+
+ - player_id: player_id of the player to handle the command.
+ """
+ player_id = self._check_redirect(player_id)
+ player_provider = self.get_player_provider(player_id)
+ await player_provider.cmd_play(player_id)
+
+ @api_command("players/cmd/pause")
+ async def cmd_pause(self, player_id: str) -> None:
+ """Send PAUSE command to given player.
+
+ - player_id: player_id of the player to handle the command.
+ """
+ player_id = self._check_redirect(player_id)
+ player_provider = self.get_player_provider(player_id)
+ await player_provider.cmd_pause(player_id)
+
+ async def _watch_pause(_player_id: str) -> None:
+ player = self.get(_player_id)
+ count = 0
+ # wait for pause
+ while count < 5 and player.state == PlayerState.PLAYING:
+ count += 1
+ await asyncio.sleep(1)
+ # wait for unpause
+ if player.state != PlayerState.PAUSED:
+ return
+ count = 0
+ while count < 30 and player.state == PlayerState.PAUSED:
+ count += 1
+ await asyncio.sleep(1)
+ # if player is still paused when the limit is reached, send stop
+ if player.state == PlayerState.PAUSED:
+ await self.cmd_stop(_player_id)
+
+ # we auto stop a player from paused when its paused for 30 seconds
+ self.mass.create_task(_watch_pause(player_id))
+
+ @api_command("players/cmd/play_pause")
+ async def cmd_play_pause(self, player_id: str) -> None:
+ """Toggle play/pause on given player.
+
+ - player_id: player_id of the player to handle the command.
+ """
+ player = self.get(player_id, True)
+ if player.state == PlayerState.PLAYING:
+ await self.cmd_pause(player_id)
+ else:
+ await self.cmd_play(player_id)
+
+ @api_command("players/cmd/power")
+ async def cmd_power(self, player_id: str, powered: bool) -> None:
+ """Send POWER command to given player.
+
+ - player_id: player_id of the player to handle the command.
+ - powered: bool if player should be powered on or off.
+ """
+ # TODO: Implement PlayerControl
+ # TODO: Handle group power
+ player = self.get(player_id, True)
+ if player.powered == powered:
+ return
+ # stop player at power off
+ if not powered and player.state in (PlayerState.PLAYING, PlayerState.PAUSED):
+ await self.cmd_stop(player_id)
+ if PlayerFeature.POWER not in player.supported_features:
+ player.powered = powered
+ self.update(player_id)
+ return
+ player_provider = self.get_player_provider(player_id)
+ await player_provider.cmd_power(player_id, powered)
+
+ @api_command("players/cmd/volume_set")
+ async def cmd_volume_set(self, player_id: str, volume_level: int) -> None:
+ """Send VOLUME_SET command to given player.
+
+ - player_id: player_id of the player to handle the command.
+ - volume_level: volume level (0..100) to set on the player.
+ """
+ # TODO: Implement PlayerControl
+ player = self.get(player_id, True)
+ if PlayerFeature.VOLUME_SET not in player.supported_features:
+ LOGGER.warning(
+ "Volume set command called but player %s does not support volume",
+ player_id,
+ )
+ player.volume_level = volume_level
+ self.update(player_id)
+ return
+ player_provider = self.get_player_provider(player_id)
+ await player_provider.cmd_volume_set(player_id, volume_level)
+
+ @api_command("players/cmd/group_volume")
+ async def cmd_group_volume(self, player_id: str, volume_level: int) -> None:
+ """Send VOLUME_SET command to given playergroup.
+
+ Will send the new (average) volume level to group child's.
+ - player_id: player_id of the playergroup to handle the command.
+ - volume_level: volume level (0..100) to set on the player.
+ """
+ group_player = self.get(player_id, True)
+ assert group_player
+ # handle group volume by only applying the volume to powered members
+ cur_volume = group_player.volume_level
+ new_volume = volume_level
+ volume_dif = new_volume - cur_volume
+ volume_dif_percent = 1 + new_volume / 100 if cur_volume == 0 else volume_dif / cur_volume
+ coros = []
+ for child_player in self._get_child_players(group_player, True):
+ cur_child_volume = child_player.volume_level
+ new_child_volume = int(cur_child_volume + (cur_child_volume * volume_dif_percent))
+ coros.append(self.cmd_volume_set(child_player.player_id, new_child_volume))
+ await asyncio.gather(*coros)
+
+ @api_command("players/cmd/volume_mute")
+ async def cmd_volume_mute(self, player_id: str, muted: bool) -> None:
+ """Send VOLUME_MUTE command to given player.
+
+ - player_id: player_id of the player to handle the command.
+ - muted: bool if player should be muted.
+ """
+ player = self.get(player_id, True)
+ assert player
+ if PlayerFeature.VOLUME_MUTE not in player.supported_features:
+ LOGGER.warning("Mute command called but player %s does not support muting", player_id)
+ player.volume_muted = muted
+ self.update(player_id)
+ return
+ # TODO: Implement PlayerControl
+ player_provider = self.get_player_provider(player_id)
+ await player_provider.cmd_volume_mute(player_id, muted)
+
+ @api_command("players/cmd/sync")
+ async def cmd_sync(self, player_id: str, target_player: str) -> None:
+ """Handle SYNC command for given player.
+
+ Join/add the given player(id) to the given (master) player/sync group.
+ If the player is already synced to another player, it will be unsynced there first.
+ If the target player itself is already synced to another player, this will fail.
+ If the player can not be synced with the given target player, this will fail.
+
+ - player_id: player_id of the player to handle the command.
+ - target_player: player_id of the syncgroup master or group player.
+ """
+ child_player = self.get(player_id, True)
+ parent_player = self.get(target_player, True)
+ assert child_player
+ assert parent_player
+ if PlayerFeature.SYNC not in child_player.supported_features:
+ raise UnsupportedFeaturedException(
+ f"Player {child_player.name} does not support (un)sync commands"
+ )
+ if PlayerFeature.SYNC not in parent_player.supported_features:
+ raise UnsupportedFeaturedException(
+ f"Player {parent_player.name} does not support (un)sync commands"
+ )
+ if parent_player.synced_to is not None:
+ raise PlayerCommandFailed(
+ f"Player {target_player} is already synced to another player."
+ )
+ if player_id not in parent_player.can_sync_with:
+ raise PlayerCommandFailed(f"Player {player_id} can not be synced to {target_player}.")
+ if child_player.synced_to:
+ if child_player.synced_to == parent_player.player_id:
+ # nothing to do: already synced to this parent
+ return
+ # player already synced, unsync first
+ await self.cmd_unsync(child_player.player_id)
+ # stop child player if it is currently playing
+ if child_player.state == PlayerState.PLAYING:
+ await self.cmd_stop(player_id)
+ # all checks passed, forward command to the player provider
+ player_provider = self.get_player_provider(player_id)
+ await player_provider.cmd_sync(player_id, target_player)
+
+ @api_command("players/cmd/unsync")
+ async def cmd_unsync(self, player_id: str) -> None:
+ """Handle UNSYNC command for given player.
+
+ Remove the given player from any syncgroups it currently is synced to.
+ If the player is not currently synced to any other player,
+ this will silently be ignored.
+
+ - player_id: player_id of the player to handle the command.
+ """
+ player = self.get(player_id, True)
+ if PlayerFeature.SYNC not in player.supported_features:
+ raise UnsupportedFeaturedException(f"Player {player.name} does not support syncing")
+ if not player.synced_to:
+ LOGGER.info(
+ "Ignoring command to unsync player %s "
+ "because it is currently not part of a (sync)group."
+ )
+ return
+
+ # all checks passed, forward command to the player provider
+ player_provider = self.get_player_provider(player_id)
+ await player_provider.cmd_unsync(player_id)
+
+ def _check_redirect(self, player_id: str) -> str:
+ """Check if playback related command should be redirected."""
+ player = self.get(player_id, True)
+ if player.synced_to:
+ sync_master = self.get(player.synced_to, True)
+ LOGGER.warning(
+ "Player %s is synced to %s and can not accept "
+ "playback related commands itself, "
+ "redirected the command to the sync leader.",
+ player.name,
+ sync_master.name,
+ )
+ return player.synced_to
+ return player_id
+
+ def _get_player_groups(self, player_id: str) -> tuple[Player, ...]:
+ """Return all (player_ids of) any groupplayers the given player belongs to."""
+ return tuple(x for x in self if player_id in x.group_childs)
+
+ def _get_active_queue(self, player: Player) -> str:
+ """Return the active_queue id for given player."""
+ # if player is synced, return master/group leader
+ if player.synced_to and player.synced_to in self._players:
+ return self._get_active_queue(self.get(player.synced_to))
+ # iterate player groups to find out if one is playing
+ if group_players := self._get_player_groups(player.player_id):
+ # prefer the first playing (or paused) group parent
+ for group_player in group_players:
+ if group_player.state in (PlayerState.PLAYING, PlayerState.PAUSED):
+ return group_player.player_id
+ # fallback to the first powered group player
+ for group_player in group_players:
+ if group_player.powered:
+ return group_player.player_id
+ # defaults to the player's own player id
+ return player.player_id
+
+ def _get_group_volume_level(self, player: Player) -> int:
+ """Calculate a group volume from the grouped members."""
+ if not player.group_childs:
+ # player is not a group
+ return player.volume_level
+ # calculate group volume from all (turned on) players
+ group_volume = 0
+ active_players = 0
+ for child_player in self._get_child_players(player, True):
+ group_volume += child_player.volume_level
+ active_players += 1
+ if active_players:
+ group_volume = group_volume / active_players
+ return int(group_volume)
+
+ def _get_child_players(
+ self,
+ player: Player,
+ only_powered: bool = False,
+ only_playing: bool = False,
+ ) -> list[Player]:
+ """Get (child) players attached to a grouped player."""
+ child_players: list[Player] = []
+ if not player.group_childs:
+ # player is not a group
+ return child_players
+ if player.type != PlayerType.GROUP:
+ # if the player is not a dedicated player group,
+ # it is the master in a sync group and thus always present as child player
+ child_players.append(player)
+ for child_id in player.group_childs:
+ if child_player := self.get(child_id):
+ if not (not only_powered or child_player.powered):
+ continue
+ if not (
+ not only_playing
+ or child_player.state in (PlayerState.PLAYING, PlayerState.PAUSED)
+ ):
+ continue
+ child_players.append(child_player)
+ return child_players
+
+ async def _poll_players(self) -> None:
+ """Background task that polls players for updates."""
+ count = 0
+ while True:
+ count += 1
+ for player_id, player in self._players.items():
+ # if the player is playing, update elapsed time every tick
+ # to ensure the queue has accurate details
+ player_playing = (
+ player.active_queue == player.player_id and player.state == PlayerState.PLAYING
+ )
+ if player_playing:
+ self.update(player_id)
+ # Poll player;
+ # - every 360 seconds if the player if not powered
+ # - every 30 seconds if the player is powered
+ # - every 10 seconds if the player is playing
+ if (
+ (player.powered and count % 30 == 0)
+ or (player_playing and count % 10 == 0)
+ or count == 360
+ ):
+ if player_prov := self.get_player_provider(player_id):
+ try:
+ await player_prov.poll_player(player_id)
+ except PlayerUnavailableError:
+ player.available = False
+ self.update(player_id)
+ except Exception as err: # pylint: disable=broad-except
+ LOGGER.warning(
+ "Error while requesting latest state from player %s: %s",
+ player.display_name,
+ str(err),
+ exc_info=err,
+ )
+ if count >= 360:
+ count = 0
+ await asyncio.sleep(1)
--- /dev/null
+"""Controller to stream audio to players."""
+from __future__ import annotations
+
+import asyncio
+import logging
+import urllib.parse
+from collections.abc import AsyncGenerator
+from typing import TYPE_CHECKING, Any
+
+import shortuuid
+from aiohttp import web
+
+from music_assistant.common.helpers.util import empty_queue
+from music_assistant.common.models.enums import ContentType, PlayerState
+from music_assistant.common.models.errors import MediaNotFoundError, QueueEmpty
+from music_assistant.common.models.queue_item import QueueItem
+from music_assistant.constants import (
+ CONF_EQ_BASS,
+ CONF_EQ_MID,
+ CONF_EQ_TREBLE,
+ CONF_OUTPUT_CHANNELS,
+ ROOT_LOGGER_NAME,
+)
+from music_assistant.server.helpers.audio import (
+ check_audio_support,
+ crossfade_pcm_parts,
+ get_media_stream,
+ get_preview_stream,
+ get_stream_details,
+)
+from music_assistant.server.helpers.process import AsyncProcess
+
+if TYPE_CHECKING:
+ from music_assistant.common.models.player import Player
+ from music_assistant.server import MusicAssistant
+
+LOGGER = logging.getLogger(f"{ROOT_LOGGER_NAME}.streams")
+
+
+class StreamJob:
+ """Representation of a (multisubscriber) Audio Queue (item)stream job/task.
+
+ The whole idea here is that in case of a player (sync)group,
+ all players receive the exact same PCM audio chunks from the source audio.
+ A StreamJob is tied to a queueitem,
+ meaning that streaming of each QueueItem will have its own StreamJob.
+ In case a QueueItem is restarted (e.g. when seeking), a new StreamJob will be created.
+ """
+
+ def __init__(
+ self,
+ queue_item: QueueItem,
+ pcm_sample_rate: int,
+ pcm_bit_depth: int,
+ audio_source: AsyncGenerator[bytes, None] | None = None,
+ flow_mode: bool = False,
+ ) -> None:
+ """Initialize MultiQueue instance."""
+ self.queue_item = queue_item
+ self.audio_source = audio_source
+ # internally all audio within MA is raw PCM, hence the pcm details
+ self.pcm_sample_rate = pcm_sample_rate
+ self.pcm_bit_depth = pcm_bit_depth
+ self.pcm_sample_size = int(pcm_sample_rate * (pcm_bit_depth / 8) * 2)
+ self.stream_id = shortuuid.uuid()
+ self.expected_consumers: set[str] = set()
+ self.flow_mode = flow_mode
+ self.subscribers: dict[str, asyncio.Queue[bytes]] = {}
+ self._all_clients_connected = asyncio.Event()
+ self._audio_task: asyncio.Task | None = None
+ self.seen_players: set[str] = set()
+
+ @property
+ def finished(self) -> bool:
+ """Return if this StreamJob is finished."""
+ if self._audio_task is None:
+ return False
+ if not self._all_clients_connected.is_set():
+ return False
+ return self._audio_task.cancelled() or self._audio_task.done()
+
+ @property
+ def pending(self) -> bool:
+ """Return if this Job is pending start."""
+ return not self._all_clients_connected.is_set()
+
+ @property
+ def running(self) -> bool:
+ """Return if this Job is running."""
+ return not self.finished and not self.pending
+
+ async def subscribe(self, player_id: str) -> AsyncGenerator[bytes, None]:
+ """Subscribe consumer and iterate incoming chunks on the queue."""
+ self.start()
+ self.seen_players.add(player_id)
+ try:
+ sub_queue = asyncio.Queue(3)
+
+ # some checks
+ assert player_id not in self.subscribers, "No duplicate subscriptions allowed"
+ assert not self.finished, "Already finished"
+ assert not self.running, "Already running"
+
+ self.subscribers[player_id] = sub_queue
+ if len(self.subscribers) == len(self.expected_consumers):
+ # we reached the number of expected subscribers, set event
+ # so that chunks can be pushed
+ self._all_clients_connected.set()
+ else:
+ # wait until all expected subscribers arrived
+ # TODO: handle edge case where a player does not connect at all ?!
+ await self._all_clients_connected.wait()
+
+ # keep reading audio chunks from the queue until we receive an empty one
+ while True:
+ chunk = await sub_queue.get()
+ if chunk == b"":
+ # EOF chunk received
+ break
+ yield chunk
+ finally:
+ # some delay here to detect misbehaving (reconnecting) players
+ await asyncio.sleep(2)
+ empty_queue(sub_queue)
+ self.subscribers.pop(player_id)
+ await asyncio.sleep(2)
+ # check if this was the last subscriber and we should cancel
+ if len(self.subscribers) == 0 and self._audio_task and not self.finished:
+ self._audio_task.cancel()
+
+ async def _put_data(self, data: Any, timeout: float = 600) -> None:
+ """Put chunk of data to all subscribers."""
+ async with asyncio.timeout(timeout):
+ async with asyncio.TaskGroup() as tg:
+ for sub_id in self.subscribers:
+ sub_queue = self.subscribers[sub_id]
+ tg.create_task(sub_queue.put(data))
+
+ async def _stream_job_runner(self) -> None:
+ """Feed audio chunks to StreamJob subscribers."""
+ chunk_num = 0
+ async for chunk in self.audio_source:
+ chunk_num += 1
+ if chunk_num == 1:
+ # wait until all expected clients are connected
+ try:
+ async with asyncio.timeout(10):
+ await self._all_clients_connected.wait()
+ except TimeoutError as err:
+ if len(self.subscribers) == 0:
+ raise TimeoutError("Clients did not connect within 10 seconds.") from err
+ self._all_clients_connected.set()
+ LOGGER.warning(
+ "Starting stream job %s but not all clients connected within 10 seconds."
+ )
+
+ await self._put_data(chunk)
+
+ # mark EOF with empty chunk
+ await self._put_data(b"")
+
+ def start(self) -> None:
+ """Start running the stream job."""
+ if self._audio_task:
+ return
+ self._audio_task = asyncio.create_task(self._stream_job_runner())
+
+
+class StreamsController:
+ """Controller to stream audio to players."""
+
+ def __init__(self, mass: MusicAssistant):
+ """Initialize instance."""
+ self.mass = mass
+ # streamjobs contains all active stream jobs
+ # there may be multiple jobs for the same queue item (e.g. when seeking)
+ # the key is the (unique) stream_id for the StreamJob
+ self.stream_jobs: dict[str, StreamJob] = {}
+ # some players do multiple GET requests for the same audio stream
+ # to determine content type or content length
+ # we try to detect/report these players and workaround it.
+ # if a player_id is in the below set of player_ids, the first GET request
+ # of that player will be ignored and audio is served only in the 2nd request
+ self.workaround_players: set[str] = set()
+
+ async def setup(self) -> None:
+ """Async initialize of module."""
+ self.mass.webapp.router.add_get("/stream/preview", self._serve_preview)
+ self.mass.webapp.router.add_get(
+ "/stream/{player_id}/{queue_item_id}/{stream_id}.{fmt}",
+ self._serve_queue_stream,
+ )
+
+ ffmpeg_present, libsoxr_support = await check_audio_support(True)
+ if not ffmpeg_present:
+ LOGGER.error("FFmpeg binary not found on your system, playback will NOT work!.")
+ elif not libsoxr_support:
+ LOGGER.warning(
+ "FFmpeg version found without libsoxr support, "
+ "highest quality audio not available. "
+ )
+ await self._cleanup_stale()
+ LOGGER.info("Started stream controller")
+
+ async def close(self) -> None:
+ """Cleanup on exit."""
+
+ async def resolve_stream_url(
+ self,
+ queue_item: QueueItem,
+ player_id: str,
+ seek_position: int = 0,
+ fade_in: bool = False,
+ content_type: ContentType = ContentType.WAV,
+ auto_start_runner: bool = True,
+ flow_mode: bool = False,
+ ) -> str:
+ """Resolve the stream URL for the given QueueItem.
+
+ This is called just-in-time by the player implementation to get the URL to the audio.
+ It will create a StreamJob which is a background task responsible for feeding
+ the PCM audio chunks to the consumer(s).
+
+ - queue_item: the QueueItem that is about to be played (or buffered).
+ - player_id: the player_id of the player that will play the stream.
+ In case of a multi subscriber stream (e.g. sync/groups),
+ call resolve for every child player.
+ - seek_position: start playing from this specific position.
+ - fade_in: fade in the music at start (e.g. at resume).
+ - content_type: Encode the stream in the given format.
+ - auto_start_runner: Start the audio stream in advance (stream track now).
+ - flow_mode: enable flow mode where the queue tracks are streamed as continuous stream.
+ """
+ # check if there is already a pending job
+ for stream_job in self.stream_jobs.values():
+ if stream_job.finished or stream_job.running:
+ continue
+ if stream_job.queue_item.queue_id != queue_item.queue_id:
+ continue
+ if stream_job.queue_item.queue_item_id != queue_item.queue_item_id:
+ continue
+ # if we hit this point, we have a match
+ break
+ else:
+ # register a new stream job
+ if flow_mode:
+ # flow mode streamjob
+ sample_rate = 48000 # hardcoded for now
+ bit_depth = 24 # hardcoded for now
+ stream_job = StreamJob(
+ queue_item=queue_item,
+ pcm_sample_rate=sample_rate,
+ pcm_bit_depth=bit_depth,
+ flow_mode=True,
+ )
+ stream_job.audio_source = self._get_flow_stream(
+ stream_job, seek_position=seek_position, fade_in=fade_in
+ )
+ else:
+ # regular streamjob
+ streamdetails = await get_stream_details(self.mass, queue_item)
+ stream_job = StreamJob(
+ queue_item=queue_item,
+ audio_source=get_media_stream(
+ self.mass,
+ streamdetails=streamdetails,
+ seek_position=seek_position,
+ fade_in=fade_in,
+ ),
+ pcm_sample_rate=streamdetails.sample_rate,
+ pcm_bit_depth=streamdetails.bit_depth,
+ )
+
+ stream_job.expected_consumers.add(player_id)
+ self.stream_jobs[stream_job.stream_id] = stream_job
+ if auto_start_runner:
+ stream_job.start()
+
+ # generate player-specific URL for the stream job
+ fmt = content_type.value
+ url = f"{self.mass.base_url}/stream/{player_id}/{queue_item.queue_item_id}/{stream_job.stream_id}.{fmt}" # noqa: E501
+ return url
+
+ async def get_preview_url(self, provider: str, track_id: str) -> str:
+ """Return url to short preview sample."""
+ enc_track_id = urllib.parse.quote(track_id)
+ return f"{self.mass.base_url}/preview?provider={provider}&item_id={enc_track_id}"
+
+ async def _serve_queue_stream(self, request: web.Request) -> web.Response:
+ """Serve Queue Stream audio to player(s)."""
+ LOGGER.debug(
+ "Got %s request to %s from %s\nheaders: %s\n",
+ request.method,
+ request.path,
+ request.remote,
+ request.headers,
+ )
+ player_id = request.match_info["player_id"]
+ player = self.mass.players.get(player_id)
+ if not player:
+ raise web.HTTPNotFound(reason=f"Unknown player_id: {player_id}")
+ stream_id = request.match_info["stream_id"]
+ stream_job = self.stream_jobs.get(stream_id)
+ if not stream_job or stream_job.finished:
+ # Player is trying to play a stream that already exited
+ if player.state == PlayerState.PAUSED:
+ await self.mass.players.queues.resume(player_id)
+ LOGGER.warning(
+ "Got stream request for an already finished stream job for player %s",
+ player.display_name,
+ )
+ raise web.HTTPNotFound(reason=f"Unknown stream_id: {stream_id}")
+
+ output_format_str = request.match_info["fmt"]
+ output_format = ContentType.try_parse(output_format_str)
+ output_sample_rate = min(stream_job.pcm_sample_rate, player.max_sample_rate)
+ player_max_bit_depth = 32 if player.supports_24bit else 16
+ output_bit_depth = min(stream_job.pcm_bit_depth, player_max_bit_depth)
+ if output_format == ContentType.PCM:
+ # resolve generic pcm type
+ output_format = ContentType.from_bit_depth(output_bit_depth)
+ if output_format.is_pcm() or output_format == ContentType.WAV:
+ output_channels = self.mass.config.get_player_config_value(
+ player_id, CONF_OUTPUT_CHANNELS
+ ).value
+ channels = 1 if output_channels != "stereo" else 2
+ output_format_str = (
+ f"x-wav;codec=pcm;rate={output_sample_rate};"
+ f"bitrate={output_bit_depth};channels={channels}"
+ )
+
+ # prepare request, add some DLNA/UPNP compatible headers
+ enable_icy = request.headers.get("Icy-MetaData", "") == "1"
+ icy_meta_interval = 65536 if output_format.is_lossless() else 8192
+ headers = {
+ "Content-Type": f"audio/{output_format_str}",
+ "transferMode.dlna.org": "Streaming",
+ "contentFeatures.dlna.org": "DLNA.ORG_OP=00;DLNA.ORG_CI=0;DLNA.ORG_FLAGS=0d500000000000000000000000000000", # noqa: E501
+ "Cache-Control": "no-cache",
+ "Connection": "close",
+ "icy-name": "Music Assistant",
+ "icy-pub": "1",
+ }
+ if enable_icy:
+ headers["icy-metaint"] = str(icy_meta_interval)
+
+ resp = web.StreamResponse(
+ status=200,
+ reason="OK",
+ headers=headers,
+ )
+ await resp.prepare(request)
+
+ # return early if this is only a HEAD request
+ if request.method == "HEAD":
+ return resp
+
+ # handler workaround for players that do 2 multiple GET requests
+ # for the same audio stream (because of the missing duration/length)
+ if player_id in self.workaround_players and player_id not in stream_job.seen_players:
+ stream_job.seen_players.add(player_id)
+ return resp
+
+ # guard for the same player connecting multiple times for the same stream
+ if player_id in stream_job.subscribers:
+ LOGGER.error(
+ "Player %s is making multiple requests for the same stream,"
+ " please create an issue report on the Music Assistant issue tracker.",
+ player.display_name,
+ )
+ # add the player to the list of players that need the workaround
+ self.workaround_players.add(player_id)
+ raise web.HTTPBadRequest(reason="Multiple connections are not allowed.")
+ if stream_job.running:
+ LOGGER.error(
+ "Player %s is making a request for an already running stream,"
+ " please create an issue report on the Music Assistant issue tracker.",
+ player.display_name,
+ )
+ raise web.HTTPBadRequest(reason="Stream is already running.")
+
+ # all checks passed, start streaming!
+ LOGGER.debug("Start serving audio stream %s to %s", stream_id, player.name)
+
+ # collect player specific ffmpeg args to re-encode the source PCM stream
+ ffmpeg_args = self._get_player_ffmpeg_args(
+ player,
+ input_sample_rate=stream_job.pcm_sample_rate,
+ input_bit_depth=stream_job.pcm_bit_depth,
+ output_format=output_format,
+ output_sample_rate=output_sample_rate,
+ )
+
+ async with AsyncProcess(ffmpeg_args, True) as ffmpeg_proc:
+ # feed stdin with pcm audio chunks from origin
+ async def read_audio():
+ async for chunk in stream_job.subscribe(player_id):
+ await ffmpeg_proc.write(chunk)
+ ffmpeg_proc.write_eof()
+
+ ffmpeg_proc.attach_task(read_audio())
+
+ # read final chunks from stdout
+ iterator = (
+ ffmpeg_proc.iter_chunked(icy_meta_interval)
+ if enable_icy
+ else ffmpeg_proc.iter_any()
+ )
+
+ bytes_streamed = 0
+
+ async for chunk in iterator:
+ try:
+ await resp.write(chunk)
+ bytes_streamed += len(chunk)
+
+ # do not allow the player to prebuffer more than 10 seconds
+ seconds_streamed = int(bytes_streamed / stream_job.pcm_sample_size)
+ if (
+ seconds_streamed > 10
+ and (seconds_streamed - player.corrected_elapsed_time) > 10
+ ):
+ await asyncio.sleep(1)
+
+ if not enable_icy:
+ continue
+
+ # if icy metadata is enabled, send the icy metadata after the chunk
+ item_in_buf = stream_job.queue_item
+ if item_in_buf and item_in_buf.streamdetails.stream_title:
+ title = item_in_buf.streamdetails.stream_title
+ elif item_in_buf and item_in_buf.name:
+ title = item_in_buf.name
+ else:
+ title = "Music Assistant"
+ metadata = f"StreamTitle='{title}';".encode()
+ while len(metadata) % 16 != 0:
+ metadata += b"\x00"
+ length = len(metadata)
+ length_b = chr(int(length / 16)).encode()
+ await resp.write(length_b + metadata)
+
+ except (BrokenPipeError, ConnectionResetError):
+ # connection lost
+ break
+
+ return resp
+
+ async def _get_flow_stream(
+ self,
+ stream_job: StreamJob,
+ seek_position: int = 0,
+ fade_in: bool = False,
+ ) -> AsyncGenerator[bytes, None]:
+ """Get a flow stream of all tracks in the queue."""
+ # ruff: noqa: PLR0915
+ queue_id = stream_job.queue_item.queue_id
+ queue = self.mass.players.queues.get(queue_id)
+ queue_track = None
+ last_fadeout_part = b""
+
+ LOGGER.info("Start Queue Flow stream for Queue %s", queue.display_name)
+
+ while True:
+ # get (next) queue item to stream
+ if queue_track is None:
+ queue_track = stream_job.queue_item
+ use_crossfade = queue.crossfade_enabled
+ else:
+ seek_position = 0
+ fade_in = False
+ try:
+ (
+ queue_track,
+ use_crossfade,
+ ) = self.mass.players.queues.player_ready_for_next_track(
+ queue_id, queue_track.queue_item_id
+ )
+ except QueueEmpty:
+ break
+ # store reference to the current queueitem on the streamjob
+ stream_job.queue_item = queue_track
+
+ # get streamdetails
+ try:
+ streamdetails = await get_stream_details(self.mass, queue_track)
+ except MediaNotFoundError as err:
+ # streamdetails retrieval failed, skip to next track instead of bailing out...
+ LOGGER.warning(
+ "Skip track %s due to missing streamdetails",
+ queue_track.name,
+ exc_info=err,
+ )
+ continue
+
+ LOGGER.debug(
+ "Start Streaming queue track: %s (%s) for queue %s - crossfade: %s",
+ streamdetails.uri,
+ queue_track.name,
+ queue.display_name,
+ use_crossfade,
+ )
+
+ # set some basic vars
+ sample_rate = stream_job.pcm_sample_rate
+ bit_depth = stream_job.pcm_bit_depth
+ pcm_sample_size = int(sample_rate * (bit_depth / 8) * 2)
+ crossfade_duration = 10
+ crossfade_size = int(pcm_sample_size * crossfade_duration)
+ queue_track.streamdetails.seconds_skipped = seek_position
+ buffer_size = crossfade_size if use_crossfade else int(pcm_sample_size * 2)
+
+ buffer = b""
+ bytes_written = 0
+ chunk_num = 0
+ # handle incoming audio chunks
+ async for chunk in get_media_stream(
+ self.mass,
+ streamdetails,
+ seek_position=seek_position,
+ fade_in=fade_in,
+ sample_rate=sample_rate,
+ bit_depth=bit_depth,
+ # only allow strip silence from begin if track is being crossfaded
+ strip_silence_begin=last_fadeout_part != b"",
+ ):
+ chunk_num += 1
+
+ #### HANDLE FIRST PART OF TRACK
+
+ # buffer full for crossfade
+ if last_fadeout_part and (len(buffer) >= buffer_size):
+ first_part = buffer + chunk
+ # perform crossfade
+ fadein_part = first_part[:crossfade_size]
+ remaining_bytes = first_part[crossfade_size:]
+ crossfade_part = await crossfade_pcm_parts(
+ fadein_part,
+ last_fadeout_part,
+ bit_depth,
+ sample_rate,
+ )
+ # send crossfade_part
+ yield crossfade_part
+ bytes_written += len(crossfade_part)
+ # also write the leftover bytes from the strip action
+ if remaining_bytes:
+ yield remaining_bytes
+ bytes_written += len(remaining_bytes)
+
+ # clear vars
+ last_fadeout_part = b""
+ buffer = b""
+ continue
+
+ # enough data in buffer, feed to output
+ if len(buffer) >= (buffer_size * 2):
+ yield buffer[:buffer_size]
+ bytes_written += buffer_size
+ buffer = buffer[buffer_size:] + chunk
+ continue
+
+ # all other: fill buffer
+ buffer += chunk
+ continue
+
+ #### HANDLE END OF TRACK
+
+ if bytes_written == 0:
+ # stream error: got empty first chunk ?!
+ LOGGER.warning("Stream error on %s", streamdetails.uri)
+ queue_track.streamdetails.seconds_streamed = 0
+ continue
+
+ if buffer and use_crossfade:
+ # if crossfade is enabled, save fadeout part to pickup for next track
+ last_fadeout_part = buffer[-crossfade_size:]
+ remaining_bytes = buffer[:-crossfade_size]
+ yield remaining_bytes
+ bytes_written += len(remaining_bytes)
+ elif buffer:
+ # no crossfade enabled, just yield the buffer last part
+ yield buffer
+ bytes_written += len(buffer)
+
+ # end of the track reached - store accurate duration
+ queue_track.streamdetails.seconds_streamed = bytes_written / pcm_sample_size
+ LOGGER.debug(
+ "Finished Streaming queue track: %s (%s) on queue %s",
+ queue_track.streamdetails.uri,
+ queue_track.name,
+ queue.display_name,
+ )
+
+ LOGGER.info("Finished Queue Flow stream for Queue %s", queue.display_name)
+
+ async def _serve_preview(self, request: web.Request):
+ """Serve short preview sample."""
+ provider_mapping = request.query["provider_mapping"]
+ item_id = urllib.parse.unquote(request.query["item_id"])
+ resp = web.StreamResponse(status=200, reason="OK", headers={"Content-Type": "audio/mp3"})
+ await resp.prepare(request)
+ async for chunk in get_preview_stream(self.mass, provider_mapping, item_id):
+ await resp.write(chunk)
+ return resp
+
+ def _get_player_ffmpeg_args(
+ self,
+ player: Player,
+ input_sample_rate: int,
+ input_bit_depth: int,
+ output_format: ContentType,
+ output_sample_rate: int,
+ ) -> list[str]:
+ """Get player specific arguments for the given (pcm) input and output details."""
+ player_conf = self.mass.config.get_player_config(player.player_id)
+ conf_channels = player_conf.get_value(CONF_OUTPUT_CHANNELS)
+ # generic args
+ generic_args = [
+ "ffmpeg",
+ "-hide_banner",
+ "-loglevel",
+ "quiet",
+ "-ignore_unknown",
+ ]
+ # input args
+ input_args = [
+ "-f",
+ ContentType.from_bit_depth(input_bit_depth).value,
+ "-ac",
+ "2",
+ "-ar",
+ str(input_sample_rate),
+ "-i",
+ "-",
+ ]
+ # output args
+ output_args = [
+ # output args
+ "-f",
+ output_format.value,
+ "-ac",
+ "1" if conf_channels != "stereo" else "2",
+ "-ar",
+ str(output_sample_rate),
+ "-compression_level",
+ "0",
+ "-",
+ ]
+ # collect extra and filter args
+ # TODO: add convolution/DSP/roomcorrections here!
+ extra_args = []
+ filter_params = []
+
+ # the below is a very basic 3-band equalizer,
+ # this could be a lot more sophisticated at some point
+ if eq_bass := player_conf.get_value(CONF_EQ_BASS):
+ filter_params.append(f"equalizer=frequency=100:width=200:width_type=h:gain={eq_bass}")
+ if eq_mid := player_conf.get_value(CONF_EQ_MID):
+ filter_params.append(f"equalizer=frequency=900:width=1800:width_type=h:gain={eq_mid}")
+ if eq_treble := player_conf.get_value(CONF_EQ_TREBLE):
+ filter_params.append(
+ f"equalizer=frequency=9000:width=18000:width_type=h:gain={eq_treble}"
+ )
+ # handle output mixing only left or right
+ if conf_channels == "left":
+ filter_params.append("pan=mono|c0=FL")
+ elif conf_channels == "right":
+ filter_params.append("pan=mono|c0=FR")
+
+ if filter_params:
+ extra_args += ["-af", ",".join(filter_params)]
+
+ return generic_args + input_args + extra_args + output_args
+
+ async def _cleanup_stale(self) -> None:
+ """Cleanup stale/done stream tasks."""
+ stale = set()
+ for stream_id, job in self.stream_jobs.items():
+ if job.finished:
+ stale.add(stream_id)
+ for stream_id in stale:
+ self.stream_jobs.pop(stream_id, None)
+
+ # reschedule self to run every 5 minutes
+ def reschedule():
+ self.mass.create_task(self._cleanup_stale())
+
+ self.mass.loop.call_later(300, reschedule)
--- /dev/null
+"""Various server-seecific utils/helpers."""
--- /dev/null
+"""Several helpers for the WebSockets API."""
+from __future__ import annotations
+
+import asyncio
+import inspect
+import logging
+import weakref
+from collections.abc import Callable, Coroutine
+from concurrent import futures
+from contextlib import suppress
+from dataclasses import MISSING, dataclass
+from datetime import datetime
+from enum import Enum
+from types import NoneType, UnionType
+from typing import TYPE_CHECKING, Any, Final, TypeVar, Union, get_args, get_origin, get_type_hints
+
+from aiohttp import WSMsgType, web
+
+from music_assistant.common.helpers.json import json_dumps, json_loads
+from music_assistant.common.models.api import (
+ CommandMessage,
+ ErrorResultMessage,
+ MessageType,
+ ServerInfoMessage,
+ SuccessResultMessage,
+)
+from music_assistant.common.models.errors import InvalidCommand
+from music_assistant.common.models.event import MassEvent
+from music_assistant.constants import __version__
+
+if TYPE_CHECKING:
+ from music_assistant.server import MusicAssistant
+
+MAX_PENDING_MSG = 512
+CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
+API_SCHEMA_VERSION = 1
+
+LOGGER = logging.getLogger(__name__)
+DEBUG = False # Set to True to enable very verbose logging of all incoming/outgoing messages
+
+_F = TypeVar("_F", bound=Callable[..., Any])
+
+
+@dataclass
+class APICommandHandler:
+ """Model for an API command handler."""
+
+ command: str
+ signature: inspect.Signature
+ type_hints: dict[str, Any]
+ target: Callable[..., Coroutine[Any, Any, Any]]
+
+ @classmethod
+ def parse(
+ cls, command: str, func: Callable[..., Coroutine[Any, Any, Any]]
+ ) -> APICommandHandler:
+ """Parse APICommandHandler by providing a function."""
+ return APICommandHandler(
+ command=command,
+ signature=inspect.signature(func),
+ type_hints=get_type_hints(func),
+ target=func,
+ )
+
+
+def api_command(command: str) -> Callable[[_F], _F]:
+ """Decorate a function as API route/command."""
+
+ def decorate(func: _F) -> _F:
+ func.api_cmd = command # type: ignore[attr-defined]
+ return func
+
+ return decorate
+
+
+def parse_arguments(
+ func_sig: inspect.Signature,
+ func_types: dict[str, Any],
+ args: dict | None,
+ strict: bool = False,
+) -> dict[str, Any]:
+ """Parse (and convert) incoming arguments to correct types."""
+ if args is None:
+ args = {}
+ final_args = {}
+ # ignore extra args if not strict
+ if strict:
+ for key, value in args.items():
+ if key not in func_sig.parameters:
+ raise KeyError("Invalid parameter: '%s'" % key)
+ # parse arguments to correct type
+ for name, param in func_sig.parameters.items():
+ value = args.get(name)
+ default = MISSING if param.default is inspect.Parameter.empty else param.default
+ final_args[name] = parse_value(name, value, func_types[name], default)
+ return final_args
+
+
+def mount_websocket_api(mass: MusicAssistant, path: str) -> None:
+ """Mount the websocket endpoint."""
+ clients: weakref.WeakSet[WebsocketClientHandler] = weakref.WeakSet()
+
+ async def _handle_ws(request: web.Request) -> web.WebSocketResponse:
+ connection = WebsocketClientHandler(mass, request)
+ try:
+ clients.add(connection)
+ return await connection.handle_client()
+ finally:
+ clients.remove(connection)
+
+ async def _handle_shutdown(app: web.Application) -> None: # noqa: ARG001
+ for client in set(clients):
+ await client.disconnect()
+
+ mass.webapp.on_shutdown.append(_handle_shutdown)
+ mass.webapp.router.add_route("GET", path, _handle_ws)
+
+
+class WebSocketLogAdapter(logging.LoggerAdapter):
+ """Add connection id to websocket log messages."""
+
+ def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
+ """Add connid to websocket log messages."""
+ return f'[{self.extra["connid"]}] {msg}', kwargs
+
+
+class WebsocketClientHandler:
+ """Handle an active websocket client connection."""
+
+ def __init__(self, mass: MusicAssistant, request: web.Request) -> None:
+ """Initialize an active connection."""
+ self.mass = mass
+ self.request = request
+ self.wsock = web.WebSocketResponse(heartbeat=55)
+ self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
+ self._handle_task: asyncio.Task | None = None
+ self._writer_task: asyncio.Task | None = None
+ self._logger = WebSocketLogAdapter(LOGGER, {"connid": id(self)})
+
+ async def disconnect(self) -> None:
+ """Disconnect client."""
+ self._cancel()
+ if self._writer_task is not None:
+ await self._writer_task
+
+ async def handle_client(self) -> web.WebSocketResponse:
+ """Handle a websocket response."""
+ # ruff: noqa: PLR0915
+ request = self.request
+ wsock = self.wsock
+ try:
+ async with asyncio.timeout(10):
+ await wsock.prepare(request)
+ except asyncio.TimeoutError:
+ self._logger.warning("Timeout preparing request from %s", request.remote)
+ return wsock
+
+ self._logger.debug("Connection from %s", request.remote)
+ self._handle_task = asyncio.current_task()
+ self._writer_task = asyncio.create_task(self._writer())
+
+ # send server(version) info when client connects
+ self._send_message(
+ ServerInfoMessage(server_version=__version__, schema_version=API_SCHEMA_VERSION)
+ )
+
+ # forward all events to clients
+ def handle_event(event: MassEvent) -> None:
+ self._send_message(event)
+
+ unsub_callback = self.mass.subscribe(handle_event)
+
+ disconnect_warn = None
+
+ try:
+ while not wsock.closed:
+ msg = await wsock.receive()
+
+ if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
+ break
+
+ if msg.type != WSMsgType.TEXT:
+ disconnect_warn = "Received non-Text message."
+ break
+
+ if DEBUG:
+ self._logger.debug("Received: %s", msg.data)
+
+ try:
+ command_msg = CommandMessage.from_dict(json_loads(msg.data))
+ except ValueError:
+ disconnect_warn = f"Received invalid JSON: {msg.data}"
+ break
+
+ self._handle_command(command_msg)
+
+ except asyncio.CancelledError:
+ self._logger.debug("Connection closed by client")
+
+ except Exception: # pylint: disable=broad-except
+ self._logger.exception("Unexpected error inside websocket API")
+
+ finally:
+ # Handle connection shutting down.
+ unsub_callback()
+ self._logger.debug("Unsubscribed from events")
+
+ try:
+ self._to_write.put_nowait(None)
+ # Make sure all error messages are written before closing
+ await self._writer_task
+ await wsock.close()
+ except asyncio.QueueFull: # can be raised by put_nowait
+ self._writer_task.cancel()
+
+ finally:
+ if disconnect_warn is None:
+ self._logger.debug("Disconnected")
+ else:
+ self._logger.warning("Disconnected: %s", disconnect_warn)
+
+ return wsock
+
+ def _handle_command(self, msg: CommandMessage) -> None:
+ """Handle an incoming command from the client."""
+ self._logger.debug("Handling command %s", msg.command)
+
+ # work out handler for the given path/command
+ handler = self.mass.command_handlers.get(msg.command)
+
+ if handler is None:
+ self._send_message(
+ ErrorResultMessage(
+ msg.message_id,
+ InvalidCommand.error_code,
+ f"Invalid command: {msg.command}",
+ )
+ )
+ self._logger.warning("Invalid command: %s", msg.command)
+ return
+
+ # schedule task to handle the command
+ asyncio.create_task(self._run_handler(handler, msg))
+
+ async def _run_handler(self, handler: APICommandHandler, msg: CommandMessage) -> None:
+ try:
+ args = parse_arguments(handler.signature, handler.type_hints, msg.args)
+ result = handler.target(**args)
+ if asyncio.iscoroutine(result):
+ result = await result
+ self._send_message(SuccessResultMessage(msg.message_id, result))
+ except Exception as err: # pylint: disable=broad-except
+ self._logger.exception("Error handling message: %s", msg)
+ self._send_message(
+ ErrorResultMessage(msg.message_id, getattr(err, "error_code", 999), str(err))
+ )
+
+ async def _writer(self) -> None:
+ """Write outgoing messages."""
+ # Exceptions if Socket disconnected or cancelled by connection handler
+ with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
+ while not self.wsock.closed:
+ if (process := await self._to_write.get()) is None:
+ break
+
+ if not isinstance(process, str):
+ message: str = process()
+ else:
+ message = process
+ if DEBUG:
+ self._logger.debug("Writing: %s", message)
+ await self.wsock.send_str(message)
+
+ def _send_message(self, message: MessageType) -> None:
+ """Send a message to the client.
+
+ Closes connection if the client is not reading the messages.
+
+ Async friendly.
+ """
+ _message = json_dumps(message)
+
+ try:
+ self._to_write.put_nowait(_message)
+ except asyncio.QueueFull:
+ self._logger.error("Client exceeded max pending messages: %s", MAX_PENDING_MSG)
+
+ self._cancel()
+
+ def _cancel(self) -> None:
+ """Cancel the connection."""
+ if self._handle_task is not None:
+ self._handle_task.cancel()
+ if self._writer_task is not None:
+ self._writer_task.cancel()
+
+
+def parse_utc_timestamp(datetime_string: str) -> datetime:
+ """Parse datetime from string."""
+ return datetime.fromisoformat(datetime_string.replace("Z", "+00:00"))
+
+
+def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING) -> Any:
+ """Try to parse a value from raw (json) data and type annotations."""
+ if isinstance(value, dict) and hasattr(value_type, "from_dict"):
+ if "media_type" in value and value["media_type"] != value_type.media_type:
+ raise ValueError("Invalid MediaType")
+ return value_type.from_dict(value)
+
+ if value is None and not isinstance(default, type(MISSING)):
+ return default
+ if value is None and value_type is NoneType:
+ return None
+ origin = get_origin(value_type)
+ if origin is list:
+ return [
+ parse_value(name, subvalue, get_args(value_type)[0])
+ for subvalue in value
+ if subvalue is not None
+ ]
+ elif origin is dict:
+ subkey_type = get_args(value_type)[0]
+ subvalue_type = get_args(value_type)[1]
+ return {
+ parse_value(subkey, subkey, subkey_type): parse_value(
+ f"{subkey}.value", subvalue, subvalue_type
+ )
+ for subkey, subvalue in value.items()
+ }
+ elif origin is Union or origin is UnionType:
+ # try all possible types
+ sub_value_types = get_args(value_type)
+ for sub_arg_type in sub_value_types:
+ if value is NoneType and sub_arg_type is NoneType:
+ return value
+ # try them all until one succeeds
+ try:
+ return parse_value(name, value, sub_arg_type)
+ except (KeyError, TypeError, ValueError):
+ pass
+ # if we get to this point, all possibilities failed
+ # find out if we should raise or log this
+ err = (
+ f"Value {value} of type {type(value)} is invalid for {name}, "
+ f"expected value of type {value_type}"
+ )
+ if NoneType not in sub_value_types:
+ # raise exception, we have no idea how to handle this value
+ raise TypeError(err)
+ # failed to parse the (sub) value but None allowed, log only
+ logging.getLogger(__name__).warn(err)
+ return None
+ elif origin is type:
+ return eval(value)
+ if value_type is Any:
+ return value
+ if value is None and value_type is not NoneType:
+ raise KeyError(f"`{name}` of type `{value_type}` is required.")
+
+ try:
+ if issubclass(value_type, Enum): # type: ignore[arg-type]
+ return value_type(value) # type: ignore[operator]
+ if issubclass(value_type, datetime): # type: ignore[arg-type]
+ return parse_utc_timestamp(value)
+ except TypeError:
+ # happens if value_type is not a class
+ pass
+
+ if value_type is float and isinstance(value, int):
+ return float(value)
+ if value_type is int and isinstance(value, str) and value.isnumeric():
+ return int(value)
+ if not isinstance(value, value_type): # type: ignore[arg-type]
+ raise TypeError(
+ f"Value {value} of type {type(value)} is invalid for {name}, "
+ f"expected value of type {value_type}"
+ )
+ return value
--- /dev/null
+# pylint: skip-file
+# fmt: off
+# flake8: noqa
+# type: ignore
+(lambda __g: [(lambda __mod: [[[None for __g['app_var'], app_var.__name__ in [(lambda index: (lambda __l: [[AV(aap(__l['var'].encode()).decode()) for __l['var'] in [(vars.split('acb2')[__l['index']][::(-1)])]][0] for __l['index'] in [(index)]][0])({}), 'app_var')]][0] for __g['vars'] in [('3YTNyUDOyQTOacb2=EmN5M2YjdzMhljYzYzYhlDMmFGNlVTOmNDZwMzNxYzNacb2=UDMzEGOyADO1QWO5kDNygTMlJGN5QzNzIWOmZTOiVmMacb2yMTNzITNacb2=UDZhJmMldTZ3QTY4IjZ3kTNxYjN0czNwI2YxkTM5MjN')]][0] for __g['aap'] in [(__mod.b64decode)]][0])(__import__('base64', __g, __g, ('b64decode',), 0)) for __g['AV'] in [((lambda b, d: d.get('__metaclass__', getattr(b[0], '__class__', type(b[0])))('AV', b, d))((str,), (lambda __l: [__l for __l['__repr__'], __l['__repr__'].__name__ in [(lambda self: (lambda __l: [__name__ for __l['self'] in [(self)]][0])({}), '__repr__')]][0])({'__module__': __name__})))]][0])(globals())
--- /dev/null
+"""Various helpers for audio manipulation."""
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+import re
+import struct
+from collections.abc import AsyncGenerator
+from io import BytesIO
+from time import time
+from typing import TYPE_CHECKING
+
+import aiofiles
+from aiohttp import ClientTimeout
+
+from music_assistant.common.helpers.util import create_tempfile
+from music_assistant.common.models.errors import AudioError, MediaNotFoundError, MusicAssistantError
+from music_assistant.common.models.media_items import ContentType, MediaType, StreamDetails
+from music_assistant.constants import CONF_VOLUME_NORMALISATION, CONF_VOLUME_NORMALISATION_TARGET
+from music_assistant.server.helpers.process import AsyncProcess, check_output
+
+if TYPE_CHECKING:
+ from music_assistant.common.models.player_queue import QueueItem
+ from music_assistant.server import MusicAssistant
+
+LOGGER = logging.getLogger(__name__)
+
+# pylint:disable=consider-using-f-string
+
+
+async def crossfade_pcm_parts(
+ fade_in_part: bytes,
+ fade_out_part: bytes,
+ bit_depth: int,
+ sample_rate: int,
+) -> bytes:
+ """Crossfade two chunks of pcm/raw audio using ffmpeg."""
+ sample_size = int(sample_rate * (bit_depth / 8) * 2)
+ fmt = ContentType.from_bit_depth(bit_depth)
+ # calculate the fade_length from the smallest chunk
+ fade_length = min(len(fade_in_part), len(fade_out_part)) / sample_size
+ fadeoutfile = create_tempfile()
+ async with aiofiles.open(fadeoutfile.name, "wb") as outfile:
+ await outfile.write(fade_out_part)
+ args = [
+ # generic args
+ "ffmpeg",
+ "-hide_banner",
+ "-loglevel",
+ "quiet",
+ # fadeout part (as file)
+ "-acodec",
+ fmt.name.lower(),
+ "-f",
+ fmt,
+ "-ac",
+ "2",
+ "-ar",
+ str(sample_rate),
+ "-i",
+ fadeoutfile.name,
+ # fade_in part (stdin)
+ "-acodec",
+ fmt.name.lower(),
+ "-f",
+ fmt,
+ "-ac",
+ "2",
+ "-ar",
+ str(sample_rate),
+ "-i",
+ "-",
+ # filter args
+ "-filter_complex",
+ f"[0][1]acrossfade=d={fade_length}",
+ # output args
+ "-f",
+ fmt,
+ "-",
+ ]
+ async with AsyncProcess(args, True) as proc:
+ crossfade_data, _ = await proc.communicate(fade_in_part)
+ if crossfade_data:
+ LOGGER.debug(
+ "crossfaded 2 pcm chunks. fade_in_part: %s - "
+ "fade_out_part: %s - fade_length: %s seconds",
+ len(fade_in_part),
+ len(fade_out_part),
+ fade_length,
+ )
+ return crossfade_data
+ # no crossfade_data, return original data instead
+ LOGGER.debug(
+ "crossfade of pcm chunks failed: not enough data? "
+ "fade_in_part: %s - fade_out_part: %s",
+ len(fade_in_part),
+ len(fade_out_part),
+ )
+ return fade_out_part + fade_in_part
+
+
+async def strip_silence(
+ mass: MusicAssistant, # noqa: ARG001
+ audio_data: bytes,
+ sample_rate: int,
+ bit_depth: int,
+ reverse: bool = False,
+) -> bytes:
+ """Strip silence from begin or end of pcm audio using ffmpeg."""
+ fmt = ContentType.from_bit_depth(bit_depth)
+ args = ["ffmpeg", "-hide_banner", "-loglevel", "quiet"]
+ args += [
+ "-acodec",
+ fmt.name.lower(),
+ "-f",
+ fmt,
+ "-ac",
+ "2",
+ "-ar",
+ str(sample_rate),
+ "-i",
+ "-",
+ ]
+ # filter args
+ if reverse:
+ args += [
+ "-af",
+ "areverse,atrim=start=0.2,silenceremove=start_periods=1:start_silence=0.1:start_threshold=0.02,areverse",
+ ]
+ else:
+ args += [
+ "-af",
+ "atrim=start=0.2,silenceremove=start_periods=1:start_silence=0.1:start_threshold=0.02",
+ ]
+ # output args
+ args += ["-f", fmt, "-"]
+ async with AsyncProcess(args, True) as proc:
+ stripped_data, _ = await proc.communicate(audio_data)
+
+ # return stripped audio
+ bytes_stripped = len(audio_data) - len(stripped_data)
+ if LOGGER.isEnabledFor(logging.DEBUG):
+ pcm_sample_size = int(sample_rate * (bit_depth / 8) * 2)
+ seconds_stripped = round(bytes_stripped / pcm_sample_size, 2)
+ location = "end" if reverse else "begin"
+ LOGGER.debug(
+ "stripped %s seconds of silence from %s of pcm audio. bytes stripped: %s",
+ seconds_stripped,
+ location,
+ bytes_stripped,
+ )
+ return stripped_data
+
+
+async def analyze_audio(mass: MusicAssistant, streamdetails: StreamDetails) -> None:
+ """Analyze track audio, for now we only calculate EBU R128 loudness."""
+ if streamdetails.loudness is not None:
+ # only when needed we do the analyze job
+ return
+
+ LOGGER.debug("Start analyzing audio for %s", streamdetails.uri)
+ # calculate BS.1770 R128 integrated loudness with ffmpeg
+ input_file = streamdetails.direct or "-"
+ proc_args = [
+ "ffmpeg",
+ "-t",
+ "300", # limit to 5 minutes to prevent OOM
+ "-i",
+ input_file,
+ "-f",
+ streamdetails.content_type,
+ "-af",
+ "ebur128=framelog=verbose",
+ "-f",
+ "null",
+ "-",
+ ]
+ async with AsyncProcess(
+ proc_args,
+ enable_stdin=streamdetails.direct is None,
+ enable_stdout=False,
+ enable_stderr=True,
+ ) as ffmpeg_proc:
+
+ async def writer():
+ """Task that grabs the source audio and feeds it to ffmpeg."""
+ music_prov = mass.get_provider(streamdetails.provider)
+ chunk_count = 0
+ async for audio_chunk in music_prov.get_audio_stream(streamdetails):
+ chunk_count += 1
+ await ffmpeg_proc.write(audio_chunk)
+ if chunk_count == 300:
+ # safety guard: max (more or less) 5 minutes seconds of audio may be analyzed
+ break
+ ffmpeg_proc.write_eof()
+
+ if streamdetails.direct is None:
+ writer_task = ffmpeg_proc.attach_task(writer())
+ # wait for the writer task to finish
+ await writer_task
+
+ _, stderr = await ffmpeg_proc.communicate()
+ try:
+ loudness_str = (
+ stderr.decode().split("Integrated loudness")[1].split("I:")[1].split("LUFS")[0]
+ )
+ loudness = float(loudness_str.strip())
+ except (IndexError, ValueError, AttributeError):
+ LOGGER.warning(
+ "Could not determine integrated loudness of %s - %s",
+ streamdetails.uri,
+ stderr.decode() or "received empty value",
+ )
+ else:
+ streamdetails.loudness = loudness
+ await mass.music.set_track_loudness(
+ streamdetails.item_id, streamdetails.provider, loudness
+ )
+ LOGGER.debug(
+ "Integrated loudness of %s is: %s",
+ streamdetails.uri,
+ loudness,
+ )
+
+
+async def get_stream_details(mass: MusicAssistant, queue_item: QueueItem) -> StreamDetails:
+ """Get streamdetails for the given QueueItem.
+
+ This is called just-in-time when a PlayerQueue wants a MediaItem to be played.
+ Do not try to request streamdetails in advance as this is expiring data.
+ param media_item: The QueueItem for which to request the streamdetails for.
+ """
+ streamdetails = None
+ if queue_item.streamdetails and (time() < (queue_item.streamdetails.expires - 360)):
+ # we already have fresh streamdetails, use these
+ queue_item.streamdetails.seconds_skipped = None
+ queue_item.streamdetails.seconds_streamed = None
+ streamdetails = queue_item.streamdetails
+ else:
+ # fetch streamdetails from provider
+ # always request the full item as there might be other qualities available
+ full_item = await mass.music.get_item_by_uri(queue_item.uri)
+ # sort by quality and check track availability
+ for prov_media in sorted(
+ full_item.provider_mappings, key=lambda x: x.quality or 0, reverse=True
+ ):
+ if not prov_media.available:
+ continue
+ # get streamdetails from provider
+ music_prov = mass.get_provider(prov_media.provider_instance)
+ if not music_prov:
+ continue # provider not available ?
+ try:
+ streamdetails: StreamDetails = await music_prov.get_stream_details(
+ prov_media.item_id
+ )
+ streamdetails.content_type = ContentType(streamdetails.content_type)
+ except MusicAssistantError as err:
+ LOGGER.warning(str(err))
+ else:
+ break
+
+ if not streamdetails:
+ raise MediaNotFoundError(f"Unable to retrieve streamdetails for {queue_item}")
+
+ # set queue_id on the streamdetails so we know what is being streamed
+ streamdetails.queue_id = queue_item.queue_id
+ # get gain correct / replaygain
+ if streamdetails.gain_correct is None:
+ loudness, gain_correct = await get_gain_correct(mass, streamdetails)
+ streamdetails.gain_correct = gain_correct
+ streamdetails.loudness = loudness
+ if not streamdetails.duration:
+ streamdetails.duration = queue_item.duration
+ # make sure that ffmpeg handles mpeg dash streams directly
+ if (
+ streamdetails.content_type == ContentType.MPEG_DASH
+ and streamdetails.data
+ and streamdetails.data.startswith("http")
+ ):
+ streamdetails.direct = streamdetails.data
+ # set streamdetails as attribute on the media_item
+ # this way the app knows what content is playing
+ queue_item.streamdetails = streamdetails
+ return streamdetails
+
+
+async def get_gain_correct(
+ mass: MusicAssistant, streamdetails: StreamDetails
+) -> tuple[float | None, float | None]:
+ """Get gain correction for given queue / track combination."""
+ player_settings = mass.config.get_player_config(streamdetails.queue_id)
+ if not player_settings or not player_settings.get_value(CONF_VOLUME_NORMALISATION):
+ return (None, None)
+ if streamdetails.gain_correct is not None:
+ return (streamdetails.loudness, streamdetails.gain_correct)
+ target_gain = player_settings.get_value(CONF_VOLUME_NORMALISATION_TARGET)
+ track_loudness = await mass.music.get_track_loudness(
+ streamdetails.item_id, streamdetails.provider
+ )
+ if track_loudness is None:
+ # fallback to provider average
+ fallback_track_loudness = await mass.music.get_provider_loudness(streamdetails.provider)
+ if fallback_track_loudness is None:
+ # fallback to some (hopefully sane) average value for now
+ fallback_track_loudness = -8.5
+ gain_correct = target_gain - fallback_track_loudness
+ else:
+ gain_correct = target_gain - track_loudness
+ gain_correct = round(gain_correct, 2)
+ return (track_loudness, gain_correct)
+
+
+def create_wave_header(samplerate=44100, channels=2, bitspersample=16, duration=None):
+ """Generate a wave header from given params."""
+ # pylint: disable=no-member
+ file = BytesIO()
+
+ # Generate format chunk
+ format_chunk_spec = b"<4sLHHLLHH"
+ format_chunk = struct.pack(
+ format_chunk_spec,
+ b"fmt ", # Chunk id
+ 16, # Size of this chunk (excluding chunk id and this field)
+ 1, # Audio format, 1 for PCM
+ channels, # Number of channels
+ int(samplerate), # Samplerate, 44100, 48000, etc.
+ int(samplerate * channels * (bitspersample / 8)), # Byterate
+ int(channels * (bitspersample / 8)), # Blockalign
+ bitspersample, # 16 bits for two byte samples, etc.
+ )
+ # Generate data chunk
+ # duration = 3600*6.7
+ data_chunk_spec = b"<4sL"
+ if duration is None:
+ # use max value possible
+ datasize = 4254768000 # = 6,7 hours at 44100/16
+ else:
+ # calculate from duration
+ numsamples = samplerate * duration
+ datasize = int(numsamples * channels * (bitspersample / 8))
+ data_chunk = struct.pack(
+ data_chunk_spec,
+ b"data", # Chunk id
+ int(datasize), # Chunk size (excluding chunk id and this field)
+ )
+ sum_items = [
+ # "WAVE" string following size field
+ 4,
+ # "fmt " + chunk size field + chunk size
+ struct.calcsize(format_chunk_spec),
+ # Size of data chunk spec + data size
+ struct.calcsize(data_chunk_spec) + datasize,
+ ]
+ # Generate main header
+ all_chunks_size = int(sum(sum_items))
+ main_header_spec = b"<4sL4s"
+ main_header = struct.pack(main_header_spec, b"RIFF", all_chunks_size, b"WAVE")
+ # Write all the contents in
+ file.write(main_header)
+ file.write(format_chunk)
+ file.write(data_chunk)
+
+ # return file.getvalue(), all_chunks_size + 8
+ return file.getvalue()
+
+
+async def get_media_stream(
+ mass: MusicAssistant,
+ streamdetails: StreamDetails,
+ seek_position: int = 0,
+ fade_in: bool = False,
+ sample_rate: int | None = None,
+ bit_depth: int | None = None,
+ strip_silence_begin: bool = False,
+ strip_silence_end: bool = True,
+) -> AsyncGenerator[bytes, None]:
+ """Get the (PCM) audio stream for the given streamdetails.
+
+ Other than stripping silence at end and beginning and optional
+ volume normalization this is the pure, unaltered audio data as PCM chunks.
+ """
+ bytes_sent = 0
+ streamdetails.seconds_skipped = seek_position
+ is_radio = streamdetails.media_type == MediaType.RADIO or not streamdetails.duration
+ if is_radio or seek_position:
+ strip_silence_begin = False
+
+ sample_rate = sample_rate or streamdetails.sample_rate
+ bit_depth = bit_depth or streamdetails.bit_depth
+ # chunk size = 2 seconds of pcm audio
+ pcm_sample_size = int(sample_rate * (bit_depth / 8) * 2)
+ chunk_size = pcm_sample_size * (1 if is_radio else 2)
+ expected_chunks = int((streamdetails.duration or 0) / 2)
+ if expected_chunks < 60:
+ strip_silence_end = False
+
+ # collect all arguments for ffmpeg
+ args = await _get_ffmpeg_args(
+ streamdetails=streamdetails,
+ sample_rate=sample_rate,
+ bit_depth=bit_depth,
+ seek_position=seek_position,
+ fade_in=fade_in,
+ )
+
+ async with AsyncProcess(args, enable_stdin=streamdetails.direct is None) as ffmpeg_proc:
+ LOGGER.debug("start media stream for: %s", streamdetails.uri)
+
+ async def writer():
+ """Task that grabs the source audio and feeds it to ffmpeg."""
+ LOGGER.debug("writer started for %s", streamdetails.uri)
+ music_prov = mass.get_provider(streamdetails.provider)
+ async for audio_chunk in music_prov.get_audio_stream(streamdetails, seek_position):
+ await ffmpeg_proc.write(audio_chunk)
+ # write eof when last packet is received
+ ffmpeg_proc.write_eof()
+ LOGGER.debug("writer finished for %s", streamdetails.uri)
+
+ if streamdetails.direct is None:
+ ffmpeg_proc.attach_task(writer())
+
+ # get pcm chunks from stdout
+ # we always stay one chunk behind to properly detect end of chunks
+ # so we can strip silence at the beginning and end of a track
+ prev_chunk = b""
+ chunk_num = 0
+ try:
+ async for chunk in ffmpeg_proc.iter_chunked(chunk_size):
+ chunk_num += 1
+ if strip_silence_begin and chunk_num == 2:
+ # first 2 chunks received, strip silence of beginning
+ stripped_audio = await strip_silence(
+ mass,
+ prev_chunk + chunk,
+ sample_rate=sample_rate,
+ bit_depth=bit_depth,
+ )
+ yield stripped_audio
+ bytes_sent += len(stripped_audio)
+ prev_chunk = b""
+ del stripped_audio
+ continue
+ if strip_silence_end and chunk_num >= (expected_chunks - 6):
+ # last part of the track, collect multiple chunks to strip silence later
+ prev_chunk += chunk
+ continue
+
+ # middle part of the track, send previous chunk and collect current chunk
+ if prev_chunk:
+ yield prev_chunk
+ bytes_sent += len(prev_chunk)
+
+ prev_chunk = chunk
+
+ # all chunks received, strip silence of last part
+ if strip_silence_end:
+ stripped_audio = await strip_silence(
+ mass,
+ prev_chunk,
+ sample_rate=sample_rate,
+ bit_depth=bit_depth,
+ reverse=True,
+ )
+ yield stripped_audio
+ bytes_sent += len(stripped_audio)
+ del stripped_audio
+ else:
+ yield prev_chunk
+ bytes_sent += len(prev_chunk)
+
+ del prev_chunk
+
+ # update duration details based on the actual pcm data we sent
+ streamdetails.seconds_streamed = bytes_sent / pcm_sample_size
+
+ except (asyncio.CancelledError, GeneratorExit) as err:
+ LOGGER.debug("media stream aborted for: %s", streamdetails.uri)
+ raise err
+ else:
+ LOGGER.debug("finished media stream for: %s", streamdetails.uri)
+ await mass.music.mark_item_played(streamdetails.item_id, streamdetails.provider)
+ finally:
+ # report playback
+ if streamdetails.callback:
+ mass.create_task(streamdetails.callback, streamdetails)
+ # send analyze job to background worker
+ if streamdetails.loudness is None:
+ mass.create_task(analyze_audio(mass, streamdetails))
+
+
+async def get_radio_stream(
+ mass: MusicAssistant, url: str, streamdetails: StreamDetails
+) -> AsyncGenerator[bytes, None]:
+ """Get radio audio stream from HTTP, including metadata retrieval."""
+ headers = {"Icy-MetaData": "1"}
+ timeout = ClientTimeout(total=0, connect=30, sock_read=600)
+ async with mass.http_session.get(url, headers=headers, timeout=timeout) as resp:
+ headers = resp.headers
+ meta_int = int(headers.get("icy-metaint", "0"))
+ # stream with ICY Metadata
+ if meta_int:
+ while True:
+ audio_chunk = await resp.content.readexactly(meta_int)
+ yield audio_chunk
+ meta_byte = await resp.content.readexactly(1)
+ meta_length = ord(meta_byte) * 16
+ meta_data = await resp.content.readexactly(meta_length)
+ if not meta_data:
+ continue
+ meta_data = meta_data.rstrip(b"\0")
+ stream_title = re.search(rb"StreamTitle='([^']*)';", meta_data)
+ if not stream_title:
+ continue
+ stream_title = stream_title.group(1).decode()
+ if stream_title != streamdetails.stream_title:
+ streamdetails.stream_title = stream_title
+ # Regular HTTP stream
+ else:
+ async for chunk in resp.content.iter_any():
+ yield chunk
+
+
+async def get_http_stream(
+ mass: MusicAssistant,
+ url: str,
+ streamdetails: StreamDetails,
+ seek_position: int = 0,
+) -> AsyncGenerator[bytes, None]:
+ """Get audio stream from HTTP."""
+ if seek_position:
+ assert streamdetails.duration, "Duration required for seek requests"
+ # try to get filesize with a head request
+ if seek_position and not streamdetails.size:
+ async with mass.http_session.head(url) as resp:
+ if size := resp.headers.get("Content-Length"):
+ streamdetails.size = int(size)
+ # headers
+ headers = {}
+ skip_bytes = 0
+ if seek_position and streamdetails.size:
+ skip_bytes = int(streamdetails.size / streamdetails.duration * seek_position)
+ headers["Range"] = f"bytes={skip_bytes}-"
+
+ # start the streaming from http
+ buffer = b""
+ buffer_all = False
+ bytes_received = 0
+ timeout = ClientTimeout(total=0, connect=30, sock_read=600)
+ async with mass.http_session.get(url, headers=headers, timeout=timeout) as resp:
+ is_partial = resp.status == 206
+ buffer_all = seek_position and not is_partial
+ async for chunk in resp.content.iter_any():
+ bytes_received += len(chunk)
+ if buffer_all and not skip_bytes:
+ buffer += chunk
+ continue
+ if not is_partial and skip_bytes and bytes_received < skip_bytes:
+ continue
+ yield chunk
+
+ # store size on streamdetails for later use
+ if not streamdetails.size:
+ streamdetails.size = bytes_received
+ if buffer_all:
+ skip_bytes = streamdetails.size / streamdetails.duration * seek_position
+ yield buffer[:skip_bytes]
+
+
+async def get_file_stream(
+ mass: MusicAssistant, # noqa: ARG001
+ filename: str,
+ streamdetails: StreamDetails,
+ seek_position: int = 0,
+) -> AsyncGenerator[bytes, None]:
+ """Get audio stream from local accessible file."""
+ if seek_position:
+ assert streamdetails.duration, "Duration required for seek requests"
+ if not streamdetails.size:
+ stat = await asyncio.to_thread(os.stat, filename)
+ streamdetails.size = stat.st_size
+ chunk_size = get_chunksize(streamdetails.content_type)
+ async with aiofiles.open(streamdetails.data, "rb") as _file:
+ if seek_position:
+ seek_pos = int((streamdetails.size / streamdetails.duration) * seek_position)
+ await _file.seek(seek_pos)
+ # yield chunks of data from file
+ while True:
+ data = await _file.read(chunk_size)
+ if not data:
+ break
+ yield data
+
+
+async def check_audio_support(try_install: bool = False) -> tuple[bool, bool]:
+ """Check if ffmpeg is present (with/without libsoxr support)."""
+ cache_key = "audio_support_cache"
+ if cache := globals().get(cache_key):
+ return cache
+
+ # check for FFmpeg presence
+ returncode, output = await check_output("ffmpeg -version")
+ ffmpeg_present = returncode == 0 and "FFmpeg" in output.decode()
+ if not ffmpeg_present and try_install:
+ # try a few common ways to install ffmpeg
+ # this all assumes we have enough rights and running on a linux based platform (or docker)
+ await check_output("apt-get update && apt-get install ffmpeg")
+ await check_output("apk add ffmpeg")
+ # test again
+ returncode, output = await check_output("ffmpeg -version")
+ ffmpeg_present = returncode == 0 and "FFmpeg" in output.decode()
+
+ # use globals as in-memory cache
+ libsoxr_support = "enable-libsoxr" in output.decode()
+ result = (ffmpeg_present, libsoxr_support)
+ globals()[cache_key] = result
+ return result
+
+
+async def get_preview_stream(
+ mass: MusicAssistant,
+ provider_mapping: str,
+ track_id: str,
+) -> AsyncGenerator[bytes, None]:
+ """Create a 30 seconds preview audioclip for the given streamdetails."""
+ music_prov = mass.get_provider(provider_mapping)
+
+ streamdetails = await music_prov.get_stream_details(track_id)
+
+ input_args = [
+ "ffmpeg",
+ "-hide_banner",
+ "-loglevel",
+ "quiet",
+ "-ignore_unknown",
+ ]
+ if streamdetails.direct:
+ input_args += ["-ss", "30", "-i", streamdetails.direct]
+ else:
+ # the input is received from pipe/stdin
+ if streamdetails.content_type != ContentType.UNKNOWN:
+ input_args += ["-f", streamdetails.content_type]
+ input_args += ["-i", "-"]
+
+ output_args = ["-to", "30", "-f", "mp3", "-"]
+ args = input_args + output_args
+ async with AsyncProcess(args, True) as ffmpeg_proc:
+
+ async def writer():
+ """Task that grabs the source audio and feeds it to ffmpeg."""
+ music_prov = mass.get_provider(streamdetails.provider)
+ async for audio_chunk in music_prov.get_audio_stream(streamdetails, 30):
+ await ffmpeg_proc.write(audio_chunk)
+ # write eof when last packet is received
+ ffmpeg_proc.write_eof()
+
+ if not streamdetails.direct:
+ ffmpeg_proc.attach_task(writer())
+
+ # yield chunks from stdout
+ async for chunk in ffmpeg_proc.iter_any():
+ yield chunk
+
+
+async def get_silence(
+ duration: int,
+ output_fmt: ContentType = ContentType.WAV,
+ sample_rate: int = 44100,
+ bit_depth: int = 16,
+) -> AsyncGenerator[bytes, None]:
+ """Create stream of silence, encoded to format of choice."""
+ # wav silence = just zero's
+ if output_fmt == ContentType.WAV:
+ yield create_wave_header(
+ samplerate=sample_rate,
+ channels=2,
+ bitspersample=bit_depth,
+ duration=duration,
+ )
+ for _ in range(0, duration):
+ yield b"\0" * int(sample_rate * (bit_depth / 8) * 2)
+ return
+
+ # use ffmpeg for all other encodings
+ args = [
+ "ffmpeg",
+ "-hide_banner",
+ "-loglevel",
+ "quiet",
+ "-f",
+ "lavfi",
+ "-i",
+ f"anullsrc=r={sample_rate}:cl={'stereo'}",
+ "-t",
+ str(duration),
+ "-f",
+ output_fmt,
+ "-",
+ ]
+ async with AsyncProcess(args) as ffmpeg_proc:
+ async for chunk in ffmpeg_proc.iter_any():
+ yield chunk
+
+
+def get_chunksize(
+ content_type: ContentType,
+ sample_rate: int = 44100,
+ bit_depth: int = 16,
+ seconds: int = 1,
+) -> int:
+ """Get a default chunksize for given contenttype."""
+ pcm_size = int(sample_rate * (bit_depth / 8) * 2 * seconds)
+ if content_type.is_pcm() or content_type == ContentType.WAV:
+ return pcm_size
+ if content_type in (ContentType.WAV, ContentType.AIFF, ContentType.DSF):
+ return pcm_size
+ if content_type in (ContentType.FLAC, ContentType.WAVPACK, ContentType.ALAC):
+ return int(pcm_size * 0.6)
+ if content_type in (ContentType.MP3, ContentType.OGG, ContentType.M4A):
+ return int(640000 * seconds)
+ return 32000 * seconds
+
+
+async def _get_ffmpeg_args(
+ streamdetails: StreamDetails,
+ sample_rate: int,
+ bit_depth: int,
+ seek_position: int = 0,
+ fade_in: bool = False,
+) -> list[str]:
+ """Collect all args to send to the ffmpeg process."""
+ input_format = streamdetails.content_type
+
+ ffmpeg_present, libsoxr_support = await check_audio_support()
+
+ if not ffmpeg_present:
+ raise AudioError(
+ "FFmpeg binary is missing from system."
+ "Please install ffmpeg on your OS to enable playback.",
+ )
+ # generic args
+ generic_args = [
+ "ffmpeg",
+ "-hide_banner",
+ "-loglevel",
+ "quiet",
+ "-ignore_unknown",
+ ]
+ # collect input args
+ input_args = []
+ if streamdetails.direct:
+ # ffmpeg can access the inputfile (or url) directly
+ if streamdetails.direct.startswith("http"):
+ # append reconnect options for direct stream from http
+ input_args += [
+ "-reconnect",
+ "1",
+ "-reconnect_streamed",
+ "1",
+ "-reconnect_on_network_error",
+ "1",
+ "-reconnect_on_http_error",
+ "5xx",
+ "-reconnect_delay_max",
+ "10",
+ ]
+ if seek_position:
+ input_args += ["-ss", str(seek_position)]
+ input_args += ["-i", streamdetails.direct]
+ else:
+ # the input is received from pipe/stdin
+ if streamdetails.content_type != ContentType.UNKNOWN:
+ input_args += ["-f", input_format]
+ input_args += ["-i", "-"]
+
+ pcm_output_format = ContentType.from_bit_depth(bit_depth)
+ # collect output args
+ output_args = [
+ "-acodec",
+ pcm_output_format.name.lower(),
+ "-f",
+ pcm_output_format,
+ "-ac",
+ "2", # to simplify things, we always output 2 channels
+ "-ar",
+ str(sample_rate),
+ "-",
+ ]
+ # collect extra and filter args
+ extra_args = []
+ filter_params = []
+ if streamdetails.gain_correct is not None:
+ filter_params.append(f"volume={streamdetails.gain_correct}dB")
+ if (
+ streamdetails.sample_rate != sample_rate
+ and libsoxr_support
+ and streamdetails.media_type == MediaType.TRACK
+ ):
+ # prefer libsoxr high quality resampler (if present) for sample rate conversions
+ filter_params.append("aresample=resampler=soxr")
+ if fade_in:
+ filter_params.append("afade=type=in:start_time=0:duration=3")
+ if filter_params:
+ extra_args += ["-af", ",".join(filter_params)]
+
+ return generic_args + input_args + extra_args + output_args
--- /dev/null
+"""Several helper/utils to compare objects."""
+from __future__ import annotations
+
+from music_assistant.common.helpers.util import create_safe_string, create_sort_name
+from music_assistant.common.models.enums import AlbumType
+from music_assistant.common.models.media_items import (
+ Album,
+ Artist,
+ ItemMapping,
+ MediaItem,
+ MediaItemMetadata,
+ Track,
+)
+
+
+def loose_compare_strings(base: str, alt: str) -> bool:
+ """Compare strings and return True even on partial match."""
+ # this is used to display 'versions' of the same track/album
+ # where we account for other spelling or some additional wording in the title
+ word_count = len(base.split(" "))
+ if word_count == 1 and len(base) < 10:
+ return compare_strings(base, alt, False)
+ base_comp = create_safe_string(base)
+ alt_comp = create_safe_string(alt)
+ if base_comp in alt_comp:
+ return True
+ if alt_comp in base_comp:
+ return True
+ return False
+
+
+def compare_strings(str1: str, str2: str, strict: bool = True) -> bool:
+ """Compare strings and return True if we have an (almost) perfect match."""
+ if str1 is None or str2 is None:
+ return False
+ # return early if total length mismatch
+ if abs(len(str1) - len(str2)) > 2:
+ return False
+ if not strict:
+ return create_safe_string(str1) == create_safe_string(str2)
+ return create_sort_name(str1) == create_sort_name(str2)
+
+
+def compare_version(left_version: str, right_version: str) -> bool:
+ """Compare version string."""
+ if not left_version and not right_version:
+ return True
+ if not left_version and right_version:
+ return False
+ if left_version and not right_version:
+ return False
+ if " " not in left_version:
+ return compare_strings(left_version, right_version)
+ # do this the hard way as sometimes the version string is in the wrong order
+ left_versions = left_version.lower().split(" ").sort()
+ right_versions = right_version.lower().split(" ").sort()
+ return left_versions == right_versions
+
+
+def compare_explicit(left: MediaItemMetadata, right: MediaItemMetadata) -> bool:
+ """Compare if explicit is same in metadata."""
+ if left.explicit is None or right.explicit is None:
+ # explicitness info is not always present in metadata
+ # only strict compare them if both have the info set
+ return True
+ return left == right
+
+
+def compare_artist(
+ left_artist: Artist | ItemMapping,
+ right_artist: Artist | ItemMapping,
+) -> bool:
+ """Compare two artist items and return True if they match."""
+ if left_artist is None or right_artist is None:
+ return False
+ # return early on exact item_id match
+ if compare_item_ids(left_artist, right_artist):
+ return True
+
+ # prefer match on musicbrainz_id
+ if getattr(left_artist, "musicbrainz_id", None) and getattr(
+ right_artist, "musicbrainz_id", None
+ ):
+ return left_artist.musicbrainz_id == right_artist.musicbrainz_id
+
+ # fallback to comparing
+ return compare_strings(left_artist.name, right_artist.name, False)
+
+
+def compare_artists(
+ left_artists: list[Artist | ItemMapping],
+ right_artists: list[Artist | ItemMapping],
+ any_match: bool = False,
+) -> bool:
+ """Compare two lists of artist and return True if both lists match (exactly)."""
+ matches = 0
+ for left_artist in left_artists:
+ for right_artist in right_artists:
+ if compare_artist(left_artist, right_artist):
+ if any_match:
+ return True
+ matches += 1
+ return len(left_artists) == matches
+
+
+def compare_item_ids(
+ left_item: MediaItem | ItemMapping, right_item: MediaItem | ItemMapping
+) -> bool:
+ """Compare item_id(s) of two media items."""
+ if left_item.provider == right_item.provider and left_item.item_id == right_item.item_id:
+ return True
+
+ left_prov_ids = getattr(left_item, "provider_mappings", None)
+ right_prov_ids = getattr(right_item, "provider_mappings", None)
+
+ if left_prov_ids is not None:
+ for prov_l in left_item.provider_mappings:
+ if (
+ prov_l.provider_domain == right_item.provider
+ and prov_l.item_id == right_item.item_id
+ ):
+ return True
+
+ if right_prov_ids is not None:
+ for prov_r in right_item.provider_mappings:
+ if prov_r.provider_domain == left_item.provider and prov_r.item_id == left_item.item_id:
+ return True
+
+ if left_prov_ids is not None and right_prov_ids is not None:
+ for prov_l in left_item.provider_mappings:
+ for prov_r in right_item.provider_mappings:
+ if prov_l.provider_domain != prov_r.provider_domain:
+ continue
+ if prov_l.item_id == prov_r.item_id:
+ return True
+ return False
+
+
+def compare_albums(
+ left_albums: list[Album | ItemMapping],
+ right_albums: list[Album | ItemMapping],
+):
+ """Compare two lists of albums and return True if a match was found."""
+ for left_album in left_albums:
+ for right_album in right_albums:
+ if compare_album(left_album, right_album):
+ return True
+ return False
+
+
+def compare_album(
+ left_album: Album | ItemMapping,
+ right_album: Album | ItemMapping,
+):
+ """Compare two album items and return True if they match."""
+ if left_album is None or right_album is None:
+ return False
+ # return early on exact item_id match
+ if compare_item_ids(left_album, right_album):
+ return True
+
+ # prefer match on UPC
+ if (
+ isinstance(left_album, Album)
+ and isinstance(right_album, Album)
+ and left_album.upc
+ and right_album.upc
+ and ((left_album.upc in right_album.upc) or (right_album.upc in left_album.upc))
+ ):
+ return True
+ # prefer match on musicbrainz_id
+ # not present on ItemMapping
+ if getattr(left_album, "musicbrainz_id", None) and getattr(right_album, "musicbrainz_id", None):
+ return left_album.musicbrainz_id == right_album.musicbrainz_id
+
+ # fallback to comparing
+ if not compare_strings(left_album.name, right_album.name, False):
+ return False
+ if not compare_version(left_album.version, right_album.version):
+ return False
+ # compare album artist
+ # Note: Not present on ItemMapping
+ if (
+ hasattr(left_album, "artist")
+ and hasattr(right_album, "artist")
+ and not compare_artist(left_album.artist, right_album.artist)
+ ):
+ return False
+ return left_album.sort_name == right_album.sort_name
+
+
+def compare_track(left_track: Track, right_track: Track):
+ """Compare two track items and return True if they match."""
+ if left_track is None or right_track is None:
+ return False
+ # return early on exact item_id match
+ if compare_item_ids(left_track, right_track):
+ return True
+ for left_isrc in left_track.isrcs:
+ for right_isrc in right_track.isrcs:
+ # ISRC is always 100% accurate match
+ if left_isrc == right_isrc:
+ return True
+ if (
+ left_track.musicbrainz_id
+ and right_track.musicbrainz_id
+ and left_track.musicbrainz_id == right_track.musicbrainz_id
+ ):
+ # musicbrainz_id is always 100% accurate match
+ return True
+ # album is required for track linking
+ if left_track.album is None or right_track.album is None:
+ return False
+ # track name must match
+ if not compare_strings(left_track.name, right_track.name, False):
+ return False
+ # exact albumtrack match = 100% match
+ if (
+ compare_album(left_track.album, right_track.album)
+ and left_track.track_number
+ and right_track.track_number
+ and left_track.disc_number == right_track.disc_number
+ and left_track.track_number == right_track.track_number
+ ):
+ return True
+ # track version must match
+ if not compare_version(left_track.version, right_track.version):
+ return False
+ # track artist(s) must match
+ if not compare_artists(left_track.artists, right_track.artists):
+ return False
+ # track if both tracks are (not) explicit
+ if not compare_explicit(left_track.metadata, right_track.metadata):
+ return False
+ # exact album match = 100% match
+ if left_track.albums and right_track.albums:
+ for left_album in left_track.albums:
+ for right_album in right_track.albums:
+ if compare_album(left_album, right_album):
+ return True
+ # fallback: both albums are compilations and (near-exact) track duration match
+ if (
+ abs(left_track.duration - right_track.duration) <= 2
+ and left_track.album.album_type in (AlbumType.UNKNOWN, AlbumType.COMPILATION)
+ and right_track.album.album_type in (AlbumType.UNKNOWN, AlbumType.COMPILATION)
+ ):
+ return True
+ return False
--- /dev/null
+"""Database helpers and logic."""
+from __future__ import annotations
+
+from collections.abc import Mapping
+from typing import Any
+
+from databases import Database as Db
+from databases import DatabaseURL
+from sqlalchemy.sql import ClauseElement
+
+
+class DatabaseConnection:
+ """Class that holds the (connection to the) database with some convenience helper functions."""
+
+ def __init__(self, url: DatabaseURL):
+ """Initialize class."""
+ self.url = url
+ # we maintain one global connection - otherwise we run into (dead)lock issues.
+ # https://github.com/encode/databases/issues/456
+ self._db = Db(self.url, timeout=360)
+
+ async def setup(self) -> None:
+ """Perform async initialization."""
+ await self._db.connect()
+
+ async def close(self) -> None:
+ """Close db connection on exit."""
+ await self._db.disconnect()
+
+ async def get_rows(
+ self,
+ table: str,
+ match: dict = None,
+ order_by: str = None,
+ limit: int = 500,
+ offset: int = 0,
+ ) -> list[Mapping]:
+ """Get all rows for given table."""
+ sql_query = f"SELECT * FROM {table}"
+ if match is not None:
+ sql_query += " WHERE " + " AND ".join(f"{x} = :{x}" for x in match)
+ if order_by is not None:
+ sql_query += f" ORDER BY {order_by}"
+ sql_query += f" LIMIT {limit} OFFSET {offset}"
+ return await self._db.fetch_all(sql_query, match)
+
+ async def get_rows_from_query(
+ self,
+ query: str,
+ params: dict | None = None,
+ limit: int = 500,
+ offset: int = 0,
+ ) -> list[Mapping]:
+ """Get all rows for given custom query."""
+ query = f"{query} LIMIT {limit} OFFSET {offset}"
+ return await self._db.fetch_all(query, params)
+
+ async def get_count_from_query(
+ self,
+ query: str,
+ params: dict | None = None,
+ ) -> int:
+ """Get row count for given custom query."""
+ query = f"SELECT count() FROM ({query})"
+ if result := await self._db.fetch_one(query, params):
+ return result[0]
+ return 0
+
+ async def get_count(
+ self,
+ table: str,
+ ) -> int:
+ """Get row count for given table."""
+ query = f"SELECT count(*) FROM {table}"
+ if result := await self._db.fetch_one(query):
+ return result[0]
+ return 0
+
+ async def search(self, table: str, search: str, column: str = "name") -> list[Mapping]:
+ """Search table by column."""
+ sql_query = f"SELECT * FROM {table} WHERE {column} LIKE :search"
+ params = {"search": f"%{search}%"}
+ return await self._db.fetch_all(sql_query, params)
+
+ async def get_row(self, table: str, match: dict[str, Any]) -> Mapping | None:
+ """Get single row for given table where column matches keys/values."""
+ sql_query = f"SELECT * FROM {table} WHERE "
+ sql_query += " AND ".join(f"{x} = :{x}" for x in match)
+ return await self._db.fetch_one(sql_query, match)
+
+ async def insert(
+ self,
+ table: str,
+ values: dict[str, Any],
+ allow_replace: bool = False,
+ ) -> Mapping:
+ """Insert data in given table."""
+ keys = tuple(values.keys())
+ if allow_replace:
+ sql_query = f'INSERT OR REPLACE INTO {table}({",".join(keys)})'
+ else:
+ sql_query = f'INSERT INTO {table}({",".join(keys)})'
+ sql_query += f' VALUES ({",".join((f":{x}" for x in keys))})'
+ await self.execute(sql_query, values)
+ # return inserted/replaced item
+ lookup_vals = {
+ key: value for key, value in values.items() if value is not None and value != ""
+ }
+ return await self.get_row(table, lookup_vals)
+
+ async def insert_or_replace(self, table: str, values: dict[str, Any]) -> Mapping:
+ """Insert or replace data in given table."""
+ return await self.insert(table=table, values=values, allow_replace=True)
+
+ async def update(
+ self,
+ table: str,
+ match: dict[str, Any],
+ values: dict[str, Any],
+ ) -> Mapping:
+ """Update record."""
+ keys = tuple(values.keys())
+ sql_query = f'UPDATE {table} SET {",".join((f"{x}=:{x}" for x in keys))} WHERE '
+ sql_query += " AND ".join(f"{x} = :{x}" for x in match)
+ await self.execute(sql_query, {**match, **values})
+ # return updated item
+ return await self.get_row(table, match)
+
+ async def delete(self, table: str, match: dict | None = None, query: str | None = None) -> None:
+ """Delete data in given table."""
+ assert not (query and "where" in query.lower())
+ sql_query = f"DELETE FROM {table} "
+ if match:
+ sql_query += " WHERE " + " AND ".join(f"{x} = :{x}" for x in match)
+ elif query and "query" not in query.lower():
+ sql_query += "WHERE " + query
+ elif query:
+ sql_query += query
+
+ await self.execute(sql_query, match)
+
+ async def delete_where_query(self, table: str, query: str | None = None) -> None:
+ """Delete data in given table using given where clausule."""
+ sql_query = f"DELETE FROM {table} WHERE {query}"
+ await self.execute(sql_query)
+
+ async def execute(self, query: ClauseElement | str, values: dict = None) -> Any:
+ """Execute command on the database."""
+ return await self._db.execute(query, values)
--- /dev/null
+"""Helper(s) to create DIDL Lite metadata for Sonos/DLNA players."""
+from __future__ import annotations
+
+import datetime
+from typing import TYPE_CHECKING
+
+from music_assistant.common.models.enums import MediaType
+from music_assistant.constants import MASS_LOGO_ONLINE
+
+if TYPE_CHECKING:
+ from music_assistant.common.models.queue_item import QueueItem
+
+# ruff: noqa: E501
+
+
+def create_didl_metadata(url: str, queue_item: QueueItem, flow_mode: bool = False) -> str:
+ """Create DIDL metadata string from url and QueueItem."""
+ ext = url.split(".")[-1]
+ is_radio = queue_item.media_type != MediaType.TRACK or not queue_item.duration
+
+ if flow_mode:
+ return (
+ '<DIDL-Lite xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:upnp="urn:schemas-upnp-org:metadata-1-0/upnp/" xmlns="urn:schemas-upnp-org:metadata-1-0/DIDL-Lite/" xmlns:dlna="urn:schemas-dlna-org:metadata-1-0/">'
+ f'<item id="{queue_item.queue_item_id}" parentID="0" restricted="1">'
+ f"<dc:title>Music Assistant</dc:title>"
+ f"<upnp:albumArtURI>{MASS_LOGO_ONLINE}</upnp:albumArtURI>"
+ f"<dc:queueItemId>{queue_item.queue_item_id}</dc:queueItemId>"
+ "<upnp:class>object.item.audioItem.audioBroadcast</upnp:class>"
+ f"<upnp:mimeType>audio/{ext}</upnp:mimeType>"
+ f'<res protocolInfo="http-get:*:audio/{ext}:DLNA.ORG_OP=00;DLNA.ORG_CI=0;DLNA.ORG_FLAGS=0d500000000000000000000000000000">{url}</res>'
+ "</item>"
+ "</DIDL-Lite>"
+ )
+ if is_radio:
+ # radio or other non-track item
+ return (
+ '<DIDL-Lite xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:upnp="urn:schemas-upnp-org:metadata-1-0/upnp/" xmlns="urn:schemas-upnp-org:metadata-1-0/DIDL-Lite/" xmlns:dlna="urn:schemas-dlna-org:metadata-1-0/">'
+ f'<item id="{queue_item.queue_item_id}" parentID="0" restricted="1">'
+ f"<dc:title>{_escape_str(queue_item.name)}</dc:title>"
+ f"<upnp:albumArtURI>{queue_item.image.url}</upnp:albumArtURI>"
+ f"<dc:queueItemId>{queue_item.queue_item_id}</dc:queueItemId>"
+ "<upnp:class>object.item.audioItem.audioBroadcast</upnp:class>"
+ f"<upnp:mimeType>audio/{ext}</upnp:mimeType>"
+ f'<res protocolInfo="http-get:*:audio/{ext}:DLNA.ORG_OP=00;DLNA.ORG_CI=0;DLNA.ORG_FLAGS=0d500000000000000000000000000000">{url}</res>'
+ "</item>"
+ "</DIDL-Lite>"
+ )
+ title = _escape_str(queue_item.media_item.name)
+ artist = _escape_str(queue_item.media_item.artist.name)
+ album = _escape_str(queue_item.media_item.album.name)
+ item_class = "object.item.audioItem.musicTrack"
+ duration_str = str(datetime.timedelta(seconds=queue_item.duration))
+ return (
+ '<DIDL-Lite xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:upnp="urn:schemas-upnp-org:metadata-1-0/upnp/" xmlns="urn:schemas-upnp-org:metadata-1-0/DIDL-Lite/" xmlns:dlna="urn:schemas-dlna-org:metadata-1-0/">'
+ f'<item id="{queue_item.queue_item_id}" parentID="0" restricted="1">'
+ f"<dc:title>{title}</dc:title>"
+ f"<dc:creator>{artist}</dc:creator>"
+ f"<upnp:album>{album}</upnp:album>"
+ f"<upnp:artist>{artist}</upnp:artist>"
+ f"<upnp:duration>{queue_item.duration}</upnp:duration>"
+ "<upnp:playlistTitle>Music Assistant</upnp:playlistTitle>"
+ f"<dc:queueItemId>{queue_item.queue_item_id}</dc:queueItemId>"
+ f"<upnp:albumArtURI>{queue_item.image.url}</upnp:albumArtURI>"
+ f"<upnp:class>{item_class}</upnp:class>"
+ f"<upnp:mimeType>audio/{ext}</upnp:mimeType>"
+ f'<res duration="{duration_str}" protocolInfo="http-get:*:audio/{ext}:DLNA.ORG_OP=00;DLNA.ORG_CI=0;DLNA.ORG_FLAGS=0d500000000000000000000000000000">{url}</res>'
+ "</item>"
+ "</DIDL-Lite>"
+ )
+
+
+def _escape_str(data: str) -> str:
+ """Create DIDL-safe string."""
+ data = data.replace("&", "&")
+ data = data.replace(">", ">")
+ data = data.replace("<", "<")
+ return data
--- /dev/null
+"""Utilities for image manipulation and retrieval."""
+from __future__ import annotations
+
+import asyncio
+import random
+from io import BytesIO
+from typing import TYPE_CHECKING
+
+from PIL import Image
+
+from music_assistant.server.helpers.tags import get_embedded_image
+
+if TYPE_CHECKING:
+ from music_assistant.server import MusicAssistant
+
+
+async def get_image_data(mass: MusicAssistant, path: str) -> bytes:
+ """Create thumbnail from image url."""
+ # always try ffmpeg first to get the image because it supports
+ # both online and offline image files as well as embedded images in media files
+ img_data = await get_embedded_image(path)
+ if img_data:
+ return img_data
+ # assume file from file provider, we need to fetch it here...
+ for prov in mass.music.providers:
+ if not prov.domain.startswith("filesystem"):
+ continue
+ if not await prov.exists(path):
+ continue
+ path = await prov.resolve(path)
+ img_data = await get_embedded_image(path.local_path)
+ if img_data:
+ return img_data
+ raise FileNotFoundError(f"Image not found: {path}")
+
+
+async def get_image_thumb(mass: MusicAssistant, path: str, size: int | None) -> bytes:
+ """Get (optimized) PNG thumbnail from image url."""
+ img_data = await get_image_data(mass, path)
+
+ def _create_image():
+ data = BytesIO()
+ img = Image.open(BytesIO(img_data))
+ if size:
+ img.thumbnail((size, size), Image.ANTIALIAS)
+ img.convert("RGB").save(data, "PNG", optimize=True)
+ return data.getvalue()
+
+ return await asyncio.to_thread(_create_image)
+
+
+async def create_collage(mass: MusicAssistant, images: list[str]) -> bytes:
+ """Create a basic collage image from multiple image urls."""
+
+ def _new_collage():
+ return Image.new("RGBA", (1500, 1500), color=(255, 255, 255, 255))
+
+ collage = await asyncio.to_thread(_new_collage)
+
+ def _add_to_collage(img_data: bytes, coord_x: int, coord_y: int):
+ data = BytesIO(img_data)
+ photo = Image.open(data).convert("RGBA")
+ photo = photo.resize((500, 500))
+ collage.paste(photo, (coord_x, coord_y))
+
+ for x_co in range(0, 1500, 500):
+ for y_co in range(0, 1500, 500):
+ img_data = await get_image_data(mass, random.choice(images))
+ await asyncio.to_thread(_add_to_collage, img_data, x_co, y_co)
+
+ def _save_collage():
+ final_data = BytesIO()
+ collage.convert("RGB").save(final_data, "PNG", optimize=True)
+ return final_data.getvalue()
+
+ return await asyncio.to_thread(_save_collage)
--- /dev/null
+"""Helpers for parsing playlists."""
+from __future__ import annotations
+
+import asyncio
+import logging
+from typing import TYPE_CHECKING
+
+import aiohttp
+
+from music_assistant.common.models.errors import InvalidDataError
+
+if TYPE_CHECKING:
+ from music_assistant.server import MusicAssistant
+
+
+LOGGER = logging.getLogger(__name__)
+
+
+async def parse_m3u(m3u_data: str) -> list[str]:
+ """Parse (only) filenames/urls from m3u playlist file."""
+ m3u_lines = m3u_data.splitlines()
+ lines = []
+ for line in m3u_lines:
+ line = line.strip() # noqa: PLW2901
+ if line.startswith("#"):
+ # ignore metadata
+ continue
+ if len(line) != 0:
+ # Get uri/path from all other, non-blank lines
+ lines.append(line)
+
+ return lines
+
+
+async def parse_pls(pls_data: str) -> list[str]:
+ """Parse (only) filenames/urls from pls playlist file."""
+ pls_lines = pls_data.splitlines()
+ lines = []
+ for line in pls_lines:
+ line = line.strip() # noqa: PLW2901
+ if not line.startswith("File"):
+ # ignore metadata lines
+ continue
+ if "=" in line:
+ # Get uri/path from all other, non-blank lines
+ lines.append(line.split("=")[1])
+
+ return lines
+
+
+async def fetch_playlist(mass: MusicAssistant, url: str) -> list[str]:
+ """Parse an online m3u or pls playlist."""
+ try:
+ async with mass.http_session.get(url, timeout=5) as resp:
+ charset = resp.charset or "utf-8"
+ try:
+ playlist_data = (await resp.content.read(64 * 1024)).decode(charset)
+ except ValueError as err:
+ raise InvalidDataError(f"Could not decode playlist {url}") from err
+ except asyncio.TimeoutError as err:
+ raise InvalidDataError(f"Timeout while fetching playlist {url}") from err
+ except aiohttp.client_exceptions.ClientError as err:
+ raise InvalidDataError(f"Error while fetching playlist {url}") from err
+
+ if url.endswith(".m3u") or url.endswith(".m3u8"):
+ playlist = await parse_m3u(playlist_data)
+ else:
+ playlist = await parse_pls(playlist_data)
+
+ if not playlist:
+ raise InvalidDataError(f"Empty playlist {url}")
+
+ return playlist
--- /dev/null
+"""Implementation of a (truly) non blocking subprocess.
+
+The subprocess implementation in asyncio can (still) sometimes cause deadlocks,
+even when properly handling reading/writes from different tasks.
+"""
+from __future__ import annotations
+
+import asyncio
+import logging
+from collections.abc import AsyncGenerator, Coroutine
+
+LOGGER = logging.getLogger(__name__)
+
+DEFAULT_CHUNKSIZE = 128000
+DEFAULT_TIMEOUT = 600
+
+# pylint: disable=invalid-name
+
+
+class AsyncProcess:
+ """Implementation of a (truly) non blocking subprocess."""
+
+ def __init__(
+ self,
+ args: list | str,
+ enable_stdin: bool = False,
+ enable_stdout: bool = True,
+ enable_stderr: bool = False,
+ ):
+ """Initialize."""
+ self._proc = None
+ self._args = args
+ self._enable_stdin = enable_stdin
+ self._enable_stdout = enable_stdout
+ self._enable_stderr = enable_stderr
+ self._attached_task: asyncio.Task = None
+ self.closed = False
+
+ async def __aenter__(self) -> AsyncProcess:
+ """Enter context manager."""
+ args = " ".join(self._args) if "|" in self._args else self._args
+ if isinstance(args, str):
+ self._proc = await asyncio.create_subprocess_shell(
+ args,
+ stdin=asyncio.subprocess.PIPE if self._enable_stdin else None,
+ stdout=asyncio.subprocess.PIPE if self._enable_stdout else None,
+ stderr=asyncio.subprocess.PIPE if self._enable_stderr else None,
+ close_fds=True,
+ )
+ else:
+ self._proc = await asyncio.create_subprocess_exec(
+ *args,
+ stdin=asyncio.subprocess.PIPE if self._enable_stdin else None,
+ stdout=asyncio.subprocess.PIPE if self._enable_stdout else None,
+ stderr=asyncio.subprocess.PIPE if self._enable_stderr else None,
+ close_fds=True,
+ )
+
+ # Fix BrokenPipeError due to a race condition
+ # by attaching a default done callback
+ def _done_cb(fut: asyncio.Future):
+ fut.exception()
+
+ self._proc._transport._protocol._stdin_closed.add_done_callback(_done_cb)
+
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback) -> bool:
+ """Exit context manager."""
+ self.closed = True
+ if self._attached_task:
+ # cancel the attached reader/writer task
+ try:
+ self._attached_task.cancel()
+ await self._attached_task
+ except asyncio.CancelledError:
+ pass
+ if self._proc.returncode is None:
+ # prevent subprocess deadlocking, read remaining bytes
+ await self._proc.communicate()
+ if self._enable_stdout and not self._proc.stdout.at_eof():
+ await self._proc.stdout.read()
+ if self._enable_stderr and not self._proc.stderr.at_eof():
+ await self._proc.stderr.read()
+ if self._proc.returncode is None:
+ # just in case?
+ self._proc.kill()
+
+ async def iter_chunked(self, n: int = DEFAULT_CHUNKSIZE) -> AsyncGenerator[bytes, None]:
+ """Yield chunks of n size from the process stdout."""
+ while True:
+ chunk = await self.readexactly(n)
+ if chunk == b"":
+ break
+ yield chunk
+ if len(chunk) < n:
+ break
+
+ async def iter_any(self, n: int = DEFAULT_CHUNKSIZE) -> AsyncGenerator[bytes, None]:
+ """Yield chunks as they come in from process stdout."""
+ while True:
+ chunk = await self.read(n)
+ if chunk == b"":
+ break
+ yield chunk
+
+ async def readexactly(self, n: int, timeout: int = DEFAULT_TIMEOUT) -> bytes:
+ """Read exactly n bytes from the process stdout (or less if eof)."""
+ try:
+ async with asyncio.timeout(timeout):
+ return await self._proc.stdout.readexactly(n)
+ except asyncio.IncompleteReadError as err:
+ return err.partial
+
+ async def read(self, n: int, timeout: int = DEFAULT_TIMEOUT) -> bytes:
+ """Read up to n bytes from the stdout stream.
+
+ If n is positive, this function try to read n bytes,
+ and may return less or equal bytes than requested, but at least one byte.
+ If EOF was received before any byte is read, this function returns empty byte object.
+ """
+ async with asyncio.timeout(timeout):
+ return await self._proc.stdout.read(n)
+
+ async def write(self, data: bytes) -> None:
+ """Write data to process stdin."""
+ if self.closed or self._proc.stdin.is_closing():
+ raise asyncio.CancelledError()
+ self._proc.stdin.write(data)
+ try:
+ await self._proc.stdin.drain()
+ except BrokenPipeError:
+ raise asyncio.CancelledError()
+
+ def write_eof(self) -> None:
+ """Write end of file to to process stdin."""
+ try:
+ if self._proc.stdin.can_write_eof():
+ self._proc.stdin.write_eof()
+ except (
+ AttributeError,
+ AssertionError,
+ BrokenPipeError,
+ RuntimeError,
+ ConnectionResetError,
+ ):
+ # already exited, race condition
+ return
+
+ async def communicate(self, input_data: bytes | None = None) -> tuple[bytes, bytes]:
+ """Write bytes to process and read back results."""
+ return await self._proc.communicate(input_data)
+
+ def attach_task(self, coro: Coroutine) -> asyncio.Task:
+ """Attach given coro func as reader/writer task to properly cancel it when needed."""
+ self._attached_task = task = asyncio.create_task(coro)
+ return task
+
+
+async def check_output(shell_cmd: str) -> tuple[int, bytes]:
+ """Run shell subprocess and return output."""
+ proc = await asyncio.create_subprocess_shell(
+ shell_cmd,
+ stderr=asyncio.subprocess.STDOUT,
+ stdout=asyncio.subprocess.PIPE,
+ )
+ stdout, _ = await proc.communicate()
+ return (proc.returncode, stdout)
--- /dev/null
+"""Helpers/utilities to parse ID3 tags from audio files with ffmpeg."""
+from __future__ import annotations
+
+import json
+import os
+from collections.abc import AsyncGenerator
+from dataclasses import dataclass
+from json import JSONDecodeError
+from typing import Any
+
+from music_assistant.common.helpers.util import try_parse_int
+from music_assistant.common.models.errors import InvalidDataError
+from music_assistant.constants import UNKNOWN_ARTIST
+from music_assistant.server.helpers.process import AsyncProcess
+
+# the only multi-item splitter we accept is the semicolon,
+# which is also the default in Musicbrainz Picard.
+# the slash is also a common splitter but causes colissions with
+# artists actually containing a slash in the name, such as ACDC
+TAG_SPLITTER = ";"
+
+
+def split_items(org_str: str) -> tuple[str]:
+ """Split up a tags string by common splitter."""
+ if not org_str:
+ return tuple()
+ if isinstance(org_str, list):
+ return org_str
+ return tuple(x.strip() for x in org_str.split(TAG_SPLITTER))
+
+
+def split_artists(org_artists: str | tuple[str]) -> tuple[str]:
+ """Parse all artists from a string."""
+ final_artists = set()
+ # when not using the multi artist tag, the artist string may contain
+ # multiple artists in freeform, even featuring artists may be included in this
+ # string. Try to parse the featuring artists and separate them.
+ splitters = ("featuring", " feat. ", " feat ", "feat.")
+ for item in split_items(org_artists):
+ for splitter in splitters:
+ for subitem in item.split(splitter):
+ final_artists.add(subitem.strip())
+ return tuple(final_artists)
+
+
+@dataclass
+class AudioTags:
+ """Audio metadata parsed from an audio file."""
+
+ raw: dict[str, Any]
+ sample_rate: int
+ channels: int
+ bits_per_sample: int
+ format: str
+ bit_rate: int
+ duration: int | None
+ tags: dict[str, str]
+ has_cover_image: bool
+ filename: str
+
+ @property
+ def title(self) -> str:
+ """Return title tag (as-is)."""
+ if tag := self.tags.get("title"):
+ return tag
+ # fallback to parsing from filename
+ title = self.filename.rsplit(os.sep, 1)[-1].split(".")[0]
+ if " - " in title:
+ title_parts = title.split(" - ")
+ if len(title_parts) >= 2:
+ return title_parts[1].strip()
+ return title
+
+ @property
+ def album(self) -> str:
+ """Return album tag (as-is) if present."""
+ return self.tags.get("album")
+
+ @property
+ def artists(self) -> tuple[str]:
+ """Return track artists."""
+ # prefer multi-artist tag
+ if tag := self.tags.get("artists"):
+ return split_items(tag)
+ # fallback to regular artist string
+ if tag := self.tags.get("artist"):
+ if ";" in tag:
+ return split_items(tag)
+ return split_artists(tag)
+ # fallback to parsing from filename
+ title = self.filename.rsplit(os.sep, 1)[-1].split(".")[0]
+ if " - " in title:
+ title_parts = title.split(" - ")
+ if len(title_parts) >= 2:
+ return split_artists(title_parts[0])
+ return (UNKNOWN_ARTIST,)
+
+ @property
+ def album_artists(self) -> tuple[str]:
+ """Return (all) album artists (if any)."""
+ # prefer multi-artist tag
+ if tag := self.tags.get("albumartists"):
+ return split_items(tag)
+ # fallback to regular artist string
+ if tag := self.tags.get("albumartist"):
+ if ";" in tag:
+ return split_items(tag)
+ return split_artists(tag)
+ return tuple()
+
+ @property
+ def genres(self) -> tuple[str]:
+ """Return (all) genres, if any."""
+ return split_items(self.tags.get("genre"))
+
+ @property
+ def disc(self) -> int | None:
+ """Return disc tag if present."""
+ if tag := self.tags.get("disc"):
+ return try_parse_int(tag.split("/")[0], None)
+ return None
+
+ @property
+ def track(self) -> int | None:
+ """Return track tag if present."""
+ if tag := self.tags.get("track"):
+ return try_parse_int(tag.split("/")[0], None)
+ return None
+
+ @property
+ def year(self) -> int | None:
+ """Return album's year if present, parsed from date."""
+ if tag := self.tags.get("originalyear"):
+ return try_parse_int(tag.split("-")[0], None)
+ if tag := self.tags.get("originaldate"):
+ return try_parse_int(tag.split("-")[0], None)
+ if tag := self.tags.get("date"):
+ return try_parse_int(tag.split("-")[0], None)
+ return None
+
+ @property
+ def musicbrainz_artistids(self) -> tuple[str]:
+ """Return musicbrainz_artistid tag(s) if present."""
+ return split_items(self.tags.get("musicbrainzartistid"))
+
+ @property
+ def musicbrainz_albumartistids(self) -> tuple[str]:
+ """Return musicbrainz_albumartistid tag if present."""
+ return split_items(self.tags.get("musicbrainzalbumartistid"))
+
+ @property
+ def musicbrainz_releasegroupid(self) -> str | None:
+ """Return musicbrainz_releasegroupid tag if present."""
+ return self.tags.get("musicbrainzreleasegroupid")
+
+ @property
+ def musicbrainz_trackid(self) -> str | None:
+ """Return musicbrainz_trackid tag if present."""
+ if tag := self.tags.get("musicbrainztrackid"):
+ return tag
+ return self.tags.get("musicbrainzreleasetrackid")
+
+ @property
+ def album_type(self) -> str | None:
+ """Return albumtype tag if present."""
+ if tag := self.tags.get("musicbrainzalbumtype"):
+ return tag
+ return self.tags.get("releasetype")
+
+ @classmethod
+ def parse(cls, raw: dict) -> AudioTags:
+ """Parse instance from raw ffmpeg info output."""
+ audio_stream = next(x for x in raw["streams"] if x["codec_type"] == "audio")
+ has_cover_image = any(x for x in raw["streams"] if x["codec_name"] in ("mjpeg", "png"))
+ # convert all tag-keys (gathered from all streams) to lowercase without spaces
+ tags = {}
+ for stream in raw["streams"] + [raw["format"]]:
+ for key, value in stream.get("tags", {}).items():
+ key = key.lower().replace(" ", "").replace("_", "") # noqa: PLW2901
+ tags[key] = value
+
+ return AudioTags(
+ raw=raw,
+ sample_rate=int(audio_stream.get("sample_rate", 44100)),
+ channels=audio_stream.get("channels", 2),
+ bits_per_sample=int(
+ audio_stream.get("bits_per_raw_sample", audio_stream.get("bits_per_sample")) or 16
+ ),
+ format=raw["format"]["format_name"],
+ bit_rate=int(raw["format"].get("bit_rate", 320)),
+ duration=int(float(raw["format"].get("duration", 0))) or None,
+ tags=tags,
+ has_cover_image=has_cover_image,
+ filename=raw["format"]["filename"],
+ )
+
+ def get(self, key: str, default=None) -> Any:
+ """Get tag by key."""
+ return self.tags.get(key, default)
+
+
+async def parse_tags(input_file: str | AsyncGenerator[bytes, None]) -> AudioTags:
+ """Parse tags from a media file.
+
+ input_file may be a (local) filename/url accessible by ffmpeg or
+ an AsyncGenerator which yields the file contents as bytes.
+ """
+ file_path = input_file if isinstance(input_file, str) else "-"
+
+ args = (
+ "ffprobe",
+ "-hide_banner",
+ "-loglevel",
+ "fatal",
+ "-show_error",
+ "-show_format",
+ "-show_streams",
+ "-print_format",
+ "json",
+ "-i",
+ file_path,
+ )
+
+ async with AsyncProcess(
+ args, enable_stdin=file_path == "-", enable_stdout=True, enable_stderr=False
+ ) as proc:
+ if file_path == "-":
+ # feed the file contents to the process
+ async def chunk_feeder():
+ # pylint: disable=protected-access
+ async for chunk in input_file:
+ try:
+ await proc.write(chunk)
+ except BrokenPipeError:
+ break # race-condition: read enough data for tags
+
+ proc.attach_task(chunk_feeder())
+
+ try:
+ res = await proc.read(-1)
+ data = json.loads(res)
+ if error := data.get("error"):
+ raise InvalidDataError(error["string"])
+ return AudioTags.parse(data)
+ except (KeyError, ValueError, JSONDecodeError, InvalidDataError) as err:
+ raise InvalidDataError(f"Unable to retrieve info for {file_path}: {str(err)}") from err
+
+
+async def get_embedded_image(input_file: str | AsyncGenerator[bytes, None]) -> bytes | None:
+ """Return embedded image data.
+
+ input_file may be a (local) filename/url accessible by ffmpeg or
+ an AsyncGenerator which yields the file contents as bytes.
+ """
+ file_path = input_file if isinstance(input_file, str) else "-"
+ args = (
+ "ffmpeg",
+ "-hide_banner",
+ "-loglevel",
+ "fatal",
+ "-i",
+ file_path,
+ "-map",
+ "0:v",
+ "-c",
+ "copy",
+ "-f",
+ "mjpeg",
+ "-",
+ )
+
+ async with AsyncProcess(
+ args, enable_stdin=file_path == "-", enable_stdout=True, enable_stderr=False
+ ) as proc:
+ if file_path == "-":
+ # feed the file contents to the process
+ async for chunk in input_file:
+ await proc.write(chunk)
+
+ if file_path == "-":
+ # feed the file contents to the process
+ async def chunk_feeder():
+ async for chunk in input_file:
+ await proc.write(chunk)
+
+ proc.attach_task(chunk_feeder())
+
+ return await proc.read(-1)
--- /dev/null
+"""Various (server-only) tools and helpers."""
+
+import asyncio
+import logging
+
+LOGGER = logging.getLogger(__name__)
+
+
+async def install_package(package: str) -> None:
+ """Install package with pip, raise when install failed."""
+ cmd = f"python3 -m pip install {package}"
+ proc = await asyncio.create_subprocess_shell(
+ cmd, stderr=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.DEVNULL
+ )
+
+ _, stderr = await proc.communicate()
+
+ if proc.returncode != 0:
+ msg = f"Failed to install package {package}\n{stderr.decode()}"
+ raise RuntimeError(msg)
--- /dev/null
+"""Server specific/only models."""
--- /dev/null
+"""Model/base for a Metadata Provider implementation."""
+from __future__ import annotations
+
+from collections.abc import Iterable
+from typing import TYPE_CHECKING
+
+from music_assistant.common.models.enums import ProviderFeature
+
+from .provider import Provider
+
+if TYPE_CHECKING:
+ from music_assistant.common.models.media_items import Album, Artist, MediaItemMetadata, Track
+
+# ruff: noqa: ARG001, ARG002
+
+
+class MetadataProvider(Provider):
+ """Base representation of a Metadata Provider (controller).
+
+ Metadata Provider implementations should inherit from this base model.
+ """
+
+ _attr_supported_features: tuple[ProviderFeature, ...] = (
+ ProviderFeature.ARTIST_METADATA,
+ ProviderFeature.ALBUM_METADATA,
+ ProviderFeature.TRACK_METADATA,
+ ProviderFeature.GET_ARTIST_MBID,
+ )
+
+ async def get_artist_metadata(self, artist: Artist) -> MediaItemMetadata | None:
+ """Retrieve metadata for an artist on this Metadata provider."""
+ if ProviderFeature.ARTIST_METADATA in self.supported_features:
+ raise NotImplementedError
+ return
+
+ async def get_album_metadata(self, album: Album) -> MediaItemMetadata | None:
+ """Retrieve metadata for an album on this Metadata provider."""
+ if ProviderFeature.ALBUM_METADATA in self.supported_features:
+ raise NotImplementedError
+ return
+
+ async def get_track_metadata(self, track: Track) -> MediaItemMetadata | None:
+ """Retrieve metadata for a track on this Metadata provider."""
+ if ProviderFeature.TRACK_METADATA in self.supported_features:
+ raise NotImplementedError
+ return
+
+ async def get_musicbrainz_artist_id(
+ self, artist: Artist, ref_albums: Iterable[Album], ref_tracks: Iterable[Track]
+ ) -> str | None:
+ """Discover MusicBrainzArtistId for an artist given some reference albums/tracks."""
+ if ProviderFeature.GET_ARTIST_MBID in self.supported_features:
+ raise NotImplementedError
+ return
--- /dev/null
+"""Model/base for a Music Provider implementation."""
+from __future__ import annotations
+
+from collections.abc import AsyncGenerator
+
+from music_assistant.common.models.enums import MediaType, ProviderFeature
+from music_assistant.common.models.media_items import (
+ Album,
+ Artist,
+ BrowseFolder,
+ MediaItemType,
+ Playlist,
+ Radio,
+ StreamDetails,
+ Track,
+)
+
+from .provider import Provider
+
+# ruff: noqa: ARG001, ARG002
+
+
+class MusicProvider(Provider):
+ """Base representation of a Music Provider (controller).
+
+ Music Provider implementations should inherit from this base model.
+ """
+
+ async def search(
+ self,
+ search_query: str,
+ media_types: list[MediaType] | None = None,
+ limit: int = 5,
+ ) -> list[MediaItemType]:
+ """Perform search on musicprovider.
+
+ :param search_query: Search query.
+ :param media_types: A list of media_types to include. All types if None.
+ :param limit: Number of items to return in the search (per type).
+ """
+ if ProviderFeature.SEARCH in self.supported_features:
+ raise NotImplementedError
+ return []
+
+ async def get_library_artists(self) -> AsyncGenerator[Artist, None]:
+ """Retrieve library artists from the provider."""
+ if ProviderFeature.LIBRARY_ARTISTS in self.supported_features:
+ raise NotImplementedError
+ yield # type: ignore
+
+ async def get_library_albums(self) -> AsyncGenerator[Album, None]:
+ """Retrieve library albums from the provider."""
+ if ProviderFeature.LIBRARY_ALBUMS in self.supported_features:
+ raise NotImplementedError
+ yield # type: ignore
+
+ async def get_library_tracks(self) -> AsyncGenerator[Track, None]:
+ """Retrieve library tracks from the provider."""
+ if ProviderFeature.LIBRARY_TRACKS in self.supported_features:
+ raise NotImplementedError
+ yield # type: ignore
+
+ async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]:
+ """Retrieve library/subscribed playlists from the provider."""
+ if ProviderFeature.LIBRARY_PLAYLISTS in self.supported_features:
+ raise NotImplementedError
+ yield # type: ignore
+
+ async def get_library_radios(self) -> AsyncGenerator[Radio, None]:
+ """Retrieve library/subscribed radio stations from the provider."""
+ if ProviderFeature.LIBRARY_RADIOS in self.supported_features:
+ raise NotImplementedError
+ yield # type: ignore
+
+ async def get_artist(self, prov_artist_id: str) -> Artist:
+ """Get full artist details by id."""
+ raise NotImplementedError
+
+ async def get_artist_albums(self, prov_artist_id: str) -> list[Album]:
+ """Get a list of all albums for the given artist."""
+ if ProviderFeature.ARTIST_ALBUMS in self.supported_features:
+ raise NotImplementedError
+ return []
+
+ async def get_artist_toptracks(self, prov_artist_id: str) -> list[Track]:
+ """Get a list of most popular tracks for the given artist."""
+ if ProviderFeature.ARTIST_TOPTRACKS in self.supported_features:
+ raise NotImplementedError
+ return []
+
+ async def get_album(self, prov_album_id: str) -> Album: # type: ignore[return]
+ """Get full album details by id."""
+ if ProviderFeature.LIBRARY_ALBUMS in self.supported_features:
+ raise NotImplementedError
+
+ async def get_track(self, prov_track_id: str) -> Track: # type: ignore[return]
+ """Get full track details by id."""
+ if ProviderFeature.LIBRARY_TRACKS in self.supported_features:
+ raise NotImplementedError
+
+ async def get_playlist(self, prov_playlist_id: str) -> Playlist: # type: ignore[return]
+ """Get full playlist details by id."""
+ if ProviderFeature.LIBRARY_PLAYLISTS in self.supported_features:
+ raise NotImplementedError
+
+ async def get_radio(self, prov_radio_id: str) -> Radio: # type: ignore[return]
+ """Get full radio details by id."""
+ if ProviderFeature.LIBRARY_RADIOS in self.supported_features:
+ raise NotImplementedError
+
+ async def get_album_tracks(self, prov_album_id: str) -> list[Track]: # type: ignore[return]
+ """Get album tracks for given album id."""
+ if ProviderFeature.LIBRARY_ALBUMS in self.supported_features:
+ raise NotImplementedError
+
+ async def get_playlist_tracks( # type: ignore[return]
+ self, prov_playlist_id: str
+ ) -> list[Track]:
+ """Get all playlist tracks for given playlist id."""
+ if ProviderFeature.LIBRARY_PLAYLISTS in self.supported_features:
+ raise NotImplementedError
+
+ async def library_add(self, prov_item_id: str, media_type: MediaType) -> bool:
+ """Add item to provider's library. Return true on success."""
+ if (
+ media_type == MediaType.ARTIST
+ and ProviderFeature.LIBRARY_ARTISTS_EDIT in self.supported_features
+ ):
+ raise NotImplementedError
+ if (
+ media_type == MediaType.ALBUM
+ and ProviderFeature.LIBRARY_ALBUMS_EDIT in self.supported_features
+ ):
+ raise NotImplementedError
+ if (
+ media_type == MediaType.TRACK
+ and ProviderFeature.LIBRARY_TRACKS_EDIT in self.supported_features
+ ):
+ raise NotImplementedError
+ if (
+ media_type == MediaType.PLAYLIST
+ and ProviderFeature.LIBRARY_PLAYLISTS_EDIT in self.supported_features
+ ):
+ raise NotImplementedError
+ if (
+ media_type == MediaType.RADIO
+ and ProviderFeature.LIBRARY_RADIOS_EDIT in self.supported_features
+ ):
+ raise NotImplementedError
+ self.logger.info(
+ "Provider %s does not support library edit, "
+ "the action will only be performed in the local database.",
+ self.name,
+ )
+ return True
+
+ async def library_remove(self, prov_item_id: str, media_type: MediaType) -> bool:
+ """Remove item from provider's library. Return true on success."""
+ if (
+ media_type == MediaType.ARTIST
+ and ProviderFeature.LIBRARY_ARTISTS_EDIT in self.supported_features
+ ):
+ raise NotImplementedError
+ if (
+ media_type == MediaType.ALBUM
+ and ProviderFeature.LIBRARY_ALBUMS_EDIT in self.supported_features
+ ):
+ raise NotImplementedError
+ if (
+ media_type == MediaType.TRACK
+ and ProviderFeature.LIBRARY_TRACKS_EDIT in self.supported_features
+ ):
+ raise NotImplementedError
+ if (
+ media_type == MediaType.PLAYLIST
+ and ProviderFeature.LIBRARY_PLAYLISTS_EDIT in self.supported_features
+ ):
+ raise NotImplementedError
+ if (
+ media_type == MediaType.RADIO
+ and ProviderFeature.LIBRARY_RADIOS_EDIT in self.supported_features
+ ):
+ raise NotImplementedError
+ self.logger.info(
+ "Provider %s does not support library edit, "
+ "the action will only be performed in the local database.",
+ self.name,
+ )
+ return True
+
+ async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]) -> None:
+ """Add track(s) to playlist."""
+ if ProviderFeature.PLAYLIST_TRACKS_EDIT in self.supported_features:
+ raise NotImplementedError
+
+ async def remove_playlist_tracks(
+ self, prov_playlist_id: str, positions_to_remove: tuple[int]
+ ) -> None:
+ """Remove track(s) from playlist."""
+ if ProviderFeature.PLAYLIST_TRACKS_EDIT in self.supported_features:
+ raise NotImplementedError
+
+ async def create_playlist(self, name: str) -> Playlist: # type: ignore[return]
+ """Create a new playlist on provider with given name."""
+ if ProviderFeature.PLAYLIST_CREATE in self.supported_features:
+ raise NotImplementedError
+
+ async def get_similar_tracks( # type: ignore[return]
+ self, prov_track_id: str, limit: int = 25
+ ) -> list[Track]:
+ """Retrieve a dynamic list of similar tracks based on the provided track."""
+ if ProviderFeature.SIMILAR_TRACKS in self.supported_features:
+ raise NotImplementedError
+
+ async def get_stream_details(self, item_id: str) -> StreamDetails | None:
+ """Get streamdetails for a track/radio."""
+ raise NotImplementedError
+
+ async def get_audio_stream( # type: ignore[return]
+ self, streamdetails: StreamDetails, seek_position: int = 0
+ ) -> AsyncGenerator[bytes, None]:
+ """Return the audio stream for the provider item."""
+ if streamdetails.direct is None:
+ raise NotImplementedError
+
+ async def get_item(self, media_type: MediaType, prov_item_id: str) -> MediaItemType:
+ """Get single MediaItem from provider."""
+ if media_type == MediaType.ARTIST:
+ return await self.get_artist(prov_item_id)
+ if media_type == MediaType.ALBUM:
+ return await self.get_album(prov_item_id)
+ if media_type == MediaType.PLAYLIST:
+ return await self.get_playlist(prov_item_id)
+ if media_type == MediaType.RADIO:
+ return await self.get_radio(prov_item_id)
+ return await self.get_track(prov_item_id)
+
+ async def browse(self, path: str) -> BrowseFolder:
+ """Browse this provider's items.
+
+ :param path: The path to browse, (e.g. provid://artists).
+ """
+ if ProviderFeature.BROWSE not in self.supported_features:
+ # we may NOT use the default implementation if the provider does not support browse
+ raise NotImplementedError
+
+ _, subpath = path.split("://")
+
+ # this reference implementation can be overridden with a provider specific approach
+ if not subpath:
+ # return main listing
+ root_items: list[BrowseFolder] = []
+ if ProviderFeature.LIBRARY_ARTISTS in self.supported_features:
+ root_items.append(
+ BrowseFolder(
+ item_id="artists",
+ provider=self.domain,
+ path=path + "artists",
+ name="",
+ label="artists",
+ )
+ )
+ if ProviderFeature.LIBRARY_ALBUMS in self.supported_features:
+ root_items.append(
+ BrowseFolder(
+ item_id="albums",
+ provider=self.domain,
+ path=path + "albums",
+ name="",
+ label="albums",
+ )
+ )
+ if ProviderFeature.LIBRARY_TRACKS in self.supported_features:
+ root_items.append(
+ BrowseFolder(
+ item_id="tracks",
+ provider=self.domain,
+ path=path + "tracks",
+ name="",
+ label="tracks",
+ )
+ )
+ if ProviderFeature.LIBRARY_PLAYLISTS in self.supported_features:
+ root_items.append(
+ BrowseFolder(
+ item_id="playlists",
+ provider=self.domain,
+ path=path + "playlists",
+ name="",
+ label="playlists",
+ )
+ )
+ if ProviderFeature.LIBRARY_RADIOS in self.supported_features:
+ root_items.append(
+ BrowseFolder(
+ item_id="radios",
+ provider=self.domain,
+ path=path + "radios",
+ name="",
+ label="radios",
+ )
+ )
+ return BrowseFolder(
+ item_id="root",
+ provider=self.domain,
+ path=path,
+ name=self.name,
+ items=root_items,
+ )
+ # sublevel
+ if subpath == "artists":
+ return BrowseFolder(
+ item_id="artists",
+ provider=self.domain,
+ path=path,
+ name="",
+ label="artists",
+ items=[x async for x in self.get_library_artists()],
+ )
+ if subpath == "albums":
+ return BrowseFolder(
+ item_id="albums",
+ provider=self.domain,
+ path=path,
+ name="",
+ label="albums",
+ items=[x async for x in self.get_library_albums()],
+ )
+ if subpath == "tracks":
+ return BrowseFolder(
+ item_id="tracks",
+ provider=self.domain,
+ path=path,
+ name="",
+ label="tracks",
+ items=[x async for x in self.get_library_tracks()],
+ )
+ if subpath == "radios":
+ return BrowseFolder(
+ item_id="radios",
+ provider=self.domain,
+ path=path,
+ name="",
+ label="radios",
+ items=[x async for x in self.get_library_radios()],
+ )
+ if subpath == "playlists":
+ return BrowseFolder(
+ item_id="playlists",
+ provider=self.domain,
+ path=path,
+ name="",
+ label="playlists",
+ items=[x async for x in self.get_library_playlists()],
+ )
+ raise KeyError("Invalid subpath")
+
+ async def recommendations(self) -> list[BrowseFolder]:
+ """Get this provider's recommendations.
+
+ Returns a list of BrowseFolder items with (max 25) mediaitems in the items attribute.
+ """
+ if ProviderFeature.RECOMMENDATIONS in self.supported_features:
+ raise NotImplementedError
+ return []
+
+ async def sync_library(self, media_types: tuple[MediaType, ...] | None = None) -> None:
+ """Run library sync for this provider."""
+ # this reference implementation can be overridden with provider specific approach
+ # this logic is aimed at streaming/online providers,
+ # which all have more or less the same structure.
+ # filesystem implementation(s) just override this.
+ if media_types is None:
+ media_types = tuple(x for x in MediaType)
+ for media_type in media_types:
+ if not self.library_supported(media_type):
+ continue
+ self.logger.debug("Start sync of %s items.", media_type.value)
+ controller = self.mass.music.get_controller(media_type)
+ cur_db_ids = set()
+ async for prov_item in self._get_library_gen(media_type):
+ db_item: MediaItemType = await controller.get_db_item_by_prov_id(
+ item_id=prov_item.item_id,
+ provider_domain=prov_item.provider,
+ )
+ if not db_item:
+ # dump the item in the db, rich metadata is lazy loaded later
+ db_item = await controller.add_db_item(prov_item)
+
+ elif (
+ db_item.metadata.checksum and prov_item.metadata.checksum
+ ) and db_item.metadata.checksum != prov_item.metadata.checksum:
+ # item checksum changed
+ db_item = await controller.update_db_item(db_item.item_id, prov_item)
+ # preload album/playlist tracks
+ if prov_item.media_type == (MediaType.ALBUM, MediaType.PLAYLIST):
+ for track in controller.tracks(prov_item.item_id, prov_item.provider):
+ await self.mass.music.tracks.add_db_item(track)
+ cur_db_ids.add(db_item.item_id)
+ if not db_item.in_library:
+ await controller.set_db_library(db_item.item_id, True)
+
+ # process deletions (= no longer in library)
+ async for db_item in controller.iter_db_items(True):
+ if db_item.item_id in cur_db_ids:
+ continue
+ for prov_mapping in db_item.provider_mappings:
+ provider_domains = {x.provider_domain for x in db_item.provider_mappings}
+ if len(provider_domains) > 1:
+ continue
+ if prov_mapping.provider_instance != self.instance_id:
+ continue
+ # only mark the item as not in library and leave the metadata in db
+ await controller.set_db_library(db_item.item_id, False)
+
+ def is_file(self) -> bool:
+ """Return if this is a FileSystem based provider."""
+ # override this is needed
+ return self.domain.startswith("filesystem")
+
+ # DO NOT OVERRIDE BELOW
+
+ def library_supported(self, media_type: MediaType) -> bool:
+ """Return if Library is supported for given MediaType on this provider."""
+ if media_type == MediaType.ARTIST:
+ return ProviderFeature.LIBRARY_ARTISTS in self.supported_features
+ if media_type == MediaType.ALBUM:
+ return ProviderFeature.LIBRARY_ALBUMS in self.supported_features
+ if media_type == MediaType.TRACK:
+ return ProviderFeature.LIBRARY_TRACKS in self.supported_features
+ if media_type == MediaType.PLAYLIST:
+ return ProviderFeature.LIBRARY_PLAYLISTS in self.supported_features
+ if media_type == MediaType.RADIO:
+ return ProviderFeature.LIBRARY_RADIOS in self.supported_features
+ return False
+
+ def library_edit_supported(self, media_type: MediaType) -> bool:
+ """Return if Library add/remove is supported for given MediaType on this provider."""
+ if media_type == MediaType.ARTIST:
+ return ProviderFeature.LIBRARY_ARTISTS_EDIT in self.supported_features
+ if media_type == MediaType.ALBUM:
+ return ProviderFeature.LIBRARY_ALBUMS_EDIT in self.supported_features
+ if media_type == MediaType.TRACK:
+ return ProviderFeature.LIBRARY_TRACKS_EDIT in self.supported_features
+ if media_type == MediaType.PLAYLIST:
+ return ProviderFeature.LIBRARY_PLAYLISTS_EDIT in self.supported_features
+ if media_type == MediaType.RADIO:
+ return ProviderFeature.LIBRARY_RADIOS_EDIT in self.supported_features
+ return False
+
+ def _get_library_gen(self, media_type: MediaType) -> AsyncGenerator[MediaItemType, None]:
+ """Return library generator for given media_type."""
+ if media_type == MediaType.ARTIST:
+ return self.get_library_artists()
+ if media_type == MediaType.ALBUM:
+ return self.get_library_albums()
+ if media_type == MediaType.TRACK:
+ return self.get_library_tracks()
+ if media_type == MediaType.PLAYLIST:
+ return self.get_library_playlists()
+ if media_type == MediaType.RADIO:
+ return self.get_library_radios()
+ raise NotImplementedError
--- /dev/null
+"""Model/base for a Metadata Provider implementation."""
+from __future__ import annotations
+
+from abc import abstractmethod
+from typing import TYPE_CHECKING
+
+from music_assistant.common.models.enums import ProviderFeature
+from music_assistant.common.models.player import Player
+from music_assistant.common.models.queue_item import QueueItem
+
+from .provider import Provider
+
+if TYPE_CHECKING:
+ from music_assistant.common.models.config_entries import ConfigEntry, PlayerConfig
+
+# ruff: noqa: ARG001, ARG002
+
+
+class PlayerProvider(Provider):
+ """Base representation of a Player Provider (controller).
+
+ Player Provider implementations should inherit from this base model.
+ """
+
+ def get_player_config_entries(self, player_id: str) -> tuple[ConfigEntry, ...]:
+ """Return all (provider/player specific) Config Entries for the given player (if any)."""
+ return tuple()
+
+ def on_player_config_changed(self, config: PlayerConfig) -> None:
+ """Call (by config manager) when the configuration of a player changes."""
+
+ def on_player_config_removed(self, player_id: str) -> None:
+ """Call (by config manager) when the configuration of a player is removed."""
+
+ async def create_player_config(self, config: PlayerConfig | None = None) -> PlayerConfig:
+ """Handle CREATE_PLAYER flow for this player provider.
+
+ Allows manually registering/creating a player,
+ for example by manually entering an IP address etc.
+
+ Called by the Config manager without a value to get the PlayerConfig to show in the UI.
+ Called with PlayerConfig value with the submitted values.
+ """
+ # will only be called if the provider has the ADD_PLAYER feature set.
+ if ProviderFeature.CREATE_PLAYER_CONFIG in self.supported_features:
+ raise NotImplementedError
+
+ @abstractmethod
+ async def cmd_stop(self, player_id: str) -> None:
+ """Send STOP command to given player.
+
+ - player_id: player_id of the player to handle the command.
+ """
+
+ @abstractmethod
+ async def cmd_play(self, player_id: str) -> None:
+ """Send PLAY (unpause) command to given player.
+
+ - player_id: player_id of the player to handle the command.
+ """
+
+ @abstractmethod
+ async def cmd_pause(self, player_id: str) -> None:
+ """Send PAUSE command to given player.
+
+ - player_id: player_id of the player to handle the command.
+ """
+
+ @abstractmethod
+ async def cmd_play_media(
+ self,
+ player_id: str,
+ queue_item: QueueItem,
+ seek_position: int = 0,
+ fade_in: bool = False,
+ flow_mode: bool = False,
+ ) -> None:
+ """Send PLAY MEDIA command to given player.
+
+ This is called when the Queue wants the player to start playing a specific QueueItem.
+ The player implementation can decide how to process the request, such as playing
+ queue items one-by-one or enqueue all/some items.
+
+ - player_id: player_id of the player to handle the command.
+ - queue_item: the QueueItem to start playing on the player.
+ - seek_position: start playing from this specific position.
+ - fade_in: fade in the music at start (e.g. at resume).
+ - flow_mode: enable flow mode where the queue tracks are streamed as continuous stream.
+ """
+
+ async def cmd_power(self, player_id: str, powered: bool) -> None:
+ """Send POWER command to given player.
+
+ - player_id: player_id of the player to handle the command.
+ - powered: bool if player should be powered on or off.
+ """
+ # will only be called for players with Power feature set.
+
+ async def cmd_volume_set(self, player_id: str, volume_level: int) -> None:
+ """Send VOLUME_SET command to given player.
+
+ - player_id: player_id of the player to handle the command.
+ - volume_level: volume level (0..100) to set on the player.
+ """
+ # will only be called for players with Volume feature set.
+
+ async def cmd_volume_mute(self, player_id: str, muted: bool) -> None:
+ """Send VOLUME MUTE command to given player.
+
+ - player_id: player_id of the player to handle the command.
+ - muted: bool if player should be muted.
+ """
+ # will only be called for players with Mute feature set.
+
+ async def cmd_seek(self, player_id: str, position: int) -> None:
+ """Handle SEEK command for given queue.
+
+ - player_id: player_id of the player to handle the command.
+ - position: position in seconds to seek to in the current playing item.
+ """
+ # will only be called for players with Seek feature set.
+
+ async def cmd_sync(self, player_id: str, target_player: str) -> None:
+ """Handle SYNC command for given player.
+
+ Join/add the given player(id) to the given (master) player/sync group.
+
+ - player_id: player_id of the player to handle the command.
+ - target_player: player_id of the syncgroup master or group player.
+ """
+ # will only be called for players with SYNC feature set.
+
+ async def cmd_unsync(self, player_id: str) -> None:
+ """Handle UNSYNC command for given player.
+
+ Remove the given player from any syncgroups it currently is synced to.
+
+ - player_id: player_id of the player to handle the command.
+ """
+ # will only be called for players with SYNC feature set.
+
+ async def poll_player(self, player_id: str) -> None:
+ """Poll player for state updates.
+
+ This is called by the Player Manager;
+ - every 360 seconds if the player if not powered
+ - every 30 seconds if the player is powered
+ - every 10 seconds if the player is playing
+
+ Use this method to request any info that is not automatically updated and/or
+ to detect if the player is still alive.
+ If this method raises the PlayerUnavailable exception,
+ the player is marked as unavailable until
+ the next successful poll or event where it becomes available again.
+ If the player does not need any polling, simply do not override this method.
+ """
+
+ # DO NOT OVERRIDE BELOW
+
+ @property
+ def players(self) -> list[Player]:
+ """Return all players belonging to this provider."""
+ # pylint: disable=no-member
+ return [player for player in self.mass.players if player.provider == self.domain]
--- /dev/null
+"""Model/base for a Plugin Provider implementation."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from .provider import Provider
+
+if TYPE_CHECKING:
+ pass
+
+# ruff: noqa: ARG001, ARG002
+
+
+class PluginProvider(Provider):
+ """
+ Base representation of a Plugin for Music Assistant.
+
+ Plugin Provider implementations should inherit from this base model.
+ """
--- /dev/null
+"""Model/base for a Provider implementation within Music Assistant."""
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING
+
+from music_assistant.common.models.config_entries import ConfigEntryValue, ProviderConfig
+from music_assistant.common.models.enums import ProviderFeature, ProviderType
+from music_assistant.common.models.provider import ProviderInstance, ProviderManifest
+from music_assistant.constants import ROOT_LOGGER_NAME
+
+if TYPE_CHECKING:
+ from music_assistant.server import MusicAssistant
+
+# noqa: ARG001
+
+
+class Provider:
+ """Base representation of a Provider implementation within Music Assistant."""
+
+ _attr_supported_features: tuple[ProviderFeature, ...] = tuple()
+
+ def __init__(
+ self, mass: MusicAssistant, manifest: ProviderManifest, config: ProviderConfig
+ ) -> None:
+ """Initialize MusicProvider."""
+ self.mass = mass
+ self.manifest = manifest
+ self.config = config
+ self.logger = logging.getLogger(f"{ROOT_LOGGER_NAME}.providers.{self.domain}")
+ self.cache = mass.cache
+ self.available = False
+ self.last_error = None
+
+ @property
+ def supported_features(self) -> tuple[ProviderFeature, ...]:
+ """Return the features supported by this MusicProvider."""
+ return self._attr_supported_features
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider.
+
+ Called when provider is registered (or its config updated).
+ """
+
+ async def close(self) -> None:
+ """Handle close/cleanup of the provider.
+
+ Called when provider is deregistered (e.g. MA exiting or config reloading).
+ """
+
+ @property
+ def type(self) -> ProviderType:
+ """Return type of this provider."""
+ return self.manifest.type
+
+ @property
+ def domain(self) -> str:
+ """Return domain for this provider."""
+ return self.manifest.domain
+
+ @property
+ def instance_id(self) -> str:
+ """Return instance_id for this provider(instance)."""
+ return self.config.instance_id
+
+ @property
+ def name(self) -> str:
+ """Return (custom) friendly name for this provider instance."""
+ if self.config.name:
+ return self.config.name
+ inst_count = len([x for x in self.mass.music.providers if x.domain == self.domain])
+ if inst_count > 1:
+ postfix = self.instance_id[:-8]
+ return f"{self.manifest.name}.{postfix}"
+ return self.manifest.name
+
+ @property
+ def config_entries(self) -> list[ConfigEntryValue]:
+ """Return list of all ConfigEntries including values for this provider(instance)."""
+ return [
+ ConfigEntryValue.parse(x, self.config.values.get(x.key))
+ for x in self.manifest.config_entries
+ ]
+
+ def to_dict(self, *args, **kwargs) -> ProviderInstance: # noqa: ARG002
+ """Return Provider(instance) as serializable dict."""
+ return {
+ "type": self.type.value,
+ "domain": self.domain,
+ "name": self.name,
+ "instance_id": self.instance_id,
+ "supported_features": [x.value for x in self.supported_features],
+ "available": self.available,
+ "last_error": self.last_error,
+ }
--- /dev/null
+"""Package with Music Provider controllers."""
--- /dev/null
+"""Airplay Player provider.
+
+This is more like a "virtual" player provider, running on top of slimproto.
+It uses the amazing work of Philippe44 who created a bridge from airplay to slimoproto.
+https://github.com/philippe44/LMS-Raop
+"""
+from __future__ import annotations
+
+import asyncio
+import os
+import platform
+import xml.etree.ElementTree as ET # noqa: N817
+from typing import TYPE_CHECKING
+
+import aiofiles
+
+from music_assistant.common.models.config_entries import ConfigEntry
+from music_assistant.common.models.enums import ConfigEntryType
+from music_assistant.common.models.errors import PlayerUnavailableError
+from music_assistant.common.models.player import DeviceInfo, Player
+from music_assistant.common.models.queue_item import QueueItem
+from music_assistant.server.models.player_provider import PlayerProvider
+
+if TYPE_CHECKING:
+ from music_assistant.common.models.config_entries import PlayerConfig
+ from music_assistant.server.providers.slimproto import SlimprotoProvider
+
+
+PLAYER_CONFIG_ENTRIES = (
+ ConfigEntry(
+ key="airplay_label",
+ type=ConfigEntryType.LABEL,
+ label="Airplay specific settings",
+ description="Configure Airplay specific settings. "
+ "Note that changing any airplay specific setting, will reconnect all players.",
+ advanced=True,
+ ),
+ ConfigEntry(
+ key="read_ahead",
+ type=ConfigEntryType.INTEGER,
+ range=(0, 2000),
+ default_value=500,
+ label="Read ahead buffer",
+ description="Sets the number of milliseconds of audio buffer in the player. "
+ "This is important to absorb network throughput jitter. "
+ "Note that the resume after pause will be skipping that amount of time "
+ "and volume changes will be delayed by the same amount, when using digital volume.",
+ advanced=True,
+ ),
+ ConfigEntry(
+ key="encryption",
+ type=ConfigEntryType.BOOLEAN,
+ default_value=False,
+ label="Enable encryption",
+ description="Enable encrypted communication with the player, "
+ "some (3rd party) players require this.",
+ advanced=True,
+ ),
+ ConfigEntry(
+ key="alac_encode",
+ type=ConfigEntryType.BOOLEAN,
+ default_value=True,
+ label="Enable compression",
+ description="Save some network bandwidth by sending the audio as "
+ "(lossless) ALAC at the cost of a bit CPU.",
+ advanced=True,
+ ),
+)
+
+
+class AirplayProvider(PlayerProvider):
+ """Player provider for Airplay based players, using the slimproto bridge."""
+
+ _bridge_bin: str | None = None
+ _bridge_proc: asyncio.subprocess.Process | None = None
+ _closing: bool = False
+ _config_file: str | None = None
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ self._config_file = os.path.join(self.mass.storage_path, "airplay_bridge.xml")
+ # locate the raopbridge binary (will raise if that fails)
+ self._bridge_bin = await self._get_bridge_binary()
+ # make sure that slimproto provider is loaded
+ slimproto_prov: SlimprotoProvider = self.mass.get_provider("slimproto")
+ assert slimproto_prov, "This provider depends on the SlimProto provider."
+ # register as virtual provider on slimproto provider
+ slimproto_prov.register_virtual_provider(
+ "RaopBridge",
+ self._handle_player_register_callback,
+ self._handle_player_update_callback,
+ )
+ await self._check_config_xml()
+ # start running the bridge
+ asyncio.create_task(self._bridge_process_runner())
+
+ async def close(self) -> None:
+ """Handle close/cleanup of the provider."""
+ self._closing = True
+ await self._stop_bridge()
+
+ def get_player_config_entries(self, player_id: str) -> tuple[ConfigEntry]:
+ """Return all (provider/player specific) Config Entries for the given player (if any)."""
+ slimproto_prov = self.mass.get_provider("slimproto")
+ base_entries = slimproto_prov.get_player_config_entries(player_id)
+ return tuple(base_entries + PLAYER_CONFIG_ENTRIES)
+
+ def on_player_config_changed(self, config: PlayerConfig) -> None:
+ """Call (by config manager) when the configuration of a player changes."""
+ # forward to slimproto too
+ slimproto_prov = self.mass.get_provider("slimproto")
+ slimproto_prov.on_player_config_changed(config)
+
+ async def update_config():
+ # stop bridge (it will be auto restarted)
+ # TODO: only restart bridge if actual xml values changed
+ await self._stop_bridge()
+ # update the config
+ await self._check_config_xml()
+
+ asyncio.create_task(update_config())
+
+ async def cmd_stop(self, player_id: str) -> None:
+ """Send STOP command to given player."""
+ # simply forward to underlying slimproto player
+ slimproto_prov = self.mass.get_provider("slimproto")
+ await slimproto_prov.cmd_stop(player_id)
+
+ async def cmd_play(self, player_id: str) -> None:
+ """Send PLAY command to given player."""
+ # simply forward to underlying slimproto player
+ slimproto_prov = self.mass.get_provider("slimproto")
+ await slimproto_prov.cmd_play(player_id)
+
+ async def cmd_play_media(
+ self,
+ player_id: str,
+ queue_item: QueueItem,
+ seek_position: int = 0,
+ fade_in: bool = False,
+ flow_mode: bool = False,
+ ) -> None:
+ """Send PLAY MEDIA command to given player."""
+ # simply forward to underlying slimproto player
+ slimproto_prov = self.mass.get_provider("slimproto")
+ await slimproto_prov.cmd_play_media(
+ player_id,
+ queue_item=queue_item,
+ seek_position=seek_position,
+ fade_in=fade_in,
+ flow_mode=flow_mode,
+ )
+
+ async def cmd_pause(self, player_id: str) -> None:
+ """Send PAUSE command to given player."""
+ # simply forward to underlying slimproto player
+ slimproto_prov = self.mass.get_provider("slimproto")
+ await slimproto_prov.cmd_pause(player_id)
+
+ async def cmd_power(self, player_id: str, powered: bool) -> None:
+ """Send POWER command to given player."""
+ # simply forward to underlying slimproto player
+ slimproto_prov = self.mass.get_provider("slimproto")
+ await slimproto_prov.cmd_power(player_id, powered)
+
+ async def cmd_volume_set(self, player_id: str, volume_level: int) -> None:
+ """Send VOLUME_SET command to given player."""
+ # simply forward to underlying slimproto player
+ slimproto_prov = self.mass.get_provider("slimproto")
+ await slimproto_prov.cmd_volume_set(player_id, volume_level)
+
+ async def cmd_volume_mute(self, player_id: str, muted: bool) -> None:
+ """Send VOLUME MUTE command to given player."""
+ # simply forward to underlying slimproto player
+ slimproto_prov = self.mass.get_provider("slimproto")
+ await slimproto_prov.cmd_volume_mute(player_id, muted)
+
+ async def cmd_sync(self, player_id: str, target_player: str) -> None:
+ """Handle SYNC command for given player."""
+ # simply forward to underlying slimproto player
+ slimproto_prov = self.mass.get_provider("slimproto")
+ await slimproto_prov.cmd_sync(player_id, target_player)
+
+ async def cmd_unsync(self, player_id: str) -> None:
+ """Handle UNSYNC command for given player."""
+ # simply forward to underlying slimproto player
+ slimproto_prov = self.mass.get_provider("slimproto")
+ await slimproto_prov.cmd_unsync(player_id)
+
+ def _handle_player_register_callback(self, player: Player) -> None:
+ """Handle player register callback from slimproto source player."""
+ # TODO: Can we get better device info from mDNS ?
+ player.provider = self.domain
+ player.device_info = DeviceInfo(
+ model="Airplay device",
+ address=player.device_info.address,
+ manufacturer="Generic",
+ )
+ player.supports_24bit = False
+
+ def _handle_player_update_callback(self, player: Player) -> None:
+ """Handle player update callback from slimproto source player."""
+ # we could override anything on the player object here
+
+ async def _get_bridge_binary(self):
+ """Find the correct bridge binary belonging to the platform."""
+ # ruff: noqa: SIM102
+
+ async def check_bridge_binary(bridge_binary_path: str) -> str | None:
+ try:
+ bridge_binary = await asyncio.create_subprocess_exec(
+ *[bridge_binary_path, "-t", "-x", self._config_file],
+ stdout=asyncio.subprocess.PIPE,
+ )
+ stdout, _ = await bridge_binary.communicate()
+ if (
+ bridge_binary.returncode == 1
+ and b"This program is free software: you can redistribute it and/or modify"
+ in stdout
+ ):
+ self._bridge_bin = bridge_binary_path
+ return bridge_binary_path
+ except OSError:
+ return None
+
+ base_path = os.path.join(os.path.dirname(__file__), "bin")
+ if platform.system() == "Windows" and (
+ bridge_binary := await check_bridge_binary(
+ os.path.join(base_path, "squeeze2raop-static.exe")
+ )
+ ):
+ return bridge_binary
+ if platform.system() == "Darwin":
+ # macos binary is autoselect x86_64/arm64
+ if bridge_binary := await check_bridge_binary(
+ os.path.join(base_path, "squeeze2raop-macos-static")
+ ):
+ return bridge_binary
+
+ if platform.system() == "FreeBSD":
+ # FreeBSD binary is x86_64 intel
+ if bridge_binary := await check_bridge_binary(
+ os.path.join(base_path, "squeeze2raop-freebsd-x86_64-static")
+ ):
+ return bridge_binary
+
+ if platform.system() == "Linux":
+ architecture = platform.machine()
+ if architecture in ["AMD64", "x86_64"]:
+ # generic linux x86_64 binary
+ if bridge_binary := await check_bridge_binary(
+ os.path.join(
+ base_path,
+ "squeeze2raop-linux-x86_64-static",
+ )
+ ):
+ return bridge_binary
+
+ # other linux architecture... try all options one by one...
+ for arch in ["arm64", "arm", "armv6", "mips", "sparc64", "x86"]:
+ if bridge_binary := await check_bridge_binary(
+ os.path.join(base_path, f"squeeze2raop-linux-{arch}-static")
+ ):
+ return bridge_binary
+
+ raise RuntimeError(
+ f"Unable to locate RaopBridge for {platform.system()} ({platform.machine()})"
+ )
+
+ async def _bridge_process_runner(self) -> None:
+ """Run the bridge binary in the background."""
+ log_file = os.path.join(self.mass.storage_path, "airplay_bridge.log")
+ self.logger.debug(
+ "Starting Airplay bridge using config file %s",
+ self._config_file,
+ )
+ args = [
+ self._bridge_bin,
+ "-s",
+ "localhost",
+ "-x",
+ self._config_file,
+ "-f",
+ log_file,
+ "-I",
+ "-Z",
+ "-d",
+ "all=info",
+ ]
+ start_success = False
+ while True:
+ try:
+ self._bridge_proc = await asyncio.create_subprocess_shell(
+ " ".join(args),
+ stdout=asyncio.subprocess.DEVNULL,
+ stderr=asyncio.subprocess.DEVNULL,
+ )
+ await self._bridge_proc.wait()
+ except Exception as err:
+ if not start_success:
+ raise err
+ self.logger.exception("Error in Airplay bridge", exc_info=err)
+ else:
+ self.logger.debug("Airplay Bridge process stopped")
+ if self._closing:
+ break
+ await asyncio.sleep(1)
+
+ async def _stop_bridge(self) -> None:
+ """Stop the bridge process."""
+ if self._bridge_proc:
+ try:
+ self._bridge_proc.terminate()
+ await self._bridge_proc.wait()
+ except ProcessLookupError:
+ pass
+
+ async def _check_config_xml(self, recreate: bool = False) -> None:
+ """Check the bridge config XML file."""
+ if recreate or not os.path.isfile(self._config_file):
+ if os.path.isfile(self._config_file):
+ os.remove(self._config_file)
+ # discover players and create default config file
+ args = [
+ self._bridge_bin,
+ "-i",
+ self._config_file,
+ ]
+ proc = await asyncio.create_subprocess_shell(
+ " ".join(args),
+ stdout=asyncio.subprocess.DEVNULL,
+ stderr=asyncio.subprocess.DEVNULL,
+ )
+ await proc.wait()
+
+ # read xml file's data
+ async with aiofiles.open(self._config_file, "r") as _file:
+ xml_data = await _file.read()
+
+ try:
+ xml_root = ET.XML(xml_data)
+ except ET.ParseError as err:
+ if recreate:
+ raise err
+ await self._check_config_xml(True)
+ return
+
+ # set codecs and sample rate to airplay default
+ common_elem = xml_root.find("common")
+ common_elem.find("codecs").text = "pcm"
+ common_elem.find("sample_rate").text = "44100"
+ common_elem.find("resample").text = "0"
+ # get/set all device configs
+ for device_elem in xml_root.findall("device"):
+ player_id = device_elem.find("mac").text
+ try:
+ player_conf = self.mass.config.get_player_config(player_id)
+ except PlayerUnavailableError:
+ player_conf = None
+ # prefer name from UDN because default name is often wrong
+ udn = device_elem.find("udn").text
+ udn_name = udn.split("@")[1].split("._")[0]
+ device_elem.find("name").text = udn_name
+ device_elem.find("enabled").text = (
+ "1" if (not player_conf or player_conf.enabled) else "0"
+ )
+
+ for conf_entry in PLAYER_CONFIG_ENTRIES:
+ if conf_entry.type == ConfigEntryType.LABEL:
+ continue
+ if player_conf:
+ conf_val = player_conf.get_value(conf_entry.key)
+ else:
+ conf_val = conf_entry.default_value
+ xml_elem = device_elem.find(conf_entry.key)
+ if xml_elem is None:
+ xml_elem = ET.SubElement(device_elem, conf_entry.key)
+ if conf_entry.type == ConfigEntryType.BOOLEAN:
+ xml_elem.text = "1" if conf_val else "0"
+ else:
+ xml_elem.text = str(conf_val)
+
+ # save config file
+ async with aiofiles.open(self._config_file, "w") as _file:
+ await _file.write(ET.tostring(xml_root).decode())
--- /dev/null
+{
+ "type": "player",
+ "domain": "airplay",
+ "name": "Airplay",
+ "description": "Support for players that support the Airplay protocol.",
+ "codeowners": ["@marcelveldt"],
+ "config_entries": [
+ ],
+ "requirements": [],
+ "documentation": "",
+ "multi_instance": false,
+ "builtin": false,
+ "load_by_default": true,
+ "depends_on": "slimproto"
+}
--- /dev/null
+"""Chromecast Player provider for Music Assistant, utilizing the pychromecast library."""
+from __future__ import annotations
+
+import asyncio
+import logging
+import time
+from dataclasses import dataclass
+from logging import Logger
+from typing import TYPE_CHECKING
+from uuid import UUID
+
+from pychromecast import (
+ APP_BUBBLEUPNP,
+ APP_MEDIA_RECEIVER,
+ Chromecast,
+ get_chromecast_from_cast_info,
+)
+from pychromecast.controllers.media import STREAM_TYPE_BUFFERED, STREAM_TYPE_LIVE
+from pychromecast.controllers.multizone import MultizoneController, MultizoneManager
+from pychromecast.discovery import CastBrowser, SimpleCastListener
+from pychromecast.models import CastInfo
+from pychromecast.socket_client import (
+ CONNECTION_STATUS_CONNECTED,
+ CONNECTION_STATUS_DISCONNECTED,
+)
+
+from music_assistant.common.models.enums import (
+ ContentType,
+ MediaType,
+ PlayerFeature,
+ PlayerState,
+ PlayerType,
+)
+from music_assistant.common.models.errors import PlayerUnavailableError, QueueEmpty
+from music_assistant.common.models.player import DeviceInfo, Player
+from music_assistant.common.models.queue_item import QueueItem
+from music_assistant.constants import MASS_LOGO_ONLINE
+from music_assistant.server.helpers.compare import compare_strings
+from music_assistant.server.models.player_provider import PlayerProvider
+from music_assistant.server.providers.chromecast.helpers import (
+ CastStatusListener,
+ ChromecastInfo,
+)
+
+if TYPE_CHECKING:
+ from pychromecast.controllers.media import MediaStatus
+ from pychromecast.controllers.receiver import CastStatus
+ from pychromecast.socket_client import ConnectionStatus
+
+
+PLAYER_CONFIG_ENTRIES = tuple()
+
+
+@dataclass
+class CastPlayer:
+ """Wrapper around Chromecast with some additional attributes."""
+
+ player_id: str
+ cast_info: ChromecastInfo
+ cc: Chromecast
+ player: Player
+ logger: Logger
+ is_stereo_pair: bool = False
+ status_listener: CastStatusListener | None = None
+ mz_controller: MultizoneController | None = None
+ next_item: str | None = None
+ flow_mode_active: bool = False
+
+
+class ChromecastProvider(PlayerProvider):
+ """Player provider for Chromecast based players."""
+
+ mz_mgr: MultizoneManager | None = None
+ browser: CastBrowser | None = None
+ castplayers: dict[str, CastPlayer]
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ self.castplayers = {}
+ # silence the cast logger a bit
+ logging.getLogger("pychromecast.socket_client").setLevel(logging.INFO)
+ logging.getLogger("pychromecast.controllers").setLevel(logging.INFO)
+ self.mz_mgr = MultizoneManager()
+ self.browser = CastBrowser(
+ SimpleCastListener(
+ add_callback=self._on_chromecast_discovered,
+ remove_callback=self._on_chromecast_removed,
+ update_callback=self._on_chromecast_discovered,
+ ),
+ self.mass.zeroconf,
+ )
+ # start discovery in executor
+ await self.mass.loop.run_in_executor(None, self.browser.start_discovery)
+
+ async def close(self) -> None:
+ """Handle close/cleanup of the provider."""
+ if not self.browser:
+ return
+ # stop discovery
+ await self.mass.loop.run_in_executor(None, self.browser.stop_discovery)
+ # stop all chromecasts
+ for castplayer in list(self.castplayers.values()):
+ await self._disconnect_chromecast(castplayer)
+
+ async def cmd_stop(self, player_id: str) -> None:
+ """Send STOP command to given player."""
+ castplayer = self.castplayers[player_id]
+ await asyncio.to_thread(castplayer.cc.media_controller.stop)
+
+ async def cmd_play(self, player_id: str) -> None:
+ """Send PLAY command to given player."""
+ castplayer = self.castplayers[player_id]
+ await asyncio.to_thread(castplayer.cc.media_controller.play)
+
+ async def cmd_play_media(
+ self,
+ player_id: str,
+ queue_item: QueueItem,
+ seek_position: int = 0,
+ fade_in: bool = False,
+ flow_mode: bool = False,
+ ) -> None:
+ """Send PLAY MEDIA command to given player."""
+ castplayer = self.castplayers[player_id]
+ url = await self.mass.streams.resolve_stream_url(
+ queue_item=queue_item,
+ player_id=player_id,
+ seek_position=seek_position,
+ fade_in=fade_in,
+ # prefer FLAC as it seems to work on all CC players
+ content_type=ContentType.FLAC,
+ flow_mode=flow_mode,
+ )
+ castplayer.flow_mode_active = flow_mode
+
+ # in flow mode, we just send the url and the metadata is of no use
+ if flow_mode:
+ await asyncio.to_thread(
+ castplayer.cc.play_media,
+ url,
+ content_type="audio/flac",
+ title="Music Assistant",
+ thumb=MASS_LOGO_ONLINE,
+ media_info={
+ "customData": {
+ "queue_item_id": queue_item.queue_item_id,
+ }
+ },
+ )
+ return
+
+ cc_queue_items = [self._create_queue_item(queue_item, url)]
+ queuedata = {
+ "type": "QUEUE_LOAD",
+ "repeatMode": "REPEAT_OFF", # handled by our queue controller
+ "shuffle": False, # handled by our queue controller
+ "queueType": "PLAYLIST",
+ "startIndex": 0, # Item index to play after this request or keep same item if undefined
+ "items": cc_queue_items,
+ }
+ # make sure that media controller app is launched
+ await self._launch_app(castplayer)
+ # send queue info to the CC
+ castplayer.next_item = None
+ media_controller = castplayer.cc.media_controller
+ await asyncio.to_thread(media_controller.send_message, queuedata, True)
+
+ async def cmd_pause(self, player_id: str) -> None:
+ """Send PAUSE command to given player."""
+ castplayer = self.castplayers[player_id]
+ await asyncio.to_thread(castplayer.cc.media_controller.pause)
+
+ async def cmd_power(self, player_id: str, powered: bool) -> None:
+ """Send POWER command to given player."""
+ castplayer = self.castplayers[player_id]
+ if powered:
+ await self._launch_app(castplayer)
+ else:
+ await asyncio.to_thread(castplayer.cc.quit_app)
+
+ async def cmd_volume_set(self, player_id: str, volume_level: int) -> None:
+ """Send VOLUME_SET command to given player."""
+ castplayer = self.castplayers[player_id]
+ await asyncio.to_thread(castplayer.cc.set_volume, volume_level / 100)
+
+ async def cmd_volume_mute(self, player_id: str, muted: bool) -> None:
+ """Send VOLUME MUTE command to given player."""
+ castplayer = self.castplayers[player_id]
+ await asyncio.to_thread(castplayer.cc.set_volume_muted, muted)
+
+ async def poll_player(self, player_id: str) -> None:
+ """Poll player for state updates.
+
+ This is called by the Player Manager;
+ - every 360 seconds if the player if not powered
+ - every 30 seconds if the player is powered
+ - every 10 seconds if the player is playing
+
+ Use this method to request any info that is not automatically updated and/or
+ to detect if the player is still alive.
+ If this method raises the PlayerUnavailable exception,
+ the player is marked as unavailable until
+ the next successful poll or event where it becomes available again.
+ If the player does not need any polling, simply do not override this method.
+ """
+ castplayer = self.castplayers[player_id]
+ try:
+ await asyncio.to_thread(castplayer.cc.media_controller.update_status)
+ except ConnectionResetError as err:
+ raise PlayerUnavailableError from err
+
+ ### Discovery callbacks
+
+ def _on_chromecast_discovered(self, uuid, _):
+ """Handle Chromecast discovered callback."""
+ if self.mass.closing:
+ return
+
+ disc_info: CastInfo = self.browser.devices[uuid]
+
+ if disc_info.uuid is None:
+ self.logger.error("Discovered chromecast without uuid %s", disc_info)
+ return
+
+ self.logger.debug("Discovered new or updated chromecast %s", disc_info)
+ player_id = str(disc_info.uuid)
+
+ castplayer = self.castplayers.get(player_id)
+ if not castplayer:
+ cast_info = ChromecastInfo.from_cast_info(disc_info)
+ cast_info.fill_out_missing_chromecast_info(self.mass.zeroconf)
+ if cast_info.is_dynamic_group:
+ self.logger.warning("Discovered a dynamic cast group which will be ignored.")
+ return
+
+ # Instantiate chromecast object
+ castplayer = CastPlayer(
+ player_id,
+ cast_info=cast_info,
+ cc=get_chromecast_from_cast_info(
+ disc_info,
+ self.mass.zeroconf,
+ ),
+ player=Player(
+ player_id=player_id,
+ provider=self.domain,
+ type=PlayerType.GROUP if cast_info.is_audio_group else PlayerType.PLAYER,
+ name=cast_info.friendly_name,
+ available=False,
+ powered=False,
+ device_info=DeviceInfo(
+ model=cast_info.model_name,
+ address=cast_info.host,
+ manufacturer=cast_info.manufacturer,
+ ),
+ supported_features=(
+ PlayerFeature.POWER,
+ PlayerFeature.VOLUME_MUTE,
+ PlayerFeature.VOLUME_SET,
+ ),
+ max_sample_rate=96000,
+ ),
+ logger=self.logger.getChild(cast_info.friendly_name),
+ )
+ self.castplayers[player_id] = castplayer
+
+ castplayer.status_listener = CastStatusListener(self, castplayer, self.mz_mgr)
+ if cast_info.is_audio_group:
+ mz_controller = MultizoneController(cast_info.uuid)
+ castplayer.cc.register_handler(mz_controller)
+ castplayer.mz_controller = mz_controller
+ castplayer.cc.start()
+
+ self.mass.loop.call_soon_threadsafe(self.mass.players.register, castplayer.player)
+
+ # if player was already added, the player will take care of reconnects itself.
+ castplayer.cast_info.update(disc_info)
+ self.mass.loop.call_soon_threadsafe(self.mass.players.update, player_id)
+
+ def _on_chromecast_removed(self, uuid, service, cast_info): # noqa: ARG002
+ """Handle zeroconf discovery of a removed Chromecast."""
+ # noqa: ARG001
+ player_id = str(service[1])
+ friendly_name = service[3]
+ self.logger.debug("Chromecast removed: %s - %s", friendly_name, player_id)
+ # we ignore this event completely as the Chromecast socket client handles this itself
+
+ ### Callbacks from Chromecast Statuslistener
+
+ def on_new_cast_status(self, castplayer: CastPlayer, status: CastStatus) -> None:
+ """Handle updated CastStatus."""
+ castplayer.logger.debug(
+ "Received cast status - app_id: %s - volume: %s",
+ status.app_id,
+ status.volume_level,
+ )
+ castplayer.player.name = castplayer.cast_info.friendly_name
+ castplayer.player.powered = status.app_id in (
+ "705D30C6",
+ APP_MEDIA_RECEIVER,
+ APP_BUBBLEUPNP,
+ )
+ castplayer.is_stereo_pair = (
+ castplayer.cast_info.is_audio_group
+ and castplayer.mz_controller
+ and castplayer.mz_controller.members
+ and compare_strings(castplayer.mz_controller.members[0], castplayer.player_id)
+ )
+ castplayer.player.volume_level = int(status.volume_level * 100)
+ castplayer.player.volume_muted = status.volume_muted
+ if castplayer.is_stereo_pair:
+ castplayer.player.type = PlayerType.PLAYER
+ self.mass.loop.call_soon_threadsafe(self.mass.players.update, castplayer.player_id)
+
+ def on_new_media_status(self, castplayer: CastPlayer, status: MediaStatus):
+ """Handle updated MediaStatus."""
+ castplayer.logger.debug("Received media status update: %s", status.player_state)
+ prev_item_id = castplayer.player.current_item_id
+ # player state
+ if status.player_is_playing:
+ castplayer.player.state = PlayerState.PLAYING
+ elif status.player_is_paused:
+ castplayer.player.state = PlayerState.PAUSED
+ else:
+ castplayer.player.state = PlayerState.IDLE
+
+ # elapsed time
+ castplayer.player.elapsed_time_last_updated = time.time()
+ if status.player_is_playing:
+ castplayer.player.elapsed_time = status.adjusted_current_time
+ else:
+ castplayer.player.elapsed_time = status.current_time
+
+ # current media
+ queue_item_id = status.media_custom_data.get("queue_item_id")
+ castplayer.player.current_item_id = queue_item_id
+ castplayer.player.current_url = status.content_id
+ self.mass.loop.call_soon_threadsafe(self.mass.players.update, castplayer.player_id)
+
+ # enqueue next item if needed
+ if castplayer.player.state == PlayerState.PLAYING and (
+ prev_item_id != castplayer.player.current_item_id
+ or not castplayer.next_item
+ or castplayer.next_item == castplayer.player.current_item_id
+ ):
+ asyncio.run_coroutine_threadsafe(
+ self._enqueue_next_track(castplayer, queue_item_id), self.mass.loop
+ )
+
+ def on_new_connection_status(self, castplayer: CastPlayer, status: ConnectionStatus) -> None:
+ """Handle updated ConnectionStatus."""
+ castplayer.logger.debug("Received connection status update - status: %s", status.status)
+
+ if status.status == CONNECTION_STATUS_DISCONNECTED:
+ castplayer.player.available = False
+ self.mass.loop.call_soon_threadsafe(self.mass.players.update, castplayer.player_id)
+ return
+
+ new_available = status.status == CONNECTION_STATUS_CONNECTED
+ if new_available != castplayer.player.available:
+ self.logger.debug(
+ "[%s] Cast device availability changed: %s",
+ castplayer.cast_info.friendly_name,
+ status.status,
+ )
+ castplayer.player.available = new_available
+ castplayer.player.device_info = DeviceInfo(
+ model=castplayer.cast_info.model_name,
+ address=castplayer.cast_info.host,
+ manufacturer=castplayer.cast_info.manufacturer,
+ )
+ self.mass.loop.call_soon_threadsafe(self.mass.players.update, castplayer.player_id)
+ if new_available and not castplayer.cast_info.is_audio_group:
+ # Poll current group status
+ for group_uuid in self.mz_mgr.get_multizone_memberships(castplayer.cast_info.uuid):
+ group_media_controller = self.mz_mgr.get_multizone_mediacontroller(group_uuid)
+ if not group_media_controller:
+ continue
+ self.on_multizone_new_media_status(
+ castplayer, group_uuid, group_media_controller.status
+ )
+
+ def on_multizone_new_media_status(
+ self, castplayer: CastPlayer, group_uuid: UUID, media_status: MediaStatus # noqa: ARG002
+ ):
+ """Handle updates of audio group media status."""
+ castplayer.logger.debug("Received multizone media status update")
+ # self.mz_media_status[group_uuid] = media_status
+ # self.mz_media_status_received[group_uuid] = dt_util.utcnow()
+ # self.schedule_update_ha_state()
+
+ ### Helpers / utils
+
+ async def _enqueue_next_track(self, castplayer: CastPlayer, current_queue_item_id: str) -> None:
+ """Enqueue the next track of the MA queue on the CC queue."""
+ if castplayer.flow_mode_active:
+ # not possible when we're in flow mode
+ return
+
+ if not current_queue_item_id:
+ return # guard
+ try:
+ next_item, crossfade = self.mass.players.queues.player_ready_for_next_track(
+ castplayer.player_id, current_queue_item_id
+ )
+ except QueueEmpty:
+ return
+
+ if castplayer.next_item == next_item.queue_item_id:
+ return # already set ?!
+ castplayer.next_item = next_item.queue_item_id
+
+ if crossfade:
+ self.logger.warning(
+ "Crossfade requested but Chromecast does not support crossfading,"
+ " consider using flow mode to enable crossfade on a Chromecast."
+ )
+
+ url = await self.mass.streams.resolve_stream_url(
+ queue_item=next_item,
+ player_id=castplayer.player_id,
+ content_type=ContentType.FLAC,
+ auto_start_runner=False,
+ )
+ cc_queue_items = [self._create_queue_item(next_item, url)]
+
+ queuedata = {
+ "type": "QUEUE_INSERT",
+ "insertBefore": None,
+ "items": cc_queue_items,
+ }
+ media_controller = castplayer.cc.media_controller
+ queuedata["mediaSessionId"] = media_controller.status.media_session_id
+
+ await asyncio.sleep(0.5) # throttle commands to CC a bit or it will crash
+ await asyncio.to_thread(media_controller.send_message, queuedata, True)
+
+ async def _launch_app(self, castplayer: CastPlayer) -> None:
+ """Launch the default Media Receiver App on a Chromecast."""
+ event = asyncio.Event()
+
+ def launched_callback():
+ self.mass.loop.call_soon_threadsafe(event.set)
+
+ def launch():
+ # controller = BubbleUPNPController()
+ # castplayer.cc.register_handler(controller)
+ # controller.launch(launched_callback)
+ castplayer.cc.media_controller.launch(launched_callback)
+
+ castplayer.logger.debug("Launching BubbleUPNPController as active app.")
+ await self.mass.loop.run_in_executor(None, launch)
+ await event.wait()
+
+ async def _disconnect_chromecast(self, castplayer: CastPlayer) -> None:
+ """Disconnect Chromecast object if it is set."""
+ castplayer.logger.debug("Disconnecting from chromecast socket")
+ await self.mass.loop.run_in_executor(None, castplayer.cc.disconnect, 10)
+ castplayer.mz_controller = None
+ castplayer.status_listener.invalidate()
+ castplayer.status_listener = None
+ self.castplayers.pop(castplayer.player_id, None)
+
+ @staticmethod
+ def _create_queue_item(queue_item: QueueItem, stream_url: str):
+ """Create CC queue item from MA QueueItem."""
+ duration = int(queue_item.duration) if queue_item.duration else None
+ if queue_item.media_type == MediaType.TRACK:
+ stream_type = STREAM_TYPE_BUFFERED
+ metadata = {
+ "metadataType": 3,
+ "albumName": queue_item.media_item.album.name,
+ "songName": queue_item.media_item.name,
+ "artist": queue_item.media_item.artist.name,
+ "title": queue_item.name,
+ "images": [{"url": queue_item.image.url}] if queue_item.image else None,
+ }
+ else:
+ stream_type = STREAM_TYPE_LIVE
+ metadata = {
+ "metadataType": 0,
+ "title": queue_item.name,
+ "images": [{"url": queue_item.image.url}] if queue_item.image else None,
+ }
+ return {
+ "autoplay": True,
+ "preloadTime": 10,
+ "playbackDuration": duration,
+ "startTime": 0,
+ "activeTrackIds": [],
+ "media": {
+ "contentId": stream_url,
+ "customData": {
+ "uri": queue_item.uri,
+ "queue_item_id": queue_item.queue_item_id,
+ },
+ "contentType": "audio/flac",
+ "streamType": stream_type,
+ "metadata": metadata,
+ "duration": duration,
+ },
+ }
--- /dev/null
+"""Helpers to deal with Cast devices."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Self
+from uuid import UUID
+
+from pychromecast import dial
+from pychromecast.const import CAST_TYPE_GROUP
+
+if TYPE_CHECKING:
+ from pychromecast.controllers.media import MediaStatus
+ from pychromecast.controllers.multizone import MultizoneManager
+ from pychromecast.controllers.receiver import CastStatus
+ from pychromecast.models import CastInfo
+ from pychromecast.socket_client import ConnectionStatus
+ from zeroconf import Zeroconf
+
+ from . import CastPlayer, ChromecastProvider
+
+DEFAULT_PORT = 8009
+
+
+@dataclass
+class ChromecastInfo:
+ """Class to hold all data about a chromecast for creating connections.
+
+ This also has the same attributes as the mDNS fields by zeroconf.
+ """
+
+ services: set
+ uuid: UUID
+ model_name: str
+ friendly_name: str
+ host: str
+ port: int
+ cast_type: str | None = None
+ manufacturer: str | None = None
+ is_dynamic_group: bool | None = None
+
+ @property
+ def is_audio_group(self) -> bool:
+ """Return if the cast is an audio group."""
+ return self.cast_type == CAST_TYPE_GROUP
+
+ @classmethod
+ def from_cast_info(cls: Self, cast_info: CastInfo) -> Self:
+ """Instantiate ChromecastInfo from CastInfo."""
+ return cls(**cast_info._asdict())
+
+ def update(self, cast_info: CastInfo) -> None:
+ """Update ChromecastInfo from CastInfo."""
+ for key, value in cast_info._asdict().items():
+ if not value:
+ continue
+ setattr(self, key, value)
+
+ def fill_out_missing_chromecast_info(self, zconf: Zeroconf) -> None:
+ """
+ Return a new ChromecastInfo object with missing attributes filled in.
+
+ Uses blocking HTTP / HTTPS.
+ """
+ if self.cast_type is None or self.manufacturer is None:
+ # Manufacturer and cast type is not available in mDNS data,
+ # get it over HTTP
+ cast_info = dial.get_cast_type(
+ self,
+ zconf=zconf,
+ )
+ self.cast_type = cast_info.cast_type
+ self.manufacturer = cast_info.manufacturer
+
+ if not self.is_audio_group or self.is_dynamic_group is not None:
+ # We have all information, no need to check HTTP API.
+ return
+
+ # Fill out missing group information via HTTP API.
+ http_group_status = dial.get_multizone_status(
+ None,
+ services=self.services,
+ zconf=zconf,
+ )
+ if http_group_status is not None:
+ self.is_dynamic_group = any(
+ g.uuid == self.uuid for g in http_group_status.dynamic_groups
+ )
+
+
+class CastStatusListener:
+ """
+ Helper class to handle pychromecast status callbacks.
+
+ Necessary because a CastDevice entity can create a new socket client
+ and therefore callbacks from multiple chromecast connections can
+ potentially arrive. This class allows invalidating past chromecast objects.
+ """
+
+ def __init__(
+ self,
+ prov: ChromecastProvider,
+ castplayer: CastPlayer,
+ mz_mgr: MultizoneManager,
+ mz_only=False,
+ ):
+ """Initialize the status listener."""
+ self.prov = prov
+ self.castplayer = castplayer
+ self._uuid = castplayer.cc.uuid
+ self._valid = True
+ self._mz_mgr = mz_mgr
+
+ if self.castplayer.cast_info.is_audio_group:
+ self._mz_mgr.add_multizone(castplayer.cc)
+ if mz_only:
+ return
+
+ castplayer.cc.register_status_listener(self)
+ castplayer.cc.socket_client.media_controller.register_status_listener(self)
+ castplayer.cc.register_connection_listener(self)
+ if not self.castplayer.cast_info.is_audio_group:
+ self._mz_mgr.register_listener(castplayer.cc.uuid, self)
+
+ def new_cast_status(self, status: CastStatus) -> None:
+ """Handle updated CastStatus."""
+ if self._valid:
+ self.prov.on_new_cast_status(self.castplayer, status)
+
+ def new_media_status(self, status: MediaStatus) -> None:
+ """Handle updated MediaStatus."""
+ if self._valid:
+ self.prov.on_new_media_status(self.castplayer, status)
+
+ def new_connection_status(self, status: ConnectionStatus) -> None:
+ """Handle updated ConnectionStatus."""
+ if self._valid:
+ self.prov.on_new_connection_status(self.castplayer, status)
+
+ @staticmethod
+ def added_to_multizone(group_uuid):
+ """Handle the cast added to a group."""
+ print("##### added_to_multizone: %s" % group_uuid)
+
+ def removed_from_multizone(self, group_uuid):
+ """Handle the cast removed from a group."""
+ if self._valid:
+ # self._cast_device.multizone_new_media_status(group_uuid, None)
+ print("##### removed_from_multizone: %s" % group_uuid)
+
+ def multizone_new_cast_status(self, group_uuid, cast_status): # noqa: ARG002
+ """Handle reception of a new CastStatus for a group."""
+ print("##### multizone_new_cast_status: %s" % group_uuid)
+
+ def multizone_new_media_status(self, group_uuid, media_status): # noqa: ARG002
+ """Handle reception of a new MediaStatus for a group."""
+ if self._valid:
+ # self._cast_device.multizone_new_media_status(group_uuid, media_status)
+ print("##### multizone_new_media_status: %s" % group_uuid)
+
+ def invalidate(self):
+ """
+ Invalidate this status listener.
+
+ All following callbacks won't be forwarded.
+ """
+ # pylint: disable=protected-access
+ if self.castplayer.cast_info.is_audio_group:
+ self._mz_mgr.remove_multizone(self._uuid)
+ else:
+ self._mz_mgr.deregister_listener(self._uuid, self)
+ self._valid = False
--- /dev/null
+{
+ "type": "player",
+ "domain": "chromecast",
+ "name": "Chromecast",
+ "description": "Support for Chromecast based players.",
+ "codeowners": ["@marcelveldt"],
+ "config_entries": [
+ ],
+ "requirements": ["PyChromecast==13.0.4"],
+ "documentation": "",
+ "multi_instance": false,
+ "builtin": false,
+ "load_by_default": true
+}
--- /dev/null
+"""DLNA/uPNP Player provider for Music Assistant.
+
+Most of this code is based on the implementation within Home Assistant:
+https://github.com/home-assistant/core/blob/dev/homeassistant/components/dlna_dmr
+
+All rights/credits reserved.
+"""
+from __future__ import annotations
+
+import asyncio
+import functools
+import logging
+import time
+from collections.abc import Awaitable, Callable, Coroutine, Sequence
+from dataclasses import dataclass, field
+from typing import Any, Concatenate, ParamSpec, TypeVar
+
+from async_upnp_client.aiohttp import AiohttpSessionRequester
+from async_upnp_client.client import UpnpRequester, UpnpService, UpnpStateVariable
+from async_upnp_client.client_factory import UpnpFactory
+from async_upnp_client.exceptions import UpnpError, UpnpResponseError
+from async_upnp_client.profiles.dlna import DmrDevice, TransportState
+from async_upnp_client.search import async_search
+from async_upnp_client.utils import CaseInsensitiveDict
+
+from music_assistant.common.models.enums import ContentType, PlayerFeature, PlayerState, PlayerType
+from music_assistant.common.models.errors import PlayerUnavailableError, QueueEmpty
+from music_assistant.common.models.player import DeviceInfo, Player
+from music_assistant.common.models.queue_item import QueueItem
+from music_assistant.server.helpers.didl_lite import create_didl_metadata
+from music_assistant.server.models.player_provider import PlayerProvider
+
+from .helpers import DLNANotifyServer
+
+PLAYER_FEATURES = (
+ PlayerFeature.SET_MEMBERS,
+ PlayerFeature.SYNC,
+ PlayerFeature.VOLUME_MUTE,
+ PlayerFeature.VOLUME_SET,
+)
+PLAYER_CONFIG_ENTRIES = tuple() # we don't have any player config entries (for now)
+
+_DLNAPlayerProviderT = TypeVar("_DLNAPlayerProviderT", bound="DLNAPlayerProvider")
+_R = TypeVar("_R")
+_P = ParamSpec("_P")
+
+
+def catch_request_errors(
+ func: Callable[Concatenate[_DLNAPlayerProviderT, _P], Awaitable[_R]]
+) -> Callable[Concatenate[_DLNAPlayerProviderT, _P], Coroutine[Any, Any, _R | None]]:
+ """Catch UpnpError errors."""
+
+ @functools.wraps(func)
+ async def wrapper(self: _DLNAPlayerProviderT, *args: _P.args, **kwargs: _P.kwargs) -> _R | None:
+ """Catch UpnpError errors and check availability before and after request."""
+ player_id = kwargs["player_id"] if "player_id" in kwargs else args[0]
+ dlna_player = self.dlnaplayers[player_id]
+ if not dlna_player.available:
+ self.logger.warning("Device disappeared when trying to call %s", func.__name__)
+ return None
+ try:
+ return await func(self, *args, **kwargs)
+ except UpnpError as err:
+ dlna_player.force_poll = True
+ self.logger.error("Error during call %s: %r", func.__name__, err)
+ return None
+
+ return wrapper
+
+
+@dataclass
+class DLNAPlayer:
+ """Class that holds all dlna variables for a player."""
+
+ udn: str # = player_id
+ player: Player # mass player
+ description_url: str # last known location (description.xml) url
+
+ device: DmrDevice | None = None
+ lock: asyncio.Lock = field(
+ default_factory=asyncio.Lock
+ ) # Held when connecting or disconnecting the device
+ force_poll: bool = False
+ ssdp_connect_failed: bool = False
+
+ # Track BOOTID in SSDP advertisements for device changes
+ bootid: int | None = None
+ last_seen: float = field(default_factory=time.time)
+ next_item: str | None = None
+ supports_next_uri = True
+ end_of_track_reached = False
+
+ def update_attributes(self):
+ """Update attributes of the MA Player from DLNA state."""
+ # generic attributes
+
+ if self.available:
+ self.player.available = True
+ self.player.name = self.device.name
+ self.player.volume_level = int((self.device.volume_level or 0) * 100)
+ self.player.volume_muted = self.device.is_volume_muted or False
+ self.player.state = self.get_state(self.device)
+ self.player.supported_features = self.get_supported_features(self.device)
+ self.player.current_url = self.device.current_track_uri or ""
+ self.player.elapsed_time = float(self.device.media_position or 0)
+ if self.device.media_position_updated_at is not None:
+ self.player.elapsed_time_last_updated = (
+ self.device.media_position_updated_at.timestamp()
+ )
+ self.player.current_item_id = self.device._get_current_track_meta_data("queue_item_id")
+ if self.device.media_duration and self.player.corrected_elapsed_time:
+ self.end_of_track_reached = (
+ self.device.media_duration - self.player.corrected_elapsed_time
+ ) < 15
+ else:
+ # device is unavailable
+ self.player.available = False
+
+ @property
+ def available(self) -> bool:
+ """Device is available when we have a connection to it."""
+ return self.device is not None and self.device.profile_device.available
+
+ @staticmethod
+ def get_state(device: DmrDevice) -> PlayerState:
+ """Return current PlayerState of the player."""
+ if device.transport_state is None:
+ return PlayerState.IDLE
+ if device.transport_state in (
+ TransportState.PLAYING,
+ TransportState.TRANSITIONING,
+ ):
+ return PlayerState.PLAYING
+ if device.transport_state in (
+ TransportState.PAUSED_PLAYBACK,
+ TransportState.PAUSED_RECORDING,
+ ):
+ return PlayerState.PAUSED
+ if device.transport_state == TransportState.VENDOR_DEFINED:
+ # Unable to map this state to anything reasonable, fallback to idle
+ return PlayerState.IDLE
+
+ return PlayerState.IDLE
+
+ @staticmethod
+ def get_supported_features(device: DmrDevice) -> set(PlayerFeature):
+ """Get player features that are supported at this moment.
+
+ Supported features may change as the device enters different states.
+ """
+ supported_features = set()
+
+ if device.has_volume_level:
+ supported_features.add(PlayerFeature.VOLUME_SET)
+ if device.has_volume_mute:
+ supported_features.add(PlayerFeature.VOLUME_MUTE)
+
+ if device.can_seek_rel_time or device.can_seek_abs_time:
+ supported_features.add(PlayerFeature.SEEK)
+
+ return supported_features
+
+
+class DLNAPlayerProvider(PlayerProvider):
+ """DLNA Player provider."""
+
+ dlnaplayers: dict[str, DLNAPlayer] | None = None
+ _discovery_running: bool = False
+
+ lock: asyncio.Lock
+ requester: UpnpRequester
+ upnp_factory: UpnpFactory
+ notify_server: DLNANotifyServer
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ self.dlnaplayers = {}
+ self.lock = asyncio.Lock()
+ self.requester = AiohttpSessionRequester(self.mass.http_session, with_sleep=True)
+ # silence the async_upnp_client logger a bit
+ logging.getLogger("async_upnp_client").setLevel(logging.INFO)
+ logging.getLogger("charset_normalizer").setLevel(logging.INFO)
+
+ self.upnp_factory = UpnpFactory(self.requester, non_strict=True)
+ self.notify_server = DLNANotifyServer(self.requester, self.mass)
+ self.mass.create_task(self._run_discovery())
+
+ @catch_request_errors
+ async def cmd_stop(self, player_id: str) -> None:
+ """Send STOP command to given player."""
+ dlna_player = self.dlnaplayers[player_id]
+ dlna_player.end_of_track_reached = False
+ dlna_player.next_item = None
+ assert dlna_player.device is not None
+ await dlna_player.device.async_stop()
+
+ @catch_request_errors
+ async def cmd_play(self, player_id: str) -> None:
+ """Send PLAY command to given player."""
+ dlna_player = self.dlnaplayers[player_id]
+ assert dlna_player.device is not None
+ await dlna_player.device.async_play()
+
+ @catch_request_errors
+ async def cmd_play_media(
+ self,
+ player_id: str,
+ queue_item: QueueItem,
+ seek_position: int = 0,
+ fade_in: bool = False,
+ flow_mode: bool = False,
+ ) -> None:
+ """Send PLAY MEDIA command to given player."""
+ dlna_player = self.dlnaplayers[player_id]
+
+ # always clear queue (by sending stop) first
+ if dlna_player.device.can_stop:
+ await self.cmd_stop(player_id)
+
+ url = await self.mass.streams.resolve_stream_url(
+ queue_item=queue_item,
+ player_id=dlna_player.udn,
+ seek_position=seek_position,
+ fade_in=fade_in,
+ content_type=ContentType.FLAC,
+ flow_mode=flow_mode,
+ )
+
+ didl_metadata = create_didl_metadata(url, queue_item, flow_mode)
+ await dlna_player.device.async_set_transport_uri(url, queue_item.name, didl_metadata)
+ # Play it
+ await dlna_player.device.async_wait_for_can_play(10)
+ await dlna_player.device.async_play()
+ # force poll the device
+ for sleep in (0, 1, 2):
+ await asyncio.sleep(sleep)
+ dlna_player.force_poll = True
+ await self.poll_player(dlna_player.udn)
+
+ @catch_request_errors
+ async def cmd_pause(self, player_id: str) -> None:
+ """Send PAUSE command to given player."""
+ dlna_player = self.dlnaplayers[player_id]
+ assert dlna_player.device is not None
+ if dlna_player.device.can_pause:
+ await dlna_player.device.async_pause()
+ else:
+ await dlna_player.device.async_stop()
+
+ @catch_request_errors
+ async def cmd_volume_set(self, player_id: str, volume_level: int) -> None:
+ """Send VOLUME_SET command to given player."""
+ dlna_player = self.dlnaplayers[player_id]
+ assert dlna_player.device is not None
+ await dlna_player.device.async_set_volume_level(volume_level / 100)
+
+ @catch_request_errors
+ async def cmd_volume_mute(self, player_id: str, muted: bool) -> None:
+ """Send VOLUME MUTE command to given player."""
+ dlna_player = self.dlnaplayers[player_id]
+ assert dlna_player.device is not None
+ await dlna_player.device.async_mute_volume(muted)
+
+ async def poll_player(self, player_id: str) -> None:
+ """Poll player for state updates.
+
+ This is called by the Player Manager;
+ - every 360 seconds if the player if not powered
+ - every 30 seconds if the player is powered
+ - every 10 seconds if the player is playing
+
+ Use this method to request any info that is not automatically updated and/or
+ to detect if the player is still alive.
+ If this method raises the PlayerUnavailable exception,
+ the player is marked as unavailable until
+ the next successful poll or event where it becomes available again.
+ If the player does not need any polling, simply do not override this method.
+ """
+ dlna_player = self.dlnaplayers[player_id]
+
+ # try to reconnect the device if the connection was lost
+ if not dlna_player.device:
+ if not dlna_player.force_poll:
+ return
+ try:
+ await self._device_connect(dlna_player)
+ except UpnpError as err:
+ raise PlayerUnavailableError from err
+
+ assert dlna_player.device is not None
+
+ try:
+ now = time.time()
+ do_ping = dlna_player.force_poll or (now - dlna_player.last_seen) > 60
+ await dlna_player.device.async_update(do_ping=do_ping)
+ dlna_player.last_seen = now if do_ping else dlna_player.last_seen
+ except UpnpError as err:
+ self.logger.debug("Device unavailable: %r", err)
+ await self._device_disconnect(dlna_player)
+ raise PlayerUnavailableError from err
+ finally:
+ dlna_player.force_poll = False
+
+ async def _run_discovery(self) -> None:
+ """Discover DLNA players on the network."""
+ if self._discovery_running:
+ return
+ try:
+ self._discovery_running = True
+ self.logger.debug("DLNA discovery started...")
+ discovered_devices: set[str] = set()
+
+ async def on_response(discovery_info: CaseInsensitiveDict):
+ """Process discovered device from ssdp search."""
+ ssdp_st: str = discovery_info.get("st", discovery_info.get("nt"))
+ if not ssdp_st:
+ return
+
+ if "MediaRenderer" not in ssdp_st:
+ # we're only interested in MediaRenderer devices
+ return
+
+ ssdp_usn: str = discovery_info["usn"]
+ ssdp_udn: str | None = discovery_info.get("_udn")
+ if not ssdp_udn and ssdp_usn.startswith("uuid:"):
+ ssdp_udn = ssdp_usn.split("::")[0]
+
+ if ssdp_udn in discovered_devices:
+ # already processed this device
+ return
+
+ discovered_devices.add(ssdp_udn)
+
+ await self._device_discovered(ssdp_udn, discovery_info["location"])
+
+ await async_search(on_response, 30)
+
+ finally:
+ self._discovery_running = False
+
+ def reschedule():
+ self.mass.create_task(self._run_discovery())
+
+ # reschedule self once finished
+ self.mass.loop.call_later(300, reschedule)
+
+ async def _device_disconnect(self, dlna_player: DLNAPlayer) -> None:
+ """
+ Destroy connections to the device now that it's not available.
+
+ Also call when removing this entity from MA to clean up connections.
+ """
+ async with dlna_player.lock:
+ if not dlna_player.device:
+ self.logger.debug("Disconnecting from device that's not connected")
+ return
+
+ self.logger.debug("Disconnecting from %s", dlna_player.device.name)
+
+ dlna_player.device.on_event = None
+ old_device = dlna_player.device
+ dlna_player.device = None
+ await old_device.async_unsubscribe_services()
+
+ await self._async_release_event_notifier(dlna_player.event_addr)
+
+ async def _device_discovered(self, udn: str, description_url: str) -> None:
+ """Handle discovered DLNA player."""
+ if dlna_player := self.dlnaplayers.get(udn):
+ # existing player
+ if dlna_player.description_url == description_url and dlna_player.player.available:
+ # nothing to do, device is already connected
+ return
+ # update description url to newly discovered one
+ dlna_player.description_url = description_url
+ else:
+ # new player detected, setup our DLNAPlayer wrapper
+ dlna_player = DLNAPlayer(
+ udn=udn,
+ player=Player(
+ player_id=udn,
+ provider=self.domain,
+ type=PlayerType.PLAYER,
+ name=udn,
+ available=False,
+ powered=False,
+ supported_features=PLAYER_FEATURES,
+ # device info will be discovered later after connect
+ device_info=DeviceInfo(
+ model="unknown",
+ address=description_url,
+ manufacturer="unknown",
+ ),
+ ),
+ description_url=description_url,
+ )
+ self.dlnaplayers[udn] = dlna_player
+
+ await self._device_connect(dlna_player)
+ dlna_player.update_attributes()
+ self.mass.players.register_or_update(dlna_player.player)
+
+ async def _device_connect(self, dlna_player: DLNAPlayer) -> None:
+ """Connect DLNA/DMR Device."""
+ self.logger.debug("Connecting to device at %s", dlna_player.description_url)
+
+ async with dlna_player.lock:
+ if dlna_player.device:
+ self.logger.debug("Trying to connect when device already connected")
+ return
+
+ # Connect to the base UPNP device
+ upnp_device = await self.upnp_factory.async_create_device(dlna_player.description_url)
+
+ # Create profile wrapper
+ dlna_player.device = DmrDevice(upnp_device, self.notify_server.event_handler)
+
+ # Subscribe to event notifications
+ try:
+ dlna_player.device.on_event = self._handle_event
+ await dlna_player.device.async_subscribe_services(auto_resubscribe=True)
+ except UpnpResponseError as err:
+ # Device rejected subscription request. This is OK, variables
+ # will be polled instead.
+ self.logger.debug("Device rejected subscription: %r", err)
+ except UpnpError as err:
+ # Don't leave the device half-constructed
+ dlna_player.device.on_event = None
+ dlna_player.device = None
+ self.logger.debug("Error while subscribing during device connect: %r", err)
+ raise
+ else:
+ # connect was successful, update device info
+ dlna_player.player.device_info = DeviceInfo(
+ model=dlna_player.device.model_name,
+ address=dlna_player.device.device.presentation_url
+ or dlna_player.description_url,
+ manufacturer=dlna_player.device.manufacturer,
+ )
+
+ def _handle_event(
+ self,
+ service: UpnpService,
+ state_variables: Sequence[UpnpStateVariable],
+ ) -> None:
+ """Handle state variable(s) changed event from DLNA device."""
+ udn = service.device.udn
+
+ dlna_player = self.dlnaplayers[udn]
+ self.logger.debug(
+ "Received event for Player %s: %s",
+ dlna_player.player.display_name,
+ service,
+ )
+
+ if not state_variables:
+ # Indicates a failure to resubscribe, check if device is still available
+ dlna_player.force_poll = True
+ return
+
+ if service.service_id == "urn:upnp-org:serviceId:AVTransport":
+ for state_variable in state_variables:
+ # Force a state refresh when player begins or pauses playback
+ # to update the position info.
+ if state_variable.name == "TransportState" and state_variable.value in (
+ TransportState.PLAYING,
+ TransportState.PAUSED_PLAYBACK,
+ ):
+ dlna_player.force_poll = True
+ self.mass.create_task(self.poll_player(dlna_player.udn))
+
+ dlna_player.last_seen = time.time()
+ self.mass.create_task(self._update_player(dlna_player))
+
+ async def _enqueue_next_track(
+ self, dlna_player: DLNAPlayer, current_queue_item_id: str
+ ) -> None:
+ """Enqueue the next track of the MA queue on the CC queue."""
+ if not current_queue_item_id:
+ return # guard
+ if not self.mass.players.queues.get_item(dlna_player.udn, current_queue_item_id):
+ return # guard
+ try:
+ next_item, crossfade = self.mass.players.queues.player_ready_for_next_track(
+ dlna_player.udn, current_queue_item_id
+ )
+ except QueueEmpty:
+ return
+
+ if dlna_player.next_item == next_item.queue_item_id:
+ return # already set ?!
+ dlna_player.next_item = next_item.queue_item_id
+
+ # no need to try setting the next url if we already know the player does not support it
+ if not dlna_player.supports_next_uri:
+ return
+
+ # send queue item to dlna queue
+ url = await self.mass.streams.resolve_stream_url(
+ queue_item=next_item,
+ player_id=dlna_player.udn,
+ content_type=ContentType.FLAC,
+ # DLNA pre-caches pretty aggressively so do not yet start the runner
+ auto_start_runner=False,
+ )
+ didl_metadata = create_didl_metadata(url, next_item)
+ try:
+ await dlna_player.device.async_set_next_transport_uri(
+ url, next_item.name, didl_metadata
+ )
+ except UpnpError:
+ dlna_player.supports_next_uri = False
+ self.logger.info("Player does not support next uri")
+
+ self.logger.debug(
+ "Enqued next track (%s) to player %s",
+ next_item.name,
+ dlna_player.player.display_name,
+ )
+
+ async def _update_player(self, dlna_player: DLNAPlayer) -> None:
+ """Update DLNA Player."""
+ prev_item_id = dlna_player.player.current_item_id
+ prev_url = dlna_player.player.current_url
+ prev_state = dlna_player.player.state
+ dlna_player.update_attributes()
+ current_item_id = dlna_player.player.current_item_id
+ current_url = dlna_player.player.current_url
+ current_state = dlna_player.player.state
+
+ if (prev_url != current_url) or (prev_state != current_state):
+ # fetch track details on state or url change
+ dlna_player.force_poll = True
+
+ # let the MA player manager work out if something actually updated
+ self.mass.players.update(dlna_player.udn)
+
+ # enqueue next item if needed
+ if dlna_player.player.state == PlayerState.PLAYING and (
+ prev_item_id != current_item_id
+ or not dlna_player.next_item
+ or dlna_player.next_item == current_item_id
+ ):
+ self.mass.create_task(self._enqueue_next_track(dlna_player, current_item_id))
+ # if player does not support next uri, manual play it
+ if (
+ not dlna_player.supports_next_uri
+ and prev_state == PlayerState.PLAYING
+ and current_state == PlayerState.IDLE
+ and dlna_player.next_item
+ and dlna_player.end_of_track_reached
+ ):
+ await self.mass.players.queues.play_index(dlna_player.udn, dlna_player.next_item)
+ dlna_player.end_of_track_reached = False
+ dlna_player.next_item = None
--- /dev/null
+"""Various helpers and utils for the DLNA Player Provider."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from aiohttp.web import Request, Response
+from async_upnp_client.event_handler import UpnpEventHandler, UpnpNotifyServer
+
+if TYPE_CHECKING:
+ from async_upnp_client.client import UpnpRequester
+
+ from music_assistant.server import MusicAssistant
+
+
+class DLNANotifyServer(UpnpNotifyServer):
+ """Notify server for async_upnp_client which uses the MA webserver."""
+
+ def __init__(
+ self,
+ requester: UpnpRequester,
+ mass: MusicAssistant,
+ ) -> None:
+ """Initialize."""
+ self.mass = mass
+ self.event_handler = UpnpEventHandler(self, requester)
+ self.mass.webapp.router.add_route("NOTIFY", "/notify", self._handle_request)
+
+ async def _handle_request(self, request: Request) -> Response:
+ """Handle incoming requests."""
+ headers = request.headers
+ body = await request.text()
+
+ if request.method != "NOTIFY":
+ return Response(status=405)
+
+ status = await self.event_handler.handle_notify(headers, body)
+
+ return Response(status=status)
+
+ @property
+ def callback_url(self) -> str:
+ """Return callback URL on which we are callable."""
+ return f"{self.mass.base_url}/notify"
--- /dev/null
+{
+ "type": "player",
+ "domain": "dlna",
+ "name": "UPnP/DLNA Player provider",
+ "description": "Support for players that are compatible with the UPnP/DLNA (DMR) standard.",
+ "codeowners": ["@music-assistant"],
+ "config_entries": [
+ ],
+ "requirements": ["async-upnp-client==0.33.1", "getmac==0.8.2"],
+ "documentation": "",
+ "multi_instance": false,
+ "builtin": false,
+ "load_by_default": true
+}
--- /dev/null
+"""Fanart.tv Metadata provider for Music Assistant."""
+from __future__ import annotations
+
+from json import JSONDecodeError
+from typing import TYPE_CHECKING
+
+import aiohttp.client_exceptions
+from asyncio_throttle import Throttler
+
+from music_assistant.common.models.enums import ProviderFeature
+from music_assistant.common.models.media_items import ImageType, MediaItemImage, MediaItemMetadata
+from music_assistant.server.controllers.cache import use_cache
+from music_assistant.server.helpers.app_vars import app_var # pylint: disable=no-name-in-module
+from music_assistant.server.models.metadata_provider import MetadataProvider
+
+if TYPE_CHECKING:
+ from music_assistant.common.models.media_items import Album, Artist
+
+# TODO: add support for personal api keys ?
+
+
+IMG_MAPPING = {
+ "artistthumb": ImageType.THUMB,
+ "hdmusiclogo": ImageType.LOGO,
+ "musicbanner": ImageType.BANNER,
+ "artistbackground": ImageType.FANART,
+}
+
+
+class FanartTvMetadataProvider(MetadataProvider):
+ """Fanart.tv Metadata provider."""
+
+ throttler: Throttler
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ self.cache = self.mass.cache
+ self.throttler = Throttler(rate_limit=2, period=1)
+ self._attr_supported_features = (
+ ProviderFeature.ARTIST_METADATA,
+ ProviderFeature.ALBUM_METADATA,
+ )
+
+ async def get_artist_metadata(self, artist: Artist) -> MediaItemMetadata | None:
+ """Retrieve metadata for artist on fanart.tv."""
+ if not artist.musicbrainz_id:
+ return None
+ self.logger.debug("Fetching metadata for Artist %s on Fanart.tv", artist.name)
+ if data := await self._get_data(f"music/{artist.musicbrainz_id}"):
+ metadata = MediaItemMetadata()
+ metadata.images = []
+ for key, img_type in IMG_MAPPING.items():
+ items = data.get(key)
+ if not items:
+ continue
+ for item in items:
+ metadata.images.append(MediaItemImage(img_type, item["url"]))
+ return metadata
+ return None
+
+ async def get_album_metadata(self, album: Album) -> MediaItemMetadata | None:
+ """Retrieve metadata for album on fanart.tv."""
+ if not album.musicbrainz_id:
+ return None
+ self.logger.debug("Fetching metadata for Album %s on Fanart.tv", album.name)
+ if data := await self._get_data(f"music/albums/{album.musicbrainz_id}"): # noqa: SIM102
+ if data and data.get("albums"):
+ data = data["albums"][album.musicbrainz_id]
+ metadata = MediaItemMetadata()
+ metadata.images = []
+ for key, img_type in IMG_MAPPING.items():
+ items = data.get(key)
+ if not items:
+ continue
+ for item in items:
+ metadata.images.append(MediaItemImage(img_type, item["url"]))
+ return metadata
+ return None
+
+ @use_cache(86400 * 14)
+ async def _get_data(self, endpoint, **kwargs) -> dict | None:
+ """Get data from api."""
+ url = f"http://webservice.fanart.tv/v3/{endpoint}"
+ kwargs["api_key"] = app_var(4)
+ async with self.throttler:
+ async with self.mass.http_session.get(url, params=kwargs, verify_ssl=False) as response:
+ try:
+ result = await response.json()
+ except (
+ aiohttp.client_exceptions.ContentTypeError,
+ JSONDecodeError,
+ ):
+ self.logger.error("Failed to retrieve %s", endpoint)
+ text_result = await response.text()
+ self.logger.debug(text_result)
+ return None
+ except (
+ aiohttp.client_exceptions.ClientConnectorError,
+ aiohttp.client_exceptions.ServerDisconnectedError,
+ ):
+ self.logger.warning("Failed to retrieve %s", endpoint)
+ return None
+ if "error" in result and "limit" in result["error"]:
+ self.logger.warning(result["error"])
+ return None
+ return result
--- /dev/null
+{
+ "type": "metadata",
+ "domain": "fanarttv",
+ "name": "fanart.tv Metadata provider",
+ "description": "fanart.tv is a community database of artwork for movies, tv series and music.",
+ "codeowners": ["@music-assistant"],
+ "config_entries": [
+ ],
+ "requirements": [],
+ "documentation": "",
+ "multi_instance": false,
+ "builtin": true,
+ "load_by_default": true
+}
--- /dev/null
+"""Filesystem musicprovider support for MusicAssistant."""
+from __future__ import annotations
+
+import asyncio
+import os
+import os.path
+from collections.abc import AsyncGenerator
+
+import aiofiles
+from aiofiles.os import wrap
+
+from music_assistant.common.models.errors import SetupFailedError
+from music_assistant.constants import CONF_PATH
+
+from .base import FileSystemItem, FileSystemProviderBase
+from .helpers import get_absolute_path, get_relative_path
+
+listdir = wrap(os.listdir)
+isdir = wrap(os.path.isdir)
+isfile = wrap(os.path.isfile)
+exists = wrap(os.path.exists)
+
+
+async def create_item(base_path: str, entry: os.DirEntry) -> FileSystemItem:
+ """Create FileSystemItem from os.DirEntry."""
+
+ def _create_item():
+ absolute_path = get_absolute_path(base_path, entry.path)
+ stat = entry.stat(follow_symlinks=False)
+ return FileSystemItem(
+ name=entry.name,
+ path=get_relative_path(base_path, entry.path),
+ absolute_path=absolute_path,
+ is_file=entry.is_file(follow_symlinks=False),
+ is_dir=entry.is_dir(follow_symlinks=False),
+ checksum=str(int(stat.st_mtime)),
+ file_size=stat.st_size,
+ # local filesystem is always local resolvable
+ local_path=absolute_path,
+ )
+
+ # run in thread because strictly taken this may be blocking IO
+ return await asyncio.to_thread(_create_item)
+
+
+class LocalFileSystemProvider(FileSystemProviderBase):
+ """Implementation of a musicprovider for local files."""
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ conf_path = self.config.get_value(CONF_PATH)
+ if not await isdir(conf_path):
+ raise SetupFailedError(f"Music Directory {conf_path} does not exist")
+
+ async def listdir(
+ self, path: str, recursive: bool = False
+ ) -> AsyncGenerator[FileSystemItem, None]:
+ """List contents of a given provider directory/path.
+
+ Parameters
+ ----------
+ - path: path of the directory (relative or absolute) to list contents of.
+ Empty string for provider's root.
+ - recursive: If True will recursively keep unwrapping subdirectories (scandir equivalent).
+
+ Returns:
+ -------
+ AsyncGenerator yielding FileSystemItem objects.
+
+ """
+ abs_path = get_absolute_path(self.config.get_value(CONF_PATH), path)
+ for entry in await asyncio.to_thread(os.scandir, abs_path):
+ if entry.name.startswith("."):
+ # skip invalid/system files and dirs
+ continue
+ item = await create_item(self.config.get_value(CONF_PATH), entry)
+ if recursive and item.is_dir:
+ try:
+ async for subitem in self.listdir(item.absolute_path, True):
+ yield subitem
+ except (OSError, PermissionError) as err:
+ self.logger.warning("Skip folder %s: %s", item.path, str(err))
+ else:
+ yield item
+
+ async def resolve(
+ self, file_path: str, require_local: bool = False # noqa: ARG002
+ ) -> FileSystemItem:
+ """Resolve (absolute or relative) path to FileSystemItem.
+
+ If require_local is True, we prefer to have the `local_path` attribute filled
+ (e.g. with a tempfile), if supported by the provider/item.
+ """
+ absolute_path = get_absolute_path(self.config.get_value(CONF_PATH), file_path)
+
+ def _create_item():
+ stat = os.stat(absolute_path, follow_symlinks=False)
+ return FileSystemItem(
+ name=os.path.basename(file_path),
+ path=get_relative_path(self.config.get_value(CONF_PATH), file_path),
+ absolute_path=absolute_path,
+ is_dir=os.path.isdir(absolute_path),
+ is_file=os.path.isfile(absolute_path),
+ checksum=str(int(stat.st_mtime)),
+ file_size=stat.st_size,
+ # local filesystem is always local resolvable
+ local_path=absolute_path,
+ )
+
+ # run in thread because strictly taken this may be blocking IO
+ return await asyncio.to_thread(_create_item)
+
+ async def exists(self, file_path: str) -> bool:
+ """Return bool is this FileSystem musicprovider has given file/dir."""
+ if not file_path:
+ return False # guard
+ abs_path = get_absolute_path(self.config.get_value(CONF_PATH), file_path)
+ return await exists(abs_path)
+
+ async def read_file_content(self, file_path: str, seek: int = 0) -> AsyncGenerator[bytes, None]:
+ """Yield (binary) contents of file in chunks of bytes."""
+ abs_path = get_absolute_path(self.config.get_value(CONF_PATH), file_path)
+ chunk_size = 512000
+ async with aiofiles.open(abs_path, "rb") as _file:
+ if seek:
+ await _file.seek(seek)
+ # yield chunks of data from file
+ while True:
+ data = await _file.read(chunk_size)
+ if not data:
+ break
+ yield data
+
+ async def write_file_content(self, file_path: str, data: bytes) -> None:
+ """Write entire file content as bytes (e.g. for playlists)."""
+ abs_path = get_absolute_path(self.config.get_value(CONF_PATH), file_path)
+ async with aiofiles.open(abs_path, "wb") as _file:
+ await _file.write(data)
--- /dev/null
+"""Filesystem musicprovider support for MusicAssistant."""
+from __future__ import annotations
+
+import asyncio
+import contextlib
+import os
+from abc import abstractmethod
+from collections.abc import AsyncGenerator
+from dataclasses import dataclass
+from time import time
+
+import xmltodict
+
+from music_assistant.common.helpers.util import parse_title_and_version
+from music_assistant.common.models.enums import ProviderFeature
+from music_assistant.common.models.errors import MediaNotFoundError, MusicAssistantError
+from music_assistant.common.models.media_items import (
+ Album,
+ AlbumType,
+ Artist,
+ BrowseFolder,
+ ContentType,
+ ImageType,
+ MediaItemImage,
+ MediaItemType,
+ MediaType,
+ Playlist,
+ ProviderMapping,
+ Radio,
+ StreamDetails,
+ Track,
+)
+from music_assistant.constants import SCHEMA_VERSION, VARIOUS_ARTISTS, VARIOUS_ARTISTS_ID
+from music_assistant.server.helpers.compare import compare_strings
+from music_assistant.server.helpers.playlists import parse_m3u, parse_pls
+from music_assistant.server.helpers.tags import parse_tags, split_items
+from music_assistant.server.models.music_provider import MusicProvider
+
+from .helpers import get_parentdir
+
+TRACK_EXTENSIONS = ("mp3", "m4a", "mp4", "flac", "wav", "ogg", "aiff", "wma", "dsf")
+PLAYLIST_EXTENSIONS = ("m3u", "pls")
+SUPPORTED_EXTENSIONS = TRACK_EXTENSIONS + PLAYLIST_EXTENSIONS
+IMAGE_EXTENSIONS = ("jpg", "jpeg", "JPG", "JPEG", "png", "PNG", "gif", "GIF")
+
+
+@dataclass
+class FileSystemItem:
+ """Representation of an item (file or directory) on the filesystem.
+
+ - name: Name (not path) of the file (or directory).
+ - path: Relative path to the item on this filesystem provider.
+ - absolute_path: Absolute (provider dependent) path to this item.
+ - is_file: Boolean if item is file (not directory or symlink).
+ - is_dir: Boolean if item is directory (not file).
+ - checksum: Checksum for this path (usually last modified time).
+ - file_size : File size in number of bytes or None if unknown (or not a file).
+ - local_path: Optional local accessible path to this (file)item, supported by ffmpeg.
+ """
+
+ name: str
+ path: str
+ absolute_path: str
+ is_file: bool
+ is_dir: bool
+ checksum: str
+ file_size: int | None = None
+ local_path: str | None = None
+
+ @property
+ def ext(self) -> str | None:
+ """Return file extension."""
+ try:
+ return self.name.rsplit(".", 1)[1]
+ except IndexError:
+ return None
+
+
+class FileSystemProviderBase(MusicProvider):
+ """Base Implementation of a musicprovider for files.
+
+ Reads ID3 tags from file and falls back to parsing filename.
+ Optionally reads metadata from nfo files and images in folder structure <artist>/<album>.
+ Supports m3u files only for playlists.
+ Supports having URI's from streaming providers within m3u playlist.
+ """
+
+ _attr_supported_features = (
+ ProviderFeature.LIBRARY_ARTISTS,
+ ProviderFeature.LIBRARY_ALBUMS,
+ ProviderFeature.LIBRARY_TRACKS,
+ ProviderFeature.LIBRARY_PLAYLISTS,
+ ProviderFeature.PLAYLIST_TRACKS_EDIT,
+ ProviderFeature.PLAYLIST_CREATE,
+ ProviderFeature.BROWSE,
+ ProviderFeature.SEARCH,
+ )
+
+ @abstractmethod
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+
+ @abstractmethod
+ async def listdir(
+ self, path: str, recursive: bool = False
+ ) -> AsyncGenerator[FileSystemItem, None]:
+ """List contents of a given provider directory/path.
+
+ Parameters
+ ----------
+ - path: path of the directory (relative or absolute) to list contents of.
+ Empty string for provider's root.
+ - recursive: If True will recursively keep unwrapping subdirectories (scandir equivalent).
+
+ Returns:
+ -------
+ AsyncGenerator yielding FileSystemItem objects.
+
+ """
+
+ @abstractmethod
+ async def resolve(self, file_path: str) -> FileSystemItem:
+ """Resolve (absolute or relative) path to FileSystemItem."""
+
+ @abstractmethod
+ async def exists(self, file_path: str) -> bool:
+ """Return bool is this FileSystem musicprovider has given file/dir."""
+
+ @abstractmethod
+ async def read_file_content(self, file_path: str, seek: int = 0) -> AsyncGenerator[bytes, None]:
+ """Yield (binary) contents of file in chunks of bytes."""
+
+ @abstractmethod
+ async def write_file_content(self, file_path: str, data: bytes) -> None:
+ """Write entire file content as bytes (e.g. for playlists)."""
+
+ ##############################################
+ # DEFAULT/GENERIC IMPLEMENTATION BELOW
+ # should normally not be needed to override
+
+ async def search(
+ self, search_query: str, media_types=list[MediaType] | None, limit: int = 5 # noqa: ARG002
+ ) -> list[MediaItemType]:
+ """Perform search on this file based musicprovider."""
+ result: list[MediaItemType] = []
+ # searching the filesystem is slow and unreliable,
+ # instead we make some (slow) freaking queries to the db ;-)
+ params = {
+ "name": f"%{search_query}%",
+ "provider_instance": f"%{self.instance_id}%",
+ }
+ # ruff: noqa: E501
+ if media_types is None or MediaType.TRACK in media_types:
+ query = "SELECT * FROM tracks WHERE name LIKE :name AND provider_mappings LIKE :provider_instance"
+ tracks = await self.mass.music.tracks.get_db_items_by_query(query, params)
+ result += tracks
+ if media_types is None or MediaType.ALBUM in media_types:
+ query = "SELECT * FROM albums WHERE name LIKE :name AND provider_mappings LIKE :provider_instance"
+ albums = await self.mass.music.albums.get_db_items_by_query(query, params)
+ result += albums
+ if media_types is None or MediaType.ARTIST in media_types:
+ query = "SELECT * FROM artists WHERE name LIKE :name AND provider_mappings LIKE :provider_instance"
+ artists = await self.mass.music.artists.get_db_items_by_query(query, params)
+ result += artists
+ if media_types is None or MediaType.PLAYLIST in media_types:
+ query = "SELECT * FROM playlists WHERE name LIKE :name AND provider_mappings LIKE :provider_instance"
+ playlists = await self.mass.music.playlists.get_db_items_by_query(query, params)
+ result += playlists
+ return result
+
+ async def browse(self, path: str) -> BrowseFolder:
+ """Browse this provider's items.
+
+ :param path: The path to browse, (e.g. provid://artists).
+ """
+ _, item_path = path.split("://")
+ if not item_path:
+ item_path = ""
+ subitems = []
+ async for item in self.listdir(item_path, recursive=False):
+ if item.is_dir:
+ subitems.append(
+ BrowseFolder(
+ item_id=item.path,
+ provider=self.domain,
+ path=f"{self.instance_id}://{item.path}",
+ name=item.name,
+ )
+ )
+ continue
+
+ if "." not in item.name or not item.ext:
+ # skip system files and files without extension
+ continue
+
+ if item.ext in TRACK_EXTENSIONS:
+ if db_item := await self.mass.music.tracks.get_db_item_by_prov_id(
+ item.path, provider_instance=self.instance_id
+ ):
+ subitems.append(db_item)
+ elif track := await self.get_track(item.path):
+ # make sure that the item exists
+ # https://github.com/music-assistant/hass-music-assistant/issues/707
+ db_item = await self.mass.music.tracks.add_db_item(track)
+ subitems.append(db_item)
+ continue
+ if item.ext in PLAYLIST_EXTENSIONS:
+ if db_item := await self.mass.music.playlists.get_db_item_by_prov_id(
+ item.path, provider_instance=self.instance_id
+ ):
+ subitems.append(db_item)
+ elif playlist := await self.get_playlist(item.path):
+ # make sure that the item exists
+ # https://github.com/music-assistant/hass-music-assistant/issues/707
+ db_item = await self.mass.music.playlists.add_db_item(playlist)
+ subitems.append(db_item)
+ continue
+
+ return BrowseFolder(
+ item_id=item_path,
+ provider=self.domain,
+ path=path,
+ name=item_path or self.name,
+ # make sure to sort the resulting listing
+ items=sorted(subitems, key=lambda x: (x.name.casefold(), x.name)),
+ )
+
+ async def sync_library(
+ self, media_types: tuple[MediaType, ...] | None = None # noqa: ARG002
+ ) -> None:
+ """Run library sync for this provider."""
+ cache_key = f"{self.instance_id}.checksums"
+ prev_checksums = await self.mass.cache.get(cache_key, SCHEMA_VERSION)
+ save_checksum_interval = 0
+ if prev_checksums is None:
+ prev_checksums = {}
+
+ # find all music files in the music directory and all subfolders
+ # we work bottom up, as-in we derive all info from the tracks
+ cur_checksums = {}
+ async for item in self.listdir("", recursive=True):
+ if "." not in item.name or not item.ext:
+ # skip system files and files without extension
+ continue
+
+ if item.ext not in SUPPORTED_EXTENSIONS:
+ # unsupported file extension
+ continue
+
+ try:
+ cur_checksums[item.path] = item.checksum
+ if item.checksum == prev_checksums.get(item.path):
+ continue
+
+ if item.ext in TRACK_EXTENSIONS:
+ # add/update track to db
+ track = await self.get_track(item.path)
+ # if the track was edited on disk, always overwrite existing db details
+ overwrite_existing = item.path in prev_checksums
+ await self.mass.music.tracks.add_db_item(
+ track, overwrite_existing=overwrite_existing
+ )
+ elif item.ext in PLAYLIST_EXTENSIONS:
+ playlist = await self.get_playlist(item.path)
+ # add/update] playlist to db
+ playlist.metadata.checksum = item.checksum
+ # playlist is always in-library
+ playlist.in_library = True
+ await self.mass.music.playlists.add_db_item(playlist)
+ except Exception as err: # pylint: disable=broad-except
+ # we don't want the whole sync to crash on one file so we catch all exceptions here
+ self.logger.exception("Error processing %s - %s", item.path, str(err))
+
+ # save checksums every 100 processed items
+ # this allows us to pickup where we leftoff when initial scan gets interrupted
+ if save_checksum_interval == 100:
+ await self.mass.cache.set(cache_key, cur_checksums, SCHEMA_VERSION)
+ save_checksum_interval = 0
+ else:
+ save_checksum_interval += 1
+
+ # store (final) checksums in cache
+ await self.mass.cache.set(cache_key, cur_checksums, SCHEMA_VERSION)
+ # work out deletions
+ deleted_files = set(prev_checksums.keys()) - set(cur_checksums.keys())
+ await self._process_deletions(deleted_files)
+
+ async def _process_deletions(self, deleted_files: set[str]) -> None:
+ """Process all deletions."""
+ # process deleted tracks/playlists
+ for file_path in deleted_files:
+ _, ext = file_path.rsplit(".", 1)
+ if ext not in SUPPORTED_EXTENSIONS:
+ # unsupported file extension
+ continue
+
+ if ext in PLAYLIST_EXTENSIONS:
+ controller = self.mass.music.get_controller(MediaType.PLAYLIST)
+ else:
+ controller = self.mass.music.get_controller(MediaType.TRACK)
+
+ if db_item := await controller.get_db_item_by_prov_id(
+ file_path, provider_instance=self.instance_id
+ ):
+ await controller.remove_prov_mapping(db_item.item_id, self.instance_id)
+
+ async def get_artist(self, prov_artist_id: str) -> Artist:
+ """Get full artist details by id."""
+ db_artist = await self.mass.music.artists.get_db_item_by_prov_id(
+ item_id=prov_artist_id, provider_instance=self.instance_id
+ )
+ if db_artist is None:
+ raise MediaNotFoundError(f"Artist not found: {prov_artist_id}")
+ if await self.exists(prov_artist_id):
+ # if path exists on disk allow parsing full details to allow refresh of metadata
+ return await self._parse_artist(db_artist.name, artist_path=prov_artist_id)
+ return db_artist
+
+ async def get_album(self, prov_album_id: str) -> Album:
+ """Get full album details by id."""
+ db_album = await self.mass.music.albums.get_db_item_by_prov_id(
+ item_id=prov_album_id, provider_instance=self.instance_id
+ )
+ if db_album is None:
+ raise MediaNotFoundError(f"Album not found: {prov_album_id}")
+ if await self.exists(prov_album_id):
+ # if path exists on disk allow parsing full details to allow refresh of metadata
+ return await self._parse_album(db_album.name, prov_album_id, db_album.artists)
+ return db_album
+
+ async def get_track(self, prov_track_id: str) -> Track:
+ """Get full track details by id."""
+ # ruff: noqa: PLR0915, PLR0912
+ if not await self.exists(prov_track_id):
+ raise MediaNotFoundError(f"Track path does not exist: {prov_track_id}")
+
+ file_item = await self.resolve(prov_track_id)
+
+ # parse tags
+ input_file = file_item.local_path or self.read_file_content(file_item.absolute_path)
+ tags = await parse_tags(input_file)
+
+ name, version = parse_title_and_version(tags.title)
+ track = Track(
+ item_id=file_item.path,
+ provider=self.domain,
+ name=name,
+ version=version,
+ )
+
+ # album
+ if tags.album:
+ # work out if we have an album folder
+ album_dir = get_parentdir(file_item.path, tags.album)
+
+ # album artist(s)
+ if tags.album_artists:
+ album_artists = []
+ for index, album_artist_str in enumerate(tags.album_artists):
+ # work out if we have an artist folder
+ artist_dir = get_parentdir(file_item.path, album_artist_str)
+ artist = await self._parse_artist(album_artist_str, artist_path=artist_dir)
+ if not artist.musicbrainz_id:
+ with contextlib.suppress(IndexError):
+ artist.musicbrainz_id = tags.musicbrainz_albumartistids[index]
+ album_artists.append(artist)
+ else:
+ # always fallback to various artists as album artist if user did not tag album artist
+ # ID3 tag properly because we must have an album artist
+ self.logger.warning(
+ "%s is missing ID3 tag [albumartist], using %s as fallback",
+ file_item.path,
+ VARIOUS_ARTISTS,
+ )
+ album_artists = [await self._parse_artist(name=VARIOUS_ARTISTS)]
+
+ track.album = await self._parse_album(
+ tags.album,
+ album_dir,
+ artists=album_artists,
+ )
+ else:
+ self.logger.warning("%s is missing ID3 tag [album]", file_item.path)
+
+ # track artist(s)
+ for index, track_artist_str in enumerate(tags.artists):
+ # re-use album artist details if possible
+ if track.album and (
+ artist := next((x for x in track.album.artists if x.name == track_artist_str), None)
+ ):
+ track.artists.append(artist)
+ continue
+ artist = await self._parse_artist(track_artist_str)
+ if not artist.musicbrainz_id:
+ with contextlib.suppress(IndexError):
+ artist.musicbrainz_id = tags.musicbrainz_artistids[index]
+ track.artists.append(artist)
+
+ # cover image - prefer album image, fallback to embedded
+ if track.album and track.album.image:
+ track.metadata.images = [track.album.image]
+ elif tags.has_cover_image:
+ # we do not actually embed the image in the metadata because that would consume too
+ # much space and bandwidth. Instead we set the filename as value so the image can
+ # be retrieved later in realtime.
+ track.metadata.images = [MediaItemImage(ImageType.THUMB, file_item.path, True)]
+ if track.album:
+ # set embedded cover on album
+ track.album.metadata.images = track.metadata.images
+
+ # parse other info
+ track.duration = tags.duration or 0
+ track.metadata.genres = tags.genres
+ track.disc_number = tags.disc
+ track.track_number = tags.track
+ track.isrc = tags.get("isrc")
+ track.metadata.copyright = tags.get("copyright")
+ track.metadata.lyrics = tags.get("lyrics")
+ track.musicbrainz_id = tags.musicbrainz_trackid
+ if track.album:
+ if not track.album.musicbrainz_id:
+ track.album.musicbrainz_id = tags.musicbrainz_releasegroupid
+ if not track.album.year:
+ track.album.year = tags.year
+ if not track.album.upc:
+ track.album.upc = tags.get("barcode")
+ # try to parse albumtype
+ if track.album and track.album.album_type == AlbumType.UNKNOWN:
+ album_type = tags.album_type
+ if album_type and "compilation" in album_type:
+ track.album.album_type = AlbumType.COMPILATION
+ elif album_type and "single" in album_type:
+ track.album.album_type = AlbumType.SINGLE
+ elif album_type and "album" in album_type:
+ track.album.album_type = AlbumType.ALBUM
+ elif track.album.sort_name in track.sort_name:
+ track.album.album_type = AlbumType.SINGLE
+
+ # set checksum to invalidate any cached listings
+ checksum_timestamp = str(int(time()))
+ track.metadata.checksum = checksum_timestamp
+ if track.album:
+ track.album.metadata.checksum = checksum_timestamp
+ for artist in track.album.artists:
+ artist.metadata.checksum = checksum_timestamp
+
+ track.add_provider_mapping(
+ ProviderMapping(
+ item_id=file_item.path,
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ content_type=ContentType.try_parse(tags.format),
+ sample_rate=tags.sample_rate,
+ bit_depth=tags.bits_per_sample,
+ bit_rate=tags.bit_rate,
+ )
+ )
+ return track
+
+ async def get_playlist(self, prov_playlist_id: str) -> Playlist:
+ """Get full playlist details by id."""
+ if not await self.exists(prov_playlist_id):
+ raise MediaNotFoundError(f"Playlist path does not exist: {prov_playlist_id}")
+
+ file_item = await self.resolve(prov_playlist_id)
+ playlist = Playlist(file_item.path, provider=self.domain, name=file_item.name)
+ playlist.is_editable = file_item.ext != "pls" # can only edit m3u playlists
+
+ playlist.add_provider_mapping(
+ ProviderMapping(
+ item_id=file_item.path,
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ )
+ )
+ playlist.owner = self.name
+ checksum = f"{SCHEMA_VERSION}.{file_item.checksum}"
+ playlist.metadata.checksum = checksum
+ return playlist
+
+ async def get_album_tracks(self, prov_album_id: str) -> list[Track]:
+ """Get album tracks for given album id."""
+ # filesystem items are always stored in db so we can query the database
+ db_album = await self.mass.music.albums.get_db_item_by_prov_id(
+ prov_album_id, provider_instance=self.instance_id
+ )
+ if db_album is None:
+ raise MediaNotFoundError(f"Album not found: {prov_album_id}")
+ # TODO: adjust to json query instead of text search
+ query = f"SELECT * FROM tracks WHERE albums LIKE '%\"{db_album.item_id}\"%'"
+ query += f" AND provider_mappings LIKE '%\"{self.instance_id}\"%'"
+ result = []
+ for track in await self.mass.music.tracks.get_db_items_by_query(query):
+ track.album = db_album
+ if album_mapping := next(
+ (x for x in track.albums if x.item_id == db_album.item_id), None
+ ):
+ track.disc_number = album_mapping.disc_number
+ track.track_number = album_mapping.track_number
+ result.append(track)
+ return sorted(result, key=lambda x: (x.disc_number or 0, x.track_number or 0))
+
+ async def get_playlist_tracks(self, prov_playlist_id: str) -> list[Track]:
+ """Get playlist tracks for given playlist id."""
+ result = []
+ if not await self.exists(prov_playlist_id):
+ raise MediaNotFoundError(f"Playlist path does not exist: {prov_playlist_id}")
+
+ _, ext = prov_playlist_id.rsplit(".", 1)
+ try:
+ # get playlist file contents
+ playlist_data = b""
+ async for chunk in self.read_file_content(prov_playlist_id):
+ playlist_data += chunk
+ playlist_data = playlist_data.decode("utf-8")
+
+ if ext in ("m3u", "m3u8"):
+ playlist_lines = await parse_m3u(playlist_data)
+ else:
+ playlist_lines = await parse_pls(playlist_data)
+
+ for line_no, playlist_line in enumerate(playlist_lines):
+ if media_item := await self._parse_playlist_line(
+ playlist_line, os.path.dirname(prov_playlist_id)
+ ):
+ # use the linenumber as position for easier deletions
+ media_item.position = line_no
+ result.append(media_item)
+
+ except Exception as err: # pylint: disable=broad-except
+ self.logger.warning("Error while parsing playlist %s", prov_playlist_id, exc_info=err)
+ return result
+
+ async def _parse_playlist_line(self, line: str, playlist_path: str) -> Track | Radio | None:
+ """Try to parse a track from a playlist line."""
+ try:
+ # try to treat uri as (relative) filename
+ if "://" not in line:
+ for filename in (line, os.path.join(playlist_path, line)):
+ if not await self.exists(filename):
+ continue
+ return await self.get_track(filename)
+ # fallback to generic uri parsing
+ return await self.mass.music.get_item_by_uri(line)
+ except MusicAssistantError as err:
+ self.logger.warning("Could not parse uri/file %s to track: %s", line, str(err))
+ return None
+
+ async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]) -> None:
+ """Add track(s) to playlist."""
+ if not await self.exists(prov_playlist_id):
+ raise MediaNotFoundError(f"Playlist path does not exist: {prov_playlist_id}")
+ playlist_data = b""
+ async for chunk in self.read_file_content(prov_playlist_id):
+ playlist_data += chunk
+ playlist_data = playlist_data.decode("utf-8")
+ for uri in prov_track_ids:
+ playlist_data += f"\n{uri}"
+
+ # write playlist file
+ await self.write_file_content(prov_playlist_id, playlist_data.encode("utf-8"))
+
+ async def remove_playlist_tracks(
+ self, prov_playlist_id: str, positions_to_remove: tuple[int]
+ ) -> None:
+ """Remove track(s) from playlist."""
+ if not await self.exists(prov_playlist_id):
+ raise MediaNotFoundError(f"Playlist path does not exist: {prov_playlist_id}")
+ cur_lines = []
+ _, ext = prov_playlist_id.rsplit(".", 1)
+
+ # get playlist file contents
+ playlist_data = b""
+ async for chunk in self.read_file_content(prov_playlist_id):
+ playlist_data += chunk
+ playlist_data.decode("utf-8")
+
+ if ext in ("m3u", "m3u8"):
+ playlist_lines = await parse_m3u(playlist_data)
+ else:
+ playlist_lines = await parse_pls(playlist_data)
+
+ for line_no, playlist_line in enumerate(playlist_lines):
+ if line_no not in positions_to_remove:
+ cur_lines.append(playlist_line)
+
+ new_playlist_data = "\n".join(cur_lines)
+ # write playlist file
+ await self.write_file_content(prov_playlist_id, new_playlist_data.encode("utf-8"))
+
+ async def create_playlist(self, name: str) -> Playlist:
+ """Create a new playlist on provider with given name."""
+ # creating a new playlist on the filesystem is as easy
+ # as creating a new (empty) file with the m3u extension...
+ filename = await self.resolve(f"{name}.m3u")
+ await self.write_file_content(filename, b"")
+ playlist = await self.get_playlist(filename)
+ db_playlist = await self.mass.music.playlists.add_db_item(playlist)
+ return db_playlist
+
+ async def get_stream_details(self, item_id: str) -> StreamDetails:
+ """Return the content details for the given track when it will be streamed."""
+ db_item = await self.mass.music.tracks.get_db_item_by_prov_id(
+ item_id=item_id, provider_instance=self.instance_id
+ )
+ if db_item is None:
+ raise MediaNotFoundError(f"Item not found: {item_id}")
+
+ prov_mapping = next(x for x in db_item.provider_mappings if x.item_id == item_id)
+ file_item = await self.resolve(item_id)
+
+ return StreamDetails(
+ provider=self.domain,
+ item_id=item_id,
+ content_type=prov_mapping.content_type,
+ media_type=MediaType.TRACK,
+ duration=db_item.duration,
+ size=file_item.file_size,
+ sample_rate=prov_mapping.sample_rate,
+ bit_depth=prov_mapping.bit_depth,
+ direct=file_item.local_path,
+ )
+
+ async def get_audio_stream(
+ self, streamdetails: StreamDetails, seek_position: int = 0
+ ) -> AsyncGenerator[bytes, None]:
+ """Return the audio stream for the provider item."""
+ if seek_position:
+ assert streamdetails.duration, "Duration required for seek requests"
+ assert streamdetails.size, "Filesize required for seek requests"
+ seek_bytes = int((streamdetails.size / streamdetails.duration) * seek_position)
+ else:
+ seek_bytes = 0
+
+ async for chunk in self.read_file_content(streamdetails.item_id, seek_bytes):
+ yield chunk
+
+ async def _parse_artist(
+ self,
+ name: str | None = None,
+ artist_path: str | None = None,
+ ) -> Artist | None:
+ """Lookup metadata in Artist folder."""
+ assert name or artist_path
+ if not artist_path:
+ artist_path = name
+
+ if not name:
+ name = artist_path.split(os.sep)[-1]
+
+ artist = Artist(
+ artist_path,
+ self.domain,
+ name,
+ provider_mappings={
+ ProviderMapping(artist_path, self.domain, self.instance_id, url=artist_path)
+ },
+ musicbrainz_id=VARIOUS_ARTISTS_ID if compare_strings(name, VARIOUS_ARTISTS) else None,
+ )
+
+ if not await self.exists(artist_path):
+ # return basic object if there is no dedicated artist folder
+ return artist
+
+ nfo_file = os.path.join(artist_path, "artist.nfo")
+ if await self.exists(nfo_file):
+ # found NFO file with metadata
+ # https://kodi.wiki/view/NFO_files/Artists
+ data = b""
+ async for chunk in self.read_file_content(nfo_file):
+ data += chunk
+ info = await asyncio.to_thread(xmltodict.parse, data)
+ info = info["artist"]
+ artist.name = info.get("title", info.get("name", name))
+ if sort_name := info.get("sortname"):
+ artist.sort_name = sort_name
+ if musicbrainz_id := info.get("musicbrainzartistid"):
+ artist.musicbrainz_id = musicbrainz_id
+ if description := info.get("biography"):
+ artist.metadata.description = description
+ if genre := info.get("genre"):
+ artist.metadata.genres = set(split_items(genre))
+ # find local images
+ artist.metadata.images = await self._get_local_images(artist_path) or None
+
+ return artist
+
+ async def _parse_album(
+ self, name: str | None, album_path: str | None, artists: list[Artist]
+ ) -> Album | None:
+ """Lookup metadata in Album folder."""
+ assert (name or album_path) and artists
+ if not album_path:
+ # create fake path
+ album_path = artists[0].name + os.sep + name
+
+ if not name:
+ name = album_path.split(os.sep)[-1]
+
+ album = Album(
+ album_path,
+ self.domain,
+ name,
+ artists=artists,
+ provider_mappings={
+ ProviderMapping(album_path, self.domain, self.instance_id, url=album_path)
+ },
+ )
+
+ if not await self.exists(album_path):
+ # return basic object if there is no dedicated album folder
+ return album
+
+ nfo_file = os.path.join(album_path, "album.nfo")
+ if await self.exists(nfo_file):
+ # found NFO file with metadata
+ # https://kodi.wiki/view/NFO_files/Artists
+ data = b""
+ async for chunk in self.read_file_content(nfo_file):
+ data += chunk
+ info = await asyncio.to_thread(xmltodict.parse, data)
+ info = info["album"]
+ album.name = info.get("title", info.get("name", name))
+ if sort_name := info.get("sortname"):
+ album.sort_name = sort_name
+ if musicbrainz_id := info.get("musicbrainzreleasegroupid"):
+ album.musicbrainz_id = musicbrainz_id
+ if mb_artist_id := info.get("musicbrainzalbumartistid"): # noqa: SIM102
+ if album.artist and not album.artist.musicbrainz_id:
+ album.artist.musicbrainz_id = mb_artist_id
+ if description := info.get("review"):
+ album.metadata.description = description
+ if year := info.get("year"):
+ album.year = int(year)
+ if genre := info.get("genre"):
+ album.metadata.genres = set(split_items(genre))
+ # parse name/version
+ album.name, album.version = parse_title_and_version(album.name)
+
+ # find local images
+ album.metadata.images = await self._get_local_images(album_path) or None
+
+ return album
+
+ async def _get_local_images(self, folder: str) -> list[MediaItemImage]:
+ """Return local images found in a given folderpath."""
+ images = []
+ async for item in self.listdir(folder):
+ if "." not in item.path or item.is_dir:
+ continue
+ for ext in IMAGE_EXTENSIONS:
+ if item.ext != ext:
+ continue
+ try:
+ images.append(MediaItemImage(ImageType(item.name), item.path, True))
+ except ValueError:
+ if "folder" in item.name or "AlbumArt" in item.name or "Artist" in item.name:
+ images.append(MediaItemImage(ImageType.THUMB, item.path, True))
+ return images
--- /dev/null
+"""Some helpers for Filesystem based Musicproviders."""
+from __future__ import annotations
+
+import os
+
+from music_assistant.server.helpers.compare import compare_strings
+
+
+def get_parentdir(base_path: str, name: str) -> str | None:
+ """Look for folder name in path (to find dedicated artist or album folder)."""
+ parentdir = os.path.dirname(base_path)
+ for _ in range(3):
+ dirname = parentdir.rsplit(os.sep)[-1]
+ if compare_strings(name, dirname, False):
+ return parentdir
+ parentdir = os.path.dirname(parentdir)
+ return None
+
+
+def get_relative_path(base_path: str, path: str) -> str:
+ """Return the relative path string for a path."""
+ if path.startswith(base_path):
+ path = path.split(base_path)[1]
+ for sep in ("/", "\\"):
+ if path.startswith(sep):
+ path = path[1:]
+ return path
+
+
+def get_absolute_path(base_path: str, path: str) -> str:
+ """Return the absolute path string for a path."""
+ if path.startswith(base_path):
+ return path
+ return os.path.join(base_path, path)
--- /dev/null
+{
+ "type": "music",
+ "domain": "filesystem_local",
+ "name": "Local Filesystem",
+ "description": "Support for music files that are present on a local accessible disk/folder.",
+ "codeowners": ["@marcelveldt"],
+ "config_entries": [
+ {
+ "key": "path",
+ "type": "string",
+ "label": "Path",
+ "default_value": "/music"
+ }
+ ],
+
+ "requirements": [],
+ "documentation": "https://github.com/music-assistant/hass-music-assistant/discussions/820",
+ "multi_instance": true,
+ "init_class": "LocalFileSystemProvider"
+}
--- /dev/null
+"""SMB filesystem provider for Music Assistant."""
+
+import contextvars
+import os
+from collections.abc import AsyncGenerator
+from contextlib import asynccontextmanager
+
+from smb.base import SharedFile
+
+from music_assistant.common.helpers.util import get_ip_from_host
+from music_assistant.constants import CONF_PASSWORD, CONF_PATH, CONF_USERNAME
+from music_assistant.server.providers.filesystem_local.base import (
+ FileSystemItem,
+ FileSystemProviderBase,
+)
+from music_assistant.server.providers.filesystem_local.helpers import (
+ get_absolute_path,
+ get_relative_path,
+)
+
+from .helpers import AsyncSMB
+
+
+async def create_item(file_path: str, entry: SharedFile, root_path: str) -> FileSystemItem:
+ """Create FileSystemItem from smb.SharedFile."""
+ rel_path = get_relative_path(root_path, file_path)
+ abs_path = get_absolute_path(root_path, file_path)
+ return FileSystemItem(
+ name=entry.filename,
+ path=rel_path,
+ absolute_path=abs_path,
+ is_file=not entry.isDirectory,
+ is_dir=entry.isDirectory,
+ checksum=str(int(entry.last_write_time)),
+ file_size=entry.file_size,
+ )
+
+
+smb_conn_ctx = contextvars.ContextVar("smb_conn_ctx", default=None)
+
+
+class SMBFileSystemProvider(FileSystemProviderBase):
+ """Implementation of an SMB File System Provider."""
+
+ _service_name = ""
+ _root_path = "/"
+ _remote_name = ""
+ _target_ip = ""
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ # extract params from path
+ if self.config.get_value(CONF_PATH).startswith("\\\\"):
+ path_parts = self.config.get_value(CONF_PATH)[2:].split("\\", 2)
+ if self.config.get_value(CONF_PATH).startswith("//"):
+ path_parts = self.config.get_value(CONF_PATH)[2:].split("/", 2)
+ elif self.config.get_value(CONF_PATH).startswith("smb://"):
+ path_parts = self.config.get_value(CONF_PATH)[6:].split("/", 2)
+ else:
+ path_parts = self.config.get_value(CONF_PATH).split(os.sep)
+ self._remote_name = path_parts[0]
+ self._service_name = path_parts[1]
+ if len(path_parts) > 2:
+ self._root_path = os.sep + path_parts[2]
+
+ default_target_ip = await get_ip_from_host(self._remote_name)
+ self._target_ip = self.config.get_value("target_ip") or default_target_ip
+ async with self._get_smb_connection():
+ # test connection and return
+ return
+
+ async def listdir(
+ self,
+ path: str,
+ recursive: bool = False,
+ ) -> AsyncGenerator[FileSystemItem, None]:
+ """List contents of a given provider directory/path.
+
+ Parameters
+ ----------
+ - path: path of the directory (relative or absolute) to list contents of.
+ Empty string for provider's root.
+ - recursive: If True will recursively keep unwrapping subdirectories (scandir equivalent)
+
+ Returns:
+ -------
+ AsyncGenerator yielding FileSystemItem objects.
+
+ """
+ abs_path = get_absolute_path(self._root_path, path)
+ async with self._get_smb_connection() as smb_conn:
+ path_result: list[SharedFile] = await smb_conn.list_path(abs_path)
+ for entry in path_result:
+ if entry.filename.startswith("."):
+ # skip invalid/system files and dirs
+ continue
+ file_path = os.path.join(path, entry.filename)
+ item = await create_item(file_path, entry, self._root_path)
+ if recursive and item.is_dir:
+ # yield sublevel recursively
+ try:
+ async for subitem in self.listdir(file_path, True):
+ yield subitem
+ except (OSError, PermissionError) as err:
+ self.logger.warning("Skip folder %s: %s", item.path, str(err))
+ elif item.is_file or item.is_dir:
+ yield item
+
+ async def resolve(self, file_path: str) -> FileSystemItem:
+ """Resolve (absolute or relative) path to FileSystemItem."""
+ abs_path = get_absolute_path(self._root_path, file_path)
+ async with self._get_smb_connection() as smb_conn:
+ entry: SharedFile = await smb_conn.get_attributes(abs_path)
+ return FileSystemItem(
+ name=file_path,
+ path=get_relative_path(self._root_path, file_path),
+ absolute_path=abs_path,
+ is_file=not entry.isDirectory,
+ is_dir=entry.isDirectory,
+ checksum=str(int(entry.last_write_time)),
+ file_size=entry.file_size,
+ )
+
+ async def exists(self, file_path: str) -> bool:
+ """Return bool if this FileSystem musicprovider has given file/dir."""
+ abs_path = get_absolute_path(self._root_path, file_path)
+ async with self._get_smb_connection() as smb_conn:
+ return await smb_conn.path_exists(abs_path)
+
+ async def read_file_content(self, file_path: str, seek: int = 0) -> AsyncGenerator[bytes, None]:
+ """Yield (binary) contents of file in chunks of bytes."""
+ abs_path = get_absolute_path(self._root_path, file_path)
+
+ async with self._get_smb_connection() as smb_conn:
+ async for chunk in smb_conn.retrieve_file(abs_path, seek):
+ yield chunk
+
+ async def write_file_content(self, file_path: str, data: bytes) -> None:
+ """Write entire file content as bytes (e.g. for playlists)."""
+ abs_path = get_absolute_path(self._root_path, file_path)
+ async with self._get_smb_connection() as smb_conn:
+ await smb_conn.write_file(abs_path, data)
+
+ @asynccontextmanager
+ async def _get_smb_connection(self) -> AsyncGenerator[AsyncSMB, None]:
+ """Get instance of AsyncSMB."""
+ # for a task that consists of multiple steps,
+ # the smb connection may be reused (shared through a contextvar)
+ if existing := smb_conn_ctx.get():
+ yield existing
+ return
+
+ async with AsyncSMB(
+ remote_name=self._remote_name,
+ service_name=self._service_name,
+ username=self.config.get_value(CONF_USERNAME),
+ password=self.config.get_value(CONF_PASSWORD),
+ target_ip=self._target_ip,
+ options={key: value.value for key, value in self.config.values.items()},
+ ) as smb_conn:
+ token = smb_conn_ctx.set(smb_conn)
+ yield smb_conn
+ smb_conn_ctx.reset(token)
--- /dev/null
+"""Some helpers for Filesystem based Musicproviders."""
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncGenerator
+from io import BytesIO
+from typing import Any
+
+from smb.base import SharedFile, SMBTimeout
+from smb.smb_structs import OperationFailure
+from smb.SMBConnection import SMBConnection
+
+from music_assistant.common.models.errors import LoginFailed
+
+SERVICE_NAME = "music_assistant"
+
+
+class AsyncSMB:
+ """Async wrapped pysmb."""
+
+ def __init__(
+ self,
+ remote_name: str,
+ service_name: str,
+ username: str,
+ password: str,
+ target_ip: str,
+ options: dict[str, Any],
+ ) -> None:
+ """Initialize instance."""
+ self._service_name = service_name
+ self._remote_name = remote_name
+ self._target_ip = target_ip
+ self._username = username
+ self._password = password
+ self._conn = SMBConnection(
+ username=self._username,
+ password=self._password,
+ my_name=SERVICE_NAME,
+ remote_name=self._remote_name,
+ # choose sane default options but allow user to override them via the options dict
+ domain=options.get("domain", ""),
+ use_ntlm_v2=options.get("use_ntlm_v2", False),
+ sign_options=options.get("sign_options", 2),
+ is_direct_tcp=options.get("is_direct_tcp", False),
+ )
+
+ async def list_path(self, path: str) -> list[SharedFile]:
+ """Retrieve a directory listing of files/folders at *path*."""
+ return await asyncio.to_thread(self._conn.listPath, self._service_name, path)
+
+ async def get_attributes(self, path: str) -> SharedFile:
+ """Retrieve information about the file at *path* on the *service_name*."""
+ return await asyncio.to_thread(self._conn.getAttributes, self._service_name, path)
+
+ async def retrieve_file(self, path: str, offset: int = 0) -> AsyncGenerator[bytes, None]:
+ """Retrieve file contents."""
+ chunk_size = 256000
+ while True:
+ with BytesIO() as file_obj:
+ await asyncio.to_thread(
+ self._conn.retrieveFileFromOffset,
+ self._service_name,
+ path,
+ file_obj,
+ offset,
+ chunk_size,
+ )
+ file_obj.seek(0)
+ chunk = file_obj.read()
+ yield chunk
+ offset += len(chunk)
+ if len(chunk) < chunk_size:
+ break
+
+ async def write_file(self, path: str, data: bytes) -> SharedFile:
+ """Store the contents to the file at *path*."""
+ with BytesIO() as file_obj:
+ file_obj.write(data)
+ file_obj.seek(0)
+ await asyncio.to_thread(
+ self._conn.storeFile,
+ self._service_name,
+ path,
+ file_obj,
+ )
+
+ async def path_exists(self, path: str) -> bool:
+ """Return bool is this FileSystem musicprovider has given file/dir."""
+ try:
+ await asyncio.to_thread(self._conn.getAttributes, self._service_name, path)
+ except (OperationFailure, SMBTimeout):
+ return False
+ return True
+
+ async def connect(self) -> None:
+ """Connect to the SMB server."""
+ try:
+ assert await asyncio.to_thread(self._conn.connect, self._target_ip) is True
+ except Exception as exc:
+ raise LoginFailed(f"SMB Connect failed to {self._remote_name}") from exc
+
+ async def __aenter__(self) -> AsyncSMB:
+ """Enter context manager."""
+ # connect
+ await self.connect()
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback) -> bool:
+ """Exit context manager."""
+ self._conn.close()
--- /dev/null
+{
+ "type": "music",
+ "domain": "filesystem_smb",
+ "name": "SMB Filesystem",
+ "description": "Support for music files that are present on remote SMB/CIFS share.",
+ "codeowners": ["@MarvinSchenkel", "@marcelveldt"],
+ "config_entries": [
+ {
+ "key": "path",
+ "type": "string",
+ "label": "Path",
+ "description": "Full SMB path to the files, e.g. \\\\server\\share\folder or smb://server/share"
+ },
+ {
+ "key": "username",
+ "type": "string",
+ "label": "Username"
+ },
+ {
+ "key": "password",
+ "type": "password",
+ "label": "Password"
+ },
+ {
+ "key": "target_ip",
+ "type": "string",
+ "label": "Target IP",
+ "description": "Use in case of DNS resolve issues. Connect to this IP instead of the DNS name.",
+ "advanced": true,
+ "required": false
+ },
+ {
+ "key": "domain",
+ "type": "string",
+ "label": "Domain",
+ "default_value": "",
+ "description": "The network domain. On windows, it is known as the workgroup. Usually, it is safe to leave this parameter as an empty string.",
+ "advanced": true,
+ "required": false
+ },
+ {
+ "key": "use_ntlm_v2",
+ "type": "boolean",
+ "label": "Use NTLM v2",
+ "default_value": false,
+ "description": "Indicates whether pysmb should be NTLMv1 or NTLMv2 authentication algorithm for authentication. The choice of NTLMv1 and NTLMv2 is configured on the remote server, and there is no mechanism to auto-detect which algorithm has been configured. Hence, we can only “guess” or try both algorithms. On Sambda, Windows Vista and Windows 7, NTLMv2 is enabled by default. On Windows XP, we can use NTLMv1 before NTLMv2.",
+ "advanced": true,
+ "required": false
+ },
+ {
+ "key": "sign_options",
+ "type": "integer",
+ "label": "Sign Options",
+ "default_value": 2,
+ "description": "Determines whether SMB messages will be signed. Default is SIGN_WHEN_REQUIRED. If SIGN_WHEN_REQUIRED (value=2), SMB messages will only be signed when remote server requires signing. If SIGN_WHEN_SUPPORTED (value=1), SMB messages will be signed when remote server supports signing but not requires signing. If SIGN_NEVER (value=0), SMB messages will never be signed regardless of remote server’s configurations; access errors will occur if the remote server requires signing.",
+ "advanced": true,
+ "required": false,
+ "options": [
+ { "title": "SIGN_NEVER", "value": 0 },
+ { "title": "SIGN_WHEN_SUPPORTED", "value": 1 },
+ { "title": "SIGN_WHEN_REQUIRED", "value": 2 }
+ ]
+ },
+ {
+ "key": "is_direct_tcp",
+ "type": "boolean",
+ "label": "Use Direct TCP",
+ "default_value": false,
+ "description": "Controls whether the NetBIOS over TCP/IP (is_direct_tcp=False) or the newer Direct hosting of SMB over TCP/IP (is_direct_tcp=True) will be used for the communication. The default parameter is False which will use NetBIOS over TCP/IP for wider compatibility (TCP port: 139).",
+ "advanced": true,
+ "required": false
+ }
+ ],
+
+ "requirements": ["pysmb==1.2.9.1"],
+ "documentation": "https://github.com/music-assistant/hass-music-assistant/discussions/820",
+ "multi_instance": true,
+ "init_class": "SMBFileSystemProvider"
+}
--- /dev/null
+"""The default Music Assistant (web) frontend, hosted within the server."""
+from __future__ import annotations
+
+import os
+from functools import partial
+
+from aiohttp import web
+from music_assistant_frontend import where
+
+from music_assistant.server.models.plugin import PluginProvider
+
+
+class Frontend(PluginProvider):
+ """The default Music Assistant (web) frontend, hosted within the server."""
+
+ async def setup(self) -> None:
+ """Handle async initialization of the plugin."""
+ frontend_dir = where()
+ for filename in next(os.walk(frontend_dir))[2]:
+ if filename.endswith(".py"):
+ continue
+ filepath = os.path.join(frontend_dir, filename)
+ handler = partial(self.serve_static, filepath)
+ self.mass.webapp.router.add_get(f"/{filename}", handler)
+ print(filename)
+
+ # add assets subdir as static
+ self.mass.webapp.router.add_static(
+ "/assets", os.path.join(frontend_dir, "assets"), name="assets"
+ )
+
+ # add index
+ handler = partial(self.serve_static, os.path.join(frontend_dir, "index.html"))
+ self.mass.webapp.router.add_get("/", handler)
+
+ async def serve_static(self, file_path: str, _request: web.Request) -> web.FileResponse:
+ """Serve file response."""
+ return web.FileResponse(file_path)
--- /dev/null
+{
+ "type": "plugin",
+ "domain": "frontend",
+ "name": "Frontend",
+ "description": "The default Music Assistant (web) frontend, written in Vue, hosted within the Music Assistant server.",
+ "codeowners": ["@marcelveldt"],
+ "config_entries": [
+ ],
+
+ "requirements": ["music-assistant-frontend==20230308.0"],
+ "documentation": "",
+ "multi_instance": false,
+ "builtin": true,
+ "load_by_default": true
+}
--- /dev/null
+"""JSON-RPC API which is more or less compatible with Logitech Media Server."""
+from __future__ import annotations
+
+from typing import Any
+
+from aiohttp import web
+
+from music_assistant.common.helpers.json import json_dumps, json_loads
+from music_assistant.common.models.enums import PlayerState
+from music_assistant.server.models.plugin import PluginProvider
+
+from .models import (
+ CommandErrorMessage,
+ CommandMessage,
+ CommandResultMessage,
+ PlayerItem,
+ PlayersResponse,
+ PlayerStatusResponse,
+ player_item_from_mass,
+ player_status_from_mass,
+)
+
+# ruff: noqa: ARG002, E501
+
+ArgsType = list[int | str]
+KwargsType = dict[str, Any]
+
+
+def parse_value(raw_value: int | str) -> int | str | tuple[str, int | str]:
+ """
+ Transform API param into a usable value.
+
+ Integer values are sometimes sent as string so we try to parse that.
+ """
+ if isinstance(raw_value, str):
+ if ":" in raw_value:
+ # this is a key:value value
+ key, val = raw_value.split(":")
+ return (key, val)
+ if raw_value.isnumeric():
+ # this is an integer sent as string
+ return int(raw_value)
+ return raw_value
+
+
+def parse_args(raw_values: list[int | str]) -> tuple[ArgsType, KwargsType]:
+ """Pargse Args and Kwargs from raw CLI params."""
+ args: ArgsType = []
+ kwargs: KwargsType = {}
+ for raw_value in raw_values:
+ value = parse_value(raw_value)
+ if isinstance(value, tuple):
+ kwargs[value[0]] = value[1]
+ else:
+ args.append(value)
+ return (args, kwargs)
+
+
+class JSONRPCApi(PluginProvider):
+ """Basic JSON-RPC API implementation, (partly) compatible with Logitech Media Server."""
+
+ async def setup(self) -> None:
+ """Handle async initialization of the plugin."""
+ self.mass.webapp.router.add_get("/jsonrpc.js", self._handle_jsonrpc)
+ self.mass.webapp.router.add_post("/jsonrpc.js", self._handle_jsonrpc)
+
+ async def _handle_jsonrpc(self, request: web.Request) -> web.Response:
+ """Handle request for image proxy."""
+ command_msg: CommandMessage = await request.json(loads=json_loads)
+ self.logger.debug("Received request: %s", command_msg)
+
+ if command_msg["method"] == "slim.request":
+ # Slim request handler
+ # {"method":"slim.request","id":1,"params":["aa:aa:ca:5a:94:4c",["status","-", 2, "tags:xcfldatgrKN"]]}
+ player_id = command_msg["params"][0]
+ command = str(command_msg["params"][1][0])
+ args, kwargs = parse_args(command_msg["params"][1][1:])
+
+ if handler := getattr(self, f"_handle_{command}", None):
+ # run handler for command
+ self.logger.debug(
+ "Handling JSON-RPC-request (player: %s command: %s - args: %s - kwargs: %s)",
+ player_id,
+ command,
+ str(args),
+ str(kwargs),
+ )
+ cmd_result = handler(player_id, *args, **kwargs)
+ if cmd_result is None:
+ cmd_result = {}
+ elif not isinstance(cmd_result, dict):
+ # individual values are returned with underscore ?!
+ cmd_result = {f"_{command}": cmd_result}
+ result: CommandResultMessage = {
+ **command_msg,
+ "result": cmd_result,
+ }
+ else:
+ # no handler found
+ self.logger.warning("No handler for %s", command)
+ result: CommandErrorMessage = {
+ **command_msg,
+ "error": {"code": -1, "message": "Invalid command"},
+ }
+ # return the response to the client
+ return web.json_response(result, dumps=json_dumps)
+
+ def _handle_players(
+ self,
+ player_id: str,
+ start_index: int | str = 0,
+ limit: int = 999,
+ **kwargs,
+ ) -> PlayersResponse:
+ """Handle players command."""
+ players: list[PlayerItem] = []
+ for index, mass_player in enumerate(self.mass.players.all()):
+ if isinstance(start_index, int) and index < start_index:
+ continue
+ if len(players) > limit:
+ break
+ players.append(player_item_from_mass(start_index + index, mass_player))
+ return PlayersResponse(count=len(players), players_loop=players)
+
+ def _handle_status(
+ self,
+ player_id: str,
+ *args,
+ start_index: int | str = "-",
+ limit: int = 2,
+ tags: str = "xcfldatgrKN",
+ **kwargs,
+ ) -> PlayerStatusResponse:
+ """Handle player status command."""
+ player = self.mass.players.get(player_id)
+ assert player is not None
+ queue = self.mass.players.queues.get_active_queue(player_id)
+ assert queue is not None
+ if start_index == "-":
+ start_index = queue.current_index or 0
+ queue_items = self.mass.players.queues.items(queue.queue_id)[
+ start_index : start_index + limit
+ ]
+ # we ignore the tags, just always send all info
+ return player_status_from_mass(player=player, queue=queue, queue_items=queue_items)
+
+ def _handle_mixer(
+ self,
+ player_id: str,
+ subcommand: str,
+ *args,
+ **kwargs,
+ ) -> int | None:
+ """Handle player mixer command."""
+ arg = args[0] if args else "?"
+ player = self.mass.players.get(player_id)
+ assert player is not None
+
+ # <playerid> mixer volume <0 .. 100|-100 .. +100|?>
+ if subcommand == "volume" and isinstance(arg, int):
+ self.mass.create_task(self.mass.players.cmd_volume_set, player_id, arg)
+ return
+ if subcommand == "volume" and arg == "?":
+ return player.volume_level
+ if subcommand == "volume" and "+" in arg:
+ volume_level = min(100, player.volume_level + int(arg.split("+")[1]))
+ self.mass.create_task(self.mass.players.cmd_volume_set, player_id, volume_level)
+ return
+ if subcommand == "volume" and "-" in arg:
+ volume_level = max(0, player.volume_level - int(arg.split("-")[1]))
+ self.mass.create_task(self.mass.players.cmd_volume_set, player_id, volume_level)
+ return
+
+ # <playerid> mixer muting <0|1|toggle|?|>
+ if subcommand == "muting" and isinstance(arg, int):
+ self.mass.create_task(self.mass.players.cmd_volume_mute, player_id, int(arg))
+ return
+ if subcommand == "muting" and arg == "toggle":
+ self.mass.create_task(
+ self.mass.players.cmd_volume_mute, player_id, not player.volume_muted
+ )
+ return
+ if subcommand == "muting":
+ return int(player.volume_muted)
+
+ def _handle_time(self, player_id: str, number: str | int) -> int | None:
+ """Handle player `time` command."""
+ # <playerid> time <number|-number|+number|?>
+ # The "time" command allows you to query the current number of seconds that the
+ # current song has been playing by passing in a "?".
+ # You may jump to a particular position in a song by specifying a number of seconds
+ # to seek to. You may also jump to a relative position within a song by putting an
+ # explicit "-" or "+" character before a number of seconds you would like to seek.
+ player_queue = self.mass.players.queues.get_active_queue(player_id)
+ assert player_queue is not None
+
+ if number == "?":
+ return int(player_queue.corrected_elapsed_time)
+
+ if isinstance(number, str) and "+" in number or "-" in number:
+ jump = int(number.split("+")[1])
+ self.mass.create_task(self.mass.players.queues.skip, jump)
+ else:
+ self.mass.create_task(self.mass.players.queues.seek, number)
+
+ def _handle_playlist(
+ self,
+ player_id: str,
+ subcommand: str,
+ *args,
+ **kwargs,
+ ) -> int | None:
+ """Handle player `playlist` command."""
+ arg = args[0] if args else "?"
+ queue = self.mass.players.queues.get_active_queue(player_id)
+ assert queue is not None
+
+ # <playerid> playlist index <index|+index|-index|?> <fadeInSecs>
+ if subcommand == "index" and isinstance(arg, int):
+ self.mass.create_task(self.mass.players.queues.play_index, player_id, arg)
+ return
+ if subcommand == "index" and arg == "?":
+ return queue.current_index
+ if subcommand == "index" and "+" in arg:
+ next_index = (queue.current_index or 0) + int(arg.split("+")[1])
+ self.mass.create_task(self.mass.players.queues.play_index, player_id, next_index)
+ return
+ if subcommand == "index" and "-" in arg:
+ next_index = (queue.current_index or 0) - int(arg.split("-")[1])
+ self.mass.create_task(self.mass.players.queues.play_index, player_id, next_index)
+ return
+
+ self.logger.warning("Unhandled command: playlist/%s", subcommand)
+
+ def _handle_play(
+ self,
+ player_id: str,
+ *args,
+ **kwargs,
+ ) -> int | None:
+ """Handle player `play` command."""
+ queue = self.mass.players.queues.get_active_queue(player_id)
+ assert queue is not None
+ self.mass.create_task(self.mass.players.queues.play, player_id)
+
+ def _handle_stop(
+ self,
+ player_id: str,
+ *args,
+ **kwargs,
+ ) -> int | None:
+ """Handle player `stop` command."""
+ queue = self.mass.players.queues.get_active_queue(player_id)
+ assert queue is not None
+ self.mass.create_task(self.mass.players.queues.stop, player_id)
+
+ def _handle_pause(
+ self,
+ player_id: str,
+ force: int = 0,
+ *args,
+ **kwargs,
+ ) -> int | None:
+ """Handle player `stop` command."""
+ queue = self.mass.players.queues.get_active_queue(player_id)
+ assert queue is not None
+
+ if force or queue.state == PlayerState.PLAYING:
+ self.mass.create_task(self.mass.players.queues.pause, player_id)
+ else:
+ self.mass.create_task(self.mass.players.queues.play, player_id)
--- /dev/null
+{
+ "type": "plugin",
+ "domain": "json_rpc",
+ "name": "JSON-RPC API",
+ "description": "Basic JSON-RPC API implementation, (partly) compatible with Logitech Media Server.",
+ "codeowners": ["@marcelveldt"],
+ "config_entries": [
+ ],
+
+ "requirements": [],
+ "documentation": "",
+ "multi_instance": false,
+ "builtin": true,
+ "load_by_default": true
+}
--- /dev/null
+"""Models used for the JSON-RPC API."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, TypedDict
+
+from music_assistant.common.models.enums import MediaType, PlayerState, RepeatMode
+
+if TYPE_CHECKING:
+ from music_assistant.common.models.player import Player
+ from music_assistant.common.models.player_queue import PlayerQueue
+ from music_assistant.common.models.queue_item import QueueItem
+
+# ruff: noqa: UP013
+
+PLAYMODE_MAP = {
+ PlayerState.IDLE: "stop",
+ PlayerState.PLAYING: "play",
+ PlayerState.OFF: "stop",
+ PlayerState.PAUSED: "pause",
+}
+
+REPEATMODE_MAP = {RepeatMode.OFF: 0, RepeatMode.ONE: 1, RepeatMode.ALL: 2}
+
+
+class CommandMessage(TypedDict):
+ """Representation of Base JSON RPC Command Message."""
+
+ # https://www.jsonrpc.org/specification
+
+ id: int | str
+ method: str
+ params: list[str | int | list[str | int]]
+
+
+class CommandResultMessage(CommandMessage):
+ """Representation of JSON RPC Result Message."""
+
+ result: Any
+
+
+class ErrorDetails(TypedDict):
+ """Representation of JSON RPC ErrorDetails."""
+
+ code: int
+ message: str
+
+
+class CommandErrorMessage(CommandMessage, TypedDict):
+ """Base Representation of JSON RPC Command Message."""
+
+ id: int | str | None
+ error: ErrorDetails
+
+
+PlayerItem = TypedDict(
+ "PlayerItem",
+ {
+ "playerindex": int,
+ "playerid": str,
+ "name": str,
+ "modelname": str,
+ "connected": int,
+ "isplaying": int,
+ "power": int,
+ "model": str,
+ "canpoweroff": int,
+ "firmware": int,
+ "isplayer": int,
+ "displaytype": str,
+ "uuid": str | None,
+ "seq_no": int,
+ "ip": str,
+ },
+)
+
+
+def player_item_from_mass(playerindex: int, player: Player) -> PlayerItem:
+ """Parse PlayerItem for the Json RPC interface from MA QueueItem."""
+ return {
+ "playerindex": playerindex,
+ "playerid": player.player_id,
+ "name": player.display_name,
+ "modelname": player.device_info.model,
+ "connected": int(player.available),
+ "isplaying": 1 if player.state == PlayerState.PLAYING else 0,
+ "power": int(player.powered),
+ "model": "squeezelite",
+ "canpoweroff": 1,
+ "firmware": 0,
+ "isplayer": 1,
+ "displaytype": None,
+ "uuid": None,
+ "seq_no": 0,
+ "ip": player.device_info.address,
+ }
+
+
+PlayersResponse = TypedDict(
+ "PlayersResponse",
+ {
+ "count": int,
+ "players_loop": list[PlayerItem],
+ },
+)
+
+
+PlaylistItem = TypedDict(
+ "PlaylistItem",
+ {
+ "playlist index": int,
+ "id": str,
+ "title": str,
+ "artist": str,
+ "remote": int,
+ "remote_title": str,
+ "artwork_url": str,
+ "bitrate": str,
+ "duration": str | int | None,
+ "coverid": str,
+ },
+)
+
+
+def playlist_item_from_mass(queue_item: QueueItem, index: int = 0) -> PlaylistItem:
+ """Parse PlaylistItem for the Json RPC interface from MA QueueItem."""
+ if queue_item.media_item and queue_item.media_type == MediaType.TRACK:
+ artist = queue_item.media_item.artist.name
+ album = queue_item.media_item.album.name
+ title = queue_item.media_item.name
+ elif queue_item.streamdetails and queue_item.streamdetails.stream_title:
+ if " - " in queue_item.streamdetails.stream_title:
+ artist, title = queue_item.streamdetails.stream_title.split(" - ")
+ else:
+ artist = ""
+ title = queue_item.streamdetails.stream_title
+ album = queue_item.name
+ else:
+ artist = ""
+ album = ""
+ title = queue_item.name
+ return {
+ "playlist index": index,
+ "id": queue_item.queue_item_id,
+ "title": title,
+ "artist": artist,
+ "album": album,
+ "genre": "",
+ "remote": 0,
+ "remote_title": queue_item.streamdetails.stream_title if queue_item.streamdetails else "",
+ "artwork_url": queue_item.image.url if queue_item.image else "",
+ "bitrate": "",
+ "duration": queue_item.duration or 0,
+ "coverid": "-94099753136392",
+ }
+
+
+PlayerStatusResponse = TypedDict(
+ "PlayerStatusResponse",
+ {
+ "time": int,
+ "mode": str,
+ "sync_slaves": str,
+ "playlist_cur_index": int | None,
+ "player_name": str,
+ "sync_master": str,
+ "player_connected": int,
+ "power": int,
+ "mixer volume": int,
+ "playlist repeat": int,
+ "playlist shuffle": int,
+ "playlist mode": str,
+ "player_ip": str,
+ "remoteMeta": dict | None,
+ "digital_volume_control": int,
+ "playlist_timestamp": float,
+ "current_title": str,
+ "duration": int,
+ "seq_no": int,
+ "remote": int,
+ "can_seek": int,
+ "signalstrength": int,
+ "rate": int,
+ "playlist_tracks": int,
+ "playlist_loop": list[PlaylistItem],
+ },
+)
+
+
+def player_status_from_mass(
+ player: Player, queue: PlayerQueue, queue_items: list[QueueItem]
+) -> PlayerStatusResponse:
+ """Parse PlayerStatusResponse for the Json RPC interface from MA info."""
+ return {
+ "time": queue.corrected_elapsed_time,
+ "mode": PLAYMODE_MAP[queue.state],
+ "sync_slaves": ",".join(player.group_childs),
+ "playlist_cur_index": queue.current_index,
+ "player_name": player.display_name,
+ "sync_master": player.synced_to or "",
+ "player_connected": int(player.available),
+ "mixer volume": player.volume_level,
+ "power": int(player.powered),
+ "digital_volume_control": 1,
+ "playlist_timestamp": 0, # TODO !
+ "current_title": queue.current_item.queue_item_id
+ if queue.current_item
+ else "Music Assistant",
+ "duration": queue.current_item.duration if queue.current_item else 0,
+ "playlist repeat": REPEATMODE_MAP[queue.repeat_mode],
+ "playlist shuffle": int(queue.shuffle_enabled),
+ "playlist mode": "off",
+ "player_ip": player.device_info.address,
+ "seq_no": 0,
+ "remote": 0,
+ "can_seek": 1,
+ "signalstrength": 0,
+ "rate": 1,
+ "playlist_tracks": queue.items,
+ "playlist_loop": [
+ playlist_item_from_mass(item, queue.current_index + index)
+ for index, item in enumerate(queue_items)
+ ],
+ }
--- /dev/null
+"""The Musicbrainz Metadata provider for Music Assistant.
+
+At this time only used for retrieval of ID's but to be expanded to fetch metadata too.
+"""
+from __future__ import annotations
+
+import re
+from collections.abc import Iterable
+from json import JSONDecodeError
+from typing import TYPE_CHECKING, Any
+
+import aiohttp.client_exceptions
+from asyncio_throttle import Throttler
+
+from music_assistant.common.helpers.util import create_sort_name
+from music_assistant.common.models.enums import ProviderFeature
+from music_assistant.server.controllers.cache import use_cache
+from music_assistant.server.helpers.compare import compare_strings
+from music_assistant.server.models.metadata_provider import MetadataProvider
+
+if TYPE_CHECKING:
+ from music_assistant.common.models.media_items import Album, Artist, Track
+
+
+LUCENE_SPECIAL = r'([+\-&|!(){}\[\]\^"~*?:\\\/])'
+
+
+class MusicbrainzProvider(MetadataProvider):
+ """The Musicbrainz Metadata provider."""
+
+ throttler: Throttler
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ self.cache = self.mass.cache
+ self.throttler = Throttler(rate_limit=1, period=1)
+ self._attr_supported_features = (ProviderFeature.GET_ARTIST_MBID,)
+
+ async def get_musicbrainz_artist_id(
+ self, artist: Artist, ref_albums: Iterable[Album], ref_tracks: Iterable[Track]
+ ) -> str | None:
+ """Discover MusicBrainzArtistId for an artist given some reference albums/tracks."""
+ for ref_album in ref_albums:
+ # try matching on album musicbrainz id
+ if ref_album.musicbrainz_id: # noqa: SIM102
+ if musicbrainz_id := await self._search_artist_by_album_mbid(
+ artistname=artist.name, album_mbid=ref_album.musicbrainz_id
+ ):
+ return musicbrainz_id
+ # try matching on album upc
+ if ref_album.upc and (
+ musicbrainz_id := await self._search_artist_by_album(
+ artistname=artist.name,
+ album_upc=ref_album.upc,
+ )
+ ):
+ return musicbrainz_id
+
+ # try again with matching on track isrc
+ for ref_track in ref_tracks:
+ for isrc in ref_track.isrcs:
+ if musicbrainz_id := await self._search_artist_by_track(
+ artistname=artist.name,
+ track_isrc=isrc,
+ ):
+ return musicbrainz_id
+
+ # last restort: track matching by name
+ for ref_track in ref_tracks:
+ if musicbrainz_id := await self._search_artist_by_track(
+ artistname=artist.name,
+ trackname=ref_track.name,
+ ):
+ return musicbrainz_id
+
+ return None
+
+ async def _search_artist_by_album(
+ self,
+ artistname: str,
+ albumname: str | None = None,
+ album_upc: str | None = None,
+ ) -> str | None:
+ """Retrieve musicbrainz artist id by providing the artist name and albumname or upc."""
+ assert albumname or album_upc
+ for searchartist in (
+ artistname,
+ re.sub(LUCENE_SPECIAL, r"\\\1", create_sort_name(artistname)),
+ ):
+ if album_upc:
+ # search by album UPC (barcode)
+ query = f"barcode:{album_upc}"
+ elif albumname:
+ # search by name
+ searchalbum = re.sub(LUCENE_SPECIAL, r"\\\1", albumname)
+ query = f'artist:"{searchartist}" AND release:"{searchalbum}"'
+ result = await self.get_data("release", query=query)
+ if result and "releases" in result:
+ for strict in (True, False):
+ for item in result["releases"]:
+ if not (
+ album_upc
+ or (albumname and compare_strings(item["title"], albumname, strict))
+ ):
+ continue
+ for artist in item["artist-credit"]:
+ if compare_strings(artist["artist"]["name"], artistname, strict):
+ return artist["artist"]["id"] # type: ignore[no-any-return]
+ for alias in artist.get("aliases", []):
+ if compare_strings(alias["name"], artistname, strict):
+ return artist["id"] # type: ignore[no-any-return]
+ return None
+
+ async def _search_artist_by_track(
+ self,
+ artistname: str,
+ trackname: str | None = None,
+ track_isrc: str | None = None,
+ ) -> str | None:
+ """Retrieve artist id by providing the artist name and trackname or track isrc."""
+ assert trackname or track_isrc
+ searchartist = re.sub(LUCENE_SPECIAL, r"\\\1", artistname)
+ if track_isrc:
+ result = await self.get_data(f"isrc/{track_isrc}", inc="artist-credits")
+ elif trackname:
+ searchtrack = re.sub(LUCENE_SPECIAL, r"\\\1", trackname)
+ result = await self.get_data(
+ "recording", query=f'"{searchtrack}" AND artist:"{searchartist}"'
+ )
+ if result and "recordings" in result:
+ for strict in (True, False):
+ for item in result["recordings"]:
+ if not (
+ track_isrc
+ or (trackname and compare_strings(item["title"], trackname, strict))
+ ):
+ continue
+ for artist in item["artist-credit"]:
+ if compare_strings(artist["artist"]["name"], artistname, strict):
+ return artist["artist"]["id"] # type: ignore[no-any-return]
+ for alias in artist["artist"].get("aliases", []):
+ if compare_strings(alias["name"], artistname, strict):
+ return artist["artist"]["id"] # type: ignore[no-any-return]
+ return None
+
+ async def _search_artist_by_album_mbid(self, artistname: str, album_mbid: str) -> str | None:
+ """Retrieve musicbrainz artist id by providing the artist name and albumname or upc."""
+ result = await self.get_data(f"release-group/{album_mbid}?inc=artist-credits")
+ if result and "artist-credit" in result:
+ for item in result["artist-credit"]:
+ if (artist := item.get("artist")) and compare_strings(artistname, artist["name"]):
+ return artist["id"] # type: ignore[no-any-return]
+ return None
+
+ @use_cache(86400 * 30)
+ async def get_data(self, endpoint: str, **kwargs: dict[str, Any]) -> Any:
+ """Get data from api."""
+ url = f"http://musicbrainz.org/ws/2/{endpoint}"
+ headers = {"User-Agent": "Music Assistant/1.0.0 https://github.com/music-assistant"}
+ kwargs["fmt"] = "json" # type: ignore[assignment]
+ async with self.throttler:
+ async with self.mass.http_session.get(
+ url, headers=headers, params=kwargs, verify_ssl=False
+ ) as response:
+ try:
+ result = await response.json()
+ except (
+ aiohttp.client_exceptions.ContentTypeError,
+ JSONDecodeError,
+ ) as exc:
+ msg = await response.text()
+ self.logger.warning("%s - %s", str(exc), msg)
+ result = None
+ return result
--- /dev/null
+{
+ "type": "metadata",
+ "domain": "musicbrainz",
+ "name": "MusicBrainz Metadata provider",
+ "description": "MusicBrainz is an open music encyclopedia that collects music metadata and makes it available to the public.",
+ "codeowners": ["@music-assistant"],
+ "config_entries": [
+ ],
+ "requirements": [],
+ "documentation": "",
+ "multi_instance": false,
+ "builtin": true,
+ "load_by_default": true
+}
--- /dev/null
+"""Qobuz musicprovider support for MusicAssistant."""
+from __future__ import annotations
+
+import datetime
+import hashlib
+import time
+from collections.abc import AsyncGenerator
+from json import JSONDecodeError
+
+import aiohttp
+from asyncio_throttle import Throttler
+
+from music_assistant.common.helpers.util import parse_title_and_version, try_parse_int
+from music_assistant.common.models.enums import ProviderFeature
+from music_assistant.common.models.errors import LoginFailed, MediaNotFoundError
+from music_assistant.common.models.media_items import (
+ Album,
+ AlbumType,
+ Artist,
+ ContentType,
+ ImageType,
+ MediaItemImage,
+ MediaItemType,
+ MediaType,
+ Playlist,
+ ProviderMapping,
+ StreamDetails,
+ Track,
+)
+from music_assistant.constants import CONF_PASSWORD, CONF_USERNAME
+from music_assistant.server.helpers.app_vars import app_var # pylint: disable=no-name-in-module
+from music_assistant.server.models.music_provider import MusicProvider
+
+
+class QobuzProvider(MusicProvider):
+ """Provider for the Qobux music service."""
+
+ _user_auth_info: str | None = None
+ _throttler: Throttler
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ self._throttler = Throttler(rate_limit=4, period=1)
+ self._attr_supported_features = (
+ ProviderFeature.LIBRARY_ARTISTS,
+ ProviderFeature.LIBRARY_ALBUMS,
+ ProviderFeature.LIBRARY_TRACKS,
+ ProviderFeature.LIBRARY_PLAYLISTS,
+ ProviderFeature.LIBRARY_ARTISTS_EDIT,
+ ProviderFeature.LIBRARY_ALBUMS_EDIT,
+ ProviderFeature.LIBRARY_PLAYLISTS_EDIT,
+ ProviderFeature.LIBRARY_TRACKS_EDIT,
+ ProviderFeature.PLAYLIST_TRACKS_EDIT,
+ ProviderFeature.BROWSE,
+ ProviderFeature.SEARCH,
+ ProviderFeature.ARTIST_ALBUMS,
+ ProviderFeature.ARTIST_TOPTRACKS,
+ )
+ if not self.config.get_value(CONF_USERNAME) or not self.config.get_value(CONF_PASSWORD):
+ raise LoginFailed("Invalid login credentials")
+ # try to get a token, raise if that fails
+ token = await self._auth_token()
+ if not token:
+ raise LoginFailed(f"Login failed for user {self.config.get_value(CONF_USERNAME)}")
+
+ async def search(
+ self, search_query: str, media_types=list[MediaType] | None, limit: int = 5
+ ) -> list[MediaItemType]:
+ """Perform search on musicprovider.
+
+ :param search_query: Search query.
+ :param media_types: A list of media_types to include. All types if None.
+ :param limit: Number of items to return in the search (per type).
+ """
+ result = []
+ params = {"query": search_query, "limit": limit}
+ if len(media_types) == 1:
+ # qobuz does not support multiple searchtypes, falls back to all if no type given
+ if media_types[0] == MediaType.ARTIST:
+ params["type"] = "artists"
+ if media_types[0] == MediaType.ALBUM:
+ params["type"] = "albums"
+ if media_types[0] == MediaType.TRACK:
+ params["type"] = "tracks"
+ if media_types[0] == MediaType.PLAYLIST:
+ params["type"] = "playlists"
+ if searchresult := await self._get_data("catalog/search", **params):
+ if "artists" in searchresult:
+ result += [
+ await self._parse_artist(item)
+ for item in searchresult["artists"]["items"]
+ if (item and item["id"])
+ ]
+ if "albums" in searchresult:
+ result += [
+ await self._parse_album(item)
+ for item in searchresult["albums"]["items"]
+ if (item and item["id"])
+ ]
+ if "tracks" in searchresult:
+ result += [
+ await self._parse_track(item)
+ for item in searchresult["tracks"]["items"]
+ if (item and item["id"])
+ ]
+ if "playlists" in searchresult:
+ result += [
+ await self._parse_playlist(item)
+ for item in searchresult["playlists"]["items"]
+ if (item and item["id"])
+ ]
+ return result
+
+ async def get_library_artists(self) -> AsyncGenerator[Artist, None]:
+ """Retrieve all library artists from Qobuz."""
+ endpoint = "favorite/getUserFavorites"
+ for item in await self._get_all_items(endpoint, key="artists", type="artists"):
+ if item and item["id"]:
+ yield await self._parse_artist(item)
+
+ async def get_library_albums(self) -> AsyncGenerator[Album, None]:
+ """Retrieve all library albums from Qobuz."""
+ endpoint = "favorite/getUserFavorites"
+ for item in await self._get_all_items(endpoint, key="albums", type="albums"):
+ if item and item["id"]:
+ yield await self._parse_album(item)
+
+ async def get_library_tracks(self) -> AsyncGenerator[Track, None]:
+ """Retrieve library tracks from Qobuz."""
+ endpoint = "favorite/getUserFavorites"
+ for item in await self._get_all_items(endpoint, key="tracks", type="tracks"):
+ if item and item["id"]:
+ yield await self._parse_track(item)
+
+ async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]:
+ """Retrieve all library playlists from the provider."""
+ endpoint = "playlist/getUserPlaylists"
+ for item in await self._get_all_items(endpoint, key="playlists"):
+ if item and item["id"]:
+ yield await self._parse_playlist(item)
+
+ async def get_artist(self, prov_artist_id) -> Artist:
+ """Get full artist details by id."""
+ params = {"artist_id": prov_artist_id}
+ artist_obj = await self._get_data("artist/get", **params)
+ return await self._parse_artist(artist_obj) if artist_obj and artist_obj["id"] else None
+
+ async def get_album(self, prov_album_id) -> Album:
+ """Get full album details by id."""
+ params = {"album_id": prov_album_id}
+ album_obj = await self._get_data("album/get", **params)
+ return await self._parse_album(album_obj) if album_obj and album_obj["id"] else None
+
+ async def get_track(self, prov_track_id) -> Track:
+ """Get full track details by id."""
+ params = {"track_id": prov_track_id}
+ track_obj = await self._get_data("track/get", **params)
+ return await self._parse_track(track_obj) if track_obj and track_obj["id"] else None
+
+ async def get_playlist(self, prov_playlist_id) -> Playlist:
+ """Get full playlist details by id."""
+ params = {"playlist_id": prov_playlist_id}
+ playlist_obj = await self._get_data("playlist/get", **params)
+ return (
+ await self._parse_playlist(playlist_obj)
+ if playlist_obj and playlist_obj["id"]
+ else None
+ )
+
+ async def get_album_tracks(self, prov_album_id) -> list[Track]:
+ """Get all album tracks for given album id."""
+ params = {"album_id": prov_album_id}
+ return [
+ await self._parse_track(item)
+ for item in await self._get_all_items("album/get", **params, key="tracks")
+ if (item and item["id"])
+ ]
+
+ async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]:
+ """Get all playlist tracks for given playlist id."""
+ count = 0
+ result = []
+ for item in await self._get_all_items(
+ "playlist/get",
+ key="tracks",
+ playlist_id=prov_playlist_id,
+ extra="tracks",
+ ):
+ if not (item and item["id"]):
+ continue
+ track = await self._parse_track(item)
+ # use count as position
+ track.position = count
+ result.append(track)
+ count += 1
+ return result
+
+ async def get_artist_albums(self, prov_artist_id) -> list[Album]:
+ """Get a list of albums for the given artist."""
+ endpoint = "artist/get"
+ return [
+ await self._parse_album(item)
+ for item in await self._get_all_items(
+ endpoint, key="albums", artist_id=prov_artist_id, extra="albums"
+ )
+ if (item and item["id"] and str(item["artist"]["id"]) == prov_artist_id)
+ ]
+
+ async def get_artist_toptracks(self, prov_artist_id) -> list[Track]:
+ """Get a list of most popular tracks for the given artist."""
+ result = await self._get_data(
+ "artist/get",
+ artist_id=prov_artist_id,
+ extra="playlists",
+ offset=0,
+ limit=25,
+ )
+ if result and result["playlists"]:
+ return [
+ await self._parse_track(item)
+ for item in result["playlists"][0]["tracks"]["items"]
+ if (item and item["id"])
+ ]
+ # fallback to search
+ artist = await self.get_artist(prov_artist_id)
+ searchresult = await self._get_data(
+ "catalog/search", query=artist.name, limit=25, type="tracks"
+ )
+ return [
+ await self._parse_track(item)
+ for item in searchresult["tracks"]["items"]
+ if (
+ item
+ and item["id"]
+ and "performer" in item
+ and str(item["performer"]["id"]) == str(prov_artist_id)
+ )
+ ]
+
+ async def get_similar_artists(self, prov_artist_id):
+ """Get similar artists for given artist."""
+ # https://www.qobuz.com/api.json/0.2/artist/getSimilarArtists?artist_id=220020&offset=0&limit=3
+
+ async def library_add(self, prov_item_id, media_type: MediaType):
+ """Add item to library."""
+ result = None
+ if media_type == MediaType.ARTIST:
+ result = await self._get_data("favorite/create", artist_id=prov_item_id)
+ elif media_type == MediaType.ALBUM:
+ result = await self._get_data("favorite/create", album_ids=prov_item_id)
+ elif media_type == MediaType.TRACK:
+ result = await self._get_data("favorite/create", track_ids=prov_item_id)
+ elif media_type == MediaType.PLAYLIST:
+ result = await self._get_data("playlist/subscribe", playlist_id=prov_item_id)
+ return result
+
+ async def library_remove(self, prov_item_id, media_type: MediaType):
+ """Remove item from library."""
+ result = None
+ if media_type == MediaType.ARTIST:
+ result = await self._get_data("favorite/delete", artist_ids=prov_item_id)
+ elif media_type == MediaType.ALBUM:
+ result = await self._get_data("favorite/delete", album_ids=prov_item_id)
+ elif media_type == MediaType.TRACK:
+ result = await self._get_data("favorite/delete", track_ids=prov_item_id)
+ elif media_type == MediaType.PLAYLIST:
+ playlist = await self.get_playlist(prov_item_id)
+ if playlist.is_editable:
+ result = await self._get_data("playlist/delete", playlist_id=prov_item_id)
+ else:
+ result = await self._get_data("playlist/unsubscribe", playlist_id=prov_item_id)
+ return result
+
+ async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]) -> None:
+ """Add track(s) to playlist."""
+ return await self._get_data(
+ "playlist/addTracks",
+ playlist_id=prov_playlist_id,
+ track_ids=",".join(prov_track_ids),
+ playlist_track_ids=",".join(prov_track_ids),
+ )
+
+ async def remove_playlist_tracks(
+ self, prov_playlist_id: str, positions_to_remove: tuple[int]
+ ) -> None:
+ """Remove track(s) from playlist."""
+ playlist_track_ids = set()
+ for track in await self.get_playlist_tracks(prov_playlist_id):
+ if track.position in positions_to_remove:
+ playlist_track_ids.add(str(track["playlist_track_id"]))
+ if len(playlist_track_ids) == positions_to_remove:
+ break
+ return await self._get_data(
+ "playlist/deleteTracks",
+ playlist_id=prov_playlist_id,
+ playlist_track_ids=",".join(playlist_track_ids),
+ )
+
+ async def get_stream_details(self, item_id: str) -> StreamDetails:
+ """Return the content details for the given track when it will be streamed."""
+ streamdata = None
+ for format_id in [27, 7, 6, 5]:
+ # it seems that simply requesting for highest available quality does not work
+ # from time to time the api response is empty for this request ?!
+ result = await self._get_data(
+ "track/getFileUrl",
+ sign_request=True,
+ format_id=format_id,
+ track_id=item_id,
+ intent="stream",
+ )
+ if result and result.get("url"):
+ streamdata = result
+ break
+ if not streamdata:
+ raise MediaNotFoundError(f"Unable to retrieve stream details for {item_id}")
+ if streamdata["mime_type"] == "audio/mpeg":
+ content_type = ContentType.MPEG
+ elif streamdata["mime_type"] == "audio/flac":
+ content_type = ContentType.FLAC
+ else:
+ raise MediaNotFoundError(f"Unsupported mime type for {item_id}")
+ # report playback started as soon as the streamdetails are requested
+ self.mass.create_task(self._report_playback_started(streamdata))
+ return StreamDetails(
+ item_id=str(item_id),
+ provider=self.domain,
+ content_type=content_type,
+ duration=streamdata["duration"],
+ sample_rate=int(streamdata["sampling_rate"] * 1000),
+ bit_depth=streamdata["bit_depth"],
+ data=streamdata, # we need these details for reporting playback
+ expires=time.time() + 3600, # not sure about the real allowed value
+ direct=streamdata["url"],
+ callback=self._report_playback_stopped,
+ )
+
+ async def _report_playback_started(self, streamdata: dict) -> None:
+ """Report playback start to qobuz."""
+ # TODO: need to figure out if the streamed track is purchased by user
+ # https://www.qobuz.com/api.json/0.2/purchase/getUserPurchasesIds?limit=5000&user_id=xxxxxxx
+ # {"albums":{"total":0,"items":[]},
+ # "tracks":{"total":0,"items":[]},"user":{"id":xxxx,"login":"xxxxx"}}
+ device_id = self._user_auth_info["user"]["device"]["id"]
+ credential_id = self._user_auth_info["user"]["credential"]["id"]
+ user_id = self._user_auth_info["user"]["id"]
+ format_id = streamdata["format_id"]
+ timestamp = int(time.time())
+ events = [
+ {
+ "online": True,
+ "sample": False,
+ "intent": "stream",
+ "device_id": device_id,
+ "track_id": streamdata["track_id"],
+ "purchase": False,
+ "date": timestamp,
+ "credential_id": credential_id,
+ "user_id": user_id,
+ "local": False,
+ "format_id": format_id,
+ }
+ ]
+ await self._post_data("track/reportStreamingStart", data=events)
+
+ async def _report_playback_stopped(self, streamdetails: StreamDetails) -> None:
+ """Report playback stop to qobuz."""
+ user_id = self._user_auth_info["user"]["id"]
+ await self._get_data(
+ "/track/reportStreamingEnd",
+ user_id=user_id,
+ track_id=str(streamdetails.item_id),
+ duration=try_parse_int(streamdetails.seconds_streamed),
+ )
+
+ async def _parse_artist(self, artist_obj: dict):
+ """Parse qobuz artist object to generic layout."""
+ artist = Artist(
+ item_id=str(artist_obj["id"]), provider=self.domain, name=artist_obj["name"]
+ )
+ artist.add_provider_mapping(
+ ProviderMapping(
+ item_id=str(artist_obj["id"]),
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ url=artist_obj.get("url", f'https://open.qobuz.com/artist/{artist_obj["id"]}'),
+ )
+ )
+ if img := self.__get_image(artist_obj):
+ artist.metadata.images = [MediaItemImage(ImageType.THUMB, img)]
+ if artist_obj.get("biography"):
+ artist.metadata.description = artist_obj["biography"].get("content")
+ return artist
+
+ async def _parse_album(self, album_obj: dict, artist_obj: dict = None):
+ """Parse qobuz album object to generic layout."""
+ if not artist_obj and "artist" not in album_obj:
+ # artist missing in album info, return full abum instead
+ return await self.get_album(album_obj["id"])
+ name, version = parse_title_and_version(album_obj["title"], album_obj.get("version"))
+ album = Album(
+ item_id=str(album_obj["id"]),
+ provider=self.domain,
+ name=name,
+ version=version,
+ )
+ album.add_provider_mapping(
+ ProviderMapping(
+ item_id=str(album_obj["id"]),
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ available=album_obj["streamable"] and album_obj["displayable"],
+ content_type=ContentType.FLAC,
+ sample_rate=album_obj["maximum_sampling_rate"] * 1000,
+ bit_depth=album_obj["maximum_bit_depth"],
+ url=album_obj.get("url", f'https://open.qobuz.com/album/{album_obj["id"]}'),
+ )
+ )
+
+ album.artist = await self._parse_artist(artist_obj or album_obj["artist"])
+ if (
+ album_obj.get("product_type", "") == "single"
+ or album_obj.get("release_type", "") == "single"
+ ):
+ album.album_type = AlbumType.SINGLE
+ elif album_obj.get("product_type", "") == "compilation" or "Various" in album.artist.name:
+ album.album_type = AlbumType.COMPILATION
+ elif (
+ album_obj.get("product_type", "") == "album"
+ or album_obj.get("release_type", "") == "album"
+ ):
+ album.album_type = AlbumType.ALBUM
+ if "genre" in album_obj:
+ album.metadata.genres = {album_obj["genre"]["name"]}
+ if img := self.__get_image(album_obj):
+ album.metadata.images = [MediaItemImage(ImageType.THUMB, img)]
+ if len(album_obj["upc"]) == 13:
+ # qobuz writes ean as upc ?!
+ album.upc = album_obj["upc"][1:]
+ else:
+ album.upc = album_obj["upc"]
+ if "label" in album_obj:
+ album.metadata.label = album_obj["label"]["name"]
+ if album_obj.get("released_at"):
+ album.year = datetime.datetime.fromtimestamp(album_obj["released_at"]).year
+ if album_obj.get("copyright"):
+ album.metadata.copyright = album_obj["copyright"]
+ if album_obj.get("description"):
+ album.metadata.description = album_obj["description"]
+ return album
+
+ async def _parse_track(self, track_obj: dict):
+ """Parse qobuz track object to generic layout."""
+ # pylint: disable=too-many-branches
+ name, version = parse_title_and_version(track_obj["title"], track_obj.get("version"))
+ track = Track(
+ item_id=str(track_obj["id"]),
+ provider=self.domain,
+ name=name,
+ version=version,
+ disc_number=track_obj["media_number"],
+ track_number=track_obj["track_number"],
+ duration=track_obj["duration"],
+ position=track_obj.get("position"),
+ )
+ if track_obj.get("performer") and "Various " not in track_obj["performer"]:
+ artist = await self._parse_artist(track_obj["performer"])
+ if artist:
+ track.artists.append(artist)
+ # try to grab artist from album
+ if not track.artists and (
+ track_obj.get("album")
+ and track_obj["album"].get("artist")
+ and "Various " not in track_obj["album"]["artist"]
+ ):
+ artist = await self._parse_artist(track_obj["album"]["artist"])
+ if artist:
+ track.artists.append(artist)
+ if not track.artists:
+ # last resort: parse from performers string
+ for performer_str in track_obj["performers"].split(" - "):
+ role = performer_str.split(", ")[1]
+ name = performer_str.split(", ")[0]
+ if "artist" in role.lower():
+ artist = Artist(name, self.domain, name)
+ track.artists.append(artist)
+ # TODO: fix grabbing composer from details
+
+ if "album" in track_obj:
+ album = await self._parse_album(track_obj["album"])
+ if album:
+ track.album = album
+ if track_obj.get("isrc"):
+ track.isrc = track_obj["isrc"]
+ if track_obj.get("performers"):
+ track.metadata.performers = {x.strip() for x in track_obj["performers"].split("-")}
+ if track_obj.get("copyright"):
+ track.metadata.copyright = track_obj["copyright"]
+ if track_obj.get("audio_info"):
+ track.metadata.replaygain = track_obj["audio_info"]["replaygain_track_gain"]
+ if track_obj.get("parental_warning"):
+ track.metadata.explicit = True
+ if img := self.__get_image(track_obj):
+ track.metadata.images = [MediaItemImage(ImageType.THUMB, img)]
+
+ track.add_provider_mapping(
+ ProviderMapping(
+ item_id=str(track_obj["id"]),
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ available=track_obj["streamable"] and track_obj["displayable"],
+ content_type=ContentType.FLAC,
+ sample_rate=track_obj["maximum_sampling_rate"] * 1000,
+ bit_depth=track_obj["maximum_bit_depth"],
+ url=track_obj.get("url", f'https://open.qobuz.com/track/{track_obj["id"]}'),
+ )
+ )
+ return track
+
+ async def _parse_playlist(self, playlist_obj):
+ """Parse qobuz playlist object to generic layout."""
+ playlist = Playlist(
+ item_id=str(playlist_obj["id"]),
+ provider=self.domain,
+ name=playlist_obj["name"],
+ owner=playlist_obj["owner"]["name"],
+ )
+ playlist.add_provider_mapping(
+ ProviderMapping(
+ item_id=str(playlist_obj["id"]),
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ url=playlist_obj.get(
+ "url", f'https://open.qobuz.com/playlist/{playlist_obj["id"]}'
+ ),
+ )
+ )
+ playlist.is_editable = (
+ playlist_obj["owner"]["id"] == self._user_auth_info["user"]["id"]
+ or playlist_obj["is_collaborative"]
+ )
+ if img := self.__get_image(playlist_obj):
+ playlist.metadata.images = [MediaItemImage(ImageType.THUMB, img)]
+ playlist.metadata.checksum = str(playlist_obj["updated_at"])
+ return playlist
+
+ async def _auth_token(self):
+ """Login to qobuz and store the token."""
+ if self._user_auth_info:
+ return self._user_auth_info["user_auth_token"]
+ params = {
+ "username": self.config.get_value(CONF_USERNAME),
+ "password": self.config.get_value(CONF_PASSWORD),
+ "device_manufacturer_id": "music_assistant",
+ }
+ details = await self._get_data("user/login", **params)
+ if details and "user" in details:
+ self._user_auth_info = details
+ self.logger.info(
+ "Successfully logged in to Qobuz as %s", details["user"]["display_name"]
+ )
+ self.mass.metadata.preferred_language = details["user"]["country_code"]
+ return details["user_auth_token"]
+ return None
+
+ async def _get_all_items(self, endpoint, key="tracks", **kwargs):
+ """Get all items from a paged list."""
+ limit = 50
+ offset = 0
+ all_items = []
+ while True:
+ kwargs["limit"] = limit
+ kwargs["offset"] = offset
+ result = await self._get_data(endpoint, **kwargs)
+ offset += limit
+ if not result:
+ break
+ if not result.get(key) or not result[key].get("items"):
+ break
+ for item in result[key]["items"]:
+ item["position"] = len(all_items) + 1
+ all_items.append(item)
+ if len(result[key]["items"]) < limit:
+ break
+ return all_items
+
+ async def _get_data(self, endpoint, sign_request=False, **kwargs):
+ """Get data from api."""
+ # pylint: disable=too-many-branches
+ url = f"http://www.qobuz.com/api.json/0.2/{endpoint}"
+ headers = {"X-App-Id": app_var(0)}
+ if endpoint != "user/login":
+ auth_token = await self._auth_token()
+ if not auth_token:
+ self.logger.debug("Not logged in")
+ return None
+ headers["X-User-Auth-Token"] = auth_token
+ if sign_request:
+ signing_data = "".join(endpoint.split("/"))
+ keys = list(kwargs.keys())
+ keys.sort()
+ for key in keys:
+ signing_data += f"{key}{kwargs[key]}"
+ request_ts = str(time.time())
+ request_sig = signing_data + request_ts + app_var(1)
+ request_sig = str(hashlib.md5(request_sig.encode()).hexdigest())
+ kwargs["request_ts"] = request_ts
+ kwargs["request_sig"] = request_sig
+ kwargs["app_id"] = app_var(0)
+ kwargs["user_auth_token"] = await self._auth_token()
+ async with self._throttler:
+ async with self.mass.http_session.get(
+ url, headers=headers, params=kwargs, verify_ssl=False
+ ) as response:
+ try:
+ # make sure status is 200
+ assert response.status == 200
+ result = await response.json()
+ # check for error in json
+ if error := result.get("error"):
+ raise ValueError(error)
+ if result.get("status") and "error" in result["status"]:
+ raise ValueError(result["status"])
+ except (
+ aiohttp.ContentTypeError,
+ JSONDecodeError,
+ AssertionError,
+ ValueError,
+ ) as err:
+ text = await response.text()
+ self.logger.exception(
+ "Error while processing %s: %s", endpoint, text, exc_info=err
+ )
+ return None
+ return result
+
+ async def _post_data(self, endpoint, params=None, data=None):
+ """Post data to api."""
+ if not params:
+ params = {}
+ if not data:
+ data = {}
+ url = f"http://www.qobuz.com/api.json/0.2/{endpoint}"
+ params["app_id"] = app_var(0)
+ params["user_auth_token"] = await self._auth_token()
+ async with self.mass.http_session.post(
+ url, params=params, json=data, verify_ssl=False
+ ) as response:
+ try:
+ result = await response.json()
+ # check for error in json
+ if error := result.get("error"):
+ raise ValueError(error)
+ if result.get("status") and "error" in result["status"]:
+ raise ValueError(result["status"])
+ except (
+ aiohttp.ContentTypeError,
+ JSONDecodeError,
+ AssertionError,
+ ValueError,
+ ) as err:
+ text = await response.text()
+ self.logger.exception("Error while processing %s: %s", endpoint, text, exc_info=err)
+ return None
+ return result
+
+ def __get_image(self, obj: dict) -> str | None:
+ """Try to parse image from Qobuz media object."""
+ if obj.get("image"):
+ for key in ["extralarge", "large", "medium", "small"]:
+ if obj["image"].get(key):
+ if "2a96cbd8b46e442fc41c2b86b821562f" in obj["image"][key]:
+ continue
+ return obj["image"][key]
+ if obj.get("images300"):
+ # playlists seem to use this strange format
+ return obj["images300"][0]
+ if obj.get("album"):
+ return self.__get_image(obj["album"])
+ if obj.get("artist"):
+ return self.__get_image(obj["artist"])
+ return None
--- /dev/null
+{
+ "type": "music",
+ "domain": "qobuz",
+ "name": "Qobuz",
+ "description": "Qobuz support for Music Assistant: Lossless (and hi-res) Music provider.",
+ "codeowners": ["@marcelveldt"],
+ "config_entries": [
+ {
+ "key": "username",
+ "type": "string",
+ "label": "Username"
+ },
+ {
+ "key": "password",
+ "type": "password",
+ "label": "Password"
+ }
+ ],
+
+ "requirements": [],
+ "documentation": "https://github.com/music-assistant/hass-music-assistant/discussions/817",
+ "multi_instance": true
+}
--- /dev/null
+"""Base/builtin provider with support for players using slimproto."""
+from __future__ import annotations
+
+import asyncio
+import time
+import urllib.parse
+from collections import deque
+from collections.abc import Callable, Generator
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any
+
+from aioslimproto.client import PlayerState as SlimPlayerState
+from aioslimproto.client import SlimClient
+from aioslimproto.client import TransitionType as SlimTransition
+from aioslimproto.const import EventType as SlimEventType
+from aioslimproto.discovery import start_discovery
+
+from music_assistant.common.helpers.util import select_free_port
+from music_assistant.common.models.config_entries import ConfigEntry
+from music_assistant.common.models.enums import (
+ ConfigEntryType,
+ ContentType,
+ PlayerFeature,
+ PlayerState,
+ PlayerType,
+)
+from music_assistant.common.models.errors import PlayerUnavailableError, QueueEmpty
+from music_assistant.common.models.player import DeviceInfo, Player
+from music_assistant.common.models.queue_item import QueueItem
+from music_assistant.server.models.player_provider import PlayerProvider
+from music_assistant.server.providers.json_rpc import parse_args
+
+if TYPE_CHECKING:
+ from music_assistant.common.models.config_entries import PlayerConfig
+ from music_assistant.server.providers.json_rpc import JSONRPCApi
+
+# sync constants
+MIN_DEVIATION_ADJUST = 10 # 10 milliseconds
+MAX_DEVIATION_ADJUST = 20000 # 10 seconds
+MIN_REQ_PLAYPOINTS = 8 # we need at least 8 measurements
+
+# TODO: Implement display support
+
+STATE_MAP = {
+ SlimPlayerState.BUFFERING: PlayerState.PLAYING,
+ SlimPlayerState.PAUSED: PlayerState.PAUSED,
+ SlimPlayerState.PLAYING: PlayerState.PLAYING,
+ SlimPlayerState.STOPPED: PlayerState.IDLE,
+}
+
+
+@dataclass
+class SyncPlayPoint:
+ """Simple structure to describe a Sync Playpoint."""
+
+ timestamp: float
+ item_id: str
+ diff: int
+
+
+CONF_SYNC_ADJUST = "sync_adjust"
+
+SLIM_PLAYER_CONFIG_ENTRIES = (
+ ConfigEntry(
+ key=CONF_SYNC_ADJUST,
+ type=ConfigEntryType.INTEGER,
+ range=(0, 1500),
+ default_value=0,
+ label="Correct synchronization delay",
+ description="If this player is playing audio synced with other players "
+ "and you always hear the audio too late on this player, you can shift the audio a bit.",
+ advanced=True,
+ ),
+)
+
+
+class SlimprotoProvider(PlayerProvider):
+ """Base/builtin provider for players using the SLIM protocol (aka slimproto)."""
+
+ _socket_servers: tuple[asyncio.Server | asyncio.BaseTransport]
+ _socket_clients: dict[str, SlimClient]
+ _sync_playpoints: dict[str, deque[SyncPlayPoint]]
+ _sync_adjusts: dict[str, int]
+ _virtual_providers: dict[str, tuple[Callable, Callable]]
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ self._socket_clients = {}
+ self._sync_playpoints = {}
+ self._sync_adjusts = {}
+ self._virtual_providers = {}
+ # autodiscovery of the slimproto server does not work
+ # when the port is not the default (3483) so we hardcode it for now
+ slimproto_port = 3483
+ cli_port = await select_free_port(9090, 9190)
+ self.logger.info("Starting SLIMProto server on port %s", slimproto_port)
+ self._socket_servers = (
+ # start slimproto server
+ await asyncio.start_server(self._create_client, "0.0.0.0", slimproto_port),
+ # setup discovery
+ await start_discovery(slimproto_port, cli_port, self.mass.port),
+ # setup (telnet) cli for players requesting basic info on that port
+ await asyncio.start_server(self._handle_cli_client, "0.0.0.0", cli_port),
+ )
+
+ async def close(self) -> None:
+ """Handle close/cleanup of the provider."""
+ if hasattr(self, "_socket_clients"):
+ for client in list(self._socket_clients.values()):
+ client.disconnect()
+ self._socket_clients = {}
+ if hasattr(self, "_socket_servers"):
+ for _server in self._socket_servers:
+ _server.close()
+ self._socket_servers = None
+
+ async def _create_client(
+ self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
+ ) -> None:
+ """Create player from new connection on the socket."""
+ if self.mass.closing:
+ return
+ addr = writer.get_extra_info("peername")
+ self.logger.debug("Socket client connected: %s", addr)
+
+ def client_callback(
+ event_type: SlimEventType, client: SlimClient, data: Any = None # noqa: ARG001
+ ):
+ if event_type == SlimEventType.PLAYER_DISCONNECTED:
+ self._handle_disconnected(client)
+ return
+
+ if event_type == SlimEventType.PLAYER_CONNECTED:
+ self._handle_connected(client)
+
+ if event_type == SlimEventType.PLAYER_DECODER_READY:
+ self.mass.create_task(self._handle_decoder_ready(client))
+ return
+
+ if event_type == SlimEventType.PLAYER_HEARTBEAT:
+ self._handle_player_heartbeat(client)
+ return
+
+ # ignore some uninteresting events
+ if event_type in (
+ SlimEventType.PLAYER_CLI_EVENT,
+ SlimEventType.PLAYER_DECODER_ERROR,
+ ):
+ return
+
+ # forward player update to MA player controller
+ self._handle_player_update(client)
+
+ # construct SlimClient from socket client
+ SlimClient(reader, writer, client_callback)
+
+ def get_player_config_entries(self, player_id: str) -> tuple[ConfigEntry]: # noqa: ARG002
+ """Return all (provider/player specific) Config Entries for the given player (if any)."""
+ return SLIM_PLAYER_CONFIG_ENTRIES
+
+ def on_player_config_changed(self, config: PlayerConfig) -> None:
+ """Call (by config manager) when the configuration of a player changes."""
+ # during synced playback this value is requested multiple times a second,
+ # so we cache it in a quick lookup dict
+ self._sync_adjusts[config.player_id] = config.get_value(CONF_SYNC_ADJUST)
+
+ async def cmd_stop(self, player_id: str) -> None:
+ """Send STOP command to given player."""
+ # forward command to player and any connected sync child's
+ for client in self._get_sync_clients(player_id):
+ if client.state == SlimPlayerState.STOPPED:
+ continue
+ await client.stop()
+ # workaround: some players do not send an event when playback stopped
+ client._process_stat_stmu(b"")
+
+ async def cmd_play(self, player_id: str) -> None:
+ """Send PLAY command to given player."""
+ # forward command to player and any connected sync child's
+ for client in self._get_sync_clients(player_id):
+ if client.state not in (
+ SlimPlayerState.PAUSED,
+ SlimPlayerState.BUFFERING,
+ ):
+ continue
+ await client.play()
+
+ async def cmd_play_media(
+ self,
+ player_id: str,
+ queue_item: QueueItem,
+ seek_position: int = 0,
+ fade_in: bool = False,
+ flow_mode: bool = False,
+ ) -> None:
+ """Send PLAY MEDIA command to given player.
+
+ This is called when the Queue wants the player to start playing a specific QueueItem.
+ The player implementation can decide how to process the request, such as playing
+ queue items one-by-one or enqueue all/some items.
+
+ - player_id: player_id of the player to handle the command.
+ - queue_item: the QueueItem to start playing on the player.
+ - seek_position: start playing from this specific position.
+ - fade_in: fade in the music at start (e.g. at resume).
+ """
+ # send stop first
+ await self.cmd_stop(player_id)
+
+ player = self.mass.players.get(player_id)
+ # make sure that the (master) player is powered
+ # powering any client players must be done in other ways
+ if not player.synced_to:
+ await self._socket_clients[player_id].power(True)
+
+ # forward command to player and any connected sync child's
+ for client in self._get_sync_clients(player_id):
+ await self._handle_play_media(
+ client,
+ queue_item=queue_item,
+ seek_position=seek_position,
+ fade_in=fade_in,
+ send_flush=True,
+ flow_mode=flow_mode,
+ )
+
+ async def _handle_play_media(
+ self,
+ client: SlimClient,
+ queue_item: QueueItem,
+ seek_position: int = 0,
+ fade_in: bool = False,
+ send_flush: bool = True,
+ crossfade: bool = False,
+ flow_mode: bool = False,
+ ) -> None:
+ """Handle PlayMedia on slimproto player(s)."""
+ player_id = client.player_id
+ # pick codec based on capabilities
+ codec_map = (
+ ("flc", ContentType.FLAC),
+ ("pcm", ContentType.PCM),
+ ("mp3", ContentType.MP3),
+ )
+ for fmt, fmt_type in codec_map:
+ if fmt in client.supported_codecs:
+ content_type = fmt_type
+ break
+ else:
+ self.logger.debug("Could not auto determine supported codec, fallback to PCM")
+ content_type = ContentType.PCM
+ url = await self.mass.streams.resolve_stream_url(
+ queue_item=queue_item,
+ player_id=player_id,
+ seek_position=seek_position,
+ fade_in=fade_in,
+ content_type=content_type,
+ flow_mode=flow_mode,
+ )
+ await client.play_url(
+ url=url,
+ mime_type=f"audio/{content_type.value}",
+ metadata={"item_id": queue_item.queue_item_id},
+ send_flush=send_flush,
+ transition=SlimTransition.CROSSFADE if crossfade else SlimTransition.NONE,
+ transition_duration=10 if crossfade else 0,
+ )
+
+ async def cmd_pause(self, player_id: str) -> None:
+ """Send PAUSE command to given player."""
+ # forward command to player and any connected sync child's
+ for client in self._get_sync_clients(player_id):
+ if client.state not in (
+ SlimPlayerState.PLAYING,
+ SlimPlayerState.BUFFERING,
+ ):
+ continue
+ await client.pause()
+
+ async def cmd_power(self, player_id: str, powered: bool) -> None:
+ """Send POWER command to given player."""
+ if client := self._socket_clients.get(player_id):
+ await client.power(powered)
+ # TODO: unsync client at poweroff if synced
+
+ async def cmd_volume_set(self, player_id: str, volume_level: int) -> None:
+ """Send VOLUME_SET command to given player."""
+ if client := self._socket_clients.get(player_id):
+ await client.volume_set(volume_level)
+
+ async def cmd_volume_mute(self, player_id: str, muted: bool) -> None:
+ """Send VOLUME MUTE command to given player."""
+ if client := self._socket_clients.get(player_id):
+ await client.mute(muted)
+
+ async def cmd_sync(self, player_id: str, target_player: str) -> None:
+ """Handle SYNC command for given player."""
+ child_player = self.mass.players.get(player_id)
+ assert child_player
+ parent_player = self.mass.players.get(target_player)
+ assert parent_player
+ parent_player.group_childs.append(child_player.player_id)
+ child_player.synced_to = parent_player.player_id
+ self.mass.players.update(child_player.player_id)
+ self.mass.players.update(parent_player.player_id)
+ if parent_player.state == PlayerState.PLAYING:
+ # playback needs to be restarted to get all players in sync
+ # TODO: If there is any need, we could make this smarter where the new
+ # sync child waits for the next track.
+ await self.mass.players.queues.resume(parent_player.player_id)
+
+ async def cmd_unsync(self, player_id: str) -> None:
+ """Handle UNSYNC command for given player."""
+ child_player = self.mass.players.get(player_id)
+ parent_player = self.mass.players.get(child_player.synced_to)
+ if child_player.state == PlayerState.PLAYING:
+ await self.cmd_stop(child_player.player_id)
+ child_player.synced_to = None
+ parent_player.group_childs.remove(child_player.player_id)
+ self.mass.players.update(child_player.player_id)
+ self.mass.players.update(parent_player.player_id)
+
+ def register_virtual_provider(
+ self,
+ player_model: str,
+ register_callback: Callable,
+ update_callback: Callable,
+ ) -> None:
+ """Register a virtual provider based on slimproto, such as the airplay bridge."""
+ self._virtual_providers[player_model] = (
+ register_callback,
+ update_callback,
+ )
+
+ def _handle_player_update(self, client: SlimClient) -> None:
+ """Process SlimClient update/add to Player controller."""
+ player_id = client.player_id
+ virtual_provider_info = self._virtual_providers.get(client.device_model)
+ try:
+ player = self.mass.players.get(player_id, raise_unavailable=False)
+ except PlayerUnavailableError:
+ # player does not yet exist, create it
+ player = Player(
+ player_id=player_id,
+ provider=self.domain,
+ type=PlayerType.PLAYER,
+ name=client.name,
+ available=True,
+ powered=client.powered,
+ device_info=DeviceInfo(
+ model=client.device_model,
+ address=client.device_address,
+ manufacturer=client.device_type,
+ ),
+ supported_features=(
+ PlayerFeature.ACCURATE_TIME,
+ PlayerFeature.POWER,
+ PlayerFeature.SYNC,
+ PlayerFeature.VOLUME_MUTE,
+ PlayerFeature.VOLUME_SET,
+ ),
+ max_sample_rate=int(client.max_sample_rate),
+ )
+ if virtual_provider_info:
+ # if this player is part of a virtual provider run the callback
+ virtual_provider_info[0](player)
+ self.mass.players.register(player)
+
+ # update player state on player events
+ player.available = True
+ player.current_url = client.current_url
+ player.current_item_id = (
+ client.current_metadata["item_id"] if client.current_metadata else None
+ )
+ player.name = client.name
+ player.powered = client.powered
+ player.state = STATE_MAP[client.state]
+ player.volume_level = client.volume_level
+ player.volume_muted = client.muted
+ # set all existing player ids in `can_sync_with` field
+ player.can_sync_with = tuple(
+ x.player_id for x in self._socket_clients.values() if x.player_id != player_id
+ )
+ if virtual_provider_info:
+ # if this player is part of a virtual provider run the callback
+ virtual_provider_info[1](player)
+ self.mass.players.update(player_id)
+
+ def _handle_player_heartbeat(self, client: SlimClient) -> None:
+ """Process SlimClient elapsed_time update."""
+ if client.state != SlimPlayerState.PLAYING:
+ # ignore server heartbeats
+ return
+
+ player = self.mass.players.get(client.player_id)
+ sync_master_id = player.synced_to
+
+ # elapsed time change on the time will be auto picked up
+ # by the player manager.
+ player.elapsed_time = client.elapsed_seconds
+ player.elapsed_time_last_updated = time.time()
+
+ # handle sync
+ if not sync_master_id:
+ # we only correct sync child's, not the sync master itself
+ return
+ if sync_master_id not in self._socket_clients:
+ return # just here as a guard as bad things can happen
+
+ sync_master = self._socket_clients[sync_master_id]
+
+ # we collect a few playpoints of the player to determine
+ # average lag/drift so we can adjust accordingly
+ sync_playpoints = self._sync_playpoints.setdefault(
+ client.player_id, deque(maxlen=MIN_REQ_PLAYPOINTS)
+ )
+
+ # make sure client has loaded the same track as sync master
+ client_item_id = client.current_metadata["item_id"] if client.current_metadata else None
+ master_item_id = (
+ sync_master.current_metadata["item_id"] if sync_master.current_metadata else None
+ )
+ if client_item_id != master_item_id:
+ sync_playpoints.clear()
+ return
+
+ last_playpoint = sync_playpoints[-1] if sync_playpoints else None
+ if last_playpoint and (time.time() - last_playpoint.timestamp) > 10:
+ # last playpoint is too old, invalidate
+ sync_playpoints.clear()
+ if last_playpoint and last_playpoint.item_id != client.current_metadata["item_id"]:
+ # item has changed, invalidate
+ sync_playpoints.clear()
+
+ diff = int(
+ self._get_corrected_elapsed_milliseconds(sync_master)
+ - self._get_corrected_elapsed_milliseconds(client)
+ )
+
+ if abs(diff) > MAX_DEVIATION_ADJUST:
+ # safety guard when player is transitioning or something is just plain wrong
+ sync_playpoints.clear()
+ return
+
+ # we can now append the current playpoint to our list
+ sync_playpoints.append(SyncPlayPoint(time.time(), client.current_metadata["item_id"], diff))
+
+ if len(sync_playpoints) < MIN_REQ_PLAYPOINTS:
+ return
+
+ # if we have enough playpoints, get the average value
+ prev_diffs = [x.diff for x in sync_playpoints]
+ avg_diff = sum(prev_diffs) / len(prev_diffs)
+ delta = abs(avg_diff)
+
+ if delta < MIN_DEVIATION_ADJUST:
+ return
+
+ # handle player lagging behind, fix with skip_ahead
+ if avg_diff > 0:
+ self.logger.debug("%s resync: skipAhead %sms", player.display_name, delta)
+ sync_playpoints.clear()
+ asyncio.create_task(self._skip_over(client.player_id, delta))
+ else:
+ # handle player is drifting too far ahead, use pause_for to adjust
+ self.logger.debug("%s resync: pauseFor %sms", player.display_name, delta)
+ sync_playpoints.clear()
+ asyncio.create_task(self._pause_for(client.player_id, delta))
+
+ async def _handle_decoder_ready(self, client: SlimClient) -> None:
+ """Handle decoder ready event, player is ready for the next track."""
+ if not client.current_metadata:
+ return
+ try:
+ next_item, crossfade = self.mass.players.queues.player_ready_for_next_track(
+ client.player_id, client.current_metadata["item_id"]
+ )
+ await self._handle_play_media(client, next_item, send_flush=False, crossfade=crossfade)
+ except QueueEmpty:
+ pass
+
+ def _handle_connected(self, client: SlimClient) -> None:
+ """Handle a client connected event."""
+ player_id = client.player_id
+ prev = self._socket_clients.pop(player_id, None)
+ if prev is not None:
+ # player reconnected while we did not yet cleanup the old socket
+ prev.disconnect()
+ self._socket_clients[player_id] = client
+ if prev is None:
+ # update existing players so they can update their `can_sync_with` field
+ for client in self._socket_clients.values():
+ self._handle_player_update(client)
+ # precache player config
+ self.on_player_config_changed(self.mass.config.get_player_config(player_id))
+
+ def _handle_disconnected(self, client: SlimClient) -> None:
+ """Handle a client disconnected event."""
+ player_id = client.player_id
+ prev = self._socket_clients.pop(player_id, None)
+ if prev is None:
+ # already cleaned up
+ return
+ if player := self.mass.players.get(player_id):
+ player.available = False
+ self.mass.players.update(player_id)
+
+ async def _pause_for(self, client_id: str, millis: int) -> None:
+ """Handle pause for x amount of time to help with syncing."""
+ client = self._socket_clients[client_id]
+ # https://wiki.slimdevices.com/index.php/SlimProto_TCP_protocol.html#u.2C_p.2C_a_.26_t_commands_and_replay_gain_field§
+ await client.send_strm(b"p", replay_gain=int(millis))
+
+ async def _skip_over(self, client_id: str, millis: int) -> None:
+ """Handle skip for x amount of time to help with syncing."""
+ client = self._socket_clients[client_id]
+ # https://wiki.slimdevices.com/index.php/SlimProto_TCP_protocol.html#u.2C_p.2C_a_.26_t_commands_and_replay_gain_field
+ await client.send_strm(b"a", replay_gain=int(millis))
+
+ def _get_sync_clients(self, player_id: str) -> Generator[SlimClient]:
+ """Get all sync clients for a player."""
+ player = self.mass.players.get(player_id)
+ for child_id in [player.player_id] + player.group_childs:
+ if client := self._socket_clients.get(child_id):
+ if not player_id and not client.powered:
+ # only powered child's
+ continue
+ yield client
+
+ def _get_corrected_elapsed_milliseconds(self, client: SlimClient) -> int:
+ """Return corrected elapsed milliseconds."""
+ sync_delay = self._sync_adjusts[client.player_id]
+ if sync_delay != 0:
+ return client.elapsed_milliseconds - sync_delay
+ return client.elapsed_milliseconds
+
+ async def _handle_cli_client(
+ self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
+ ) -> None:
+ """Handle new connection on the legacy CLI."""
+ # https://raw.githubusercontent.com/Logitech/slimserver/public/7.8/HTML/EN/html/docs/cli-api.html
+ self.logger.info("Client connected on Telnet CLI")
+ try:
+ while True:
+ raw_request = await reader.readline()
+ raw_request = raw_request.strip().decode("utf-8")
+ # request comes in as url encoded strings, separated by space
+ raw_params = [urllib.parse.unquote(x) for x in raw_request.split(" ")]
+ # the first param is either a macaddress or a command
+ if ":" in raw_params[0]:
+ # assume this is a mac address (=player_id)
+ player_id = raw_params[0]
+ command = raw_params[1]
+ command_params = raw_params[2:]
+ else:
+ player_id = ""
+ command = raw_params[0]
+ command_params = raw_params[1:]
+
+ args, kwargs = parse_args(command_params)
+
+ response: str = raw_request
+
+ # check if we have a handler for this command
+ # note that we only have support for very limited commands
+ # just enough for compatibility with players but not to be used as api
+ # with 3rd party tools!
+ json_rpc: JSONRPCApi = self.mass.get_provider("json_rpc")
+ assert json_rpc is not None
+ if handler := getattr(json_rpc, f"_handle_{command}", None):
+ self.logger.debug(
+ "Handling CLI-request (player: %s command: %s - args: %s - kwargs: %s)",
+ player_id,
+ command,
+ str(args),
+ str(kwargs),
+ )
+ cmd_result: list[str] = handler(player_id, *args, **kwargs)
+ if isinstance(cmd_result, dict):
+ result_parts = dict_to_strings(cmd_result)
+ result_str = " ".join(urllib.parse.quote(x) for x in result_parts)
+ elif not cmd_result:
+ result_str = ""
+ else:
+ result_str = str(cmd_result)
+ response += " " + result_str
+ else:
+ self.logger.warning(
+ "No handler for %s (player: %s - args: %s - kwargs: %s)",
+ command,
+ player_id,
+ str(args),
+ str(kwargs),
+ )
+ # echo back the request and the result (if any)
+ response += "\n"
+ writer.write(response.encode("utf-8"))
+ await writer.drain()
+ except Exception as err:
+ self.logger.debug("Error handling CLI command", exc_info=err)
+ finally:
+ self.logger.debug("Client disconnected from Telnet CLI")
+
+
+def dict_to_strings(source: dict) -> list[str]:
+ """Convert dict to key:value strings (used in slimproto cli)."""
+ result: list[str] = []
+
+ for key, value in source.items():
+ if value is None or value == "":
+ continue
+ if isinstance(value, list):
+ for subval in value:
+ if isinstance(subval, dict):
+ result += dict_to_strings(subval)
+ else:
+ result.append(str(subval))
+ elif isinstance(value, dict):
+ result += dict_to_strings(subval)
+ else:
+ result.append(f"{key}:{str(value)}")
+ return result
--- /dev/null
+{
+ "type": "player",
+ "domain": "slimproto",
+ "name": "Slimproto",
+ "description": "Support for slimproto based players (e.g. squeezebox, squeezelite). Music Assistant emulates a Logitech Media Server.",
+ "codeowners": ["@marcelveldt"],
+ "config_entries": [
+ ],
+ "requirements": ["aioslimproto==2.2.0"],
+ "documentation": "",
+ "multi_instance": false,
+ "builtin": true,
+ "load_by_default": true
+}
--- /dev/null
+"""Sample Player provider for Music Assistant."""
+from __future__ import annotations
+
+import asyncio
+import logging
+import time
+import xml.etree.ElementTree as ET # noqa: N817
+from contextlib import suppress
+from dataclasses import dataclass, field
+from typing import Any
+
+import soco
+from soco.events_base import Event as SonosEvent
+from soco.events_base import SubscriptionBase
+from soco.groups import ZoneGroup
+
+from music_assistant.common.models.enums import (
+ ContentType,
+ MediaType,
+ PlayerFeature,
+ PlayerState,
+ PlayerType,
+)
+from music_assistant.common.models.errors import PlayerUnavailableError, QueueEmpty
+from music_assistant.common.models.player import DeviceInfo, Player
+from music_assistant.common.models.queue_item import QueueItem
+from music_assistant.server.helpers.didl_lite import create_didl_metadata
+from music_assistant.server.models.player_provider import PlayerProvider
+
+PLAYER_FEATURES = (
+ PlayerFeature.SET_MEMBERS,
+ PlayerFeature.SYNC,
+ PlayerFeature.VOLUME_MUTE,
+ PlayerFeature.VOLUME_SET,
+)
+PLAYER_CONFIG_ENTRIES = tuple() # we don't have any player config entries (for now)
+
+
+@dataclass
+class SonosPlayer:
+ """Wrapper around Sonos/SoCo with some additional attributes."""
+
+ player_id: str
+ soco: soco.SoCo
+ player: Player
+ is_stereo_pair: bool = False
+ next_item: str | None = None
+ elapsed_time: int = 0
+ current_item_id: str | None = None
+ radio_mode_started: float | None = None
+
+ subscriptions: list[SubscriptionBase] = field(default_factory=list)
+
+ transport_info: dict = field(default_factory=dict)
+ track_info: dict = field(default_factory=dict)
+ speaker_info: dict = field(default_factory=dict)
+ rendering_control_info: dict = field(default_factory=dict)
+ group_info: ZoneGroup | None = None
+
+ speaker_info_updated: float = 0.0
+ transport_info_updated: float = 0.0
+ track_info_updated: float = 0.0
+ rendering_control_info_updated: float = 0.0
+ group_info_updated: float = 0.0
+
+ def update_info(
+ self,
+ update_transport_info: bool = False,
+ update_track_info: bool = False,
+ update_speaker_info: bool = False,
+ update_rendering_control_info: bool = False,
+ update_group_info: bool = False,
+ ):
+ """Poll all info from player (must be run in executor thread)."""
+ # transport info
+ if update_transport_info:
+ transport_info = self.soco.get_current_transport_info()
+ if transport_info.get("current_transport_state") != "TRANSITIONING":
+ self.transport_info = transport_info
+ self.transport_info_updated = time.time()
+ # track info
+ if update_track_info:
+ self.track_info = self.soco.get_current_track_info()
+ self.track_info_updated = time.time()
+ self.elapsed_time = _timespan_secs(self.track_info["position"]) or 0
+
+ if track_metadata := self.track_info.get("metadata"):
+ # extract queue_item_id from metadata xml
+ try:
+ xml_root = ET.XML(track_metadata)
+ for match in xml_root.iter("{http://purl.org/dc/elements/1.1/}queueItemId"):
+ item_id = match.text
+ self.current_item_id = item_id
+ break
+ except (ET.ParseError, AttributeError):
+ self.current_item_id = None
+
+ if (
+ self.current_item_id is None
+ and "/stream/" in self.track_info["uri"]
+ and self.player_id in self.track_info["uri"]
+ ):
+ # try to extract the item id from the uri
+ self.current_item_id = self.track_info["uri"].rsplit("/")[-2]
+
+ # speaker info
+ if update_speaker_info:
+ self.speaker_info = self.soco.get_speaker_info()
+ self.speaker_info_updated = time.time()
+ # rendering control info
+ if update_rendering_control_info:
+ self.rendering_control_info["volume"] = self.soco.volume
+ self.rendering_control_info["mute"] = self.soco.mute
+ self.rendering_control_info_updated = time.time()
+ # group info
+ if update_group_info:
+ self.group_info = self.soco.group
+ self.group_info_updated = time.time()
+
+ def update_attributes(self):
+ """Update attributes of the MA Player from soco.SoCo state."""
+ # generic attributes (speaker_info)
+ self.player.name = self.speaker_info["zone_name"]
+ self.player.volume_level = int(self.rendering_control_info["volume"])
+ self.player.volume_muted = self.rendering_control_info["mute"]
+
+ # transport info (playback state)
+ current_transport_state = self.transport_info["current_transport_state"]
+ new_state = _convert_state(current_transport_state)
+ self.player.state = new_state
+
+ # media info (track info)
+ self.player.current_url = self.track_info["uri"]
+ self.player.current_item_id = self.current_item_id
+
+ if self.radio_mode_started is not None:
+ # sonos reports bullshit elapsed time while playing radio,
+ # trying to be "smart" and resetting the counter when new ICY metadata is detected
+ if new_state == PlayerState.PLAYING:
+ now = time.time()
+ self.player.elapsed_time = int(now - self.radio_mode_started + 0.5)
+ self.player.elapsed_time_last_updated = now
+ else:
+ self.player.elapsed_time = self.elapsed_time
+ self.player.elapsed_time_last_updated = self.track_info_updated
+
+ # zone topology (syncing/grouping) details
+ if self.group_info and self.group_info.coordinator.uid == self.player_id:
+ # this player is the sync leader
+ self.player.synced_to = None
+ self.player.group_childs = {
+ x.uid for x in self.group_info.members if x.uid != self.player_id
+ }
+ elif self.group_info and self.group_info.coordinator:
+ # player is synced to
+ self.player.synced_to = self.group_info.coordinator.uid
+
+ async def check_poll(self) -> None:
+ """Check if any of the endpoints needs to be polled for info."""
+ cur_time = time.time()
+ update_transport_info = (cur_time - self.transport_info_updated) > 30
+ update_track_info = self.transport_info.get("current_transport_state") == "PLAYING" or (
+ (cur_time - self.track_info_updated) > 300
+ )
+ update_speaker_info = (cur_time - self.speaker_info_updated) > 300
+ update_rendering_control_info = (cur_time - self.rendering_control_info_updated) > 30
+ update_group_info = (cur_time - self.group_info_updated) > 300
+
+ if not (
+ update_transport_info
+ or update_track_info
+ or update_speaker_info
+ or update_rendering_control_info
+ or update_group_info
+ ):
+ return
+
+ await asyncio.to_thread(
+ self.update_info,
+ update_transport_info,
+ update_track_info,
+ update_speaker_info,
+ update_rendering_control_info,
+ update_group_info,
+ )
+
+
+class SonosPlayerProvider(PlayerProvider):
+ """Sonos Player provider."""
+
+ sonosplayers: dict[str, SonosPlayer]
+ _discovery_running: bool
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ self.sonosplayers = {}
+ self._discovery_running = False
+ # silence the soco logger a bit
+ logging.getLogger("soco").setLevel(logging.INFO)
+ logging.getLogger("urllib3.connectionpool").setLevel(logging.INFO)
+ self.mass.create_task(self._run_discovery())
+
+ async def close(self) -> None:
+ """Handle close/cleanup of the provider."""
+ if hasattr(self, "sonosplayers"):
+ for player in self.sonosplayers.values():
+ player.soco.end_direct_control_session
+
+ async def cmd_stop(self, player_id: str) -> None:
+ """Send STOP command to given player."""
+ sonos_player = self.sonosplayers[player_id]
+ if not sonos_player.soco.is_coordinator:
+ self.logger.debug(
+ "Ignore STOP command for %s: Player is synced to another player.",
+ player_id,
+ )
+ return
+ await asyncio.to_thread(sonos_player.soco.stop)
+ await asyncio.to_thread(sonos_player.soco.clear_queue)
+
+ async def cmd_play(self, player_id: str) -> None:
+ """Send PLAY command to given player."""
+ sonos_player = self.sonosplayers[player_id]
+ if not sonos_player.soco.is_coordinator:
+ self.logger.debug(
+ "Ignore PLAY command for %s: Player is synced to another player.",
+ player_id,
+ )
+ return
+ await asyncio.to_thread(sonos_player.soco.play)
+
+ async def cmd_play_media(
+ self,
+ player_id: str,
+ queue_item: QueueItem,
+ seek_position: int = 0,
+ fade_in: bool = False,
+ flow_mode: bool = False,
+ ) -> None:
+ """Send PLAY MEDIA command to given player."""
+ sonos_player = self.sonosplayers[player_id]
+ if not sonos_player.soco.is_coordinator:
+ self.logger.debug(
+ "Ignore PLAY_MEDIA command for %s: Player is synced to another player.",
+ player_id,
+ )
+ return
+ # always stop and clear queue first
+ sonos_player.next_item = None
+ await asyncio.to_thread(sonos_player.soco.stop)
+ await asyncio.to_thread(sonos_player.soco.clear_queue)
+
+ radio_mode = flow_mode or not queue_item.duration
+ url = await self.mass.streams.resolve_stream_url(
+ queue_item=queue_item,
+ player_id=sonos_player.player_id,
+ seek_position=seek_position,
+ fade_in=fade_in,
+ content_type=ContentType.MP3 if radio_mode else ContentType.FLAC,
+ flow_mode=flow_mode,
+ )
+ if radio_mode:
+ sonos_player.radio_mode_started = time.time()
+ url = url.replace("http", "x-rincon-mp3radio")
+ metadata = create_didl_metadata(url, queue_item, flow_mode)
+ # sonos does multiple get requests if no duration is known
+ # our stream engine does not like that, hence the workaround
+ self.mass.streams.workaround_players.add(sonos_player.player_id)
+ await asyncio.to_thread(sonos_player.soco.play_uri, url, meta=metadata)
+ else:
+ sonos_player.radio_mode_started = None
+ await self._enqueue_item(
+ sonos_player, queue_item=queue_item, url=url, flow_mode=flow_mode
+ )
+ await asyncio.to_thread(sonos_player.soco.play_from_queue, 0)
+
+ async def cmd_pause(self, player_id: str) -> None:
+ """Send PAUSE command to given player."""
+ sonos_player = self.sonosplayers[player_id]
+ if not sonos_player.soco.is_coordinator:
+ self.logger.debug(
+ "Ignore PLAY command for %s: Player is synced to another player.",
+ player_id,
+ )
+ return
+ await asyncio.to_thread(sonos_player.soco.pause)
+
+ async def cmd_volume_set(self, player_id: str, volume_level: int) -> None:
+ """Send VOLUME_SET command to given player."""
+
+ def set_volume_level(player_id: str, volume_level: int) -> None:
+ sonos_player = self.sonosplayers[player_id]
+ sonos_player.soco.volume = volume_level
+
+ await asyncio.to_thread(set_volume_level, player_id, volume_level)
+
+ async def cmd_volume_mute(self, player_id: str, muted: bool) -> None:
+ """Send VOLUME MUTE command to given player."""
+
+ def set_volume_mute(player_id: str, muted: bool) -> None:
+ sonos_player = self.sonosplayers[player_id]
+ sonos_player.soco.mute = muted
+
+ await asyncio.to_thread(set_volume_mute, player_id, muted)
+
+ async def cmd_sync(self, player_id: str, target_player: str) -> None:
+ """Handle SYNC command for given player.
+
+ Join/add the given player(id) to the given (master) player/sync group.
+
+ - player_id: player_id of the player to handle the command.
+ - target_player: player_id of the syncgroup master or group player.
+ """
+ sonos_player = self.sonosplayers[player_id]
+ await asyncio.to_thread(sonos_player.soco.join, self.sonosplayers[target_player].soco)
+
+ async def cmd_unsync(self, player_id: str) -> None:
+ """Handle UNSYNC command for given player.
+
+ Remove the given player from any syncgroups it currently is synced to.
+
+ - player_id: player_id of the player to handle the command.
+ """
+ sonos_player = self.sonosplayers[player_id]
+ await asyncio.to_thread(sonos_player.soco.unjoin)
+
+ async def poll_player(self, player_id: str) -> None:
+ """Poll player for state updates.
+
+ This is called by the Player Manager;
+ - every 360 seconds if the player if not powered
+ - every 30 seconds if the player is powered
+ - every 10 seconds if the player is playing
+
+ Use this method to request any info that is not automatically updated and/or
+ to detect if the player is still alive.
+ If this method raises the PlayerUnavailable exception,
+ the player is marked as unavailable until
+ the next successful poll or event where it becomes available again.
+ If the player does not need any polling, simply do not override this method.
+ """
+ sonos_player = self.sonosplayers[player_id]
+ try:
+ # the check_poll logic will work out what endpoints need polling now
+ # based on when we last received info from the device
+ await sonos_player.check_poll()
+ # always update the attributes
+ await self._update_player(sonos_player, signal_update=False)
+ except ConnectionResetError as err:
+ raise PlayerUnavailableError from err
+
+ async def _run_discovery(self) -> None:
+ """Discover Sonos players on the network."""
+ if self._discovery_running:
+ return
+ try:
+ self._discovery_running = True
+ self.logger.debug("Sonos discovery started...")
+ discovered_devices: set[soco.SoCo] = await asyncio.to_thread(soco.discover, 10)
+ if discovered_devices is None:
+ discovered_devices = set()
+ new_device_ids = {item.uid for item in discovered_devices}
+ cur_player_ids = set(self.sonosplayers.keys())
+ added_devices = new_device_ids.difference(cur_player_ids)
+ removed_devices = cur_player_ids.difference(new_device_ids)
+
+ # mark any disconnected players as unavailable...
+ for player_id in removed_devices:
+ if player := self.mass.players.get(player_id):
+ player.available = False
+ self.mass.players.update(player_id)
+
+ # process new players
+ for device in discovered_devices:
+ if device.uid not in added_devices:
+ continue
+ await self._device_discovered(device)
+
+ # handle groups
+ # if soco_player := next(iter(discovered_devices), None):
+ # self._process_groups(soco_player.all_groups)
+ # else:
+ # self._process_groups(set())
+
+ finally:
+ self._discovery_running = False
+
+ def reschedule():
+ self.mass.create_task(self._run_discovery())
+
+ # reschedule self once finished
+ self.mass.loop.call_later(300, reschedule)
+
+ async def _device_discovered(self, soco_device: soco.SoCo) -> None:
+ """Handle discovered Sonos player."""
+ player_id = soco_device.uid
+ speaker_info = await asyncio.to_thread(soco_device.get_speaker_info, True)
+ assert player_id not in self.sonosplayers
+
+ sonos_player = SonosPlayer(
+ player_id=player_id,
+ soco=soco_device,
+ player=Player(
+ player_id=soco_device.uid,
+ provider=self.domain,
+ type=PlayerType.PLAYER,
+ name=soco_device.player_name,
+ available=True,
+ powered=True,
+ supported_features=PLAYER_FEATURES,
+ device_info=DeviceInfo(
+ model=speaker_info["model_name"],
+ address=speaker_info["mac_address"],
+ manufacturer=self.name,
+ ),
+ max_sample_rate=48000,
+ ),
+ speaker_info=speaker_info,
+ speaker_info_updated=time.time(),
+ )
+ # poll all endpoints once and update attributes
+ await sonos_player.check_poll()
+ sonos_player.update_attributes()
+
+ # handle subscriptions to events
+ def subscribe(service, _callback):
+ queue = ProcessSonosEventQueue(sonos_player, _callback)
+ sub = service.subscribe(auto_renew=True, event_queue=queue)
+ sonos_player.subscriptions.append(sub)
+
+ subscribe(soco_device.avTransport, self._handle_av_transport_event)
+ subscribe(soco_device.renderingControl, self._handle_rendering_control_event)
+ subscribe(soco_device.zoneGroupTopology, self._handle_zone_group_topology_event)
+
+ self.sonosplayers[player_id] = sonos_player
+
+ self.mass.players.register(sonos_player.player)
+
+ def _handle_av_transport_event(self, sonos_player: SonosPlayer, event: SonosEvent):
+ """Handle a soco.SoCo AVTransport event."""
+ if self.mass.closing:
+ return
+ self.logger.debug("Received AVTransport event for Player %s", sonos_player.soco.player_name)
+
+ if "transport_state" in event.variables:
+ new_state = event.variables["transport_state"]
+ if new_state == "TRANSITIONING":
+ return
+ sonos_player.transport_info["current_transport_state"] = new_state
+
+ if "current_track_uri" in event.variables:
+ sonos_player.transport_info["uri"] = event.variables["current_track_uri"]
+
+ sonos_player.transport_info_updated = time.time()
+ asyncio.run_coroutine_threadsafe(self._update_player(sonos_player), self.mass.loop)
+
+ def _handle_rendering_control_event(self, sonos_player: SonosPlayer, event: SonosEvent):
+ """Handle a soco.SoCo RenderingControl event."""
+ if self.mass.closing:
+ return
+ self.logger.debug(
+ "Received RenderingControl event for Player %s",
+ sonos_player.soco.player_name,
+ )
+ if "volume" in event.variables:
+ sonos_player.rendering_control_info["volume"] = event.variables["volume"]["Master"]
+ if "mute" in event.variables:
+ sonos_player.rendering_control_info["mute"] = bool(event.variables["mute"]["Master"])
+ sonos_player.rendering_control_info_updated = time.time()
+ asyncio.run_coroutine_threadsafe(self._update_player(sonos_player), self.mass.loop)
+
+ def _handle_zone_group_topology_event(
+ self, sonos_player: SonosPlayer, event: SonosEvent # noqa: ARG002
+ ):
+ """Handle a soco.SoCo ZoneGroupTopology event."""
+ if self.mass.closing:
+ return
+ self.logger.debug(
+ "Received ZoneGroupTopology event for Player %s",
+ sonos_player.soco.player_name,
+ )
+ sonos_player.group_info = sonos_player.soco.group
+ sonos_player.group_info_updated = time.time()
+ asyncio.run_coroutine_threadsafe(self._update_player(sonos_player), self.mass.loop)
+
+ def _process_groups(self, sonos_groups: list[soco.SoCo]) -> None:
+ """Process all sonos groups."""
+ all_group_ids = set()
+ for sonos_player in sonos_groups:
+ all_group_ids.add(sonos_player.uid)
+ if sonos_player.uid not in self.sonosplayers:
+ # unknown player ?!
+ continue
+
+ # mass_player = self.mass.players.get(sonos_player.uid)
+ # sonos_player.is_coordinator
+ # # check members
+ # group_player.is_group_player = True
+ # group_player.name = group.label
+ # group_player.group_childs = [item.uid for item in group.members]
+ # create_task(self.mass.players.update_player(group_player))
+
+ async def _enqueue_next_track(
+ self, sonos_player: SonosPlayer, current_queue_item_id: str
+ ) -> None:
+ """Enqueue the next track of the MA queue on the CC queue."""
+ if not current_queue_item_id:
+ return # guard
+ if not self.mass.players.queues.get_item(sonos_player.player_id, current_queue_item_id):
+ return # guard
+ try:
+ next_item, crossfade = self.mass.players.queues.player_ready_for_next_track(
+ sonos_player.player_id, current_queue_item_id
+ )
+ except QueueEmpty:
+ return
+
+ if sonos_player.next_item == next_item.queue_item_id:
+ return # already set ?!
+ sonos_player.next_item = next_item.queue_item_id
+
+ # set crossfade according to queue mode
+ if sonos_player.soco.cross_fade != crossfade:
+
+ def set_crossfade():
+ with suppress(Exception):
+ sonos_player.soco.cross_fade = crossfade
+
+ await asyncio.to_thread(set_crossfade)
+
+ # send queue item to sonos queue
+ is_radio = next_item.media_type != MediaType.TRACK
+ url = await self.mass.streams.resolve_stream_url(
+ queue_item=next_item,
+ player_id=sonos_player.player_id,
+ content_type=ContentType.MP3 if is_radio else ContentType.FLAC,
+ # Sonos pre-caches pretty aggressively so do not yet start the runner
+ auto_start_runner=False,
+ )
+ await self._enqueue_item(sonos_player, queue_item=next_item, url=url)
+
+ async def _enqueue_item(
+ self,
+ sonos_player: SonosPlayer,
+ queue_item: QueueItem,
+ url: str,
+ flow_mode: bool = False,
+ ) -> None:
+ """Enqueue a queue item to the Sonos player Queue."""
+ metadata = create_didl_metadata(url, queue_item, flow_mode)
+ await asyncio.to_thread(
+ sonos_player.soco.avTransport.AddURIToQueue,
+ [
+ ("InstanceID", 0),
+ ("EnqueuedURI", url),
+ ("EnqueuedURIMetaData", metadata),
+ ("DesiredFirstTrackNumberEnqueued", 0),
+ ("EnqueueAsNext", 0),
+ ],
+ timeout=60,
+ )
+ if sonos_player.player_id in self.mass.streams.workaround_players:
+ self.mass.streams.workaround_players.remove(sonos_player.player_id)
+ self.logger.debug(
+ "Enqued track (%s) to player %s",
+ queue_item.name,
+ sonos_player.player.display_name,
+ )
+
+ async def _update_player(self, sonos_player: SonosPlayer, signal_update: bool = True) -> None:
+ """Update Sonos Player."""
+ prev_item_id = sonos_player.current_item_id
+ prev_url = sonos_player.player.current_url
+ prev_state = sonos_player.player.state
+ sonos_player.update_attributes()
+ sonos_player.player.can_sync_with = tuple(
+ x for x in self.sonosplayers if x != sonos_player.player_id
+ )
+ current_url = sonos_player.player.current_url
+ current_state = sonos_player.player.state
+
+ if (prev_url != current_url) or (prev_state != current_state):
+ # fetch track details on state or url change
+ await asyncio.to_thread(
+ sonos_player.update_info,
+ update_track_info=True,
+ )
+ sonos_player.update_attributes()
+
+ if signal_update:
+ # send update to the player manager right away only if we are triggered from an event
+ # when we're just updating from a manual poll, the player manager will
+ # update will detect changes to the player object itself
+ self.mass.players.update(sonos_player.player_id)
+
+ # enqueue next item if needed
+ if sonos_player.player.state == PlayerState.PLAYING and (
+ prev_item_id != sonos_player.current_item_id
+ or not sonos_player.next_item
+ or sonos_player.next_item == sonos_player.current_item_id
+ ):
+ self.mass.create_task(
+ self._enqueue_next_track(sonos_player, sonos_player.current_item_id)
+ )
+
+
+def _convert_state(sonos_state: str) -> PlayerState:
+ """Convert Sonos state to PlayerState."""
+ if sonos_state == "PLAYING":
+ return PlayerState.PLAYING
+ if sonos_state == "TRANSITIONING":
+ return PlayerState.PLAYING
+ if sonos_state == "PAUSED_PLAYBACK":
+ return PlayerState.PAUSED
+ return PlayerState.IDLE
+
+
+def _timespan_secs(timespan):
+ """Parse a time-span into number of seconds."""
+ if timespan in ("", "NOT_IMPLEMENTED", None):
+ return None
+ return sum(60 ** x[0] * int(x[1]) for x in enumerate(reversed(timespan.split(":"))))
+
+
+class ProcessSonosEventQueue:
+ """Queue like object for dispatching sonos events."""
+
+ def __init__(
+ self,
+ sonos_player: SonosPlayer,
+ callback_handler: callable[[SonosPlayer, dict], None],
+ ) -> None:
+ """Initialize Sonos event queue."""
+ self._callback_handler = callback_handler
+ self._sonos_player = sonos_player
+
+ def put(self, info: Any, block=True, timeout=None) -> None: # noqa: ARG002
+ """Process event."""
+ # noqa: ARG001
+ self._callback_handler(self._sonos_player, info)
--- /dev/null
+{
+ "type": "player",
+ "domain": "sonos",
+ "name": "SONOS",
+ "description": "SONOS Playerprovider for Music Assistant.",
+ "codeowners": ["@music-assistant"],
+ "config_entries": [
+ ],
+ "requirements": ["soco==0.29.1"],
+ "documentation": "",
+ "multi_instance": false,
+ "builtin": false,
+ "load_by_default": true
+}
--- /dev/null
+"""Spotify musicprovider support for MusicAssistant."""
+from __future__ import annotations
+
+import asyncio
+import json
+import os
+import platform
+import time
+from collections.abc import AsyncGenerator
+from json.decoder import JSONDecodeError
+from tempfile import gettempdir
+
+import aiohttp
+from asyncio_throttle import Throttler
+
+from music_assistant.common.helpers.util import parse_title_and_version
+from music_assistant.common.models.enums import ProviderFeature
+from music_assistant.common.models.errors import LoginFailed, MediaNotFoundError
+from music_assistant.common.models.media_items import (
+ Album,
+ AlbumType,
+ Artist,
+ ContentType,
+ ImageType,
+ MediaItemImage,
+ MediaItemType,
+ MediaType,
+ Playlist,
+ ProviderMapping,
+ StreamDetails,
+ Track,
+)
+from music_assistant.constants import CONF_PASSWORD, CONF_USERNAME
+from music_assistant.server.helpers.app_vars import app_var
+from music_assistant.server.helpers.process import AsyncProcess
+from music_assistant.server.models.music_provider import MusicProvider
+
+CACHE_DIR = gettempdir()
+
+
+class SpotifyProvider(MusicProvider):
+ """Implementation of a Spotify MusicProvider."""
+
+ _auth_token: str | None = None
+ _sp_user: str | None = None
+ _librespot_bin: str | None = None
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ self._throttler = Throttler(rate_limit=1, period=0.1)
+ self._cache_dir = CACHE_DIR
+ self._ap_workaround = False
+ self._attr_supported_features = (
+ ProviderFeature.LIBRARY_ARTISTS,
+ ProviderFeature.LIBRARY_ALBUMS,
+ ProviderFeature.LIBRARY_TRACKS,
+ ProviderFeature.LIBRARY_PLAYLISTS,
+ ProviderFeature.LIBRARY_ARTISTS_EDIT,
+ ProviderFeature.LIBRARY_ALBUMS_EDIT,
+ ProviderFeature.LIBRARY_PLAYLISTS_EDIT,
+ ProviderFeature.LIBRARY_TRACKS_EDIT,
+ ProviderFeature.PLAYLIST_TRACKS_EDIT,
+ ProviderFeature.BROWSE,
+ ProviderFeature.SEARCH,
+ ProviderFeature.ARTIST_ALBUMS,
+ ProviderFeature.ARTIST_TOPTRACKS,
+ ProviderFeature.SIMILAR_TRACKS,
+ )
+ # try to get a token, raise if that fails
+ self._cache_dir = os.path.join(CACHE_DIR, self.instance_id)
+ # try login which will raise if it fails
+ await self.login()
+
+ async def search(
+ self, search_query: str, media_types=list[MediaType] | None, limit: int = 5
+ ) -> list[MediaItemType]:
+ """Perform search on musicprovider.
+
+ :param search_query: Search query.
+ :param media_types: A list of media_types to include. All types if None.
+ :param limit: Number of items to return in the search (per type).
+ """
+ result = []
+ searchtypes = []
+ if MediaType.ARTIST in media_types:
+ searchtypes.append("artist")
+ if MediaType.ALBUM in media_types:
+ searchtypes.append("album")
+ if MediaType.TRACK in media_types:
+ searchtypes.append("track")
+ if MediaType.PLAYLIST in media_types:
+ searchtypes.append("playlist")
+ searchtype = ",".join(searchtypes)
+ search_query = search_query.replace("'", "")
+ if searchresult := await self._get_data(
+ "search", q=search_query, type=searchtype, limit=limit
+ ):
+ if "artists" in searchresult:
+ result += [
+ await self._parse_artist(item)
+ for item in searchresult["artists"]["items"]
+ if (item and item["id"])
+ ]
+ if "albums" in searchresult:
+ result += [
+ await self._parse_album(item)
+ for item in searchresult["albums"]["items"]
+ if (item and item["id"])
+ ]
+ if "tracks" in searchresult:
+ result += [
+ await self._parse_track(item)
+ for item in searchresult["tracks"]["items"]
+ if (item and item["id"])
+ ]
+ if "playlists" in searchresult:
+ result += [
+ await self._parse_playlist(item)
+ for item in searchresult["playlists"]["items"]
+ if (item and item["id"])
+ ]
+ return result
+
+ async def get_library_artists(self) -> AsyncGenerator[Artist, None]:
+ """Retrieve library artists from spotify."""
+ endpoint = "me/following"
+ while True:
+ spotify_artists = await self._get_data(
+ endpoint,
+ type="artist",
+ limit=50,
+ )
+ for item in spotify_artists["artists"]["items"]:
+ if item and item["id"]:
+ yield await self._parse_artist(item)
+ if spotify_artists["artists"]["next"]:
+ endpoint = spotify_artists["artists"]["next"]
+ endpoint = endpoint.replace("https://api.spotify.com/v1/", "")
+ else:
+ break
+
+ async def get_library_albums(self) -> AsyncGenerator[Album, None]:
+ """Retrieve library albums from the provider."""
+ for item in await self._get_all_items("me/albums"):
+ if item["album"] and item["album"]["id"]:
+ yield await self._parse_album(item["album"])
+
+ async def get_library_tracks(self) -> AsyncGenerator[Track, None]:
+ """Retrieve library tracks from the provider."""
+ for item in await self._get_all_items("me/tracks"):
+ if item and item["track"]["id"]:
+ yield await self._parse_track(item["track"])
+
+ async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]:
+ """Retrieve playlists from the provider."""
+ for item in await self._get_all_items("me/playlists"):
+ if item and item["id"]:
+ yield await self._parse_playlist(item)
+
+ async def get_artist(self, prov_artist_id) -> Artist:
+ """Get full artist details by id."""
+ artist_obj = await self._get_data(f"artists/{prov_artist_id}")
+ return await self._parse_artist(artist_obj) if artist_obj else None
+
+ async def get_album(self, prov_album_id) -> Album:
+ """Get full album details by id."""
+ album_obj = await self._get_data(f"albums/{prov_album_id}")
+ return await self._parse_album(album_obj) if album_obj else None
+
+ async def get_track(self, prov_track_id) -> Track:
+ """Get full track details by id."""
+ track_obj = await self._get_data(f"tracks/{prov_track_id}")
+ return await self._parse_track(track_obj) if track_obj else None
+
+ async def get_playlist(self, prov_playlist_id) -> Playlist:
+ """Get full playlist details by id."""
+ playlist_obj = await self._get_data(f"playlists/{prov_playlist_id}")
+ return await self._parse_playlist(playlist_obj) if playlist_obj else None
+
+ async def get_album_tracks(self, prov_album_id) -> list[Track]:
+ """Get all album tracks for given album id."""
+ return [
+ await self._parse_track(item)
+ for item in await self._get_all_items(f"albums/{prov_album_id}/tracks")
+ if (item and item["id"])
+ ]
+
+ async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]:
+ """Get all playlist tracks for given playlist id."""
+ count = 0
+ result = []
+ for item in await self._get_all_items(
+ f"playlists/{prov_playlist_id}/tracks",
+ ):
+ if not (item and item["track"] and item["track"]["id"]):
+ continue
+ track = await self._parse_track(item["track"])
+ # use count as position
+ track.position = count
+ result.append(track)
+ count += 1
+ return result
+
+ async def get_artist_albums(self, prov_artist_id) -> list[Album]:
+ """Get a list of all albums for the given artist."""
+ return [
+ await self._parse_album(item)
+ for item in await self._get_all_items(
+ f"artists/{prov_artist_id}/albums?include_groups=album,single,compilation"
+ )
+ if (item and item["id"])
+ ]
+
+ async def get_artist_toptracks(self, prov_artist_id) -> list[Track]:
+ """Get a list of 10 most popular tracks for the given artist."""
+ artist = await self.get_artist(prov_artist_id)
+ endpoint = f"artists/{prov_artist_id}/top-tracks"
+ items = await self._get_data(endpoint)
+ return [
+ await self._parse_track(item, artist=artist)
+ for item in items["tracks"]
+ if (item and item["id"])
+ ]
+
+ async def library_add(self, prov_item_id, media_type: MediaType):
+ """Add item to library."""
+ result = False
+ if media_type == MediaType.ARTIST:
+ result = await self._put_data("me/following", {"ids": prov_item_id, "type": "artist"})
+ elif media_type == MediaType.ALBUM:
+ result = await self._put_data("me/albums", {"ids": prov_item_id})
+ elif media_type == MediaType.TRACK:
+ result = await self._put_data("me/tracks", {"ids": prov_item_id})
+ elif media_type == MediaType.PLAYLIST:
+ result = await self._put_data(
+ f"playlists/{prov_item_id}/followers", data={"public": False}
+ )
+ return result
+
+ async def library_remove(self, prov_item_id, media_type: MediaType):
+ """Remove item from library."""
+ result = False
+ if media_type == MediaType.ARTIST:
+ result = await self._delete_data(
+ "me/following", {"ids": prov_item_id, "type": "artist"}
+ )
+ elif media_type == MediaType.ALBUM:
+ result = await self._delete_data("me/albums", {"ids": prov_item_id})
+ elif media_type == MediaType.TRACK:
+ result = await self._delete_data("me/tracks", {"ids": prov_item_id})
+ elif media_type == MediaType.PLAYLIST:
+ result = await self._delete_data(f"playlists/{prov_item_id}/followers")
+ return result
+
+ async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]):
+ """Add track(s) to playlist."""
+ track_uris = []
+ for track_id in prov_track_ids:
+ track_uris.append(f"spotify:track:{track_id}")
+ data = {"uris": track_uris}
+ return await self._post_data(f"playlists/{prov_playlist_id}/tracks", data=data)
+
+ async def remove_playlist_tracks(
+ self, prov_playlist_id: str, positions_to_remove: tuple[int]
+ ) -> None:
+ """Remove track(s) from playlist."""
+ track_uris = []
+ for track in await self.get_playlist_tracks(prov_playlist_id):
+ if track.position in positions_to_remove:
+ track_uris.append({"uri": f"spotify:track:{track.item_id}"})
+ if len(track_uris) == positions_to_remove:
+ break
+ data = {"tracks": track_uris}
+ return await self._delete_data(f"playlists/{prov_playlist_id}/tracks", data=data)
+
+ async def get_similar_tracks(self, prov_track_id, limit=25) -> list[Track]:
+ """Retrieve a dynamic list of tracks based on the provided item."""
+ endpoint = "recommendations"
+ items = await self._get_data(endpoint, seed_tracks=prov_track_id, limit=limit)
+ return [await self._parse_track(item) for item in items["tracks"] if (item and item["id"])]
+
+ async def get_stream_details(self, item_id: str) -> StreamDetails:
+ """Return the content details for the given track when it will be streamed."""
+ # make sure a valid track is requested.
+ track = await self.get_track(item_id)
+ if not track:
+ raise MediaNotFoundError(f"track {item_id} not found")
+ # make sure that the token is still valid by just requesting it
+ await self.login()
+ return StreamDetails(
+ item_id=track.item_id,
+ provider=self.domain,
+ content_type=ContentType.OGG,
+ duration=track.duration,
+ )
+
+ async def get_audio_stream(
+ self, streamdetails: StreamDetails, seek_position: int = 0
+ ) -> AsyncGenerator[bytes, None]:
+ """Return the audio stream for the provider item."""
+ # make sure that the token is still valid by just requesting it
+ await self.login()
+ librespot = await self.get_librespot_binary()
+ args = [
+ librespot,
+ "-c",
+ self._cache_dir,
+ "--pass-through",
+ "-b",
+ "320",
+ "--single-track",
+ f"spotify://track:{streamdetails.item_id}",
+ ]
+ if seek_position:
+ args += ["--start-position", str(int(seek_position))]
+ if self._ap_workaround:
+ args += ["--ap-port", "12345"]
+ bytes_sent = 0
+ async with AsyncProcess(args) as librespot_proc:
+ async for chunk in librespot_proc.iter_any():
+ yield chunk
+ bytes_sent += len(chunk)
+
+ if bytes_sent == 0 and not self._ap_workaround:
+ # AP resolve failure
+ # https://github.com/librespot-org/librespot/issues/972
+ # retry with ap-port set to invalid value, which will force fallback
+ args += ["--ap-port", "12345"]
+ async with AsyncProcess(args) as librespot_proc:
+ async for chunk in librespot_proc.iter_any(64000):
+ yield chunk
+ self._ap_workaround = True
+
+ async def _parse_artist(self, artist_obj):
+ """Parse spotify artist object to generic layout."""
+ artist = Artist(item_id=artist_obj["id"], provider=self.domain, name=artist_obj["name"])
+ artist.add_provider_mapping(
+ ProviderMapping(
+ item_id=artist_obj["id"],
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ url=artist_obj["external_urls"]["spotify"],
+ )
+ )
+ if "genres" in artist_obj:
+ artist.metadata.genres = set(artist_obj["genres"])
+ if artist_obj.get("images"):
+ for img in artist_obj["images"]:
+ img_url = img["url"]
+ if "2a96cbd8b46e442fc41c2b86b821562f" not in img_url:
+ artist.metadata.images = [MediaItemImage(ImageType.THUMB, img_url)]
+ break
+ return artist
+
+ async def _parse_album(self, album_obj: dict):
+ """Parse spotify album object to generic layout."""
+ name, version = parse_title_and_version(album_obj["name"])
+ album = Album(item_id=album_obj["id"], provider=self.domain, name=name, version=version)
+ for artist_obj in album_obj["artists"]:
+ album.artists.append(await self._parse_artist(artist_obj))
+ if album_obj["album_type"] == "single":
+ album.album_type = AlbumType.SINGLE
+ elif album_obj["album_type"] == "compilation":
+ album.album_type = AlbumType.COMPILATION
+ elif album_obj["album_type"] == "album":
+ album.album_type = AlbumType.ALBUM
+ if "genres" in album_obj:
+ album.metadata.genre = set(album_obj["genres"])
+ if album_obj.get("images"):
+ album.metadata.images = [MediaItemImage(ImageType.THUMB, album_obj["images"][0]["url"])]
+ if "external_ids" in album_obj and album_obj["external_ids"].get("upc"):
+ album.upc = album_obj["external_ids"]["upc"]
+ if "label" in album_obj:
+ album.metadata.label = album_obj["label"]
+ if album_obj.get("release_date"):
+ album.year = int(album_obj["release_date"].split("-")[0])
+ if album_obj.get("copyrights"):
+ album.metadata.copyright = album_obj["copyrights"][0]["text"]
+ if album_obj.get("explicit"):
+ album.metadata.explicit = album_obj["explicit"]
+ album.add_provider_mapping(
+ ProviderMapping(
+ item_id=album_obj["id"],
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ content_type=ContentType.OGG,
+ bit_rate=320,
+ url=album_obj["external_urls"]["spotify"],
+ )
+ )
+ return album
+
+ async def _parse_track(self, track_obj, artist=None):
+ """Parse spotify track object to generic layout."""
+ name, version = parse_title_and_version(track_obj["name"])
+ track = Track(
+ item_id=track_obj["id"],
+ provider=self.domain,
+ name=name,
+ version=version,
+ duration=track_obj["duration_ms"] / 1000,
+ disc_number=track_obj["disc_number"],
+ track_number=track_obj["track_number"],
+ position=track_obj.get("position"),
+ )
+ if artist:
+ track.artists.append(artist)
+ for track_artist in track_obj.get("artists", []):
+ artist = await self._parse_artist(track_artist)
+ if artist and artist.item_id not in {x.item_id for x in track.artists}:
+ track.artists.append(artist)
+
+ track.metadata.explicit = track_obj["explicit"]
+ if "preview_url" in track_obj:
+ track.metadata.preview = track_obj["preview_url"]
+ if "external_ids" in track_obj and "isrc" in track_obj["external_ids"]:
+ track.isrc = track_obj["external_ids"]["isrc"]
+ if "album" in track_obj:
+ track.album = await self._parse_album(track_obj["album"])
+ if track_obj["album"].get("images"):
+ track.metadata.images = [
+ MediaItemImage(ImageType.THUMB, track_obj["album"]["images"][0]["url"])
+ ]
+ if track_obj.get("copyright"):
+ track.metadata.copyright = track_obj["copyright"]
+ if track_obj.get("explicit"):
+ track.metadata.explicit = True
+ if track_obj.get("popularity"):
+ track.metadata.popularity = track_obj["popularity"]
+ track.add_provider_mapping(
+ ProviderMapping(
+ item_id=track_obj["id"],
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ content_type=ContentType.OGG,
+ bit_rate=320,
+ url=track_obj["external_urls"]["spotify"],
+ available=not track_obj["is_local"] and track_obj["is_playable"],
+ )
+ )
+ return track
+
+ async def _parse_playlist(self, playlist_obj):
+ """Parse spotify playlist object to generic layout."""
+ playlist = Playlist(
+ item_id=playlist_obj["id"],
+ provider=self.domain,
+ name=playlist_obj["name"],
+ owner=playlist_obj["owner"]["display_name"],
+ )
+ playlist.add_provider_mapping(
+ ProviderMapping(
+ item_id=playlist_obj["id"],
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ url=playlist_obj["external_urls"]["spotify"],
+ )
+ )
+ playlist.is_editable = (
+ playlist_obj["owner"]["id"] == self._sp_user["id"] or playlist_obj["collaborative"]
+ )
+ if playlist_obj.get("images"):
+ playlist.metadata.images = [
+ MediaItemImage(ImageType.THUMB, playlist_obj["images"][0]["url"])
+ ]
+ playlist.metadata.checksum = str(playlist_obj["snapshot_id"])
+ return playlist
+
+ async def login(self) -> dict:
+ """Log-in Spotify and return tokeninfo."""
+ # return existing token if we have one in memory
+ if (
+ self._auth_token
+ and os.path.isdir(self._cache_dir)
+ and (self._auth_token["expiresAt"] > int(time.time()) + 20)
+ ):
+ return self._auth_token
+ tokeninfo, userinfo = None, self._sp_user
+ if not self.config.get_value(CONF_USERNAME) or not self.config.get_value(CONF_PASSWORD):
+ raise LoginFailed("Invalid login credentials")
+ # retrieve token with librespot
+ retries = 0
+ while retries < 20:
+ try:
+ retries += 1
+ if not tokeninfo:
+ tokeninfo = await asyncio.wait_for(self._get_token(), 5)
+ if tokeninfo and not userinfo:
+ userinfo = await asyncio.wait_for(self._get_data("me", tokeninfo=tokeninfo), 5)
+ if tokeninfo and userinfo:
+ # we have all info we need!
+ break
+ if retries > 2:
+ # switch to ap workaround after 2 retries
+ self._ap_workaround = True
+ except asyncio.exceptions.TimeoutError:
+ await asyncio.sleep(2)
+ if tokeninfo and userinfo:
+ self._auth_token = tokeninfo
+ self._sp_user = userinfo
+ self.mass.metadata.preferred_language = userinfo["country"]
+ self.logger.info("Successfully logged in to Spotify as %s", userinfo["id"])
+ self._auth_token = tokeninfo
+ return tokeninfo
+ if tokeninfo and not userinfo:
+ raise LoginFailed(
+ "Unable to retrieve userdetails from Spotify API - probably just a temporary error"
+ )
+ if self.config.get_value(CONF_USERNAME).isnumeric():
+ # a spotify free/basic account can be recognized when
+ # the username consists of numbers only - check that here
+ # an integer can be parsed of the username, this is a free account
+ raise LoginFailed("Only Spotify Premium accounts are supported")
+ raise LoginFailed(f"Login failed for user {self.config.get_value(CONF_USERNAME)}")
+
+ async def _get_token(self):
+ """Get spotify auth token with librespot bin."""
+ time_start = time.time()
+ # authorize with username and password (NOTE: this can also be Spotify Connect)
+ args = [
+ await self.get_librespot_binary(),
+ "-O",
+ "-c",
+ self._cache_dir,
+ "-a",
+ "-u",
+ self.config.get_value(CONF_USERNAME),
+ "-p",
+ self.config.get_value(CONF_PASSWORD),
+ ]
+ librespot = await asyncio.create_subprocess_exec(*args)
+ await librespot.wait()
+ # get token with (authorized) librespot
+ scopes = [
+ "user-read-playback-state",
+ "user-read-currently-playing",
+ "user-modify-playback-state",
+ "playlist-read-private",
+ "playlist-read-collaborative",
+ "playlist-modify-public",
+ "playlist-modify-private",
+ "user-follow-modify",
+ "user-follow-read",
+ "user-library-read",
+ "user-library-modify",
+ "user-read-private",
+ "user-read-email",
+ "user-read-birthdate",
+ "user-top-read",
+ ]
+ scope = ",".join(scopes)
+ args = [
+ await self.get_librespot_binary(),
+ "-O",
+ "-t",
+ "--client-id",
+ app_var(2),
+ "--scope",
+ scope,
+ "-c",
+ self._cache_dir,
+ ]
+ if self._ap_workaround:
+ args += ["--ap-port", "12345"]
+ librespot = await asyncio.create_subprocess_exec(
+ *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
+ )
+ stdout, _ = await librespot.communicate()
+ duration = round(time.time() - time_start, 2)
+ try:
+ result = json.loads(stdout)
+ except JSONDecodeError:
+ self.logger.warning(
+ "Error while retrieving Spotify token after %s seconds, details: %s",
+ duration,
+ stdout.decode(),
+ )
+ return None
+ self.logger.debug(
+ "Retrieved Spotify token using librespot in %s seconds",
+ duration,
+ )
+ # transform token info to spotipy compatible format
+ if result and "accessToken" in result:
+ tokeninfo = result
+ tokeninfo["expiresAt"] = tokeninfo["expiresIn"] + int(time.time())
+ return tokeninfo
+ return None
+
+ async def _get_all_items(self, endpoint, key="items", **kwargs) -> list[dict]:
+ """Get all items from a paged list."""
+ limit = 50
+ offset = 0
+ all_items = []
+ while True:
+ kwargs["limit"] = limit
+ kwargs["offset"] = offset
+ result = await self._get_data(endpoint, **kwargs)
+ offset += limit
+ if not result or key not in result or not result[key]:
+ break
+ for item in result[key]:
+ item["position"] = len(all_items) + 1
+ all_items.append(item)
+ if len(result[key]) < limit:
+ break
+ return all_items
+
+ async def _get_data(self, endpoint, tokeninfo: dict | None = None, **kwargs):
+ """Get data from api."""
+ url = f"https://api.spotify.com/v1/{endpoint}"
+ kwargs["market"] = "from_token"
+ kwargs["country"] = "from_token"
+ if tokeninfo is None:
+ tokeninfo = await self.login()
+ headers = {"Authorization": f'Bearer {tokeninfo["accessToken"]}'}
+ async with self._throttler:
+ time_start = time.time()
+ try:
+ async with self.mass.http_session.get(
+ url, headers=headers, params=kwargs, verify_ssl=False, timeout=120
+ ) as response:
+ result = await response.json()
+ if "error" in result or ("status" in result and "error" in result["status"]):
+ self.logger.error("%s - %s", endpoint, result)
+ return None
+ except (
+ aiohttp.ContentTypeError,
+ JSONDecodeError,
+ ) as err:
+ self.logger.error("%s - %s", endpoint, str(err))
+ return None
+ finally:
+ self.logger.debug(
+ "Processing GET/%s took %s seconds",
+ endpoint,
+ round(time.time() - time_start, 2),
+ )
+ return result
+
+ async def _delete_data(self, endpoint, data=None, **kwargs):
+ """Delete data from api."""
+ url = f"https://api.spotify.com/v1/{endpoint}"
+ token = await self.login()
+ if not token:
+ return None
+ headers = {"Authorization": f'Bearer {token["accessToken"]}'}
+ async with self.mass.http_session.delete(
+ url, headers=headers, params=kwargs, json=data, verify_ssl=False
+ ) as response:
+ return await response.text()
+
+ async def _put_data(self, endpoint, data=None, **kwargs):
+ """Put data on api."""
+ url = f"https://api.spotify.com/v1/{endpoint}"
+ token = await self.login()
+ if not token:
+ return None
+ headers = {"Authorization": f'Bearer {token["accessToken"]}'}
+ async with self.mass.http_session.put(
+ url, headers=headers, params=kwargs, json=data, verify_ssl=False
+ ) as response:
+ return await response.text()
+
+ async def _post_data(self, endpoint, data=None, **kwargs):
+ """Post data on api."""
+ url = f"https://api.spotify.com/v1/{endpoint}"
+ token = await self.login()
+ if not token:
+ return None
+ headers = {"Authorization": f'Bearer {token["accessToken"]}'}
+ async with self.mass.http_session.post(
+ url, headers=headers, params=kwargs, json=data, verify_ssl=False
+ ) as response:
+ return await response.text()
+
+ async def get_librespot_binary(self):
+ """Find the correct librespot binary belonging to the platform."""
+ # ruff: noqa: SIM102
+ if self._librespot_bin is not None:
+ return self._librespot_bin
+
+ async def check_librespot(librespot_path: str) -> str | None:
+ try:
+ librespot = await asyncio.create_subprocess_exec(
+ *[librespot_path, "--check"], stdout=asyncio.subprocess.PIPE
+ )
+ stdout, _ = await librespot.communicate()
+ if (
+ librespot.returncode == 0
+ and b"ok spotty" in stdout
+ and b"using librespot" in stdout
+ ):
+ self._librespot_bin = librespot_path
+ return librespot_path
+ except OSError:
+ return None
+
+ base_path = os.path.join(os.path.dirname(__file__), "librespot")
+ if platform.system() == "Windows" and (
+ librespot := await check_librespot(os.path.join(base_path, "windows", "librespot.exe"))
+ ):
+ return librespot
+ if platform.system() == "Darwin":
+ # macos binary is x86_64 intel
+ if librespot := await check_librespot(os.path.join(base_path, "osx", "librespot")):
+ return librespot
+
+ if platform.system() == "FreeBSD":
+ # FreeBSD binary is x86_64 intel
+ if librespot := await check_librespot(os.path.join(base_path, "freebsd", "librespot")):
+ return librespot
+
+ if platform.system() == "Linux":
+ architecture = platform.machine()
+ if architecture in ["AMD64", "x86_64"]:
+ # generic linux x86_64 binary
+ if librespot := await check_librespot(
+ os.path.join(
+ base_path,
+ "linux",
+ "librespot-x86_64",
+ )
+ ):
+ return librespot
+
+ # arm architecture... try all options one by one...
+ for arch in ["aarch64", "armv7", "armhf", "arm"]:
+ if librespot := await check_librespot(
+ os.path.join(
+ base_path,
+ "linux",
+ f"librespot-{arch}",
+ )
+ ):
+ return librespot
+
+ raise RuntimeError(
+ f"Unable to locate Libespot for {platform.system()} ({platform.machine()})"
+ )
--- /dev/null
+{
+ "type": "music",
+ "domain": "spotify",
+ "name": "Spotify",
+ "description": "Support for the Spotify streaming provider in Music Assistant.",
+ "codeowners": ["@marcelveldt"],
+ "config_entries": [
+ {
+ "key": "username",
+ "type": "string",
+ "label": "Username"
+ },
+ {
+ "key": "password",
+ "type": "password",
+ "label": "Password"
+ }
+ ],
+
+ "requirements": [],
+ "documentation": "https://github.com/music-assistant/hass-music-assistant/discussions/816",
+ "multi_instance": true
+}
--- /dev/null
+"""The AudioDB Metadata provider for Music Assistant."""
+from __future__ import annotations
+
+from json import JSONDecodeError
+from typing import TYPE_CHECKING, Any
+
+import aiohttp.client_exceptions
+from asyncio_throttle import Throttler
+
+from music_assistant.common.models.enums import ProviderFeature
+from music_assistant.common.models.media_items import (
+ Album,
+ AlbumType,
+ Artist,
+ ImageType,
+ LinkType,
+ MediaItemImage,
+ MediaItemLink,
+ MediaItemMetadata,
+ Track,
+)
+from music_assistant.server.controllers.cache import use_cache
+from music_assistant.server.helpers.app_vars import app_var # pylint: disable=no-name-in-module
+from music_assistant.server.helpers.compare import compare_strings
+from music_assistant.server.models.metadata_provider import MetadataProvider
+
+if TYPE_CHECKING:
+ from collections.abc import Iterable
+
+IMG_MAPPING = {
+ "strArtistThumb": ImageType.THUMB,
+ "strArtistLogo": ImageType.LOGO,
+ "strArtistCutout": ImageType.CUTOUT,
+ "strArtistClearart": ImageType.CLEARART,
+ "strArtistWideThumb": ImageType.LANDSCAPE,
+ "strArtistFanart": ImageType.FANART,
+ "strArtistBanner": ImageType.BANNER,
+ "strAlbumThumb": ImageType.THUMB,
+ "strAlbumThumbHQ": ImageType.THUMB,
+ "strAlbumCDart": ImageType.DISCART,
+ "strAlbum3DCase": ImageType.OTHER,
+ "strAlbum3DFlat": ImageType.OTHER,
+ "strAlbum3DFace": ImageType.OTHER,
+ "strAlbum3DThumb": ImageType.OTHER,
+ "strTrackThumb": ImageType.THUMB,
+ "strTrack3DCase": ImageType.OTHER,
+}
+
+LINK_MAPPING = {
+ "strWebsite": LinkType.WEBSITE,
+ "strFacebook": LinkType.FACEBOOK,
+ "strTwitter": LinkType.TWITTER,
+ "strLastFMChart": LinkType.LASTFM,
+}
+
+ALBUMTYPE_MAPPING = {
+ "Single": AlbumType.SINGLE,
+ "Compilation": AlbumType.COMPILATION,
+ "Album": AlbumType.ALBUM,
+}
+
+
+class AudioDbMetadataProvider(MetadataProvider):
+ """The AudioDB Metadata provider."""
+
+ throttler: Throttler
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ self.cache = self.mass.cache
+ self._attr_supported_features = (
+ ProviderFeature.ARTIST_METADATA,
+ ProviderFeature.ALBUM_METADATA,
+ ProviderFeature.TRACK_METADATA,
+ ProviderFeature.GET_ARTIST_MBID,
+ )
+ self.throttler = Throttler(rate_limit=2, period=1)
+
+ async def get_artist_metadata(self, artist: Artist) -> MediaItemMetadata | None:
+ """Retrieve metadata for artist on theaudiodb."""
+ if data := await self._get_data("artist-mb.php", i=artist.musicbrainz_id): # noqa: SIM102
+ if data.get("artists"):
+ return self.__parse_artist(data["artists"][0])
+ return None
+
+ async def get_album_metadata(self, album: Album) -> MediaItemMetadata | None:
+ """Retrieve metadata for album on theaudiodb."""
+ adb_album = None
+ if album.musicbrainz_id:
+ result = await self._get_data("album-mb.php", i=album.musicbrainz_id)
+ if result and result.get("album"):
+ adb_album = result["album"][0]
+ elif album.artist:
+ # lookup by name
+ result = await self._get_data("searchalbum.php", s=album.artist.name, a=album.name)
+ if result and result.get("album"):
+ for item in result["album"]:
+ assert isinstance(album.artist, Artist)
+ if album.artist.musicbrainz_id:
+ if album.artist.musicbrainz_id != item["strMusicBrainzArtistID"]:
+ continue
+ elif not compare_strings(album.artist.name, item["strArtistStripped"]):
+ continue
+ if compare_strings(album.name, item["strAlbumStripped"]):
+ adb_album = item
+ break
+ if adb_album:
+ if not album.year:
+ album.year = int(adb_album.get("intYearReleased", "0"))
+ if not album.musicbrainz_id:
+ album.musicbrainz_id = adb_album["strMusicBrainzID"]
+ assert isinstance(album.artist, Artist)
+ if album.artist and not album.artist.musicbrainz_id:
+ album.artist.musicbrainz_id = adb_album["strMusicBrainzArtistID"]
+ if album.album_type == AlbumType.UNKNOWN:
+ album.album_type = ALBUMTYPE_MAPPING.get(
+ adb_album.get("strReleaseFormat"), AlbumType.UNKNOWN
+ )
+ return self.__parse_album(adb_album)
+ return None
+
+ async def get_track_metadata(self, track: Track) -> MediaItemMetadata | None:
+ """Retrieve metadata for track on theaudiodb."""
+ adb_track = None
+ if track.musicbrainz_id:
+ result = await self._get_data("track-mb.php", i=track.musicbrainz_id)
+ if result and result.get("track"):
+ return self.__parse_track(result["track"][0])
+
+ # lookup by name
+ for track_artist in track.artists:
+ assert isinstance(track_artist, Artist)
+ result = await self._get_data("searchtrack.php?", s=track_artist.name, t=track.name)
+ if result and result.get("track"):
+ for item in result["track"]:
+ if track_artist.musicbrainz_id:
+ if track_artist.musicbrainz_id != item["strMusicBrainzArtistID"]:
+ continue
+ elif not compare_strings(track_artist.name, item["strArtist"]):
+ continue
+ if compare_strings(track.name, item["strTrack"]):
+ adb_track = item
+ break
+ if adb_track:
+ if not track.musicbrainz_id:
+ track.musicbrainz_id = adb_track["strMusicBrainzID"]
+ assert isinstance(track.album, Album)
+ if track.album and not track.album.musicbrainz_id:
+ track.album.musicbrainz_id = adb_track["strMusicBrainzAlbumID"]
+ if not track_artist.musicbrainz_id:
+ track_artist.musicbrainz_id = adb_track["strMusicBrainzArtistID"]
+
+ return self.__parse_track(adb_track)
+ return None
+
+ async def get_musicbrainz_artist_id(
+ self,
+ artist: Artist,
+ ref_albums: Iterable[Album],
+ ref_tracks: Iterable[Track], # noqa: ARG002
+ ) -> str | None:
+ """Discover MusicBrainzArtistId for an artist given some reference albums/tracks."""
+ musicbrainz_id = None
+ if data := await self._get_data("searchalbum.php", s=artist.name):
+ # NOTE: object is 'null' when no records found instead of empty array
+ albums = data.get("album") or []
+ for item in albums:
+ if not compare_strings(item["strArtistStripped"], artist.name):
+ continue
+ for ref_album in ref_albums:
+ if not compare_strings(item["strAlbumStripped"], ref_album.name):
+ continue
+ # found match - update album metadata too while we're here
+ if not ref_album.musicbrainz_id:
+ ref_album.metadata = self.__parse_album(item)
+ await self.mass.music.albums.add_db_item(ref_album)
+ musicbrainz_id = item["strMusicBrainzArtistID"]
+
+ return musicbrainz_id
+
+ def __parse_artist(self, artist_obj: dict[str, Any]) -> MediaItemMetadata:
+ """Parse audiodb artist object to MediaItemMetadata."""
+ metadata = MediaItemMetadata()
+ # generic data
+ metadata.label = artist_obj.get("strLabel")
+ metadata.style = artist_obj.get("strStyle")
+ if genre := artist_obj.get("strGenre"):
+ metadata.genres = {genre}
+ metadata.mood = artist_obj.get("strMood")
+ # links
+ metadata.links = set()
+ for key, link_type in LINK_MAPPING.items():
+ if link := artist_obj.get(key):
+ metadata.links.add(MediaItemLink(link_type, link))
+ # description/biography
+ if desc := artist_obj.get(f"strBiography{self.mass.metadata.preferred_language}"):
+ metadata.description = desc
+ else:
+ metadata.description = artist_obj.get("strBiographyEN")
+ # images
+ metadata.images = []
+ for key, img_type in IMG_MAPPING.items():
+ for postfix in ("", "2", "3", "4", "5", "6", "7", "8", "9", "10"):
+ if img := artist_obj.get(f"{key}{postfix}"):
+ metadata.images.append(MediaItemImage(img_type, img))
+ else:
+ break
+ return metadata
+
+ def __parse_album(self, album_obj: dict[str, Any]) -> MediaItemMetadata:
+ """Parse audiodb album object to MediaItemMetadata."""
+ metadata = MediaItemMetadata()
+ # generic data
+ metadata.label = album_obj.get("strLabel")
+ metadata.style = album_obj.get("strStyle")
+ if genre := album_obj.get("strGenre"):
+ metadata.genres = {genre}
+ metadata.mood = album_obj.get("strMood")
+ # links
+ metadata.links = set()
+ if link := album_obj.get("strWikipediaID"):
+ metadata.links.add(
+ MediaItemLink(LinkType.WIKIPEDIA, f"https://wikipedia.org/wiki/{link}")
+ )
+ if link := album_obj.get("strAllMusicID"):
+ metadata.links.add(
+ MediaItemLink(LinkType.ALLMUSIC, f"https://www.allmusic.com/album/{link}")
+ )
+
+ # description
+ if desc := album_obj.get(f"strDescription{self.mass.metadata.preferred_language}"):
+ metadata.description = desc
+ else:
+ metadata.description = album_obj.get("strDescriptionEN")
+ metadata.review = album_obj.get("strReview")
+ # images
+ metadata.images = []
+ for key, img_type in IMG_MAPPING.items():
+ for postfix in ("", "2", "3", "4", "5", "6", "7", "8", "9", "10"):
+ if img := album_obj.get(f"{key}{postfix}"):
+ metadata.images.append(MediaItemImage(img_type, img))
+ else:
+ break
+ return metadata
+
+ def __parse_track(self, track_obj: dict[str, Any]) -> MediaItemMetadata:
+ """Parse audiodb track object to MediaItemMetadata."""
+ metadata = MediaItemMetadata()
+ # generic data
+ metadata.lyrics = track_obj.get("strTrackLyrics")
+ metadata.style = track_obj.get("strStyle")
+ if genre := track_obj.get("strGenre"):
+ metadata.genres = {genre}
+ metadata.mood = track_obj.get("strMood")
+ # description
+ if desc := track_obj.get(f"strDescription{self.mass.metadata.preferred_language}"):
+ metadata.description = desc
+ else:
+ metadata.description = track_obj.get("strDescriptionEN")
+ # images
+ metadata.images = []
+ for key, img_type in IMG_MAPPING.items():
+ for postfix in ("", "2", "3", "4", "5", "6", "7", "8", "9", "10"):
+ if img := track_obj.get(f"{key}{postfix}"):
+ metadata.images.append(MediaItemImage(img_type, img))
+ else:
+ break
+ return metadata
+
+ @use_cache(86400 * 14)
+ async def _get_data(self, endpoint, **kwargs) -> dict | None:
+ """Get data from api."""
+ url = f"https://theaudiodb.com/api/v1/json/{app_var(3)}/{endpoint}"
+ async with self.throttler:
+ async with self.mass.http_session.get(url, params=kwargs, verify_ssl=False) as response:
+ try:
+ result = await response.json()
+ except (
+ aiohttp.client_exceptions.ContentTypeError,
+ JSONDecodeError,
+ ):
+ self.logger.error("Failed to retrieve %s", endpoint)
+ text_result = await response.text()
+ self.logger.debug(text_result)
+ return None
+ except (
+ aiohttp.client_exceptions.ClientConnectorError,
+ aiohttp.client_exceptions.ServerDisconnectedError,
+ ):
+ self.logger.warning("Failed to retrieve %s", endpoint)
+ return None
+ if "error" in result and "limit" in result["error"]:
+ self.logger.warning(result["error"])
+ return None
+ return result
--- /dev/null
+{
+ "type": "metadata",
+ "domain": "theaudiodb",
+ "name": "TheAudioDB Metadata provider",
+ "description": "TheAudioDB is a community Database of audio artwork and metadata with a JSON API.",
+ "codeowners": ["@music-assistant"],
+ "config_entries": [
+ ],
+ "requirements": [],
+ "documentation": "",
+ "multi_instance": false,
+ "builtin": true,
+ "load_by_default": true
+}
--- /dev/null
+"""Tune-In musicprovider support for MusicAssistant."""
+from __future__ import annotations
+
+from collections.abc import AsyncGenerator
+from time import time
+
+from asyncio_throttle import Throttler
+
+from music_assistant.common.helpers.util import create_sort_name
+from music_assistant.common.models.enums import ProviderFeature
+from music_assistant.common.models.errors import LoginFailed, MediaNotFoundError
+from music_assistant.common.models.media_items import (
+ ContentType,
+ ImageType,
+ MediaItemImage,
+ MediaType,
+ ProviderMapping,
+ Radio,
+ StreamDetails,
+)
+from music_assistant.constants import CONF_USERNAME
+from music_assistant.server.helpers.audio import get_radio_stream
+from music_assistant.server.helpers.playlists import fetch_playlist
+from music_assistant.server.helpers.tags import parse_tags
+from music_assistant.server.models.music_provider import MusicProvider
+
+
+class TuneInProvider(MusicProvider):
+ """Provider implementation for Tune In."""
+
+ _throttler: Throttler
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider."""
+ self._throttler = Throttler(rate_limit=1, period=1)
+ self._attr_supported_features = (
+ ProviderFeature.LIBRARY_RADIOS,
+ ProviderFeature.BROWSE,
+ )
+ if not self.config.get_value(CONF_USERNAME):
+ raise LoginFailed("Username is invalid")
+ if "@" in self.config.get_value(CONF_USERNAME):
+ self.logger.warning(
+ "Emailadress detected instead of username, "
+ "it is advised to use the tunein username instead of email."
+ )
+
+ async def get_library_radios(self) -> AsyncGenerator[Radio, None]:
+ """Retrieve library/subscribed radio stations from the provider."""
+
+ async def parse_items(items: list[dict], folder: str = None) -> AsyncGenerator[Radio, None]:
+ for item in items:
+ item_type = item.get("type", "")
+ if item_type == "audio":
+ if "preset_id" not in item:
+ continue
+ # each radio station can have multiple streams add each one as different quality
+ stream_info = await self.__get_data("Tune.ashx", id=item["preset_id"])
+ for stream in stream_info["body"]:
+ yield await self._parse_radio(item, stream, folder)
+ elif item_type == "link" and item.get("item") == "url":
+ # custom url
+ yield await self._parse_radio(item)
+ elif item_type == "link":
+ # stations are in sublevel (new style)
+ if sublevel := await self.__get_data(item["URL"], render="json"):
+ async for subitem in parse_items(sublevel["body"], item["text"]):
+ yield subitem
+ elif item.get("children"):
+ # stations are in sublevel (old style ?)
+ async for subitem in parse_items(item["children"], item["text"]):
+ yield subitem
+
+ data = await self.__get_data("Browse.ashx", c="presets")
+ if data and "body" in data:
+ async for item in parse_items(data["body"]):
+ yield item
+
+ async def get_radio(self, prov_radio_id: str) -> Radio:
+ """Get radio station details."""
+ if not prov_radio_id.startswith("http"):
+ prov_radio_id, media_type = prov_radio_id.split("--", 1)
+ params = {"c": "composite", "detail": "listing", "id": prov_radio_id}
+ result = await self.__get_data("Describe.ashx", **params)
+ if result and result.get("body") and result["body"][0].get("children"):
+ item = result["body"][0]["children"][0]
+ stream_info = await self.__get_data("Tune.ashx", id=prov_radio_id)
+ for stream in stream_info["body"]:
+ if stream["media_type"] != media_type:
+ continue
+ return await self._parse_radio(item, stream)
+ # fallback - e.g. for handle custom urls ...
+ async for radio in self.get_library_radios():
+ if radio.item_id == prov_radio_id:
+ return radio
+ return None
+
+ async def _parse_radio(
+ self, details: dict, stream: dict | None = None, folder: str | None = None
+ ) -> Radio:
+ """Parse Radio object from json obj returned from api."""
+ if "name" in details:
+ name = details["name"]
+ else:
+ # parse name from text attr
+ name = details["text"]
+ if " | " in name:
+ name = name.split(" | ")[1]
+ name = name.split(" (")[0]
+
+ if stream is None:
+ # custom url (no stream object present)
+ url = details["URL"]
+ item_id = url
+ media_info = await parse_tags(url)
+ content_type = ContentType.try_parse(media_info.format)
+ bit_rate = media_info.bit_rate
+ else:
+ url = stream["url"]
+ item_id = f'{details["preset_id"]}--{stream["media_type"]}'
+ content_type = ContentType.try_parse(stream["media_type"])
+ bit_rate = stream.get("bitrate", 128) # TODO !
+
+ radio = Radio(item_id=item_id, provider=self.domain, name=name)
+ radio.add_provider_mapping(
+ ProviderMapping(
+ item_id=item_id,
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ content_type=content_type,
+ bit_rate=bit_rate,
+ details=url,
+ )
+ )
+ # preset number is used for sorting (not present at stream time)
+ preset_number = details.get("preset_number")
+ if preset_number and folder:
+ radio.sort_name = f'{folder}-{details["preset_number"]}'
+ elif preset_number:
+ radio.sort_name = details["preset_number"]
+ radio.sort_name += create_sort_name(name)
+ if "text" in details:
+ radio.metadata.description = details["text"]
+ # images
+ if img := details.get("image"):
+ radio.metadata.images = [MediaItemImage(ImageType.THUMB, img)]
+ if img := details.get("logo"):
+ radio.metadata.images = [MediaItemImage(ImageType.LOGO, img)]
+ return radio
+
+ async def get_stream_details(self, item_id: str) -> StreamDetails:
+ """Get streamdetails for a radio station."""
+ if item_id.startswith("http"):
+ # custom url
+ return StreamDetails(
+ provider=self.domain,
+ item_id=item_id,
+ content_type=ContentType.UNKNOWN,
+ media_type=MediaType.RADIO,
+ data=item_id,
+ )
+ item_id, media_type = item_id.split("--", 1)
+ stream_info = await self.__get_data("Tune.ashx", id=item_id)
+ for stream in stream_info["body"]:
+ if stream["media_type"] != media_type:
+ continue
+ # check if the radio stream is not a playlist
+ url = stream["url"]
+ if url.endswith("m3u8") or url.endswith("m3u") or url.endswith("pls"):
+ playlist = await fetch_playlist(self.mass, url)
+ url = playlist[0]
+ return StreamDetails(
+ provider=self.domain,
+ item_id=item_id,
+ content_type=ContentType(stream["media_type"]),
+ media_type=MediaType.RADIO,
+ data=url,
+ expires=time() + 24 * 3600,
+ )
+ raise MediaNotFoundError(f"Unable to retrieve stream details for {item_id}")
+
+ async def get_audio_stream(
+ self, streamdetails: StreamDetails, seek_position: int = 0 # noqa: ARG002
+ ) -> AsyncGenerator[bytes, None]:
+ """Return the audio stream for the provider item."""
+ async for chunk in get_radio_stream(self.mass, streamdetails.data, streamdetails):
+ yield chunk
+
+ async def __get_data(self, endpoint: str, **kwargs):
+ """Get data from api."""
+ if endpoint.startswith("http"):
+ url = endpoint
+ else:
+ url = f"https://opml.radiotime.com/{endpoint}"
+ kwargs["formats"] = "ogg,aac,wma,mp3"
+ kwargs["username"] = self.config.get_value(CONF_USERNAME)
+ kwargs["partnerId"] = "1"
+ kwargs["render"] = "json"
+ async with self._throttler:
+ async with self.mass.http_session.get(url, params=kwargs, verify_ssl=False) as response:
+ result = await response.json()
+ if not result or "error" in result:
+ self.logger.error(url)
+ self.logger.error(kwargs)
+ result = None
+ return result
--- /dev/null
+{
+ "type": "music",
+ "domain": "tunein",
+ "name": "Tune-In Radio",
+ "description": "Play your favorite radio stations from Tune-In in Music Assistant.",
+ "codeowners": ["@marcelveldt"],
+ "config_entries": [
+ {
+ "key": "username",
+ "type": "string",
+ "label": "Username"
+ }
+ ],
+
+ "requirements": [],
+ "documentation": "https://github.com/music-assistant/hass-music-assistant/discussions/categories/music-providers",
+ "multi_instance": true
+}
--- /dev/null
+"""Basic provider allowing for external URL's to be streamed."""
+from __future__ import annotations
+
+import os
+from collections.abc import AsyncGenerator
+
+from music_assistant.common.models.enums import ContentType, ImageType, MediaType
+from music_assistant.common.models.media_items import (
+ Artist,
+ MediaItemImage,
+ MediaItemType,
+ ProviderMapping,
+ Radio,
+ StreamDetails,
+ Track,
+)
+from music_assistant.server.helpers.audio import get_file_stream, get_http_stream, get_radio_stream
+from music_assistant.server.helpers.playlists import fetch_playlist
+from music_assistant.server.helpers.tags import AudioTags, parse_tags
+from music_assistant.server.models.music_provider import MusicProvider
+
+
+class URLProvider(MusicProvider):
+ """Music Provider for manual URL's/files added to the queue."""
+
+ async def setup(self) -> None:
+ """Handle async initialization of the provider.
+
+ Called when provider is registered.
+ """
+ self._attr_available = True
+ self._full_url = {}
+
+ async def get_track(self, prov_track_id: str) -> Track:
+ """Get full track details by id."""
+ return await self.parse_item(prov_track_id)
+
+ async def get_radio(self, prov_radio_id: str) -> Radio:
+ """Get full radio details by id."""
+ return await self.parse_item(prov_radio_id)
+
+ async def get_artist(self, prov_artist_id: str) -> Track:
+ """Get full artist details by id."""
+ artist = prov_artist_id
+ # this is here for compatibility reasons only
+ return Artist(
+ artist,
+ self.domain,
+ artist,
+ provider_mappings={
+ ProviderMapping(artist, self.domain, self.instance_id, available=False)
+ },
+ )
+
+ async def get_item(self, media_type: MediaType, prov_item_id: str) -> MediaItemType:
+ """Get single MediaItem from provider."""
+ if media_type == MediaType.ARTIST:
+ return await self.get_artist(prov_item_id)
+ if media_type == MediaType.TRACK:
+ return await self.get_track(prov_item_id)
+ if media_type == MediaType.RADIO:
+ return await self.get_radio(prov_item_id)
+ if media_type == MediaType.UNKNOWN:
+ return await self.parse_item(prov_item_id)
+ raise NotImplementedError
+
+ async def parse_item(self, item_id_or_url: str, force_refresh: bool = False) -> Track | Radio:
+ """Parse plain URL to MediaItem of type Radio or Track."""
+ item_id, url, media_info = await self._get_media_info(item_id_or_url, force_refresh)
+ is_radio = media_info.get("icy-name") or not media_info.duration
+ if is_radio:
+ # treat as radio
+ media_item = Radio(
+ item_id=item_id,
+ provider=self.domain,
+ name=media_info.get("icy-name") or media_info.title,
+ )
+ else:
+ media_item = Track(
+ item_id=item_id,
+ provider=self.domain,
+ name=media_info.title,
+ duration=int(media_info.duration or 0),
+ artists=[await self.get_artist(artist) for artist in media_info.artists],
+ )
+
+ media_item.provider_mappings = {
+ ProviderMapping(
+ item_id=item_id,
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ content_type=ContentType.try_parse(media_info.format),
+ sample_rate=media_info.sample_rate,
+ bit_depth=media_info.bits_per_sample,
+ bit_rate=media_info.bit_rate,
+ )
+ }
+ if media_info.has_cover_image:
+ media_item.metadata.images = [MediaItemImage(ImageType.THUMB, url, True)]
+ return media_item
+
+ async def _get_media_info(
+ self, item_id_or_url: str, force_refresh: bool = False
+ ) -> tuple[str, str, AudioTags]:
+ """Retrieve (cached) mediainfo for url."""
+ # check if the radio stream is not a playlist
+ if (
+ item_id_or_url.endswith("m3u8")
+ or item_id_or_url.endswith("m3u")
+ or item_id_or_url.endswith("pls")
+ ):
+ playlist = await fetch_playlist(self.mass, item_id_or_url)
+ url = playlist[0]
+ item_id = item_id_or_url
+ self._full_url[item_id] = url
+ elif "?" in item_id_or_url or "&" in item_id_or_url:
+ # store the 'real' full url to be picked up later
+ # this makes sure that we're not storing any temporary data like auth keys etc
+ # a request for an url mediaitem always passes here first before streamdetails
+ url = item_id_or_url
+ item_id = item_id_or_url.split("?")[0].split("&")[0]
+ self._full_url[item_id] = url
+ else:
+ url = self._full_url.get(item_id_or_url, item_id_or_url)
+ item_id = item_id_or_url
+ cache_key = f"{self.instance_id}.media_info.{item_id}"
+ # do we have some cached info for this url ?
+ cached_info = await self.mass.cache.get(cache_key)
+ if cached_info and not force_refresh:
+ media_info = AudioTags.parse(cached_info)
+ else:
+ # parse info with ffprobe (and store in cache)
+ media_info = await parse_tags(url)
+ if "authSig" in url:
+ media_info.has_cover_image = False
+ await self.mass.cache.set(cache_key, media_info.raw)
+ return (item_id, url, media_info)
+
+ async def get_stream_details(self, item_id: str) -> StreamDetails | None:
+ """Get streamdetails for a track/radio."""
+ item_id, url, media_info = await self._get_media_info(item_id)
+ is_radio = media_info.get("icy-name") or not media_info.duration
+ return StreamDetails(
+ provider=self.domain,
+ item_id=item_id,
+ content_type=ContentType.try_parse(media_info.format),
+ media_type=MediaType.RADIO if is_radio else MediaType.TRACK,
+ sample_rate=media_info.sample_rate,
+ bit_depth=media_info.bits_per_sample,
+ direct=None if is_radio else url,
+ data=url,
+ )
+
+ async def get_audio_stream(
+ self, streamdetails: StreamDetails, seek_position: int = 0
+ ) -> AsyncGenerator[bytes, None]:
+ """Return the audio stream for the provider item."""
+ if streamdetails.media_type == MediaType.RADIO:
+ # radio stream url
+ async for chunk in get_radio_stream(self.mass, streamdetails.data, streamdetails):
+ yield chunk
+ elif os.path.isfile(streamdetails.data):
+ # local file
+ async for chunk in get_file_stream(
+ self.mass, streamdetails.data, streamdetails, seek_position
+ ):
+ yield chunk
+ else:
+ # regular stream url (without icy meta)
+ async for chunk in get_http_stream(
+ self.mass, streamdetails.data, streamdetails, seek_position
+ ):
+ yield chunk
--- /dev/null
+{
+ "type": "music",
+ "domain": "url",
+ "name": "URL",
+ "description": "Built-in/generic provider to play music (or playlists) from a remote URL.",
+ "codeowners": ["@marcelveldt"],
+ "config_entries": [
+ ],
+
+ "requirements": [],
+ "documentation": "https://github.com/music-assistant/hass-music-assistant/discussions/categories/music-providers",
+ "multi_instance": false,
+ "builtin": true,
+ "load_by_default": true
+}
--- /dev/null
+"""Youtube Music support for MusicAssistant."""
+import asyncio
+import re
+from operator import itemgetter
+from time import time
+from typing import AsyncGenerator # noqa: UP035
+from urllib.parse import unquote
+
+import pytube
+import ytmusicapi
+
+from music_assistant.common.models.enums import ProviderFeature
+from music_assistant.common.models.errors import InvalidDataError, LoginFailed, MediaNotFoundError
+from music_assistant.common.models.media_items import (
+ Album,
+ AlbumType,
+ Artist,
+ ContentType,
+ ImageType,
+ MediaItemImage,
+ MediaItemType,
+ MediaType,
+ Playlist,
+ ProviderMapping,
+ StreamDetails,
+ Track,
+)
+from music_assistant.constants import CONF_USERNAME
+from music_assistant.server.models.music_provider import MusicProvider
+
+from .helpers import (
+ add_remove_playlist_tracks,
+ get_album,
+ get_artist,
+ get_library_albums,
+ get_library_artists,
+ get_library_playlists,
+ get_library_tracks,
+ get_playlist,
+ get_song_radio_tracks,
+ get_track,
+ library_add_remove_album,
+ library_add_remove_artist,
+ library_add_remove_playlist,
+ search,
+)
+
+# if TYPE_CHECKING:
+# from collections.abc import AsyncGenerator
+
+CONF_COOKIE = "cookie"
+
+YT_DOMAIN = "https://www.youtube.com"
+YTM_DOMAIN = "https://music.youtube.com"
+YTM_BASE_URL = f"{YTM_DOMAIN}/youtubei/v1/"
+
+# TODO: fix disabled tests
+# ruff: noqa: PLW2901, RET504
+
+
+class YoutubeMusicProvider(MusicProvider):
+ """Provider for Youtube Music."""
+
+ _headers = None
+ _context = None
+ _cookies = None
+ _signature_timestamp = 0
+ _cipher = None
+
+ async def setup(self) -> None:
+ """Set up the YTMusic provider."""
+ self._attr_supported_features = (
+ ProviderFeature.LIBRARY_ARTISTS,
+ ProviderFeature.LIBRARY_ALBUMS,
+ ProviderFeature.LIBRARY_TRACKS,
+ ProviderFeature.LIBRARY_PLAYLISTS,
+ ProviderFeature.BROWSE,
+ ProviderFeature.SEARCH,
+ ProviderFeature.ARTIST_ALBUMS,
+ ProviderFeature.ARTIST_TOPTRACKS,
+ ProviderFeature.SIMILAR_TRACKS,
+ )
+ if not self.config.get_value(CONF_USERNAME) or not self.config.get_value(CONF_COOKIE):
+ raise LoginFailed("Invalid login credentials")
+ await self._initialize_headers(cookie=self.config.get_value(CONF_COOKIE))
+ await self._initialize_context()
+ self._cookies = {"CONSENT": "YES+1"}
+ self._signature_timestamp = await self._get_signature_timestamp()
+
+ async def search(
+ self, search_query: str, media_types=list[MediaType] | None, limit: int = 5
+ ) -> list[MediaItemType]:
+ """Perform search on musicprovider.
+
+ :param search_query: Search query.
+ :param media_types: A list of media_types to include. All types if None.
+ :param limit: Number of items to return in the search (per type).
+ """
+ ytm_filter = None
+ if len(media_types) == 1:
+ # YTM does not support multiple searchtypes, falls back to all if no type given
+ if media_types[0] == MediaType.ARTIST:
+ ytm_filter = "artists"
+ if media_types[0] == MediaType.ALBUM:
+ ytm_filter = "albums"
+ if media_types[0] == MediaType.TRACK:
+ ytm_filter = "songs"
+ if media_types[0] == MediaType.PLAYLIST:
+ ytm_filter = "playlists"
+ results = await search(query=search_query, ytm_filter=ytm_filter, limit=limit)
+ parsed_results = []
+ for result in results:
+ try:
+ if result["resultType"] == "artist":
+ parsed_results.append(await self._parse_artist(result))
+ elif result["resultType"] == "album":
+ parsed_results.append(await self._parse_album(result))
+ elif result["resultType"] == "playlist":
+ parsed_results.append(await self._parse_playlist(result))
+ elif result["resultType"] == "song" and (track := await self._parse_track(result)):
+ parsed_results.append(track)
+ except InvalidDataError:
+ pass # ignore invalid item
+ return parsed_results
+
+ async def get_library_artists(self) -> AsyncGenerator[Artist, None]:
+ """Retrieve all library artists from Youtube Music."""
+ artists_obj = await get_library_artists(
+ headers=self._headers, username=self.config.get_value(CONF_USERNAME)
+ )
+ for artist in artists_obj:
+ yield await self._parse_artist(artist)
+
+ async def get_library_albums(self) -> AsyncGenerator[Album, None]:
+ """Retrieve all library albums from Youtube Music."""
+ albums_obj = await get_library_albums(
+ headers=self._headers, username=self.config.get_value(CONF_USERNAME)
+ )
+ for album in albums_obj:
+ yield await self._parse_album(album, album["browseId"])
+
+ async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]:
+ """Retrieve all library playlists from the provider."""
+ playlists_obj = await get_library_playlists(
+ headers=self._headers, username=self.config.get_value(CONF_USERNAME)
+ )
+ for playlist in playlists_obj:
+ yield await self._parse_playlist(playlist)
+
+ async def get_library_tracks(self) -> AsyncGenerator[Track, None]:
+ """Retrieve library tracks from Youtube Music."""
+ tracks_obj = await get_library_tracks(
+ headers=self._headers, username=self.config.get_value(CONF_USERNAME)
+ )
+ for track in tracks_obj:
+ # Library tracks sometimes do not have a valid artist id
+ # In that case, call the API for track details based on track id
+ try:
+ yield await self._parse_track(track)
+ except InvalidDataError:
+ track = await self.get_track(track["videoId"])
+ yield track
+
+ async def get_album(self, prov_album_id) -> Album:
+ """Get full album details by id."""
+ album_obj = await get_album(prov_album_id=prov_album_id)
+ return (
+ await self._parse_album(album_obj=album_obj, album_id=prov_album_id)
+ if album_obj
+ else None
+ )
+
+ async def get_album_tracks(self, prov_album_id: str) -> list[Track]:
+ """Get album tracks for given album id."""
+ album_obj = await get_album(prov_album_id=prov_album_id)
+ if not album_obj.get("tracks"):
+ return []
+ tracks = []
+ for idx, track_obj in enumerate(album_obj["tracks"], 1):
+ track = await self._parse_track(track_obj=track_obj)
+ track.disc_number = 0
+ track.track_number = idx
+ tracks.append(track)
+ return tracks
+
+ async def get_artist(self, prov_artist_id) -> Artist:
+ """Get full artist details by id."""
+ artist_obj = await get_artist(prov_artist_id=prov_artist_id)
+ return await self._parse_artist(artist_obj=artist_obj) if artist_obj else None
+
+ async def get_track(self, prov_track_id) -> Track:
+ """Get full track details by id."""
+ track_obj = await get_track(prov_track_id=prov_track_id)
+ return await self._parse_track(track_obj)
+
+ async def get_playlist(self, prov_playlist_id) -> Playlist:
+ """Get full playlist details by id."""
+ playlist_obj = await get_playlist(
+ prov_playlist_id=prov_playlist_id,
+ headers=self._headers,
+ username=self.config.get_value(CONF_USERNAME),
+ )
+ return await self._parse_playlist(playlist_obj)
+
+ async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]:
+ """Get all playlist tracks for given playlist id."""
+ playlist_obj = await get_playlist(
+ prov_playlist_id=prov_playlist_id,
+ headers=self._headers,
+ username=self.config.get_value(CONF_USERNAME),
+ )
+ if "tracks" not in playlist_obj:
+ return []
+ tracks = []
+ for index, track in enumerate(playlist_obj["tracks"]):
+ if track["isAvailable"]:
+ # Playlist tracks sometimes do not have a valid artist id
+ # In that case, call the API for track details based on track id
+ try:
+ track = await self._parse_track(track)
+ if track:
+ track.position = index
+ tracks.append(track)
+ except InvalidDataError:
+ track = await self.get_track(track["videoId"])
+ if track:
+ track.position = index
+ tracks.append(track)
+ return tracks
+
+ async def get_artist_albums(self, prov_artist_id) -> list[Album]:
+ """Get a list of albums for the given artist."""
+ artist_obj = await get_artist(prov_artist_id=prov_artist_id)
+ if "albums" in artist_obj and "results" in artist_obj["albums"]:
+ albums = []
+ for album_obj in artist_obj["albums"]["results"]:
+ if "artists" not in album_obj:
+ album_obj["artists"] = [
+ {"id": artist_obj["channelId"], "name": artist_obj["name"]}
+ ]
+ albums.append(await self._parse_album(album_obj, album_obj["browseId"]))
+ return albums
+ return []
+
+ async def get_artist_toptracks(self, prov_artist_id) -> list[Track]:
+ """Get a list of 25 most popular tracks for the given artist."""
+ artist_obj = await get_artist(prov_artist_id=prov_artist_id)
+ if artist_obj.get("songs") and artist_obj["songs"].get("browseId"):
+ prov_playlist_id = artist_obj["songs"]["browseId"]
+ playlist_tracks = await self.get_playlist_tracks(prov_playlist_id=prov_playlist_id)
+ return playlist_tracks[:25]
+ return []
+
+ async def library_add(self, prov_item_id, media_type: MediaType) -> None:
+ """Add an item to the library."""
+ result = False
+ if media_type == MediaType.ARTIST:
+ result = await library_add_remove_artist(
+ headers=self._headers,
+ prov_artist_id=prov_item_id,
+ add=True,
+ username=self.config.get_value(CONF_USERNAME),
+ )
+ elif media_type == MediaType.ALBUM:
+ result = await library_add_remove_album(
+ headers=self._headers,
+ prov_item_id=prov_item_id,
+ add=True,
+ username=self.config.get_value(CONF_USERNAME),
+ )
+ elif media_type == MediaType.PLAYLIST:
+ result = await library_add_remove_playlist(
+ headers=self._headers,
+ prov_item_id=prov_item_id,
+ add=True,
+ username=self.config.get_value(CONF_USERNAME),
+ )
+ elif media_type == MediaType.TRACK:
+ raise NotImplementedError
+ return result
+
+ async def library_remove(self, prov_item_id, media_type: MediaType):
+ """Remove an item from the library."""
+ result = False
+ if media_type == MediaType.ARTIST:
+ result = await library_add_remove_artist(
+ headers=self._headers,
+ prov_artist_id=prov_item_id,
+ add=False,
+ username=self.config.get_value(CONF_USERNAME),
+ )
+ elif media_type == MediaType.ALBUM:
+ result = await library_add_remove_album(
+ headers=self._headers,
+ prov_item_id=prov_item_id,
+ add=False,
+ username=self.config.get_value(CONF_USERNAME),
+ )
+ elif media_type == MediaType.PLAYLIST:
+ result = await library_add_remove_playlist(
+ headers=self._headers,
+ prov_item_id=prov_item_id,
+ add=False,
+ username=self.config.get_value(CONF_USERNAME),
+ )
+ elif media_type == MediaType.TRACK:
+ raise NotImplementedError
+ return result
+
+ async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]) -> None:
+ """Add track(s) to playlist."""
+ return await add_remove_playlist_tracks(
+ headers=self._headers,
+ prov_playlist_id=prov_playlist_id,
+ prov_track_ids=prov_track_ids,
+ add=True,
+ username=self.config.get_value(CONF_USERNAME),
+ )
+
+ async def remove_playlist_tracks(
+ self, prov_playlist_id: str, positions_to_remove: tuple[int]
+ ) -> None:
+ """Remove track(s) from playlist."""
+ playlist_obj = await get_playlist(
+ prov_playlist_id=prov_playlist_id,
+ headers=self._headers,
+ username=self.config.get_value(CONF_USERNAME),
+ )
+ if "tracks" not in playlist_obj:
+ return None
+ tracks_to_delete = []
+ for index, track in enumerate(playlist_obj["tracks"]):
+ if index in positions_to_remove:
+ # YT needs both the videoId and the setVideoId in order to remove
+ # the track. Thus, we need to obtain the playlist details and
+ # grab the info from there.
+ tracks_to_delete.append(
+ {"videoId": track["videoId"], "setVideoId": track["setVideoId"]}
+ )
+
+ return await add_remove_playlist_tracks(
+ headers=self._headers,
+ prov_playlist_id=prov_playlist_id,
+ prov_track_ids=tracks_to_delete,
+ add=False,
+ username=self.config.get_value(CONF_USERNAME),
+ )
+
+ async def get_similar_tracks(self, prov_track_id, limit=25) -> list[Track]:
+ """Retrieve a dynamic list of tracks based on the provided item."""
+ result = []
+ result = await get_song_radio_tracks(
+ headers=self._headers,
+ username=self.config.get_value(CONF_USERNAME),
+ prov_item_id=prov_track_id,
+ limit=limit,
+ )
+ if "tracks" in result:
+ tracks = []
+ for track in result["tracks"]:
+ # Playlist tracks sometimes do not have a valid artist id
+ # In that case, call the API for track details based on track id
+ try:
+ track = await self._parse_track(track)
+ if track:
+ tracks.append(track)
+ except InvalidDataError:
+ track = await self.get_track(track["videoId"])
+ if track:
+ tracks.append(track)
+ return tracks
+ return []
+
+ async def get_stream_details(self, item_id: str) -> StreamDetails:
+ """Return the content details for the given track when it will be streamed."""
+ data = {
+ "playbackContext": {
+ "contentPlaybackContext": {"signatureTimestamp": self._signature_timestamp}
+ },
+ "video_id": item_id,
+ }
+ track_obj = await self._post_data("player", data=data)
+ stream_format = await self._parse_stream_format(track_obj)
+ url = await self._parse_stream_url(stream_format=stream_format, item_id=item_id)
+ stream_details = StreamDetails(
+ provider=self.domain,
+ item_id=item_id,
+ content_type=ContentType.try_parse(stream_format["mimeType"]),
+ direct=url,
+ )
+ if (
+ track_obj["streamingData"].get("expiresInSeconds")
+ and track_obj["streamingData"].get("expiresInSeconds").isdigit()
+ ):
+ stream_details.expires = time() + int(
+ track_obj["streamingData"].get("expiresInSeconds")
+ )
+ if stream_format.get("audioChannels") and str(stream_format.get("audioChannels")).isdigit():
+ stream_details.channels = int(stream_format.get("audioChannels"))
+ if stream_format.get("audioSampleRate") and stream_format.get("audioSampleRate").isdigit():
+ stream_details.sample_rate = int(stream_format.get("audioSampleRate"))
+ return stream_details
+
+ async def _post_data(self, endpoint: str, data: dict[str, str], **kwargs): # noqa: ARG002
+ url = f"{YTM_BASE_URL}{endpoint}"
+ data.update(self._context)
+ async with self.mass.http_session.post(
+ url,
+ headers=self._headers,
+ json=data,
+ verify_ssl=False,
+ cookies=self._cookies,
+ ) as response:
+ return await response.json()
+
+ async def _get_data(self, url: str, params: dict = None):
+ async with self.mass.http_session.get(
+ url, headers=self._headers, params=params, cookies=self._cookies
+ ) as response:
+ return await response.text()
+
+ async def _initialize_headers(self, cookie: str) -> dict[str, str]:
+ """Return headers to include in the requests."""
+ headers = {
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:72.0) Gecko/20100101 Firefox/72.0", # noqa: E501
+ "Accept": "*/*",
+ "Accept-Language": "en-US,en;q=0.5",
+ "Content-Type": "application/json",
+ "X-Goog-AuthUser": "0",
+ "x-origin": "https://music.youtube.com",
+ "Cookie": cookie,
+ }
+ sapisid = ytmusicapi.helpers.sapisid_from_cookie(cookie)
+ origin = headers.get("origin", headers.get("x-origin"))
+ headers["Authorization"] = ytmusicapi.helpers.get_authorization(sapisid + " " + origin)
+ self._headers = headers
+
+ async def _initialize_context(self) -> dict[str, str]:
+ """Return a dict to use as a context in requests."""
+ self._context = {
+ "context": {
+ "client": {"clientName": "WEB_REMIX", "clientVersion": "0.1"},
+ "user": {},
+ }
+ }
+
+ async def _parse_album(self, album_obj: dict, album_id: str = None) -> Album:
+ """Parse a YT Album response to an Album model object."""
+ album_id = album_id or album_obj.get("id") or album_obj.get("browseId")
+ if "title" in album_obj:
+ name = album_obj["title"]
+ elif "name" in album_obj:
+ name = album_obj["name"]
+ album = Album(
+ item_id=album_id,
+ name=name,
+ provider=self.domain,
+ )
+ if album_obj.get("year") and album_obj["year"].isdigit():
+ album.year = album_obj["year"]
+ if "thumbnails" in album_obj:
+ album.metadata.images = await self._parse_thumbnails(album_obj["thumbnails"])
+ if "description" in album_obj:
+ album.metadata.description = unquote(album_obj["description"])
+ if "artists" in album_obj:
+ album.artists = [
+ await self._parse_artist(artist)
+ for artist in album_obj["artists"]
+ # artist object may be missing an id
+ # in that case its either a performer (like the composer) OR this
+ # is a Various artists compilation album...
+ if (artist.get("id") or artist["name"] == "Various Artists")
+ ]
+ if "type" in album_obj:
+ if album_obj["type"] == "Single":
+ album_type = AlbumType.SINGLE
+ elif album_obj["type"] == "EP":
+ album_type = AlbumType.EP
+ elif album_obj["type"] == "Album":
+ album_type = AlbumType.ALBUM
+ else:
+ album_type = AlbumType.UNKNOWN
+ album.album_type = album_type
+ album.add_provider_mapping(
+ ProviderMapping(
+ item_id=str(album_id),
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ )
+ )
+ return album
+
+ async def _parse_artist(self, artist_obj: dict) -> Artist:
+ """Parse a YT Artist response to Artist model object."""
+ artist_id = None
+ if "channelId" in artist_obj:
+ artist_id = artist_obj["channelId"]
+ elif "id" in artist_obj and artist_obj["id"]:
+ artist_id = artist_obj["id"]
+ elif artist_obj["name"] == "Various Artists":
+ artist_id = "UCUTXlgdcKU5vfzFqHOWIvkA"
+ if not artist_id:
+ raise InvalidDataError("Artist does not have a valid ID")
+ artist = Artist(item_id=artist_id, name=artist_obj["name"], provider=self.domain)
+ if "description" in artist_obj:
+ artist.metadata.description = artist_obj["description"]
+ if "thumbnails" in artist_obj and artist_obj["thumbnails"]:
+ artist.metadata.images = await self._parse_thumbnails(artist_obj["thumbnails"])
+ artist.add_provider_mapping(
+ ProviderMapping(
+ item_id=str(artist_id),
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ url=f"https://music.youtube.com/channel/{artist_id}",
+ )
+ )
+ return artist
+
+ async def _parse_playlist(self, playlist_obj: dict) -> Playlist:
+ """Parse a YT Playlist response to a Playlist object."""
+ playlist = Playlist(
+ item_id=playlist_obj["id"], provider=self.domain, name=playlist_obj["title"]
+ )
+ if "description" in playlist_obj:
+ playlist.metadata.description = playlist_obj["description"]
+ if "thumbnails" in playlist_obj and playlist_obj["thumbnails"]:
+ playlist.metadata.images = await self._parse_thumbnails(playlist_obj["thumbnails"])
+ is_editable = False
+ if playlist_obj.get("privacy") and playlist_obj.get("privacy") == "PRIVATE":
+ is_editable = True
+ playlist.is_editable = is_editable
+ playlist.add_provider_mapping(
+ ProviderMapping(
+ item_id=playlist_obj["id"],
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ )
+ )
+ playlist.metadata.checksum = playlist_obj.get("checksum")
+ return playlist
+
+ async def _parse_track(self, track_obj: dict) -> Track:
+ """Parse a YT Track response to a Track model object."""
+ track = Track(item_id=track_obj["videoId"], provider=self.domain, name=track_obj["title"])
+ if "artists" in track_obj:
+ track.artists = [
+ await self._parse_artist(artist)
+ for artist in track_obj["artists"]
+ if artist.get("id")
+ or artist.get("channelId")
+ or artist.get("name") == "Various Artists"
+ ]
+ # guard that track has valid artists
+ if not track.artists:
+ raise InvalidDataError("Track is missing artists")
+ if "thumbnails" in track_obj and track_obj["thumbnails"]:
+ track.metadata.images = await self._parse_thumbnails(track_obj["thumbnails"])
+ if (
+ track_obj.get("album")
+ and track_obj.get("artists")
+ and isinstance(track_obj.get("album"), dict)
+ and track_obj["album"].get("id")
+ ):
+ album = track_obj["album"]
+ album["artists"] = track_obj["artists"]
+ track.album = await self._parse_album(album, album["id"])
+ if "isExplicit" in track_obj:
+ track.metadata.explicit = track_obj["isExplicit"]
+ if "duration" in track_obj and str(track_obj["duration"]).isdigit():
+ track.duration = int(track_obj["duration"])
+ elif "duration_seconds" in track_obj and str(track_obj["duration_seconds"]).isdigit():
+ track.duration = int(track_obj["duration_seconds"])
+ available = True
+ if "isAvailable" in track_obj:
+ available = track_obj["isAvailable"]
+ track.add_provider_mapping(
+ ProviderMapping(
+ item_id=str(track_obj["videoId"]),
+ provider_domain=self.domain,
+ provider_instance=self.instance_id,
+ available=available,
+ content_type=ContentType.M4A,
+ )
+ )
+ return track
+
+ async def _get_signature_timestamp(self):
+ """Get a signature timestamp required to generate valid stream URLs."""
+ response = await self._get_data(url=YTM_DOMAIN)
+ match = re.search(r'jsUrl"\s*:\s*"([^"]+)"', response)
+ if match is None:
+ # retry with youtube domain
+ response = await self._get_data(url=YT_DOMAIN)
+ match = re.search(r'jsUrl"\s*:\s*"([^"]+)"', response)
+ if match is None:
+ raise Exception("Could not identify the URL for base.js player.")
+ url = YTM_DOMAIN + match.group(1)
+ response = await self._get_data(url=url)
+ match = re.search(r"signatureTimestamp[:=](\d+)", response)
+ if match is None:
+ raise Exception("Unable to identify the signatureTimestamp.")
+ return int(match.group(1))
+
+ async def _parse_stream_url(self, stream_format: dict, item_id: str) -> str:
+ """Figure out the stream URL to use based on the YT track object."""
+ url = None
+ if stream_format.get("signatureCipher"):
+ # Secured URL
+ cipher_parts = {}
+ for part in stream_format["signatureCipher"].split("&"):
+ key, val = part.split("=", maxsplit=1)
+ cipher_parts[key] = unquote(val)
+ signature = await self._decipher_signature(
+ ciphered_signature=cipher_parts["s"], item_id=item_id
+ )
+ url = cipher_parts["url"] + "&sig=" + signature
+ elif stream_format.get("url"):
+ # Non secured URL
+ url = stream_format.get("url")
+ return url
+
+ @classmethod
+ async def _parse_thumbnails(cls, thumbnails_obj: dict) -> list[MediaItemImage]:
+ """Parse and sort a list of thumbnails and return the highest quality."""
+ thumb = sorted(thumbnails_obj, key=itemgetter("width"), reverse=True)[0]
+ return [MediaItemImage(ImageType.THUMB, thumb["url"])]
+
+ @classmethod
+ async def _parse_stream_format(cls, track_obj: dict) -> dict:
+ """Grab the highest available audio stream from available streams."""
+ stream_format = {}
+ quality_mapper = {
+ "AUDIO_QUALITY_LOW": 1,
+ "AUDIO_QUALITY_MEDIUM": 2,
+ "AUDIO_QUALITY_HIGH": 3,
+ }
+ for adaptive_format in track_obj["streamingData"]["adaptiveFormats"]:
+ if adaptive_format["mimeType"].startswith("audio") and (
+ not stream_format
+ or quality_mapper.get(adaptive_format["audioQuality"], 0)
+ > quality_mapper.get(stream_format["audioQuality"], 0)
+ ):
+ stream_format = adaptive_format
+ if stream_format is None:
+ raise MediaNotFoundError("No stream found for this track")
+ return stream_format
+
+ async def _decipher_signature(self, ciphered_signature: str, item_id: str):
+ """Decipher the signature, required to build the Stream URL."""
+
+ def _decipher():
+ embed_url = f"https://www.youtube.com/embed/{item_id}"
+ embed_html = pytube.request.get(embed_url)
+ js_url = pytube.extract.js_url(embed_html)
+ ytm_js = pytube.request.get(js_url)
+ cipher = pytube.cipher.Cipher(js=ytm_js)
+ return cipher
+
+ if not self._cipher:
+ self._cipher = await asyncio.to_thread(_decipher)
+ return self._cipher.get_signature(ciphered_signature)
--- /dev/null
+"""Helper module for parsing the Youtube Music API.
+
+This helpers file is an async wrapper around the excellent ytmusicapi package.
+While the ytmusicapi package does an excellent job at parsing the Youtube Music results,
+it is unfortunately not async, which is required for Music Assistant to run smoothly.
+This also nicely separates the parsing logic from the Youtube Music provider logic.
+"""
+
+import asyncio
+import json
+from time import time
+
+import ytmusicapi
+
+
+async def get_artist(prov_artist_id: str) -> dict[str, str]:
+ """Async wrapper around the ytmusicapi get_artist function."""
+
+ def _get_artist():
+ ytm = ytmusicapi.YTMusic()
+ try:
+ artist = ytm.get_artist(channelId=prov_artist_id)
+ # ChannelId can sometimes be different and original ID is not part of the response
+ artist["channelId"] = prov_artist_id
+ except KeyError:
+ user = ytm.get_user(channelId=prov_artist_id)
+ artist = {"channelId": prov_artist_id, "name": user["name"]}
+ return artist
+
+ return await asyncio.to_thread(_get_artist)
+
+
+async def get_album(prov_album_id: str) -> dict[str, str]:
+ """Async wrapper around the ytmusicapi get_album function."""
+
+ def _get_album():
+ ytm = ytmusicapi.YTMusic()
+ return ytm.get_album(browseId=prov_album_id)
+
+ return await asyncio.to_thread(_get_album)
+
+
+async def get_playlist(
+ prov_playlist_id: str, headers: dict[str, str], username: str
+) -> dict[str, str]:
+ """Async wrapper around the ytmusicapi get_playlist function."""
+
+ def _get_playlist():
+ user = username if is_brand_account(username) else None
+ ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
+ playlist = ytm.get_playlist(playlistId=prov_playlist_id)
+ playlist["checksum"] = get_playlist_checksum(playlist)
+ return playlist
+
+ return await asyncio.to_thread(_get_playlist)
+
+
+async def get_track(prov_track_id: str) -> dict[str, str]:
+ """Async wrapper around the ytmusicapi get_playlist function."""
+
+ def _get_song():
+ ytm = ytmusicapi.YTMusic()
+ track_obj = ytm.get_song(videoId=prov_track_id)
+ track = {}
+ track["videoId"] = track_obj["videoDetails"]["videoId"]
+ track["title"] = track_obj["videoDetails"]["title"]
+ track["artists"] = [
+ {
+ "channelId": track_obj["videoDetails"]["channelId"],
+ "name": track_obj["videoDetails"]["author"],
+ }
+ ]
+ track["duration"] = track_obj["videoDetails"]["lengthSeconds"]
+ track["thumbnails"] = track_obj["microformat"]["microformatDataRenderer"]["thumbnail"][
+ "thumbnails"
+ ]
+ track["isAvailable"] = track_obj["playabilityStatus"]["status"] == "OK"
+ return track
+
+ return await asyncio.to_thread(_get_song)
+
+
+async def get_library_artists(headers: dict[str, str], username: str) -> dict[str, str]:
+ """Async wrapper around the ytmusicapi get_library_artists function."""
+
+ def _get_library_artists():
+ user = username if is_brand_account(username) else None
+ ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
+ artists = ytm.get_library_subscriptions(limit=9999)
+ # Sync properties with uniformal artist object
+ for artist in artists:
+ artist["id"] = artist["browseId"]
+ artist["name"] = artist["artist"]
+ del artist["browseId"]
+ del artist["artist"]
+ return artists
+
+ return await asyncio.to_thread(_get_library_artists)
+
+
+async def get_library_albums(headers: dict[str, str], username: str) -> dict[str, str]:
+ """Async wrapper around the ytmusicapi get_library_albums function."""
+
+ def _get_library_albums():
+ user = username if is_brand_account(username) else None
+ ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
+ return ytm.get_library_albums(limit=9999)
+
+ return await asyncio.to_thread(_get_library_albums)
+
+
+async def get_library_playlists(headers: dict[str, str], username: str) -> dict[str, str]:
+ """Async wrapper around the ytmusicapi get_library_playlists function."""
+
+ def _get_library_playlists():
+ user = username if is_brand_account(username) else None
+ ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
+ playlists = ytm.get_library_playlists(limit=9999)
+ # Sync properties with uniformal playlist object
+ for playlist in playlists:
+ playlist["id"] = playlist["playlistId"]
+ del playlist["playlistId"]
+ playlist["checksum"] = get_playlist_checksum(playlist)
+ return playlists
+
+ return await asyncio.to_thread(_get_library_playlists)
+
+
+async def get_library_tracks(headers: dict[str, str], username: str) -> dict[str, str]:
+ """Async wrapper around the ytmusicapi get_library_tracks function."""
+
+ def _get_library_tracks():
+ user = username if is_brand_account(username) else None
+ ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
+ tracks = ytm.get_library_songs(limit=9999)
+ return tracks
+
+ return await asyncio.to_thread(_get_library_tracks)
+
+
+async def library_add_remove_artist(
+ headers: dict[str, str], prov_artist_id: str, add: bool = True, username: str = None
+) -> bool:
+ """Add or remove an artist to the user's library."""
+
+ def _library_add_remove_artist():
+ user = username if is_brand_account(username) else None
+ ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
+ if add:
+ return "actions" in ytm.subscribe_artists(channelIds=[prov_artist_id])
+ if not add:
+ return "actions" in ytm.unsubscribe_artists(channelIds=[prov_artist_id])
+ return None
+
+ return await asyncio.to_thread(_library_add_remove_artist)
+
+
+async def library_add_remove_album(
+ headers: dict[str, str], prov_item_id: str, add: bool = True, username: str = None
+) -> bool:
+ """Add or remove an album or playlist to the user's library."""
+ album = await get_album(prov_album_id=prov_item_id)
+
+ def _library_add_remove_album():
+ user = username if is_brand_account(username) else None
+ ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
+ playlist_id = album["audioPlaylistId"]
+ if add:
+ return ytm.rate_playlist(playlist_id, "LIKE")
+ if not add:
+ return ytm.rate_playlist(playlist_id, "INDIFFERENT")
+ return None
+
+ return await asyncio.to_thread(_library_add_remove_album)
+
+
+async def library_add_remove_playlist(
+ headers: dict[str, str], prov_item_id: str, add: bool = True, username: str = None
+) -> bool:
+ """Add or remove an album or playlist to the user's library."""
+
+ def _library_add_remove_playlist():
+ user = username if is_brand_account(username) else None
+ ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
+ if add:
+ return "actions" in ytm.rate_playlist(prov_item_id, "LIKE")
+ if not add:
+ return "actions" in ytm.rate_playlist(prov_item_id, "INDIFFERENT")
+ return None
+
+ return await asyncio.to_thread(_library_add_remove_playlist)
+
+
+async def add_remove_playlist_tracks(
+ headers: dict[str, str],
+ prov_playlist_id: str,
+ prov_track_ids: list[str],
+ add: bool,
+ username: str = None,
+) -> bool:
+ """Async wrapper around adding/removing tracks to a playlist."""
+
+ def _add_playlist_tracks():
+ user = username if is_brand_account(username) else None
+ ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
+ if add:
+ return ytm.add_playlist_items(playlistId=prov_playlist_id, videoIds=prov_track_ids)
+ if not add:
+ return ytm.remove_playlist_items(playlistId=prov_playlist_id, videos=prov_track_ids)
+ return None
+
+ return await asyncio.to_thread(_add_playlist_tracks)
+
+
+async def get_song_radio_tracks(
+ headers: dict[str, str], username: str, prov_item_id: str, limit=25
+) -> dict[str, str]:
+ """Async wrapper around the ytmusicapi radio function."""
+ user = username if is_brand_account(username) else None
+
+ def _get_song_radio_tracks():
+ ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user)
+ playlist_id = f"RDAMVM{prov_item_id}"
+ result = ytm.get_watch_playlist(videoId=prov_item_id, playlistId=playlist_id, limit=limit)
+ # Replace inconsistensies for easier parsing
+ for track in result["tracks"]:
+ if track.get("thumbnail"):
+ track["thumbnails"] = track["thumbnail"]
+ del track["thumbnail"]
+ if track.get("length"):
+ track["duration"] = get_sec(track["length"])
+ return result
+
+ return await asyncio.to_thread(_get_song_radio_tracks)
+
+
+async def search(query: str, ytm_filter: str = None, limit: int = 20) -> list[dict]:
+ """Async wrapper around the ytmusicapi search function."""
+
+ def _search():
+ ytm = ytmusicapi.YTMusic()
+ results = ytm.search(query=query, filter=ytm_filter, limit=limit)
+ # Sync result properties with uniformal objects
+ for result in results:
+ if result["resultType"] == "artist":
+ result["id"] = result["browseId"]
+ result["name"] = result["artist"]
+ del result["browseId"]
+ del result["artist"]
+ elif result["resultType"] == "playlist":
+ if "playlistId" in result:
+ result["id"] = result["playlistId"]
+ del result["playlistId"]
+ elif "browseId" in result:
+ result["id"] = result["browseId"]
+ del result["browseId"]
+ return results
+
+ return await asyncio.to_thread(_search)
+
+
+def get_playlist_checksum(playlist_obj: dict) -> str:
+ """Try to calculate a checksum so we can detect changes in a playlist."""
+ for key in ("duration_seconds", "trackCount"):
+ if key in playlist_obj:
+ return playlist_obj[key]
+ return str(int(time()))
+
+
+def is_brand_account(username: str) -> bool:
+ """Check if the provided username is a brand-account."""
+ return len(username) == 21 and username.isdigit()
+
+
+def get_sec(time_str):
+ """Get seconds from time."""
+ parts = time_str.split(":")
+ if len(parts) == 3:
+ return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2])
+ if len(parts) == 2:
+ return int(parts[0]) * 60 + int(parts[1])
+ return 0
--- /dev/null
+{
+ "type": "music",
+ "domain": "ytmusic",
+ "name": "YouTube Music",
+ "description": "Support for the YouTube Music streaming provider in Music Assistant.",
+ "codeowners": ["@MarvinSchenkel"],
+ "config_entries": [
+ {
+ "key": "username",
+ "type": "string",
+ "label": "Username"
+ },
+ {
+ "key": "cookie",
+ "type": "string",
+ "label": "Cookie"
+ }
+ ],
+
+ "requirements": ["ytmusicapi==0.25.0", "pytube==12.1.2"],
+ "documentation": "https://github.com/music-assistant/hass-music-assistant/discussions/606",
+ "multi_instance": true
+}
--- /dev/null
+"""Main Music Assistant class."""
+from __future__ import annotations
+
+import asyncio
+import importlib
+import inspect
+import logging
+import os
+from collections.abc import Awaitable, Callable, Coroutine
+from typing import TYPE_CHECKING, Any
+
+from aiohttp import ClientSession, TCPConnector, web
+from zeroconf import InterfaceChoice, NonUniqueNameException, ServiceInfo, Zeroconf
+
+from music_assistant.common.helpers.util import get_ip, get_ip_pton, select_free_port
+from music_assistant.common.models.config_entries import ProviderConfig
+from music_assistant.common.models.enums import EventType, ProviderType
+from music_assistant.common.models.errors import (
+ MusicAssistantError,
+ ProviderUnavailableError,
+ SetupFailedError,
+)
+from music_assistant.common.models.event import MassEvent
+from music_assistant.common.models.provider import ProviderManifest
+from music_assistant.constants import CONF_SERVER_ID, CONF_WEB_IP, ROOT_LOGGER_NAME
+from music_assistant.server.controllers.cache import CacheController
+from music_assistant.server.controllers.config import ConfigController
+from music_assistant.server.controllers.metadata import MetaDataController
+from music_assistant.server.controllers.music import MusicController
+from music_assistant.server.controllers.players import PlayerController
+from music_assistant.server.controllers.streams import StreamsController
+from music_assistant.server.helpers.api import APICommandHandler, api_command, mount_websocket_api
+from music_assistant.server.helpers.util import install_package
+from music_assistant.server.models.plugin import PluginProvider
+
+from .models.metadata_provider import MetadataProvider
+from .models.music_provider import MusicProvider
+from .models.player_provider import PlayerProvider
+
+if TYPE_CHECKING:
+ from types import TracebackType
+
+ProviderInstanceType = MetadataProvider | MusicProvider | PlayerProvider
+EventCallBackType = Callable[[MassEvent], None]
+EventSubscriptionType = tuple[EventCallBackType, tuple[EventType] | None, tuple[str] | None]
+
+LOGGER = logging.getLogger(ROOT_LOGGER_NAME)
+
+BASE_DIR = os.path.dirname(os.path.abspath(__file__))
+PROVIDERS_PATH = os.path.join(BASE_DIR, "providers")
+
+
+class MusicAssistant:
+ """Main MusicAssistant (Server) object."""
+
+ loop: asyncio.AbstractEventLoop
+ http_session: ClientSession
+ _web_apprunner: web.AppRunner
+ _web_tcp: web.TCPSite
+
+ def __init__(self, storage_path: str, port: int | None = None) -> None:
+ """Initialize the MusicAssistant Server."""
+ self.storage_path = storage_path
+ self.port = port
+ self.base_ip = get_ip()
+ # shared zeroconf instance
+ self.zeroconf = Zeroconf(interfaces=InterfaceChoice.All)
+ # we dynamically register command handlers
+ self.webapp = web.Application()
+ self.command_handlers: dict[str, APICommandHandler] = {}
+ self._subscribers: set[EventSubscriptionType] = set()
+ self._available_providers: dict[str, ProviderManifest] = {}
+ self._providers: dict[str, ProviderInstanceType] = {}
+
+ # init core controllers
+ self.config = ConfigController(self)
+ self.cache = CacheController(self)
+ self.metadata = MetaDataController(self)
+ self.music = MusicController(self)
+ self.players = PlayerController(self)
+ self.streams = StreamsController(self)
+ self._tracked_tasks: list[asyncio.Task] = []
+ self.closing = False
+ # register all api commands (methods with decorator)
+ self._register_api_commands()
+
+ async def start(self) -> None:
+ """Start running the Music Assistant server."""
+ self.loop = asyncio.get_running_loop()
+ # create shared aiohttp ClientSession
+ self.http_session = ClientSession(
+ loop=self.loop,
+ connector=TCPConnector(ssl=False),
+ )
+ # setup config controller first and fetch important config values
+ await self.config.setup()
+ if self.port is None:
+ # if port is None, we need to autoselect it
+ self.port = await select_free_port(8095, 9200)
+ # allow overriding of the base_ip if autodetect failed
+ self.base_ip = self.config.get(CONF_WEB_IP, self.base_ip)
+
+ # setup other core controllers
+ await self.cache.setup()
+ await self.music.setup()
+ await self.metadata.setup()
+ await self.players.setup()
+ await self.streams.setup()
+
+ # load providers
+ await self._load_providers()
+ # setup web server
+ mount_websocket_api(self, "/ws")
+ self._web_apprunner = web.AppRunner(self.webapp, access_log=None)
+ await self._web_apprunner.setup()
+ # set host to None to bind to all addresses on both IPv4 and IPv6
+ host = None
+ self._web_tcp = web.TCPSite(self._web_apprunner, host=host, port=self.port)
+ await self._web_tcp.start()
+ await self._setup_discovery()
+
+ async def stop(self) -> None:
+ """Stop running the music assistant server."""
+ LOGGER.info("Stop called, cleaning up...")
+ self.signal_event(EventType.SHUTDOWN)
+ self.closing = True
+ # cancel all running tasks
+ for task in self._tracked_tasks:
+ task.cancel()
+ # stop/clean streams controller
+ await self.streams.close()
+ # stop/clean webserver
+ await self._web_tcp.stop()
+ await self._web_apprunner.cleanup()
+ await self.webapp.shutdown()
+ await self.webapp.cleanup()
+ # stop core controllers
+ await self.metadata.close()
+ await self.music.close()
+ await self.players.close()
+ # cleanup all providers
+ for prov in self._providers.values():
+ await prov.close()
+ # cleanup cache and config
+ await self.config.close()
+ await self.cache.close()
+ # close/cleanup shared http session
+ if self.http_session:
+ await self.http_session.close()
+
+ @property
+ def base_url(self) -> str:
+ """Return the (web)server's base url."""
+ return f"http://{self.base_ip}:{self.port}"
+
+ @property
+ def server_id(self) -> str:
+ """Return unique ID of this server."""
+ if not self.config.initialized:
+ return ""
+ return self.config.get(CONF_SERVER_ID) # type: ignore[no-any-return]
+
+ @api_command("providers/available")
+ def get_available_providers(self) -> list[ProviderManifest]:
+ """Return all available Providers."""
+ return list(self._available_providers.values())
+
+ @api_command("providers")
+ def get_providers(
+ self, provider_type: ProviderType | None = None
+ ) -> list[ProviderInstanceType]:
+ """Return all loaded/running Providers (instances), optionally filtered by ProviderType."""
+ return [
+ x for x in self._providers.values() if provider_type is None or provider_type == x.type
+ ]
+
+ @property
+ def providers(self) -> list[ProviderInstanceType]:
+ """Return all loaded/running Providers (instances)."""
+ return list(self._providers.values())
+
+ def get_provider(self, provider_instance_or_domain: str) -> ProviderInstanceType:
+ """Return provider by instance id (or domain)."""
+ if prov := self._providers.get(provider_instance_or_domain):
+ return prov
+ for prov in self._providers.values():
+ if prov.domain == provider_instance_or_domain:
+ return prov
+ raise ProviderUnavailableError(f"Provider {provider_instance_or_domain} is not available")
+
+ def signal_event(
+ self,
+ event: EventType,
+ object_id: str | None = None,
+ data: Any = None,
+ ) -> None:
+ """Signal event to subscribers."""
+ if self.closing:
+ return
+
+ if LOGGER.isEnabledFor(logging.DEBUG) and event != EventType.QUEUE_TIME_UPDATED:
+ # do not log queue time updated events because that is too chatty
+ LOGGER.getChild("event").debug("%s %s", event.value, object_id or "")
+
+ event_obj = MassEvent(event=event, object_id=object_id, data=data)
+ for cb_func, event_filter, id_filter in self._subscribers:
+ if not (event_filter is None or event in event_filter):
+ continue
+ if not (id_filter is None or object_id in id_filter):
+ continue
+ if asyncio.iscoroutinefunction(cb_func):
+ asyncio.run_coroutine_threadsafe(cb_func(event_obj), self.loop)
+ else:
+ self.loop.call_soon_threadsafe(cb_func, event_obj)
+
+ def subscribe(
+ self,
+ cb_func: EventCallBackType,
+ event_filter: EventType | tuple[EventType] | None = None,
+ id_filter: str | tuple[str] | None = None,
+ ) -> Callable:
+ """Add callback to event listeners.
+
+ Returns function to remove the listener.
+ :param cb_func: callback function or coroutine
+ :param event_filter: Optionally only listen for these events
+ :param id_filter: Optionally only listen for these id's (player_id, queue_id, uri)
+ """
+ if isinstance(event_filter, EventType):
+ event_filter = (event_filter,)
+ if isinstance(id_filter, str):
+ id_filter = (id_filter,)
+ listener = (cb_func, event_filter, id_filter)
+ self._subscribers.add(listener)
+
+ def remove_listener():
+ self._subscribers.remove(listener)
+
+ return remove_listener
+
+ def create_task(
+ self,
+ target: Coroutine | Awaitable | Callable | asyncio.Future,
+ *args: Any,
+ **kwargs: Any,
+ ) -> asyncio.Task | asyncio.Future:
+ """Create Task on (main) event loop from Coroutine(function).
+
+ Tasks created by this helper will be properly cancelled on stop.
+ """
+ if asyncio.iscoroutinefunction(target):
+ task = self.loop.create_task(target(*args, **kwargs))
+ elif isinstance(target, asyncio.Future):
+ task = target
+ elif asyncio.iscoroutine(target):
+ task = self.loop.create_task(target)
+ else:
+ # assume normal callable (non coroutine or awaitable)
+ task = self.loop.create_task(asyncio.to_thread(target, *args, **kwargs))
+
+ def task_done_callback(*args, **kwargs): # noqa: ARG001
+ self._tracked_tasks.remove(task)
+ if LOGGER.isEnabledFor(logging.DEBUG):
+ # print unhandled exceptions
+ task_name = getattr(task, "name", "")
+ if not task.cancelled() and task.exception():
+ task_name = task.get_name() if hasattr(task, "get_name") else task
+ LOGGER.exception(
+ "Exception in task %s",
+ task_name,
+ exc_info=task.exception(),
+ )
+
+ self._tracked_tasks.append(task)
+ task.add_done_callback(task_done_callback)
+ return task
+
+ def register_api_command(
+ self,
+ command: str,
+ handler: Callable,
+ ) -> None:
+ """Dynamically register a command on the API."""
+ assert command not in self.command_handlers, "Command already registered"
+ self.command_handlers[command] = APICommandHandler.parse(command, handler)
+
+ async def load_provider(self, conf: ProviderConfig) -> None: # noqa: C901
+ """Load (or reload) a provider."""
+ # if provider is already loaded, stop and unload it first
+ await self.unload_provider(conf.instance_id)
+
+ # abort if provider is disabled
+ if not conf.enabled:
+ LOGGER.debug(
+ "Not loading provider %s because it is disabled",
+ conf.name or conf.instance_id,
+ )
+ return
+
+ domain = conf.domain
+ prov_manifest = self._available_providers.get(domain)
+ # check for other instances of this provider
+ existing = next((x for x in self.providers if x.domain == domain), None)
+ if existing and not prov_manifest.multi_instance:
+ raise SetupFailedError(
+ f"Provider {domain} already loaded and only one instance allowed."
+ )
+
+ if not prov_manifest:
+ raise SetupFailedError(f"Provider {domain} manifest not found")
+
+ # try to load the module
+ try:
+ prov_mod = importlib.import_module(f".{domain}", "music_assistant.server.providers")
+ for name, obj in inspect.getmembers(prov_mod):
+ if not inspect.isclass(obj):
+ continue
+ # lookup class to initialize
+ if name == prov_manifest.init_class or (
+ not prov_manifest.init_class
+ and issubclass(
+ obj, MusicProvider | PlayerProvider | MetadataProvider | PluginProvider
+ )
+ and obj != MusicProvider
+ and obj != PlayerProvider
+ and obj != MetadataProvider
+ and obj != PluginProvider
+ ):
+ prov_cls = obj
+ break
+ else:
+ raise AttributeError("Unable to locate Provider class")
+ provider: ProviderInstanceType = prov_cls(self, prov_manifest, conf)
+ self._providers[provider.instance_id] = provider
+ try:
+ await provider.setup()
+ except MusicAssistantError as err:
+ provider.last_error = str(err)
+ provider.available = False
+ raise err
+
+ # mark provider as available once setup succeeded
+ provider.available = True
+ provider.last_error = None
+ # if this is a music provider, start sync
+ if provider.type == ProviderType.MUSIC:
+ await self.music.start_sync(providers=[provider.instance_id])
+ # pylint: disable=broad-except
+ except Exception as exc:
+ LOGGER.exception(
+ "Error loading provider(instance) %s: %s",
+ conf.name or conf.domain,
+ str(exc),
+ )
+ else:
+ LOGGER.debug(
+ "Successfully loaded provider %s",
+ conf.name or conf.domain,
+ )
+ # always signal event, regardless if the loading succeeded or not
+ self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
+
+ async def unload_provider(self, instance_id: str) -> None:
+ """Unload a provider."""
+ if provider := self._providers.get(instance_id):
+ # make sure to stop any running sync tasks first
+ for sync_task in self.music.in_progress_syncs:
+ if sync_task.provider_instance == instance_id:
+ sync_task.task.cancel()
+ await sync_task.task
+ await provider.close()
+ self._providers.pop(instance_id)
+ self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
+
+ def _register_api_commands(self) -> None:
+ """Register all methods decorated as api_command within a class(instance)."""
+ for cls in (
+ self,
+ self.config,
+ self.metadata,
+ self.music,
+ self.players,
+ self.players.queues,
+ ):
+ for attr_name in dir(cls):
+ if attr_name.startswith("__"):
+ continue
+ obj = getattr(cls, attr_name)
+ if hasattr(obj, "api_cmd"):
+ # method is decorated with our api decorator
+ self.register_api_command(obj.api_cmd, obj)
+
+ async def _load_providers(self) -> None:
+ """Load providers from config."""
+ # load all available providers from manifest files
+ await self.__load_available_providers()
+ loaded_providers = set()
+ async with asyncio.TaskGroup() as tg:
+ # we loop twice to solve any dependencies
+ for allow_depends_on in (False, True):
+ # load all configured (and enabled) providers
+ for prov_conf in self.config.get_provider_configs():
+ prov_manifest = self._available_providers[prov_conf.domain]
+ if prov_manifest.depends_on and not allow_depends_on:
+ continue
+ if prov_conf.instance_id in loaded_providers:
+ continue
+ loaded_providers.add(prov_conf.domain)
+ loaded_providers.add(prov_conf.instance_id)
+ tg.create_task(self.load_provider(prov_conf))
+ # create default config for any 'load_by_default' providers (e.g. URL provider)
+ # NOTE: this will auto load any not yet existing providers
+ for prov_manifest in self._available_providers.values():
+ if prov_manifest.domain in loaded_providers:
+ continue
+ if not prov_manifest.load_by_default:
+ continue
+ if prov_manifest.depends_on and not allow_depends_on:
+ continue
+ default_conf = self.config.create_provider_config(prov_manifest.domain)
+ self.config.set_provider_config(default_conf)
+
+ async def __load_available_providers(self) -> None:
+ """Preload all available provider manifest files."""
+ for dir_str in os.listdir(PROVIDERS_PATH):
+ dir_path = os.path.join(PROVIDERS_PATH, dir_str)
+ if not os.path.isdir(dir_path):
+ continue
+ # get files in subdirectory
+ for file_str in os.listdir(dir_path):
+ file_path = os.path.join(dir_path, file_str)
+ if not os.path.isfile(file_path):
+ continue
+ if file_str != "manifest.json":
+ continue
+ try:
+ provider_manifest = await ProviderManifest.parse(file_path)
+ self._available_providers[provider_manifest.domain] = provider_manifest
+ # install requirement/dependencies
+ for requirement in provider_manifest.requirements:
+ await install_package(requirement)
+ LOGGER.debug("Loaded manifest for provider %s", dir_str)
+ except Exception as exc: # pylint: disable=broad-except
+ LOGGER.exception(
+ "Error while loading manifest for provider %s",
+ dir_str,
+ exc_info=exc,
+ )
+
+ async def _setup_discovery(self) -> None:
+ """Make this Music Assistant instance discoverable on the network."""
+
+ def setup_discovery():
+ zeroconf_type = "_music-assistant._tcp.local."
+ server_id = "mass" # TODO ?
+
+ info = ServiceInfo(
+ zeroconf_type,
+ name=f"{server_id}.{zeroconf_type}",
+ addresses=[get_ip_pton()],
+ port=self.port,
+ properties={},
+ server=f"mass_{server_id}.local.",
+ )
+ LOGGER.debug("Starting Zeroconf broadcast...")
+ try:
+ existing = getattr(self, "mass_zc_service_set", None)
+ if existing:
+ self.zeroconf.update_service(info)
+ else:
+ self.zeroconf.register_service(info)
+ setattr(self, "mass_zc_service_set", True)
+ except NonUniqueNameException:
+ LOGGER.error(
+ "Music Assistant instance with identical name present in the local network!"
+ )
+
+ await asyncio.to_thread(setup_discovery)
+
+ async def __aenter__(self) -> MusicAssistant:
+ """Return Context manager."""
+ await self.start()
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException],
+ exc_val: BaseException,
+ exc_tb: TracebackType,
+ ) -> bool | None:
+ """Exit context manager."""
+ await self.stop()
+ if exc_val:
+ raise exc_val
+ return exc_type
+++ /dev/null
-[MASTER]
-ignore=tests
-ignore-patterns=app_vars
-ignore-paths=.vscode
-# Use a conservative default here; 2 should speed up most setups and not hurt
-# any too bad. Override on command line as appropriate.
-jobs=2
-persistent=no
-suggestion-mode=yes
-extension-pkg-whitelist=taglib
-
-[BASIC]
-good-names=id,i,j,k,ex,Run,_,fp,T,ev,db,d
-
-[MESSAGES CONTROL]
-# Reasons disabled:
-# format - handled by black
-# locally-disabled - it spams too much
-# duplicate-code - unavoidable
-# cyclic-import - doesn't test if both import on load
-# unused-argument - generic callbacks and setup methods create a lot of warnings
-# too-many-* - are not enforced for the sake of readability
-# too-few-* - same as too-many-*
-# abstract-method - with intro of async there are always methods missing
-# inconsistent-return-statements - doesn't handle raise
-# too-many-ancestors - it's too strict.
-# wrong-import-order - isort guards this
-# fixme - project is in development phase
-disable=
- format,
- abstract-method,
- cyclic-import,
- duplicate-code,
- inconsistent-return-statements,
- locally-disabled,
- not-context-manager,
- too-few-public-methods,
- too-many-ancestors,
- too-many-arguments,
- too-many-branches,
- too-many-instance-attributes,
- too-many-lines,
- too-many-locals,
- too-many-public-methods,
- too-many-return-statements,
- too-many-statements,
- too-many-boolean-expressions,
- unused-argument,
- wrong-import-order,
- fixme
-# enable useless-suppression temporarily every now and then to clean them up
-enable=
- use-symbolic-message-instead
-
-[REPORTS]
-score=no
-
-[REFACTORING]
-
-# Maximum number of nested blocks for function / method body
-max-nested-blocks=15
-
-[TYPECHECK]
-# For attrs
-ignored-classes=_CountingAttr
-
-[FORMAT]
-expected-line-ending-format=LF
--- /dev/null
+[build-system]
+requires = ["setuptools~=62.3", "wheel~=0.37.1"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "music_assistant"
+dynamic = ["version"]
+license = "Apache-2.0"
+description = "Music Assistant"
+readme = "README.md"
+requires-python = ">=3.11"
+author = "marcelveldt"
+author_email = "marcelveldt@users.noreply.github.com"
+classifiers = [
+ "Environment :: Console",
+ "Programming Language :: Python :: 3.11",
+]
+dependencies = [
+ "aiohttp",
+ "coloredlogs",
+ "orjson",
+ "mashumaro>=3.5"
+]
+
+[project.optional-dependencies]
+server = [
+ "aiohttp==3.8.4",
+ "asyncio-throttle==1.0.2",
+ "aiofiles==23.1.0",
+ "aiorun==2022.11.1",
+ "databases==0.7.0",
+ "aiosqlite==0.18.0",
+ "python-slugify==7.0.0",
+ "memory-tempfile==2.2.3",
+ "pillow==9.4.0",
+ "unidecode==1.3.6",
+ "mashumaro==3.5",
+ "xmltodict==0.13.0",
+ "orjson==3.8.6",
+ "shortuuid==1.0.11",
+ "zeroconf==0.47.3",
+ "cryptography==39.0.2"
+]
+test = [
+ "black==23.1.0",
+ "codespell==2.2.2",
+ "mypy==1.0.1",
+ "ruff==0.0.254",
+ "pytest==7.2.1",
+ "pytest-aiohttp==1.0.4",
+ "pytest-cov==4.0.0",
+ "pre-commit==2.20.0"
+]
+
+[project.scripts]
+mass = "music_assistant.server.__main__:main"
+
+[tool.setuptools.dynamic]
+version = {attr = "music_assistant.constants.__version__"}
+
+[tool.black]
+target-version = ['py311']
+line-length = 100
+
+[tool.codespell]
+ignore-words-list = "provid"
+
+[tool.mypy]
+python_version = "3.11"
+check_untyped_defs = true
+#disallow_any_generics = true
+disallow_incomplete_defs = true
+disallow_untyped_calls = false
+disallow_untyped_defs = true
+mypy_path = "music_assistant/"
+no_implicit_optional = true
+show_error_codes = true
+warn_incomplete_stub = true
+warn_redundant_casts = true
+warn_return_any = true
+warn_unreachable = true
+warn_unused_configs = true
+warn_unused_ignores = true
+
+[[tool.mypy.overrides]]
+ignore_missing_imports = true
+module = [
+ "aiorun",
+ "coloredlogs",
+]
+
+[tool.pytest.ini_options]
+asyncio_mode = "auto"
+
+[tool.setuptools]
+platforms = ["any"]
+zip-safe = false
+include-package-data = true
+
+[tool.setuptools.package-data]
+music_assistant = ["py.typed"]
+
+[tool.setuptools.packages.find]
+include = ["music_assistant*"]
+
+[tool.ruff]
+fix = true
+show-fixes = true
+
+# enable later: "C90", "PTH", "TCH", "RET", "ANN"
+select = ["E", "F", "W", "I", "N", "D", "UP", "PL", "Q", "SIM", "TID", "ARG"]
+ignore = ["PLR2004", "N818"]
+extend-exclude = ["app_vars.py"]
+unfixable = ["F841"]
+line-length = 100
+target-version = "py311"
+
+[tool.ruff.flake8-annotations]
+allow-star-arg-any = true
+suppress-dummy-args = true
+
+[tool.ruff.flake8-builtins]
+builtins-ignorelist = ["id"]
+
+[tool.ruff.pydocstyle]
+# Use Google-style docstrings.
+convention = "pep257"
+
+[tool.ruff.pylint]
+
+max-branches=25
+max-returns=15
+max-args=10
+max-statements=50
+++ /dev/null
-async-timeout>=3.0,<=4.0.2
-aiohttp>=3.7.0,>=3.8.1
-asyncio-throttle>=1.0,<=1.0.2
-aiofiles>=0.7.0,<22.1.1
-databases>=0.6.3
-aiosqlite>=0.13,<=0.17
-python-slugify>=4.0,<7.0.1
-memory-tempfile<=2.2.3
-pillow>=8.0,<9.4.0
-unidecode>=1.0,<1.3.7
-mashumaro>=3.0,<=3.1
-xmltodict>=0.12.0,<=0.13.0
-ytmusicapi>=0.22.0,<0.25.0
-pytube>=12.1.0,<=12.2.0
-pysmb>=1.2.8,<=1.3.0
+++ /dev/null
--r requirements.txt
-black==22.10.0
-flake8==6.0.0
-mypy==0.982
-pydocstyle==6.1.1
-pylint==2.15.5
-pytest-aiohttp==1.0.4
-pytest-cov==4.0.0
-pytest-freezegun==0.4.2
-pytest-socket==0.5.1
-pytest-test-groups==1.0.3
-pytest-sugar==0.9.6
-pytest-timeout==2.1.0
-pytest-xdist==3.0.2
-pytest==7.2.0
-pre-commit==2.20.0
+++ /dev/null
-[flake8]
-exclude = .venv,.git,.tox,docs,venv,bin,lib,deps,build
-# To work with Black
-max-line-length = 100
-# E501: line too long
-# W503: Line break occurred before a binary operator
-# E203: Whitespace before ':'
-# D202 No blank lines allowed after function docstring
-# W504 line break after binary operator
-ignore =
- E501,
- W503,
- E203,
- D202,
- W504,
- E266
-
-[isort]
-profile = black
-multi_line_output = 3
-include_trailing_comma = True
-force_grid_wrap = 0
-use_parentheses = True
-line_length = 88
-
-[mypy]
-python_version = 3.9
-ignore_errors = true
-follow_imports = silent
-ignore_missing_imports = true
-warn_incomplete_stub = true
-warn_redundant_casts = true
-warn_unused_configs = true
-
-[pydocstyle]
-add-ignore = D202
-
-[pylint.master]
-ignore=tests
-ignore-patterns=app_vars
-# Use a conservative default here; 2 should speed up most setups and not hurt
-# any too bad. Override on command line as appropriate.
-jobs=2
-persistent=no
-suggestion-mode=yes
-extension-pkg-whitelist=taglib,orjson
-
-[pylint.basic]
-good-names=id,i,j,k,ex,Run,_,fp,T,ev,db,d
-
-[pylint.messages_control]
-# Reasons disabled:
-# format - handled by black
-# locally-disabled - it spams too much
-# duplicate-code - unavoidable
-# cyclic-import - doesn't test if both import on load
-# abstract-class-little-used - prevents from setting right foundation
-# unused-argument - generic callbacks and setup methods create a lot of warnings
-# too-many-* - are not enforced for the sake of readability
-# too-few-* - same as too-many-*
-# abstract-method - with intro of async there are always methods missing
-# inconsistent-return-statements - doesn't handle raise
-# too-many-ancestors - it's too strict.
-# wrong-import-order - isort guards this
-# fixme - project is in development phase
-# c-extension-no-member - it was giving me headaches
-disable=
- format,
- abstract-class-little-used,
- abstract-method,
- cyclic-import,
- duplicate-code,
- inconsistent-return-statements,
- locally-disabled,
- not-context-manager,
- too-few-public-methods,
- too-many-ancestors,
- too-many-arguments,
- too-many-branches,
- too-many-instance-attributes,
- too-many-lines,
- too-many-locals,
- too-many-public-methods,
- too-many-return-statements,
- too-many-statements,
- too-many-boolean-expressions,
- unused-argument,
- wrong-import-order,
- fixme,
- c-extension-no-member
-
-# enable useless-suppression temporarily every now and then to clean them up
-enable=
- use-symbolic-message-instead
-
-[pylint.reports]
-score=no
-
-[pylint.refactoring]
-# Maximum number of nested blocks for function / method body
-max-nested-blocks=15
-
-[pylint.typecheck]
-# For attrs
-ignored-classes=_CountingAttr
-
-[pylint.format]
-expected-line-ending-format=LF
+++ /dev/null
-"""Music Assistant setup."""
-import os
-from pathlib import Path
-
-from setuptools import find_packages, setup
-
-PROJECT_NAME = "Music Assistant"
-PROJECT_PACKAGE_NAME = "music_assistant"
-PROJECT_VERSION = "1.8.8"
-PROJECT_REQ_PYTHON_VERSION = "3.9"
-PROJECT_LICENSE = "Apache License 2.0"
-PROJECT_AUTHOR = "Marcel van der Veldt"
-PROJECT_EMAIL = "marcelveldt@users.noreply.github.com"
-
-PROJECT_GITHUB_USERNAME = "music-assistant"
-PROJECT_GITHUB_REPOSITORY = "music-assistant"
-
-PYPI_URL = f"https://pypi.python.org/pypi/{PROJECT_PACKAGE_NAME}"
-GITHUB_PATH = f"{PROJECT_GITHUB_USERNAME}/{PROJECT_GITHUB_REPOSITORY}"
-GITHUB_URL = f"https://github.com/{GITHUB_PATH}"
-
-DOWNLOAD_URL = f"{GITHUB_URL}/archive/{PROJECT_VERSION}.zip"
-PROJECT_URLS = {
- "Bug Reports": f"{GITHUB_URL}/issues",
- "Website": GITHUB_URL,
- "Discord": "https://discord.gg/AmDBM6QCAs",
-}
-PROJECT_DIR = Path(__file__).parent.resolve()
-README_FILE = PROJECT_DIR / "README.md"
-REQUIREMENTS_FILE = PROJECT_DIR / "requirements.txt"
-PACKAGES = find_packages(exclude=["tests", "tests.*"])
-PACKAGE_FILES = []
-for (path, directories, filenames) in os.walk("music_assistant/"):
- for filename in filenames:
- PACKAGE_FILES.append(os.path.join("..", path, filename))
-
-setup(
- name=PROJECT_PACKAGE_NAME,
- version=PROJECT_VERSION,
- url=GITHUB_URL,
- download_url=DOWNLOAD_URL,
- project_urls=PROJECT_URLS,
- author=PROJECT_AUTHOR,
- author_email=PROJECT_EMAIL,
- long_description=README_FILE.read_text(encoding="utf-8"),
- long_description_content_type="text/markdown",
- packages=PACKAGES,
- include_package_data=True,
- zip_safe=False,
- install_requires=REQUIREMENTS_FILE.read_text(encoding="utf-8"),
- python_requires=f">={PROJECT_REQ_PYTHON_VERSION}",
- test_suite="tests",
- package_data={"music_assistant": PACKAGE_FILES},
-)
from pytest import raises
-from music_assistant.helpers import uri, util
-from music_assistant.models import media_items
-from music_assistant.models.enums import ProviderType
-from music_assistant.models.errors import MusicAssistantError
+from music_assistant.common.helpers import uri, util
+from music_assistant.common.models import media_items
+from music_assistant.common.models.errors import MusicAssistantError
def test_version_extract():
"""Test the extraction of version from title."""
-
test_str = "Bam Bam (feat. Ed Sheeran) - Karaoke Version"
title, version = util.parse_title_and_version(test_str)
assert title == "Bam Bam"
test_uri = "spotify://track/123456789"
media_type, provider, item_id = uri.parse_uri(test_uri)
assert media_type == media_items.MediaType.TRACK
- assert provider == ProviderType.SPOTIFY
+ assert provider == "spotify"
assert item_id == "123456789"
# test spotify uri
test_uri = "spotify:track:123456789"
media_type, provider, item_id = uri.parse_uri(test_uri)
assert media_type == media_items.MediaType.TRACK
- assert provider == ProviderType.SPOTIFY
+ assert provider == "spotify"
assert item_id == "123456789"
# test public play/open url
- test_uri = (
- "https://open.spotify.com/playlist/5lH9NjOeJvctAO92ZrKQNB?si=04a63c8234ac413e"
- )
+ test_uri = "https://open.spotify.com/playlist/5lH9NjOeJvctAO92ZrKQNB?si=04a63c8234ac413e"
media_type, provider, item_id = uri.parse_uri(test_uri)
assert media_type == media_items.MediaType.PLAYLIST
- assert provider == ProviderType.SPOTIFY
+ assert provider == "spotify"
assert item_id == "5lH9NjOeJvctAO92ZrKQNB"
# test filename with slashes as item_id
test_uri = "filesystem://track/Artist/Album/Track.flac"
media_type, provider, item_id = uri.parse_uri(test_uri)
assert media_type == media_items.MediaType.TRACK
- assert provider == ProviderType.FILESYSTEM_LOCAL
+ assert provider == "filesystem"
assert item_id == "Artist/Album/Track.flac"
# test invalid uri
with raises(MusicAssistantError):
import pathlib
-from music_assistant.helpers import tags
+from music_assistant.server.helpers import tags
RESOURCES_DIR = pathlib.Path(__file__).parent.resolve().joinpath("fixtures")