diff --git a/gstlal-inspiral/python/cbc_template_fir.py b/gstlal-inspiral/python/cbc_template_fir.py
index 8b3951f3c15546c64407dfac92e192320e590bc9..75bab8f1350e0c53b3a63744d8e9eaa3ffe69e13 100644
--- a/gstlal-inspiral/python/cbc_template_fir.py
+++ b/gstlal-inspiral/python/cbc_template_fir.py
@@ -94,6 +94,57 @@ except KeyError:
 # =============================================================================
 #
 
+def create_FIR_whitener_kernel(length, duration, sample_rate, psd):
+	assert psd
+	#
+	# Add another COMPLEX16TimeSeries and COMPLEX16FrequencySeries for kernel's FFT (Leo)
+	#
+
+	# Add another FFT plan for kernel FFT (Leo)
+	fwdplan_kernel = lal.CreateForwardCOMPLEX16FFTPlan(length, 1)
+	kernel_tseries = lal.CreateCOMPLEX16TimeSeries(
+		name = "timeseries of whitening kernel",
+		epoch = LIGOTimeGPS(0.),
+		f0 = 0.,
+		deltaT = 1.0 / sample_rate,
+		length = length,
+		sampleUnits = lal.Unit("strain")
+	)
+	kernel_fseries = lal.CreateCOMPLEX16FrequencySeries(
+		name = "freqseries of whitening kernel",
+		epoch = LIGOTimeGPS(0),
+		f0 = 0.0,
+		deltaF = 1.0 / duration,
+		length = length,
+		sampleUnits = lal.Unit("strain s")
+	)
+
+	#
+	# Obtain a kernel of zero-latency whitening filter and
+	# adjust its length (Leo)
+	#
+
+	psd_fir_kernel = reference_psd.PSDFirKernel()
+	(kernel, latency, fir_rate) = psd_fir_kernel.psd_to_linear_phase_whitening_fir_kernel(psd, nyquist = sample_rate / 2.0)
+	(kernel, theta) = psd_fir_kernel.linear_phase_fir_kernel_to_minimum_phase_whitening_fir_kernel(kernel, fir_rate)
+	kernel = kernel[-1::-1]
+	# FIXME this is off by one sample, but shouldn't be. Look at the miminum phase function
+	# assert len(kernel) == length
+	if len(kernel) < length:
+		kernel = numpy.append(kernel, numpy.zeros(length - len(kernel)))
+	else:
+		kernel = kernel[:length]
+
+	kernel_tseries.data.data = kernel
+
+	#
+	# FFT of the kernel
+	#
+
+	lal.COMPLEX16TimeFreqFFT(kernel_fseries, kernel_tseries, fwdplan_kernel) #FIXME
+
+	return kernel_fseries
+
 
 def tukeywindow(data, samps = 200.):
 	assert (len(data) >= 2 * samps) # make sure that the user is requesting something sane
@@ -286,211 +337,149 @@ def condition_psd(psd, newdeltaF, minfs = (35.0, 40.0), maxfs = (1800., 2048.),
 	return psd
 
 
-def generate_templates(template_table, approximant, psd, f_low, time_slices, autocorrelation_length = None, fhigh = None, verbose = False):
-	"""!
-	Generate a bank of templates, which are
-	1. broken up into time slice,
-	2. down-sampled in each time slice and
-	3. whitened with the given psd.
-	"""
-	sample_rate_max = max(time_slices['rate'])
-	duration = max(time_slices['end'])
-	length_max = int(round(duration * sample_rate_max))
-	if fhigh is None:
-		fhigh = sample_rate_max/2.
-	# Some input checking to avoid incomprehensible error messages
-	if not template_table:
-		raise ValueError("template list is empty")
-	if f_low < 0.:
-		raise ValueError("f_low must be >= 0.: %s" % repr(f_low))
-
-	# working f_low to actually use for generating the waveform.  pick
-	# template with lowest chirp mass, compute its duration starting
-	# from f_low;  the extra time is 10% of this plus 3 cycles (3 /
-	# f_low);  invert to obtain f_low corresponding to desired padding.
-	# NOTE:  because SimInspiralChirpStartFrequencyBound() does not
-	# account for spin, we set the spins to 0 in the call to
-	# SimInspiralChirpTimeBound() regardless of the component's spins.
-	template = min(template_table, key = lambda row: row.mchirp)
-	tchirp = lalsim.SimInspiralChirpTimeBound(f_low, template.mass1 * lal.MSUN_SI, template.mass2 * lal.MSUN_SI, 0., 0.)
-	working_f_low = lalsim.SimInspiralChirpStartFrequencyBound(1.1 * tchirp + 3. / f_low, template.mass1 * lal.MSUN_SI, template.mass2 * lal.MSUN_SI)
-
-	# Add duration of PSD to template length for PSD ringing, round up to power of 2 count of samples
-	working_length = templates.ceil_pow_2(length_max + round(1./psd.deltaF * sample_rate_max))
-	working_duration = float(working_length) / sample_rate_max
-
-	# Smooth the PSD and interpolate to required resolution
-	if not FIR_WHITENER and psd is not None:
-		psd = condition_psd(psd, 1.0 / working_duration, minfs = (working_f_low, f_low), maxfs = (sample_rate_max / 2.0 * 0.90, sample_rate_max / 2.0))
-	else:
-		psd = reference_psd.interpolate_psd(psd, 1.0 / working_duration)
-	revplan = lal.CreateReverseCOMPLEX16FFTPlan(working_length, 1)
-	fwdplan = lal.CreateForwardREAL8FFTPlan(working_length, 1)
-	tseries = lal.CreateCOMPLEX16TimeSeries(
-		name = "timeseries",
-		epoch = LIGOTimeGPS(0.),
-		f0 = 0.,
-		deltaT = 1.0 / sample_rate_max,
-		length = working_length,
-		sampleUnits = lal.Unit("strain")
-	)
-	fworkspace = lal.CreateCOMPLEX16FrequencySeries(
-		name = "template",
-		epoch = LIGOTimeGPS(0),
-		f0 = 0.0,
-		deltaF = 1.0 / working_duration,
-		length = working_length // 2 + 1,
-		sampleUnits = lal.Unit("strain s")
-	)
-
-	if FIR_WHITENER:
-		assert psd
-		#
-		# Add another COMPLEX16TimeSeries and COMPLEX16FrequencySeries for kernel's FFT (Leo)
-		#
-
-		# Add another FFT plan for kernel FFT (Leo)
-		fwdplan_kernel = lal.CreateForwardCOMPLEX16FFTPlan(working_length, 1)
-		kernel_tseries = lal.CreateCOMPLEX16TimeSeries(
-			name = "timeseries of whitening kernel",
+class templates_workspace(object):
+	def __init__(self, template_table, approximant, psd, f_low, time_slices, autocorrelation_length = None, fhigh = None):
+		self.template_table = template_table
+		self.approximant = approximant
+		self.f_low = f_low
+		self.time_slices = time_slices
+		self.autocorrelation_length = autocorrelation_length
+		self.fhigh = fhigh
+		self.sample_rate_max = max(time_slices["rate"])
+		self.duration = max(time_slices["end"])
+		self.length_max = int(round(self.duration * self.sample_rate_max))
+
+		if self.fhigh is None:
+			self.fhigh = self.sample_rate_max / 2.
+		# Some input checking to avoid incomprehensible error messages
+		if not self.template_table:
+			raise ValueError("template list is empty")
+		if self.f_low < 0.:
+			raise ValueError("f_low must be >= 0. %s" % repr(self.f_low))
+
+		# working f_low to actually use for generating the waveform.  pick
+		# template with lowest chirp mass, compute its duration starting
+		# from f_low;  the extra time is 10% of this plus 3 cycles (3 /
+		# f_low);  invert to obtain f_low corresponding to desired padding.
+		# NOTE:  because SimInspiralChirpStartFrequencyBound() does not
+		# account for spin, we set the spins to 0 in the call to
+		# SimInspiralChirpTimeBound() regardless of the component's spins.
+		template = min(self.template_table, key = lambda row: row.mchirp)
+		tchirp = lalsim.SimInspiralChirpTimeBound(self.f_low, template.mass1 * lal.MSUN_SI, template.mass2 * lal.MSUN_SI, 0., 0.)
+		working_f_low = lalsim.SimInspiralChirpStartFrequencyBound(1.1 * tchirp + 3. / self.f_low, template.mass1 * lal.MSUN_SI, template.mass2 * lal.MSUN_SI)
+
+		# Add duration of PSD to template length for PSD ringing, round up to power of 2 count of samples
+		self.working_length = templates.ceil_pow_2(self.length_max + round(1./psd.deltaF * self.sample_rate_max))
+		self.working_duration = float(self.working_length) / self.sample_rate_max
+
+		# Smooth the PSD and interpolate to required resolution
+		if not FIR_WHITENER and psd is not None:
+			self.psd = condition_psd(psd, 1.0 / self.working_duration, minfs = (working_f_low, self.f_low), maxfs = (self.sample_rate_max / 2.0 * 0.90, self.sample_rate_max / 2.0))
+		else:
+			self.psd = reference_psd.interpolate_psd(psd, 1.0 / self.working_duration)
+		self.revplan = lal.CreateReverseCOMPLEX16FFTPlan(self.working_length, 1)
+		self.fwdplan = lal.CreateForwardREAL8FFTPlan(self.working_length, 1)
+		self.tseries = lal.CreateCOMPLEX16TimeSeries(
+			name = "timeseries",
 			epoch = LIGOTimeGPS(0.),
 			f0 = 0.,
-			deltaT = 1.0 / sample_rate_max,
-			length = working_length,
+			deltaT = 1.0 / self.sample_rate_max,
+			length = self.working_length,
 			sampleUnits = lal.Unit("strain")
 		)
-		kernel_fseries = lal.CreateCOMPLEX16FrequencySeries(
-			name = "freqseries of whitening kernel",
+		self.fworkspace = lal.CreateCOMPLEX16FrequencySeries(
+			name = "template",
 			epoch = LIGOTimeGPS(0),
 			f0 = 0.0,
-			deltaF = 1.0 / working_duration,
-			length = working_length,
+			deltaF = 1.0 / self.working_duration,
+			length = self.working_length // 2 + 1,
 			sampleUnits = lal.Unit("strain s")
 		)
 
-		#
-		# Obtain a kernel of zero-latency whitening filter and
-		# adjust its length (Leo)
-		#
+		if FIR_WHITENER:
+			self.kernel_fseries = create_FIR_whitener_kernel(self.working_length, self.working_duration, self.sample_rate_max, self.psd)
 
-		psd_fir_kernel = reference_psd.PSDFirKernel()
-		(kernel, latency, fir_rate) = psd_fir_kernel.psd_to_linear_phase_whitening_fir_kernel(psd, nyquist = sample_rate_max / 2.0)
-		(kernel, theta) = psd_fir_kernel.linear_phase_fir_kernel_to_minimum_phase_whitening_fir_kernel(kernel, fir_rate)
-		kernel = kernel[-1::-1]
-		# FIXME this is off by one sample, but shouldn't be. Look at the miminum phase function
-		# assert len(kernel) == working_length
-		if len(kernel) < working_length:
-			kernel = numpy.append(kernel, numpy.zeros(working_length - len(kernel)))
+		# Calculate the maximum ring down time or maximum shift time
+		if approximant in templates.gstlal_IMR_approximants:
+			self.max_ringtime = max([chirptime.ringtime(row.mass1*lal.MSUN_SI + row.mass2*lal.MSUN_SI, chirptime.overestimate_j_from_chi(max(row.spin1z, row.spin2z))) for row in self.template_table])
 		else:
-			kernel = kernel[:working_length]
-
-		kernel_tseries.data.data = kernel
-
-		#
-		# FFT of the kernel
-		#
-
-		lal.COMPLEX16TimeFreqFFT(kernel_fseries, kernel_tseries, fwdplan_kernel) #FIXME
-
-	# Check parity of autocorrelation length
-	if autocorrelation_length is not None:
-		if not (autocorrelation_length % 2):
-			raise ValueError, "autocorrelation_length must be odd (got %d)" % autocorrelation_length
-		autocorrelation_bank = numpy.zeros((len(template_table), autocorrelation_length), dtype = "cdouble")
-		autocorrelation_mask = compute_autocorrelation_mask( autocorrelation_bank )
-	else:
-		autocorrelation_bank = None
-		autocorrelation_mask = None
+			if self.sample_rate_max > 2. * self.fhigh:
+			# Calculate the maximum time we need to shift the early warning
+			# waveforms forward by, calculated by the 3.5 approximation from
+			# fhigh to ISCO.
+				self.max_shift_time = max([spawaveform.chirptime(row.mass1, row.mass2, 7, fhigh, 0., spawaveform.computechi(row.mass1, row.mass2, row.spin1z, row.spin2z)) for row in self.template_table])
 
-	# Multiply by 2 * length of the number of sngl_inspiral rows to get the sine/cosine phases.
-	template_bank = [numpy.zeros((2 * len(template_table), int(round(rate*(end-begin)))), dtype = "double") for rate,begin,end in time_slices]
-
-	# Store the original normalization of the waveform.  After
-	# whitening, the waveforms are normalized.  Use the sigmasq factors
-	# to get back the original waveform.
-	sigmasq = []
-
-	if approximant in templates.gstlal_IMR_approximants:
-		max_ringtime = max([chirptime.ringtime(row.mass1*lal.MSUN_SI + row.mass2*lal.MSUN_SI, chirptime.overestimate_j_from_chi(max(row.spin1z, row.spin2z))) for row in template_table])
+			#
+			# Generate each template, downsampling as we go to save memory
+			# generate "cosine" component of frequency-domain template.
+			# waveform is generated for a canonical distance of 1 Mpc.
+			#
 
-	else:
-		if sample_rate_max>2.*fhigh:
-		# Calculate the maximum time we need to shift the early warning
-		# waveforms forward by, calculated by the 3.5 approximation from
-		# fhigh to ISCO.
-			max_shift_time = max([spawaveform.chirptime(row.mass1, row.mass2, 7, fhigh, 0., spawaveform.computechi(row.mass1, row.mass2, row.spin1z, row.spin2z)) for row in template_table])
+	def make_whitened_template(self, template_table_row):
+		# FIXME: This is won't work
+		#assert template_table_row in self.template_table, "The input Sngl_Inspiral:Table is not found in the workspace."
 
-		#
-		# Generate each template, downsampling as we go to save memory
-		# generate "cosine" component of frequency-domain template.
-		# waveform is generated for a canonical distance of 1 Mpc.
-		#
-
-	for i, row in enumerate(template_table):
-		if verbose:
-			print >>sys.stderr, "generating template %d/%d:  m1 = %g, m2 = %g, s1x = %g, s1y = %g, s1z = %g, s2x = %g, s2y = %g, s2z = %g" % (i + 1, len(template_table), row.mass1, row.mass2, row.spin1x, row.spin1y, row.spin1z, row.spin2x, row.spin2y, row.spin2z)
-
-		fseries = generate_template(row, approximant, sample_rate_max, working_duration, f_low, fhigh, fwdplan = fwdplan, fworkspace = fworkspace)
+		# Create template
+		fseries = generate_template(template_table_row, self.approximant, self.sample_rate_max, self.working_duration, self.f_low, self.fhigh, fwdplan = self.fwdplan, fworkspace = self.fworkspace)
 
 		if FIR_WHITENER:
 			#
 			# Compute a product of freq series of the whitening kernel and the template (convolution in time domain) then add quadrature phase(Leo)
 			#
-
-			assert (len(kernel_fseries.data.data) // 2 + 1) == len(fseries.data.data), "the size of whitening kernel freq series does not match with a given format of COMPLEX16FrequencySeries."
-			fseries.data.data *= kernel_fseries.data.data[len(kernel_fseries.data.data) // 2 - 1:]
-			fseries = templates.QuadraturePhase.add_quadrature_phase(fseries, working_length)
+			assert (len(self.kernel_fseries.data.data) // 2 + 1) == len(fseries.data.data), "the size of whitening kernel freq series does not match with a given format of COMPLEX16FrequencySeries."
+			fseries.data.data *= self.kernel_fseries.data.data[len(self.kernel_fseries.data.data) // 2 - 1:]
+			fseries = templates.QuadraturePhase.add_quadrature_phase(fseries, self.working_length)
 		else:
 			#
 			# whiten and add quadrature phase ("sine" component)
 			#
 
-			if psd is not None:
-				lal.WhitenCOMPLEX16FrequencySeries(fseries, psd)
-				fseries = templates.QuadraturePhase.add_quadrature_phase(fseries, working_length)
-
+			if self.psd is not None:
+				lal.WhitenCOMPLEX16FrequencySeries(fseries, self.psd)
+				fseries = templates.QuadraturePhase.add_quadrature_phase(fseries, self.working_length)
 
 		#
 		# compute time-domain autocorrelation function
 		#
 
-		if autocorrelation_bank is not None:
-			autocorrelation = templates.normalized_autocorrelation(fseries, revplan).data.data
-			autocorrelation_bank[i, ::-1] = numpy.concatenate((autocorrelation[-(autocorrelation_length // 2):], autocorrelation[:(autocorrelation_length // 2  + 1)]))
+		if self.autocorrelation_length is not None:
+			autocorrelation = templates.normalized_autocorrelation(fseries, self.revplan).data.data
+		else:
+			autocorrelation = None
 
 		#
 		# transform template to time domain
 		#
 
-		lal.COMPLEX16FreqTimeFFT(tseries, fseries, revplan)
+		lal.COMPLEX16FreqTimeFFT(self.tseries, fseries, self.revplan)
 
-		data = tseries.data.data
+		data = self.tseries.data.data
 		epoch_time = fseries.epoch.gpsSeconds + fseries.epoch.gpsNanoSeconds*1.e-9
+
 		#
 		# extract the portion to be used for filtering
 		#
 
-
 		#
 		# condition the template if necessary (e.g. line up IMR
 		# waveforms by peak amplitude)
 		#
 
-		if approximant in templates.gstlal_IMR_approximants:
-			data, target_index = condition_imr_template(approximant, data, epoch_time, sample_rate_max, max_ringtime)
+		if self.approximant in templates.gstlal_IMR_approximants:
+			data, target_index = condition_imr_template(self.approximant, data, epoch_time, self.sample_rate_max, self.max_ringtime)
 			# record the new end times for the waveforms (since we performed the shifts)
-			row.end = LIGOTimeGPS(float(target_index-(len(data) - 1.))/sample_rate_max)
-
+			template_table_row.end = LIGOTimeGPS(float(target_index-(len(data) - 1.))/self.sample_rate_max)
 		else:
-			if sample_rate_max > fhigh*2.:
-				data, target_index = condition_ear_warn_template(approximant, data, epoch_time, sample_rate_max, max_shift_time)
+			if self.sample_rate_max > self.fhigh*2.:
+				data, target_index = condition_ear_warn_template(self.approximant, data, epoch_time, self.sample_rate_max, self.max_shift_time)
 				data *= tukeywindow(data, samps = 32)
 				# record the new end times for the waveforms (since we performed the shifts)
-				row.end = LIGOTimeGPS(float(target_index)/sample_rate_max)
+				template_table_row.end = LIGOTimeGPS(float(target_index)/self.sample_rate_max)
 			else:
 				data *= tukeywindow(data, samps = 32)
 
-		data = data[-length_max:]
+		data = data[-self.length_max:]
+
 		#
 		# normalize so that inner product of template with itself
 		# is 2
@@ -527,7 +516,43 @@ def generate_templates(template_table, approximant, psd, f_low, time_slices, aut
 		# sigmasq = norm * N * (\Delta t)^2
 		#
 
-		sigmasq.append(norm * len(data) / sample_rate_max**2.)
+		sigmasq = norm * len(data) / self.sample_rate_max**2.
+
+		return data, autocorrelation, sigmasq
+
+
+def generate_templates(template_table, approximant, psd, f_low, time_slices, autocorrelation_length = None, fhigh = None, verbose = False):
+	# Create workspace for making template bank
+	workspace = templates_workspace(template_table, approximant, psd, f_low, time_slices, autocorrelation_length = autocorrelation_length, fhigh = fhigh)
+
+	# Check parity of autocorrelation length
+	if autocorrelation_length is not None:
+		if not (autocorrelation_length % 2):
+			raise ValueError, "autocorrelation_length must be odd (got %d)" % autocorrelation_length
+		autocorrelation_bank = numpy.zeros((len(template_table), autocorrelation_length), dtype = "cdouble")
+		autocorrelation_mask = compute_autocorrelation_mask( autocorrelation_bank )
+	else:
+		autocorrelation_bank = None
+		autocorrelation_mask = None
+
+	# Multiply by 2 * length of the number of sngl_inspiral rows to get the sine/cosine phases.
+	template_bank = [numpy.zeros((2 * len(template_table), int(round(rate*(end-begin)))), dtype = "double") for rate,begin,end in time_slices]
+
+	# Store the original normalization of the waveform.  After
+	# whitening, the waveforms are normalized.  Use the sigmasq factors
+	# to get back the original waveform.
+	sigmasq = []
+
+	for i, row in enumerate(template_table):
+		if verbose:
+			print >>sys.stderr, "generating template %d/%d:  m1 = %g, m2 = %g, s1x = %g, s1y = %g, s1z = %g, s2x = %g, s2y = %g, s2z = %g" % (i + 1, len(template_table), row.mass1, row.mass2, row.spin1x, row.spin1y, row.spin1z, row.spin2x, row.spin2y, row.spin2z)
+		# FIXME: ensure the row is in template_table
+		template, autocorrelation, this_sigmasq = workspace.make_whitened_template(row)
+
+		sigmasq.append(this_sigmasq)
+
+		if autocorrelation is not None:
+			autocorrelation_bank[i, ::-1] = numpy.concatenate((autocorrelation[-(autocorrelation_length // 2):], autocorrelation[:(autocorrelation_length // 2  + 1)]))
 
 		#
 		# copy real and imaginary parts into adjacent (real-valued)
@@ -543,11 +568,12 @@ def generate_templates(template_table, approximant, psd, f_low, time_slices, aut
 			# but probaby has something to do with the reversal
 			# of the open/closed boundary conditions through
 			# all of this (argh!  Chad!)
-			stride = int(round(sample_rate_max / time_slice['rate']))
-			begin_index = length_max - int(round(time_slice['begin'] * sample_rate_max)) + stride - 1
-			end_index = length_max - int(round(time_slice['end'] * sample_rate_max)) + stride - 1
+
+			stride = int(round(workspace.sample_rate_max / time_slice['rate']))
+			begin_index = workspace.length_max - int(round(time_slice['begin'] * workspace.sample_rate_max)) + stride - 1
+			end_index = workspace.length_max - int(round(time_slice['end'] * workspace.sample_rate_max)) + stride - 1
 			# make sure the rates are commensurate
-			assert stride * time_slice['rate'] == sample_rate_max
+			assert stride * time_slice['rate'] == workspace.sample_rate_max
 
 			# extract every stride-th sample.  we multiply by
 			# \sqrt{stride} to maintain inner product
@@ -558,10 +584,10 @@ def generate_templates(template_table, approximant, psd, f_low, time_slices, aut
 			# normalization of the basis vectors used for
 			# filtering but it ensures that the chifacs values
 			# have the correct relative normalization.
-			template_bank[j][(2*i+0),:] = data.real[end_index:begin_index:stride] * math.sqrt(stride)
-			template_bank[j][(2*i+1),:] = data.imag[end_index:begin_index:stride] * math.sqrt(stride)
+			template_bank[j][(2*i+0),:] = template.real[end_index:begin_index:stride] * math.sqrt(stride)
+			template_bank[j][(2*i+1),:] = template.imag[end_index:begin_index:stride] * math.sqrt(stride)
 
-	return template_bank, autocorrelation_bank, autocorrelation_mask, sigmasq, psd
+	return template_bank, autocorrelation_bank, autocorrelation_mask, sigmasq, workspace.psd
 
 
 def decompose_templates(template_bank, tolerance, identity = False):