Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bilby
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
lscsoft
bilby
Commits
09ea2cb3
Commit
09ea2cb3
authored
3 years ago
by
Colm Talbot
Browse files
Options
Downloads
Patches
Plain Diff
Allow user to provide `variable_names` in hyper Model
parent
48d7685e
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!1069
Allow user to provide `variable_names` in hyper Model
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
bilby/hyper/model.py
+38
-9
38 additions, 9 deletions
bilby/hyper/model.py
test/hyper/hyper_pe_test.py
+31
-0
31 additions, 0 deletions
test/hyper/hyper_pe_test.py
with
69 additions
and
9 deletions
bilby/hyper/model.py
+
38
−
9
View file @
09ea2cb3
...
...
@@ -2,10 +2,14 @@ from ..core.utils import infer_args_from_function_except_n_args
class
Model
(
object
):
"""
Population model
r
"""
Population model
that combines a set of factorizable models.
This should take population parameters and return the probability.
.. math::
p(\theta | \Lambda) = \prod_{i} p_{i}(\theta | \Lambda)
"""
def
__init__
(
self
,
model_functions
=
None
):
...
...
@@ -13,7 +17,11 @@ class Model(object):
Parameters
==========
model_functions: list
List of functions to compute.
List of callables to compute the probability.
If this includes classes, the `__call__` method should return the
probability.
The requires variables are chosen at run time based on either
inspection or querying a :code:`variable_names` attribute.
"""
self
.
models
=
model_functions
self
.
_cached_parameters
=
{
model
:
None
for
model
in
self
.
models
}
...
...
@@ -22,6 +30,21 @@ class Model(object):
self
.
parameters
=
dict
()
def
prob
(
self
,
data
,
**
kwargs
):
"""
Compute the total population probability for the provided data given
the keyword arguments.
Parameters
==========
data: dict
Dictionary containing the points at which to evaluate the
population model.
kwargs: dict
The population parameters. These cannot include any of
:code:`[
"
dataset
"
,
"
data
"
,
"
self
"
,
"
cls
"
]` unless the
:code:`variable_names` attribute is available for the relevant
model.
"""
probability
=
1.0
for
ii
,
function
in
enumerate
(
self
.
models
):
function_parameters
=
self
.
_get_function_parameters
(
function
)
...
...
@@ -37,11 +60,17 @@ class Model(object):
return
probability
def
_get_function_parameters
(
self
,
func
):
"""
If the function is a class method we need to remove more arguments
"""
param_keys
=
infer_args_from_function_except_n_args
(
func
,
n
=
0
)
ignore
=
[
'
dataset
'
,
'
self
'
,
'
cls
'
]
for
key
in
ignore
:
if
key
in
param_keys
:
del
param_keys
[
param_keys
.
index
(
key
)]
"""
If the function is a class method we need to remove more arguments or
have the variable names provided in the class.
"""
if
hasattr
(
func
,
"
variable_names
"
):
param_keys
=
func
.
variable_names
else
:
param_keys
=
infer_args_from_function_except_n_args
(
func
,
n
=
0
)
ignore
=
[
"
dataset
"
,
"
data
"
,
"
self
"
,
"
cls
"
]
for
key
in
ignore
:
if
key
in
param_keys
:
del
param_keys
[
param_keys
.
index
(
key
)]
parameters
=
{
key
:
self
.
parameters
[
key
]
for
key
in
param_keys
}
return
parameters
This diff is collapsed.
Click to expand it.
test/hyper/hyper_pe_test.py
+
31
−
0
View file @
09ea2cb3
import
unittest
import
numpy
as
np
import
pandas
as
pd
from
parameterized
import
parameterized
import
bilby.hyper
as
hyp
def
_toy_function
(
data
,
dataset
,
self
,
cls
,
a
,
b
,
c
):
return
a
class
_ToyClassNoVariableNames
:
def
__call__
(
self
,
a
,
b
,
c
):
return
a
class
_ToyClassVariableNames
:
variable_names
=
[
"
a
"
,
"
b
"
,
"
c
"
]
def
__call__
(
self
,
**
kwargs
):
return
kwargs
.
get
(
"
a
"
,
1
)
class
TestHyperLikelihood
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
keys
=
[
"
a
"
,
"
b
"
,
"
c
"
]
...
...
@@ -38,6 +57,18 @@ class TestHyperLikelihood(unittest.TestCase):
)
self
.
assertTrue
(
np
.
isnan
(
like
.
evidence_factor
))
@parameterized.expand
([
(
"
func
"
,
_toy_function
),
(
"
class_no_names
"
,
_ToyClassNoVariableNames
()),
(
"
class_with_names
"
,
_ToyClassVariableNames
()),
])
def
test_get_function_parameters
(
self
,
_
,
model
):
expected
=
dict
(
a
=
1
,
b
=
2
,
c
=
3
)
model
=
hyp
.
model
.
Model
([
model
])
model
.
parameters
.
update
(
expected
)
result
=
model
.
_get_function_parameters
(
model
.
models
[
0
])
self
.
assertDictEqual
(
expected
,
result
)
def
test_len_samples_with_max_samples
(
self
):
like
=
hyp
.
likelihood
.
HyperparameterLikelihood
(
self
.
posteriors
,
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment