Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bilby
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
lscsoft
bilby
Commits
f67b4961
Commit
f67b4961
authored
5 years ago
by
Moritz
Browse files
Options
Downloads
Patches
Plain Diff
Some modifications to ConditionalPriorDict after testing
parent
4adcd871
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!332
Resolve "Introduce conditional prior sets"
Pipeline
#84110
failed
5 years ago
Stage: test
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
bilby/core/prior.py
+11
-9
11 additions, 9 deletions
bilby/core/prior.py
test/prior_test.py
+144
-4
144 additions, 4 deletions
test/prior_test.py
with
155 additions
and
13 deletions
bilby/core/prior.py
+
11
−
9
View file @
f67b4961
...
...
@@ -501,10 +501,10 @@ class ConditionalPriorDict(PriorDict):
self
.
_rescale_indexes
=
[]
self
.
_least_recently_rescaled_keys
=
[]
super
(
ConditionalPriorDict
,
self
).
__init__
(
dictionary
=
dictionary
,
filename
=
filename
)
self
.
resolved
=
False
self
.
_
resolved
=
False
self
.
_resolve_conditions
()
def
_resolve_conditions
(
self
,
disable_log
=
False
):
def
_resolve_conditions
(
self
):
"""
Resolves how variables depend on each other and automatically sorts them into the right order
"""
conditioned_keys_unsorted
=
[
key
for
key
in
self
.
keys
()
if
hasattr
(
self
[
key
],
'
condition_func
'
)]
self
.
_unconditional_keys
=
[
key
for
key
in
self
.
keys
()
if
not
hasattr
(
self
[
key
],
'
condition_func
'
)]
...
...
@@ -519,12 +519,10 @@ class ConditionalPriorDict(PriorDict):
self
.
_sorted_keys
=
self
.
_unconditional_keys
.
copy
()
self
.
_sorted_keys
.
extend
(
self
.
conditional_keys
)
self
.
resolved
=
True
self
.
_
resolved
=
True
if
len
(
conditioned_keys_unsorted
)
!=
0
:
self
.
resolved
=
False
if
not
disable_log
:
logger
.
warning
(
'
This set contains unresolvable conditions
'
)
self
.
_resolved
=
False
def
_check_conditions_resolved
(
self
,
key
,
sampled_keys
):
"""
Checks if all required variables have already been sampled so we can sample this key
"""
...
...
@@ -537,7 +535,7 @@ class ConditionalPriorDict(PriorDict):
def
sample_subset
(
self
,
keys
=
iter
([]),
size
=
None
):
self
.
convert_floats_to_delta_functions
()
subset_dict
=
ConditionalPriorDict
({
key
:
self
[
key
]
for
key
in
keys
})
if
not
subset_dict
.
resolved
:
if
not
subset_dict
.
_
resolved
:
raise
IllegalConditionsException
(
"
The current set of priors contains unresolveable conditions.
"
)
res
=
dict
()
for
key
in
subset_dict
.
sorted_keys
:
...
...
@@ -634,7 +632,7 @@ class ConditionalPriorDict(PriorDict):
self
.
_rescale_indexes
=
np
.
append
(
unconditional_idxs
,
conditional_idxs
)
def
_check_resolved
(
self
):
if
not
self
.
resolved
:
if
not
self
.
_
resolved
:
raise
IllegalConditionsException
(
"
The current set of priors contains unresolveable conditions.
"
)
@property
...
...
@@ -651,7 +649,11 @@ class ConditionalPriorDict(PriorDict):
def
__setitem__
(
self
,
key
,
value
):
super
(
ConditionalPriorDict
,
self
).
__setitem__
(
key
,
value
)
self
.
_resolve_conditions
(
disable_log
=
True
)
self
.
_resolve_conditions
()
def
__delitem__
(
self
,
key
):
super
(
ConditionalPriorDict
,
self
).
__delitem__
(
key
)
self
.
_resolve_conditions
()
def
create_default_prior
(
name
,
default_priors_file
=
None
):
...
...
This diff is collapsed.
Click to expand it.
test/prior_test.py
+
144
−
4
View file @
f67b4961
...
...
@@ -294,6 +294,11 @@ class TestPriorClasses(unittest.TestCase):
outside_domain
=
np
.
linspace
(
prior
.
minimum
-
1e4
,
prior
.
minimum
-
1
,
1000
)
self
.
assertTrue
(
all
(
prior
.
prob
(
outside_domain
)
==
0
))
def
test_least_recently_sampled
(
self
):
for
prior
in
self
.
priors
:
lrs
=
prior
.
sample
()
self
.
assertEqual
(
lrs
,
prior
.
least_recently_sampled
)
def
test_prob_and_ln_prob
(
self
):
for
prior
in
self
.
priors
:
sample
=
prior
.
sample
()
...
...
@@ -852,7 +857,7 @@ class TestConditionalPrior(unittest.TestCase):
def
setUp
(
self
):
self
.
condition_func_call_counter
=
0
def
condition_func
(
reference_parameters
,
test_
p
ar
ameter
_1
,
test_
p
ar
ameter
_2
):
def
condition_func
(
reference_parameters
,
test_
v
ar
iable
_1
,
test_
v
ar
iable
_2
):
self
.
condition_func_call_counter
+=
1
return
{
key
:
value
+
1
for
key
,
value
in
reference_parameters
.
items
()}
self
.
condition_func
=
condition_func
...
...
@@ -860,14 +865,17 @@ class TestConditionalPrior(unittest.TestCase):
self
.
maximum
=
5
self
.
test_parameter_1
=
0
self
.
test_parameter_2
=
1
self
.
prior
=
bilby
.
core
.
prior
.
ConditionalPrior
(
condition_func
=
condition_func
,
minimum
=
self
.
minimum
,
maximum
=
self
.
maximum
)
self
.
prior
=
bilby
.
core
.
prior
.
Conditional
Base
Prior
(
condition_func
=
condition_func
,
minimum
=
self
.
minimum
,
maximum
=
self
.
maximum
)
def
tearDown
(
self
):
del
self
.
condition_func
del
self
.
condition_func_call_counter
del
self
.
minimum
del
self
.
maximum
del
self
.
test_parameter_1
del
self
.
test_parameter_2
del
self
.
prior
def
test_reference_params
(
self
):
...
...
@@ -876,6 +884,11 @@ class TestConditionalPrior(unittest.TestCase):
def
test_required_variables
(
self
):
self
.
assertListEqual
([
'
test_parameter_1
'
,
'
test_parameter_2
'
],
sorted
(
self
.
prior
.
required_variables
))
def
test_required_variables_no_condition_func
(
self
):
self
.
prior
=
bilby
.
core
.
prior
.
ConditionalBasePrior
(
minimum
=
self
.
minimum
,
maximum
=
self
.
maximum
)
self
.
assertListEqual
([],
self
.
prior
.
required_variables
)
def
test_get_instantiation_dict
(
self
):
expected
=
dict
(
minimum
=
0
,
maximum
=
5
,
name
=
None
,
latex_label
=
None
,
unit
=
None
,
boundary
=
None
,
condition_func
=
self
.
condition_func
)
...
...
@@ -937,6 +950,133 @@ class TestConditionalPrior(unittest.TestCase):
self
.
assertEqual
(
self
.
prior
.
reference_params
[
'
minimum
'
],
self
.
prior
.
minimum
)
self
.
assertEqual
(
self
.
prior
.
reference_params
[
'
maximum
'
],
self
.
prior
.
maximum
)
def
test_cond_prior_instantiation_no_boundary_prior
(
self
):
prior
=
bilby
.
core
.
prior
.
ConditionalFermiDirac
(
sigma
=
1
)
self
.
assertIsNone
(
prior
.
boundary
)
class
TestConditionalPriorDict
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
condition_func_1
(
reference_parameters
,
var_0
):
return
reference_parameters
def
condition_func_2
(
reference_parameters
,
var_0
,
var_1
):
return
reference_parameters
def
condition_func_3
(
reference_parameters
,
var_1
,
var_2
):
return
reference_parameters
self
.
minimum
=
0
self
.
maximum
=
1
self
.
prior_0
=
bilby
.
core
.
prior
.
Uniform
(
minimum
=
self
.
minimum
,
maximum
=
self
.
maximum
)
self
.
prior_1
=
bilby
.
core
.
prior
.
ConditionalUniform
(
condition_func
=
condition_func_1
,
minimum
=
self
.
minimum
,
maximum
=
self
.
maximum
)
self
.
prior_2
=
bilby
.
core
.
prior
.
ConditionalUniform
(
condition_func
=
condition_func_2
,
minimum
=
self
.
minimum
,
maximum
=
self
.
maximum
)
self
.
prior_3
=
bilby
.
core
.
prior
.
ConditionalUniform
(
condition_func
=
condition_func_3
,
minimum
=
self
.
minimum
,
maximum
=
self
.
maximum
)
self
.
conditional_priors
=
bilby
.
core
.
prior
.
ConditionalPriorDict
(
dict
(
var_3
=
self
.
prior_3
,
var_2
=
self
.
prior_2
,
var_0
=
self
.
prior_0
,
var_1
=
self
.
prior_1
))
self
.
conditional_priors_manually_set_items
=
bilby
.
core
.
prior
.
ConditionalPriorDict
()
self
.
test_sample
=
dict
(
var_0
=
0.3
,
var_1
=
0.4
,
var_2
=
0.5
,
var_3
=
0.4
)
for
key
,
value
in
dict
(
var_0
=
self
.
prior_0
,
var_1
=
self
.
prior_1
,
var_2
=
self
.
prior_2
,
var_3
=
self
.
prior_3
).
items
():
self
.
conditional_priors_manually_set_items
[
key
]
=
value
def
tearDown
(
self
):
del
self
.
minimum
del
self
.
maximum
del
self
.
prior_0
del
self
.
prior_1
del
self
.
prior_2
del
self
.
prior_3
del
self
.
conditional_priors
del
self
.
conditional_priors_manually_set_items
del
self
.
test_sample
def
test_conditions_resolved_upon_instantiation
(
self
):
self
.
assertListEqual
([
'
var_0
'
,
'
var_1
'
,
'
var_2
'
,
'
var_3
'
],
self
.
conditional_priors
.
sorted_keys
)
def
test_conditions_resolved_setting_items
(
self
):
self
.
assertListEqual
([
'
var_0
'
,
'
var_1
'
,
'
var_2
'
,
'
var_3
'
],
self
.
conditional_priors_manually_set_items
.
sorted_keys
)
def
test_unconditional_keys_upon_instantiation
(
self
):
self
.
assertListEqual
([
'
var_0
'
],
self
.
conditional_priors
.
unconditional_keys
)
def
test_unconditional_keys_setting_items
(
self
):
self
.
assertListEqual
([
'
var_0
'
],
self
.
conditional_priors_manually_set_items
.
unconditional_keys
)
def
test_conditional_keys_upon_instantiation
(
self
):
self
.
assertListEqual
([
'
var_1
'
,
'
var_2
'
,
'
var_3
'
],
self
.
conditional_priors
.
conditional_keys
)
def
test_conditional_keys_setting_items
(
self
):
self
.
assertListEqual
([
'
var_1
'
,
'
var_2
'
,
'
var_3
'
],
self
.
conditional_priors_manually_set_items
.
conditional_keys
)
def
test_prob
(
self
):
self
.
assertEqual
(
1
,
self
.
conditional_priors
.
prob
(
sample
=
self
.
test_sample
))
def
test_prob_illegal_conditions
(
self
):
del
self
.
conditional_priors
[
'
var_0
'
]
with
self
.
assertRaises
(
bilby
.
core
.
prior
.
IllegalConditionsException
):
self
.
conditional_priors
.
prob
(
sample
=
self
.
test_sample
)
def
test_ln_prob
(
self
):
self
.
assertEqual
(
0
,
self
.
conditional_priors
.
ln_prob
(
sample
=
self
.
test_sample
))
def
test_ln_prob_illegal_conditions
(
self
):
del
self
.
conditional_priors
[
'
var_0
'
]
with
self
.
assertRaises
(
bilby
.
core
.
prior
.
IllegalConditionsException
):
self
.
conditional_priors
.
ln_prob
(
sample
=
self
.
test_sample
)
def
test_sample_subset_all_keys
(
self
):
with
mock
.
patch
(
"
numpy.random.uniform
"
)
as
m
:
m
.
return_value
=
0.5
self
.
assertDictEqual
(
dict
(
var_0
=
0.5
,
var_1
=
0.5
,
var_2
=
0.5
,
var_3
=
0.5
),
self
.
conditional_priors
.
sample_subset
(
keys
=
[
'
var_0
'
,
'
var_1
'
,
'
var_2
'
,
'
var_3
'
]))
def
test_sample_illegal_subset
(
self
):
with
mock
.
patch
(
"
numpy.random.uniform
"
)
as
m
:
m
.
return_value
=
0.5
with
self
.
assertRaises
(
bilby
.
core
.
prior
.
IllegalConditionsException
):
self
.
conditional_priors
.
sample_subset
(
keys
=
[
'
var_1
'
])
def
test_rescale
(
self
):
def
condition_func_1_rescale
(
reference_parameters
,
var_0
):
if
var_0
==
0.5
:
return
dict
(
minimum
=
reference_parameters
[
'
minimum
'
],
maximum
=
1
)
return
reference_parameters
def
condition_func_2_rescale
(
reference_parameters
,
var_0
,
var_1
):
if
var_0
==
0.5
and
var_1
==
0.5
:
return
dict
(
minimum
=
reference_parameters
[
'
minimum
'
],
maximum
=
1
)
return
reference_parameters
def
condition_func_3_rescale
(
reference_parameters
,
var_1
,
var_2
):
if
var_1
==
0.5
and
var_2
==
0.5
:
return
dict
(
minimum
=
reference_parameters
[
'
minimum
'
],
maximum
=
1
)
return
reference_parameters
self
.
prior_0
=
bilby
.
core
.
prior
.
Uniform
(
minimum
=
self
.
minimum
,
maximum
=
1
)
self
.
prior_1
=
bilby
.
core
.
prior
.
ConditionalUniform
(
condition_func
=
condition_func_1_rescale
,
minimum
=
self
.
minimum
,
maximum
=
2
)
self
.
prior_2
=
bilby
.
core
.
prior
.
ConditionalUniform
(
condition_func
=
condition_func_2_rescale
,
minimum
=
self
.
minimum
,
maximum
=
2
)
self
.
prior_3
=
bilby
.
core
.
prior
.
ConditionalUniform
(
condition_func
=
condition_func_3_rescale
,
minimum
=
self
.
minimum
,
maximum
=
2
)
self
.
conditional_priors
=
bilby
.
core
.
prior
.
ConditionalPriorDict
(
dict
(
var_3
=
self
.
prior_3
,
var_2
=
self
.
prior_2
,
var_0
=
self
.
prior_0
,
var_1
=
self
.
prior_1
))
ref_variables
=
[
0.5
,
0.5
,
0.5
,
0.5
]
res
=
self
.
conditional_priors
.
rescale
(
keys
=
list
(
self
.
test_sample
.
keys
()),
theta
=
ref_variables
)
self
.
assertListEqual
(
ref_variables
,
res
)
def
test_rescale_illegal_conditions
(
self
):
del
self
.
conditional_priors
[
'
var_0
'
]
with
self
.
assertRaises
(
bilby
.
core
.
prior
.
IllegalConditionsException
):
self
.
conditional_priors
.
rescale
(
keys
=
list
(
self
.
test_sample
.
keys
()),
theta
=
list
(
self
.
test_sample
.
values
()))
class
TestJsonIO
(
unittest
.
TestCase
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment