Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • alexander.pace/server
  • geoffrey.mo/gracedb-server
  • deep.chatterjee/gracedb-server
  • cody.messick/server
  • sushant.sharma-chaudhary/server
  • michael-coughlin/server
  • daniel.wysocki/gracedb-server
  • roberto.depietri/gracedb
  • philippe.grassia/gracedb
  • tri.nguyen/gracedb
  • jonah-kanner/gracedb
  • brandon.piotrzkowski/gracedb
  • joseph-areeda/gracedb
  • duncanmmacleod/gracedb
  • thomas.downes/gracedb
  • tanner.prestegard/gracedb
  • leo-singer/gracedb
  • computing/gracedb/server
18 results
Show changes
Showing
with 2128 additions and 0 deletions
from base64 import b64encode
try:
from unittest import mock
except ImportError: # python < 3
import mock
from django.conf import settings
from django.urls import reverse
from django.utils import timezone
from api.backends import (
GraceDbBasicAuthentication, GraceDbX509Authentication,
GraceDbAuthenticatedAuthentication,
)
from api.tests.utils import GraceDbApiTestBase
from api.utils import api_reverse
from ligoauth.models import X509Cert
# Make sure to test password expiration
class TestGraceDbBasicAuthentication(GraceDbApiTestBase):
"""Test basic auth backend for API in full auth cycle"""
@classmethod
def setUpTestData(cls):
super(TestGraceDbBasicAuthentication, cls).setUpTestData()
# Set up password for LV-EM user account
cls.password = 'passw0rd'
cls.lvem_user.set_password(cls.password)
cls.lvem_user.save()
def test_user_authenticate_to_api_with_password(self):
"""User can authenticate to API with correct password"""
# Set up and make request
url = api_reverse('api:root')
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password=self.password
).encode()
).decode("ascii")
headers = {
'HTTP_AUTHORIZATION': 'Basic {0}'.format(user_and_pass),
}
response = self.client.get(url, data=None, **headers)
# Check response
self.assertEqual(response.status_code, 200)
# Make sure user is authenticated properly by checking the
# renderer context
req = response.renderer_context['request']
self.assertEqual(req.user, self.lvem_user)
self.assertEqual(req.successful_authenticator.__class__,
GraceDbBasicAuthentication)
def test_user_authenticate_to_api_with_bad_password(self):
"""User can't authenticate with wrong password"""
# Set up and make request
url = api_reverse('api:root')
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password='b4d'
).encode()
).decode("ascii")
headers = {
'HTTP_AUTHORIZATION': 'Basic {0}'.format(user_and_pass),
}
response = self.client.get(url, data=None, **headers)
# Check response
self.assertContains(response, 'Invalid username/password',
status_code=403)
def test_user_authenticate_to_api_with_expired_password(self):
"""User can't authenticate with expired password"""
# Set password to be expired
self.lvem_user.date_joined = timezone.now() - \
2*settings.PASSWORD_EXPIRATION_TIME
self.lvem_user.save(update_fields=['date_joined'])
# Set up and make request
url = api_reverse('api:root')
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password=self.password
).encode()
).decode("ascii")
headers = {
'HTTP_AUTHORIZATION': 'Basic {0}'.format(user_and_pass),
}
response = self.client.get(url, data=None, **headers)
# Check response
self.assertContains(response, 'Your password has expired',
status_code=403)
class TestGraceDbX509Authentication(GraceDbApiTestBase):
"""Test X509 certificate auth backend for API in full auth cycle"""
def setUp(self):
super(TestGraceDbX509Authentication, self).setUp()
# Patch auth classes to make sure the right one is active
self.auth_patcher = mock.patch(
'rest_framework.views.APIView.get_authenticators',
return_value=[GraceDbX509Authentication(),])
self.auth_patcher.start()
def tearDown(self):
super(TestGraceDbX509Authentication, self).tearDown()
self.auth_patcher.stop()
@classmethod
def setUpTestData(cls):
super(TestGraceDbX509Authentication, cls).setUpTestData()
# Set up certificate for internal user account
cls.x509_subject = '/x509_subject'
cert = X509Cert.objects.create(subject=cls.x509_subject,
user=cls.internal_user)
def test_user_authenticate_to_api_with_x509_cert(self):
"""User can authenticate to API with valid X509 certificate"""
# Set up and make request
url = api_reverse('api:root')
headers = {
GraceDbX509Authentication.subject_dn_header: self.x509_subject,
}
response = self.client.get(url, data=None, **headers)
# Check response
self.assertEqual(response.status_code, 200)
# Make sure user is authenticated properly by checking the
# renderer context
req = response.renderer_context['request']
self.assertEqual(req.user, self.internal_user)
self.assertEqual(req.successful_authenticator.__class__,
GraceDbX509Authentication)
def test_user_authenticate_to_api_with_bad_x509_cert(self):
"""User can't authenticate with invalid X509 certificate subject"""
# Set up and make request
url = api_reverse('api:root')
headers = {
GraceDbX509Authentication.subject_dn_header: 'bad subject',
}
response = self.client.get(url, data=None, **headers)
# Check response
self.assertContains(response, 'Invalid certificate subject',
status_code=401)
def test_inactive_user_authenticate(self):
"""Inactive user can't authenticate"""
# Set internal user to inactive
self.internal_user.is_active = False
self.internal_user.save(update_fields=['is_active'])
# Set up and make request
url = api_reverse('api:root')
headers = {
GraceDbX509Authentication.subject_dn_header: self.x509_subject,
}
response = self.client.get(url, data=None, **headers)
# Check response
self.assertContains(response, 'User inactive or deleted',
status_code=401)
def test_authenticate_cert_with_proxy(self):
"""User can authenticate to API with proxied X509 certificate"""
# Set up request
#request = self.factory.get(api_reverse('api:root'))
#request.META[GraceDbX509Authentication.subject_dn_header] = \
# '/CN=123' + self.x509_subject
#request.META[GraceDbX509Authentication.issuer_dn_header] = \
# '/CN=123'
# Authentication attempt
#user, other = self.backend_instance.authenticate(request)
# Check authenticated user
#self.assertEqual(user, self.internal_user)
# Make sure user is authenticated properly by checking the
# renderer context
#req = response.renderer_context['request']
#self.assertEqual(req.user, self.internal_user)
#self.assertEqual(req.successful_authenticator.__class__,
# GraceDbAuthenticatedAuthentication)
class TestGraceDbAuthenticatedAuthentication(GraceDbApiTestBase):
"""Test "already-authenticated" auth backend for API"""
def test_user_authenticate_to_api(self):
"""User can authenticate if already authenticated"""
# Make request to post-login page to set up session
headers = {
settings.SHIB_USER_HEADER: self.internal_user.username,
settings.SHIB_GROUPS_HEADER: self.internal_group.name,
}
response = self.client.get(reverse('post-login'), data=None, **headers)
# Now make a request to the API root
response = self.client.get(api_reverse('api:root'))
# Make sure user is authenticated properly by checking the
# renderer context
req = response.renderer_context['request']
self.assertEqual(req.user, self.internal_user)
self.assertEqual(req.successful_authenticator.__class__,
GraceDbAuthenticatedAuthentication)
def test_user_not_authenticated_to_api(self):
"""User can't authenticate if not already authenticated"""
# Make a request to the API root
response = self.client.get(api_reverse('api:root'))
# Check response - should be 200 with AnonymousUser
self.assertEqual(response.status_code, 200)
# Make sure user is not authenticated
req = response.renderer_context['request']
self.assertTrue(req.user.is_anonymous)
self.assertFalse(req.user.is_authenticated)
self.assertEqual(req.successful_authenticator, None)
from base64 import b64encode
from django.conf import settings
from django.contrib.auth.middleware import AuthenticationMiddleware
from django.urls import reverse
from django.utils import timezone
from rest_framework import exceptions
from rest_framework.request import Request
from rest_framework.test import APIRequestFactory
from user_sessions.middleware import SessionMiddleware
from api.backends import (
GraceDbBasicAuthentication, GraceDbX509Authentication,
GraceDbSciTokenAuthentication, GraceDbAuthenticatedAuthentication,
)
from api.tests.utils import GraceDbApiTestBase
from api.utils import api_reverse
from ligoauth.middleware import ShibbolethWebAuthMiddleware
from ligoauth.models import X509Cert
import mock
import scitokens
import time
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import generate_private_key
from core.tests.utils import GraceDbTestBase
from django.test import override_settings
# Make sure to test password expiration
class TestGraceDbBasicAuthentication(GraceDbApiTestBase):
"""Test basic auth backend for API"""
@classmethod
def setUpClass(cls):
super(TestGraceDbBasicAuthentication, cls).setUpClass()
# Attach request factory to class
cls.backend_instance = GraceDbBasicAuthentication()
cls.factory = APIRequestFactory()
@classmethod
def setUpTestData(cls):
super(TestGraceDbBasicAuthentication, cls).setUpTestData()
# Set up password for LV-EM user account
cls.password = 'passw0rd'
cls.lvem_user.set_password(cls.password)
cls.lvem_user.save()
def test_user_authenticate_to_api_with_password(self):
"""User can authenticate to API with correct password"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password=self.password
).encode()
).decode("ascii")
request.META['HTTP_AUTHORIZATION'] = 'Basic {0}'.format(user_and_pass)
# Authentication attempt
user, other = self.backend_instance.authenticate(request)
# Check authenticated user
self.assertEqual(user, self.lvem_user)
def test_user_authenticate_to_api_with_bad_password(self):
"""User can't authenticate with wrong password"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password='b4d'
).encode()
).decode("ascii")
request.META['HTTP_AUTHORIZATION'] = 'Basic {0}'.format(user_and_pass)
# Authentication attempt should fail
with self.assertRaises(exceptions.AuthenticationFailed):
user, other = self.backend_instance.authenticate(request)
def test_user_authenticate_to_api_with_expired_password(self):
"""User can't authenticate with expired password"""
# Set user's password date (date_joined) so that it is expired
self.lvem_user.date_joined = timezone.now() - \
2*settings.PASSWORD_EXPIRATION_TIME
self.lvem_user.save(update_fields=['date_joined'])
# Set up request
request = self.factory.get(api_reverse('api:root'))
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password=self.password
).encode()
).decode("ascii")
request.META['HTTP_AUTHORIZATION'] = 'Basic {0}'.format(user_and_pass)
# Authentication attempt should fail
with self.assertRaisesRegex(exceptions.AuthenticationFailed,
'Your password has expired'):
user, other = self.backend_instance.authenticate(request)
def test_user_authenticate_non_api(self):
"""User can't authenticate to a non-API URL path"""
# Set up request
request = self.factory.get(reverse('home'))
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password=self.password
).encode()
).decode("ascii")
request.META['HTTP_AUTHORIZATION'] = 'Basic {0}'.format(user_and_pass)
# Try to authenticate
user_auth_tuple = self.backend_instance.authenticate(request)
self.assertEqual(user_auth_tuple, None)
def test_inactive_user_authenticate(self):
"""Inactive user can't authenticate"""
# Set LV-EM user to inactive
self.lvem_user.is_active = False
self.lvem_user.save(update_fields=['is_active'])
# Set up request
request = self.factory.get(api_reverse('api:root'))
user_and_pass = b64encode(
"{username}:{password}".format(
username=self.lvem_user.username,
password=self.password
).encode()
).decode("ascii")
request.META['HTTP_AUTHORIZATION'] = 'Basic {0}'.format(user_and_pass)
# Authentication attempt should fail
with self.assertRaises(exceptions.AuthenticationFailed):
user, other = self.backend_instance.authenticate(request)
class TestGraceDbSciTokenAuthentication(GraceDbTestBase):
"""Test SciToken auth backend for API"""
TEST_ISSUER = ['local', 'local2']
TEST_AUDIENCE = ["TEST"]
TEST_SCOPE = "gracedb.read"
@classmethod
def setUpClass(cls):
super(TestGraceDbSciTokenAuthentication, cls).setUpClass()
# Attach request factory to class
cls.backend_instance = GraceDbSciTokenAuthentication()
cls.factory = APIRequestFactory()
@classmethod
def setUpTestData(cls):
super(TestGraceDbSciTokenAuthentication, cls).setUpTestData()
def setUp(self):
self._private_key = generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
self._public_key = self._private_key.public_key()
self._public_pem = self._public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
keycache = scitokens.utils.keycache.KeyCache.getinstance()
keycache.addkeyinfo("local", "sample_key", self._private_key.public_key())
now = int(time.time())
self._token = scitokens.SciToken(key = self._private_key, key_id="sample_key")
self._token.update_claims({
"iss": self.TEST_ISSUER,
"aud": self.TEST_AUDIENCE,
"scope": self.TEST_SCOPE,
"sub": str(self.internal_user),
})
self._serialized_token = self._token.serialize(issuer = "local")
self._no_kid_token = scitokens.SciToken(key = self._private_key)
@override_settings(
SCITOKEN_ISSUER="local",
SCITOKEN_AUDIENCE=["TEST"],
)
def test_user_authenticate_to_api_with_scitoken(self):
"""User can authenticate to API with valid Scitoken"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
token_str = 'Bearer ' + self._serialized_token.decode()
request.headers = {'Authorization': token_str}
# Authentication attempt
user, other = self.backend_instance.authenticate(request, public_key=self._public_pem)
# Check authenticated user
self.assertEqual(user, self.internal_user)
@override_settings(
SCITOKEN_ISSUER="local",
SCITOKEN_AUDIENCE=["TEST"],
)
def test_user_authenticate_to_api_without_scitoken(self):
"""User can authenticate to API without valid Scitoken"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
# Authentication attempt
resp = self.backend_instance.authenticate(request, public_key=self._public_pem)
# Check authentication response
assert resp == None
@override_settings(
SCITOKEN_ISSUER="local",
SCITOKEN_AUDIENCE=["TEST"],
)
def test_user_authenticate_to_api_with_wrong_audience(self):
"""User can authenticate to API with invalid Scitoken audience"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
self._token["aud"] = "https://somethingelse.example.com"
serialized_token = self._token.serialize(issuer = "local")
token_str = 'Bearer ' + serialized_token.decode()
request.headers = {'Authorization': token_str}
# Authentication attempt
resp = self.backend_instance.authenticate(request, public_key=self._public_pem)
# Check authentication response
assert resp == None
@override_settings(
SCITOKEN_ISSUER="local",
SCITOKEN_AUDIENCE=["TEST"],
)
def test_user_authenticate_to_api_with_expired_scitoken(self):
"""User can authenticate to API with valid Scitoken"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
serialized_token = self._token.serialize(issuer = "local", lifetime=-1)
token_str = 'Bearer ' + serialized_token.decode()
request.headers = {'Authorization': token_str}
# Authentication attempt
resp = self.backend_instance.authenticate(request, public_key=self._public_pem)
# Check authentication response
assert resp == None
@override_settings(
SCITOKEN_ISSUER="local",
SCITOKEN_AUDIENCE=["TEST"],
)
def test_inactive_user_authenticate_to_api_with_scitoken(self):
"""Inactive user can't authenticate with valid Scitoken"""
# Set internal user to inactive
self.internal_user.is_active = False
self.internal_user.save(update_fields=['is_active'])
# Set up request
request = self.factory.get(api_reverse('api:root'))
token_str = 'Bearer ' + self._serialized_token.decode()
request.headers = {'Authorization': token_str}
# Authentication attempt should fail
with self.assertRaises(exceptions.AuthenticationFailed):
user, other = self.backend_instance.authenticate(request, public_key=self._public_pem)
class TestGraceDbX509Authentication(GraceDbApiTestBase):
"""Test X509 certificate auth backend for API"""
@classmethod
def setUpClass(cls):
super(TestGraceDbX509Authentication, cls).setUpClass()
# Attach request factory to class
cls.backend_instance = GraceDbX509Authentication()
cls.factory = APIRequestFactory()
@classmethod
def setUpTestData(cls):
super(TestGraceDbX509Authentication, cls).setUpTestData()
# Set up certificate for internal user account
cls.x509_subject = '/x509_subject'
cert = X509Cert.objects.create(subject=cls.x509_subject,
user=cls.internal_user)
def test_user_authenticate_to_api_with_x509_cert(self):
"""User can authenticate to API with valid X509 certificate"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
request.META[GraceDbX509Authentication.subject_dn_header] = \
self.x509_subject
# Authentication attempt
user, other = self.backend_instance.authenticate(request)
# Check authenticated user
self.assertEqual(user, self.internal_user)
def test_user_authenticate_to_api_with_bad_x509_cert(self):
"""User can't authenticate with invalid X509 certificate subject"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
request.META[GraceDbX509Authentication.subject_dn_header] = \
'bad subject'
# Authentication attempt should fail
with self.assertRaises(exceptions.AuthenticationFailed):
user, other = self.backend_instance.authenticate(request)
def test_user_authenticate_non_api(self):
"""User can't authenticate to a non-API URL path"""
# Set up request
request = self.factory.get(reverse('home'))
request.META[GraceDbX509Authentication.subject_dn_header] = \
self.x509_subject
# Try to authenticate
user_auth_tuple = self.backend_instance.authenticate(request)
self.assertEqual(user_auth_tuple, None)
def test_inactive_user_authenticate(self):
"""Inactive user can't authenticate"""
# Set internal user to inactive
self.internal_user.is_active = False
self.internal_user.save(update_fields=['is_active'])
# Set up request
request = self.factory.get(api_reverse('api:root'))
request.META[GraceDbX509Authentication.subject_dn_header] = \
self.x509_subject
# Authentication attempt should fail
with self.assertRaises(exceptions.AuthenticationFailed):
user, other = self.backend_instance.authenticate(request)
def test_authenticate_cert_with_proxy(self):
"""User can authenticate to API with proxied X509 certificate"""
# Set up request
request = self.factory.get(api_reverse('api:root'))
request.META[GraceDbX509Authentication.subject_dn_header] = \
self.x509_subject + '/CN=123456789'
request.META[GraceDbX509Authentication.issuer_dn_header] = \
self.x509_subject
# Authentication attempt
user, other = self.backend_instance.authenticate(request)
# Check authenticated user
self.assertEqual(user, self.internal_user)
def test_authenticate_cert_with_double_proxy(self):
"""User can authenticate to API with double-proxied X509 certificate"""
proxied_x509_subject = self.x509_subject + '/CN=123456789'
# Set up request
request = self.factory.get(api_reverse('api:root'))
request.META[GraceDbX509Authentication.subject_dn_header] = \
proxied_x509_subject + '/CN=987654321'
request.META[GraceDbX509Authentication.issuer_dn_header] = \
proxied_x509_subject
# Authentication attempt
user, other = self.backend_instance.authenticate(request)
# Check authenticated user
self.assertEqual(user, self.internal_user)
class TestGraceDbAuthenticatedAuthentication(GraceDbApiTestBase):
"""Test shibboleth auth backend for API"""
@classmethod
def setUpClass(cls):
super(TestGraceDbAuthenticatedAuthentication, cls).setUpClass()
# Attach request factory to class
cls.backend_instance = GraceDbAuthenticatedAuthentication()
cls.factory = APIRequestFactory()
cls.get_response = mock.MagicMock()
def test_user_authenticate_to_api(self):
"""User can authenticate if already authenticated"""
# Need to convert request to a rest_framework Request,
# as would be done in a view's initialize_request() method.
request = self.factory.get(api_reverse('api:root'))
request.user = self.internal_user
request = Request(request=request)
# Try to authenticate user
user, other = self.backend_instance.authenticate(request)
self.assertEqual(user, self.internal_user)
def test_user_not_authenticated_to_api(self):
"""User can't authenticate if not already authenticated"""
# Need to convert request to a rest_framework Request,
# as would be done in a view's initialize_request() method.
request = self.factory.get(api_reverse('api:root'))
# Preprocessing to set request.user to anonymous
SessionMiddleware(self.get_response).process_request(request)
AuthenticationMiddleware(self.get_response).process_request(request)
request = Request(request=request)
# Try to authenticate user
user_auth_tuple = self.backend_instance.authenticate(request)
self.assertEqual(user_auth_tuple, None)
def test_user_authenticate_to_non_api(self):
"""User can't authenticate to non-API URL path"""
# Need to convert request to a rest_framework Request,
# as would be done in a view's initialize_request() method.
request = self.factory.get(reverse('home'))
request = Request(request=request)
# Try to authenticate user
user_auth_tuple = self.backend_instance.authenticate(request)
self.assertEqual(user_auth_tuple, None)
try:
from unittest import mock
except ImportError: # python < 3
import mock
from django.conf import settings
from django.core.cache import caches
from django.test import override_settings
from django.urls import reverse
from api.tests.utils import GraceDbApiTestBase
class TestThrottling(GraceDbApiTestBase):
"""Test API throttles"""
@mock.patch('api.throttling.BurstAnonRateThrottle.get_rate',
return_value='1/hour'
)
def test_anon_burst_throttle(self, mock_get_rate):
"""Test anonymous user burst throttle"""
url = reverse('api:default:root')
# First request should be OK
response = self.request_as_user(url, "GET")
self.assertEqual(response.status_code, 200)
# Second response should get throttled
response = self.request_as_user(url, "GET")
self.assertContains(response, 'Request was throttled', status_code=429)
self.assertIn('Retry-After', response.headers)
from django.urls import reverse as django_reverse
from rest_framework.test import APIRequestFactory, APISimpleTestCase
from api.utils import api_reverse
class TestApiReverse(APISimpleTestCase):
"""Test behavior of custom api_reverse function"""
@classmethod
def setUpClass(cls):
super(TestApiReverse, cls).setUpClass()
cls.api_version = 'v1'
cls.api_url = django_reverse('api:{version}:root'.format(
version=cls.api_version))
cls.non_api_url = django_reverse('home')
def setUp(self):
super(TestApiReverse, self).setUp()
self.factory = APIRequestFactory()
# Create requests
self.api_request = self.factory.get(self.api_url)
self.non_api_request = self.factory.get(self.non_api_url)
# Simulate version checking that is done in an API view's
# initial() method
self.api_request.version = self.api_version
def test_full_viewname_with_request_to_api(self):
"""
Reverse a fully namespaced viewname, including a request to the API
"""
url = api_reverse('api:v1:root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_full_viewname_with_request_to_api_different_version(self):
"""
Reverse a fully namespaced viewname with a different version than
the corresponding request to the API
"""
url = api_reverse('api:v2:root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v2:root'))
def test_full_viewname_with_request_to_non_api(self):
"""
Reverse a fully namespaced viewname, including a request to a non-API
page
"""
url = api_reverse('api:v1:root', request=self.non_api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_full_viewname_with_no_request(self):
"""
Reverse a fully namespaced viewname, with no associated request
"""
url = api_reverse('api:v1:root', absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_versioned_viewname_with_request_to_api(self):
"""
Reverse a versioned viewname, including a request to the API
"""
url = api_reverse('v1:root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_versioned_viewname_with_request_to_api_different_version(self):
"""
Reverse a versioned viewname with a different version than
the corresponding request to the API
"""
url = api_reverse('v2:root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v2:root'))
def test_versioned_viewname_with_request_to_non_api(self):
"""
Reverse a versioned viewname, including a request to a non-API page
"""
url = api_reverse('v1:root', request=self.non_api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_versioned_viewname_with_no_request(self):
"""
Reverse a versioned viewname, with no associated request
"""
url = api_reverse('v1:root', absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_api_unversioned_viewname_with_request_to_api(self):
"""
Reverse an api-namespaced but unversioned viewname, including a request
to the API
"""
url = api_reverse('api:root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_api_unversioned_viewname_with_request_to_non_api(self):
"""
Reverse an api-namespaced but unversioned viewname, including a request
to a non-API page
"""
url = api_reverse('api:root', request=self.non_api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:default:root'))
def test_api_unversioned_viewname_with_no_request(self):
"""
Reverse an api-namespaced bu unversioned viewname, with no associated
request
"""
url = api_reverse('api:root', absolute_path=False)
self.assertEqual(url, django_reverse('api:default:root'))
def test_relative_viewname_with_request_to_api(self):
url = api_reverse('root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_relative_viewname_with_request_to_non_api(self):
url = api_reverse('root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_relative_viewname_with_no_request(self):
url = api_reverse('root', request=self.api_request,
absolute_path=False)
self.assertEqual(url, django_reverse('api:v1:root'))
def test_reverse_non_api_url(self):
pass
from copy import deepcopy
try:
from functools import reduce
except ImportError: # python < 3
pass
try:
from unittest import mock
except ImportError: # python < 3
import mock
from django.conf import settings
from django.core.cache import caches
from django.test import override_settings
from rest_framework.test import APIClient
from core.tests.utils import GraceDbTestBase
def fix_settings(key, value):
"""
Dynamically override settings for testing. But, it will only
ever be useful if rest_framework fixes the way that their
settings work so that override_settings actually works.
'key' should be x.y.z for nested dictionaries.
"""
api_settings = deepcopy(settings.REST_FRAMEWORK)
key_list = key.split('.')
new_dict = reduce(dict.get, key_list[:-1], api_settings)
new_dict[key_list[-1]] = value
return api_settings
@override_settings(
ALLOW_BLANK_USER_AGENT_TO_API=True,
)
class GraceDbApiTestBase(GraceDbTestBase):
client_class = APIClient
def setUp(self):
super(GraceDbApiTestBase, self).setUp()
# Patch throttle and start patcher
self.patcher = mock.patch('api.throttling.BurstAnonRateThrottle.get_rate',
return_value='1000/second')
self.patcher.start()
def tearDown(self):
super(GraceDbApiTestBase, self).tearDown()
# Clear throttle cache
caches['throttles'].clear()
# Stop patcher
self.patcher.stop()
from django.core.cache import caches
from rest_framework.throttling import AnonRateThrottle, UserRateThrottle
# NOTE: we have to use database-backed throttles to have a centralized location
# where multiple workers (like in the production instance) can access and
# update the same throttling information.
###############################################################################
# Base throttle classes #######################################################
###############################################################################
class DbCachedThrottleMixin(object):
"""Uses a non-default (database-backed) cache"""
cache = caches['throttles']
###############################################################################
# Throttles for unauthenticated users #########################################
###############################################################################
class BurstAnonRateThrottle(DbCachedThrottleMixin, AnonRateThrottle):
scope = 'anon_burst'
class SustainedAnonRateThrottle(DbCachedThrottleMixin, AnonRateThrottle):
scope = 'anon_sustained'
###############################################################################
# Throttles for authenticated users #########################################
###############################################################################
class PostOrPutUserRateThrottle(DbCachedThrottleMixin, UserRateThrottle):
def allow_request(self, request, view):
"""
This is mostly copied from the Rest Framework's SimpleRateThrottle
except we now pass the request to throttle_success
"""
# Don't throttle superusers - causes problems with client
# integration tests
if request.user.is_superuser:
return True
# We don't want to throttle any safe methods
if request.method not in ['POST', 'PUT']:
return True
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success(request)
def throttle_success(self, request):
"""
Inserts the current request's timestamp along with the key
into the cache. Except we only do this if the request is a
writing method (POST or PUT). That's why we needed the request.
"""
if request.method in ['POST', 'PUT']:
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
from django.urls import re_path, include
from .v1 import urls as v1_urls
from .v2 import urls as v2_urls
app_name = 'api'
urlpatterns = [
re_path(r'^', include((v1_urls, 'default'))),
re_path(r'^v1/', include((v1_urls, 'v1'))),
re_path(r'^v2/', include((v2_urls, 'v2'))),
]
import logging
from django.urls import resolve, reverse as django_reverse
from rest_framework.settings import api_settings
from rest_framework.response import Response
from core.urls import build_absolute_uri
# Set up logger
logger = logging.getLogger(__name__)
# Some default values
API_NAMESPACE = 'api'
def api_reverse(viewname, args=None, kwargs=None, request=None, format=None,
absolute_path=True, **extra):
"""
Usage:
# No request and no version in viewname uses default version
# Same for case when request points to a non-API url
api_reverse('api:events:event-list')
api_reverse('events:event-list')
api_reverse('api:events:event-list', request=request)
api_reverse('events:event-list', request=request)
/api/events/
# No request but with version in viewname uses the specified version
# Same for case when request points to a non-API url
api_reverse('api:v1:events:event-list')
api_reverse('api:v1:events:event-list', request=request)
/api/v1/events/
api_reverse('v2:events:event-list')
api_reverse('v2:events:event-list', request=request)
/api/v2/events/
# Request pointing to an API URL uses the specified version in the
# viewname. If a version is not specified in the viewname, the version
# is determined from the request.
api_reverse('api:v1:events:event-list, request=request)
api_reverse('v1:events:event-list, request=request)
/api/v1/events/ (request.path is like /api/(any version)/*)
api_reverse('api:events:event-list, request=request)
api_reverse('events:event-list, request=request)
/api/v2/events (request.path is like /api/v2/*)
"""
# Prepend 'api:' if viewname doesn't start with it.
if not viewname.startswith(API_NAMESPACE + ':'):
viewname = API_NAMESPACE + ':' + viewname
# Handle versioning. Nothing is done by the versioning_class if the
# viewname already has a version namespace.
versioning_class = api_settings.DEFAULT_VERSIONING_CLASS()
viewname = versioning_class.get_versioned_viewname(viewname, request)
# Get URL
url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
if absolute_path:
url = build_absolute_uri(url)
return url
def is_api_request(request_path):
"""
Returns True/False based on whether the request is directed to the API
"""
# This is hard-coded because things break if we try to import it from .urls
api_app_name = 'api'
resolver_match = resolve(request_path)
if (resolver_match.app_names and
resolver_match.app_names[0] == api_app_name):
return True
return False
class ResponseThenRun(Response):
"""
A response class that will do something after the response is sent
"""
def __init__(self, data, callback, callback_kwargs, **kwargs):
super(ResponseThenRun, self).__init__(data, **kwargs)
self.callback = callback
self.callback_kwargs = callback_kwargs
def close(self):
super().close()
self.callback(**self.callback_kwargs)
from __future__ import absolute_import
import logging
import six
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from events.models import Event
from ..fields import GenericField
# Set up logger
logger = logging.getLogger(__name__)
class EventGraceidField(GenericField, serializers.RelatedField):
default_error_messages = {
'invalid': _('Event graceid must be a string.'),
'bad_graceid': _('Not a valid graceid.'),
}
trim_whitespace = True
to_repr = 'graceid'
lookup_field = 'pk'
model = Event
queryset = Event.objects.all()
def _validate_graceid(self, data):
# data should be a string at this point
prefix = data[0]
suffix = data[1:]
if not prefix in 'GEHMTD':
self.fail('bad_graceid')
try:
suffix = int(suffix)
except ValueError:
self.fail('bad_graceid')
def to_internal_value(self, data):
# Add string validation
if not isinstance(data, six.string_types):
self.fail('invalid')
value = six.text_type(data)
if self.trim_whitespace:
value = value.strip()
# Convert to uppercase (allows g1234 to work)
value = value.upper()
# Graceid validation
self._validate_graceid(value)
return super(EventGraceidField, self).to_internal_value(value)
def get_model_dict(self, data):
return {self.lookup_field: data[1:]}
def get_does_not_exist_error(self, graceid):
err_msg = "Event with graceid {graceid} does not exist.".format(
graceid=graceid)
return err_msg
from collections import OrderedDict
from django.utils.functional import cached_property
from django.utils.http import urlencode
from rest_framework import pagination
from rest_framework.response import Response
class CustomEventPagination(pagination.LimitOffsetPagination):
default_limit = 100
limit_query_param = 'count'
offset_query_param = 'start'
# Override the built-in counting method from here:
# https://github.com/encode/django-rest-framework/blob/3.4.7/rest_framework/pagination.py#L47
# And see if it can be cached. Like suggested here:
# https://stackoverflow.com/a/47357445
# update: caching works but I noticed in the browser that sometimes the numRows
# field retains its value from pervious queries. I don't know yet if it affects
# data fetching in the API, but i'm going to leave it commented for now.
# Another update: it would seem that fetching just the integer row-id for query
# results and then counting is faster than fetching the entire queryset. I didn't
# observe any change when doing large-ish counts on dev1 (~33000 events), but
# maybe postgres gets clever enough when doing larger queries. I'll leave it in
# and see what happens: https://stackoverflow.com/a/47357445
# @cached_property
@property
def _get_count(self):
"""
Determine an object count, supporting either querysets or regular lists.
"""
try:
return self.queryset.values('id').count()
except (AttributeError, TypeError):
return len(self.queryset)
def paginate_queryset(self, queryset, request, view=None):
self.limit = self.get_limit(request)
self.queryset = queryset
if self.limit is None:
return None
self.offset = self.get_offset(request)
self.count = self._get_count
self.request = request
if self.count > self.limit and self.template is not None:
self.display_page_controls = True
if self.count == 0 or self.offset > self.count:
return []
return list(self.queryset[self.offset:self.offset + self.limit])
def get_paginated_response(self, data):
numRows = self.count
# Get base URI
base_uri = self.request.build_absolute_uri(self.request.path)
# Construct custom link for "last" page
last = max(0, (numRows / self.limit)) * self.limit
param_dict = {
'start': last,
self.limit_query_param: self.limit,
}
last_uri = base_uri + '?' + urlencode(param_dict)
output = OrderedDict([
('numRows', numRows),
('events', data),
('links',
OrderedDict([
('self', self.request.build_absolute_uri()),
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('first', base_uri),
('last', last_uri),
])),
])
return Response(output)
import logging
from rest_framework import permissions
# Set up logger
logger = logging.getLogger(__name__)
class CanUpdateGrbEvent(permissions.BasePermission):
def has_permission(self, request, view):
return request.user.has_perm('events.t90_grbevent')
from django.conf import settings
from rest_framework import serializers
from rest_framework.fields import CurrentUserDefault
from events.models import Event, EventLog, GrbEvent, NeutrinoEvent, \
CoincInspiralEvent, MLyBurstEvent, MultiBurstEvent, LalInferenceBurstEvent, \
SingleInspiral, SimInspiralEvent
from api.utils import api_reverse
from superevents.models import Superevent
# Fields in the siminspiral table to expose to public:
snglinsp_public_fields = ['ifo','end_time','end_time_ns']
multiburst_public_fields = ['ifos', 'single_ifo_times']
# define a function that returns a tuple of the superevent window range:
def superevent_window_range(gpstime):
return (gpstime - settings.EVENT_SUPEREVENT_WINDOW_BEFORE, \
gpstime + settings.EVENT_SUPEREVENT_WINDOW_AFTER)
class GRBEventSerializer(serializers.ModelSerializer):
T90 = serializers.SerializerMethodField()
class Meta:
model = GrbEvent
fields =('author_ivorn', 'dec', 'designation', 'redshift', \
'how_description', 'coord_system', 'trigger_id', 'error_radius', \
'how_reference_url', 'ra', 'ivorn', 'trigger_duration', \
'author_shortname', 'T90', 'observatory_location_id')
def get_T90(self, obj):
return obj.t90
class NeutrinoEventSerializer(serializers.ModelSerializer):
class Meta:
model = NeutrinoEvent
fields =('ivorn', 'coord_system', 'ra', 'dec', 'error_radius', \
'far_ne', 'far_unit', 'signalness', 'energy', 'src_error_90', \
'src_error_50', 'amon_id', 'run_id', 'event_id', 'stream')
class CoincInspiralEventSerializer(serializers.ModelSerializer):
class Meta:
model = CoincInspiralEvent
fields = ('ifos', 'end_time', 'end_time_ns', 'mass', 'mchirp',
'minimum_duration', 'snr', 'false_alarm_rate', 'combined_far')
class MLyBurstEventSerializer(serializers.ModelSerializer):
scores = serializers.SerializerMethodField()
SNR = serializers.SerializerMethodField()
class Meta:
model = MLyBurstEvent
fields =('ifos', 'central_freq', 'bandwidth', 'duration', 'central_time', \
'detection_statistic', 'SNR', 'bbh', 'sglf', 'sghf', \
'background', 'glitch', 'freq_correlation', 'channels', 'scores')
def to_representation(self, obj):
channels_out = None
ret = super().to_representation(obj)
if obj.channels:
channels_out = obj.channels.split(',')
ret['channels'] = channels_out
return ret
def get_scores(self, obj):
return {'coherency': obj.score_coher,
'coincidence': obj.score_coinc,
'combined': obj.score_comb}
def get_SNR(self, obj):
return obj.snr
class MultiBurstEventSerializer(serializers.ModelSerializer):
class Meta:
model = MultiBurstEvent
fields =('ifos', 'start_time', 'start_time_ns', 'duration', 'strain', \
'peak_time', 'peak_time_ns', 'central_freq', 'bandwidth', \
'amplitude', 'mchirp', 'snr', 'confidence', 'false_alarm_rate', \
'ligo_axis_ra', 'ligo_axis_dec', 'ligo_angle', 'ligo_angle_sig', \
'single_ifo_times', 'hoft', 'code')
def __init__(self, *args, is_external=False, **kwargs):
super().__init__(*args, **kwargs)
if is_external:
self.fields = {field_name: self.fields[field_name] \
for field_name in multiburst_public_fields}
class LalInferenceBurstEventSerializer(serializers.ModelSerializer):
class Meta:
model = LalInferenceBurstEvent
fields =('bci', 'quality_mean', 'quality_median', 'bsn', \
'omicron_snr_network', 'omicron_snr_H1', 'omicron_snr_L1', \
'omicron_snr_V1', 'hrss_mean', 'hrss_median', 'frequency_mean', \
'frequency_median')
class SimInspiralEventSerializer(serializers.ModelSerializer):
class Meta:
model = SimInspiralEvent
fields =('mass1', 'mass2', 'eta', 'coa_phase', 'mchirp', 'spin1x', \
'spin1y', 'spin1z', 'spin2x', 'spin2y', 'spin2z', 'end_time_gmst', \
'f_lower', 'f_final', 'distance', 'latitude', 'longitude', \
'polarization', 'inclination', 'theta0', 'phi0', 'alpha', 'beta', \
'psi0', 'psi3', 'alpha1', 'alpha2', 'alpha3', 'alpha4', 'alpha5', \
'alpha6', 'eff_dist_g', 'eff_dist_h', 'eff_dist_l', 'eff_dist_t', \
'eff_dist_v', 'amplitude', 'tau', 'phi', 'freq', 'amp_order', \
'geocent_end_time', 'geocent_end_time_ns', 'numrel_mode_min', \
'numrel_mode_max', 'bandpass', 'g_end_time', 'g_end_time_ns', \
'h_end_time', 'h_end_time_ns', 'l_end_time', 'l_end_time_ns', \
't_end_time', 't_end_time_ns', 'v_end_time', 'v_end_time_ns', \
'waveform', 'numrel_data', 'source', 'taper', 'source_channel', \
'destination_channel')
class SingleInspiralSerializer(serializers.ModelSerializer):
class Meta:
model = SingleInspiral
fields = ('ifo', 'search', 'channel', 'end_time', 'end_time_ns', \
'end_time_gmst', 'impulse_time', 'impulse_time_ns', \
'template_duration', 'event_duration', 'amplitude', \
'eff_distance', 'coa_phase', 'mass1', 'mass2', 'mchirp', \
'mtotal', 'eta', 'kappa', 'chi', 'tau0', 'tau2', 'tau3', \
'tau4', 'tau5', 'ttotal', 'psi0', 'psi3', 'alpha', \
'alpha1', 'alpha2', 'alpha3', 'alpha4', 'alpha5', 'alpha6', \
'beta', 'f_final', 'snr', 'chisq', 'chisq_dof', 'bank_chisq', \
'bank_chisq_dof', 'cont_chisq', 'cont_chisq_dof', 'sigmasq', \
'rsqveto_duration', 'Gamma0', 'Gamma1', 'Gamma2', 'Gamma3', \
'Gamma4', 'Gamma5', 'Gamma6', 'Gamma7', 'Gamma8', 'Gamma9', \
'spin1x', 'spin1y', 'spin1z', 'spin2x', 'spin2y', 'spin2z')
def __init__(self, *args, is_external=False, **kwargs):
super().__init__(*args, **kwargs)
if is_external:
self.fields = {field_name: self.fields[field_name] \
for field_name in snglinsp_public_fields}
# This dict maps the key in the extra_attributes dict to its corresponding
# event subclass name. SingleInspiral events are a little different, so do
# that separately for now.
EVENT_ATTRIBUTE_MAP = {
GrbEvent: ('GRB', 'grbevent', GRBEventSerializer),
NeutrinoEvent: ('NeutrinoEvent', 'neutrinoevent', NeutrinoEventSerializer),
CoincInspiralEvent: ('CoincInspiral', 'coincinspiralevent', CoincInspiralEventSerializer),
MLyBurstEvent: ('MLyBurst', 'mlyburstevent', MLyBurstEventSerializer),
MultiBurstEvent: ('MultiBurst', 'multiburstevent', MultiBurstEventSerializer),
LalInferenceBurstEvent: ('LalInferenceBurst', 'lalinferenceburstevent', LalInferenceBurstEventSerializer),
SimInspiralEvent: ('SimInspiral', 'siminspiralevent', SimInspiralEventSerializer),
}
class EventSerializer(serializers.ModelSerializer):
# Fields.
group = serializers.CharField(source="group.name")
pipeline = serializers.CharField(source="pipeline.name")
search = serializers.CharField(source="search.name", allow_null=True)
submitter = serializers.SlugRelatedField(slug_field="username",
read_only=True)
# New fields.
labels = serializers.SerializerMethodField('get_labels')
created = serializers.DateTimeField(format=settings.GRACE_STRFTIME_FORMAT,
read_only=True)
far_is_upper_limit = serializers.SerializerMethodField()
extra_attributes = serializers.SerializerMethodField(allow_null=True)
links = serializers.SerializerMethodField('get_links')
superevent_neighbours = serializers.SerializerMethodField()
def __init__(self, *args, **kwargs):
super(EventSerializer, self).__init__(*args, **kwargs)
self.request = self.context.get('request', None)
self.external = self.context.get('request_is_external', False)
# don't include neighboring superevents if this is nested inside itself
self.is_nested = self.context.get('is_nested', False)
# don't include neighboring superevents for serializations that aren't
# alerts.
self.is_alert = self.context.get('is_alert', False)
if not self.is_alert or self.is_nested:
self.fields.pop('superevent_neighbours')
class Meta:
model = Event
fields = ('submitter', 'created', 'group',
'pipeline', 'graceid', 'gpstime', 'reporting_latency',
'instruments', 'nevents', 'offline',
'search', 'far', 'far_is_upper_limit', 'likelihood',
'labels', 'extra_attributes', 'superevent', 'links',
'superevent_neighbours')
def to_representation(self, obj):
display_far, far_is_upper_limit = self.display_far_and_limit(obj)
ret = super().to_representation(obj)
ret['far'] = display_far
ret['superevent'] = obj.superevent.superevent_id if obj.superevent else None
ret['far_is_upper_limit'] = far_is_upper_limit
return ret
# modify the display far and limit parameter. TODO this same code is
# duplicated in at least (?) two other places, it should really get
# combined into one function. really though, it's probably not even
# necessary since the VOEVENT_FAR_FLOOR is zero (and so not used...)
# but who knows, it might come back at some point.
def display_far_and_limit(self, obj):
far_is_upper_limit = False
display_far = obj.far
if obj.far and self.external and obj.far < settings.VOEVENT_FAR_FLOOR:
display_far = settings.VOEVENT_FAR_FLOOR
far_is_upper_limit = True
return display_far, far_is_upper_limit
def get_labels(self, obj):
return [label.name for label in obj.labels.all()]
def display_far_floor(self, obj):
if obj.far and self.external and obj.far < settings.VOEVENT_FAR_FLOOR:
return True
else:
return False
def get_links(self, obj):
graceid = obj.graceid
return {
"neighbors" : api_reverse("events:neighbors", args=[graceid], request=self.request),
"log" : api_reverse("events:eventlog-list", args=[graceid], request=self.request),
"emobservations" : api_reverse("events:emobservation-list", args=[graceid], request=self.request),
"files" : api_reverse("events:files", args=[graceid], request=self.request),
"labels" : api_reverse("events:labels", args=[graceid], request=self.request),
"self" : api_reverse("events:event-detail", args=[graceid], request=self.request),
"tags" : api_reverse("events:eventtag-list", args=[graceid], request=self.request),
}
# This is a dummy placeholder since it gets overwritten by to_representation
def get_far_is_upper_limit(self, obj):
return False
def get_extra_attributes(self, obj):
extra_attrs_dict = {}
# Include this section for requests by internal users and
# FIXME alert contents:
if (self.request is not None and not self.external) or self.is_alert or self.is_nested:
for subevent_type, subevent_vals in EVENT_ATTRIBUTE_MAP.items():
if isinstance(obj, subevent_type):
extra_attrs_dict.update({subevent_vals[0]:
subevent_vals[2](getattr(obj, subevent_vals[1])).data})
# For CoincInspiral events, append the SingleInspiral data, after
# checking to see if the tables are actually there, just to be sure.
if isinstance(obj, CoincInspiralEvent):
if obj.singleinspiral_set.exists():
extra_attrs_dict.update({'SingleInspiral':
[SingleInspiralSerializer(e, is_external=self.external).data for e in obj.singleinspiral_set.all()]
})
# For the special inexplicable case where the user is external, then
# return specific fields from SingleInspiral or multiburst events
elif (self.request is not None and self.external):
if isinstance(obj, CoincInspiralEvent):
extra_attrs_dict.update({'SingleInspiral':
[SingleInspiralSerializer(e, is_external=self.external).data \
for e in obj.singleinspiral_set.all()]})
elif isinstance(obj, MultiBurstEvent):
extra_attrs_dict.update({'MultiBurst':
MultiBurstEventSerializer(obj, is_external=self.external).data})
return extra_attrs_dict
def get_superevent_neighbours(self, obj):
return_dict = {}
# Look for neighbors only if the event has a valid gpstime. This will
# probably trigger if event.gpstime=0.0, but that event isn't valid anyway.
if obj.gpstime:
for s in self.superevent_neighbours_set(obj):
sn_dict = {}
# Add superevent_id:
sn_dict.update({
'superevent_id': s.superevent_id,
})
# Add gw_events:
sn_dict.update({
'gw_events': [g.graceid for g in s.get_internal_events()],
})
# Add basic superevent info:
sn_dict.update({
'far': s.far,
't_start': s.t_start,
't_0': s.t_0,
't_end': s.t_end,
})
# Add labels:
sn_dict.update({
'labels': [l.name for l in s.labels.all()],
})
# Add preferred_event:
sn_dict.update({
'preferred_event': s.preferred_event.graceid,
})
# Add preferred_event_data:
sn_dict.update({
'preferred_event_data':
EventSerializer(s.preferred_event.get_subclass(),
context={'is_nested': True}).data,
})
# Add this info to the response
return_dict.update(
{s.superevent_id: sn_dict})
return return_dict
def superevent_category(self, obj):
# return the corresponding superevent category of the event
if obj.is_production():
return Superevent.SUPEREVENT_CATEGORY_PRODUCTION
elif obj.is_mdc():
return Superevent.SUPEREVENT_CATEGORY_MDC
elif obj.is_test():
return Superevent.SUPEREVENT_CATEGORY_TEST
def superevent_neighbours_set(self, obj):
# query the database for nearby superevents
return Superevent.objects.filter(
t_0__range=superevent_window_range(obj.gpstime),
category=self.superevent_category(obj)) \
.prefetch_related('events', 'labels') \
.select_related('preferred_event', 'preferred_event__pipeline',
'preferred_event__group',
'preferred_event__search',
'preferred_event__submitter',
'preferred_event__superevent',
'preferred_event__grbevent',
'preferred_event__neutrinoevent',
'preferred_event__coincinspiralevent',
'preferred_event__mlyburstevent',
'preferred_event__multiburstevent',
'preferred_event__lalinferenceburstevent',
'preferred_event__siminspiralevent',
)
class EventLogSerializer(serializers.ModelSerializer):
"""docstring for EventLogSerializer"""
comment = serializers.CharField(required=True, max_length=200)
class Meta:
model = EventLog
fields = ('comment', 'issuer', 'created')
from __future__ import absolute_import
from django.conf import settings
from django.urls import reverse
from api.tests.utils import GraceDbApiTestBase
from events.tests.mixins import EventSetup
from ...settings import API_VERSION
def v_reverse(viewname, *args, **kwargs):
"""Easily customizable versioned API reverse for testing"""
viewname = 'api:{version}:'.format(version=API_VERSION) + viewname
return reverse(viewname, *args, **kwargs)
class TestPublicAccess(EventSetup, GraceDbApiTestBase):
def test_event_list(self):
"""Unauthenticated user can't access event list"""
url = v_reverse('events:event-list')
methods = ["GET", "POST"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_detail(self):
"""Unauthenticated user can't access event detail"""
url = v_reverse('events:event-detail', args=['G123456'])
methods = ["GET", "PUT"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_log_list(self):
"""Unauthenticated user can't access event log list"""
url = v_reverse('events:eventlog-list', args=['G123456'])
methods = ["GET", "POST"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_log_detail(self):
"""Unauthenticated user can't access event log detail"""
url = v_reverse('events:eventlog-detail', args=['G123456', 1])
methods = ["GET"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_voevent_list(self):
"""Unauthenticated user can't access event VOEvent list"""
url = v_reverse('events:voevent-list', args=['G123456'])
methods = ["GET", "POST"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_voevent_detail(self):
"""Unauthenticated user can't access event VOEvent detail"""
url = v_reverse('events:voevent-detail', args=['G123456', 1])
methods = ["GET"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_embbeventlog_list(self):
"""Unauthenticated user can't access event EMBBEventLog list"""
url = v_reverse('events:embbeventlog-list', args=['G123456'])
methods = ["GET", "POST"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_embbeventlog_detail(self):
"""Unauthenticated user can't access event EMBBEventLog detail"""
url = v_reverse('events:embbeventlog-detail', args=['G123456', 1])
methods = ["GET"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_emobservation_list(self):
"""Unauthenticated user can't access event EMObservation list"""
url = v_reverse('events:emobservation-list', args=['G123456'])
methods = ["GET", "POST"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_emobservation_detail(self):
"""Unauthenticated user can't access event EMObservation detail"""
url = v_reverse('events:emobservation-detail', args=['G123456', 1])
methods = ["GET"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_tag_list(self):
"""Unauthenticated user can't access event tag list"""
url = v_reverse('events:eventtag-list', args=['G123456'])
methods = ["GET"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_tag_detail(self):
"""Unauthenticated user can't access event tag detail"""
url = v_reverse('events:eventtag-detail', args=['G123456', 'tagname'])
methods = ["GET"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_log_tag_list(self):
"""Unauthenticated user can't access event log tag list"""
url = v_reverse('events:eventlogtag-list', args=['G123456', 1])
methods = ["GET"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_log_tag_detail(self):
"""Unauthenticated user can't access event log tag detail"""
url = v_reverse('events:eventlogtag-detail', args=['G123456',
1, 'tagname'])
methods = ["GET", "PUT", "DELETE"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_permission_list(self):
"""Unauthenticated user can't access event permission list"""
url = v_reverse('events:eventpermission-list', args=['G123456'])
methods = ["GET"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_group_permission_list(self):
"""Unauthenticated user can't access event group permission list"""
url = v_reverse('events:groupeventpermission-list', args=['G123456',
'group_name'])
methods = ["GET"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_group_permission_detail(self):
"""Unauthenticated user can't access event group permission list"""
url = v_reverse('events:groupeventpermission-detail', args=['G123456',
'group_name', 'perm_name'])
methods = ["GET", "PUT", "DELETE"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_files(self):
"""Unauthenticated user can't access event files (list or detail)"""
url = v_reverse('events:files', args=['G123456', 'file_name'])
methods = ["GET", "PUT"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_labels(self):
"""Unauthenticated user can't access event labels (list or detail)"""
url = v_reverse('events:labels', args=['G123456', 'label_name'])
methods = ["GET", "PUT", "DELETE"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_neighbors(self):
"""Unauthenticated user can't access event neighbors list"""
url = v_reverse('events:labels', args=['G123456'])
methods = ["GET"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
def test_event_signoff_list(self):
"""Unauthenticated user can't access event signoff list"""
url = v_reverse('events:labels', args=['G123456'])
methods = ["GET"]
for http_method in methods:
response = self.request_as_user(url, http_method)
self.assertContains(
response,
'Authentication credentials were not provided',
status_code=403
)
try:
from unittest import mock
except ImportError: # python < 3
import mock
import pytest
from rest_framework.exceptions import ValidationError
from ..fields import EventGraceidField
@pytest.mark.parametrize(
"graceid",
[1234, 1.234, (), [], None, True, lambda x: x]
)
def test_bad_types(graceid):
field = EventGraceidField()
err_msg = 'Event graceid must be a string.'
with pytest.raises(ValidationError, match=err_msg):
field.to_internal_value(graceid)
@pytest.mark.parametrize(
"graceid",
['GG', '1234G', 'G.1234', 'G1234z', 'Q1234', 'GH12']
)
def test_graceid_bad_format(graceid):
field = EventGraceidField()
err_msg = 'Not a valid graceid.'
with pytest.raises(ValidationError, match=err_msg):
field.to_internal_value(graceid)
@pytest.mark.parametrize(
"graceid",
['G1234', 'E0001', 'H12', 'M352345', 'T2323', ' T123', 'T123 ', ' T123 ',
'g4567', 't456 ', ' m2398 ', ' e8732']
)
def test_valid_graceids(graceid):
field = EventGraceidField()
# WHY do we have to mock this as 'gracedb.api...'
# instead of just 'api...'??
super_tiv = 'gracedb.api.v1.fields.GenericField.to_internal_value'
with mock.patch(super_tiv) as mock_super_tiv:
field.to_internal_value(graceid)
call_args, _ = mock_super_tiv.call_args
assert mock_super_tiv.call_count == 1
assert len(call_args) == 1
assert call_args[0] == graceid.upper().strip()
try:
from unittest import mock
except ImportError: # python < 3
import mock
import pytest
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Permission
from django.contrib.contenttypes.models import ContentType
from django.urls import reverse
from guardian.shortcuts import assign_perm
from rest_framework.test import APIRequestFactory as rf
from events.models import Event, GrbEvent, Group, Pipeline, Search
from ..views import GrbEventPatchView
from ...settings import API_VERSION
UserModel = get_user_model()
###############################################################################
# UTILITIES ###################################################################
###############################################################################
def v_reverse(viewname, *args, **kwargs):
"""Easily customizable versioned API reverse for testing"""
viewname = 'api:{version}:'.format(version=API_VERSION) + viewname
return reverse(viewname, *args, **kwargs)
def create_grbevent(internal_group):
user = UserModel.objects.create(username='grbevent.creator')
grb_search, _ = Search.objects.get_or_create(name='GRB')
grbevent = GrbEvent.objects.create(
submitter=user,
group=Group.objects.create(name='External'),
pipeline=Pipeline.objects.create(name=settings.GRB_PIPELINES[0]),
search=grb_search
)
grbevent.save()
p, _ = Permission.objects.get_or_create(
content_type=ContentType.objects.get_for_model(GrbEvent),
codename='change_grbevent'
)
assign_perm(p, internal_group, grbevent)
return grbevent
###############################################################################
# FIXTURES ####################################################################
###############################################################################
###############################################################################
# TESTS #######################################################################
###############################################################################
@pytest.mark.django_db
def test_access(internal_user, internal_group, standard_plus_grb_user):
# NOTE: standard_plus_grb_user is a parametrized fixture (basically a
# list of three users), so this test will run three times.
# Create a GrbEvent
grbevent = create_grbevent(internal_group)
# Get URL and set up request and view
url = v_reverse("events:update-grbevent", args=[grbevent.graceid])
data = {'redshift': 2}
request = rf().patch(url, data=data)
request.user = standard_plus_grb_user
view = GrbEventPatchView.as_view()
with mock.patch('gracedb.api.v1.events.views.EventAlertIssuer'):
# Process request
response = view(request, grbevent.graceid)
response.render()
# Update grbevent in memory from database
grbevent.refresh_from_db()
if standard_plus_grb_user.username != 'grb.user':
assert response.status_code == 403
assert grbevent.redshift is None
else:
assert response.status_code == 200
assert grbevent.redshift == 2
@pytest.mark.parametrize("data",
[
{'redshift': 2, 't90': 12, 'designation': 'good'},
{'ra': 1, 'dec': 2, 'error_radius': 3},
# FAR should not be updated
{'far': 123, 't90': 15},
]
)
@pytest.mark.django_db
def test_parameter_updates(grb_user, internal_group, data):
grbevent = create_grbevent(internal_group)
grbevent.far = 321
grbevent.save(update_fields=['far'])
# Get URL and set up request and view
url = v_reverse("events:update-grbevent", args=[grbevent.graceid])
request = rf().patch(url, data=data)
request.user = grb_user
view = GrbEventPatchView.as_view()
with mock.patch('gracedb.api.v1.events.views.EventAlertIssuer'):
# Process request
response = view(request, grbevent.graceid)
response.render()
# Update grbevent in memory from database
grbevent.refresh_from_db()
# Check response
assert response.status_code == 200
# Compare parameters
for attr in GrbEventPatchView.updatable_attributes:
grbevent_attr = getattr(grbevent, attr)
if attr in data:
assert grbevent_attr == data.get(attr)
else:
assert grbevent_attr is None
# FAR should not be updated even by requests which include FAR
assert grbevent.far == 321
@pytest.mark.parametrize("data", [{}, {'redshift': 2}])
@pytest.mark.django_db
def test_update_with_no_new_data(grb_user, internal_group, data):
grbevent = create_grbevent(internal_group)
grbevent.redshift = 2
grbevent.save(update_fields=['redshift'])
# Get URL and set up request and view
url = v_reverse("events:update-grbevent", args=[grbevent.graceid])
request = rf().patch(url, data=data)
request.user = grb_user
view = GrbEventPatchView.as_view()
with mock.patch('gracedb.api.v1.events.views.EventAlertIssuer'):
# Process request
response = view(request, grbevent.graceid)
response.render()
# Check response
assert response.status_code == 400
assert 'Request would not modify the GRB event' \
in response.content.decode()
@pytest.mark.parametrize("data",
[
{'redshift': 'random string'},
{'t90': 'random string'},
{'ra': 'random string'},
{'dec': 'random string'},
{'error_radius': 'random string'},
]
)
@pytest.mark.django_db
def test_update_with_bad_data(grb_user, internal_group, data):
grbevent = create_grbevent(internal_group)
# Get URL and set up request and view
url = v_reverse("events:update-grbevent", args=[grbevent.graceid])
request = rf().patch(url, data=data)
request.user = grb_user
view = GrbEventPatchView.as_view()
with mock.patch('gracedb.api.v1.events.views.EventAlertIssuer'):
# Process request
response = view(request, grbevent.graceid)
response.render()
# Check response
assert response.status_code == 400
assert 'must be a float' in response.content.decode()
@pytest.mark.django_db
def test_update_non_grbevent(grb_user, internal_group):
event = Event.objects.create(
submitter=grb_user,
group=Group.objects.create(name='External'),
pipeline=Pipeline.objects.create(name='other_pipeline'),
)
event.save()
p, _ = Permission.objects.get_or_create(
content_type=ContentType.objects.get_for_model(Event),
codename='change_event'
)
assign_perm(p, internal_group, event)
# Get URL and set up request and view
url = v_reverse("events:update-grbevent", args=[event.graceid])
request = rf().patch(url, data={'redshift': 2})
request.user = grb_user
view = GrbEventPatchView.as_view()
with mock.patch('gracedb.api.v1.events.views.EventAlertIssuer'):
# Process request
response = view(request, event.graceid)
response.render()
# Check response
assert response.status_code == 400
assert 'Cannot update GRB event parameters for non-GRB event' \
in response.content.decode()
from api.throttling import PostOrPutUserRateThrottle
class EventCreationThrottle(PostOrPutUserRateThrottle):
scope = 'event_creation'
class AnnotationThrottle(PostOrPutUserRateThrottle):
scope = 'annotation'