Skip to content
Snippets Groups Projects
Commit f24b0879 authored by Tanner Prestegard's avatar Tanner Prestegard Committed by GraceDB
Browse files

Abstracting/generalizing from events app

Moving a few base class throttles and fields from the events app up
one level so that they can be used elsewhere.
parent 94bf4a03
No related branches found
No related tags found
No related merge requests found
......@@ -345,7 +345,7 @@ REST_FRAMEWORK = {
'annotation' : '10/second',
},
'DEFAULT_AUTHENTICATION_CLASSES': (
'api.v1.backends.LigoAuthentication',
'api.backends.LigoAuthentication',
),
'COERCE_DECIMAL_TO_STRING': False,
'EXCEPTION_HANDLER':
......
File moved
from __future__ import absolute_import
import logging
from django.contrib.auth import get_user_model
from rest_framework import serializers
from events.models import Event
# Set up user model
UserModel = get_user_model()
from ..fields import GenericField
# Set up logger
logger = logging.getLogger(__name__)
class GenericField(serializers.Field):
# Field, property, or callable of the object which will be used to
# generate the representation of the object.
to_repr = None
lookup_field = 'id'
model = None
def __init__(self, *args, **kwargs):
self.to_repr = kwargs.pop('to_repr', self.to_repr)
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.model = kwargs.pop('model', self.model)
assert self.to_repr is not None, ('Must specify to_repr')
assert self.model is not None, ('Must specify model')
super(GenericField, self).__init__(*args, **kwargs)
def to_representation(self, obj):
value = getattr(obj, self.to_repr)
# Handle case where we are given a function instead of
# a model field or a property
if callable(value):
value = value()
return value
def to_internal_value(self, data):
model_dict = self.get_model_dict(data)
try:
return self.model.objects.get(**model_dict)
except self.model.DoesNotExist:
error_msg = '{model} with {lf}={data} does not exist' \
.format(model=self.model.__name__, lf=model_dict.keys()[0],
data=model_dict.values()[0])
raise serializers.ValidationError(error_msg)
def get_model_dict(self, data):
return {self.lookup_field: data}
class EventGraceidField(GenericField, serializers.RelatedField):
to_repr = 'graceid'
lookup_field = 'id'
......
from rest_framework.throttling import UserRateThrottle
from ..throttles import PostOrPutUserRateThrottle
class PostOrPutUserRateThrottle(UserRateThrottle):
def allow_request(self, request, view):
"""
This is mostly copied from the Rest Framework's SimpleRateThrottle
except we now pass the request to throttle_success
"""
# We don't want to throttle any safe methods
if request.method not in ['POST', 'PUT']:
return True
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success(request)
def throttle_success(self, request):
"""
Inserts the current request's timestamp along with the key
into the cache. Except we only do this if the request is a
writing method (POST or PUT). That's why we needed the request.
"""
if request.method in ['POST', 'PUT']:
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
class EventCreationThrottle(PostOrPutUserRateThrottle):
scope = 'event_creation'
......
......@@ -33,6 +33,7 @@ from rest_framework.response import Response
from rest_framework.views import APIView
from alerts.old_alert import issueAlertForUpdate
from api.backends import LigoAuthentication
from core.http import check_and_serve_file
from core.vfile import VersionedFile
from events.buildVOEvent import buildVOEvent, VOEventBuilderException
......@@ -50,7 +51,6 @@ from events.view_utils import eventToDict, eventLogToDict, labelToDict, \
skymapViewerEMObservationToDict, BadFARRange, check_query_far_range
from superevents.models import Superevent
from .throttles import EventCreationThrottle, AnnotationThrottle
from ..backends import LigoAuthentication
from ...utils import api_reverse
# Set up logger
......
from __future__ import absolute_import
import logging
import six
from rest_framework import fields
from rest_framework import serializers
# Set up logger
logger = logging.getLogger(__name__)
class CustomHiddenDefault(fields.CurrentUserDefault):
class CustomHiddenDefault(serializers.CurrentUserDefault):
context_key = None
def __init__(self, *args, **kwargs):
......@@ -45,7 +46,7 @@ class ParentObjectDefault(CustomHiddenDefault):
return value
class CommaSeparatedOrListField(fields.ListField):
class CommaSeparatedOrListField(serializers.ListField):
default_style = {'base_template': 'input.html'}
def __init__(self, *args, **kwargs):
......@@ -66,7 +67,7 @@ class CommaSeparatedOrListField(fields.ListField):
return super(CommaSeparatedOrListField, self).to_internal_value(data)
class ChoiceDisplayField(fields.ChoiceField):
class ChoiceDisplayField(serializers.ChoiceField):
"""
Same as standard choice field, but return a choice's display_value
instead of the key when serializing the field.
......@@ -74,3 +75,42 @@ class ChoiceDisplayField(fields.ChoiceField):
def to_representation(self, value):
return self._choices[value]
class GenericField(serializers.Field):
# Field, property, or callable of the object which will be used to
# generate the representation of the object.
to_repr = None
lookup_field = 'id'
model = None
def __init__(self, *args, **kwargs):
self.to_repr = kwargs.pop('to_repr', self.to_repr)
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.model = kwargs.pop('model', self.model)
assert self.to_repr is not None, ('Must specify to_repr')
assert self.model is not None, ('Must specify model')
super(GenericField, self).__init__(*args, **kwargs)
def to_representation(self, obj):
value = getattr(obj, self.to_repr)
# Handle case where we are given a function instead of
# a model field or a property
if callable(value):
value = value()
return value
def to_internal_value(self, data):
model_dict = self.get_model_dict(data)
try:
return self.model.objects.get(**model_dict)
except self.model.DoesNotExist:
error_msg = '{model} with {lf}={data} does not exist' \
.format(model=self.model.__name__, lf=model_dict.keys()[0],
data=model_dict.values()[0])
raise serializers.ValidationError(error_msg)
def get_model_dict(self, data):
return {self.lookup_field: data}
......@@ -12,12 +12,12 @@ from rest_framework.response import Response
from rest_framework.reverse import reverse as drf_reverse
from rest_framework.views import APIView
from api.backends import LigoAuthentication
from api.utils import api_reverse
from events.models import Group, Pipeline, Search, Tag, Label, EMGroup, \
VOEvent, EMBBEventLog, EMSPECTRUM
from events.view_logic import get_performance_info
from superevents.models import Superevent
from ..backends import LigoAuthentication
from ..superevents.url_templates import construct_url_templates
......
from rest_framework.throttling import UserRateThrottle
class PostOrPutUserRateThrottle(UserRateThrottle):
def allow_request(self, request, view):
"""
This is mostly copied from the Rest Framework's SimpleRateThrottle
except we now pass the request to throttle_success
"""
# We don't want to throttle any safe methods
if request.method not in ['POST', 'PUT']:
return True
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success(request)
def throttle_success(self, request):
"""
Inserts the current request's timestamp along with the key
into the cache. Except we only do this if the request is a
writing method (POST or PUT). That's why we needed the request.
"""
if request.method in ['POST', 'PUT']:
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment