Merge branch '248-invite' into 'develop'
Resolve "Invite system" Closes #248 See merge request funkwhale/funkwhale!263
This commit is contained in:
commit
afe9ad2c91
27 changed files with 729 additions and 66 deletions
|
|
@ -461,3 +461,7 @@ MUSIC_DIRECTORY_PATH = env("MUSIC_DIRECTORY_PATH", default=None)
|
|||
MUSIC_DIRECTORY_SERVE_PATH = env(
|
||||
"MUSIC_DIRECTORY_SERVE_PATH", default=MUSIC_DIRECTORY_PATH
|
||||
)
|
||||
|
||||
USERS_INVITATION_EXPIRATION_DAYS = env.int(
|
||||
"USERS_INVITATION_EXPIRATION_DAYS", default=14
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,16 @@
|
|||
from rest_framework import serializers
|
||||
|
||||
|
||||
class Action(object):
|
||||
def __init__(self, name, allow_all=False, filters=None):
|
||||
self.name = name
|
||||
self.allow_all = allow_all
|
||||
self.filters = filters or {}
|
||||
|
||||
def __repr__(self):
|
||||
return "<Action {}>".format(self.name)
|
||||
|
||||
|
||||
class ActionSerializer(serializers.Serializer):
|
||||
"""
|
||||
A special serializer that can operate on a list of objects
|
||||
|
|
@ -11,19 +21,16 @@ class ActionSerializer(serializers.Serializer):
|
|||
objects = serializers.JSONField(required=True)
|
||||
filters = serializers.DictField(required=False)
|
||||
actions = None
|
||||
filterset_class = None
|
||||
# those are actions identifier where we don't want to allow the "all"
|
||||
# selector because it's to dangerous. Like object deletion.
|
||||
dangerous_actions = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.actions_by_name = {a.name: a for a in self.actions}
|
||||
self.queryset = kwargs.pop("queryset")
|
||||
if self.actions is None:
|
||||
raise ValueError(
|
||||
"You must declare a list of actions on " "the serializer class"
|
||||
)
|
||||
|
||||
for action in self.actions:
|
||||
for action in self.actions_by_name.keys():
|
||||
handler_name = "handle_{}".format(action)
|
||||
assert hasattr(self, handler_name), "{} miss a {} method".format(
|
||||
self.__class__.__name__, handler_name
|
||||
|
|
@ -31,13 +38,14 @@ class ActionSerializer(serializers.Serializer):
|
|||
super().__init__(self, *args, **kwargs)
|
||||
|
||||
def validate_action(self, value):
|
||||
if value not in self.actions:
|
||||
try:
|
||||
return self.actions_by_name[value]
|
||||
except KeyError:
|
||||
raise serializers.ValidationError(
|
||||
"{} is not a valid action. Pick one of {}.".format(
|
||||
value, ", ".join(self.actions)
|
||||
value, ", ".join(self.actions_by_name.keys())
|
||||
)
|
||||
)
|
||||
return value
|
||||
|
||||
def validate_objects(self, value):
|
||||
if value == "all":
|
||||
|
|
@ -51,15 +59,15 @@ class ActionSerializer(serializers.Serializer):
|
|||
)
|
||||
|
||||
def validate(self, data):
|
||||
dangerous = data["action"] in self.dangerous_actions
|
||||
if dangerous and self.initial_data["objects"] == "all":
|
||||
allow_all = data["action"].allow_all
|
||||
if not allow_all and self.initial_data["objects"] == "all":
|
||||
raise serializers.ValidationError(
|
||||
"This action is to dangerous to be applied to all objects"
|
||||
)
|
||||
if self.filterset_class and "filters" in data:
|
||||
qs_filterset = self.filterset_class(
|
||||
data["filters"], queryset=data["objects"]
|
||||
"You cannot apply this action on all objects"
|
||||
)
|
||||
final_filters = data.get("filters", {}) or {}
|
||||
final_filters.update(data["action"].filters)
|
||||
if self.filterset_class and final_filters:
|
||||
qs_filterset = self.filterset_class(final_filters, queryset=data["objects"])
|
||||
try:
|
||||
assert qs_filterset.form.is_valid()
|
||||
except (AssertionError, TypeError):
|
||||
|
|
@ -72,12 +80,12 @@ class ActionSerializer(serializers.Serializer):
|
|||
return data
|
||||
|
||||
def save(self):
|
||||
handler_name = "handle_{}".format(self.validated_data["action"])
|
||||
handler_name = "handle_{}".format(self.validated_data["action"].name)
|
||||
handler = getattr(self, handler_name)
|
||||
result = handler(self.validated_data["objects"])
|
||||
payload = {
|
||||
"updated": self.validated_data["count"],
|
||||
"action": self.validated_data["action"],
|
||||
"action": self.validated_data["action"].name,
|
||||
"result": result,
|
||||
}
|
||||
return payload
|
||||
|
|
|
|||
|
|
@ -769,7 +769,7 @@ class CollectionSerializer(serializers.Serializer):
|
|||
|
||||
|
||||
class LibraryTrackActionSerializer(common_serializers.ActionSerializer):
|
||||
actions = ["import"]
|
||||
actions = [common_serializers.Action("import", allow_all=True)]
|
||||
filterset_class = filters.LibraryTrackFilter
|
||||
|
||||
@transaction.atomic
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from django_filters import rest_framework as filters
|
||||
|
||||
from funkwhale_api.common import fields
|
||||
|
|
@ -37,3 +36,17 @@ class ManageUserFilterSet(filters.FilterSet):
|
|||
"permission_settings",
|
||||
"permission_federation",
|
||||
]
|
||||
|
||||
|
||||
class ManageInvitationFilterSet(filters.FilterSet):
|
||||
q = fields.SearchFilter(search_fields=["owner__username", "code", "owner__email"])
|
||||
is_open = filters.BooleanFilter(method="filter_is_open")
|
||||
|
||||
class Meta:
|
||||
model = users_models.Invitation
|
||||
fields = ["q", "is_open"]
|
||||
|
||||
def filter_is_open(self, queryset, field_name, value):
|
||||
if value is None:
|
||||
return queryset
|
||||
return queryset.open(value)
|
||||
|
|
|
|||
|
|
@ -61,8 +61,7 @@ class ManageTrackFileSerializer(serializers.ModelSerializer):
|
|||
|
||||
|
||||
class ManageTrackFileActionSerializer(common_serializers.ActionSerializer):
|
||||
actions = ["delete"]
|
||||
dangerous_actions = ["delete"]
|
||||
actions = [common_serializers.Action("delete", allow_all=False)]
|
||||
filterset_class = filters.ManageTrackFileFilterSet
|
||||
|
||||
@transaction.atomic
|
||||
|
|
@ -78,6 +77,23 @@ class PermissionsSerializer(serializers.Serializer):
|
|||
return {"permissions": o}
|
||||
|
||||
|
||||
class ManageUserSimpleSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = users_models.User
|
||||
fields = (
|
||||
"id",
|
||||
"username",
|
||||
"email",
|
||||
"name",
|
||||
"is_active",
|
||||
"is_staff",
|
||||
"is_superuser",
|
||||
"date_joined",
|
||||
"last_activity",
|
||||
"privacy_level",
|
||||
)
|
||||
|
||||
|
||||
class ManageUserSerializer(serializers.ModelSerializer):
|
||||
permissions = PermissionsSerializer(source="*")
|
||||
|
||||
|
|
@ -115,3 +131,32 @@ class ManageUserSerializer(serializers.ModelSerializer):
|
|||
update_fields=["permission_{}".format(p) for p in permissions.keys()]
|
||||
)
|
||||
return instance
|
||||
|
||||
|
||||
class ManageInvitationSerializer(serializers.ModelSerializer):
|
||||
users = ManageUserSimpleSerializer(many=True, required=False)
|
||||
owner = ManageUserSimpleSerializer(required=False)
|
||||
code = serializers.CharField(required=False, allow_null=True)
|
||||
|
||||
class Meta:
|
||||
model = users_models.Invitation
|
||||
fields = ("id", "owner", "code", "expiration_date", "creation_date", "users")
|
||||
read_only_fields = ["id", "expiration_date", "owner", "creation_date", "users"]
|
||||
|
||||
def validate_code(self, value):
|
||||
if not value:
|
||||
return value
|
||||
if users_models.Invitation.objects.filter(code__iexact=value).exists():
|
||||
raise serializers.ValidationError(
|
||||
"An invitation with this code already exists"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class ManageInvitationActionSerializer(common_serializers.ActionSerializer):
|
||||
actions = [common_serializers.Action("delete", allow_all=False)]
|
||||
filterset_class = filters.ManageInvitationFilterSet
|
||||
|
||||
@transaction.atomic
|
||||
def handle_delete(self, objects):
|
||||
return objects.delete()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ library_router = routers.SimpleRouter()
|
|||
library_router.register(r"track-files", views.ManageTrackFileViewSet, "track-files")
|
||||
users_router = routers.SimpleRouter()
|
||||
users_router.register(r"users", views.ManageUserViewSet, "users")
|
||||
users_router.register(r"invitations", views.ManageInvitationViewSet, "invitations")
|
||||
|
||||
urlpatterns = [
|
||||
url(r"^library/", include((library_router.urls, "instance"), namespace="library")),
|
||||
|
|
|
|||
|
|
@ -62,3 +62,37 @@ class ManageUserViewSet(
|
|||
context = super().get_serializer_context()
|
||||
context["default_permissions"] = preferences.get("users__default_permissions")
|
||||
return context
|
||||
|
||||
|
||||
class ManageInvitationViewSet(
|
||||
mixins.CreateModelMixin,
|
||||
mixins.ListModelMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.UpdateModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
viewsets.GenericViewSet,
|
||||
):
|
||||
queryset = (
|
||||
users_models.Invitation.objects.all()
|
||||
.order_by("-id")
|
||||
.prefetch_related("users")
|
||||
.select_related("owner")
|
||||
)
|
||||
serializer_class = serializers.ManageInvitationSerializer
|
||||
filter_class = filters.ManageInvitationFilterSet
|
||||
permission_classes = (HasUserPermission,)
|
||||
required_permissions = ["settings"]
|
||||
ordering_fields = ["creation_date", "expiration_date"]
|
||||
|
||||
def perform_create(self, serializer):
|
||||
serializer.save(owner=self.request.user)
|
||||
|
||||
@list_route(methods=["post"])
|
||||
def action(self, request, *args, **kwargs):
|
||||
queryset = self.get_queryset()
|
||||
serializer = serializers.ManageInvitationActionSerializer(
|
||||
request.data, queryset=queryset
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
result = serializer.save()
|
||||
return response.Response(result, status=200)
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ from django.contrib.auth.admin import UserAdmin as AuthUserAdmin
|
|||
from django.contrib.auth.forms import UserChangeForm, UserCreationForm
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from .models import User
|
||||
from . import models
|
||||
|
||||
|
||||
class MyUserChangeForm(UserChangeForm):
|
||||
class Meta(UserChangeForm.Meta):
|
||||
model = User
|
||||
model = models.User
|
||||
|
||||
|
||||
class MyUserCreationForm(UserCreationForm):
|
||||
|
|
@ -22,18 +22,18 @@ class MyUserCreationForm(UserCreationForm):
|
|||
)
|
||||
|
||||
class Meta(UserCreationForm.Meta):
|
||||
model = User
|
||||
model = models.User
|
||||
|
||||
def clean_username(self):
|
||||
username = self.cleaned_data["username"]
|
||||
try:
|
||||
User.objects.get(username=username)
|
||||
except User.DoesNotExist:
|
||||
models.User.objects.get(username=username)
|
||||
except models.User.DoesNotExist:
|
||||
return username
|
||||
raise forms.ValidationError(self.error_messages["duplicate_username"])
|
||||
|
||||
|
||||
@admin.register(User)
|
||||
@admin.register(models.User)
|
||||
class UserAdmin(AuthUserAdmin):
|
||||
form = MyUserChangeForm
|
||||
add_form = MyUserCreationForm
|
||||
|
|
@ -74,3 +74,11 @@ class UserAdmin(AuthUserAdmin):
|
|||
(_("Important dates"), {"fields": ("last_login", "date_joined")}),
|
||||
(_("Useless fields"), {"fields": ("user_permissions", "groups")}),
|
||||
)
|
||||
|
||||
|
||||
@admin.register(models.Invitation)
|
||||
class InvitationAdmin(admin.ModelAdmin):
|
||||
list_select_related = True
|
||||
list_display = ["owner", "code", "creation_date", "expiration_date"]
|
||||
search_fields = ["owner__username", "code"]
|
||||
readonly_fields = ["expiration_date", "code"]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import factory
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.utils import timezone
|
||||
|
||||
from funkwhale_api.factories import ManyToManyFromList, registry
|
||||
|
||||
|
|
@ -28,6 +29,17 @@ class GroupFactory(factory.django.DjangoModelFactory):
|
|||
self.permissions.add(*perms)
|
||||
|
||||
|
||||
@registry.register
|
||||
class InvitationFactory(factory.django.DjangoModelFactory):
|
||||
owner = factory.LazyFunction(lambda: UserFactory())
|
||||
|
||||
class Meta:
|
||||
model = "users.Invitation"
|
||||
|
||||
class Params:
|
||||
expired = factory.Trait(expiration_date=factory.LazyFunction(timezone.now))
|
||||
|
||||
|
||||
@registry.register
|
||||
class UserFactory(factory.django.DjangoModelFactory):
|
||||
username = factory.Sequence(lambda n: "user-{0}".format(n))
|
||||
|
|
@ -40,6 +52,9 @@ class UserFactory(factory.django.DjangoModelFactory):
|
|||
model = "users.User"
|
||||
django_get_or_create = ("username",)
|
||||
|
||||
class Params:
|
||||
invited = factory.Trait(invitation=factory.SubFactory(InvitationFactory))
|
||||
|
||||
@factory.post_generation
|
||||
def perms(self, create, extracted, **kwargs):
|
||||
if not create:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
# Generated by Django 2.0.6 on 2018-06-19 20:24
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
import django.utils.timezone
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('users', '0008_auto_20180617_1531'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='Invitation',
|
||||
fields=[
|
||||
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('creation_date', models.DateTimeField(default=django.utils.timezone.now)),
|
||||
('expiration_date', models.DateTimeField()),
|
||||
('code', models.CharField(max_length=50, unique=True)),
|
||||
('owner', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='invitations', to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='user',
|
||||
name='invitation',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='users', to='users.Invitation'),
|
||||
),
|
||||
]
|
||||
|
|
@ -4,6 +4,8 @@ from __future__ import absolute_import, unicode_literals
|
|||
import binascii
|
||||
import datetime
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import uuid
|
||||
|
||||
from django.conf import settings
|
||||
|
|
@ -79,6 +81,14 @@ class User(AbstractUser):
|
|||
|
||||
last_activity = models.DateTimeField(default=None, null=True, blank=True)
|
||||
|
||||
invitation = models.ForeignKey(
|
||||
"Invitation",
|
||||
related_name="users",
|
||||
null=True,
|
||||
blank=True,
|
||||
on_delete=models.SET_NULL,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.username
|
||||
|
||||
|
|
@ -138,3 +148,40 @@ class User(AbstractUser):
|
|||
if current is None or current < now - datetime.timedelta(seconds=delay):
|
||||
self.last_activity = now
|
||||
self.save(update_fields=["last_activity"])
|
||||
|
||||
|
||||
def generate_code(length=10):
|
||||
return "".join(
|
||||
random.SystemRandom().choice(string.ascii_uppercase) for _ in range(length)
|
||||
)
|
||||
|
||||
|
||||
class InvitationQuerySet(models.QuerySet):
|
||||
def open(self, include=True):
|
||||
now = timezone.now()
|
||||
qs = self.annotate(_users=models.Count("users"))
|
||||
query = models.Q(_users=0, expiration_date__gt=now)
|
||||
if include:
|
||||
return qs.filter(query)
|
||||
return qs.exclude(query)
|
||||
|
||||
|
||||
class Invitation(models.Model):
|
||||
creation_date = models.DateTimeField(default=timezone.now)
|
||||
expiration_date = models.DateTimeField()
|
||||
owner = models.ForeignKey(
|
||||
User, related_name="invitations", on_delete=models.CASCADE
|
||||
)
|
||||
code = models.CharField(max_length=50, unique=True)
|
||||
|
||||
objects = InvitationQuerySet.as_manager()
|
||||
|
||||
def save(self, **kwargs):
|
||||
if not self.code:
|
||||
self.code = generate_code()
|
||||
if not self.expiration_date:
|
||||
self.expiration_date = self.creation_date + datetime.timedelta(
|
||||
days=settings.USERS_INVITATION_EXPIRATION_DAYS
|
||||
)
|
||||
|
||||
return super().save(**kwargs)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from django.conf import settings
|
||||
from rest_auth.serializers import PasswordResetSerializer as PRS
|
||||
from rest_auth.registration.serializers import RegisterSerializer as RS
|
||||
from rest_framework import serializers
|
||||
|
||||
from funkwhale_api.activity import serializers as activity_serializers
|
||||
|
|
@ -7,6 +8,28 @@ from funkwhale_api.activity import serializers as activity_serializers
|
|||
from . import models
|
||||
|
||||
|
||||
class RegisterSerializer(RS):
|
||||
invitation = serializers.CharField(
|
||||
required=False, allow_null=True, allow_blank=True
|
||||
)
|
||||
|
||||
def validate_invitation(self, value):
|
||||
if not value:
|
||||
return
|
||||
|
||||
try:
|
||||
return models.Invitation.objects.open().get(code__iexact=value)
|
||||
except models.Invitation.DoesNotExist:
|
||||
raise serializers.ValidationError("Invalid invitation code")
|
||||
|
||||
def save(self, request):
|
||||
user = super().save(request)
|
||||
if self.validated_data.get("invitation"):
|
||||
user.invitation = self.validated_data.get("invitation")
|
||||
user.save(update_fields=["invitation"])
|
||||
return user
|
||||
|
||||
|
||||
class UserActivitySerializer(activity_serializers.ModelSerializer):
|
||||
type = serializers.SerializerMethodField()
|
||||
name = serializers.CharField(source="username")
|
||||
|
|
|
|||
|
|
@ -10,8 +10,11 @@ from . import models, serializers
|
|||
|
||||
|
||||
class RegisterView(BaseRegisterView):
|
||||
serializer_class = serializers.RegisterSerializer
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
if not self.is_open_for_signup(request):
|
||||
invitation_code = request.data.get("invitation")
|
||||
if not invitation_code and not self.is_open_for_signup(request):
|
||||
r = {"detail": "Registration has been disabled"}
|
||||
return Response(r, status=403)
|
||||
return super().create(request, *args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class TestActionFilterSet(django_filters.FilterSet):
|
|||
|
||||
|
||||
class TestSerializer(serializers.ActionSerializer):
|
||||
actions = ["test"]
|
||||
actions = [serializers.Action("test", allow_all=True)]
|
||||
filterset_class = TestActionFilterSet
|
||||
|
||||
def handle_test(self, objects):
|
||||
|
|
@ -19,8 +19,10 @@ class TestSerializer(serializers.ActionSerializer):
|
|||
|
||||
|
||||
class TestDangerousSerializer(serializers.ActionSerializer):
|
||||
actions = ["test", "test_dangerous"]
|
||||
dangerous_actions = ["test_dangerous"]
|
||||
actions = [
|
||||
serializers.Action("test", allow_all=True),
|
||||
serializers.Action("test_dangerous"),
|
||||
]
|
||||
|
||||
def handle_test(self, objects):
|
||||
pass
|
||||
|
|
@ -29,6 +31,14 @@ class TestDangerousSerializer(serializers.ActionSerializer):
|
|||
pass
|
||||
|
||||
|
||||
class TestDeleteOnlyInactiveSerializer(serializers.ActionSerializer):
|
||||
actions = [serializers.Action("test", allow_all=True, filters={"is_active": False})]
|
||||
filterset_class = TestActionFilterSet
|
||||
|
||||
def handle_test(self, objects):
|
||||
pass
|
||||
|
||||
|
||||
def test_action_serializer_validates_action():
|
||||
data = {"objects": "all", "action": "nope"}
|
||||
serializer = TestSerializer(data, queryset=models.User.objects.none())
|
||||
|
|
@ -52,7 +62,7 @@ def test_action_serializers_objects_clean_ids(factories):
|
|||
data = {"objects": [user1.pk], "action": "test"}
|
||||
serializer = TestSerializer(data, queryset=models.User.objects.all())
|
||||
|
||||
assert serializer.is_valid() is True
|
||||
assert serializer.is_valid(raise_exception=True) is True
|
||||
assert list(serializer.validated_data["objects"]) == [user1]
|
||||
|
||||
|
||||
|
|
@ -63,7 +73,7 @@ def test_action_serializers_objects_clean_all(factories):
|
|||
data = {"objects": "all", "action": "test"}
|
||||
serializer = TestSerializer(data, queryset=models.User.objects.all())
|
||||
|
||||
assert serializer.is_valid() is True
|
||||
assert serializer.is_valid(raise_exception=True) is True
|
||||
assert list(serializer.validated_data["objects"]) == [user1, user2]
|
||||
|
||||
|
||||
|
|
@ -75,7 +85,7 @@ def test_action_serializers_save(factories, mocker):
|
|||
data = {"objects": "all", "action": "test"}
|
||||
serializer = TestSerializer(data, queryset=models.User.objects.all())
|
||||
|
||||
assert serializer.is_valid() is True
|
||||
assert serializer.is_valid(raise_exception=True) is True
|
||||
result = serializer.save()
|
||||
assert result == {"updated": 2, "action": "test", "result": {"hello": "world"}}
|
||||
handler.assert_called_once()
|
||||
|
|
@ -88,7 +98,7 @@ def test_action_serializers_filterset(factories):
|
|||
data = {"objects": "all", "action": "test", "filters": {"is_active": True}}
|
||||
serializer = TestSerializer(data, queryset=models.User.objects.all())
|
||||
|
||||
assert serializer.is_valid() is True
|
||||
assert serializer.is_valid(raise_exception=True) is True
|
||||
assert list(serializer.validated_data["objects"]) == [user2]
|
||||
|
||||
|
||||
|
|
@ -109,9 +119,14 @@ def test_dangerous_actions_refuses_all(factories):
|
|||
assert "non_field_errors" in serializer.errors
|
||||
|
||||
|
||||
def test_dangerous_actions_refuses_not_listed(factories):
|
||||
factories["users.User"]()
|
||||
data = {"objects": "all", "action": "test"}
|
||||
serializer = TestDangerousSerializer(data, queryset=models.User.objects.all())
|
||||
def test_action_serializers_can_require_filter(factories):
|
||||
user1 = factories["users.User"](is_active=False)
|
||||
factories["users.User"](is_active=True)
|
||||
|
||||
assert serializer.is_valid() is True
|
||||
data = {"objects": "all", "action": "test"}
|
||||
serializer = TestDeleteOnlyInactiveSerializer(
|
||||
data, queryset=models.User.objects.all()
|
||||
)
|
||||
|
||||
assert serializer.is_valid(raise_exception=True) is True
|
||||
assert list(serializer.validated_data["objects"]) == [user1]
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from funkwhale_api.manage import serializers, views
|
|||
[
|
||||
(views.ManageTrackFileViewSet, ["library"], "and"),
|
||||
(views.ManageUserViewSet, ["settings"], "and"),
|
||||
(views.ManageInvitationViewSet, ["settings"], "and"),
|
||||
],
|
||||
)
|
||||
def test_permissions(assert_user_permission, view, permissions, operator):
|
||||
|
|
@ -42,3 +43,23 @@ def test_user_view(factories, superuser_api_client, mocker):
|
|||
|
||||
assert response.data["count"] == len(users)
|
||||
assert response.data["results"] == expected
|
||||
|
||||
|
||||
def test_invitation_view(factories, superuser_api_client, mocker):
|
||||
invitations = factories["users.Invitation"].create_batch(size=5)
|
||||
qs = invitations[0].__class__.objects.order_by("-id")
|
||||
url = reverse("api:v1:manage:users:invitations-list")
|
||||
|
||||
response = superuser_api_client.get(url, {"sort": "-id"})
|
||||
expected = serializers.ManageInvitationSerializer(qs, many=True).data
|
||||
|
||||
assert response.data["count"] == len(invitations)
|
||||
assert response.data["results"] == expected
|
||||
|
||||
|
||||
def test_invitation_view_create(factories, superuser_api_client, mocker):
|
||||
url = reverse("api:v1:manage:users:invitations-list")
|
||||
response = superuser_api_client.post(url)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert superuser_api_client.user.invitations.latest("id") is not None
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import datetime
|
||||
import pytest
|
||||
|
||||
from funkwhale_api.users import models
|
||||
|
|
@ -95,3 +96,34 @@ def test_record_activity_does_nothing_if_already(factories, now, mocker):
|
|||
user.record_activity()
|
||||
|
||||
save.assert_not_called()
|
||||
|
||||
|
||||
def test_invitation_generates_random_code_on_save(factories):
|
||||
invitation = factories["users.Invitation"]()
|
||||
assert len(invitation.code) >= 6
|
||||
|
||||
|
||||
def test_invitation_expires_after_delay(factories, settings):
|
||||
delay = settings.USERS_INVITATION_EXPIRATION_DAYS
|
||||
invitation = factories["users.Invitation"]()
|
||||
assert invitation.expiration_date == (
|
||||
invitation.creation_date + datetime.timedelta(days=delay)
|
||||
)
|
||||
|
||||
|
||||
def test_can_filter_open_invitations(factories):
|
||||
okay = factories["users.Invitation"]()
|
||||
factories["users.Invitation"](expired=True)
|
||||
factories["users.User"](invited=True)
|
||||
|
||||
assert models.Invitation.objects.count() == 3
|
||||
assert list(models.Invitation.objects.open()) == [okay]
|
||||
|
||||
|
||||
def test_can_filter_closed_invitations(factories):
|
||||
factories["users.Invitation"]()
|
||||
expired = factories["users.Invitation"](expired=True)
|
||||
used = factories["users.User"](invited=True).invitation
|
||||
|
||||
assert models.Invitation.objects.count() == 3
|
||||
assert list(models.Invitation.objects.open(False)) == [expired, used]
|
||||
|
|
|
|||
|
|
@ -50,6 +50,39 @@ def test_can_disable_registration_view(preferences, api_client, db):
|
|||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_can_signup_with_invitation(preferences, factories, api_client):
|
||||
url = reverse("rest_register")
|
||||
invitation = factories["users.Invitation"](code="Hello")
|
||||
data = {
|
||||
"username": "test1",
|
||||
"email": "test1@test.com",
|
||||
"password1": "testtest",
|
||||
"password2": "testtest",
|
||||
"invitation": "hello",
|
||||
}
|
||||
preferences["users__registration_enabled"] = False
|
||||
response = api_client.post(url, data)
|
||||
assert response.status_code == 201
|
||||
u = User.objects.get(email="test1@test.com")
|
||||
assert u.username == "test1"
|
||||
assert u.invitation == invitation
|
||||
|
||||
|
||||
def test_can_signup_with_invitation_invalid(preferences, factories, api_client):
|
||||
url = reverse("rest_register")
|
||||
factories["users.Invitation"](code="hello")
|
||||
data = {
|
||||
"username": "test1",
|
||||
"email": "test1@test.com",
|
||||
"password1": "testtest",
|
||||
"password2": "testtest",
|
||||
"invitation": "nope",
|
||||
}
|
||||
response = api_client.post(url, data)
|
||||
assert response.status_code == 400
|
||||
assert "invitation" in response.data
|
||||
|
||||
|
||||
def test_can_fetch_data_from_api(api_client, factories):
|
||||
url = reverse("api:v1:users:users-me")
|
||||
response = api_client.get(url)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue