From d9bc1d888713478ced2ed81dbb35f0626fbc2153 Mon Sep 17 00:00:00 2001
From: Branson Stephens <branson.stephens@ligo.org>
Date: Thu, 3 Mar 2016 10:58:11 -0600
Subject: [PATCH] Changes required for upgrade to Rest Framework 3.3

---
 gracedb/api.py        | 64 +++++++++++++++++++++++--------------------
 gracedb/view_logic.py |  4 +--
 2 files changed, 36 insertions(+), 32 deletions(-)

diff --git a/gracedb/api.py b/gracedb/api.py
index 67fbf995b..7a998ff9c 100644
--- a/gracedb/api.py
+++ b/gracedb/api.py
@@ -152,10 +152,10 @@ def event_and_auth_required(view):
 def pipeline_auth_required(view):
     @wraps(view)
     def inner(self, request, *args, **kwargs):
-        group_name = request.POST.get('group', None)
+        group_name = request.data.get('group', None)
         if not group_name=='Test':
             try:
-                pipeline = Pipeline.objects.get(name=request.POST['pipeline'])
+                pipeline = Pipeline.objects.get(name=request.data['pipeline'])
             except:
                 return Response({'error': "Please provide a valid pipeline."}, 
                     status = status.HTTP_400_BAD_REQUEST)
@@ -399,11 +399,11 @@ class EventList(APIView):
     def get(self, request, *args, **kwargs):
 
         """I am the GET docstring for EventList"""
-        query = request.QUERY_PARAMS.get("query")
-        count = request.QUERY_PARAMS.get("count", PAGINATE_BY)
-        start = request.QUERY_PARAMS.get("start", 0)
-        sort = request.QUERY_PARAMS.get("sort", "-created")
-        columns = request.QUERY_PARAMS.get("columns", "")
+        query = request.query_params.get("query")
+        count = request.query_params.get("count", PAGINATE_BY)
+        start = request.query_params.get("start", 0)
+        sort = request.query_params.get("sort", "-created")
+        columns = request.query_params.get("columns", "")
 
         events = Event.objects
         if query:
@@ -495,24 +495,28 @@ class EventList(APIView):
         # Eventually, we will want to get rid of this check and just let it fail.
         rv = {}
         rv['warnings'] = []
-        if 'type' in request.POST:
+        if 'type' in request.data:
             request = fix_old_creation_request(request)
             rv['warnings'] += ['It looks like you are using the old GraceDB client (v<=1.14). ' + \
                              'Please update! This will eventually stop working.']
 
         # Check user authorization for pipeline. 
         # XXX This is a temporary hack until they roll out the new client.
-        group_name = request.POST.get('group', None) 
+        group_name = request.data.get('group', None) 
         if not group_name=='Test':
             try:
-                pipeline = Pipeline.objects.get(name=request.POST['pipeline'])
+                pipeline = Pipeline.objects.get(name=request.data['pipeline'])
             except:
                 return Response({'error': "Please provide a valid pipeline."},
                     status = status.HTTP_400_BAD_REQUEST)
             if not user_has_perm(request.user, "populate", pipeline):
                 return HttpResponseForbidden("You don't have permission on this pipeline.")
         
-        form = CreateEventForm(request.POST, request.FILES)
+        # The following looks a bit funny but it is actually necessary. The 
+        # django form expects a dict containing the POST data as the first
+        # arg, and a dict containing the FILE data as the second. In the 
+        # django-restframework, however, both are in request.data
+        form = CreateEventForm(request.data, request.data)
         if form.is_valid():
             event, warnings = _createEventFromForm(request, form)
             if event:
@@ -610,22 +614,22 @@ class EventDetail(APIView):
             return Response(str(e))
 
 #       messages = []
-#       if event.group.name != request.DATA['group']:
+#       if event.group.name != request.data['group']:
 #           messages += [
 #                   "Existing event group ({0}) does not match "
 #                   "replacement event group ({1})".format(
-#                       event.group.name, request.DATA['group'])]
-#       if event.analysisType != request.DATA['type']:
+#                       event.group.name, request.data['group'])]
+#       if event.analysisType != request.data['type']:
 #           messages += [
 #                   "Existing event type ({0}) does not match "
 #                   "replacement event type ({1})".format(
-#                       event.analysisType, request.DATA['type'])]
+#                       event.analysisType, request.data['type'])]
 #       if messages:
 #           return Response("\n".join(messages),
 #                   status=status.HTTP_400_BAD_REQUEST)
 
         # XXX handle duplicate file names.
