Create a testing environment in production for ListenBrainz recommendation engine (troi-recommendation-playground)
This commit is contained in:
parent
cc0f8f395c
commit
f821dcbbc2
19 changed files with 1124 additions and 11 deletions
135
api/funkwhale_api/radios/lb_recommendations.py
Normal file
135
api/funkwhale_api/radios/lb_recommendations.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
import troi
|
||||
import troi.core
|
||||
from django.core.cache import cache
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db.models import Q
|
||||
from requests.exceptions import ConnectTimeout
|
||||
|
||||
from funkwhale_api.music import models as music_models
|
||||
from funkwhale_api.typesense import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
patches = troi.utils.discover_patches()
|
||||
|
||||
SUPPORTED_PATCHES = patches.keys()
|
||||
|
||||
|
||||
def run(config, **kwargs):
|
||||
"""Validate the received config and run the queryset generation"""
|
||||
candidates = kwargs.pop("candidates", music_models.Track.objects.all())
|
||||
validate(config)
|
||||
return TroiPatch().get_queryset(config, candidates)
|
||||
|
||||
|
||||
def validate(config):
|
||||
patch = config.get("patch")
|
||||
if patch not in SUPPORTED_PATCHES:
|
||||
raise ValidationError(
|
||||
'Invalid patch "{}". Supported patches: {}'.format(
|
||||
config["patch"], SUPPORTED_PATCHES
|
||||
)
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def build_radio_queryset(patch, config, radio_qs):
|
||||
"""Take a troi patch and its arg, match the missing mbid and then build a radio queryset"""
|
||||
|
||||
logger.info("Config used for troi radio generation is " + str(config))
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
recommendations = troi.core.generate_playlist(patch, config)
|
||||
except ConnectTimeout:
|
||||
raise ValueError(
|
||||
"Timed out while connecting to ListenBrainz. No candidates could be retrieved for the radio."
|
||||
)
|
||||
end_time_rec = time.time()
|
||||
logger.info("Troi fetch took :" + str(end_time_rec - start_time))
|
||||
|
||||
if not recommendations:
|
||||
raise ValueError("No candidates found by troi")
|
||||
|
||||
recommended_recording_mbids = [
|
||||
recommended_recording.mbid
|
||||
for recommended_recording in recommendations.playlists[0].recordings
|
||||
]
|
||||
|
||||
logger.info("Searching for MusicBrainz ID in Funkwhale database")
|
||||
|
||||
qs_mbid = music_models.Track.objects.all().filter(
|
||||
mbid__in=recommended_recording_mbids
|
||||
)
|
||||
mbids_found = [str(i.mbid) for i in qs_mbid]
|
||||
|
||||
recommended_recording_mbids_not_found = [
|
||||
mbid for mbid in recommended_recording_mbids if mbid not in mbids_found
|
||||
]
|
||||
cached_mbid_match = cache.get_many(recommended_recording_mbids_not_found)
|
||||
|
||||
if qs_mbid and cached_mbid_match:
|
||||
logger.info("MusicBrainz IDs found in Funkwhale database and redis")
|
||||
mbids_found = [str(i.mbid) for i in qs_mbid]
|
||||
mbids_found.extend([i for i in cached_mbid_match.keys()])
|
||||
elif qs_mbid and not cached_mbid_match:
|
||||
logger.info("MusicBrainz IDs found in Funkwhale database")
|
||||
mbids_found = mbids_found
|
||||
elif not qs_mbid and cached_mbid_match:
|
||||
logger.info("MusicBrainz IDs found in redis cache")
|
||||
mbids_found = [i for i in cached_mbid_match.keys()]
|
||||
else:
|
||||
logger.info(
|
||||
"Couldn't find any matches in Funkwhale database. Trying to match all"
|
||||
)
|
||||
mbids_found = []
|
||||
|
||||
recommended_recordings_not_found = [
|
||||
i for i in recommendations.playlists[0].recordings if i.mbid not in mbids_found
|
||||
]
|
||||
|
||||
logger.info("Matching missing MusicBrainz ID to Funkwhale track")
|
||||
|
||||
start_time_resolv = time.time()
|
||||
utils.resolve_recordings_to_fw_track(recommended_recordings_not_found)
|
||||
end_time_resolv = time.time()
|
||||
|
||||
logger.info(
|
||||
"Resolving "
|
||||
+ str(len(recommended_recordings_not_found))
|
||||
+ " tracks in "
|
||||
+ str(end_time_resolv - start_time_resolv)
|
||||
)
|
||||
|
||||
cached_mbid_match = cache.get_many(recommended_recording_mbids_not_found)
|
||||
|
||||
if not qs_mbid and not cached_mbid_match:
|
||||
raise ValueError("No candidates found for troi radio")
|
||||
|
||||
logger.info("Radio generation with troi took " + str(end_time_resolv - start_time))
|
||||
logger.info("qs_mbid is " + str(mbids_found))
|
||||
|
||||
if qs_mbid and cached_mbid_match:
|
||||
return radio_qs.filter(
|
||||
Q(mbid__in=mbids_found) | Q(pk__in=cached_mbid_match.values())
|
||||
)
|
||||
if qs_mbid and not cached_mbid_match:
|
||||
return radio_qs.filter(mbid__in=mbids_found)
|
||||
|
||||
if not qs_mbid and cached_mbid_match:
|
||||
return radio_qs.filter(pk__in=cached_mbid_match.values())
|
||||
|
||||
|
||||
class TroiPatch:
|
||||
code = "troi-patch"
|
||||
label = "Troi Patch"
|
||||
|
||||
def get_queryset(self, config, qs):
|
||||
patch_string = config.pop("patch")
|
||||
patch = patches[patch_string]
|
||||
return build_radio_queryset(patch(), config, qs)
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
|
|
@ -12,6 +13,7 @@ from funkwhale_api.federation import fields as federation_fields
|
|||
from funkwhale_api.federation import models as federation_models
|
||||
from funkwhale_api.moderation import filters as moderation_filters
|
||||
from funkwhale_api.music.models import Artist, Library, Track, Upload
|
||||
from funkwhale_api.radios import lb_recommendations
|
||||
from funkwhale_api.tags.models import Tag
|
||||
|
||||
from . import filters, models
|
||||
|
|
@ -189,9 +191,7 @@ class CustomMultiple(SessionRadio):
|
|||
|
||||
def validate_session(self, data, **context):
|
||||
data = super().validate_session(data, **context)
|
||||
try:
|
||||
data["config"] is not None
|
||||
except KeyError:
|
||||
if data.get("config") is None:
|
||||
raise serializers.ValidationError(
|
||||
"You must provide a configuration for this radio"
|
||||
)
|
||||
|
|
@ -405,3 +405,58 @@ class RecentlyAdded(SessionRadio):
|
|||
Q(artist__content_category="music"),
|
||||
Q(creation_date__gt=date),
|
||||
)
|
||||
|
||||
|
||||
# Use this to experiment on the custom multiple radio with troi
|
||||
@registry.register(name="troi")
|
||||
class Troi(SessionRadio):
|
||||
"""
|
||||
Receive a vuejs generated config and use it to launch a troi radio session.
|
||||
The config data should follow :
|
||||
{"patch": "troi_patch_name", "troi_arg1":"troi_arg_1", "troi_arg2": ...}
|
||||
Validation of the config (args) is done by troi during track fetch.
|
||||
Funkwhale only checks if the patch is implemented
|
||||
"""
|
||||
|
||||
config = serializers.JSONField(required=True)
|
||||
|
||||
def append_lb_config(self, data):
|
||||
if self.session.user.settings is None:
|
||||
logger.warning(
|
||||
"No lb_user_name set in user settings. Some troi patches will fail"
|
||||
)
|
||||
return data
|
||||
elif self.session.user.settings.get("lb_user_name") is None:
|
||||
logger.warning(
|
||||
"No lb_user_name set in user settings. Some troi patches will fail"
|
||||
)
|
||||
else:
|
||||
data["user_name"] = self.session.user.settings["lb_user_name"]
|
||||
|
||||
if self.session.user.settings.get("lb_user_token") is None:
|
||||
logger.warning(
|
||||
"No lb_user_token set in user settings. Some troi patch will fail"
|
||||
)
|
||||
else:
|
||||
data["user_token"] = self.session.user.settings["lb_user_token"]
|
||||
|
||||
return data
|
||||
|
||||
def get_queryset_kwargs(self):
|
||||
kwargs = super().get_queryset_kwargs()
|
||||
kwargs["config"] = self.session.config
|
||||
return kwargs
|
||||
|
||||
def validate_session(self, data, **context):
|
||||
data = super().validate_session(data, **context)
|
||||
if data.get("config") is None:
|
||||
raise serializers.ValidationError(
|
||||
"You must provide a configuration for this radio"
|
||||
)
|
||||
return data
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
config = self.append_lb_config(json.loads(kwargs["config"]))
|
||||
|
||||
return lb_recommendations.run(config, candidates=qs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue