Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bilby
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
lscsoft
bilby
Commits
61c10675
Commit
61c10675
authored
4 years ago
by
Gregory Ashton
Browse files
Options
Downloads
Plain Diff
Merge branch 'master' into 492-minimum-python-version-should-be-3-6
parents
93771200
4893cd83
No related branches found
No related tags found
1 merge request
!811
Revert f-strings to keep python 3.5 compatibility
Pipeline
#135009
passed
4 years ago
Stage: test
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
bilby/core/sampler/pymc3.py
+59
-9
59 additions, 9 deletions
bilby/core/sampler/pymc3.py
test/sampler_test.py
+8
-4
8 additions, 4 deletions
test/sampler_test.py
with
67 additions
and
13 deletions
bilby/core/sampler/pymc3.py
+
59
−
9
View file @
61c10675
from
__future__
import
absolute_import
,
print_function
from
collections
import
OrderedDict
from
distutils.version
import
StrictVersion
import
numpy
as
np
...
...
@@ -45,8 +46,6 @@ class Pymc3(MCMCSampler):
'
CategoricalGibbsMetropolis
'
. Note: you cannot provide a PyMC3 step
method function itself here as it is outside of the model context
manager.
nuts_kwargs: dict
Keyword arguments for the NUTS sampler.
step_kwargs: dict
Options for steps methods other than NUTS. The dictionary is keyed on
lowercase step method names with values being dictionaries of keywords
...
...
@@ -56,13 +55,27 @@ class Pymc3(MCMCSampler):
default_kwargs
=
dict
(
draws
=
500
,
step
=
None
,
init
=
'
auto
'
,
n_init
=
200000
,
start
=
None
,
trace
=
None
,
chain_idx
=
0
,
chains
=
2
,
cores
=
1
,
tune
=
500
,
nuts_kwargs
=
None
,
step_kwargs
=
None
,
progressbar
=
True
,
model
=
None
,
random_seed
=
None
,
discard_tuned_samples
=
True
,
compute_convergence_checks
=
True
)
chains
=
2
,
cores
=
1
,
tune
=
500
,
progressbar
=
True
,
model
=
None
,
random_seed
=
None
,
discard_tuned_samples
=
True
,
compute_convergence_checks
=
True
,
nuts_kwargs
=
None
,
step_kwargs
=
None
,
)
default_nuts_kwargs
=
dict
(
target_accept
=
None
,
max_treedepth
=
None
,
step_scale
=
None
,
Emax
=
None
,
gamma
=
None
,
k
=
None
,
t0
=
None
,
adapt_step_size
=
None
,
early_max_treedepth
=
None
,
scaling
=
None
,
is_cov
=
None
,
potential
=
None
,
)
default_kwargs
.
update
(
default_nuts_kwargs
)
def
__init__
(
self
,
likelihood
,
priors
,
outdir
=
'
outdir
'
,
label
=
'
label
'
,
use_ratio
=
False
,
plot
=
False
,
skip_import_verification
=
False
,
**
kwargs
):
# add default step kwargs
_
,
STEP_METHODS
,
_
=
self
.
_import_external_sampler
()
self
.
default_step_kwargs
=
{
m
.
__name__
.
lower
():
None
for
m
in
STEP_METHODS
}
self
.
default_kwargs
.
update
(
self
.
default_step_kwargs
)
super
(
Pymc3
,
self
).
__init__
(
likelihood
=
likelihood
,
priors
=
priors
,
outdir
=
outdir
,
label
=
label
,
use_ratio
=
use_ratio
,
plot
=
plot
,
skip_import_verification
=
skip_import_verification
,
**
kwargs
)
...
...
@@ -454,8 +467,35 @@ class Pymc3(MCMCSampler):
self
.
set_likelihood
()
# get the step method keyword arguments
step_kwargs
=
self
.
kwargs
.
pop
(
'
step_kwargs
'
)
nuts_kwargs
=
self
.
kwargs
.
pop
(
'
nuts_kwargs
'
)
step_kwargs
=
self
.
kwargs
.
pop
(
"
step_kwargs
"
)
if
step_kwargs
is
not
None
:
# remove all individual default step kwargs if passed together using
# step_kwargs keywords
for
key
in
self
.
default_step_kwargs
:
self
.
kwargs
.
pop
(
key
)
else
:
# remove any None default step keywords and place others in step_kwargs
step_kwargs
=
{}
for
key
in
self
.
default_step_kwargs
:
if
self
.
kwargs
[
key
]
is
None
:
self
.
kwargs
.
pop
(
key
)
else
:
step_kwargs
[
key
]
=
self
.
kwargs
.
pop
(
key
)
nuts_kwargs
=
self
.
kwargs
.
pop
(
"
nuts_kwargs
"
)
if
nuts_kwargs
is
not
None
:
# remove all individual default nuts kwargs if passed together using
# nuts_kwargs keywords
for
key
in
self
.
default_nuts_kwargs
:
self
.
kwargs
.
pop
(
key
)
else
:
# remove any None default nuts keywords and place others in nut_kwargs
nuts_kwargs
=
{}
for
key
in
self
.
default_nuts_kwargs
:
if
self
.
kwargs
[
key
]
is
None
:
self
.
kwargs
.
pop
(
key
)
else
:
nuts_kwargs
[
key
]
=
self
.
kwargs
.
pop
(
key
)
methodslist
=
[]
# set the step method
...
...
@@ -496,13 +536,19 @@ class Pymc3(MCMCSampler):
self
.
kwargs
[
'
step
'
]
=
pymc3
.
__dict__
[
step_methods
[
curmethod
]](
**
args
)
else
:
# re-add step_kwargs if no step methods are set
self
.
kwargs
[
'
step_kwargs
'
]
=
step_kwargs
if
len
(
step_kwargs
)
>
0
and
StrictVersion
(
pymc3
.
__version__
)
<
StrictVersion
(
"
3.7
"
):
self
.
kwargs
[
'
step_kwargs
'
]
=
step_kwargs
# check whether only NUTS step method has been assigned
if
np
.
all
([
sm
.
lower
()
==
'
nuts
'
for
sm
in
methodslist
]):
# in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
self
.
kwargs
[
'
step
'
]
=
None
self
.
kwargs
[
'
nuts_kwargs
'
]
=
nuts_kwargs
if
len
(
nuts_kwargs
)
>
0
and
StrictVersion
(
pymc3
.
__version__
)
<
StrictVersion
(
"
3.7
"
):
self
.
kwargs
[
'
nuts_kwargs
'
]
=
nuts_kwargs
elif
len
(
nuts_kwargs
)
>
0
:
# add NUTS kwargs to standard kwargs
self
.
kwargs
.
update
(
nuts_kwargs
)
with
self
.
pymc3_model
:
# perform the sampling
...
...
@@ -561,6 +607,10 @@ class Pymc3(MCMCSampler):
args
=
{}
return
args
,
nuts_kwargs
def
_pymc3_version
(
self
):
pymc3
,
_
,
_
=
self
.
_import_external_sampler
()
return
pymc3
.
__version__
def
set_prior
(
self
):
"""
Set the PyMC3 prior distributions.
...
...
This diff is collapsed.
Click to expand it.
test/sampler_test.py
+
8
−
4
View file @
61c10675
...
...
@@ -695,14 +695,16 @@ class TestPyMC3(unittest.TestCase):
chains
=
2
,
cores
=
1
,
tune
=
500
,
nuts_kwargs
=
None
,
step_kwargs
=
None
,
progressbar
=
True
,
model
=
None
,
nuts_kwargs
=
None
,
step_kwargs
=
None
,
random_seed
=
None
,
discard_tuned_samples
=
True
,
compute_convergence_checks
=
True
,
)
expected
.
update
(
self
.
sampler
.
default_nuts_kwargs
)
expected
.
update
(
self
.
sampler
.
default_step_kwargs
)
self
.
assertDictEqual
(
expected
,
self
.
sampler
.
kwargs
)
def
test_translate_kwargs
(
self
):
...
...
@@ -717,14 +719,16 @@ class TestPyMC3(unittest.TestCase):
chains
=
2
,
cores
=
1
,
tune
=
500
,
nuts_kwargs
=
None
,
step_kwargs
=
None
,
progressbar
=
True
,
model
=
None
,
nuts_kwargs
=
None
,
step_kwargs
=
None
,
random_seed
=
None
,
discard_tuned_samples
=
True
,
compute_convergence_checks
=
True
,
)
expected
.
update
(
self
.
sampler
.
default_nuts_kwargs
)
expected
.
update
(
self
.
sampler
.
default_step_kwargs
)
self
.
sampler
.
kwargs
[
"
draws
"
]
=
123
for
equiv
in
bilby
.
core
.
sampler
.
base_sampler
.
NestedSampler
.
npoints_equiv_kwargs
:
new_kwargs
=
self
.
sampler
.
kwargs
.
copy
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment