From 9a34533d018c4f7e6270ab3b7bc7683472e44cf5 Mon Sep 17 00:00:00 2001
From: Matthew Pitkin <matthew.pitkin@ligo.org>
Date: Thu, 22 Aug 2024 15:19:55 +0000
Subject: [PATCH] Resolve "patch to speed the MCMC"

---
 AUTHORS.md                  |  1 +
 bilby/core/sampler/emcee.py | 16 +++++++++-------
 2 files changed, 10 insertions(+), 7 deletions(-)

diff --git a/AUTHORS.md b/AUTHORS.md
index 8c227ed6f..17b09070f 100644
--- a/AUTHORS.md
+++ b/AUTHORS.md
@@ -16,6 +16,7 @@ Bruce Edelman
 Carl-Johan Haster
 Cecilio Garcia-Quiros
 Charlie Hoy
+Chentao Yang
 Christopher Philip Luke Berry
 Christos Karathanasis
 Colm Talbot
diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index 76e4dd1eb..2c12ee354 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -1,7 +1,6 @@
 import os
 import shutil
 from collections import namedtuple
-from shutil import copyfile
 
 import numpy as np
 from packaging import version
@@ -333,16 +332,19 @@ class Emcee(MCMCSampler):
     def write_chains_to_file(self, sample):
         chain_file = self.checkpoint_info.chain_file
         temp_chain_file = chain_file + ".temp"
-        if os.path.isfile(chain_file):
-            copyfile(chain_file, temp_chain_file)
         if self.prerelease:
             points = np.hstack([sample.coords, sample.blobs])
         else:
             points = np.hstack([sample[0], np.array(sample[3])])
-        with open(temp_chain_file, "a") as ff:
-            for ii, point in enumerate(points):
-                ff.write(self.checkpoint_info.chain_template.format(ii, *point))
-        shutil.move(temp_chain_file, chain_file)
+        data_to_write = "\n".join(
+            self.checkpoint_info.chain_template.format(ii, *point)
+            for ii, point in enumerate(points)
+        )
+        with open(temp_chain_file, "w") as ff:
+            ff.write(data_to_write)
+        with open(temp_chain_file, "rb") as ftemp, open(chain_file, "ab") as fchain:
+            shutil.copyfileobj(ftemp, fchain)
+        os.remove(temp_chain_file)
 
     @property
     def _previous_iterations(self):
-- 
GitLab