-        f = request.FILES['eventFile']
+        f = request.data['eventFile']
         uploadDestination = os.path.join(event.datadir(), f.name)
         fdest = VersionedFile(uploadDestination, 'w')
         #for chunk in f.chunks():
@@ -668,8 +672,8 @@ class EventNeighbors(APIView):
     # and TSV renderers.
     @event_and_auth_required
     def get(self, request, event):
-        if request.QUERY_PARAMS.has_key('neighborhood'):
-            delta = request.QUERY_PARAMS['neighborhood']
+        if request.query_params.has_key('neighborhood'):
+            delta = request.query_params['neighborhood']
             try:
                 if delta.find(',') < 0:
                     neighborhood = (int(delta), int(delta))
@@ -780,8 +784,8 @@ class EventLogList(APIView):
 
     @event_and_auth_required
     def post(self, request, event):
-        message = request.DATA.get('message')
-        tagnames = request.DATA.get('tagname', None)
+        message = request.data.get('message')
+        tagnames = request.data.get('tagname', None)
         # Convert tagnames from comma separated list.
         if tagnames:
             tagnames = tagnames.split(',')
@@ -789,7 +793,7 @@ class EventLogList(APIView):
             tagnames = []
 
         try:
-            uploadedFile = request.FILES['upload'] 
+            uploadedFile = request.data['upload'] 
             self.check_object_permissions(self.request, event)
         except:
             uploadedFile = None
@@ -913,7 +917,7 @@ class EMBBEventLogList(APIView):
     @event_and_auth_required
     def post(self, request, event):
         try:
-            eel = create_eel(request.DATA, event, request.user)
+            eel = create_eel(request.data, event, request.user)
         except ValueError, e:
             return Response("%s" % str(e), status=status.HTTP_400_BAD_REQUEST)
         except IntegrityError, e:
@@ -969,7 +973,7 @@ class EMObservationList(APIView):
         # XXX Note the following hack.
         # If this JSON information is requested for skymapViewer, use a different
         # representation for backwards compatibility.
-        if 'skymapViewer' in request.QUERY_PARAMS.keys():
+        if 'skymapViewer' in request.query_params.keys():
             emo = [ skymapViewerEMObservationToDict(emo, request)
                     for emo in emo_set.iterator() ]
 
@@ -1218,7 +1222,7 @@ class EventLogTagDetail(APIView):
             try:
                 tag = Tag.objects.filter(name=tagname)[0]
             except:
-                displayName = request.DATA.get('displayName')
+                displayName = request.data.get('displayName')
                 tag = Tag(name=tagname, displayName=displayName)
                 tag.save()
 
@@ -1733,7 +1737,7 @@ class Files(APIView):
         try:
             # Open / Write the file.
             fdest = VersionedFile(filepath, 'w')
-            f = request.FILES['upload']
+            f = request.data['upload']
             for chunk in f.chunks(): 
                 fdest.write(chunk)
             fdest.close()
@@ -1851,16 +1855,16 @@ class VOEventList(APIView):
 
     @event_and_auth_required
     def post(self, request, event):
-        voevent_type = request.DATA.get('voevent_type', None)
+        voevent_type = request.data.get('voevent_type', None)
         if not voevent_type:
             msg = "You must provide a valid voevent_type."
             return Response({'error': msg}, status = status.HTTP_400_BAD_REQUEST)
 
-        internal = request.DATA.get('internal', 1)
+        internal = request.data.get('internal', 1)
             
-        skymap_type = request.DATA.get('skymap_type', None)
-        skymap_filename = request.DATA.get('skymap_filename', None)
-        skymap_image_filename = request.DATA.get('skymap_image_filename', None)
+        skymap_type = request.data.get('skymap_type', None)
+        skymap_filename = request.data.get('skymap_filename', None)
+        skymap_image_filename = request.data.get('skymap_image_filename', None)
 
         if (skymap_filename and not skymap_type) or (skymap_type and not skymap_filename):
             msg = "Both or neither of skymap_time and skymap_filename must be specified."
diff --git a/gracedb/view_logic.py b/gracedb/view_logic.py
index a6c97562e..b069df54e 100644
--- a/gracedb/view_logic.py
+++ b/gracedb/view_logic.py
@@ -414,10 +414,10 @@ def create_eel(d, event, user):
 # Create an EMBB Observaton Record
 #
 def create_emobservation(request, event):    
-    d = getattr(request, 'DATA', None)
+    d = getattr(request, 'data', None)
     if not d:
         d = getattr(request, 'POST', None)
-    # Still haven't got the d?
+    # Still haven't got the data?
     if not d:
         raise ValueError('create_emobservation: got no post data from the request.')
 
-- 
GitLab