Commit 77cce1b1 authored by Florent Chehab's avatar Florent Chehab
Browse files

feat(backend): refactor/cleaned/ infer get_serializer

* Cleaned all init files
* Infer the serializer from the model instead of having it in the models
* Updated the doc accordingly
* Fixed typos

Fixes #93
Fixes #85
parent cb86531b
...@@ -17,10 +17,6 @@ class ForTestingVersioning(VersionedEssentialModule): ...@@ -17,10 +17,6 @@ class ForTestingVersioning(VersionedEssentialModule):
bbb = models.CharField(max_length=100) bbb = models.CharField(max_length=100)
@classmethod
def get_serializer(cls):
return ForTestingVersioningSerializer
class ForTestingVersioningSerializer(VersionedEssentialModuleSerializer): class ForTestingVersioningSerializer(VersionedEssentialModuleSerializer):
""" """
......
...@@ -6,7 +6,7 @@ from backend_app.models.abstract.base import ( ...@@ -6,7 +6,7 @@ from backend_app.models.abstract.base import (
BaseModelSerializer, BaseModelSerializer,
BaseModelViewSet, BaseModelViewSet,
) )
from backend_app.models import SEMESTER_OPTIONS from backend_app.models.shared import SEMESTER_OPTIONS
class Offer(BaseModel): class Offer(BaseModel):
......
...@@ -8,7 +8,7 @@ from backend_app.models.abstract.essentialModule import ( ...@@ -8,7 +8,7 @@ from backend_app.models.abstract.essentialModule import (
EssentialModuleSerializer, EssentialModuleSerializer,
EssentialModuleViewSet, EssentialModuleViewSet,
) )
from backend_app.models import SEMESTER_OPTIONS from backend_app.models.shared import SEMESTER_OPTIONS
class PreviousDeparture(EssentialModule): class PreviousDeparture(EssentialModule):
......
# This file is not a model. It is file to hold shared things across models.
SEMESTER_OPTIONS = (("a", "autumn"), ("p", "spring"))
...@@ -6,7 +6,7 @@ from backend_app.models.abstract.essentialModule import ( ...@@ -6,7 +6,7 @@ from backend_app.models.abstract.essentialModule import (
) )
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from rest_framework.validators import ValidationError as RFValidationError from rest_framework.validators import ValidationError as RFValidationError
from backend_app.validators.tag.url import validate_extension from backend_app.validators.tags import validate_extension
from django.conf import settings from django.conf import settings
......
...@@ -7,10 +7,6 @@ class UniversityDri(Module): ...@@ -7,10 +7,6 @@ class UniversityDri(Module):
universities = models.ManyToManyField(University, related_name="university_dri") universities = models.ManyToManyField(University, related_name="university_dri")
@classmethod
def get_serializer(cls):
return UniversityDriSerializer
class UniversityDriSerializer(ModuleSerializer): class UniversityDriSerializer(ModuleSerializer):
class Meta: class Meta:
......
...@@ -25,10 +25,6 @@ class UniversityInfo(Module): ...@@ -25,10 +25,6 @@ class UniversityInfo(Module):
costs_currency = models.ForeignKey(Currency, on_delete=models.PROTECT, null=True) costs_currency = models.ForeignKey(Currency, on_delete=models.PROTECT, null=True)
@classmethod
def get_serializer(cls):
return UniversityInfoSerializer
class UniversityInfoSerializer(ModuleSerializer): class UniversityInfoSerializer(ModuleSerializer):
class Meta: class Meta:
......
...@@ -13,10 +13,6 @@ class UniversityScholarship(Scholarship): ...@@ -13,10 +13,6 @@ class UniversityScholarship(Scholarship):
University, related_name="university_scholarships" University, related_name="university_scholarships"
) )
@classmethod
def get_serializer(cls):
return UniversityScholarshipSerializer
class UniversityScholarshipSerializer(ScholarshipSerializer): class UniversityScholarshipSerializer(ScholarshipSerializer):
class Meta: class Meta:
......
...@@ -25,10 +25,6 @@ class UniversitySemestersDates(Module): ...@@ -25,10 +25,6 @@ class UniversitySemestersDates(Module):
autumn_begin = models.DateField(null=True, blank=True) autumn_begin = models.DateField(null=True, blank=True)
autumn_end = models.DateField(null=True, blank=True) autumn_end = models.DateField(null=True, blank=True)
@classmethod
def get_serializer(cls):
return UniversitySemestersDatesSerializer
class UniversitySemestersDatesSerializer(ModuleSerializer): class UniversitySemestersDatesSerializer(ModuleSerializer):
def validate(self, attrs): def validate(self, attrs):
......
...@@ -13,10 +13,6 @@ class UniversityTaggedItem(TaggedItem): ...@@ -13,10 +13,6 @@ class UniversityTaggedItem(TaggedItem):
University, on_delete=models.PROTECT, related_name="university_tagged_items" University, on_delete=models.PROTECT, related_name="university_tagged_items"
) )
@classmethod
def get_serializer(cls):
return UniversityTaggedItemSerializer
class Meta: class Meta:
unique_together = ("university", "tag", "importance_level") unique_together = ("university", "tag", "importance_level")
......
...@@ -14,6 +14,39 @@ class VersionSerializer(BaseModelSerializer): ...@@ -14,6 +14,39 @@ class VersionSerializer(BaseModelSerializer):
data = serializers.SerializerMethodField() data = serializers.SerializerMethodField()
serializers_mapping = None
@classmethod
def get_serializers_mapping(cls) -> dict:
"""
Function that returns a mapping from model name to the serializer
class that should be used to return the versioned data.
"""
if cls.serializers_mapping is None:
# Prevent cyclic imports
from backend_app.config.viewsets import get_viewsets_info
# A little bit of optimization to easily find the serializer class associated with a model
cls.serializers_mapping = dict()
viewsets = map(
lambda v: v.Viewset,
get_viewsets_info(requires_testing="smart", is_api_view=False),
)
for viewset in viewsets:
serializer = viewset().get_serializer_class()
model = serializer.Meta.model
cls.serializers_mapping[model.__name__] = serializer
# Override if models has a get_serializer method
for viewset in viewsets:
model = viewset().get_serializer_class().Meta.model
try:
cls.serializers_mapping[model.__name__] = model.get_serializer()
except AttributeError:
pass
return cls.serializers_mapping
def get_data(self, obj): def get_data(self, obj):
""" """
Serilizer for the data field Serilizer for the data field
...@@ -26,7 +59,8 @@ class VersionSerializer(BaseModelSerializer): ...@@ -26,7 +59,8 @@ class VersionSerializer(BaseModelSerializer):
djangoSerializers.deserialize(obj.format, data, ignorenonexistent=True) djangoSerializers.deserialize(obj.format, data, ignorenonexistent=True)
)[0] )[0]
# Version is valid, # Version is valid,
obj_serializer = tmp.object.get_serializer() print(self.get_serializers_mapping())
obj_serializer = self.get_serializers_mapping()[type(tmp.object).__name__]
new_context = dict(self.context) new_context = dict(self.context)
new_context["view"].action = "list" new_context["view"].action = "list"
return obj_serializer(tmp.object, context=new_context).data return obj_serializer(tmp.object, context=new_context).data
......
...@@ -4,7 +4,7 @@ from reversion.models import Version ...@@ -4,7 +4,7 @@ from reversion.models import Version
def squash_revision_by_user(sender, obj, **kwargs): def squash_revision_by_user(sender, obj, **kwargs):
""" """
It should also work with moderation as obj will be a versionned object It should also work with moderation as obj will be a versioned object
""" """
versions = ( versions = (
Version.objects.get_for_object(obj) Version.objects.get_for_object(obj)
......
from django.test import TestCase from django.test import TestCase
from backend_app.load_data import load_all from backend_app.load_data.load_all import load_all
class ModerationTestCase(TestCase): class ModerationTestCase(TestCase):
......
...@@ -2,7 +2,7 @@ from django.test import TestCase ...@@ -2,7 +2,7 @@ from django.test import TestCase
import pytest import pytest
from rest_framework.validators import ValidationError as RFValidationError from rest_framework.validators import ValidationError as RFValidationError
from django.core.validators import ValidationError as DJValidationError from django.core.validators import ValidationError as DJValidationError
from backend_app.validators.tag.url import validate_extension, validate_url from backend_app.validators.tags import validate_extension, validate_url
class ValidationUrlTestCase(TestCase): class ValidationUrlTestCase(TestCase):
......
...@@ -4,6 +4,7 @@ from django.conf import settings ...@@ -4,6 +4,7 @@ from django.conf import settings
from reversion.models import Version from reversion.models import Version
from backend_app.signals.squash_revisions import new_revision_saved from backend_app.signals.squash_revisions import new_revision_saved
from django.test import override_settings from django.test import override_settings
from django.contrib.contenttypes.models import ContentType
class VersioningTestCase(WithUserTestCase): class VersioningTestCase(WithUserTestCase):
...@@ -69,3 +70,7 @@ class VersioningTestCase(WithUserTestCase): ...@@ -69,3 +70,7 @@ class VersioningTestCase(WithUserTestCase):
self.assertEqual(len(versions), 2) self.assertEqual(len(versions), 2)
self.assertEqual(len(versions), instance.nb_versions) self.assertEqual(len(versions), instance.nb_versions)
self.assertTrue(self.signal_was_called) self.assertTrue(self.signal_was_called)
# Final test: we query the version viewset itself
ct = ContentType.objects.get_for_model(instance).id
self.authenticated_client.get("/api/versions/{}/{}/".format(ct, instance.pk))
from django.conf.urls import include, url from django.conf.urls import include, url
from backend_app.config.viewsets import get_viewsets_info from backend_app.config.viewsets import get_viewsets_info
from backend_app.checks import check_viewsets
from rest_framework import routers from rest_framework import routers
from rest_framework.documentation import include_docs_urls from rest_framework.documentation import include_docs_urls
STANDARD_VIEWSETS = get_viewsets_info(requires_testing="smart", is_api_view=False) STANDARD_VIEWSETS = get_viewsets_info(requires_testing="smart", is_api_view=False)
API_VIEW_VIEWSETS = get_viewsets_info(requires_testing="smart", is_api_view=True) API_VIEW_VIEWSETS = get_viewsets_info(requires_testing="smart", is_api_view=True)
check_viewsets(map(lambda v: v.Viewset, STANDARD_VIEWSETS))
####### #######
# Building the API routing # Building the API routing
......
from .url import validate_url
from .text import validate_text
from .tagged_item_validation import (
validate_content_against_config,
tagged_item_validation,
)
__all__ = [
"validate_url",
"validate_text",
"validate_content_against_config",
"tagged_item_validation",
]
from rest_framework.validators import ValidationError
def missing_field(field):
return ValidationError("{} : this field is required".format(field))
def check_required(config, content):
for field in config:
if config[field]["required"]:
try:
val = content[field]
if type(val) is str:
if len(val) == 0:
raise missing_field(field)
if val is None:
raise missing_field(field)
except KeyError:
raise missing_field(field)
from .photos import PHOTOS_TAG_CONFIG
from .useful_links import USEFULL_LINKS_CONFIG
__all__ = ["PHOTOS_TAG_CONFIG", "USEFULL_LINKS_CONFIG"]
from rest_framework.validators import ValidationError
def validate_text(config, string):
string = str(string) # might cause error with number ?
try:
validators = config["validators"]
for validator in validators:
validator_content = validators[validator]
if validator == "max_length":
if len(string) > validator_content:
raise ValidationError("Your text is too long !")
else:
raise Exception("Dev, you have implement something here...")
except KeyError:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment