diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py
index 88e7234a20820d1eee151118b717a5d0c6f91cb8..c53ee9ff8063eafd53bb60f9585a7167a976d4d4 100644
--- a/bilby/gw/likelihood/multiband.py
+++ b/bilby/gw/likelihood/multiband.py
@@ -1,12 +1,15 @@
 
 import math
+import numbers
 
 import numpy as np
 
 from .base import GravitationalWaveTransient
 from ...core.utils import (
     logger, speed_of_light, solar_mass, radius_of_earth,
-    gravitational_constant, round_up_to_power_of_two
+    gravitational_constant, round_up_to_power_of_two,
+    recursively_load_dict_contents_from_group,
+    recursively_save_dict_contents_to_group
 )
 from ..prior import CBCPriorDict
 
@@ -42,6 +45,8 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
         A maximum frequency for multi-banding. If specified, the low-frequency limit of a band does not exceed it.
     minimum_banding_duration: float, optional
         A minimum duration for multi-banding. If specified, the duration of a band is not smaller than it.
+    weights: str or dict, optional
+        Pre-computed multiband weights for calculating inner products.
     distance_marginalization: bool, optional
         If true, marginalize over distance in the likelihood. This uses a look up table calculated at run time. The
         distance prior is set to be a delta function at the minimum distance allowed in the prior being marginalised
@@ -76,9 +81,9 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
     def __init__(
             self, interferometers, waveform_generator, reference_chirp_mass=None, highest_mode=2,
             linear_interpolation=True, accuracy_factor=5, time_offset=None, delta_f_end=None,
-            maximum_banding_frequency=None, minimum_banding_duration=0., distance_marginalization=False,
-            phase_marginalization=False, priors=None, distance_marginalization_lookup_table=None,
-            reference_frame="sky", time_reference="geocenter"
+            maximum_banding_frequency=None, minimum_banding_duration=0., weights=None,
+            distance_marginalization=False, phase_marginalization=False, priors=None,
+            distance_marginalization_lookup_table=None, reference_frame="sky", time_reference="geocenter"
     ):
         super(MBGravitationalWaveTransient, self).__init__(
             interferometers=interferometers, waveform_generator=waveform_generator, priors=priors,
@@ -86,17 +91,23 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
             time_marginalization=False, distance_marginalization_lookup_table=distance_marginalization_lookup_table,
             jitter_time=False, reference_frame=reference_frame, time_reference=time_reference
         )
-        self.reference_chirp_mass = reference_chirp_mass
-        self.highest_mode = highest_mode
-        self.linear_interpolation = linear_interpolation
-        self.accuracy_factor = accuracy_factor
-        self.time_offset = time_offset
-        self.delta_f_end = delta_f_end
-        self.minimum_frequency = np.min([i.minimum_frequency for i in self.interferometers])
-        self.maximum_frequency = np.max([i.maximum_frequency for i in self.interferometers])
-        self.maximum_banding_frequency = maximum_banding_frequency
-        self.minimum_banding_duration = minimum_banding_duration
-        self.setup_multibanding()
+        if weights is None:
+            self.reference_chirp_mass = reference_chirp_mass
+            self.highest_mode = highest_mode
+            self.linear_interpolation = linear_interpolation
+            self.accuracy_factor = accuracy_factor
+            self.time_offset = time_offset
+            self.delta_f_end = delta_f_end
+            self.maximum_banding_frequency = maximum_banding_frequency
+            self.minimum_banding_duration = minimum_banding_duration
+            self.setup_multibanding()
+        else:
+            if isinstance(weights, str):
+                import h5py
+                logger.info(f"Loading multiband weights from {weights}.")
+                with h5py.File(weights, 'r') as f:
+                    weights = recursively_load_dict_contents_from_group(f, '/')
+            self.setup_multibanding_from_weights(weights)
 
     @property
     def reference_chirp_mass(self):
@@ -108,7 +119,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
 
     @reference_chirp_mass.setter
     def reference_chirp_mass(self, reference_chirp_mass):
-        if isinstance(reference_chirp_mass, int) or isinstance(reference_chirp_mass, float):
+        if isinstance(reference_chirp_mass, numbers.Number):
             self._reference_chirp_mass = reference_chirp_mass
         else:
             logger.info(
@@ -136,7 +147,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
 
     @highest_mode.setter
     def highest_mode(self, highest_mode):
-        if isinstance(highest_mode, int) or isinstance(highest_mode, float):
+        if isinstance(highest_mode, numbers.Number):
             self._highest_mode = highest_mode
         else:
             raise TypeError("highest_mode must be a number")
@@ -147,7 +158,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
 
     @linear_interpolation.setter
     def linear_interpolation(self, linear_interpolation):
-        if isinstance(linear_interpolation, bool):
+        if isinstance(linear_interpolation, bool) or isinstance(linear_interpolation, np.bool_):
             self._linear_interpolation = linear_interpolation
         else:
             raise TypeError("linear_interpolation must be a bool")
@@ -158,7 +169,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
 
     @accuracy_factor.setter
     def accuracy_factor(self, accuracy_factor):
-        if isinstance(accuracy_factor, int) or isinstance(accuracy_factor, float):
+        if isinstance(accuracy_factor, numbers.Number):
             self._accuracy_factor = accuracy_factor
         else:
             raise TypeError("accuracy_factor must be a number")
@@ -182,7 +193,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
         else:
             safety = 2 * radius_of_earth / speed_of_light
         if time_offset is not None:
-            if isinstance(time_offset, int) or isinstance(time_offset, float):
+            if isinstance(time_offset, numbers.Number):
                 self._time_offset = time_offset
             else:
                 raise TypeError("time_offset must be a number")
@@ -216,7 +227,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
         else:
             safety = 2 * radius_of_earth / speed_of_light
         if delta_f_end is not None:
-            if isinstance(delta_f_end, int) or isinstance(delta_f_end, float):
+            if isinstance(delta_f_end, numbers.Number):
                 self._delta_f_end = delta_f_end
             else:
                 raise TypeError("delta_f_end must be a number")
@@ -248,7 +259,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
             / self.reference_chirp_mass_in_second
         )
         if maximum_banding_frequency is not None:
-            if isinstance(maximum_banding_frequency, int) or isinstance(maximum_banding_frequency, float):
+            if isinstance(maximum_banding_frequency, numbers.Number):
                 if maximum_banding_frequency < fmax_tmp:
                     fmax_tmp = maximum_banding_frequency
                 else:
@@ -264,11 +275,23 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
 
     @minimum_banding_duration.setter
     def minimum_banding_duration(self, minimum_banding_duration):
-        if isinstance(minimum_banding_duration, int) or isinstance(minimum_banding_duration, float):
+        if isinstance(minimum_banding_duration, numbers.Number):
             self._minimum_banding_duration = minimum_banding_duration
         else:
             raise TypeError("minimum_banding_duration must be a number")
 
+    @property
+    def minimum_frequency(self):
+        return np.min([i.minimum_frequency for i in self.interferometers])
+
+    @property
+    def maximum_frequency(self):
+        return np.max([i.maximum_frequency for i in self.interferometers])
+
+    @property
+    def number_of_bands(self):
+        return len(self.durations)
+
     def setup_multibanding(self):
         """Set up frequency bands and coefficients needed for likelihood evaluations"""
         self._setup_frequency_bands()
@@ -365,25 +388,26 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
         original duration. This sets the following instance variables.
 
         durations: durations of bands (T^(b) in the paper)
-        fb_dfb: the list of tuples, which contain starting frequencies (f^(b) in the paper) and frequency scales for
+        fb_dfb: 2-dimensional ndarray, which contain starting frequencies (f^(b) in the paper) and frequency scales for
         smoothing waveforms (\Delta f^(b) in the paper) of bands
 
         """
-        self.durations = [self.interferometers.duration]
-        self.fb_dfb = [(self.minimum_frequency, 0.)]
+        self.durations = np.array([self.interferometers.duration])
+        self.fb_dfb = [[self.minimum_frequency, 0.]]
         dnext = self.interferometers.duration / 2
         while dnext > max(self.time_offset, self.minimum_banding_duration):
             fnow, _ = self.fb_dfb[-1]
             fnext, dfnext = self._find_starting_frequency(dnext, fnow)
             if fnext is not None and fnext < min(self.maximum_frequency, self.maximum_banding_frequency):
-                self.durations.append(dnext)
-                self.fb_dfb.append((fnext, dfnext))
+                self.durations = np.append(self.durations, dnext)
+                self.fb_dfb.append([fnext, dfnext])
                 dnext /= 2
             else:
                 break
-        self.fb_dfb.append((self.maximum_frequency + self.delta_f_end, self.delta_f_end))
+        self.fb_dfb.append([self.maximum_frequency + self.delta_f_end, self.delta_f_end])
+        self.fb_dfb = np.array(self.fb_dfb)
         logger.info("The total frequency range is divided into {} bands with frequency intervals of {}.".format(
-            len(self.durations), ", ".join(["1/{} Hz".format(d) for d in self.durations])))
+            self.number_of_bands, ", ".join(["1/{} Hz".format(d) for d in self.durations])))
 
     def _setup_integers(self):
         """Set up integers needed for likelihood evaluations. This sets the following instance variables.
@@ -393,17 +417,18 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
         Ks_Ke: start and end frequency indices of bands (K^(b)_s and K^(b)_e in the paper)
 
         """
-        self.Nbs = []
-        self.Mbs = []
+        self.Nbs = np.array([], dtype=int)
+        self.Mbs = np.array([], dtype=int)
         self.Ks_Ke = []
-        for b in range(len(self.durations)):
+        for b in range(self.number_of_bands):
             dnow = self.durations[b]
             fnow, dfnow = self.fb_dfb[b]
             fnext, _ = self.fb_dfb[b + 1]
             Nb = max(round_up_to_power_of_two(2. * (fnext * self.interferometers.duration + 1.)), 2**b)
-            self.Nbs.append(Nb)
-            self.Mbs.append(Nb // 2**b)
-            self.Ks_Ke.append((math.ceil((fnow - dfnow) * dnow), math.floor(fnext * dnow)))
+            self.Nbs = np.append(self.Nbs, Nb)
+            self.Mbs = np.append(self.Mbs, Nb // 2**b)
+            self.Ks_Ke.append([math.ceil((fnow - dfnow) * dnow), math.floor(fnext * dnow)])
+        self.Ks_Ke = np.array(self.Ks_Ke)
 
     def _setup_waveform_frequency_points(self):
         """Set up frequency points where waveforms are evaluated. Frequency points are reordered because some waveform
@@ -419,13 +444,14 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
         self.banded_frequency_points = np.array([])
         self.start_end_idxs = []
         start_idx = 0
-        for i in range(len(self.fb_dfb) - 1):
+        for i in range(self.number_of_bands):
             d = self.durations[i]
             Ks, Ke = self.Ks_Ke[i]
             self.banded_frequency_points = np.append(self.banded_frequency_points, np.arange(Ks, Ke + 1) / d)
             end_idx = start_idx + Ke - Ks
-            self.start_end_idxs.append((start_idx, end_idx))
+            self.start_end_idxs.append([start_idx, end_idx])
             start_idx = end_idx + 1
+        self.start_end_idxs = np.array(self.start_end_idxs)
         unique_frequencies, idxs = np.unique(self.banded_frequency_points, return_inverse=True)
         self.waveform_generator.waveform_arguments['frequencies'] = unique_frequencies
         self.unique_to_original_frequencies = idxs
@@ -435,35 +461,54 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
             (self.maximum_frequency - self.minimum_frequency) * self.interferometers.duration /
             len(unique_frequencies)))
 
-    def _window(self, f, b):
-        """Compute window function in the b-th band
+    def _get_window_sequence(self, delta_f, start_idx, length, b):
+        """Compute window function on frequencies with a fixed frequency interval
 
         Parameters
         ----------
-        f: float or ndarray
-            frequency at which the window function is computed
+        delta_f: float
+            frequency interval
+        start_idx: int
+            starting frequency per delta_f
+        length: int
+            number of frequencies
         b: int
+            band number
 
         Returns
         -------
-        window: float
-            window function at f
+        window_sequence: array
+
         """
         fnow, dfnow = self.fb_dfb[b]
         fnext, dfnext = self.fb_dfb[b + 1]
 
-        @np.vectorize
-        def _vectorized_window(f):
-            if fnow - dfnow < f < fnow:
-                return (1. + np.cos(np.pi * (f - fnow) / dfnow)) / 2.
-            elif fnow <= f <= fnext - dfnext:
-                return 1.
-            elif fnext - dfnext < f < fnext:
-                return (1. - np.cos(np.pi * (f - fnext) / dfnext)) / 2.
-            else:
-                return 0.
+        window_sequence = np.zeros(length)
+        increase_start = np.clip(
+            math.floor((fnow - dfnow) / delta_f) - start_idx + 1, 0, length
+        )
+        unity_start = np.clip(math.ceil(fnow / delta_f) - start_idx, 0, length)
+        decrease_start = np.clip(
+            math.floor((fnext - dfnext) / delta_f) - start_idx + 1, 0, length
+        )
+        decrease_stop = np.clip(math.ceil(fnext / delta_f) - start_idx, 0, length)
+
+        window_sequence[unity_start:decrease_start] = 1.
+
+        # this if statement avoids overflow caused by vanishing dfnow
+        if increase_start < unity_start:
+            frequencies = (np.arange(increase_start, unity_start) + start_idx) * delta_f
+            window_sequence[increase_start:unity_start] = (
+                1. + np.cos(np.pi * (frequencies - fnow) / dfnow)
+            ) / 2.
 
-        return _vectorized_window(f)
+        if decrease_start < decrease_stop:
+            frequencies = (np.arange(decrease_start, decrease_stop) + start_idx) * delta_f
+            window_sequence[decrease_start:decrease_stop] = (
+                1. - np.cos(np.pi * (frequencies - fnext) / dfnext)
+            ) / 2.
+
+        return window_sequence
 
     def _setup_linear_coefficients(self):
         """Set up coefficients by which waveforms are multiplied to compute (d, h)"""
@@ -474,9 +519,9 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
             fddata = np.zeros(N // 2 + 1, dtype=complex)
             fddata[:len(ifo.frequency_domain_strain)][ifo.frequency_mask] += \
                 ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask]
-            for b in range(len(self.fb_dfb) - 1):
-                start_idx, end_idx = self.start_end_idxs[b]
-                windows = self._window(self.banded_frequency_points[start_idx:end_idx + 1], b)
+            for b in range(self.number_of_bands):
+                Ks, Ke = self.Ks_Ke[b]
+                windows = self._get_window_sequence(1. / self.durations[b], Ks, Ke - Ks + 1, b)
                 fddata_in_ith_band = np.copy(fddata[:int(self.Nbs[b] / 2 + 1)])
                 fddata_in_ith_band[-1] = 0.  # zeroing data at the Nyquist frequency
                 tddata = np.fft.irfft(fddata_in_ith_band)[-self.Mbs[b]:]
@@ -490,36 +535,60 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
         linear-interpolation algorithm"""
         logger.info("Linear-interpolation algorithm is used for (h, h).")
         self.quadratic_coeffs = dict((ifo.name, np.array([])) for ifo in self.interferometers)
-        N = self.Nbs[-1]
-        for ifo in self.interferometers:
-            logger.info("Pre-computing quadratic coefficients for {}".format(ifo.name))
-            full_frequencies = np.arange(N // 2 + 1) / ifo.duration
-            full_inv_psds = np.zeros(N // 2 + 1)
-            full_inv_psds[:len(ifo.power_spectral_density_array)][ifo.frequency_mask] = \
-                1. / ifo.power_spectral_density_array[ifo.frequency_mask]
-            for i in range(len(self.fb_dfb) - 1):
-                start_idx, end_idx = self.start_end_idxs[i]
-                banded_frequencies = self.banded_frequency_points[start_idx:end_idx + 1]
+        original_duration = self.interferometers.duration
+
+        for b in range(self.number_of_bands):
+            logger.info(f"Pre-computing quadratic coefficients for the {b}-th band")
+            _start, _end = self.start_end_idxs[b]
+            banded_frequencies = self.banded_frequency_points[_start:_end + 1]
+            prefactor = 4 * self.durations[b] / original_duration
+
+            # precompute window values
+            _fnow, _dfnow = self.fb_dfb[b]
+            _fnext, _ = self.fb_dfb[b + 1]
+            start_idx_in_band = math.ceil((_fnow - _dfnow) * original_duration)
+            window_sequence = self._get_window_sequence(
+                1 / original_duration,
+                start_idx_in_band,
+                math.floor(_fnext * original_duration) - start_idx_in_band + 1,
+                b
+            )
+
+            for ifo in self.interferometers:
+                end_idx_in_band = min(
+                    start_idx_in_band + len(window_sequence) - 1,
+                    len(ifo.power_spectral_density_array) - 1
+                )
+                _frequency_mask = ifo.frequency_mask[start_idx_in_band:end_idx_in_band + 1]
+                window_over_psd = np.zeros(end_idx_in_band + 1 - start_idx_in_band)
+                window_over_psd[_frequency_mask] = \
+                    1. / ifo.power_spectral_density_array[start_idx_in_band:end_idx_in_band + 1][_frequency_mask]
+                window_over_psd *= window_sequence[:len(window_over_psd)]
+
                 coeffs = np.zeros(len(banded_frequencies))
                 for k in range(len(coeffs) - 1):
                     if k == 0:
-                        start_idx_in_sum = 0
+                        start_idx_in_sum = start_idx_in_band
                     else:
-                        start_idx_in_sum = math.ceil(ifo.duration * banded_frequencies[k])
+                        start_idx_in_sum = max(
+                            start_idx_in_band,
+                            math.ceil(original_duration * banded_frequencies[k])
+                        )
                     if k == len(coeffs) - 2:
-                        end_idx_in_sum = len(full_frequencies) - 1
+                        end_idx_in_sum = end_idx_in_band
                     else:
-                        end_idx_in_sum = math.ceil(ifo.duration * banded_frequencies[k + 1]) - 1
-                    window_over_psd = (
-                        full_inv_psds[start_idx_in_sum:end_idx_in_sum + 1]
-                        * self._window(full_frequencies[start_idx_in_sum:end_idx_in_sum + 1], i)
-                    )
-                    frequencies_in_sum = full_frequencies[start_idx_in_sum:end_idx_in_sum + 1]
-                    coeffs[k] += 4 * self.durations[i] / ifo.duration * np.sum(
-                        (banded_frequencies[k + 1] - frequencies_in_sum) * window_over_psd
+                        end_idx_in_sum = min(
+                            end_idx_in_band,
+                            math.ceil(original_duration * banded_frequencies[k + 1]) - 1
+                        )
+                    frequencies_in_sum = np.arange(start_idx_in_sum, end_idx_in_sum + 1) / original_duration
+                    coeffs[k] += prefactor * np.sum(
+                        (banded_frequencies[k + 1] - frequencies_in_sum) *
+                        window_over_psd[start_idx_in_sum - start_idx_in_band:end_idx_in_sum - start_idx_in_band + 1]
                     )
-                    coeffs[k + 1] += 4 * self.durations[i] / ifo.duration * np.sum(
-                        (frequencies_in_sum - banded_frequencies[k]) * window_over_psd
+                    coeffs[k + 1] += prefactor * np.sum(
+                        (frequencies_in_sum - banded_frequencies[k]) *
+                        window_over_psd[start_idx_in_sum - start_idx_in_band:end_idx_in_sum - start_idx_in_band + 1]
                     )
                 self.quadratic_coeffs[ifo.name] = np.append(self.quadratic_coeffs[ifo.name], coeffs)
 
@@ -540,7 +609,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
             full_inv_psds[:len(ifo.power_spectral_density_array)][ifo.frequency_mask] = (
                 1 / ifo.power_spectral_density_array[ifo.frequency_mask]
             )
-            for b in range(len(self.fb_dfb) - 1):
+            for b in range(self.number_of_bands):
                 Imb = np.fft.irfft(full_inv_psds[:self.Nbs[b] // 2 + 1])
                 half_length = Nhatbs[b] // 2
                 Imbc = np.append(Imb[:half_length + 1], Imb[-(Nhatbs[b] - half_length - 1):])
@@ -551,12 +620,79 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
         # precompute windows and their squares
         self.windows = np.array([])
         self.square_root_windows = np.array([])
-        for b in range(len(self.fb_dfb) - 1):
-            start, end = self.start_end_idxs[b]
-            ws = self._window(self.banded_frequency_points[start:end + 1], b)
+        for b in range(self.number_of_bands):
+            Ks, Ke = self.Ks_Ke[b]
+            ws = self._get_window_sequence(1. / self.durations[b], Ks, Ke - Ks + 1, b)
             self.windows = np.append(self.windows, ws)
             self.square_root_windows = np.append(self.square_root_windows, np.sqrt(ws))
 
+    @property
+    def weights(self):
+        _weights = {}
+        for key in [
+            "reference_chirp_mass", "highest_mode", "linear_interpolation",
+            "accuracy_factor", "time_offset", "delta_f_end",
+            "maximum_banding_frequency", "minimum_banding_duration",
+            "durations", "fb_dfb", "Nbs", "Mbs", "Ks_Ke",
+            "banded_frequency_points", "start_end_idxs",
+            "unique_to_original_frequencies", "linear_coeffs"
+        ]:
+            _weights[key] = getattr(self, key)
+        _weights["waveform_frequencies"] = \
+            self.waveform_generator.waveform_arguments['frequencies']
+        if self.linear_interpolation:
+            _weights["quadratic_coeffs"] = self.quadratic_coeffs
+        else:
+            for key in ["Tbhats", "windows", "square_root_windows"]:
+                _weights[key] = getattr(self, key)
+            for key in ["wths", "hbcs", "Ibcs"]:
+                _weights[key] = {}
+                value = getattr(self, key)
+                for ifo_name, data in value.items():
+                    _weights[key][ifo_name] = dict((str(b), v) for b, v in enumerate(data))
+        return _weights
+
+    def save_weights(self, filename):
+        """
+        Save multiband weights into a .hdf5 file.
+
+        Parameters
+        ==========
+        filename : str
+
+        """
+        import h5py
+        if not filename.endswith(".hdf5"):
+            filename += ".hdf5"
+        logger.info(f"Saving multiband weights to {filename}")
+        with h5py.File(filename, 'w') as f:
+            recursively_save_dict_contents_to_group(f, '/', self.weights)
+
+    def setup_multibanding_from_weights(self, weights):
+        """
+        Set multiband weights from dictionary-like weights
+
+        Parameters
+        ==========
+        weights : dict
+
+        """
+        keys = list(weights.keys())
+        # reference_chirp_mass needs to be set first as it is required for the setter of maximum_banding_frequency
+        self.reference_chirp_mass = weights["reference_chirp_mass"]
+        keys.remove("reference_chirp_mass")
+        for key in keys:
+            value = weights[key]
+            if key in ["wths", "hbcs", "Ibcs"]:
+                to_set = {}
+                for ifo_name, data in value.items():
+                    to_set[ifo_name] = [data[str(b)] for b in range(len(data.keys()))]
+                setattr(self, key, to_set)
+            elif key == "waveform_frequencies":
+                self.waveform_generator.waveform_arguments['frequencies'] = weights["waveform_frequencies"]
+            else:
+                setattr(self, key, value)
+
     def calculate_snrs(self, waveform_polarizations, interferometer):
         """
         Compute the snrs for multi-banding
@@ -601,7 +737,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
             )
         else:
             optimal_snr_squared = 0.
-            for b in range(len(self.fb_dfb) - 1):
+            for b in range(self.number_of_bands):
                 Ks, Ke = self.Ks_Ke[b]
                 start_idx, end_idx = self.start_end_idxs[b]
                 Mb = self.Mbs[b]
diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py
index e7e707c287aeb473cf3bb0528d9c78e876a39dd4..47a4a6aca3b92be3bc841d60f571f4f8029c7cab 100644
--- a/test/gw/likelihood_test.py
+++ b/test/gw/likelihood_test.py
@@ -2,6 +2,7 @@ import itertools
 import os
 import pytest
 import unittest
+import tempfile
 from copy import deepcopy
 from itertools import product
 from parameterized import parameterized
@@ -1794,6 +1795,113 @@ class TestMBLikelihood(unittest.TestCase):
                 interferometers=self.ifos, waveform_generator=wfg_mb, priors=self.priors
             )
 
+    @parameterized.expand([(True, ), (False, )])
+    def test_inout_weights(self, linear_interpolation):
+        """
+        Check if multiband weights can be saved as a file, and a likelihood object constructed from the weights file
+        produces the same likelihood value.
+        """
+        approximant = "IMRPhenomD"
+        wfg = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant=approximant
+            )
+        )
+        self.ifos.inject_signal(
+            parameters=self.test_parameters, waveform_generator=wfg
+        )
+
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant=approximant
+            )
+        )
+        likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb,
+            reference_chirp_mass=self.test_parameters['chirp_mass'],
+            linear_interpolation=linear_interpolation,
+        )
+        likelihood_mb.parameters.update(self.test_parameters)
+        llr = likelihood_mb.log_likelihood_ratio()
+
+        with tempfile.TemporaryDirectory() as tmpdirname:
+            # check if weights can be saved as a file
+            filepath = os.path.join(tmpdirname, "weights.hdf5")
+            likelihood_mb.save_weights(filepath)
+            self.assertTrue(os.path.exists(filepath))
+
+            # reset waveform generator to check if likelihood recovered from the weights file properly adds banded
+            # frequency points to waveform arguments
+            wfg_mb = bilby.gw.WaveformGenerator(
+                duration=self.duration, sampling_frequency=self.sampling_frequency,
+                frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
+                waveform_arguments=dict(
+                    reference_frequency=self.fmin, approximant=approximant
+                )
+            )
+            likelihood_mb_from_weights = bilby.gw.likelihood.MBGravitationalWaveTransient(
+                interferometers=self.ifos, waveform_generator=wfg_mb, weights=filepath
+            )
+
+        likelihood_mb_from_weights.parameters.update(self.test_parameters)
+        llr_from_weights = likelihood_mb_from_weights.log_likelihood_ratio()
+
+        self.assertAlmostEqual(llr, llr_from_weights)
+
+    @parameterized.expand([(True, ), (False, )])
+    def test_from_dict_weights(self, linear_interpolation):
+        """
+        Check if a likelihood object constructed from dictionary-like weights produce the same likelihood value
+        """
+        approximant = "IMRPhenomD"
+        wfg = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant=approximant
+            )
+        )
+        self.ifos.inject_signal(
+            parameters=self.test_parameters, waveform_generator=wfg
+        )
+
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant=approximant
+            )
+        )
+        likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb,
+            reference_chirp_mass=self.test_parameters['chirp_mass'],
+            linear_interpolation=linear_interpolation,
+        )
+        likelihood_mb.parameters.update(self.test_parameters)
+        llr = likelihood_mb.log_likelihood_ratio()
+
+        # reset waveform generator to check if likelihood recovered from the weights properly adds banded
+        # frequency points to waveform arguments
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant=approximant
+            )
+        )
+        weights = likelihood_mb.weights
+        likelihood_mb_from_weights = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb, weights=weights
+        )
+        likelihood_mb_from_weights.parameters.update(self.test_parameters)
+        llr_from_weights = likelihood_mb_from_weights.log_likelihood_ratio()
+
+        self.assertAlmostEqual(llr, llr_from_weights)
+
 
 if __name__ == "__main__":
     unittest.main()