From 0ac57e3db253414ea6625d583f8cc57e77df1b68 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Mon, 8 Oct 2018 02:01:46 -0500
Subject: [PATCH] update cpnest syntax

---
 bilby/core/sampler/cpnest.py | 42 +++++++++++++++++++++++++-----------
 sampler_requirements.txt     |  2 +-
 test/sampler_test.py         | 12 ++++++-----
 3 files changed, 37 insertions(+), 19 deletions(-)

diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py
index 91cf7bc6..0fcdba59 100644
--- a/bilby/core/sampler/cpnest.py
+++ b/bilby/core/sampler/cpnest.py
@@ -1,7 +1,7 @@
 from __future__ import absolute_import
 import numpy as np
 from pandas import DataFrame
-from ..utils import logger
+from ..utils import logger, check_directory_exists_and_if_not_mkdir
 from .base_sampler import NestedSampler
 
 
@@ -15,27 +15,33 @@ class Cpnest(NestedSampler):
 
     Keyword Arguments
     -----------------
-    npoints: int
+    nlive: int
         The number of live points, note this can also equivalently be given as
-        one of [nlive, nlives, n_live_points]
+        one of [npoints, nlives, n_live_points]
     seed: int (1234)
         Initialised random seed
-    Nthreads: int, (1)
+    nthreads: int, (1)
         Number of threads to use
     maxmcmc: int (1000)
         The maximum number of MCMC steps to take
-    verbose: Bool
+    verbose: Bool (True)
         If true, print information information about the convergence during
+    resume: Bool (False)
+        Whether or not to resume from a previous run
+    output: str
+        Where to write the CPNest, by default this is
+        {self.outdir}/cpnest_{self.label}/
 
     """
-    default_kwargs = dict(verbose=1, Nthreads=1, Nlive=500, maxmcmc=1000,
-                          Poolsize=100, seed=None, balance_samplers=True)
+    default_kwargs = dict(verbose=1, nthreads=1, nlive=500, maxmcmc=1000,
+                          seed=None, poolsize=100, nhamiltonian=0, resume=False,
+                          output=None)
 
     def _translate_kwargs(self, kwargs):
-        if 'Nlive' not in kwargs:
+        if 'nlive' not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['Nlive'] = kwargs.pop(equiv)
+                    kwargs['nlive'] = kwargs.pop(equiv)
         if 'seed' not in kwargs:
             logger.warning('No seed provided, cpnest will use 1234.')
 
@@ -69,15 +75,25 @@ class Cpnest(NestedSampler):
         bounds = [[self.priors[key].minimum, self.priors[key].maximum]
                   for key in self.search_parameter_keys]
         model = Model(self.search_parameter_keys, bounds)
-        out = CPNest(model, output=self.outdir, **self.kwargs)
+        out = CPNest(model, **self.kwargs)
         out.run()
 
         if self.plot:
             out.plot()
 
-        # Since the output is not just samples, but log_likelihood as well,
-        # we turn this into a dataframe here. The index [0] here may be wrong
-        self.result.posterior = DataFrame(out.posterior_samples[0])
+        self.result.posterior = DataFrame(out.posterior_samples)
         self.result.log_evidence = out.NS.state.logZ
         self.result.log_evidence_err = np.nan
         return self.result
+
+    def _verify_kwargs_against_default_kwargs(self):
+        """
+        Set the directory where the output will be written.
+        """
+        if not self.kwargs['output']:
+            self.kwargs['output'] = \
+                '{}/cpnest_{}/'.format(self.outdir, self.label)
+        if self.kwargs['output'].endswith('/') is False:
+            self.kwargs['output'] = '{}/'.format(self.kwargs['output'])
+        check_directory_exists_and_if_not_mkdir(self.kwargs['output'])
+        NestedSampler._verify_kwargs_against_default_kwargs(self)
diff --git a/sampler_requirements.txt b/sampler_requirements.txt
index fbb2c32f..e81411d3 100644
--- a/sampler_requirements.txt
+++ b/sampler_requirements.txt
@@ -1,4 +1,4 @@
-cpnest
+cpnest>=0.9.4
 dynesty
 emcee
 nestle
diff --git a/test/sampler_test.py b/test/sampler_test.py
index 56409d4a..963813f0 100644
--- a/test/sampler_test.py
+++ b/test/sampler_test.py
@@ -121,16 +121,18 @@ class TestCPNest(unittest.TestCase):
         del self.sampler
 
     def test_default_kwargs(self):
-        expected = dict(verbose=1, Nthreads=1, Nlive=500, maxmcmc=1000,
-                        Poolsize=100, seed=None, balance_samplers=True)
+        expected = dict(verbose=1, nthreads=1, nlive=500, maxmcmc=1000,
+                        seed=None, poolsize=100, nhamiltonian=0, resume=False,
+                        output='outdir/cpnest_label/')
         self.assertDictEqual(expected, self.sampler.kwargs)
 
     def test_translate_kwargs(self):
-        expected = dict(verbose=1, Nthreads=1, Nlive=250, maxmcmc=1000,
-                        Poolsize=100, seed=None, balance_samplers=True)
+        expected = dict(verbose=1, nthreads=1, nlive=250, maxmcmc=1000,
+                        seed=None, poolsize=100, nhamiltonian=0, resume=False,
+                        output='outdir/cpnest_label/')
         for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
             new_kwargs = self.sampler.kwargs.copy()
-            del new_kwargs['Nlive']
+            del new_kwargs['nlive']
             new_kwargs[equiv] = 250
             self.sampler.kwargs = new_kwargs
             self.assertDictEqual(expected, self.sampler.kwargs)
-- 
GitLab