Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bilby
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
lscsoft
bilby
Commits
4af071b1
Commit
4af071b1
authored
3 years ago
by
Michael Williams
Committed by
Colm Talbot
3 years ago
Browse files
Options
Downloads
Patches
Plain Diff
Update for nessai v0.4.0
parent
6cf9a859
No related branches found
No related tags found
1 merge request
!1042
Update for nessai v0.4.0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
bilby/core/sampler/nessai.py
+38
-50
38 additions, 50 deletions
bilby/core/sampler/nessai.py
test/core/sampler/nessai_test.py
+9
-54
9 additions, 54 deletions
test/core/sampler/nessai_test.py
with
47 additions
and
104 deletions
bilby/core/sampler/nessai.py
+
38
−
50
View file @
4af071b1
import
numpy
as
np
import
os
from
pandas
import
DataFrame
from
.base_sampler
import
NestedSampler
...
...
@@ -15,55 +16,42 @@ class Nessai(NestedSampler):
Documentation: https://nessai.readthedocs.io/
"""
default_kwargs
=
dict
(
output
=
None
,
nlive
=
1000
,
stopping
=
0.1
,
resume
=
True
,
max_iteration
=
None
,
checkpointing
=
True
,
seed
=
1234
,
acceptance_threshold
=
0.01
,
analytic_priors
=
False
,
maximum_uninformed
=
1000
,
uninformed_proposal
=
None
,
uninformed_proposal_kwargs
=
None
,
flow_class
=
None
,
flow_config
=
None
,
training_frequency
=
None
,
reset_weights
=
False
,
reset_permutations
=
False
,
reset_acceptance
=
False
,
train_on_empty
=
True
,
cooldown
=
100
,
memory
=
False
,
poolsize
=
None
,
drawsize
=
None
,
max_poolsize_scale
=
10
,
update_poolsize
=
False
,
latent_prior
=
'
truncated_gaussian
'
,
draw_latent_kwargs
=
None
,
compute_radius_with_all
=
False
,
min_radius
=
False
,
max_radius
=
50
,
check_acceptance
=
False
,
fuzz
=
1.0
,
expansion_fraction
=
1.0
,
rescale_parameters
=
True
,
rescale_bounds
=
[
-
1
,
1
],
update_bounds
=
False
,
boundary_inversion
=
False
,
inversion_type
=
'
split
'
,
detect_edges
=
False
,
detect_edges_kwargs
=
None
,
reparameterisations
=
None
,
n_pool
=
None
,
max_threads
=
1
,
pytorch_threads
=
None
,
plot
=
None
,
proposal_plots
=
False
)
_default_kwargs
=
None
seed_equiv_kwargs
=
[
'
sampling_seed
'
]
@property
def
default_kwargs
(
self
):
"""
Default kwargs for nessai.
Retrieves default values from nessai directly and then includes any
bilby specific defaults. This avoids the need to update bilby when the
defaults change or new kwargs are added to nessai.
"""
if
not
self
.
_default_kwargs
:
from
inspect
import
signature
from
nessai.flowsampler
import
FlowSampler
from
nessai.nestedsampler
import
NestedSampler
from
nessai.proposal
import
AugmentedFlowProposal
,
FlowProposal
kwargs
=
{}
classes
=
[
AugmentedFlowProposal
,
FlowProposal
,
NestedSampler
,
FlowSampler
,
]
for
c
in
classes
:
kwargs
.
update
(
{
k
:
v
.
default
for
k
,
v
in
signature
(
c
).
parameters
.
items
()
if
v
.
default
is
not
v
.
empty
}
)
# Defaults for bilby that will override nessai defaults
bilby_defaults
=
dict
(
output
=
None
,
)
kwargs
.
update
(
bilby_defaults
)
self
.
_default_kwargs
=
kwargs
return
self
.
_default_kwargs
def
log_prior
(
self
,
theta
):
"""
...
...
@@ -194,9 +182,9 @@ class Nessai(NestedSampler):
self
.
kwargs
[
'
n_pool
'
]
=
None
if
not
self
.
kwargs
[
'
output
'
]:
self
.
kwargs
[
'
output
'
]
=
self
.
outdir
+
'
/{}_nessai/
'
.
format
(
self
.
label
)
if
self
.
kwargs
[
'
output
'
].
endswith
(
'
/
'
)
is
False
:
self
.
kwargs
[
'
output
'
]
=
'
{}/
'
.
format
(
self
.
kwargs
[
'
output
'
]
)
self
.
kwargs
[
'
output
'
]
=
os
.
path
.
join
(
self
.
outdir
,
'
{}_nessai
'
.
format
(
self
.
label
),
''
)
check_directory_exists_and_if_not_mkdir
(
self
.
kwargs
[
'
output
'
])
NestedSampler
.
_verify_kwargs_against_default_kwargs
(
self
)
This diff is collapsed.
Click to expand it.
test/core/sampler/nessai_test.py
+
9
−
54
View file @
4af071b1
...
...
@@ -22,53 +22,8 @@ class TestNessai(unittest.TestCase):
plot
=
False
,
skip_import_verification
=
True
,
)
self
.
expected
=
dict
(
output
=
"
outdir/label_nessai/
"
,
nlive
=
1000
,
stopping
=
0.1
,
resume
=
True
,
max_iteration
=
None
,
checkpointing
=
True
,
seed
=
1234
,
acceptance_threshold
=
0.01
,
analytic_priors
=
False
,
maximum_uninformed
=
1000
,
uninformed_proposal
=
None
,
uninformed_proposal_kwargs
=
None
,
flow_class
=
None
,
flow_config
=
None
,
training_frequency
=
None
,
reset_weights
=
False
,
reset_permutations
=
False
,
reset_acceptance
=
False
,
train_on_empty
=
True
,
cooldown
=
100
,
memory
=
False
,
poolsize
=
None
,
drawsize
=
None
,
max_poolsize_scale
=
10
,
update_poolsize
=
False
,
latent_prior
=
'
truncated_gaussian
'
,
draw_latent_kwargs
=
None
,
compute_radius_with_all
=
False
,
min_radius
=
False
,
max_radius
=
50
,
check_acceptance
=
False
,
fuzz
=
1.0
,
expansion_fraction
=
1.0
,
rescale_parameters
=
True
,
rescale_bounds
=
[
-
1
,
1
],
update_bounds
=
False
,
boundary_inversion
=
False
,
inversion_type
=
'
split
'
,
detect_edges
=
False
,
detect_edges_kwargs
=
None
,
reparameterisations
=
None
,
n_pool
=
None
,
max_threads
=
1
,
pytorch_threads
=
None
,
plot
=
False
,
proposal_plots
=
False
)
self
.
expected
=
self
.
sampler
.
default_kwargs
self
.
expected
[
'
output
'
]
=
'
outdir/label_nessai/
'
def
tearDown
(
self
):
del
self
.
likelihood
...
...
@@ -76,16 +31,16 @@ class TestNessai(unittest.TestCase):
del
self
.
sampler
del
self
.
expected
def
test_default_kwargs
(
self
):
expected
=
self
.
expected
.
copy
()
self
.
assertDictEqual
(
expected
,
self
.
sampler
.
kwargs
)
def
test_translate_kwargs_nlive
(
self
):
expected
=
self
.
expected
.
copy
()
# nlive in the default kwargs is not a fixed but depends on the
# version of nessai, so get the value here and use it when setting
# the equivalent kwargs.
nlive
=
expected
[
"
nlive
"
]
for
equiv
in
bilby
.
core
.
sampler
.
base_sampler
.
NestedSampler
.
npoints_equiv_kwargs
:
new_kwargs
=
self
.
sampler
.
kwargs
.
copy
()
del
new_kwargs
[
"
nlive
"
]
new_kwargs
[
equiv
]
=
1000
new_kwargs
[
equiv
]
=
nlive
self
.
sampler
.
kwargs
=
new_kwargs
self
.
assertDictEqual
(
expected
,
self
.
sampler
.
kwargs
)
...
...
@@ -117,10 +72,10 @@ class TestNessai(unittest.TestCase):
self
.
sampler
.
kwargs
=
new_kwargs
self
.
assertDictEqual
(
expected
,
self
.
sampler
.
kwargs
)
@patch
(
"
builtins.open
"
,
mock_open
(
read_data
=
'
{
"
nlive
"
:
2
000}
'
))
@patch
(
"
builtins.open
"
,
mock_open
(
read_data
=
'
{
"
nlive
"
:
4
000}
'
))
def
test_update_from_config_file
(
self
):
expected
=
self
.
expected
.
copy
()
expected
[
"
nlive
"
]
=
2
000
expected
[
"
nlive
"
]
=
4
000
new_kwargs
=
self
.
expected
.
copy
()
new_kwargs
[
"
config_file
"
]
=
"
config_file.json
"
self
.
sampler
.
kwargs
=
new_kwargs
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment