From 6c8678ff5e41d66baab35d99f4e36dee7b8f3a0a Mon Sep 17 00:00:00 2001
From: Michael Williams <michael.williams@ligo.org>
Date: Mon, 13 May 2024 13:28:39 +0000
Subject: [PATCH] ENH: add `get_expected_outputs` to sampler classes

---
 bilby/bilby_mcmc/sampler.py            | 23 +++++++++++++++
 bilby/core/sampler/base_sampler.py     | 39 ++++++++++++++++++++++++++
 bilby/core/sampler/cpnest.py           |  1 +
 bilby/core/sampler/dnest4.py           |  1 +
 bilby/core/sampler/dynamic_dynesty.py  |  1 +
 bilby/core/sampler/dynesty.py          | 27 ++++++++++++++++++
 bilby/core/sampler/emcee.py            |  1 +
 bilby/core/sampler/fake_sampler.py     |  2 ++
 bilby/core/sampler/kombine.py          |  1 +
 bilby/core/sampler/nessai.py           | 26 +++++++++++++++++
 bilby/core/sampler/nestle.py           |  1 +
 bilby/core/sampler/polychord.py        | 25 +++++++++++++++++
 bilby/core/sampler/ptemcee.py          | 24 ++++++++++++++++
 bilby/core/sampler/ptmcmc.py           |  2 ++
 bilby/core/sampler/pymc.py             |  1 +
 bilby/core/sampler/pymultinest.py      |  2 ++
 bilby/core/sampler/ultranest.py        |  2 ++
 bilby/core/sampler/zeus.py             |  1 +
 test/bilby_mcmc/test_sampler.py        | 11 ++++++++
 test/core/sampler/base_sampler_test.py | 24 ++++++++++++++++
 test/core/sampler/dynesty_test.py      | 13 +++++++++
 test/core/sampler/nessai_test.py       | 15 ++++++++++
 test/core/sampler/ptemcee_test.py      | 12 ++++++++
 23 files changed, 255 insertions(+)

diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py
index fcd4c9c5a..a3732add0 100644
--- a/bilby/bilby_mcmc/sampler.py
+++ b/bilby/bilby_mcmc/sampler.py
@@ -547,6 +547,29 @@ class Bilby_MCMC(MCMCSampler):
                         all_samples=ptsampler.samples,
                     )
 
+    @classmethod
+    def get_expected_outputs(cls, outdir=None, label=None):
+        """Get lists of the expected outputs directories and files.
+
+        These are used by :code:`bilby_pipe` when transferring files via HTCondor.
+
+        Parameters
+        ----------
+        outdir : str
+            The output directory.
+        label : str
+            The label for the run.
+
+        Returns
+        -------
+        list
+            List of file names.
+        list
+            List of directory names. Will always be empty for bilby_mcmc.
+        """
+        filenames = [os.path.join(outdir, f"{label}_resume.pickle")]
+        return filenames, []
+
 
 class BilbyPTMCMCSampler(object):
     def __init__(
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index 3f73cb795..3c573afe2 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -173,6 +173,13 @@ class Sampler(object):
         Whether the implemented sampler exits hard (:code:`os._exit` rather
         than :code:`sys.exit`). The latter can be escaped as :code:`SystemExit`.
         The former cannot.
+    sampler_name : str
+        Name of the sampler. This is used when creating the output directory for
+        the sampler.
+    abbreviation : str
+        Abbreviated name of the sampler. Does not have to be specified in child
+        classes. If set to a value other than :code:`None`, this will be used
+        instead of :code:`sampler_name` when creating the output directory.
 
     Raises
     ======
@@ -187,6 +194,8 @@ class Sampler(object):
 
     """
 
+    sampler_name = "sampler"
+    abbreviation = None
     default_kwargs = dict()
     npool_equiv_kwargs = [
         "npool",
@@ -779,8 +788,37 @@ class Sampler(object):
     def write_current_state(self):
         raise NotImplementedError()
 
+    @classmethod
+    def get_expected_outputs(cls, outdir=None, label=None):
+        """Get lists of the expected outputs directories and files.
+
+        These are used by :code:`bilby_pipe` when transferring files via HTCondor.
+        Both can be empty. Defaults to a single directory:
+        :code:`"{outdir}/{name}_{label}/"`, where :code:`name`
+        is :code:`abbreviation` if it is defined for the sampler class, otherwise
+        it defaults to :code:`sampler_name`.
+
+        Parameters
+        ----------
+        outdir : str
+            The output directory.
+        label : str
+            The label for the run.
+
+        Returns
+        -------
+        list
+            List of file names.
+        list
+            List of directory names.
+        """
+        name = cls.abbreviation or cls.sampler_name
+        dirname = os.path.join(outdir, f"{name}_{label}", "")
+        return [], [dirname]
+
 
 class NestedSampler(Sampler):
+    sampler_name = "nested_sampler"
     npoints_equiv_kwargs = [
         "nlive",
         "nlives",
@@ -854,6 +892,7 @@ class NestedSampler(Sampler):
 
 
 class MCMCSampler(Sampler):
+    sampler_name = "mcmc_sampler"
     nwalkers_equiv_kwargs = ["nwalker", "nwalkers", "draws", "Niter"]
     nburn_equiv_kwargs = ["burn", "nburn"]
 
diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py
index e1f3ae19e..e777ebc67 100644
--- a/bilby/core/sampler/cpnest.py
+++ b/bilby/core/sampler/cpnest.py
@@ -40,6 +40,7 @@ class Cpnest(NestedSampler):
 
     """
 
+    sampler_name = "cpnest"
     default_kwargs = dict(
         verbose=3,
         nthreads=1,
diff --git a/bilby/core/sampler/dnest4.py b/bilby/core/sampler/dnest4.py
index a767ef89d..87717f6fd 100644
--- a/bilby/core/sampler/dnest4.py
+++ b/bilby/core/sampler/dnest4.py
@@ -99,6 +99,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler):
         If True, prints information during run
     """
 
+    sampler_name = "d4nest"
     default_kwargs = dict(
         max_num_levels=20,
         num_steps=500,
diff --git a/bilby/core/sampler/dynamic_dynesty.py b/bilby/core/sampler/dynamic_dynesty.py
index 294d8fd6d..8c7f2966e 100644
--- a/bilby/core/sampler/dynamic_dynesty.py
+++ b/bilby/core/sampler/dynamic_dynesty.py
@@ -14,6 +14,7 @@ class DynamicDynesty(Dynesty):
     """
 
     external_sampler_name = "dynesty"
+    sampler_name = "dynamic_dynesty"
 
     @property
     def nlive(self):
diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index faebfc6bf..852fb88c1 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -151,6 +151,7 @@ class Dynesty(NestedSampler):
         specified.
     """
 
+    sampler_name = "dynesty"
     sampling_seed_key = "seed"
 
     @property
@@ -299,6 +300,32 @@ class Dynesty(NestedSampler):
                 )
         Sampler._verify_kwargs_against_default_kwargs(self)
 
+    @classmethod
+    def get_expected_outputs(cls, outdir=None, label=None):
+        """Get lists of the expected outputs directories and files.
+
+        These are used by :code:`bilby_pipe` when transferring files via HTCondor.
+
+        Parameters
+        ----------
+        outdir : str
+            The output directory.
+        label : str
+            The label for the run.
+
+        Returns
+        -------
+        list
+            List of file names.
+        list
+            List of directory names. Will always be empty for dynesty.
+        """
+        filenames = []
+        for kind in ["resume", "dynesty"]:
+            filename = os.path.join(outdir, f"{label}_{kind}.pickle")
+            filenames.append(filename)
+        return filenames, []
+
     def _print_func(
         self,
         results,
diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index 7253a0fa4..db88ee5a2 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -45,6 +45,7 @@ class Emcee(MCMCSampler):
 
     """
 
+    sampler_name = "emcee"
     default_kwargs = dict(
         nwalkers=500,
         a=2,
diff --git a/bilby/core/sampler/fake_sampler.py b/bilby/core/sampler/fake_sampler.py
index 5f375fdba..9795631fb 100644
--- a/bilby/core/sampler/fake_sampler.py
+++ b/bilby/core/sampler/fake_sampler.py
@@ -17,6 +17,8 @@ class FakeSampler(Sampler):
         A string pointing to the posterior data file to be loaded.
     """
 
+    sampler_name = "fake_sampler"
+
     default_kwargs = dict(
         verbose=True, logl_args=None, logl_kwargs=None, print_progress=True
     )
diff --git a/bilby/core/sampler/kombine.py b/bilby/core/sampler/kombine.py
index 1f09387cc..bda7c6d4f 100644
--- a/bilby/core/sampler/kombine.py
+++ b/bilby/core/sampler/kombine.py
@@ -39,6 +39,7 @@ class Kombine(Emcee):
 
     """
 
+    sampler_name = "kombine"
     default_kwargs = dict(
         nwalkers=500,
         args=[],
diff --git a/bilby/core/sampler/nessai.py b/bilby/core/sampler/nessai.py
index b6f40f8aa..65a650efd 100644
--- a/bilby/core/sampler/nessai.py
+++ b/bilby/core/sampler/nessai.py
@@ -20,6 +20,7 @@ class Nessai(NestedSampler):
     Documentation: https://nessai.readthedocs.io/
     """
 
+    sampler_name = "nessai"
     _default_kwargs = None
     _run_kwargs_list = None
     sampling_seed_key = "seed"
@@ -300,5 +301,30 @@ class Nessai(NestedSampler):
         self._log_interruption(signum=signum)
         sys.exit(self.exit_code)
 
+    @classmethod
+    def get_expected_outputs(cls, outdir=None, label=None):
+        """Get lists of the expected outputs directories and files.
+
+        These are used by :code:`bilby_pipe` when transferring files via HTCondor.
+
+        Parameters
+        ----------
+        outdir : str
+            The output directory.
+        label : str
+            The label for the run.
+
+        Returns
+        -------
+        list
+            List of file names. This will be empty for nessai.
+        list
+            List of directory names.
+        """
+        dirs = [os.path.join(outdir, f"{label}_{cls.sampler_name}", "")]
+        dirs += [os.path.join(dirs[0], d, "") for d in ["proposal", "diagnostics"]]
+        filenames = []
+        return filenames, dirs
+
     def _setup_pool(self):
         pass
diff --git a/bilby/core/sampler/nestle.py b/bilby/core/sampler/nestle.py
index ebd955376..75d93bf69 100644
--- a/bilby/core/sampler/nestle.py
+++ b/bilby/core/sampler/nestle.py
@@ -24,6 +24,7 @@ class Nestle(NestedSampler):
 
     """
 
+    sampler_name = "nestle"
     default_kwargs = dict(
         verbose=True,
         method="multi",
diff --git a/bilby/core/sampler/polychord.py b/bilby/core/sampler/polychord.py
index e43c5d50b..9391dd202 100644
--- a/bilby/core/sampler/polychord.py
+++ b/bilby/core/sampler/polychord.py
@@ -1,3 +1,5 @@
+import os
+
 import numpy as np
 
 from .base_sampler import NestedSampler, signal_wrapper
@@ -20,6 +22,7 @@ class PyPolyChord(NestedSampler):
     To see what the keyword arguments are for, see the docstring of PyPolyChordSettings
     """
 
+    sampler_name = "pypolychord"
     default_kwargs = dict(
         use_polychord_defaults=False,
         nlive=None,
@@ -130,6 +133,28 @@ class PyPolyChord(NestedSampler):
         physical_parameters = samples[:, -self.ndim :]
         return log_likelihoods, physical_parameters
 
+    @classmethod
+    def get_expected_outputs(cls, outdir=None, label=None):
+        """Get lists of the expected outputs directories and files.
+
+        These are used by :code:`bilby_pipe` when transferring files via HTCondor.
+
+        Parameters
+        ----------
+        outdir : str
+            The output directory.
+        label : str
+            Ignored for pypolychord.
+
+        Returns
+        -------
+        list
+            List of file names. This will always be empty for pypolychord.
+        list
+            List of directory names.
+        """
+        return [], [os.path.join(outdir, "chains", "")]
+
     @property
     def _sample_file_directory(self):
         return self.outdir + "/chains"
diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 531fb102a..fd927235d 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -128,6 +128,7 @@ class Ptemcee(MCMCSampler):
 
     """
 
+    sampler_name = "ptemcee"
     # Arguments used by ptemcee
     default_kwargs = dict(
         ntemps=10,
@@ -710,6 +711,29 @@ class Ptemcee(MCMCSampler):
             except Exception as e:
                 logger.info(f"mean_logl plot failed with exception {e}")
 
+    @classmethod
+    def get_expected_outputs(cls, outdir=None, label=None):
+        """Get lists of the expected outputs directories and files.
+
+        These are used by :code:`bilby_pipe` when transferring files via HTCondor.
+
+        Parameters
+        ----------
+        outdir : str
+            The output directory.
+        label : str
+            The label for the run.
+
+        Returns
+        -------
+        list
+            List of file names.
+        list
+            List of directory names. Will always be empty for ptemcee.
+        """
+        filenames = [f"{outdir}/{label}_checkpoint_resume.pickle"]
+        return filenames, []
+
 
 def get_minimum_stable_itertion(mean_array, frac, nsteps_min=10):
     nsteps = mean_array.shape[1]
diff --git a/bilby/core/sampler/ptmcmc.py b/bilby/core/sampler/ptmcmc.py
index 42279e018..f2a771cb0 100644
--- a/bilby/core/sampler/ptmcmc.py
+++ b/bilby/core/sampler/ptmcmc.py
@@ -41,6 +41,8 @@ class PTMCMCSampler(MCMCSampler):
 
     """
 
+    sampler_name = "ptmcmcsampler"
+    abbreviation = "ptmcmc_temp"
     default_kwargs = {
         "p0": None,
         "Niter": 2 * 10**4 + 1,
diff --git a/bilby/core/sampler/pymc.py b/bilby/core/sampler/pymc.py
index e72aace49..fd138b985 100644
--- a/bilby/core/sampler/pymc.py
+++ b/bilby/core/sampler/pymc.py
@@ -52,6 +52,7 @@ class Pymc(MCMCSampler):
 
     """
 
+    sampler_name = "pymc"
     default_kwargs = dict(
         draws=500,
         step=None,
diff --git a/bilby/core/sampler/pymultinest.py b/bilby/core/sampler/pymultinest.py
index 303acb705..0a9bb0aaf 100644
--- a/bilby/core/sampler/pymultinest.py
+++ b/bilby/core/sampler/pymultinest.py
@@ -34,6 +34,8 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler):
 
     """
 
+    sampler_name = "pymultinest"
+    abbreviation = "pm"
     default_kwargs = dict(
         importance_nested_sampling=False,
         resume=True,
diff --git a/bilby/core/sampler/ultranest.py b/bilby/core/sampler/ultranest.py
index 542f86246..6aacfa999 100644
--- a/bilby/core/sampler/ultranest.py
+++ b/bilby/core/sampler/ultranest.py
@@ -38,6 +38,8 @@ class Ultranest(_TemporaryFileSamplerMixin, NestedSampler):
         stepping behaviour is used.
     """
 
+    sampler_name = "ultranest"
+    abbreviation = "ultra"
     default_kwargs = dict(
         resume=True,
         show_status=True,
diff --git a/bilby/core/sampler/zeus.py b/bilby/core/sampler/zeus.py
index c7ae40da2..ad6e7edb8 100644
--- a/bilby/core/sampler/zeus.py
+++ b/bilby/core/sampler/zeus.py
@@ -38,6 +38,7 @@ class Zeus(Emcee):
 
     """
 
+    sampler_name = "zeus"
     default_kwargs = dict(
         nwalkers=500,
         args=[],
diff --git a/test/bilby_mcmc/test_sampler.py b/test/bilby_mcmc/test_sampler.py
index 746eb1a9e..7e636e1ab 100644
--- a/test/bilby_mcmc/test_sampler.py
+++ b/test/bilby_mcmc/test_sampler.py
@@ -85,5 +85,16 @@ class TestBilbyMCMCSampler(unittest.TestCase):
         self.assertTrue(isinstance(sampler.samples, pd.DataFrame))
 
 
+def test_get_expected_outputs():
+    label = "par0"
+    outdir = os.path.join("some", "bilby_pipe", "dir")
+    filenames, directories = Bilby_MCMC.get_expected_outputs(
+        outdir=outdir, label=label
+    )
+    assert len(filenames) == 1
+    assert len(directories) == 0
+    assert os.path.join(outdir, f"{label}_resume.pickle") in filenames
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py
index 1250fa0d6..d20ee978a 100644
--- a/test/core/sampler/base_sampler_test.py
+++ b/test/core/sampler/base_sampler_test.py
@@ -148,6 +148,30 @@ class TestSampler(unittest.TestCase):
         )
 
 
+def test_get_expected_outputs():
+    outdir = os.path.join("some", "bilby_pipe", "dir")
+    label = "par0"
+    filenames, directories = bilby.core.sampler.Sampler.get_expected_outputs(
+        outdir=outdir, label=label
+    )
+    assert len(filenames) == 0
+    assert len(directories) == 1
+    assert directories[0] == os.path.join(outdir, f"sampler_{label}", "")
+
+
+def test_get_expected_outputs_abbreviation():
+    outdir = os.path.join("some", "bilby_pipe", "dir")
+    label = "par0"
+    bilby.core.sampler.Sampler.abbreviation = "abbr"
+    filenames, directories = bilby.core.sampler.Sampler.get_expected_outputs(
+        outdir=outdir, label=label
+    )
+    assert len(filenames) == 0
+    assert len(directories) == 1
+    assert directories[0] == os.path.join(outdir, f"abbr_{label}", "")
+    bilby.core.sampler.Sampler.abbreviation = None
+
+
 samplers = [
     "bilby_mcmc",
     "dynamic_dynesty",
diff --git a/test/core/sampler/dynesty_test.py b/test/core/sampler/dynesty_test.py
index 39ac6f231..d33cc2e23 100644
--- a/test/core/sampler/dynesty_test.py
+++ b/test/core/sampler/dynesty_test.py
@@ -9,6 +9,7 @@ import bilby.core.sampler.dynesty
 from bilby.core.sampler import dynesty_utils
 from scipy.stats import gamma, ks_1samp, uniform, powerlaw
 import shutil
+import os
 
 
 @define
@@ -101,6 +102,18 @@ class TestDynesty(unittest.TestCase):
         self.sampler._run_test()
 
 
+def test_get_expected_outputs():
+    label = "par0"
+    outdir = os.path.join("some", "bilby_pipe", "dir")
+    filenames, directories = bilby.core.sampler.dynesty.Dynesty.get_expected_outputs(
+        outdir=outdir, label=label
+    )
+    assert len(filenames) == 2
+    assert len(directories) == 0
+    assert os.path.join(outdir, f"{label}_resume.pickle") in filenames
+    assert os.path.join(outdir, f"{label}_dynesty.pickle") in filenames
+
+
 class ProposalsTest(unittest.TestCase):
 
     def test_boundaries(self):
diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py
index cca5d22b0..3246c74e7 100644
--- a/test/core/sampler/nessai_test.py
+++ b/test/core/sampler/nessai_test.py
@@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch, mock_open
 
 import bilby
 import bilby.core.sampler.nessai
+import os
 
 
 class TestNessai(unittest.TestCase):
@@ -84,5 +85,19 @@ class TestNessai(unittest.TestCase):
         self.assertDictEqual(expected, self.sampler.kwargs)
 
 
+def test_get_expected_outputs():
+    label = "par0"
+    outdir = os.path.join("some", "bilby_pipe", "dir")
+    filenames, directories = bilby.core.sampler.nessai.Nessai.get_expected_outputs(
+        outdir=outdir, label=label
+    )
+    assert len(filenames) == 0
+    assert len(directories) == 3
+    base_dir = os.path.join(outdir, f"{label}_nessai", "")
+    assert base_dir in directories
+    assert os.path.join(base_dir, "proposal", "") in directories
+    assert os.path.join(base_dir, "diagnostics", "") in directories
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/test/core/sampler/ptemcee_test.py b/test/core/sampler/ptemcee_test.py
index ec135eeef..4708a12b0 100644
--- a/test/core/sampler/ptemcee_test.py
+++ b/test/core/sampler/ptemcee_test.py
@@ -5,6 +5,7 @@ from bilby.core.prior import Uniform, PriorDict
 from bilby.core.sampler.ptemcee import Ptemcee
 from bilby.core.sampler.base_sampler import MCMCSampler
 import numpy as np
+import os
 
 
 class TestPTEmcee(unittest.TestCase):
@@ -89,5 +90,16 @@ class TestPTEmcee(unittest.TestCase):
         self.assertEqual(old, new)
 
 
+def test_get_expected_outputs():
+    label = "par0"
+    outdir = os.path.join("some", "bilby_pipe", "dir")
+    filenames, directories = Ptemcee.get_expected_outputs(
+        outdir=outdir, label=label
+    )
+    assert len(filenames) == 1
+    assert len(directories) == 0
+    assert os.path.join(outdir, f"{label}_checkpoint_resume.pickle") in filenames
+
+
 if __name__ == "__main__":
     unittest.main()
-- 
GitLab