Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bilby
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Sylvia Biscoveanu
bilby
Commits
be3a886a
Commit
be3a886a
authored
6 years ago
by
Colm Talbot
Committed by
Gregory Ashton
6 years ago
Browse files
Options
Downloads
Patches
Plain Diff
basic checkpointing and resuming
parent
747ebded
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
CHANGELOG.md
+1
-0
1 addition, 0 deletions
CHANGELOG.md
bilby/core/sampler/emcee.py
+70
-18
70 additions, 18 deletions
bilby/core/sampler/emcee.py
with
71 additions
and
18 deletions
CHANGELOG.md
+
1
−
0
View file @
be3a886a
...
...
@@ -3,6 +3,7 @@
## Unreleased
### Added
-
`emcee`
now writes all progress to disk and can resume from a previous run.
-
### Changed
...
...
This diff is collapsed.
Click to expand it.
bilby/core/sampler/emcee.py
+
70
−
18
View file @
be3a886a
from
__future__
import
absolute_import
,
print_function
import
os
import
numpy
as
np
from
pandas
import
DataFrame
from
distutils.version
import
LooseVersion
from
..utils
import
logger
,
get_progress_bar
from
..utils
import
(
logger
,
get_progress_bar
,
check_directory_exists_and_if_not_mkdir
)
from
.base_sampler
import
MCMCSampler
,
SamplerError
...
...
@@ -41,19 +44,23 @@ class Emcee(MCMCSampler):
default_kwargs
=
dict
(
nwalkers
=
500
,
a
=
2
,
args
=
[],
kwargs
=
{},
postargs
=
None
,
pool
=
None
,
live_dangerously
=
False
,
runtime_sortingfn
=
None
,
lnprob0
=
None
,
rstate0
=
None
,
blobs0
=
None
,
iterations
=
100
,
thin
=
1
,
storechain
=
True
,
mh_proposal
=
None
)
blobs0
=
None
,
iterations
=
100
,
thin
=
1
,
storechain
=
True
,
mh_proposal
=
None
)
def
__init__
(
self
,
likelihood
,
priors
,
outdir
=
'
outdir
'
,
label
=
'
label
'
,
use_ratio
=
False
,
plot
=
False
,
skip_import_verification
=
False
,
pos0
=
None
,
nburn
=
None
,
burn_in_fraction
=
0.25
,
def
__init__
(
self
,
likelihood
,
priors
,
outdir
=
'
outdir
'
,
label
=
'
label
'
,
use_ratio
=
False
,
plot
=
False
,
skip_import_verification
=
False
,
pos0
=
None
,
nburn
=
None
,
burn_in_fraction
=
0.25
,
resume
=
True
,
burn_in_act
=
3
,
**
kwargs
):
MCMCSampler
.
__init__
(
self
,
likelihood
=
likelihood
,
priors
=
priors
,
outdir
=
outdir
,
label
=
label
,
use_ratio
=
use_ratio
,
plot
=
plot
,
skip_import_verification
=
skip_import_verification
,
**
kwargs
)
MCMCSampler
.
__init__
(
self
,
likelihood
=
likelihood
,
priors
=
priors
,
outdir
=
outdir
,
label
=
label
,
use_ratio
=
use_ratio
,
plot
=
plot
,
skip_import_verification
=
skip_import_verification
,
**
kwargs
)
self
.
resume
=
resume
self
.
pos0
=
pos0
self
.
nburn
=
nburn
self
.
burn_in_fraction
=
burn_in_fraction
self
.
burn_in_act
=
burn_in_act
self
.
_old_chain
=
None
def
_translate_kwargs
(
self
,
kwargs
):
if
'
nwalkers
'
not
in
kwargs
:
...
...
@@ -168,23 +175,54 @@ class Emcee(MCMCSampler):
import
emcee
tqdm
=
get_progress_bar
()
sampler
=
emcee
.
EnsembleSampler
(
**
self
.
sampler_init_kwargs
)
self
.
_set_pos0
()
for
_
in
tqdm
(
sampler
.
sample
(
**
self
.
sampler_function_kwargs
),
total
=
self
.
nsteps
):
pass
out_dir
=
os
.
path
.
join
(
self
.
outdir
,
'
emcee_{}
'
.
format
(
self
.
label
))
out_file
=
os
.
path
.
join
(
out_dir
,
'
chain.dat
'
)
if
self
.
resume
:
self
.
load_old_chain
(
out_file
)
else
:
self
.
_set_pos0
()
check_directory_exists_and_if_not_mkdir
(
out_dir
)
if
not
os
.
path
.
isfile
(
out_file
):
with
open
(
out_file
,
"
w
"
)
as
ff
:
ff
.
write
(
'
walker
\t
{}
\t
log_l
'
.
format
(
'
\t
'
.
join
(
self
.
search_parameter_keys
)))
template
=
\
'
{:d}
'
+
'
\t
{:.9e}
'
*
(
len
(
self
.
search_parameter_keys
)
+
2
)
+
'
\n
'
for
sample
in
tqdm
(
sampler
.
sample
(
**
self
.
sampler_function_kwargs
),
total
=
self
.
nsteps
):
points
=
np
.
hstack
([
sample
[
0
],
np
.
array
(
sample
[
3
])])
# import IPython; IPython.embed()
with
open
(
out_file
,
"
a
"
)
as
ff
:
for
ii
,
point
in
enumerate
(
points
):
ff
.
write
(
template
.
format
(
ii
,
*
point
))
self
.
result
.
sampler_output
=
np
.
nan
self
.
calculate_autocorrelation
(
sampler
.
chain
.
reshape
((
-
1
,
self
.
ndim
)))
blobs_flat
=
np
.
array
(
sampler
.
blobs
).
reshape
((
-
1
,
2
))
log_likelihoods
,
log_priors
=
blobs_flat
.
T
if
self
.
_old_chain
is
not
None
:
chain
=
np
.
vstack
([
self
.
_old_chain
[:,
:
-
2
],
sampler
.
chain
.
reshape
((
-
1
,
self
.
ndim
))])
log_ls
=
np
.
hstack
([
self
.
_old_chain
[:,
-
2
],
log_likelihoods
])
log_ps
=
np
.
hstack
([
self
.
_old_chain
[:,
-
1
],
log_priors
])
self
.
nsteps
=
chain
.
shape
[
0
]
//
self
.
nwalkers
else
:
chain
=
sampler
.
chain
.
reshape
((
-
1
,
self
.
ndim
))
log_ls
=
log_likelihoods
log_ps
=
log_priors
self
.
calculate_autocorrelation
(
chain
)
self
.
print_nburn_logging_info
()
self
.
result
.
nburn
=
self
.
nburn
n_samples
=
self
.
nwalkers
*
self
.
nburn
if
self
.
result
.
nburn
>
self
.
nsteps
:
raise
SamplerError
(
"
The run has finished, but the chain is not burned in:
"
"
`nburn < nsteps`. Try increasing the number of steps.
"
)
self
.
result
.
samples
=
sampler
.
chain
[:,
self
.
nburn
:,
:].
reshape
((
-
1
,
self
.
ndim
))
blobs_flat
=
np
.
array
(
sampler
.
blobs
)[
self
.
nburn
:,
:,
:].
reshape
((
-
1
,
2
))
log_likelihoods
,
log_priors
=
blobs_flat
.
T
self
.
result
.
log_likelihood_evaluations
=
log_likelihoods
self
.
result
.
log_prior_evaluations
=
log_priors
self
.
result
.
samples
=
chain
[
n_samples
:,
:]
self
.
result
.
log_likelihood_evaluations
=
log_ls
[
n_samples
:]
self
.
result
.
log_prior_evaluations
=
log_ps
[
n_samples
:]
self
.
result
.
walkers
=
sampler
.
chain
self
.
result
.
log_evidence
=
np
.
nan
self
.
result
.
log_evidence_err
=
np
.
nan
...
...
@@ -209,6 +247,20 @@ class Emcee(MCMCSampler):
self
.
pos0
=
[
self
.
get_random_draw_from_prior
()
for
_
in
range
(
self
.
nwalkers
)]
def
load_old_chain
(
self
,
file_name
=
None
):
if
file_name
is
None
:
out_dir
=
os
.
path
.
join
(
self
.
outdir
,
'
emcee_{}
'
.
format
(
self
.
label
))
file_name
=
os
.
path
.
join
(
out_dir
,
'
chain.dat
'
)
if
os
.
path
.
isfile
(
file_name
):
old_chain
=
np
.
genfromtxt
(
file_name
,
skip_header
=
1
)
self
.
pos0
=
[
np
.
squeeze
(
old_chain
[
-
(
self
.
nwalkers
-
ii
),
1
:
-
2
])
for
ii
in
range
(
self
.
nwalkers
)]
self
.
_old_chain
=
old_chain
[:
-
self
.
nwalkers
+
1
,
1
:]
logger
.
info
(
'
Resuming from {}
'
.
format
(
os
.
path
.
abspath
(
file_name
)))
else
:
logger
.
warning
(
'
Failed to resume. {} not found.
'
.
format
(
file_name
))
self
.
_set_pos0
()
def
lnpostfn
(
self
,
theta
):
log_prior
=
self
.
log_prior
(
theta
)
if
np
.
isinf
(
log_prior
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment