From f24b0879bb6a1f095eda5a03fa0876c8442f6ba9 Mon Sep 17 00:00:00 2001
From: Tanner Prestegard <tanner.prestegard@ligo.org>
Date: Thu, 23 Aug 2018 09:59:19 -0500
Subject: [PATCH] Abstracting/generalizing from events app

Moving a few base class throttles and fields from the events app up
one level so that they can be used elsewhere.
---
 config/settings/base.py            |  2 +-
 gracedb/api/{v1 => }/backends.py   |  0
 gracedb/api/v1/events/fields.py    | 45 +---------------------------
 gracedb/api/v1/events/throttles.py | 41 +------------------------
 gracedb/api/v1/events/views.py     |  2 +-
 gracedb/api/v1/fields.py           | 48 +++++++++++++++++++++++++++---
 gracedb/api/v1/main/views.py       |  2 +-
 gracedb/api/v1/throttles.py        | 42 ++++++++++++++++++++++++++
 8 files changed, 91 insertions(+), 91 deletions(-)
 rename gracedb/api/{v1 => }/backends.py (100%)
 create mode 100644 gracedb/api/v1/throttles.py

diff --git a/config/settings/base.py b/config/settings/base.py
index a2b76f94e..7abf0f105 100644
--- a/config/settings/base.py
+++ b/config/settings/base.py
@@ -345,7 +345,7 @@ REST_FRAMEWORK = {
         'annotation'    : '10/second',
     },
     'DEFAULT_AUTHENTICATION_CLASSES': (
-        'api.v1.backends.LigoAuthentication',
+        'api.backends.LigoAuthentication',
     ),
     'COERCE_DECIMAL_TO_STRING': False,
     'EXCEPTION_HANDLER':
diff --git a/gracedb/api/v1/backends.py b/gracedb/api/backends.py
similarity index 100%
rename from gracedb/api/v1/backends.py
rename to gracedb/api/backends.py
diff --git a/gracedb/api/v1/events/fields.py b/gracedb/api/v1/events/fields.py
index 9129160fa..e766763b0 100644
--- a/gracedb/api/v1/events/fields.py
+++ b/gracedb/api/v1/events/fields.py
@@ -1,58 +1,15 @@
 from __future__ import absolute_import
 import logging
 
-from django.contrib.auth import get_user_model
-
 from rest_framework import serializers
 
 from events.models import Event
-
-# Set up user model
-UserModel = get_user_model()
+from ..fields import GenericField
 
 # Set up logger
 logger = logging.getLogger(__name__)
 
 
-class GenericField(serializers.Field):
-    # Field, property, or callable of the object which will be used to
-    # generate the representation of the object.
-    to_repr = None
-    lookup_field = 'id'
-    model = None
-
-    def __init__(self, *args, **kwargs):
-        self.to_repr = kwargs.pop('to_repr', self.to_repr)
-        self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
-        self.model = kwargs.pop('model', self.model)
-
-        assert self.to_repr is not None, ('Must specify to_repr')
-        assert self.model is not None, ('Must specify model')
-        super(GenericField, self).__init__(*args, **kwargs)
-
-    def to_representation(self, obj):
-        value = getattr(obj, self.to_repr)
-
-        # Handle case where we are given a function instead of
-        # a model field or a property
-        if callable(value):
-            value = value()
-        return value
-
-    def to_internal_value(self, data):
-        model_dict = self.get_model_dict(data)
-        try:
-            return self.model.objects.get(**model_dict)
-        except self.model.DoesNotExist:
-            error_msg = '{model} with {lf}={data} does not exist' \
-                .format(model=self.model.__name__, lf=model_dict.keys()[0],
-                data=model_dict.values()[0])
-            raise serializers.ValidationError(error_msg)
-
-    def get_model_dict(self, data):
-        return {self.lookup_field: data}
-
-
 class EventGraceidField(GenericField, serializers.RelatedField):
     to_repr = 'graceid'
     lookup_field = 'id'
diff --git a/gracedb/api/v1/events/throttles.py b/gracedb/api/v1/events/throttles.py
index 02a1c2290..3e40b639c 100644
--- a/gracedb/api/v1/events/throttles.py
+++ b/gracedb/api/v1/events/throttles.py
@@ -1,44 +1,5 @@
-from rest_framework.throttling import UserRateThrottle
+from ..throttles import PostOrPutUserRateThrottle
 
-class PostOrPutUserRateThrottle(UserRateThrottle):
-
-    def allow_request(self, request, view):
-        """
-        This is mostly copied from the Rest Framework's SimpleRateThrottle
-        except we now pass the request to throttle_success
-        """
-        # We don't want to throttle any safe methods
-        if request.method not in ['POST', 'PUT']:
-            return True
-
-        if self.rate is None:
-            return True
-
-        self.key = self.get_cache_key(request, view)
-        if self.key is None:
-            return True
-
-        self.history = self.cache.get(self.key, [])
-        self.now = self.timer()
-
-        # Drop any requests from the history which have now passed the
-        # throttle duration
-        while self.history and self.history[-1] <= self.now - self.duration:
-            self.history.pop()
-        if len(self.history) >= self.num_requests:
-            return self.throttle_failure()
-        return self.throttle_success(request)
-
-    def throttle_success(self, request):
-        """
-        Inserts the current request's timestamp along with the key
-        into the cache. Except we only do this if the request is a
-        writing method (POST or PUT). That's why we needed the request.
-        """
-        if request.method in ['POST', 'PUT']:
-            self.history.insert(0, self.now)
-        self.cache.set(self.key, self.history, self.duration)
-        return True
 
 class EventCreationThrottle(PostOrPutUserRateThrottle):
     scope = 'event_creation'
diff --git a/gracedb/api/v1/events/views.py b/gracedb/api/v1/events/views.py
index 61850f768..04ada23df 100644
--- a/gracedb/api/v1/events/views.py
+++ b/gracedb/api/v1/events/views.py
@@ -33,6 +33,7 @@ from rest_framework.response import Response
 from rest_framework.views import APIView
 
 from alerts.old_alert import issueAlertForUpdate
+from api.backends import LigoAuthentication
 from core.http import check_and_serve_file
 from core.vfile import VersionedFile
 from events.buildVOEvent import buildVOEvent, VOEventBuilderException
@@ -50,7 +51,6 @@ from events.view_utils import eventToDict, eventLogToDict, labelToDict, \
     skymapViewerEMObservationToDict, BadFARRange, check_query_far_range
 from superevents.models import Superevent
 from .throttles import EventCreationThrottle, AnnotationThrottle
-from ..backends import LigoAuthentication
 from ...utils import api_reverse
 
 # Set up logger
diff --git a/gracedb/api/v1/fields.py b/gracedb/api/v1/fields.py
index fb378b958..562f0d6b0 100644
--- a/gracedb/api/v1/fields.py
+++ b/gracedb/api/v1/fields.py
@@ -1,13 +1,14 @@
+from __future__ import absolute_import
 import logging
 import six
 
-from rest_framework import fields
+from rest_framework import serializers
 
 # Set up logger
 logger = logging.getLogger(__name__)
 
 
-class CustomHiddenDefault(fields.CurrentUserDefault):
+class CustomHiddenDefault(serializers.CurrentUserDefault):
     context_key = None
 
     def __init__(self, *args, **kwargs):
@@ -45,7 +46,7 @@ class ParentObjectDefault(CustomHiddenDefault):
         return value
 
 
-class CommaSeparatedOrListField(fields.ListField):
+class CommaSeparatedOrListField(serializers.ListField):
     default_style = {'base_template': 'input.html'}
 
     def __init__(self, *args, **kwargs):
@@ -66,7 +67,7 @@ class CommaSeparatedOrListField(fields.ListField):
         return super(CommaSeparatedOrListField, self).to_internal_value(data)
 
 
-class ChoiceDisplayField(fields.ChoiceField):
+class ChoiceDisplayField(serializers.ChoiceField):
     """
     Same as standard choice field, but return a choice's display_value
     instead of the key when serializing the field.
@@ -74,3 +75,42 @@ class ChoiceDisplayField(fields.ChoiceField):
 
     def to_representation(self, value):
         return self._choices[value]
+
+
+class GenericField(serializers.Field):
+    # Field, property, or callable of the object which will be used to
+    # generate the representation of the object.
+    to_repr = None
+    lookup_field = 'id'
+    model = None
+
+    def __init__(self, *args, **kwargs):
+        self.to_repr = kwargs.pop('to_repr', self.to_repr)
+        self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
+        self.model = kwargs.pop('model', self.model)
+
+        assert self.to_repr is not None, ('Must specify to_repr')
+        assert self.model is not None, ('Must specify model')
+        super(GenericField, self).__init__(*args, **kwargs)
+
+    def to_representation(self, obj):
+        value = getattr(obj, self.to_repr)
+
+        # Handle case where we are given a function instead of
+        # a model field or a property
+        if callable(value):
+            value = value()
+        return value
+
+    def to_internal_value(self, data):
+        model_dict = self.get_model_dict(data)
+        try:
+            return self.model.objects.get(**model_dict)
+        except self.model.DoesNotExist:
+            error_msg = '{model} with {lf}={data} does not exist' \
+                .format(model=self.model.__name__, lf=model_dict.keys()[0],
+                data=model_dict.values()[0])
+            raise serializers.ValidationError(error_msg)
+
+    def get_model_dict(self, data):
+        return {self.lookup_field: data}
diff --git a/gracedb/api/v1/main/views.py b/gracedb/api/v1/main/views.py
index 2a5a54ab6..cf7b120a2 100644
--- a/gracedb/api/v1/main/views.py
+++ b/gracedb/api/v1/main/views.py
@@ -12,12 +12,12 @@ from rest_framework.response import Response
 from rest_framework.reverse import reverse as drf_reverse
 from rest_framework.views import APIView
 
+from api.backends import LigoAuthentication
 from api.utils import api_reverse
 from events.models import Group, Pipeline, Search, Tag, Label, EMGroup, \
     VOEvent, EMBBEventLog, EMSPECTRUM
 from events.view_logic import get_performance_info
 from superevents.models import Superevent
-from ..backends import LigoAuthentication
 from ..superevents.url_templates import construct_url_templates
 
 
diff --git a/gracedb/api/v1/throttles.py b/gracedb/api/v1/throttles.py
new file mode 100644
index 000000000..1c7a5e159
--- /dev/null
+++ b/gracedb/api/v1/throttles.py
@@ -0,0 +1,42 @@
+from rest_framework.throttling import UserRateThrottle
+
+
+class PostOrPutUserRateThrottle(UserRateThrottle):
+
+    def allow_request(self, request, view):
+        """
+        This is mostly copied from the Rest Framework's SimpleRateThrottle
+        except we now pass the request to throttle_success
+        """
+        # We don't want to throttle any safe methods
+        if request.method not in ['POST', 'PUT']:
+            return True
+
+        if self.rate is None:
+            return True
+
+        self.key = self.get_cache_key(request, view)
+        if self.key is None:
+            return True
+
+        self.history = self.cache.get(self.key, [])
+        self.now = self.timer()
+
+        # Drop any requests from the history which have now passed the
+        # throttle duration
+        while self.history and self.history[-1] <= self.now - self.duration:
+            self.history.pop()
+        if len(self.history) >= self.num_requests:
+            return self.throttle_failure()
+        return self.throttle_success(request)
+
+    def throttle_success(self, request):
+        """
+        Inserts the current request's timestamp along with the key
+        into the cache. Except we only do this if the request is a
+        writing method (POST or PUT). That's why we needed the request.
+        """
+        if request.method in ['POST', 'PUT']:
+            self.history.insert(0, self.now)
+        self.cache.set(self.key, self.history, self.duration)
+        return True
-- 
GitLab