waveform_generator.py 5.31 KB
Newer Older
1
2
import inspect

3
from . import utils
4

5

6
class WaveformGenerator(object):
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
    """ A waveform generator

    Parameters
    ----------
    sampling_frequency: float
        The sampling frequency to sample at
    time_duration: float
        Time duration of data
    source_model: func
        A python function taking some arguments and returning the frequency
        domain strain. Note the first argument must be the frequencies at
        which to compute the strain

    Note: the arguments of source_model (except the first, which is the
    frequencies at which to compute the strain) will be added to the
    WaveformGenerator object and initialised to `None`.

    """
25

26
    def __init__(self, frequency_domain_source_model=None, time_domain_source_model=None, sampling_frequency=4096, time_duration=1,
27
                 parameters=None):
28
        self.time_duration = time_duration
29
        self.sampling_frequency = sampling_frequency
30
31
        self.frequency_domain_source_model = frequency_domain_source_model
        self.time_domain_source_model = time_domain_source_model
32
        self.parameters = parameters
33
34
35
36
37
        self.__frequency_array_updated = False
        self.__time_array_updated = False

    def frequency_domain_strain(self):
        """ Wrapper to source_model """
38
39
40
        if self.frequency_domain_source_model is not None:
            return self.frequency_domain_source_model(self.frequency_array, **self.parameters)
        elif self.time_domain_source_model is not None:
41
42
43
            fft_data = dict()
            fft_data['cross'], self.frequency_array = utils.nfft(self.time_domain_source_model(self.time_array, **self.parameters)['cross'], self.sampling_frequency)
            fft_data['plus'], _ = utils.nfft(self.time_domain_source_model(self.time_array, **self.parameters)['plus'], self.sampling_frequency)
44
45
46
            return fft_data
        else:
            raise RuntimeError("No source model given")
47
48

    def time_domain_strain(self):
49
50
51
        if self.time_domain_source_model is not None:
            return self.time_domain_source_model(self.time_array, **self.parameters)
        elif self.frequency_domain_source_model is not None:
52
53
54
55
            ifft_data = dict()
            ifft_data['cross'] = utils.infft(self.frequency_domain_source_model(self.frequency_array, **self.parameters)['cross'], self.sampling_frequency)
            ifft_data['plus'] = utils.infft(self.frequency_domain_source_model(self.frequency_array, **self.parameters)['plus'], self.sampling_frequency)
            return ifft_data
56
57
        else:
            raise RuntimeError("No source model given")
58

59
60
    @property
    def frequency_array(self):
61
        if self.__frequency_array_updated is False:
62
63
64
65
66
            self.__frequency_array = utils.create_fequency_series(
                                        self.sampling_frequency,
                                        self.time_duration)
            self.__frequency_array_updated = True
        return self.__frequency_array
67

68
69
70
71
    @frequency_array.setter
    def frequency_array(self, frequency_array):
        self.__frequency_array = frequency_array

72
73
    @property
    def time_array(self):
74
        if self.__time_array_updated is False:
75
76
            self.__time_array = utils.create_time_series(
                                        self.sampling_frequency,
77
                                        self.time_duration)
78

79
80
81
            self.__time_array_updated = True
        return self.__time_array

82
83
84
85
86
87
    @property
    def parameters(self):
        return self.__parameters

    @parameters.setter
    def parameters(self, parameters):
88
        if parameters is None:
89
90
91
92
93
94
95
96
            if self.frequency_domain_source_model is not None:
                parameters = inspect.getargspec(self.frequency_domain_source_model).args
                parameters.pop(0)
                self.__parameters = dict.fromkeys(parameters)
            elif self.time_domain_source_model is not None:
                parameters = inspect.getargspec(self.time_domain_source_model).args
                parameters.pop(0)
                self.__parameters = dict.fromkeys(parameters)
97
        elif isinstance(parameters, list):
98
99
100
            parameters.pop(0)
            self.__parameters = dict.fromkeys(parameters)
        elif isinstance(parameters, dict):
101
102
103
104
105
            self.__parameters = parameters
            # for key in self.__parameters.keys():
            #
            #     if key in parameters.keys():
            #         self.__parameters[key] = parameters[key]
106
107
108
                # else:
                #     raise KeyError('The provided dictionary did not '
                #                    'contain key {}'.format(key))
109
        else:
110
111
            raise TypeError('Parameters must either be set as a list of keys or'
                            ' a dictionary of key-value pairs.')
112

113
114
115
116
117
118
119
120
121
122
123
124
125
    @property
    def time_duration(self):
        return self.__time_duration

    @time_duration.setter
    def time_duration(self, time_duration):
        self.__time_duration = time_duration
        self.__frequency_array_updated = False
        self.__time_array_updated = False

    @property
    def sampling_frequency(self):
        return self.__sampling_frequency
126

127
128
129
130
131
    @sampling_frequency.setter
    def sampling_frequency(self, sampling_frequency):
        self.__sampling_frequency = sampling_frequency
        self.__frequency_array_updated = False
        self.__time_array_updated = False