Commit 77cce1b1 authored by Florent Chehab's avatar Florent Chehab

feat(backend): refactor/cleaned/ infer get_serializer

* Cleaned all init files
* Infer the serializer from the model instead of having it in the models
* Updated the doc accordingly
* Fixed typos

Fixes #93
Fixes #85
parent cb86531b
Pipeline #37131 passed with stages
in 5 minutes and 39 seconds
......@@ -2,20 +2,12 @@ from django.contrib import admin
from reversion_compare.admin import CompareVersionAdmin
from backend_app.config.models import get_models
from backend_app.checks import check_classic_models, check_versionned_models
# We need to register testing models, otherwise we won't be able to test properly,
# Since no migrations would privide those models.
# So don't put requires_testing=True
VERSIONED_MODELS = get_models(versionned=True) # , requires_testing=False)
CLASSIC_MODELS = get_models(versionned=False) # , requires_testing=False)
#######
# Perform some dynamic checks
#######
check_classic_models(CLASSIC_MODELS)
check_versionned_models(VERSIONED_MODELS)
VERSIONED_MODELS = get_models(versioned=True) # , requires_testing=False)
CLASSIC_MODELS = get_models(versioned=False) # , requires_testing=False)
#######
# Register the models
......
def check_classic_models(classic_models):
from collections import Counter
def check_viewsets(viewsets):
"""
Check that all "classic" models don't have a `get_serializer` method:
they don't need it.
See doc for more information:
Check that if 2 serializers are registered for the same model. Then that model
has a get_serializer method to point to the serializer to use to deserialize it.
There should be only one of serializer being used per model. Otherwise extra
configuration is required.
http://localhost:5000/#/Application/Backend/models_serializers_viewsets
"""
for Model in classic_models:
try:
# Check that it doesn't have the get_serializer method
Model.get_serializer()
raise Exception(
"A 'CLASSIC MODEL' SHOULDN'T have the "
"get_serializer method, {}".format(Model)
# Prevent cyclic imports
from backend_app.models.abstract.versionedEssentialModule import (
VersionedEssentialModule,
)
except AttributeError:
pass
serializers = list()
models = []
for viewset in viewsets:
serializer = viewset().get_serializer_class()
model = serializer.Meta.model
def check_versionned_models(versionned_models):
"""
Check that all "versionned" models have a `get_serializer` method.
See doc for more information:
http://localhost:5000/#/Application/Backend/models_serializers_viewsets
"""
for Model in versionned_models:
if issubclass(model, VersionedEssentialModule):
if serializer not in serializers:
serializers.append(serializer)
models.append(model)
models = dict(Counter(models))
for model, n in models.items():
if n > 1:
try:
# Check that the model has a get_serializer method
model.get_serializer()
# Check that it has a get_serializer method
if Model.get_serializer().Meta.model != Model:
raise Exception("Get_serializer configuration incorrect in", str(Model))
except AttributeError:
raise Exception(
"The model {} has multiple serializers pointing to it. "
"In such case, you must define the get_serializer method inside the model. "
"Have a look at the documentation.".format(model)
)
......@@ -10,7 +10,7 @@ from .utils import load_viewsets_config
def get_models(
versionned: Optional[bool] = None,
versioned: Optional[bool] = None,
requires_testing: Union[None, bool, "smart"] = None,
) -> List[object]:
"""
......@@ -51,10 +51,10 @@ def get_models(
continue
Model = Viewset.serializer_class.Meta.model
if versionned is not None:
if versionned and not issubclass(Model, VersionedEssentialModule):
if versioned is not None:
if versioned and not issubclass(Model, VersionedEssentialModule):
continue
if not versionned and issubclass(Model, VersionedEssentialModule):
if not versioned and issubclass(Model, VersionedEssentialModule):
continue
out.append(Model)
......
# We firest need to define ASSETS_PATH to prevent cyclic imports
from os import path
ASSETS_PATH = path.join(path.realpath(__file__), "../assets/") # noqa: E402
from .load_all import load_all # noqa: E402
__all__ = ["load_all", "ASSETS_PATH"]
import reversion
from .loading_scripts import (
LoadAdminUser,
LoadCountries,
LoadCurrencies,
LoadGroups,
LoadTags,
LoadUniversities,
LoadUniversityEx,
)
from backend_app.load_data.loading_scripts.loadAdminUser import LoadAdminUser
from backend_app.load_data.loading_scripts.loadCountries import LoadCountries
from backend_app.load_data.loading_scripts.loadCurrencies import LoadCurrencies
from backend_app.load_data.loading_scripts.loadGroups import LoadGroups
from backend_app.load_data.loading_scripts.loadTags import LoadTags
from backend_app.load_data.loading_scripts.loadUniversities import LoadUniversities
from backend_app.load_data.loading_scripts.loadUniversityEx import LoadUniversityEx
def load_all():
"""Function to load all the initial data in the app
"""
Function to load all the initial data in the app
"""
with reversion.create_revision():
......
from .loadGroups import LoadGroups
from .loadAdminUser import LoadAdminUser
from .loadCountries import LoadCountries
from .loadUniversities import LoadUniversities
from .loadTags import LoadTags
from .loadCurrencies import LoadCurrencies
from .loadUniversityEx import LoadUniversityEx
__all__ = [
"LoadGroups",
"LoadAdminUser",
"LoadCountries",
"LoadUniversities",
"LoadTags",
"LoadCurrencies",
"LoadUniversityEx",
]
from os.path import abspath, join
import pandas as pd
from backend_app.load_data import ASSETS_PATH
from backend_app.load_data.shared import ASSETS_PATH
from backend_app.models.country import Country
from base_app.models import User
......
......@@ -2,7 +2,7 @@ import csv
from decimal import Decimal
from os.path import abspath, join
from backend_app.load_data import ASSETS_PATH
from backend_app.load_data.shared import ASSETS_PATH
from backend_app.models.currency import Currency
from base_app.models import User
......
from os.path import abspath, join
import pandas as pd
from backend_app.load_data import ASSETS_PATH
from backend_app.load_data.shared import ASSETS_PATH
from backend_app.models.campus import Campus
from backend_app.models.city import City
from backend_app.models.country import Country
......
from os import path
ASSETS_PATH = path.join(path.realpath(__file__), "../assets/") # noqa: E402
SEMESTER_OPTIONS = (("a", "autumn"), ("p", "spring"))
__all__ = ["SEMESTER_OPTIONS"]
......@@ -6,8 +6,8 @@ from backend_app.models.abstract.versionedEssentialModule import (
VersionedEssentialModuleSerializer,
VersionedEssentialModuleViewSet,
)
from backend_app.validators.tag import validate_content_against_config
from backend_app.validators.tag.tags_config import USEFULL_LINKS_CONFIG
from backend_app.validators.tags import validate_content_against_config
from backend_app.validators.tags_config.useful_links import USEFULL_LINKS_CONFIG
IMPORTANCE_LEVEL = (("-", "normal"), ("+", "important"), ("++", "IMPORTANT"))
......@@ -19,7 +19,7 @@ class Module(VersionedEssentialModule):
Those field will be inherited.
All Basic modules are also "versionned" modules
All Basic modules are also "versioned" modules
"""
title = models.CharField(default="", blank=True, max_length=150)
......
......@@ -3,7 +3,7 @@ from django.db import models
from backend_app.fields import JSONField
from backend_app.models.abstract.module import Module, ModuleSerializer, ModuleViewSet
from backend_app.models.tag import Tag
from backend_app.validators.tag import tagged_item_validation
from backend_app.validators.tags import tagged_item_validation
class TaggedItem(Module):
......
......@@ -13,21 +13,12 @@ from reversion.models import Version
class VersionedEssentialModule(EssentialModule):
"""
Custom EssentialModule that will be versionned in the app
Custom EssentialModule that will be versioned in the app
"""
# We store the current number of versions for better performance
nb_versions = models.PositiveIntegerField(default=0)
@classmethod
def get_serializer(cls):
"""
This function is required for handling
versioning easily.
You have to put the correct value in each subclass
"""
raise Exception("Get_serializer must be redefined in subclass")
def delete(self, using=None, keep_parents=False):
"""
Override the default delete behavior to make sure
......@@ -44,7 +35,7 @@ class VersionedEssentialModule(EssentialModule):
class VersionedEssentialModuleSerializer(EssentialModuleSerializer):
"""
Serializer for versionned models
Serializer for versioned models
"""
# Add a nb_versions field
......@@ -64,7 +55,7 @@ class VersionedEssentialModuleSerializer(EssentialModuleSerializer):
class VersionedEssentialModuleViewSet(EssentialModuleViewSet):
"""
Viewset for the versionned models
Viewset for the versioned models
"""
serializer_class = VersionedEssentialModuleSerializer
......@@ -32,10 +32,6 @@ class Campus(Module):
def location(self):
return {"lat": self.lat, "lon": self.lon}
@classmethod
def get_serializer(cls):
return CampusSerializer
class Meta:
unique_together = ("is_main_campus", "university")
......
......@@ -13,10 +13,6 @@ class CampusTaggedItem(TaggedItem):
Campus, on_delete=models.PROTECT, related_name="campus_tagged_items"
)
@classmethod
def get_serializer(cls):
return CampusTaggedItemSerializer
class Meta:
unique_together = ("campus", "tag", "importance_level")
......
......@@ -13,10 +13,6 @@ class CityTaggedItem(TaggedItem):
City, on_delete=models.PROTECT, related_name="city_tagged_items"
)
@classmethod
def get_serializer(cls):
return CityTaggedItemSerializer
class Meta:
unique_together = ("city", "tag", "importance_level")
......
......@@ -7,10 +7,6 @@ class CountryDri(Module):
countries = models.ManyToManyField(Country, related_name="country_dri")
@classmethod
def get_serializer(cls):
return CountryDriSerializer
class CountryDriSerializer(ModuleSerializer):
class Meta:
......
......@@ -11,10 +11,6 @@ class CountryScholarship(Scholarship):
countries = models.ManyToManyField(Country, related_name="country_scholarships")
@classmethod
def get_serializer(cls):
return CountryScholarshipSerializer
class CountryScholarshipSerializer(ScholarshipSerializer):
class Meta:
......
......@@ -13,10 +13,6 @@ class CountryTaggedItem(TaggedItem):
Country, on_delete=models.PROTECT, related_name="country_tagged_items"
)
@classmethod
def get_serializer(cls):
return CountryTaggedItemSerializer
class Meta:
unique_together = ("country", "tag", "importance_level")
......
......@@ -17,10 +17,6 @@ class ForTestingVersioning(VersionedEssentialModule):
bbb = models.CharField(max_length=100)
@classmethod
def get_serializer(cls):
return ForTestingVersioningSerializer
class ForTestingVersioningSerializer(VersionedEssentialModuleSerializer):
"""
......
......@@ -6,7 +6,7 @@ from backend_app.models.abstract.base import (
BaseModelSerializer,
BaseModelViewSet,
)
from backend_app.models import SEMESTER_OPTIONS
from backend_app.models.shared import SEMESTER_OPTIONS
class Offer(BaseModel):
......
......@@ -8,7 +8,7 @@ from backend_app.models.abstract.essentialModule import (
EssentialModuleSerializer,
EssentialModuleViewSet,
)
from backend_app.models import SEMESTER_OPTIONS
from backend_app.models.shared import SEMESTER_OPTIONS
class PreviousDeparture(EssentialModule):
......
# This file is not a model. It is file to hold shared things across models.
SEMESTER_OPTIONS = (("a", "autumn"), ("p", "spring"))
......@@ -6,7 +6,7 @@ from backend_app.models.abstract.essentialModule import (
)
from django.core.exceptions import ValidationError
from rest_framework.validators import ValidationError as RFValidationError
from backend_app.validators.tag.url import validate_extension
from backend_app.validators.tags import validate_extension
from django.conf import settings
......
......@@ -7,10 +7,6 @@ class UniversityDri(Module):
universities = models.ManyToManyField(University, related_name="university_dri")
@classmethod
def get_serializer(cls):
return UniversityDriSerializer
class UniversityDriSerializer(ModuleSerializer):
class Meta:
......
......@@ -25,10 +25,6 @@ class UniversityInfo(Module):
costs_currency = models.ForeignKey(Currency, on_delete=models.PROTECT, null=True)
@classmethod
def get_serializer(cls):
return UniversityInfoSerializer
class UniversityInfoSerializer(ModuleSerializer):
class Meta:
......
......@@ -13,10 +13,6 @@ class UniversityScholarship(Scholarship):
University, related_name="university_scholarships"
)
@classmethod
def get_serializer(cls):
return UniversityScholarshipSerializer
class UniversityScholarshipSerializer(ScholarshipSerializer):
class Meta:
......
......@@ -25,10 +25,6 @@ class UniversitySemestersDates(Module):
autumn_begin = models.DateField(null=True, blank=True)
autumn_end = models.DateField(null=True, blank=True)
@classmethod
def get_serializer(cls):
return UniversitySemestersDatesSerializer
class UniversitySemestersDatesSerializer(ModuleSerializer):
def validate(self, attrs):
......
......@@ -13,10 +13,6 @@ class UniversityTaggedItem(TaggedItem):
University, on_delete=models.PROTECT, related_name="university_tagged_items"
)
@classmethod
def get_serializer(cls):
return UniversityTaggedItemSerializer
class Meta:
unique_together = ("university", "tag", "importance_level")
......
......@@ -14,6 +14,39 @@ class VersionSerializer(BaseModelSerializer):
data = serializers.SerializerMethodField()
serializers_mapping = None
@classmethod
def get_serializers_mapping(cls) -> dict:
"""
Function that returns a mapping from model name to the serializer
class that should be used to return the versioned data.
"""
if cls.serializers_mapping is None:
# Prevent cyclic imports
from backend_app.config.viewsets import get_viewsets_info
# A little bit of optimization to easily find the serializer class associated with a model
cls.serializers_mapping = dict()
viewsets = map(
lambda v: v.Viewset,
get_viewsets_info(requires_testing="smart", is_api_view=False),
)
for viewset in viewsets:
serializer = viewset().get_serializer_class()
model = serializer.Meta.model
cls.serializers_mapping[model.__name__] = serializer
# Override if models has a get_serializer method
for viewset in viewsets:
model = viewset().get_serializer_class().Meta.model
try:
cls.serializers_mapping[model.__name__] = model.get_serializer()
except AttributeError:
pass
return cls.serializers_mapping
def get_data(self, obj):
"""
Serilizer for the data field
......@@ -26,7 +59,8 @@ class VersionSerializer(BaseModelSerializer):
djangoSerializers.deserialize(obj.format, data, ignorenonexistent=True)
)[0]
# Version is valid,
obj_serializer = tmp.object.get_serializer()
print(self.get_serializers_mapping())
obj_serializer = self.get_serializers_mapping()[type(tmp.object).__name__]
new_context = dict(self.context)
new_context["view"].action = "list"
return obj_serializer(tmp.object, context=new_context).data
......
......@@ -4,7 +4,7 @@ from reversion.models import Version
def squash_revision_by_user(sender, obj, **kwargs):
"""
It should also work with moderation as obj will be a versionned object
It should also work with moderation as obj will be a versioned object
"""
versions = (
Version.objects.get_for_object(obj)
......
from django.test import TestCase
from backend_app.load_data import load_all
from backend_app.load_data.load_all import load_all
class ModerationTestCase(TestCase):
......
......@@ -2,7 +2,7 @@ from django.test import TestCase
import pytest
from rest_framework.validators import ValidationError as RFValidationError
from django.core.validators import ValidationError as DJValidationError
from backend_app.validators.tag.url import validate_extension, validate_url
from backend_app.validators.tags import validate_extension, validate_url
class ValidationUrlTestCase(TestCase):
......
......@@ -4,6 +4,7 @@ from django.conf import settings
from reversion.models import Version
from backend_app.signals.squash_revisions import new_revision_saved
from django.test import override_settings
from django.contrib.contenttypes.models import ContentType
class VersioningTestCase(WithUserTestCase):
......@@ -69,3 +70,7 @@ class VersioningTestCase(WithUserTestCase):
self.assertEqual(len(versions), 2)
self.assertEqual(len(versions), instance.nb_versions)
self.assertTrue(self.signal_was_called)
# Final test: we query the version viewset itself
ct = ContentType.objects.get_for_model(instance).id
self.authenticated_client.get("/api/versions/{}/{}/".format(ct, instance.pk))
from django.conf.urls import include, url
from backend_app.config.viewsets import get_viewsets_info
from backend_app.checks import check_viewsets
from rest_framework import routers
from rest_framework.documentation import include_docs_urls
STANDARD_VIEWSETS = get_viewsets_info(requires_testing="smart", is_api_view=False)
API_VIEW_VIEWSETS = get_viewsets_info(requires_testing="smart", is_api_view=True)
check_viewsets(map(lambda v: v.Viewset, STANDARD_VIEWSETS))
#######
# Building the API routing
......
from .url import validate_url
from .text import validate_text
from .tagged_item_validation import (
validate_content_against_config,
tagged_item_validation,
)
__all__ = [
"validate_url",
"validate_text",
"validate_content_against_config",
"tagged_item_validation",
]
from rest_framework.validators import ValidationError
def missing_field(field):
return ValidationError("{} : this field is required".format(field))
def check_required(config, content):
for field in config:
if config[field]["required"]:
try:
val = content[field]
if type(val) is str:
if len(val) == 0:
raise missing_field(field)
if val is None:
raise missing_field(field)
except KeyError:
raise missing_field(field)
from .photos import PHOTOS_TAG_CONFIG
from .useful_links import USEFULL_LINKS_CONFIG
__all__ = ["PHOTOS_TAG_CONFIG", "USEFULL_LINKS_CONFIG"]
from rest_framework.validators import ValidationError
def validate_text(config, string):
string = str(string) # might cause error with number ?
try:
validators = config["validators"]
for validator in validators:
validator_content = validators[validator]
if validator == "max_length":
if len(string) > validator_content:
raise ValidationError("Your text is too long !")
else:
raise Exception("Dev, you have implement something here...")
except KeyError:
pass
from django.core.validators import URLValidator
from rest_framework.validators import ValidationError
def validate_extension(allowed_extensions, string):
allowed_extensions = [
allowed_extension.lower() for allowed_extension in allowed_extensions
]
try:
if string.split(".")[-1].lower() not in allowed_extensions:
raise ValidationError(
"The file you submitted has an unauthorized extension"
)
except KeyError:
raise ValidationError("File extension not recognized")
def validate_url(config, string):
string = str(string)
validate = URLValidator(schemes=("http", "https", "ftp", "ftps"))
validate(string)
try:
validators = config["validators"]
for validator in validators:
validator_content = validators[validator]
if validator == "extension":
validate_extension(validator_content, string)
elif validator == "max_length":
if len(string) > validator_content:
raise ValidationError("Your url is too long !")
else:
raise Exception("Dev, you have implement something here...")
except KeyError:
pass
from .checks import check_required
from .url import validate_url
from .text import validate_text
from rest_framework.validators import ValidationError
from .tags_config import PHOTOS_TAG_CONFIG
from .tags_config import USEFULL_LINKS_CONFIG
from django.core.validators import URLValidator
from rest_framework.exceptions import ValidationError
from backend_app.validators.tags_config.photos import PHOTOS_TAG_CONFIG
from backend_app.validators.tags_config.useful_links import USEFULL_LINKS_CONFIG
def missing_field(field):
return ValidationError("{} : this field is required".format(field))
def check_required(config, content):
for field in config:
if config[field]["required"]:
try:
val = content[field]
if type(val) is str:
if len(val) == 0:
raise missing_field(field)
if val is None:
raise missing_field(field)
except KeyError:
raise missing_field(field)
def tagged_item_validation(attrs):
tag_config = attrs["tag"].config
try:
sumbitted_content = attrs["custom_content"]
submitted_content = attrs["custom_content"]
except KeyError:
assert len(tag_config.keys()) == 0
return True
validate_content_against_config(tag_config, sumbitted_content)
validate_content_against_config(tag_config, submitted_content)
def validate_content_against_config(config, content):
......@@ -49,3 +66,52 @@ def validate_content_against_config(config, content):