Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bilby
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Sylvia Biscoveanu
bilby
Commits
ad8a0d68
Commit
ad8a0d68
authored
4 years ago
by
Matthew David Pitkin
Committed by
Gregory Ashton
4 years ago
Browse files
Options
Downloads
Patches
Plain Diff
Apply similar changes to those in !804 to help file transfer on
Condor for ultranest
parent
ed8bef95
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
bilby/core/sampler/ultranest.py
+108
-37
108 additions, 37 deletions
bilby/core/sampler/ultranest.py
test/sampler_test.py
+6
-4
6 additions, 4 deletions
test/sampler_test.py
with
114 additions
and
41 deletions
bilby/core/sampler/ultranest.py
+
108
−
37
View file @
ad8a0d68
from
__future__
import
absolute_import
import
datetime
import
distutils.dir_util
import
inspect
import
os
import
shutil
import
signal
import
tempfile
import
time
import
numpy
as
np
from
pandas
import
DataFrame
...
...
@@ -59,7 +63,7 @@ class Ultranest(NestedSampler):
dlogz
=
None
,
max_iters
=
None
,
update_interval_iter_fraction
=
0.2
,
viz_callback
=
"
auto
"
,
viz_callback
=
None
,
dKL
=
0.5
,
frac_remain
=
0.01
,
Lepsilon
=
0.001
,
...
...
@@ -81,6 +85,8 @@ class Ultranest(NestedSampler):
plot
=
False
,
exit_code
=
77
,
skip_import_verification
=
False
,
temporary_directory
=
True
,
callback_interval
=
10
,
**
kwargs
,
):
super
(
Ultranest
,
self
).
__init__
(
...
...
@@ -95,6 +101,12 @@ class Ultranest(NestedSampler):
**
kwargs
,
)
self
.
_apply_ultranest_boundaries
()
self
.
use_temporary_directory
=
temporary_directory
if
self
.
use_temporary_directory
:
# set callback interval, so copying of results does not thrash the
# disk (ultranest will call viz_callback quite a lot)
self
.
callback_interval
=
callback_interval
signal
.
signal
(
signal
.
SIGTERM
,
self
.
write_current_state_and_exit
)
signal
.
signal
(
signal
.
SIGINT
,
self
.
write_current_state_and_exit
)
...
...
@@ -113,9 +125,18 @@ class Ultranest(NestedSampler):
"""
Check the kwargs
"""
self
.
outputfiles_basename
=
self
.
kwargs
.
pop
(
"
log_dir
"
,
None
)
if
self
.
kwargs
[
"
viz_callback
"
]
is
None
:
self
.
kwargs
[
"
viz_callback
"
]
=
self
.
_viz_callback
NestedSampler
.
_verify_kwargs_against_default_kwargs
(
self
)
def
_viz_callback
(
self
,
*
args
,
**
kwargs
):
if
self
.
use_temporary_directory
:
if
not
(
self
.
_viz_callback_counter
%
self
.
callback_interval
):
self
.
_copy_temporary_directory_contents_to_proper_path
()
self
.
_calculate_and_save_sampling_time
()
self
.
_viz_callback_counter
+=
1
def
_apply_ultranest_boundaries
(
self
):
if
(
self
.
kwargs
[
"
wrapped_params
"
]
is
None
...
...
@@ -136,9 +157,11 @@ class Ultranest(NestedSampler):
@outputfiles_basename.setter
def
outputfiles_basename
(
self
,
outputfiles_basename
):
if
outputfiles_basename
is
None
:
outputfiles_basename
=
"
{}/ultra_{}
"
.
format
(
self
.
outdir
,
self
.
label
)
if
outputfiles_basename
.
endswith
(
"
/
"
)
is
True
:
outputfiles_basename
=
outputfiles_basename
.
rstrip
(
"
/
"
)
outputfiles_basename
=
os
.
path
.
join
(
self
.
outdir
,
"
ultra_{}/
"
.
format
(
self
.
label
)
)
if
not
outputfiles_basename
.
endswith
(
"
/
"
):
outputfiles_basename
+=
"
/
"
check_directory_exists_and_if_not_mkdir
(
self
.
outdir
)
self
.
_outputfiles_basename
=
outputfiles_basename
...
...
@@ -148,7 +171,7 @@ class Ultranest(NestedSampler):
@temporary_outputfiles_basename.setter
def
temporary_outputfiles_basename
(
self
,
temporary_outputfiles_basename
):
if
temporary_outputfiles_basename
.
endswith
(
"
/
"
)
is
False
:
if
not
temporary_outputfiles_basename
.
endswith
(
"
/
"
):
temporary_outputfiles_basename
=
"
{}/
"
.
format
(
temporary_outputfiles_basename
)
...
...
@@ -157,10 +180,6 @@ class Ultranest(NestedSampler):
shutil
.
copytree
(
self
.
outputfiles_basename
,
self
.
temporary_outputfiles_basename
)
if
os
.
path
.
islink
(
self
.
outputfiles_basename
):
os
.
unlink
(
self
.
outputfiles_basename
)
else
:
shutil
.
rmtree
(
self
.
outputfiles_basename
)
def
write_current_state_and_exit
(
self
,
signum
=
None
,
frame
=
None
):
"""
Write current state and exit on exit_code
"""
...
...
@@ -169,24 +188,38 @@ class Ultranest(NestedSampler):
signum
,
self
.
exit_code
)
)
# self.copy_temporary_directory_to_proper_path()
self
.
_calculate_and_save_sampling_time
()
if
self
.
use_temporary_directory
:
self
.
_move_temporary_directory_to_proper_path
()
os
.
_exit
(
self
.
exit_code
)
def
copy_temporary_directory_to_proper_path
(
self
):
logger
.
info
(
"
Overwriting {} with {}
"
.
format
(
self
.
outputfiles_basename
,
self
.
temporary_outputfiles_basename
def
_copy_temporary_directory_contents_to_proper_path
(
self
):
"""
Copy the temporary back to the proper path.
Do not delete the temporary directory.
"""
if
inspect
.
stack
()[
1
].
function
!=
"
_viz_callback
"
:
logger
.
info
(
"
Overwriting {} with {}
"
.
format
(
self
.
outputfiles_basename
,
self
.
temporary_outputfiles_basename
)
)
if
self
.
outputfiles_basename
.
endswith
(
"
/
"
):
outputfiles_basename_stripped
=
self
.
outputfiles_basename
[:
-
1
]
else
:
outputfiles_basename_stripped
=
self
.
outputfiles_basename
distutils
.
dir_util
.
copy_tree
(
self
.
temporary_outputfiles_basename
,
outputfiles_basename_stripped
)
# First remove anything in the outputfiles_basename for overwriting
if
os
.
path
.
exists
(
self
.
outputfiles_basename
):
if
os
.
path
.
islink
(
self
.
outputfiles_basename
):
os
.
unlink
(
self
.
outputfiles_basename
)
else
:
shutil
.
rmtree
(
self
.
outputfiles_basename
,
ignore_errors
=
True
)
def
_move_temporary_directory_to_proper_path
(
self
):
"""
Move the temporary back to the proper path
shutil
.
copytree
(
self
.
temporary_outputfiles_basename
,
self
.
outputfiles_basename
)
Anything in the proper path at this point is removed including links
"""
self
.
_copy_temporary_directory_contents_to_proper_path
()
shutil
.
rmtree
(
self
.
temporary_outputfiles_basename
)
@property
def
sampler_function_kwargs
(
self
):
...
...
@@ -253,19 +286,8 @@ class Ultranest(NestedSampler):
stepsampler
=
self
.
kwargs
.
pop
(
"
step_sampler
"
,
None
)
temporary_outputfiles_basename
=
tempfile
.
TemporaryDirectory
().
name
self
.
temporary_outputfiles_basename
=
temporary_outputfiles_basename
logger
.
info
(
"
Using temporary file {}
"
.
format
(
temporary_outputfiles_basename
))
check_directory_exists_and_if_not_mkdir
(
temporary_outputfiles_basename
)
self
.
kwargs
[
"
log_dir
"
]
=
self
.
temporary_outputfiles_basename
# Symlink the temporary directory with the target directory: ensures data is stored on exit
os
.
symlink
(
os
.
path
.
abspath
(
self
.
temporary_outputfiles_basename
),
os
.
path
.
abspath
(
self
.
outputfiles_basename
),
target_is_directory
=
True
,
)
self
.
_setup_run_directory
()
self
.
_check_and_load_sampling_time_file
()
# use reactive nested sampler when no live points are given
if
self
.
kwargs
.
get
(
"
num_live_points
"
,
None
)
is
not
None
:
...
...
@@ -289,18 +311,66 @@ class Ultranest(NestedSampler):
"
The default step sampling will be used instead.
"
)
results
=
sampler
.
run
(
**
self
.
sampler_function_kwargs
)
if
self
.
use_temporary_directory
:
self
.
_viz_callback_counter
=
1
self
.
copy_temporary_directory_to_proper_path
()
self
.
start_time
=
time
.
time
()
results
=
sampler
.
run
(
**
self
.
sampler_function_kwargs
)
self
.
_calculate_and_save_sampling_time
()
# Clean up
s
hutil
.
rmtree
(
temporary_outputfiles_basename
)
s
elf
.
_clean_up_run_directory
(
)
self
.
_generate_result
(
results
)
self
.
calc_likelihood_count
()
return
self
.
result
def
_setup_run_directory
(
self
):
"""
If using a temporary directory, the output directory is moved to the
temporary directory and symlinked back.
"""
if
self
.
use_temporary_directory
:
temporary_outputfiles_basename
=
tempfile
.
TemporaryDirectory
().
name
self
.
temporary_outputfiles_basename
=
temporary_outputfiles_basename
if
os
.
path
.
exists
(
self
.
outputfiles_basename
):
distutils
.
dir_util
.
copy_tree
(
self
.
outputfiles_basename
,
self
.
temporary_outputfiles_basename
)
check_directory_exists_and_if_not_mkdir
(
temporary_outputfiles_basename
)
self
.
kwargs
[
"
log_dir
"
]
=
self
.
temporary_outputfiles_basename
logger
.
info
(
"
Using temporary file {}
"
.
format
(
temporary_outputfiles_basename
)
)
else
:
check_directory_exists_and_if_not_mkdir
(
self
.
outputfiles_basename
)
self
.
kwargs
[
"
log_dir
"
]
=
self
.
outputfiles_basename
logger
.
info
(
"
Using output file {}
"
.
format
(
self
.
outputfiles_basename
))
def
_clean_up_run_directory
(
self
):
if
self
.
use_temporary_directory
:
self
.
_move_temporary_directory_to_proper_path
()
self
.
kwargs
[
"
log_dir
"
]
=
self
.
outputfiles_basename
def
_check_and_load_sampling_time_file
(
self
):
self
.
time_file_path
=
os
.
path
.
join
(
self
.
kwargs
[
"
log_dir
"
],
"
sampling_time.dat
"
)
if
os
.
path
.
exists
(
self
.
time_file_path
):
with
open
(
self
.
time_file_path
,
"
r
"
)
as
time_file
:
self
.
total_sampling_time
=
float
(
time_file
.
readline
())
else
:
self
.
total_sampling_time
=
0
def
_calculate_and_save_sampling_time
(
self
):
current_time
=
time
.
time
()
new_sampling_time
=
current_time
-
self
.
start_time
self
.
total_sampling_time
+=
new_sampling_time
with
open
(
self
.
time_file_path
,
"
w
"
)
as
time_file
:
time_file
.
write
(
str
(
self
.
total_sampling_time
))
self
.
start_time
=
current_time
def
_generate_result
(
self
,
out
):
# extract results (samples stored in "v" will change to "points",
# weights stored in "w" will change to "weights")
...
...
@@ -325,3 +395,4 @@ class Ultranest(NestedSampler):
self
.
result
.
log_evidence_err
=
out
[
"
logzerr
"
]
self
.
result
.
outputfiles_basename
=
self
.
outputfiles_basename
self
.
result
.
sampling_time
=
datetime
.
timedelta
(
seconds
=
self
.
total_sampling_time
)
This diff is collapsed.
Click to expand it.
test/sampler_test.py
+
6
−
4
View file @
ad8a0d68
...
...
@@ -836,7 +836,7 @@ class TestUltranest(unittest.TestCase):
dlogz
=
None
,
max_iters
=
None
,
update_interval_iter_fraction
=
0.2
,
viz_callback
=
"
auto
"
,
viz_callback
=
None
,
dKL
=
0.5
,
frac_remain
=
0.01
,
Lepsilon
=
0.001
,
...
...
@@ -850,6 +850,7 @@ class TestUltranest(unittest.TestCase):
self
.
assertListEqual
([
1
,
0
],
self
.
sampler
.
kwargs
[
"
wrapped_params
"
])
# Check this separately
self
.
sampler
.
kwargs
[
"
wrapped_params
"
]
=
None
# The dict comparison can't handle lists
self
.
sampler
.
kwargs
[
"
derived_param_names
"
]
=
None
self
.
sampler
.
kwargs
[
"
viz_callback
"
]
=
None
self
.
assertDictEqual
(
expected
,
self
.
sampler
.
kwargs
)
def
test_translate_kwargs
(
self
):
...
...
@@ -870,7 +871,7 @@ class TestUltranest(unittest.TestCase):
dlogz
=
None
,
max_iters
=
None
,
update_interval_iter_fraction
=
0.2
,
viz_callback
=
"
auto
"
,
viz_callback
=
None
,
dKL
=
0.5
,
frac_remain
=
0.01
,
Lepsilon
=
0.001
,
...
...
@@ -884,10 +885,11 @@ class TestUltranest(unittest.TestCase):
for
equiv
in
bilby
.
core
.
sampler
.
base_sampler
.
NestedSampler
.
npoints_equiv_kwargs
:
new_kwargs
=
self
.
sampler
.
kwargs
.
copy
()
del
new_kwargs
[
'
num_live_points
'
]
new_kwargs
[
'
wrapped_params
'
]
=
None
# The dict comparison can't handle lists
new_kwargs
[
"
derived_param_names
"
]
=
None
new_kwargs
[
equiv
]
=
123
self
.
sampler
.
kwargs
=
new_kwargs
self
.
sampler
.
kwargs
[
"
wrapped_params
"
]
=
None
self
.
sampler
.
kwargs
[
"
derived_param_names
"
]
=
None
self
.
sampler
.
kwargs
[
"
viz_callback
"
]
=
None
self
.
assertDictEqual
(
expected
,
self
.
sampler
.
kwargs
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment