waveform_generator.py 5.43 KB
Newer Older
1
2
import inspect

3
from . import utils
4
import numpy as np
5

6
class WaveformGenerator(object):
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
    """ A waveform generator

    Parameters
    ----------
    sampling_frequency: float
        The sampling frequency to sample at
    time_duration: float
        Time duration of data
    source_model: func
        A python function taking some arguments and returning the frequency
        domain strain. Note the first argument must be the frequencies at
        which to compute the strain

    Note: the arguments of source_model (except the first, which is the
    frequencies at which to compute the strain) will be added to the
    WaveformGenerator object and initialised to `None`.

    """
25

26
    def __init__(self, frequency_domain_source_model=None, time_domain_source_model=None, sampling_frequency=4096, time_duration=1,
27
                 parameters=None):
28
        self.time_duration = time_duration
29
        self.sampling_frequency = sampling_frequency
30
31
        self.frequency_domain_source_model = frequency_domain_source_model
        self.time_domain_source_model = time_domain_source_model
32
        self.parameters = parameters
33
34
35
36
37
        self.__frequency_array_updated = False
        self.__time_array_updated = False

    def frequency_domain_strain(self):
        """ Wrapper to source_model """
38
39
40
        if self.frequency_domain_source_model is not None:
            return self.frequency_domain_source_model(self.frequency_array, **self.parameters)
        elif self.time_domain_source_model is not None:
41
            fft_data = dict()
42
43
44
45
46
            time_domain_strain = self.time_domain_source_model(self.time_array, **self.parameters)
            if isinstance(time_domain_strain, np.ndarray):
                return time_domain_strain
            for key in time_domain_strain:
                fft_data[key], self.frequency_array = utils.nfft(time_domain_strain[key], self.sampling_frequency)
47
48
49
            return fft_data
        else:
            raise RuntimeError("No source model given")
50
51

    def time_domain_strain(self):
52
53
54
        if self.time_domain_source_model is not None:
            return self.time_domain_source_model(self.time_array, **self.parameters)
        elif self.frequency_domain_source_model is not None:
55
            ifft_data = dict()
56
57
58
59
60
            frequency_domain_strain = self.frequency_domain_source_model(self.frequency_array, **self.parameters)
            if isinstance(frequency_domain_strain, np.ndarray):
                return frequency_domain_strain
            for key in frequency_domain_strain:
                ifft_data = utils.infft(frequency_domain_strain[key], self.sampling_frequency)
61
            return ifft_data
62
63
        else:
            raise RuntimeError("No source model given")
64

65
66
    @property
    def frequency_array(self):
67
        if self.__frequency_array_updated is False:
68
69
70
71
72
            self.__frequency_array = utils.create_fequency_series(
                                        self.sampling_frequency,
                                        self.time_duration)
            self.__frequency_array_updated = True
        return self.__frequency_array
73

74
75
76
77
    @frequency_array.setter
    def frequency_array(self, frequency_array):
        self.__frequency_array = frequency_array

78
79
    @property
    def time_array(self):
80
        if self.__time_array_updated is False:
81
82
            self.__time_array = utils.create_time_series(
                                        self.sampling_frequency,
83
                                        self.time_duration)
84

85
86
87
            self.__time_array_updated = True
        return self.__time_array

88
89
90
91
92
93
    @property
    def parameters(self):
        return self.__parameters

    @parameters.setter
    def parameters(self, parameters):
94
        if parameters is None:
95
96
97
98
99
100
101
102
            if self.frequency_domain_source_model is not None:
                parameters = inspect.getargspec(self.frequency_domain_source_model).args
                parameters.pop(0)
                self.__parameters = dict.fromkeys(parameters)
            elif self.time_domain_source_model is not None:
                parameters = inspect.getargspec(self.time_domain_source_model).args
                parameters.pop(0)
                self.__parameters = dict.fromkeys(parameters)
103
        elif isinstance(parameters, list):
104
105
106
            parameters.pop(0)
            self.__parameters = dict.fromkeys(parameters)
        elif isinstance(parameters, dict):
107
108
109
110
111
            self.__parameters = parameters
            # for key in self.__parameters.keys():
            #
            #     if key in parameters.keys():
            #         self.__parameters[key] = parameters[key]
112
113
114
                # else:
                #     raise KeyError('The provided dictionary did not '
                #                    'contain key {}'.format(key))
115
        else:
116
117
            raise TypeError('Parameters must either be set as a list of keys or'
                            ' a dictionary of key-value pairs.')
118

119
120
121
122
123
124
125
126
127
128
129
130
131
    @property
    def time_duration(self):
        return self.__time_duration

    @time_duration.setter
    def time_duration(self, time_duration):
        self.__time_duration = time_duration
        self.__frequency_array_updated = False
        self.__time_array_updated = False

    @property
    def sampling_frequency(self):
        return self.__sampling_frequency
132

133
134
135
136
137
    @sampling_frequency.setter
    def sampling_frequency(self, sampling_frequency):
        self.__sampling_frequency = sampling_frequency
        self.__frequency_array_updated = False
        self.__time_array_updated = False