From f24b0879bb6a1f095eda5a03fa0876c8442f6ba9 Mon Sep 17 00:00:00 2001 From: Tanner Prestegard <tanner.prestegard@ligo.org> Date: Thu, 23 Aug 2018 09:59:19 -0500 Subject: [PATCH] Abstracting/generalizing from events app Moving a few base class throttles and fields from the events app up one level so that they can be used elsewhere. --- config/settings/base.py | 2 +- gracedb/api/{v1 => }/backends.py | 0 gracedb/api/v1/events/fields.py | 45 +--------------------------- gracedb/api/v1/events/throttles.py | 41 +------------------------ gracedb/api/v1/events/views.py | 2 +- gracedb/api/v1/fields.py | 48 +++++++++++++++++++++++++++--- gracedb/api/v1/main/views.py | 2 +- gracedb/api/v1/throttles.py | 42 ++++++++++++++++++++++++++ 8 files changed, 91 insertions(+), 91 deletions(-) rename gracedb/api/{v1 => }/backends.py (100%) create mode 100644 gracedb/api/v1/throttles.py diff --git a/config/settings/base.py b/config/settings/base.py index a2b76f94e..7abf0f105 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -345,7 +345,7 @@ REST_FRAMEWORK = { 'annotation' : '10/second', }, 'DEFAULT_AUTHENTICATION_CLASSES': ( - 'api.v1.backends.LigoAuthentication', + 'api.backends.LigoAuthentication', ), 'COERCE_DECIMAL_TO_STRING': False, 'EXCEPTION_HANDLER': diff --git a/gracedb/api/v1/backends.py b/gracedb/api/backends.py similarity index 100% rename from gracedb/api/v1/backends.py rename to gracedb/api/backends.py diff --git a/gracedb/api/v1/events/fields.py b/gracedb/api/v1/events/fields.py index 9129160fa..e766763b0 100644 --- a/gracedb/api/v1/events/fields.py +++ b/gracedb/api/v1/events/fields.py @@ -1,58 +1,15 @@ from __future__ import absolute_import import logging -from django.contrib.auth import get_user_model - from rest_framework import serializers from events.models import Event - -# Set up user model -UserModel = get_user_model() +from ..fields import GenericField # Set up logger logger = logging.getLogger(__name__) -class GenericField(serializers.Field): - # Field, property, or callable of the object which will be used to - # generate the representation of the object. - to_repr = None - lookup_field = 'id' - model = None - - def __init__(self, *args, **kwargs): - self.to_repr = kwargs.pop('to_repr', self.to_repr) - self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) - self.model = kwargs.pop('model', self.model) - - assert self.to_repr is not None, ('Must specify to_repr') - assert self.model is not None, ('Must specify model') - super(GenericField, self).__init__(*args, **kwargs) - - def to_representation(self, obj): - value = getattr(obj, self.to_repr) - - # Handle case where we are given a function instead of - # a model field or a property - if callable(value): - value = value() - return value - - def to_internal_value(self, data): - model_dict = self.get_model_dict(data) - try: - return self.model.objects.get(**model_dict) - except self.model.DoesNotExist: - error_msg = '{model} with {lf}={data} does not exist' \ - .format(model=self.model.__name__, lf=model_dict.keys()[0], - data=model_dict.values()[0]) - raise serializers.ValidationError(error_msg) - - def get_model_dict(self, data): - return {self.lookup_field: data} - - class EventGraceidField(GenericField, serializers.RelatedField): to_repr = 'graceid' lookup_field = 'id' diff --git a/gracedb/api/v1/events/throttles.py b/gracedb/api/v1/events/throttles.py index 02a1c2290..3e40b639c 100644 --- a/gracedb/api/v1/events/throttles.py +++ b/gracedb/api/v1/events/throttles.py @@ -1,44 +1,5 @@ -from rest_framework.throttling import UserRateThrottle +from ..throttles import PostOrPutUserRateThrottle -class PostOrPutUserRateThrottle(UserRateThrottle): - - def allow_request(self, request, view): - """ - This is mostly copied from the Rest Framework's SimpleRateThrottle - except we now pass the request to throttle_success - """ - # We don't want to throttle any safe methods - if request.method not in ['POST', 'PUT']: - return True - - if self.rate is None: - return True - - self.key = self.get_cache_key(request, view) - if self.key is None: - return True - - self.history = self.cache.get(self.key, []) - self.now = self.timer() - - # Drop any requests from the history which have now passed the - # throttle duration - while self.history and self.history[-1] <= self.now - self.duration: - self.history.pop() - if len(self.history) >= self.num_requests: - return self.throttle_failure() - return self.throttle_success(request) - - def throttle_success(self, request): - """ - Inserts the current request's timestamp along with the key - into the cache. Except we only do this if the request is a - writing method (POST or PUT). That's why we needed the request. - """ - if request.method in ['POST', 'PUT']: - self.history.insert(0, self.now) - self.cache.set(self.key, self.history, self.duration) - return True class EventCreationThrottle(PostOrPutUserRateThrottle): scope = 'event_creation' diff --git a/gracedb/api/v1/events/views.py b/gracedb/api/v1/events/views.py index 61850f768..04ada23df 100644 --- a/gracedb/api/v1/events/views.py +++ b/gracedb/api/v1/events/views.py @@ -33,6 +33,7 @@ from rest_framework.response import Response from rest_framework.views import APIView from alerts.old_alert import issueAlertForUpdate +from api.backends import LigoAuthentication from core.http import check_and_serve_file from core.vfile import VersionedFile from events.buildVOEvent import buildVOEvent, VOEventBuilderException @@ -50,7 +51,6 @@ from events.view_utils import eventToDict, eventLogToDict, labelToDict, \ skymapViewerEMObservationToDict, BadFARRange, check_query_far_range from superevents.models import Superevent from .throttles import EventCreationThrottle, AnnotationThrottle -from ..backends import LigoAuthentication from ...utils import api_reverse # Set up logger diff --git a/gracedb/api/v1/fields.py b/gracedb/api/v1/fields.py index fb378b958..562f0d6b0 100644 --- a/gracedb/api/v1/fields.py +++ b/gracedb/api/v1/fields.py @@ -1,13 +1,14 @@ +from __future__ import absolute_import import logging import six -from rest_framework import fields +from rest_framework import serializers # Set up logger logger = logging.getLogger(__name__) -class CustomHiddenDefault(fields.CurrentUserDefault): +class CustomHiddenDefault(serializers.CurrentUserDefault): context_key = None def __init__(self, *args, **kwargs): @@ -45,7 +46,7 @@ class ParentObjectDefault(CustomHiddenDefault): return value -class CommaSeparatedOrListField(fields.ListField): +class CommaSeparatedOrListField(serializers.ListField): default_style = {'base_template': 'input.html'} def __init__(self, *args, **kwargs): @@ -66,7 +67,7 @@ class CommaSeparatedOrListField(fields.ListField): return super(CommaSeparatedOrListField, self).to_internal_value(data) -class ChoiceDisplayField(fields.ChoiceField): +class ChoiceDisplayField(serializers.ChoiceField): """ Same as standard choice field, but return a choice's display_value instead of the key when serializing the field. @@ -74,3 +75,42 @@ class ChoiceDisplayField(fields.ChoiceField): def to_representation(self, value): return self._choices[value] + + +class GenericField(serializers.Field): + # Field, property, or callable of the object which will be used to + # generate the representation of the object. + to_repr = None + lookup_field = 'id' + model = None + + def __init__(self, *args, **kwargs): + self.to_repr = kwargs.pop('to_repr', self.to_repr) + self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + self.model = kwargs.pop('model', self.model) + + assert self.to_repr is not None, ('Must specify to_repr') + assert self.model is not None, ('Must specify model') + super(GenericField, self).__init__(*args, **kwargs) + + def to_representation(self, obj): + value = getattr(obj, self.to_repr) + + # Handle case where we are given a function instead of + # a model field or a property + if callable(value): + value = value() + return value + + def to_internal_value(self, data): + model_dict = self.get_model_dict(data) + try: + return self.model.objects.get(**model_dict) + except self.model.DoesNotExist: + error_msg = '{model} with {lf}={data} does not exist' \ + .format(model=self.model.__name__, lf=model_dict.keys()[0], + data=model_dict.values()[0]) + raise serializers.ValidationError(error_msg) + + def get_model_dict(self, data): + return {self.lookup_field: data} diff --git a/gracedb/api/v1/main/views.py b/gracedb/api/v1/main/views.py index 2a5a54ab6..cf7b120a2 100644 --- a/gracedb/api/v1/main/views.py +++ b/gracedb/api/v1/main/views.py @@ -12,12 +12,12 @@ from rest_framework.response import Response from rest_framework.reverse import reverse as drf_reverse from rest_framework.views import APIView +from api.backends import LigoAuthentication from api.utils import api_reverse from events.models import Group, Pipeline, Search, Tag, Label, EMGroup, \ VOEvent, EMBBEventLog, EMSPECTRUM from events.view_logic import get_performance_info from superevents.models import Superevent -from ..backends import LigoAuthentication from ..superevents.url_templates import construct_url_templates diff --git a/gracedb/api/v1/throttles.py b/gracedb/api/v1/throttles.py new file mode 100644 index 000000000..1c7a5e159 --- /dev/null +++ b/gracedb/api/v1/throttles.py @@ -0,0 +1,42 @@ +from rest_framework.throttling import UserRateThrottle + + +class PostOrPutUserRateThrottle(UserRateThrottle): + + def allow_request(self, request, view): + """ + This is mostly copied from the Rest Framework's SimpleRateThrottle + except we now pass the request to throttle_success + """ + # We don't want to throttle any safe methods + if request.method not in ['POST', 'PUT']: + return True + + if self.rate is None: + return True + + self.key = self.get_cache_key(request, view) + if self.key is None: + return True + + self.history = self.cache.get(self.key, []) + self.now = self.timer() + + # Drop any requests from the history which have now passed the + # throttle duration + while self.history and self.history[-1] <= self.now - self.duration: + self.history.pop() + if len(self.history) >= self.num_requests: + return self.throttle_failure() + return self.throttle_success(request) + + def throttle_success(self, request): + """ + Inserts the current request's timestamp along with the key + into the cache. Except we only do this if the request is a + writing method (POST or PUT). That's why we needed the request. + """ + if request.method in ['POST', 'PUT']: + self.history.insert(0, self.now) + self.cache.set(self.key, self.history, self.duration) + return True -- GitLab