Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • alexander.pace/server
  • geoffrey.mo/gracedb-server
  • deep.chatterjee/gracedb-server
  • cody.messick/server
  • sushant.sharma-chaudhary/server
  • michael-coughlin/server
  • daniel.wysocki/gracedb-server
  • roberto.depietri/gracedb
  • philippe.grassia/gracedb
  • tri.nguyen/gracedb
  • jonah-kanner/gracedb
  • brandon.piotrzkowski/gracedb
  • joseph-areeda/gracedb
  • duncanmmacleod/gracedb
  • thomas.downes/gracedb
  • tanner.prestegard/gracedb
  • leo-singer/gracedb
  • computing/gracedb/server
18 results
Show changes
Showing
with 2927 additions and 0 deletions
import logging
from django.conf import settings
from django.contrib import messages
from django.core.mail import EmailMessage
from django.http import HttpResponseRedirect
from django.shortcuts import render
from django.urls import reverse, reverse_lazy
from django.utils import timezone
from django.utils.decorators import method_decorator
from django.views.generic.edit import DeleteView, UpdateView
from django.views.generic.detail import DetailView
from django_twilio.client import twilio_client
from core.views import MultipleFormView
from ligoauth.decorators import internal_user_required
from .forms import (
PhoneContactForm, EmailContactForm, VerifyContactForm,
EventNotificationForm, SupereventNotificationForm,
)
from . import egad
from .models import Contact, Notification
from .phone import get_twilio_from
# Set up logger
logger = logging.getLogger(__name__)
###############################################################################
# Generic views ###############################################################
###############################################################################
@internal_user_required
def index(request):
context = {
'notifications': request.user.notification_set.all(),
'contacts': request.user.contact_set.all(),
}
return render(request, 'alerts/index.html', context=context)
###############################################################################
# Notification views ##########################################################
###############################################################################
@method_decorator(internal_user_required, name='dispatch')
class CreateNotificationView(MultipleFormView):
"""Create a notification"""
template_name = 'alerts/create_notification.html'
success_url = reverse_lazy('alerts:index')
form_classes = [SupereventNotificationForm, EventNotificationForm]
def get_context_data(self, **kwargs):
kwargs['idx'] = 0
if (self.request.method in ('POST', 'PUT')):
form_keys = [f.key for f in self.form_classes]
idx = form_keys.index(self.request.POST['key_field'])
kwargs['idx'] = idx
return kwargs
def get_form_kwargs(self, *args, **kwargs):
kw = super(CreateNotificationView, self).get_form_kwargs(
*args, **kwargs)
kw['user'] = self.request.user
return kw
def form_valid(self, form):
form.cleaned_data.pop('key_field', None)
# Add user (from request) and category (stored on form class) to
# the form instance, then save
form.instance.user = self.request.user
form.instance.category = form.category
form.save()
# Add message and return
messages.info(self.request, 'Created notification: {n}.'.format(
n=form.instance.description))
return super(CreateNotificationView, self).form_valid(form)
def get(self, request, *args, **kwargs):
# Make sure user has at least one verified contact; if not, redirect
# and display an error message
user_has_verified_contact = request.user.contact_set.filter(
verified=True).exists()
if not user_has_verified_contact:
messages.error(request, ('Error: you have no verified contacts. '
'Create and verify a contact before creating a notification.'))
return HttpResponseRedirect(reverse('alerts:index'))
return super(CreateNotificationView, self).get(request, *args,
**kwargs)
superevent_form_valid = event_form_valid = form_valid
@method_decorator(internal_user_required, name='dispatch')
class EditNotificationView(UpdateView):
"""Edit a notification"""
template_name = 'alerts/edit_notification.html'
# Have to provide form_class, but it will be dynamically selected below in
# get_form()
form_class = SupereventNotificationForm
success_url = reverse_lazy('alerts:index')
def get_form_class(self):
if self.object.category == Notification.NOTIFICATION_CATEGORY_EVENT:
return EventNotificationForm
else:
return SupereventNotificationForm
def get_form_kwargs(self, *args, **kwargs):
kw = super(EditNotificationView, self).get_form_kwargs(
*args, **kwargs)
kw['user'] = self.request.user
# Cases that have a label query actually have labels in the database.
# But we don't want to include those in the form because
# a) it's confusing and b) it breaks the form
if self.object.label_query and self.object.labels.exists():
kw['initial']['labels'] = None
return kw
def get_queryset(self):
return self.request.user.notification_set.all()
@method_decorator(internal_user_required, name='dispatch')
class DeleteNotificationView(DeleteView):
"""Delete a notification"""
success_url = reverse_lazy('alerts:index')
def get(self, request, *args, **kwargs):
# Override this so that we don't require a confirmation page
# for deletion
return self.delete(request, *args, **kwargs)
def form_valid(self, form, request, *args, **kwargs):
response = super(DeleteNotificationView, self).delete(request, *args,
**kwargs)
messages.info(request, 'Notification "{n}" has been deleted.'.format(
n=self.object.description))
return super().delete(request, *args, **kwargs)
def get_queryset(self):
# Queryset should only contain the user's notifications
return self.request.user.notification_set.all()
###############################################################################
# Contact views ###############################################################
###############################################################################
@method_decorator(internal_user_required, name='dispatch')
class CreateContactView(MultipleFormView):
"""Create a contact"""
template_name = 'alerts/create_contact.html'
success_url = reverse_lazy('alerts:index')
form_classes = [PhoneContactForm, EmailContactForm]
def get_context_data(self, **kwargs):
kwargs['idx'] = 0
if (self.request.method in ('POST', 'PUT')):
form_keys = [f.key for f in self.form_classes]
idx = form_keys.index(self.request.POST['key_field'])
kwargs['idx'] = idx
return kwargs
def form_valid(self, form):
# Remove key_field, add user, and save form
form.cleaned_data.pop('key_field', None)
form.instance.user = self.request.user
form.save()
# Generate message and return
messages.info(self.request, 'Created contact "{cname}".'.format(
cname=form.instance.description))
return super(CreateContactView, self).form_valid(form)
email_form_valid = phone_form_valid = form_valid
@method_decorator(internal_user_required, name='dispatch')
class EditContactView(UpdateView):
"""
Edit a contact. Users shouldn't be able to edit the actual email address
or phone number since that would allow them to circumvent the verification
process.
"""
template_name = 'alerts/edit_contact.html'
# Have to provide form_class, but it will be dynamically selected below in
# get_form()
form_class = PhoneContactForm
success_url = reverse_lazy('alerts:index')
def get_form_class(self):
if self.object.phone is not None:
return PhoneContactForm
else:
return EmailContactForm
return self.form_class
def get_form(self, form_class=None):
form = super(EditContactView, self).get_form(form_class)
if isinstance(form, PhoneContactForm):
form.fields['phone'].disabled = True
elif isinstance(form, EmailContactForm):
form.fields['email'].disabled = True
return form
def get_queryset(self):
return self.request.user.contact_set.all()
@method_decorator(internal_user_required, name='dispatch')
class DeleteContactView(DeleteView):
"""Delete a contact"""
success_url = reverse_lazy('alerts:index')
def get(self, request, *args, **kwargs):
# Override this so that we don't require a confirmation page
# for deletion
return self.delete(request, *args, **kwargs)
def form_valid(self, form, request, *args, **kwargs):
# Get contact
self.object = self.get_object()
# Contacts can only be deleted if they aren't part of a notification -
# this will prevent cases where a user creates a notification, deletes
# the related contact(s), and then wonders why they aren't getting
# any notifications.
if self.object.notification_set.exists():
messages.error(request, ('Contact "{cname}" cannot be deleted '
'because it is part of a notification. Remove it from the '
'notification or delete the notification first.').format(
cname=self.object.description))
return HttpResponseRedirect(reverse('alerts:index'))
# Otherwise, delete the contact and show a corresponding message.
self.object.is_deleted = True
self.object.delete()
messages.info(request, 'Contact "{cname}" has been deleted.'.format(
cname=self.object.description))
return HttpResponseRedirect(self.get_success_url())
def get_queryset(self):
# Queryset should only contain the user's contacts
return self.request.user.contact_set.all()
@method_decorator(internal_user_required, name='dispatch')
class TestContactView(DetailView):
"""Test a contact (must be verified already)"""
# Send alerts to all contact methods
success_url = reverse_lazy('alerts:index')
def get_queryset(self):
return self.request.user.contact_set.all()
def get(self, request, *args, **kwargs):
self.object = self.get_object()
# Handle case where contact is not verified
if not self.object.verified:
msg = ('Contact "{desc}" must be verified before it can be '
'tested.').format(desc=self.object.description)
messages.info(request, msg)
return HttpResponseRedirect(self.success_url)
# Send test notifications
msg = 'This is a test of contact "{desc}" from {host}.'.format(
desc=self.object.description, host=settings.LIGO_FQDN)
if self.object.email:
subject = 'Test of contact "{desc}" from {host}'.format(
desc=self.object.description, host=settings.LIGO_FQDN)
if settings.ENABLE_EGAD_EMAIL:
payload = {
"recipients": [self.object.email],
"subject": subject,
"body": msg,
}
egad.send_alert("email", payload)
else:
email = EmailMessage(subject, msg,
from_email=settings.ALERT_EMAIL_FROM, to=[self.object.email])
email.send()
if self.object.phone:
# Construct URL of TwiML bin
twiml_url = '{base}{twiml_bin}'.format(
base=settings.TWIML_BASE_URL,
twiml_bin=settings.TWIML_BIN['test'])
if settings.ENABLE_EGAD_PHONE:
payload = {
"contacts": [{
"phone_method": self.object.phone_method,
"phone_number": self.object.phone,
}],
"message": msg,
"twiml_url": twiml_url,
}
egad.send_alert("phone", payload)
else:
# Get "from" phone number.
from_ = get_twilio_from()
# Send test call
if (self.object.phone_method == Contact.CONTACT_PHONE_CALL or
self.object.phone_method == Contact.CONTACT_PHONE_BOTH):
# Make call
twilio_client.calls.create(to=self.object.phone, from_=from_,
url=twiml_url, method='GET')
if (self.object.phone_method == Contact.CONTACT_PHONE_TEXT or
self.object.phone_method == Contact.CONTACT_PHONE_BOTH):
twilio_client.messages.create(to=self.object.phone,
from_=from_, body=msg)
# Message for web view
messages.info(request, 'Testing contact "{desc}".'.format(
desc=self.object.description))
return HttpResponseRedirect(self.success_url)
@method_decorator(internal_user_required, name='dispatch')
class VerifyContactView(UpdateView):
"""Request a verification code or verify a contact"""
template_name = 'alerts/verify_contact.html'
form_class = VerifyContactForm
success_url = reverse_lazy('alerts:index')
def get_queryset(self):
return self.request.user.contact_set.all()
def form_valid(self, form):
self.object.verify()
msg = 'Contact "{cname}" successfully verified.'.format(
cname=self.object.description)
messages.info(self.request, msg)
return super(VerifyContactView, self).form_valid(form)
def get_context_data(self, **kwargs):
context = super(VerifyContactView, self).get_context_data(**kwargs)
# Determine if verification code exists and is expired
if (self.object.verification_code is not None and
timezone.now() > self.object.verification_expiration):
context['code_expired'] = True
return context
@method_decorator(internal_user_required, name='dispatch')
class RequestVerificationCodeView(DetailView):
"""Redirect view for requesting a contact verification code"""
def get_queryset(self):
return self.request.user.contact_set.all()
def get(self, request, *args, **kwargs):
self.object = self.get_object()
# Handle case where contact is already verified
if self.object.verified:
msg = 'Contact "{desc}" is already verified.'.format(
desc=self.object.description)
messages.info(request, msg)
return HttpResponseRedirect(reverse('alerts:index'))
# Otherwise, set up verification code for contact
self.object.generate_verification_code()
# Send verification code
self.object.send_verification_code()
messages.info(request, "Verification code sent.")
return HttpResponseRedirect(reverse('alerts:verify-contact',
args=[self.object.pk]))
from __future__ import absolute_import
import copy
import logging
import os
import simplejson
import socket
import sys
from django.core.mail import EmailMessage
from django.conf import settings
from xml.sax.saxutils import escape
from datetime import datetime, timezone
from hashlib import sha1
from core.time_utils import gpsToUtc
from events.permission_utils import is_external
from events.shortcuts import is_event
from superevents.shortcuts import is_superevent
from .lvalert import send_with_lvalert_overseer, send_with_kafka_client
from . import egad
# Set up logger
logger = logging.getLogger(__name__)
def get_xmpp_node_names(event_or_superevent):
"""
Utility function for determining the names of nodes to which XMPP
notifications should be sent. Accepts an event or superevent object as the
sole argument.
"""
# Compile a list of node names
node_names = []
if is_superevent(event_or_superevent):
superevent = event_or_superevent
if superevent.is_production():
superevent_node = 'superevent'
elif superevent.is_mdc():
superevent_node = 'mdc_superevent'
else:
superevent_node = 'test_superevent'
node_names.append(superevent_node)
elif is_event(event_or_superevent):
# Node name format is group_pipeline or group_pipeline_search
# If search is provided, we send alerts to both of the relevant nodes
# NOTE: for test events, group=Test
event = event_or_superevent
gp_node = "{group}_{pipeline}".format(group=event.group.name,
pipeline=event.pipeline.name).lower()
node_names.append(gp_node)
if event.search and settings.SEND_TO_SEARCH_TOPICS:
gps_node = gp_node + "_{search}".format(
search=event.search.name.lower())
node_names.append(gps_node)
else:
error_msg = ('Object is of {0} type; should be an event '
'or superevent').format(type(event_or_superevent))
logger.error(error_msg)
# TODO: way to catch this?
raise TypeError(error_msg)
return node_names
def issue_xmpp_alerts_local(event_or_superevent, alert_type, serialized_object,
serialized_parent=None):
"""
serialized_object should be a dict
"""
# Check settings switch for turning off XMPP alerts
if not settings.SEND_XMPP_ALERTS:
return
# FIXME: quarantine detchar and hardwareinjection events for now
if (is_event(event_or_superevent) and
(event_or_superevent.group.name == 'Detchar' or
event_or_superevent.pipeline.name == 'HardwareInjection')):
return
# Determine LVAlert node names
node_names = get_xmpp_node_names(event_or_superevent)
# Get uid
uid = event_or_superevent.graceid
# Create the output dictionary and serialize as JSON.
lva_data = {
'uid': uid,
'alert_type': alert_type,
'dispatched': f'{datetime.now(timezone.utc):%Y-%m-%d %H:%M:%S %Z}',
'data': serialized_object,
}
# Add serialized "parent" object
if serialized_parent is not None:
lva_data['object'] = serialized_parent
# Dump to JSON format:
# simplejson.dumps is needed to properly handle Decimal fields
msg = simplejson.dumps(lva_data)
# Try 'escaping' the message:
msg = escape(msg)
# Log message for debugging
logger.info("issue_xmpp_alerts: sending message {msg} for {uid}" \
.format(msg=msg, uid=uid))
# Loop over LVAlert servers and nodes, issuing the alert to each
for overseer_instance in settings.LVALERT_OVERSEER_INSTANCES[::-1]:
server = overseer_instance.get('lvalert_server')
port = overseer_instance.get('listen_port')
for node_name in node_names:
# Calculate unique message_id and log
message_id = sha1((node_name + msg).encode()).hexdigest()
# Log message
logger.info(("issue_kafka_alerts: sending alert type {alert_type} "
"with message {msg_id} for {uid} to {node} on {server}").format(
alert_type=alert_type, msg_id=message_id, uid=uid,
node=node_name, server=server))
# Try to send with LVAlert Overseer (if enabled)
success = False
if settings.USE_LVALERT_OVERSEER:
# Send with LVAlert Overseer
success = send_with_lvalert_overseer(node_name, msg, port)
# If not success, we need to do this the old way.
if not success:
logger.critical(("issue_kafka_alerts: sending message with "
"Overseer failed, trying igwn-alert client code"))
# If not using Overseer or if sending with overseer failed,
# use basic igwn-alert client send
if (not settings.USE_LVALERT_OVERSEER) or (not success):
try:
# Make a settings dictionary and then change some names:
lvalert_settings_dict = copy.deepcopy(overseer_instance)
port = lvalert_settings_dict.pop('listen_port')
server = lvalert_settings_dict.pop('lvalert_server')
lvalert_settings_dict['group'] = lvalert_settings_dict.pop('igwn_alert_group')
send_with_kafka_client(node_name, msg, server,
**lvalert_settings_dict)
except Exception as e:
logger.critical(("issue_kafka_alerts: error sending "
"message with igwn-alert client: {e}").format(e=e))
def issue_xmpp_alerts_egad(event_or_superevent, alert_type, serialized_object,
serialized_parent=None):
"""
serialized_object should be a dict
"""
# Check settings switch for turning off XMPP alerts
if not settings.SEND_XMPP_ALERTS:
return
# Determine LVAlert node names
node_names = get_xmpp_node_names(event_or_superevent)
# Get uid
uid = event_or_superevent.graceid
# Create the output dictionary and serialize as JSON.
lva_data = {
'uid': uid,
'alert_type': alert_type,
'dispatched': f'{datetime.now(timezone.utc):%Y-%m-%d %H:%M:%S %Z}',
'data': serialized_object,
}
# Add serialized "parent" object
if serialized_parent is not None:
lva_data['object'] = serialized_parent
# Dump to JSON format:
# simplejson.dumps is needed to properly handle Decimal fields
msg = simplejson.dumps(lva_data)
# Try 'escaping' the message:
msg = escape(msg)
# Log message for debugging
logger.info("issue_xmpp_alerts: sending message {msg} for {uid}" \
.format(msg=msg, uid=uid))
payload = {
"topics": node_names,
"message": msg,
}
egad.send_alert("kafka", payload)
if settings.ENABLE_EGAD_KAFKA:
issue_xmpp_alerts = issue_xmpp_alerts_egad
else:
issue_xmpp_alerts = issue_xmpp_alerts_local
# See the VOEvent specification for details
# http://www.ivoa.net/Documents/latest/VOEvent.html
import datetime
import logging
import os
from scipy.constants import c, G, pi
import voeventparse as vp
from django.conf import settings
from django.urls import reverse
from core.time_utils import gpsToUtc
from core.urls import build_absolute_uri
from events.models import VOEventBase, Event
from events.models import CoincInspiralEvent, MultiBurstEvent, \
LalInferenceBurstEvent, MLyBurstEvent
from superevents.shortcuts import is_superevent
# Set up logger
logger = logging.getLogger(__name__)
###############################################################################
# SETUP #######################################################################
###############################################################################
# Dict of VOEvent type abbreviations and full strings
VOEVENT_TYPE_DICT = dict(VOEventBase.VOEVENT_TYPE_CHOICES)
# Used to create the Packet_Type parameter block
# Note: order matters. The order of this dict is the
# same as VOEVENT_TYPE_DICT.
PACKET_TYPES = {
VOEventBase.VOEVENT_TYPE_PRELIMINARY: (150, 'LVC_PRELIMINARY'),
VOEventBase.VOEVENT_TYPE_INITIAL: (151, 'LVC_INITIAL'),
VOEventBase.VOEVENT_TYPE_UPDATE: (152, 'LVC_UPDATE'),
VOEventBase.VOEVENT_TYPE_RETRACTION: (164, 'LVC_RETRACTION'),
VOEventBase.VOEVENT_TYPE_EARLYWARNING: (163, 'LVC_EARLY_WARNING'),
}
# Description strings
DEFAULT_DESCRIPTION = \
"Candidate gravitational wave event identified by low-latency analysis"
INSTRUMENT_DESCRIPTIONS = {
"H1": "H1: LIGO Hanford 4 km gravitational wave detector",
"L1": "L1: LIGO Livingston 4 km gravitational wave detector",
"V1": "V1: Virgo 3 km gravitational wave detector",
"K1": "K1: KAGRA 3 km gravitational wave detector"
}
###############################################################################
# MAIN ########################################################################
###############################################################################
def construct_voevent_file(obj, voevent, request=None):
# Setup ###################################################################
## Determine event or superevent
obj_is_superevent = False
if is_superevent(obj):
obj_is_superevent = True
event = obj.preferred_event
graceid = obj.default_superevent_id
obj_view_name = "superevents:view"
fits_view_name = "api:default:superevents:superevent-file-detail"
else:
event = obj
graceid = obj.graceid
obj_view_name = "view"
fits_view_name = "api:default:events:files"
# Get the event subclass (CoincInspiralEvent, MultiBurstEvent, etc.) and
# set that as the event
event = event.get_subclass_or_self()
## Let's convert that voevent_type to something nicer looking
voevent_type = VOEVENT_TYPE_DICT[voevent.voevent_type]
## Now build the IVORN.
if voevent_type == 'earlywarning':
type_string = 'EarlyWarning'
else:
type_string = voevent_type.capitalize()
voevent_id = '{gid}-{N}-{type_str}'.format(type_str=type_string,
gid=graceid, N=voevent.N)
## Determine role
if event.is_mdc() or event.is_test():
role = vp.definitions.roles.test
else:
role = vp.definitions.roles.observation
## Instantiate VOEvent
v = vp.Voevent(settings.VOEVENT_STREAM, voevent_id, role)
## Set root Description
if voevent_type != 'retraction':
v.Description = "Report of a candidate gravitational wave event"
# Overwrite the description for early warning events:
if voevent_type == 'earlywarning':
v.Description = "Early warning report of a candidate gravitational wave event"
# Who #####################################################################
## Remove Who.Description
v.Who.remove(v.Who.Description)
## Set Who.Date
vp.set_who(
v,
date=datetime.datetime.utcnow()
)
v.Who.Date += 'Z'
## Set Who.Author
vp.set_author(
v,
contactName="LIGO Scientific Collaboration, Virgo Collaboration, and KAGRA Collaboration"
)
# How #####################################################################
if voevent_type != 'retraction':
descriptions = [DEFAULT_DESCRIPTION]
# Add instrument descriptions
instruments = event.instruments.split(',')
for inst in INSTRUMENT_DESCRIPTIONS:
if inst in instruments:
descriptions.append(INSTRUMENT_DESCRIPTIONS[inst])
if voevent.coinc_comment:
descriptions.append("A gravitational wave trigger identified a "
"possible counterpart GRB")
vp.add_how(v, descriptions=descriptions)
# What ####################################################################
# UCD = Unified Content Descriptors
# http://monet.uni-sw.gwdg.de/twiki/bin/view/VOEvent/UnifiedContentDescriptors
# OR -- (from VOTable document, [21] below)
# http://www.ivoa.net/twiki/bin/view/IVOA/IvoaUCD
# http://cds.u-strasbg.fr/doc/UCD.htx
#
# which somehow gets you to:
# http://www.ivoa.net/Documents/REC/UCD/UCDlist-20070402.html
# where you might find some actual information.
# Unit / Section 4.3 of [21] which relies on [25]
# [21] http://www.ivoa.net/Documents/latest/VOT.html
# [25] http://vizier.u-strasbg.fr/doc/catstd-3.2.htx
#
# Basically, a string that makes sense to humans about what units a value
# is. eg. "m/s"
## Packet_Type param
p_packet_type = vp.Param(
"Packet_Type",
value=PACKET_TYPES[voevent.voevent_type][0],
ac=True
)
p_packet_type.Description = ("The Notice Type number is assigned/used "
"within GCN, eg type={typenum} is an {typedesc} notice").format(
typenum=PACKET_TYPES[voevent.voevent_type][0],
typedesc=PACKET_TYPES[voevent.voevent_type][1]
)
v.What.append(p_packet_type)
# Internal param
p_internal = vp.Param(
"internal",
value=int(voevent.internal),
ac=True
)
p_internal.Description = ("Indicates whether this event should be "
"distributed to LSC/Virgo/KAGRA members only")
v.What.append(p_internal)
## Packet serial number
p_serial_num = vp.Param(
"Pkt_Ser_Num",
value=voevent.N,
ac=True
)
p_serial_num.Description = ("A number that increments by 1 each time a "
"new revision is issued for this event")
v.What.append(p_serial_num)
## Event graceid or superevent ID
p_gid = vp.Param(
"GraceID",
value=graceid,
ucd="meta.id",
dataType="string"
)
p_gid.Description = "Identifier in GraceDB"
v.What.append(p_gid)
## Alert type parameter
if voevent_type == 'earlywarning':
voevent_at = 'EarlyWarning'
else:
voevent_at = voevent_type.capitalize()
p_alert_type = vp.Param(
"AlertType",
value = voevent_at,
ucd="meta.version",
dataType="string"
)
p_alert_type.Description = "VOEvent alert type"
v.What.append(p_alert_type)
## Whether the event is a hardware injection or not
p_hardware_inj = vp.Param(
"HardwareInj",
value=int(voevent.hardware_inj),
ucd="meta.number",
ac=True
)
p_hardware_inj.Description = ("Indicates that this event is a hardware "
"injection if 1, no if 0")
v.What.append(p_hardware_inj)
## Open alert parameter
p_open_alert = vp.Param(
"OpenAlert",
value=int(voevent.open_alert),
ucd="meta.number",
ac=True
)
p_open_alert.Description = ("Indicates that this event is an open alert "
"if 1, no if 0")
v.What.append(p_open_alert)
## Superevent page
p_detail_url = vp.Param(
"EventPage",
value=build_absolute_uri(
reverse(obj_view_name, args=[graceid]),
request
),
ucd="meta.ref.url",
dataType="string"
)
p_detail_url.Description = ("Web page for evolving status of this GW "
"candidate")
v.What.append(p_detail_url)
## Only for non-retractions
if voevent_type != 'retraction':
## Instruments
p_instruments = vp.Param(
"Instruments",
value=event.instruments,
ucd="meta.code",
dataType="string"
)
p_instruments.Description = ("List of instruments used in analysis to "
"identify this event")
v.What.append(p_instruments)
## False alarm rate
if event.far:
p_far = vp.Param(
"FAR",
value=float(max(event.far, settings.VOEVENT_FAR_FLOOR)),
ucd="arith.rate;stat.falsealarm",
unit="Hz",
ac=True
)
p_far.Description = ("False alarm rate for GW candidates with "
"this strength or greater")
v.What.append(p_far)
## Whether this is a significant candidate or not
p_significant = vp.Param(
"Significant",
value=int(voevent.significant),
ucd="meta.number",
ac=True
)
p_significant.Description = ("Indicates that this event is significant if "
"1, no if 0")
v.What.append(p_significant)
## Analysis group
## Special case: BURST-CWB-BBH search is a CBC group alert.
if (event.group.name == "Burst" and
(event.search and (event.pipeline.name == 'CWB' and
event.search.name == 'BBH'))
):
p_group = vp.Param(
"Group",
value="CBC",
ucd="meta.code",
dataType="string"
)
else:
p_group = vp.Param(
"Group",
value=event.group.name,
ucd="meta.code",
dataType="string"
)
p_group.Description = "Data analysis working group"
v.What.append(p_group)
## Analysis pipeline
p_pipeline = vp.Param(
"Pipeline",
value=event.pipeline.name,
ucd="meta.code",
dataType="string"
)
p_pipeline.Description = "Low-latency data analysis pipeline"
v.What.append(p_pipeline)
## Search type
if event.search:
p_search = vp.Param(
"Search",
value=event.search.name,
ucd="meta.code",
dataType="string"
)
p_search.Description = "Specific low-latency search"
v.What.append(p_search)
## RAVEN specific entries
if (is_superevent(obj) and voevent.raven_coinc):
ext_id = obj.em_type
ext_event = Event.getByGraceid(ext_id)
emcoinc_params = []
## External GCN ID
if ext_event.trigger_id:
p_extid = vp.Param(
"External_GCN_Notice_Id",
value=ext_event.trigger_id,
ucd="meta.id",
dataType="string"
)
p_extid.Description = ("GCN trigger ID of external event")
emcoinc_params.append(p_extid)
## External IVORN
if ext_event.ivorn:
p_extivorn = vp.Param(
"External_Ivorn",
value=ext_event.ivorn,
ucd="meta.id",
dataType="string"
)
p_extivorn.Description = ("IVORN of external event")
emcoinc_params.append(p_extivorn)
## External Pipeline
if ext_event.pipeline:
p_extpipeline = vp.Param(
"External_Observatory",
value=ext_event.pipeline.name,
ucd="meta.code",
dataType="string"
)
p_extpipeline.Description = ("External Observatory")
emcoinc_params.append(p_extpipeline)
## External Search
if ext_event.search:
p_extsearch = vp.Param(
"External_Search",
value=ext_event.search.name,
ucd="meta.code",
dataType="string"
)
p_extsearch.Description = ("External astrophysical search")
emcoinc_params.append(p_extsearch)
## Time Difference
if ext_event.gpstime and obj.t_0:
deltat = round(ext_event.gpstime - obj.t_0, 2)
p_deltat = vp.Param(
"Time_Difference",
value=float(deltat),
ucd="meta.code",
ac=True,
)
p_deltat.Description = ("Time difference between GW candidate "
"and external event, centered on the "
"GW candidate")
emcoinc_params.append(p_deltat)
## Temporal Coinc FAR
if obj.time_coinc_far:
p_coincfar = vp.Param(
"Time_Coincidence_FAR",
value=obj.time_coinc_far,
ucd="arith.rate;stat.falsealarm",
ac=True,
unit="Hz"
)
p_coincfar.Description = ("Estimated coincidence false alarm "
"rate in Hz using timing")
emcoinc_params.append(p_coincfar)
## Spatial-Temporal Coinc FAR
if obj.space_coinc_far:
p_coincfar_space = vp.Param(
"Time_Sky_Position_Coincidence_FAR",
value=obj.space_coinc_far,
ucd="arith.rate;stat.falsealarm",
ac=True,
unit="Hz"
)
p_coincfar_space.Description = ("Estimated coincidence false alarm "
"rate in Hz using timing and sky "
"position")
emcoinc_params.append(p_coincfar_space)
## RAVEN combined sky map
if voevent.combined_skymap_filename:
## Skymap group
### fits skymap URL
fits_skymap_url_comb = build_absolute_uri(
reverse(fits_view_name, args=[graceid,
voevent.combined_skymap_filename]),
request
)
p_fits_url_comb = vp.Param(
"joint_skymap_fits",
value=fits_skymap_url_comb,
ucd="meta.ref.url",
dataType="string"
)
p_fits_url_comb.Description = "Combined GW-External Sky Map FITS"
emcoinc_params.append(p_fits_url_comb)
## Create EMCOINC group
emcoinc_group = vp.Group(
emcoinc_params,
name='External Coincidence',
type='External Coincidence' # keep this only for backwards compatibility
)
emcoinc_group.Description = \
("Properties of joint coincidence found by RAVEN")
v.What.append(emcoinc_group)
# initial and update VOEvents must have a skymap.
# new feature (10/24/2016): preliminary VOEvents can have a skymap,
# but they don't have to.
if (voevent_type in ["initial", "update"] or
(voevent_type in ["preliminary", "earlywarning"] and voevent.skymap_filename != None)):
## Skymap group
### fits skymap URL
fits_skymap_url = build_absolute_uri(
reverse(fits_view_name, args=[graceid, voevent.skymap_filename]),
request
)
p_fits_url = vp.Param(
"skymap_fits",
value=fits_skymap_url,
ucd="meta.ref.url",
dataType="string"
)
p_fits_url.Description = "Sky Map FITS"
### Create skymap group with params
skymap_group = vp.Group(
[p_fits_url],
name="GW_SKYMAP",
type="GW_SKYMAP",
)
### Add to What
v.What.append(skymap_group)
## Analysis specific attributes
if voevent_type != 'retraction':
### Classification group (EM-Bright params; CBC only)
### In should also be present in case of cWB-BBH
em_bright_params = []
source_properties_params = []
if ( (isinstance(event, CoincInspiralEvent) or
(event.search and (event.pipeline.name == 'CWB' and
event.search.name == 'BBH'))
) and voevent_type != 'retraction'):
# EM-Bright mass classifier information for CBC event candidates
if voevent.prob_bns is not None:
p_pbns = vp.Param(
"BNS",
value=voevent.prob_bns,
ucd="stat.probability",
ac=True
)
p_pbns.Description = \
("Probability that the source is a binary neutron star "
"merger (both objects lighter than 3 solar masses)")
em_bright_params.append(p_pbns)
if voevent.prob_nsbh is not None:
p_pnsbh = vp.Param(
"NSBH",
value=voevent.prob_nsbh,
ucd="stat.probability",
ac=True
)
p_pnsbh.Description = \
("Probability that the source is a neutron star-black "
"merger (secondary lighter than 3 solar masses)")
em_bright_params.append(p_pnsbh)
if voevent.prob_bbh is not None:
p_pbbh = vp.Param(
"BBH",
value=voevent.prob_bbh,
ucd="stat.probability",
ac=True
)
p_pbbh.Description = ("Probability that the source is a "
"binary black hole merger (both objects "
"heavier than 3 solar masses)")
em_bright_params.append(p_pbbh)
#if voevent.prob_mass_gap is not None:
# p_pmassgap = vp.Param(
# "MassGap",
# value=voevent.prob_mass_gap,
# ucd="stat.probability",
# ac=True
# )
# p_pmassgap.Description = ("Probability that the source has at "
# "least one object between 3 and 5 "
# "solar masses")
# em_bright_params.append(p_pmassgap)
if voevent.prob_terrestrial is not None:
p_pterr = vp.Param(
"Terrestrial",
value=voevent.prob_terrestrial,
ucd="stat.probability",
ac=True
)
p_pterr.Description = ("Probability that the source is "
"terrestrial (i.e., a background noise "
"fluctuation or a glitch)")
em_bright_params.append(p_pterr)
# Add to source properties group
if voevent.prob_has_ns is not None:
p_phasns = vp.Param(
name="HasNS",
value=voevent.prob_has_ns,
ucd="stat.probability",
ac=True
)
p_phasns.Description = ("Probability that at least one object "
"in the binary has a mass that is "
"less than 3 solar masses")
source_properties_params.append(p_phasns)
if voevent.prob_has_remnant is not None:
p_phasremnant = vp.Param(
"HasRemnant",
value=voevent.prob_has_remnant,
ucd="stat.probability",
ac=True
)
p_phasremnant.Description = ("Probability that a nonzero mass "
"was ejected outside the central "
"remnant object")
source_properties_params.append(p_phasremnant)
if voevent.prob_has_mass_gap is not None:
p_pmassgap = vp.Param(
"HasMassGap",
value=voevent.prob_has_mass_gap,
ucd="stat.probability",
ac=True
)
p_pmassgap.Description = ("Probability that the source has at "
"least one object between 3 and 5 "
"solar masses")
source_properties_params.append(p_pmassgap)
if voevent.prob_has_ssm is not None:
p_phasssm = vp.Param(
"HasSSM",
value=voevent.prob_has_ssm,
ucd="stat.probability",
ac=True
)
p_phasssm.Description = ("Probability that the source has at "
"least one object less than 1 "
"solar mass")
source_properties_params.append(p_phasssm)
elif isinstance(event, MultiBurstEvent):
### Central frequency
p_central_freq = vp.Param(
"CentralFreq",
value=float(event.central_freq),
ucd="gw.frequency",
unit="Hz",
ac=True,
)
p_central_freq.Description = \
"Central frequency of GW burst signal"
v.What.append(p_central_freq)
### Duration
p_duration = vp.Param(
"Duration",
value=float(event.duration),
unit="s",
ucd="time.duration",
ac=True,
)
p_duration.Description = "Measured duration of GW burst signal"
v.What.append(p_duration)
elif isinstance(event, LalInferenceBurstEvent):
p_freq = vp.Param(
"CentralFreq",
value=float(event.frequency_mean),
ucd="gw.frequency",
unit="Hz",
ac=True,
)
p_freq.Description = "Central frequency of GW burst signal"
v.What.append(p_freq)
duration = event.quality_mean / (2 * pi * event.frequency_mean)
p_duration = vp.Param(
"Duration",
value=float(duration),
unit="s",
ucd="time.duration",
ac=True,
)
p_duration.Description = "Measured duration of GW burst signal"
v.What.append(p_duration)
elif isinstance(event, MLyBurstEvent):
p_central_freq = vp.Param(
"CentralFreq",
value=float(event.central_freq),
ucd="gw.frequency",
unit="Hz",
ac=True,
)
p_central_freq.Description = \
"Central frequency of GW burst signal"
v.What.append(p_central_freq)
p_duration = vp.Param(
"Duration",
value=float(event.duration),
unit="s",
ucd="time.duration",
ac=True,
)
p_duration.Description = "Measured duration of GW burst signal"
v.What.append(p_duration)
## Create classification group
classification_group = vp.Group(
em_bright_params,
name='Classification',
type='Classification' # keep this only for backwards compatibility
)
classification_group.Description = \
("Source classification: binary neutron star (BNS), neutron star-"
"black hole (NSBH), binary black hole (BBH), or "
"terrestrial (noise)")
v.What.append(classification_group)
## Create properties group
properties_group = vp.Group(
source_properties_params,
name='Properties',
type='Properties' # keep this only for backwards compatibility
)
properties_group.Description = \
("Qualitative properties of the source, conditioned on the "
"assumption that the signal is an astrophysical compact binary "
"merger")
v.What.append(properties_group)
# WhereWhen ###############################################################
# NOTE: we use a fake ra, dec, err, and units for creating the coords
# object. We are required to provide them by the voeventparse code, but
# our "format" for VOEvents didn't have a Position2D entry. So to make
# the code work but maintain the same format, we add fake information here,
# then remove it later.
coords = vp.Position2D(
ra=1, dec=2, err=3, units='degrees',
system=vp.definitions.sky_coord_system.utc_fk5_geo
)
observatory_id = 'LIGO Virgo'
vp.add_where_when(
v,
coords,
gpsToUtc(event.gpstime),
observatory_id
)
v.WhereWhen.ObsDataLocation.ObservationLocation.AstroCoords.Time.TimeInstant.ISOTime += 'Z'
# NOTE: now remove position 2D so the fake ra, dec, err, and units
# don't show up.
ol = v.WhereWhen.ObsDataLocation.ObservationLocation
ol.AstroCoords.remove(ol.AstroCoords.Position2D)
# Citations ###############################################################
if obj.voevent_set.count() > 1:
## Loop over previous VOEvents for this event or superevent and
## add them to citations
event_ivorns_list = []
for ve in obj.voevent_set.all():
# Oh, actually we need to exclude *this* voevent.
if ve.N == voevent.N:
continue
# Get cite type
if voevent_type == 'retraction':
cite_type = vp.definitions.cite_types.retraction
else:
cite_type = vp.definitions.cite_types.supersedes
# Set up event ivorn
ei = vp.EventIvorn(ve.ivorn, cite_type)
# Add event ivorn
event_ivorns_list.append(ei)
# Add citations
vp.add_citations(
v,
event_ivorns_list
)
# Get description for citation
desc = None
if voevent_type == 'preliminary':
desc = 'Initial localization is now available (preliminary)'
elif voevent_type == 'initial':
desc = 'Initial localization is now available'
elif voevent_type == 'update':
desc = 'Updated localization is now available'
elif voevent_type == 'retraction':
desc = 'Determined to not be a viable GW event candidate'
elif voevent_type == 'earlywarning':
desc = 'Early warning localization is now available'
if desc is not None:
v.Citations.Description = desc
# Return the document as a string, along with the IVORN ###################
xml = vp.dumps(v, pretty_print=True)
return xml, v.get('ivorn')
import base64
import logging
import OpenSSL.crypto
import OpenSSL.SSL
import re
from django.contrib.auth import get_user_model, authenticate
from django.conf import settings
from django.contrib.auth.models import User
from django.http import HttpResponseForbidden
from django.utils import timezone
from django.utils.http import unquote
from django.utils.translation import gettext_lazy as _
from django.urls import resolve
from rest_framework import authentication, exceptions
from ligoauth.models import X509Cert
from .utils import is_api_request
import scitokens
from jwt import InvalidTokenError
from scitokens.utils.errors import SciTokensException
from urllib.parse import unquote_plus
# Set up logger
logger = logging.getLogger(__name__)
class GraceDbBasicAuthentication(authentication.BasicAuthentication):
allow_ajax = False
api_only = True
def authenticate(self, request, *args, **kwargs):
"""
Same as base class, except we require the request to be directed
toward the basic auth API.
"""
# Make sure this request is directed to the API
if self.api_only and not is_api_request(request.path):
return None
# Don't allow this auth type for AJAX requests, since we don't want it
# to work for API requests made by the web views.
#if request.is_ajax() and not self.allow_ajax:
if request.headers.get('x-requested-with') == 'XMLHttpRequest' and not self.allow_ajax:
return None
# Call base class authenticate() method
return super(GraceDbBasicAuthentication, self).authenticate(request)
def authenticate_credentials(self, userid, password, request=None):
"""
Add a hacky password expiration check to the inherited method.
"""
user_auth_tuple = super(GraceDbBasicAuthentication, self) \
.authenticate_credentials(userid, password, request)
user = user_auth_tuple[0]
# Check password expiration
# NOTE: This is *super* hacky because we are using date_joined to store
# the date when the password was set. See managePassword() in
# userprofile.views.
password_expiry = user.date_joined + settings.PASSWORD_EXPIRATION_TIME
if timezone.now() > password_expiry:
msg = ('Your password has expired. Please log in to the web '
'interface and request another.')
raise exceptions.AuthenticationFailed(_(msg))
return user_auth_tuple
class GraceDbSciTokenAuthentication(authentication.BasicAuthentication):
class MultiIssuerEnforcer(scitokens.Enforcer):
def __init__(self, issuer, **kwargs):
if not isinstance(issuer, (tuple, list)):
issuer = [issuer]
super().__init__(issuer, **kwargs)
def _validate_iss(self, value):
return value in self._issuer
def authenticate(self, request, public_key=None):
# Get token from header
try:
bearer = request.headers["Authorization"]
except KeyError:
return None
auth_type, serialized_token = bearer.split()
if auth_type != "Bearer":
return None
# Deserialize token
try:
token = scitokens.SciToken.deserialize(
serialized_token,
# deserialize all tokens, enforce audience later
audience={"ANY"} | set(settings.SCITOKEN_AUDIENCE),
public_key=public_key,
)
except (InvalidTokenError, SciTokensException) as exc:
return None
# Enforce scitoken logic
enforcer = self.MultiIssuerEnforcer(
settings.SCITOKEN_ISSUER,
audience = settings.SCITOKEN_AUDIENCE,
)
try:
authz, path = settings.SCITOKEN_SCOPE.split(":", 1)
except ValueError:
authz = settings.SCITOKEN_SCOPE
path = None
if not enforcer.test(token, authz, path):
return None
# Get username from token 'Subject' claim.
try:
user = User.objects.get(username=token['sub'].lower())
except User.DoesNotExist:
try:
# Catch Kagra and robot accounts that don't have @ligo.org usernames
user = User.objects.get(username=token['sub'].split('@')[0].lower())
except User.DoesNotExist:
return None
if not user.is_active:
raise exceptions.AuthenticationFailed(
_('User inactive or deleted'))
return (user, None)
class GraceDbX509Authentication(authentication.BaseAuthentication):
"""
Authentication based on X509 certificate subject.
Certificate should be verified by Apache already.
"""
allow_ajax = False
api_only = True
www_authenticate_realm = 'api'
subject_dn_header = getattr(settings, 'X509_SUBJECT_DN_HEADER',
'SSL_CLIENT_S_DN')
issuer_dn_header = getattr(settings, 'X509_ISSUER_DN_HEADER',
'SSL_CLIENT_I_DN')
proxy_pattern = re.compile(r'^(.*?)(/CN=\d+)*$')
def authenticate(self, request):
# Make sure this request is directed to the API
if self.api_only and not is_api_request(request.path):
return None
# Don't allow this auth type for AJAX requests - this is because
# users with certificates in their browser can still authenticate via
# this mechanism in the web view (since it makes API queries), even
# when they are not logged in.
#if request.is_ajax() and not self.allow_ajax:
if request.headers.get('x-requested-with') == 'XMLHttpRequest' and not self.allow_ajax:
return None
# Try to get credentials from request headers.
user_cert_dn = self.get_cert_dn_from_request(request)
# If no user dn is found, pass on to the next auth method
if not user_cert_dn:
return None
return self.authenticate_credentials(user_cert_dn)
@classmethod
def authenticate_header(cls, request):
return 'X509 realm="{0}"'.format(cls.www_authenticate_realm)
@classmethod
def get_cert_dn_from_request(cls, request):
"""Get SSL headers and return DN for user"""
# Get subject and issuer DN from SSL headers
certdn = request.META.get(cls.subject_dn_header, None)
issuer = request.META.get(cls.issuer_dn_header, '')
# Handled proxied certificates
certdn = cls.extract_subject_from_proxied_cert(certdn, issuer)
return certdn
@classmethod
def extract_subject_from_proxied_cert(cls, subject, issuer):
"""
Handles the case of "impersonation proxies", where /CN=[0-9]+ is
appended to the end of the certificate subject. This occurs when you
generate a certificate and it "follows" you to another machine - you
effectively self-sign a copy of the certificate to use on the other
machine.
Example:
Albert generates a certificate with ligo-proxy-init on his laptop.
Subject and issuer when he pings the GraceDB server from his laptop:
/DC=org/DC=cilogon/C=US/O=LIGO/CN=Albert Einstein albert.einstein@ligo.org
/DC=org/DC=cilogon/C=US/O=CILogon/CN=CILogon Basic CA 1
Subject and issuer when he gsisshs to an LDG cluster and then pings the
GraceDB server from there:
/DC=org/DC=cilogon/C=US/O=LIGO/CN=Albert Einstein albert.einstein@ligo.org/CN=1492637212
/DC=org/DC=cilogon/C=US/O=LIGO/CN=Albert Einstein albert.einstein@ligo.org
If he then gsisshs to *another* machine from there and repeats this,
he would get:
/DC=org/DC=cilogon/C=US/O=LIGO/CN=Albert Einstein albert.einstein@ligo.org/CN=1492637212/CN=28732493
/DC=org/DC=cilogon/C=US/O=LIGO/CN=Albert Einstein albert.einstein@ligo.org/CN=1492637212
"""
if subject and issuer and subject.startswith(issuer):
# If we get here, we have an impersonation proxy, so we extract
# the proxy /CN=12345... part from the subject. Could also
# do it from the issuer (see above examples)
subject = cls.proxy_pattern.match(subject).group(1)
return subject
def authenticate_credentials(self, user_cert_dn):
certs = X509Cert.objects.filter(subject=user_cert_dn)
if not certs.exists():
raise exceptions.AuthenticationFailed(_('Invalid certificate '
'subject'))
cert = certs.first()
# Check if user is active
user = cert.user
if not user.is_active:
raise exceptions.AuthenticationFailed(
_('User inactive or deleted'))
return (user, None)
class GraceDbX509CertInfosAuthentication(GraceDbX509Authentication):
"""
Authentication based on X509 "infos" header.
Certificate should be verified by Traefik already.
"""
allow_ajax = False
api_only = True
infos_header = getattr(settings, 'X509_INFOS_HEADER',
'HTTP_X_FORWARDED_TLS_CLIENT_CERT_INFOS')
infos_pattern = re.compile(r'Subject="(.*?)".*Issuer="(.*?)"')
@classmethod
def get_cert_dn_from_request(cls, request):
"""Get SSL headers and return subject for user"""
# Get infos from request headers
infos = request.META.get(cls.infos_header, None)
# Unquote (handle pluses -> spaces)
infos_unquoted = unquote_plus(infos)
# Extract subject and issuer
subject, issuer = cls.infos_pattern.search(infos_unquoted).groups()
# Convert formats
subject = cls.convert_format(subject)
issuer = cls.convert_format(issuer)
# Handled proxied certificates
subject = cls.extract_subject_from_proxied_cert(subject, issuer)
return subject
@staticmethod
def convert_format(s):
# Convert subject or issuer strings from comma to slash format
s = s.replace(',', '/')
if not s.startswith('/'):
s = '/' + s
return s
class GraceDbX509FullCertAuthentication(GraceDbX509Authentication):
"""
Authentication based on a full X509 certificate. We verify the
certificate here.
"""
allow_ajax = False
api_only = True
www_authenticate_realm = 'api'
cert_header = getattr(settings, 'X509_CERT_HEADER',
'HTTP_X_FORWARDED_TLS_CLIENT_CERT')
def authenticate(self, request):
# Make sure this request is directed to the API
if self.api_only and not is_api_request(request.path):
return None
# Don't allow this auth type for AJAX requests - this is because
# users with certificates in their browser can still authenticate via
# this mechanism in the web view (since it makes API queries), even
# when they are not logged in.
#if request.is_ajax() and not self.allow_ajax:
if request.headers.get('x-requested-with') == 'XMLHttpRequest' and not self.allow_ajax:
return None
# Try to get certificate from request headers
cert_data = self.get_certificate_data_from_request(request)
# If no certificate is found, abort
if not cert_data:
return None
# Verify certificate
try:
certificate = self.verify_certificate_chain(cert_data)
except exceptions.AuthenticationFailed as e:
raise
except Exception as e:
raise exceptions.AuthenticationFailed(_('Certificate could not be '
'verified'))
return self.authenticate_credentials(certificate)
@classmethod
def get_certificate_data_from_request(cls, request):
"""Get certificate data from request"""
cert_quoted = request.META.get(cls.cert_header, None)
if cert_quoted is None:
return None
# Process the certificate a bit
cert_b64 = unquote(cert_quoted)
cert_der = base64.b64decode(cert_b64)
return cert_der
def verify_certificate_chain(self, cert_data, capath=settings.CAPATH):
# Load certificate data
certificate = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, cert_data)
# Set up context and get certificate store
ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
ctx.load_verify_locations(None, capath=capath)
store = ctx.get_cert_store()
# Verify certificate
store_ctx = OpenSSL.crypto.X509StoreContext(store, certificate)
store_ctx.verify_certificate()
# Check if expired
if certificate.has_expired():
raise exceptions.AuthenticationFailed(_('Certificate has expired'))
return certificate
def authenticate_credentials(self, certificate):
# Get subject and issuer
subject = self.get_certificate_subject_string(certificate)
issuer = self.get_certificate_issuer_string(certificate)
# Handled proxied certificates
subject = self.extract_subject_from_proxied_cert(subject, issuer)
# Authenticate credentials
return super(GraceDbX509FullCertAuthentication, self) \
.authenticate_credentials(subject)
@staticmethod
def get_certificate_subject_string(certificate):
subject = certificate.get_subject()
subject_decoded = [[word.decode("utf8") for word in sets]
for sets in subject.get_components()]
subject_string = '/' + "/".join(["=".join(c) for c in
subject_decoded])
return subject_string
@staticmethod
def get_certificate_issuer_string(certificate):
issuer = certificate.get_issuer()
issuer_decoded = [[word.decode("utf8") for word in sets]
for sets in issuer.get_components()]
issuer_string = '/' + "/".join(["=".join(c) for c in
issuer_decoded])
return issuer_string
class GraceDbAuthenticatedAuthentication(authentication.BaseAuthentication):
"""
If user is already authenticated by the main Django middleware,
don't make them authenticate again.
This is mostly (only?) used for access to the web-browsable API when
the user is already authenticated via Shibboleth.
"""
api_only = True
def authenticate(self, request):
# Make sure this request is directed to the API
if self.api_only and not is_api_request(request.path):
return None
if (hasattr(request, '_request') and hasattr(request._request, 'user')
and hasattr(request._request.user, 'is_authenticated') and
request._request.user.is_authenticated):
return (request._request.user, None)
else:
return None
import logging
from rest_framework.views import exception_handler
import logging
# Set up logger
logger = logging.getLogger(__name__)
def gracedb_exception_handler(exc, context):
# Call REST framework's default exception handler first,
# to get the standard error response.
response = exception_handler(exc, context)
# Now add the HTTP status code to the response.
if response is not None:
if response.data.has_key('detail'):
response.data['detail'] = []
for a in exc.args:
response.data['detail'].append(a)
if hasattr(exc, 'detail') and hasattr(exc.detail, 'values'):
# Combine values into one list
exc_out = [item for sublist in list(exc.detail.values())
for item in sublist]
# For only one exception, just print it rather than the list
if len(exc_out) == 1:
exc_out = exc_out[0]
# Update response data
response.data = exc_out
return response
import logging
from django.conf import settings
from rest_framework import permissions
# Set up logger
logger = logging.getLogger(__name__)
class IsPriorityUser(permissions.BasePermission):
"""Only allow users in the priority users group"""
message = 'You are not authorized to use this API.'
def has_permission(self, request, view):
return request.user.groups.filter(
name=settings.PRIORITY_USERS_GROUP).exists()
from base64 import b64encode
try:
from unittest import mock
except ImportError: # python < 3
import mock
from django.conf import settings
from django.urls import reverse
from django.utils import timezone
from api.backends import (
GraceDbBasicAuthentication, GraceDbX509Authentication,
GraceDbAuthenticatedAuthentication,
)
from api.tests.utils import GraceDbApiTestBase
from api.utils import api_reverse
from ligoauth.models import X509Cert
# Make sure to test password expiration
class TestGraceDbBasicAuthentication(GraceDbApiTestBase):
"""Test basic auth backend for API in full auth cycle"""
@classmethod
def setUpTestData(cls):
super(TestGraceDbBasicAuthentication, cls).setUpTestData()
# Set up password for LV-EM user account
cls.password = 'passw0rd'
cls.lvem_user.set_password(cls.password)
cls.lvem_user.save()
def test_user_authenticate_to_api_with_password(self):
"""User can authenticate to API with correct password"""
# Set up and make request
url = api_reverse('api:root')
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password=self.password
).encode()
).decode("ascii")
headers = {
'HTTP_AUTHORIZATION': 'Basic {0}'.format(user_and_pass),
}
response = self.client.get(url, data=None, **headers)
# Check response
self.assertEqual(response.status_code, 200)
# Make sure user is authenticated properly by checking the
# renderer context
req = response.renderer_context['request']
self.assertEqual(req.user, self.lvem_user)
self.assertEqual(req.successful_authenticator.__class__,
GraceDbBasicAuthentication)
def test_user_authenticate_to_api_with_bad_password(self):
"""User can't authenticate with wrong password"""
# Set up and make request
url = api_reverse('api:root')
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password='b4d'
).encode()
).decode("ascii")
headers = {
'HTTP_AUTHORIZATION': 'Basic {0}'.format(user_and_pass),
}
response = self.client.get(url, data=None, **headers)
# Check response
self.assertContains(response, 'Invalid username/password',
status_code=403)
def test_user_authenticate_to_api_with_expired_password(self):
"""User can't authenticate with expired password"""
# Set password to be expired
self.lvem_user.date_joined = timezone.now() - \
2*settings.PASSWORD_EXPIRATION_TIME
self.lvem_user.save(update_fields=['date_joined'])
# Set up and make request
url = api_reverse('api:root')
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password=self.password
).encode()
).decode("ascii")
headers = {
'HTTP_AUTHORIZATION': 'Basic {0}'.format(user_and_pass),
}
response = self.client.get(url, data=None, **headers)
# Check response
self.assertContains(response, 'Your password has expired',
status_code=403)
class TestGraceDbX509Authentication(GraceDbApiTestBase):
"""Test X509 certificate auth backend for API in full auth cycle"""
def setUp(self):
super(TestGraceDbX509Authentication, self).setUp()
# Patch auth classes to make sure the right one is active
self.auth_patcher = mock.patch(
'rest_framework.views.APIView.get_authenticators',
return_value=[GraceDbX509Authentication(),])
self.auth_patcher.start()
def tearDown(self):
super(TestGraceDbX509Authentication, self).tearDown()
self.auth_patcher.stop()
@classmethod
def setUpTestData(cls):
super(TestGraceDbX509Authentication, cls).setUpTestData()
# Set up certificate for internal user account
cls.x509_subject = '/x509_subject'
cert = X509Cert.objects.create(subject=cls.x509_subject,
user=cls.internal_user)
def test_user_authenticate_to_api_with_x509_cert(self):
"""User can authenticate to API with valid X509 certificate"""
# Set up and make request
url = api_reverse('api:root')
headers = {
GraceDbX509Authentication.subject_dn_header: self.x509_subject,
}
response = self.client.get(url, data=None, **headers)
# Check response
self.assertEqual(response.status_code, 200)
# Make sure user is authenticated properly by checking the
# renderer context
req = response.renderer_context['request']
self.assertEqual(req.user, self.internal_user)
self.assertEqual(req.successful_authenticator.__class__,
GraceDbX509Authentication)
def test_user_authenticate_to_api_with_bad_x509_cert(self):
"""User can't authenticate with invalid X509 certificate subject"""
# Set up and make request
url = api_reverse('api:root')
headers = {
GraceDbX509Authentication.subject_dn_header: 'bad subject',
}
response = self.client.get(url, data=None, **headers)
# Check response
self.assertContains(response, 'Invalid certificate subject',
status_code=401)
def test_inactive_user_authenticate(self):
"""Inactive user can't authenticate"""
# Set internal user to inactive
self.internal_user.is_active = False
self.internal_user.save(update_fields=['is_active'])
# Set up and make request
url = api_reverse('api:root')
headers = {
GraceDbX509Authentication.subject_dn_header: self.x509_subject,
}
response = self.client.get(url, data=None, **headers)
# Check response
self.assertContains(response, 'User inactive or deleted',
status_code=401)
def test_authenticate_cert_with_proxy(self):
"""User can authenticate to API with proxied X509 certificate"""
# Set up request
#request = self.factory.get(api_reverse('api:root'))
#request.META[GraceDbX509Authentication.subject_dn_header] = \
# '/CN=123' + self.x509_subject
#request.META[GraceDbX509Authentication.issuer_dn_header] = \
# '/CN=123'
# Authentication attempt
#user, other = self.backend_instance.authenticate(request)
# Check authenticated user
#self.assertEqual(user, self.internal_user)
# Make sure user is authenticated properly by checking the
# renderer context
#req = response.renderer_context['request']
#self.assertEqual(req.user, self.internal_user)
#self.assertEqual(req.successful_authenticator.__class__,
# GraceDbAuthenticatedAuthentication)
class TestGraceDbAuthenticatedAuthentication(GraceDbApiTestBase):
"""Test "already-authenticated" auth backend for API"""
def test_user_authenticate_to_api(self):
"""User can authenticate if already authenticated"""
# Make request to post-login page to set up session
headers = {
settings.SHIB_USER_HEADER: self.internal_user.username,
settings.SHIB_GROUPS_HEADER: self.internal_group.name,
}
response = self.client.get(reverse('post-login'), data=None, **headers)
# Now make a request to the API root
response = self.client.get(api_reverse('api:root'))
# Make sure user is authenticated properly by checking the
# renderer context
req = response.renderer_context['request']
self.assertEqual(req.user, self.internal_user)
self.assertEqual(req.successful_authenticator.__class__,
GraceDbAuthenticatedAuthentication)
def test_user_not_authenticated_to_api(self):
"""User can't authenticate if not already authenticated"""
# Make a request to the API root
response = self.client.get(api_reverse('api:root'))
# Check response - should be 200 with AnonymousUser
self.assertEqual(response.status_code, 200)
# Make sure user is not authenticated
req = response.renderer_context['request']
self.assertTrue(req.user.is_anonymous)
self.assertFalse(req.user.is_authenticated)
self.assertEqual(req.successful_authenticator, None)
from base64 import b64encode
from django.conf import settings
from django.contrib.auth.middleware import AuthenticationMiddleware
from django.urls import reverse
from django.utils import timezone
from rest_framework import exceptions
from rest_framework.request import Request
from rest_framework.test import APIRequestFactory
from user_sessions.middleware import SessionMiddleware
from api.backends import (
GraceDbBasicAuthentication, GraceDbX509Authentication,
GraceDbSciTokenAuthentication, GraceDbAuthenticatedAuthentication,
)
from api.tests.utils import GraceDbApiTestBase
from api.utils import api_reverse
from ligoauth.middleware import ShibbolethWebAuthMiddleware
from ligoauth.models import X509Cert
import mock
import scitokens
import time
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import generate_private_key
from core.tests.utils import GraceDbTestBase
from django.test import override_settings
# Make sure to test password expiration
class TestGraceDbBasicAuthentication(GraceDbApiTestBase):
"""Test basic auth backend for API"""
@classmethod
def setUpClass(cls):
super(TestGraceDbBasicAuthentication, cls).setUpClass()
# Attach request factory to class
cls.backend_instance = GraceDbBasicAuthentication()
cls.factory = APIRequestFactory()
@classmethod
def setUpTestData(cls):
super(TestGraceDbBasicAuthentication, cls).setUpTestData()
# Set up password for LV-EM user account
cls.password = 'passw0rd'
cls.lvem_user.set_password(cls.password)
cls.lvem_user.save()
def test_user_authenticate_to_api_with_password(self):
"""User can authenticate to API with correct password"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password=self.password
).encode()
).decode("ascii")
request.META['HTTP_AUTHORIZATION'] = 'Basic {0}'.format(user_and_pass)
# Authentication attempt
user, other = self.backend_instance.authenticate(request)
# Check authenticated user
self.assertEqual(user, self.lvem_user)
def test_user_authenticate_to_api_with_bad_password(self):
"""User can't authenticate with wrong password"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password='b4d'
).encode()
).decode("ascii")
request.META['HTTP_AUTHORIZATION'] = 'Basic {0}'.format(user_and_pass)
# Authentication attempt should fail
with self.assertRaises(exceptions.AuthenticationFailed):
user, other = self.backend_instance.authenticate(request)
def test_user_authenticate_to_api_with_expired_password(self):
"""User can't authenticate with expired password"""
# Set user's password date (date_joined) so that it is expired
self.lvem_user.date_joined = timezone.now() - \
2*settings.PASSWORD_EXPIRATION_TIME
self.lvem_user.save(update_fields=['date_joined'])
# Set up request
request = self.factory.get(api_reverse('api:root'))
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password=self.password
).encode()
).decode("ascii")
request.META['HTTP_AUTHORIZATION'] = 'Basic {0}'.format(user_and_pass)
# Authentication attempt should fail
with self.assertRaisesRegex(exceptions.AuthenticationFailed,
'Your password has expired'):
user, other = self.backend_instance.authenticate(request)
def test_user_authenticate_non_api(self):
"""User can't authenticate to a non-API URL path"""
# Set up request
request = self.factory.get(reverse('home'))
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password=self.password
).encode()
).decode("ascii")
request.META['HTTP_AUTHORIZATION'] = 'Basic {0}'.format(user_and_pass)
# Try to authenticate
user_auth_tuple = self.backend_instance.authenticate(request)
self.assertEqual(user_auth_tuple, None)
def test_inactive_user_authenticate(self):
"""Inactive user can't authenticate"""
# Set LV-EM user to inactive
self.lvem_user.is_active = False
self.lvem_user.save(update_fields=['is_active'])
# Set up request
request = self.factory.get(api_reverse('api:root'))
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password=self.password
).encode()
).decode("ascii")
request.META['HTTP_AUTHORIZATION'] = 'Basic {0}'.format(user_and_pass)
# Authentication attempt should fail
with self.assertRaises(exceptions.AuthenticationFailed):
user, other = self.backend_instance.authenticate(request)
class TestGraceDbSciTokenAuthentication(GraceDbTestBase):
"""Test SciToken auth backend for API"""
TEST_ISSUER = ['local', 'local2']
TEST_AUDIENCE = ["TEST"]
TEST_SCOPE = "gracedb.read"
@classmethod
def setUpClass(cls):
super(TestGraceDbSciTokenAuthentication, cls).setUpClass()
# Attach request factory to class
cls.backend_instance = GraceDbSciTokenAuthentication()
cls.factory = APIRequestFactory()
@classmethod
def setUpTestData(cls):
super(TestGraceDbSciTokenAuthentication, cls).setUpTestData()
def setUp(self):
self._private_key = generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
self._public_key = self._private_key.public_key()
self._public_pem = self._public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
keycache = scitokens.utils.keycache.KeyCache.getinstance()
keycache.addkeyinfo("local", "sample_key", self._private_key.public_key())
now = int(time.time())
self._token = scitokens.SciToken(key = self._private_key, key_id="sample_key")
self._token.update_claims({
"iss": self.TEST_ISSUER,
"aud": self.TEST_AUDIENCE,
"scope": self.TEST_SCOPE,
"sub": str(self.internal_user),
})
self._serialized_token = self._token.serialize(issuer = "local")
self._no_kid_token = scitokens.SciToken(key = self._private_key)
@override_settings(
SCITOKEN_ISSUER="local",
SCITOKEN_AUDIENCE=["TEST"],
)
def test_user_authenticate_to_api_with_scitoken(self):
"""User can authenticate to API with valid Scitoken"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
token_str = 'Bearer ' + self._serialized_token.decode()
request.headers = {'Authorization': token_str}
# Authentication attempt
user, other = self.backend_instance.authenticate(request, public_key=self._public_pem)
# Check authenticated user
self.assertEqual(user, self.internal_user)
@override_settings(
SCITOKEN_ISSUER="local",
SCITOKEN_AUDIENCE=["TEST"],
)
def test_user_authenticate_to_api_without_scitoken(self):
"""User can authenticate to API without valid Scitoken"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
# Authentication attempt
resp = self.backend_instance.authenticate(request, public_key=self._public_pem)
# Check authentication response
assert resp == None
@override_settings(
SCITOKEN_ISSUER="local",
SCITOKEN_AUDIENCE=["TEST"],
)
def test_user_authenticate_to_api_with_wrong_audience(self):
"""User can authenticate to API with invalid Scitoken audience"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
self._token["aud"] = "https://somethingelse.example.com"
serialized_token = self._token.serialize(issuer = "local")
token_str = 'Bearer ' + serialized_token.decode()
request.headers = {'Authorization': token_str}
# Authentication attempt
resp = self.backend_instance.authenticate(request, public_key=self._public_pem)
# Check authentication response
assert resp == None
@override_settings(
SCITOKEN_ISSUER="local",
SCITOKEN_AUDIENCE=["TEST"],
)
def test_user_authenticate_to_api_with_expired_scitoken(self):
"""User can authenticate to API with valid Scitoken"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
serialized_token = self._token.serialize(issuer = "local", lifetime=-1)
token_str = 'Bearer ' + serialized_token.decode()
request.headers = {'Authorization': token_str}
# Authentication attempt
resp = self.backend_instance.authenticate(request, public_key=self._public_pem)
# Check authentication response
assert resp == None
@override_settings(
SCITOKEN_ISSUER="local",
SCITOKEN_AUDIENCE=["TEST"],
)
def test_inactive_user_authenticate_to_api_with_scitoken(self):
"""Inactive user can't authenticate with valid Scitoken"""
# Set internal user to inactive
self.internal_user.is_active = False
self.internal_user.save(update_fields=['is_active'])
# Set up request
request = self.factory.get(api_reverse('api:root'))
token_str = 'Bearer ' + self._serialized_token.decode()
request.headers = {'Authorization': token_str}
# Authentication attempt should fail
with self.assertRaises(exceptions.AuthenticationFailed):
user, other = self.backend_instance.authenticate(request, public_key=self._public_pem)
class TestGraceDbX509Authentication(GraceDbApiTestBase):
"""Test X509 certificate auth backend for API"""
@classmethod
def setUpClass(cls):
super(TestGraceDbX509Authentication, cls).setUpClass()
# Attach request factory to class
cls.backend_instance = GraceDbX509Authentication()
cls.factory = APIRequestFactory()
@classmethod
def setUpTestData(cls):
super(TestGraceDbX509Authentication, cls).setUpTestData()
# Set up certificate for internal user account
cls.x509_subject = '/x509_subject'
cert = X509Cert.objects.create(subject=cls.x509_subject,
user=cls.internal_user)
def test_user_authenticate_to_api_with_x509_cert(self):
"""User can authenticate to API with valid X509 certificate"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
request.META[GraceDbX509Authentication.subject_dn_header] = \
self.x509_subject
# Authentication attempt
user, other = self.backend_instance.authenticate(request)
# Check authenticated user
self.assertEqual(user, self.internal_user)
def test_user_authenticate_to_api_with_bad_x509_cert(self):
"""User can't authenticate with invalid X509 certificate subject"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
request.META[GraceDbX509Authentication.subject_dn_header] = \
'bad subject'
# Authentication attempt should fail
with self.assertRaises(exceptions.AuthenticationFailed):
user, other = self.backend_instance.authenticate(request)
def test_user_authenticate_non_api(self):
"""User can't authenticate to a non-API URL path"""
# Set up request
request = self.factory.get(reverse('home'))
request.META[GraceDbX509Authentication.subject_dn_header] = \
self.x509_subject
# Try to authenticate
user_auth_tuple = self.backend_instance.authenticate(request)
self.assertEqual(user_auth_tuple, None)
def test_inactive_user_authenticate(self):
"""Inactive user can't authenticate"""
# Set internal user to inactive
self.internal_user.is_active = False
self.internal_user.save(update_fields=['is_active'])
# Set up request
request = self.factory.get(api_reverse('api:root'))
request.META[GraceDbX509Authentication.subject_dn_header] = \
self.x509_subject
# Authentication attempt should fail
with self.assertRaises(exceptions.AuthenticationFailed):
user, other = self.backend_instance.authenticate(request)
def test_authenticate_cert_with_proxy(self):
"""User can authenticate to API with proxied X509 certificate"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
request.META[GraceDbX509Authentication.subject_dn_header] = \
self.x509_subject + '/CN=123456789'
request.META[GraceDbX509Authentication.issuer_dn_header] = \
self.x509_subject
# Authentication attempt
user, other = self.backend_instance.authenticate(request)
# Check authenticated user
self.assertEqual(user, self.internal_user)
def test_authenticate_cert_with_double_proxy(self):
"""User can authenticate to API with double-proxied X509 certificate"""
proxied_x509_subject = self.x509_subject + '/CN=123456789'
# Set up request
request = self.factory.get(api_reverse('api:root'))
request.META[GraceDbX509Authentication.subject_dn_header] = \
proxied_x509_subject + '/CN=987654321'
request.META[GraceDbX509Authentication.issuer_dn_header] = \
proxied_x509_subject
# Authentication attempt
user, other = self.backend_instance.authenticate(request)
# Check authenticated user
self.assertEqual(user, self.internal_user)
class TestGraceDbAuthenticatedAuthentication(GraceDbApiTestBase):
"""Test shibboleth auth backend for API"""
@classmethod
def setUpClass(cls):
super(TestGraceDbAuthenticatedAuthentication, cls).setUpClass()
# Attach request factory to class
cls.backend_instance = GraceDbAuthenticatedAuthentication()
cls.factory = APIRequestFactory()
cls.get_response = mock.MagicMock()
def test_user_authenticate_to_api(self):
"""User can authenticate if already authenticated"""
# Need to convert request to a rest_framework Request,
# as would be done in a view's initialize_request() method.
request = self.factory.get(api_reverse('api:root'))
request.user = self.internal_user
request = Request(request=request)
# Try to authenticate user
user, other = self.backend_instance.authenticate(request)
self.assertEqual(user, self.internal_user)
def test_user_not_authenticated_to_api(self):
"""User can't authenticate if not already authenticated"""
# Need to convert request to a rest_framework Request,
# as would be done in a view's initialize_request() method.
request = self.factory.get(api_reverse('api:root'))
# Preprocessing to set request.user to anonymous
SessionMiddleware(self.get_response).process_request(request)
AuthenticationMiddleware(self.get_response).process_request(request)
request = Request(request=request)
# Try to authenticate user
user_auth_tuple = self.backend_instance.authenticate(request)
self.assertEqual(user_auth_tuple, None)
def test_user_authenticate_to_non_api(self):
"""User can't authenticate to non-API URL path"""
# Need to convert request to a rest_framework Request,
# as would be done in a view's initialize_request() method.
request = self.factory.get(reverse('home'))
request = Request(request=request)
# Try to authenticate user
user_auth_tuple = self.backend_instance.authenticate(request)
self.assertEqual(user_auth_tuple, None)
try:
from unittest import mock
except ImportError: # python < 3
import mock
from django.conf import settings
from django.core.cache import caches
from django.test import override_settings
from django.urls import reverse
from api.tests.utils import GraceDbApiTestBase
class TestThrottling(GraceDbApiTestBase):
"""Test API throttles"""
@mock.patch('api.throttling.BurstAnonRateThrottle.get_rate',
return_value='1/hour'
)
def test_anon_burst_throttle(self, mock_get_rate):
"""Test anonymous user burst throttle"""
url = reverse('api:default:root')
# First request should be OK
response = self.request_as_user(url, "GET")
self.assertEqual(response.status_code, 200)
# Second response should get throttled
response = self.request_as_user(url, "GET")
self.assertContains(response, 'Request was throttled', status_code=429)
self.assertIn('Retry-After', response.headers)
from django.urls import reverse as django_reverse
from rest_framework.test import APIRequestFactory, APISimpleTestCase
from api.utils import api_reverse
class TestApiReverse(APISimpleTestCase):
"""Test behavior of custom api_reverse function"""
@classmethod
def setUpClass(cls):
super(TestApiReverse, cls).setUpClass()
cls.api_version = 'v1'
cls.api_url = django_reverse('api:{version}:root'.format(
version=cls.api_version))
cls.non_api_url = django_reverse('home')
def setUp(self):
super(TestApiReverse, self).setUp()
self.factory = APIRequestFactory()
# Create requests
self.api_request = self.factory.get(self.api_url)
self.non_api_request = self.factory.get(self.non_api_url)
# Simulate version checking that is done in an API view's
# initial() method
self.api_request.version = self.api_version
def test_full_viewname_with_request_to_api(self):
"""
Reverse a fully namespaced viewname, including a request to the API
"""
url = api_reverse('api:v1:root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_full_viewname_with_request_to_api_different_version(self):
"""
Reverse a fully namespaced viewname with a different version than
the corresponding request to the API
"""
url = api_reverse('api:v2:root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v2:root'))
def test_full_viewname_with_request_to_non_api(self):
"""
Reverse a fully namespaced viewname, including a request to a non-API
page
"""
url = api_reverse('api:v1:root', request=self.non_api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_full_viewname_with_no_request(self):
"""
Reverse a fully namespaced viewname, with no associated request
"""
url = api_reverse('api:v1:root', absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_versioned_viewname_with_request_to_api(self):
"""
Reverse a versioned viewname, including a request to the API
"""
url = api_reverse('v1:root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_versioned_viewname_with_request_to_api_different_version(self):
"""
Reverse a versioned viewname with a different version than
the corresponding request to the API
"""
url = api_reverse('v2:root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v2:root'))
def test_versioned_viewname_with_request_to_non_api(self):
"""
Reverse a versioned viewname, including a request to a non-API page
"""
url = api_reverse('v1:root', request=self.non_api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_versioned_viewname_with_no_request(self):
"""
Reverse a versioned viewname, with no associated request
"""
url = api_reverse('v1:root', absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_api_unversioned_viewname_with_request_to_api(self):
"""
Reverse an api-namespaced but unversioned viewname, including a request
to the API
"""
url = api_reverse('api:root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_api_unversioned_viewname_with_request_to_non_api(self):
"""
Reverse an api-namespaced but unversioned viewname, including a request
to a non-API page
"""
url = api_reverse('api:root', request=self.non_api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:default:root'))
def test_api_unversioned_viewname_with_no_request(self):
"""
Reverse an api-namespaced bu unversioned viewname, with no associated
request
"""
url = api_reverse('api:root', absolute_path=False)
self.assertEqual(url, django_reverse('api:default:root'))
def test_relative_viewname_with_request_to_api(self):
url = api_reverse('root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_relative_viewname_with_request_to_non_api(self):
url = api_reverse('root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_relative_viewname_with_no_request(self):
url = api_reverse('root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_reverse_non_api_url(self):
pass
from copy import deepcopy
try:
from functools import reduce
except ImportError: # python < 3
pass
try:
from unittest import mock
except ImportError: # python < 3
import mock
from django.conf import settings
from django.core.cache import caches
from django.test import override_settings
from rest_framework.test import APIClient
from core.tests.utils import GraceDbTestBase
def fix_settings(key, value):
"""
Dynamically override settings for testing. But, it will only
ever be useful if rest_framework fixes the way that their
settings work so that override_settings actually works.
'key' should be x.y.z for nested dictionaries.
"""
api_settings = deepcopy(settings.REST_FRAMEWORK)
key_list = key.split('.')
new_dict = reduce(dict.get, key_list[:-1], api_settings)
new_dict[key_list[-1]] = value
return api_settings
@override_settings(
ALLOW_BLANK_USER_AGENT_TO_API=True,
)
class GraceDbApiTestBase(GraceDbTestBase):
client_class = APIClient
def setUp(self):
super(GraceDbApiTestBase, self).setUp()
# Patch throttle and start patcher
self.patcher = mock.patch('api.throttling.BurstAnonRateThrottle.get_rate',
return_value='1000/second')
self.patcher.start()
def tearDown(self):
super(GraceDbApiTestBase, self).tearDown()
# Clear throttle cache
caches['throttles'].clear()
# Stop patcher
self.patcher.stop()
from rest_framework.throttling import UserRateThrottle
from django.core.cache import caches
class PostOrPutUserRateThrottle(UserRateThrottle):
from rest_framework.throttling import AnonRateThrottle, UserRateThrottle
# NOTE: we have to use database-backed throttles to have a centralized location
# where multiple workers (like in the production instance) can access and
# update the same throttling information.
###############################################################################
# Base throttle classes #######################################################
###############################################################################
class DbCachedThrottleMixin(object):
"""Uses a non-default (database-backed) cache"""
cache = caches['throttles']
###############################################################################
# Throttles for unauthenticated users #########################################
###############################################################################
class BurstAnonRateThrottle(DbCachedThrottleMixin, AnonRateThrottle):
scope = 'anon_burst'
class SustainedAnonRateThrottle(DbCachedThrottleMixin, AnonRateThrottle):
scope = 'anon_sustained'
###############################################################################
# Throttles for authenticated users #########################################
###############################################################################
class PostOrPutUserRateThrottle(DbCachedThrottleMixin, UserRateThrottle):
def allow_request(self, request, view):
"""
This is mostly copied from the Rest Framework's SimpleRateThrottle
except we now pass the request to throttle_success
"""
# Don't throttle superusers - causes problems with client
# integration tests
if request.user.is_superuser:
return True
# We don't want to throttle any safe methods
if request.method not in ['POST', 'PUT']:
return True
......@@ -39,10 +75,3 @@ class PostOrPutUserRateThrottle(UserRateThrottle):
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
class EventCreationThrottle(PostOrPutUserRateThrottle):
scope = 'event_creation'
class AnnotationThrottle(PostOrPutUserRateThrottle):
scope = 'annotation'
from django.urls import re_path, include
from .v1 import urls as v1_urls
from .v2 import urls as v2_urls
app_name = 'api'
urlpatterns = [
re_path(r'^', include((v1_urls, 'default'))),
re_path(r'^v1/', include((v1_urls, 'v1'))),
re_path(r'^v2/', include((v2_urls, 'v2'))),
]
import logging
from django.urls import resolve, reverse as django_reverse
from rest_framework.settings import api_settings
from rest_framework.response import Response
from core.urls import build_absolute_uri
# Set up logger
logger = logging.getLogger(__name__)
# Some default values
API_NAMESPACE = 'api'
def api_reverse(viewname, args=None, kwargs=None, request=None, format=None,
absolute_path=True, **extra):
"""
Usage:
# No request and no version in viewname uses default version
# Same for case when request points to a non-API url
api_reverse('api:events:event-list')
api_reverse('events:event-list')
api_reverse('api:events:event-list', request=request)
api_reverse('events:event-list', request=request)
/api/events/
# No request but with version in viewname uses the specified version
# Same for case when request points to a non-API url
api_reverse('api:v1:events:event-list')
api_reverse('api:v1:events:event-list', request=request)
/api/v1/events/
api_reverse('v2:events:event-list')
api_reverse('v2:events:event-list', request=request)
/api/v2/events/
# Request pointing to an API URL uses the specified version in the
# viewname. If a version is not specified in the viewname, the version
# is determined from the request.
api_reverse('api:v1:events:event-list, request=request)
api_reverse('v1:events:event-list, request=request)
/api/v1/events/ (request.path is like /api/(any version)/*)
api_reverse('api:events:event-list, request=request)
api_reverse('events:event-list, request=request)
/api/v2/events (request.path is like /api/v2/*)
"""
# Prepend 'api:' if viewname doesn't start with it.
if not viewname.startswith(API_NAMESPACE + ':'):
viewname = API_NAMESPACE + ':' + viewname
# Handle versioning. Nothing is done by the versioning_class if the
# viewname already has a version namespace.
versioning_class = api_settings.DEFAULT_VERSIONING_CLASS()
viewname = versioning_class.get_versioned_viewname(viewname, request)
# Get URL
url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
if absolute_path:
url = build_absolute_uri(url)
return url
def is_api_request(request_path):
"""
Returns True/False based on whether the request is directed to the API
"""
# This is hard-coded because things break if we try to import it from .urls
api_app_name = 'api'
resolver_match = resolve(request_path)
if (resolver_match.app_names and
resolver_match.app_names[0] == api_app_name):
return True
return False
class ResponseThenRun(Response):
"""
A response class that will do something after the response is sent
"""
def __init__(self, data, callback, callback_kwargs, **kwargs):
super(ResponseThenRun, self).__init__(data, **kwargs)
self.callback = callback
self.callback_kwargs = callback_kwargs
def close(self):
super().close()
self.callback(**self.callback_kwargs)
from __future__ import absolute_import
import logging
import six
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from events.models import Event
from ..fields import GenericField
# Set up logger
logger = logging.getLogger(__name__)
class EventGraceidField(GenericField, serializers.RelatedField):
default_error_messages = {
'invalid': _('Event graceid must be a string.'),
'bad_graceid': _('Not a valid graceid.'),
}
trim_whitespace = True
to_repr = 'graceid'
lookup_field = 'pk'
model = Event
queryset = Event.objects.all()
def _validate_graceid(self, data):
# data should be a string at this point
prefix = data[0]
suffix = data[1:]
if not prefix in 'GEHMTD':
self.fail('bad_graceid')
try:
suffix = int(suffix)
except ValueError:
self.fail('bad_graceid')
def to_internal_value(self, data):
# Add string validation
if not isinstance(data, six.string_types):
self.fail('invalid')
value = six.text_type(data)
if self.trim_whitespace:
value = value.strip()
# Convert to uppercase (allows g1234 to work)
value = value.upper()
# Graceid validation
self._validate_graceid(value)
return super(EventGraceidField, self).to_internal_value(value)
def get_model_dict(self, data):
return {self.lookup_field: data[1:]}
def get_does_not_exist_error(self, graceid):
err_msg = "Event with graceid {graceid} does not exist.".format(
graceid=graceid)
return err_msg