actions.py 20.6 KB
Newer Older
1 2 3
"""
Actions
"""
4 5

from finesse.exceptions import ParameterLocked
6 7 8
from finesse.solutions import BaseSolution
from finesse.analysis.runners import run_axes_scan
from finesse.solutions import ArraySolution
9
from finesse.parameter import Parameter
10

11
import regex as re
12 13
import logging
import textwrap
14
import numpy as np
15 16 17 18 19 20 21

LOGGER = logging.getLogger(__name__)
import weakref
import finesse
from copy import deepcopy


22 23 24 25 26
def get_param(model, attr):
    el, p = attr.split(".")
    return getattr(model.elements[el], p)


27 28 29 30 31 32 33
def get_sweep_array(start: float, stop: float, steps: int, mode="lin"):
    start = float(start)
    stop = float(stop)
    steps = int(steps)
    if steps <= 0:
        raise Exception("Steps must be greater than 0")

34 35 36 37 38 39 40 41 42
    if mode == "lin":
        arr = np.linspace(start, stop, steps + 1)
    else:
        arr = np.logspace(np.log10(start), np.log10(stop), steps + 1)

    return arr


class ActionWorkspace:
43
    def __init__(self, s_prev, sim):
44
        self.s_prev = s_prev
45 46
        self.sim = sim
        self.model = sim.model
47 48 49
        self.fn_do = None


50 51
class AnalysisStepInfo(finesse.solutions.base.ParameterChangingTreeNode):
    def __init__(self, action, makes_solution=False, parameters_changing=tuple()):
52
        if action.name is None and makes_solution is True:
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
            raise Exception("A name must be supplied if an action produces a solution")
        super().__init__(action.name)
        self.action = weakref.ref(action)
        self.makes_a_solution = makes_solution
        self.empty = makes_solution
        self.parameters_changing = parameters_changing

    def __str__(self):
        def fn_name(child):
            act = child.action()
            if (
                act is not None
                and act.__class__.__name__ != child.name
                and (act.__class__.__name__ is not Action.__name__)
            ):
                return f"{child.name} - {act.__class__.__name__}"
            else:
                return child.name

        return self.draw_tree(fn_name, title="Analysis Info Tree")

    def __repr__(self):
        return f"<{self.__class__.__name__} of {self.get_path()} @ {hex(id(self))} children={len(self.children)}>"


class Action:
    def __init__(self, name=None):
        self.__name = name
        self._info = AnalysisStepInfo(self)

    def copy_info(self):
        return deepcopy(self._info)

    @property
    def name(self):
        return self.__name

    def fill_info(self, p_info):
        p_info.add(self.copy_info())

93
    def setup(self, s_prev, sim):
Daniel Brown's avatar
Daniel Brown committed
94
        # By default simple actions should just run their 'do' method
95
        ws = ActionWorkspace(s_prev, sim)
96 97 98
        ws.fn_do = self.do
        return ws

99
    def run(self, model, reset_params=False):
100 101 102 103 104 105 106 107 108
        if model.is_built:
            raise Exception("Model is currently built")

        # Create a top level information object and pass that
        # along to all the actions so we can find out what's
        # happening
        info = AnalysisStepInfo(Action("Start"))
        self.fill_info(info)

109
        params = [get_param(model, pstr) for pstr in info.get_all_parameters_changing()]
110

111
        initial_param_values = {}
112 113
        for p in params:
            p.is_tunable = True
114
            initial_param_values[p] = p.value
115

116
        with model.built() as sim:
117 118
            try:
                s = BaseSolution(None, None)
119
                ws = self.setup(s, sim)
120
                ws.fn_do(ws)
121 122 123 124 125
            except StopIteration:
                raise Exception("Should we reach this point? Probably unexpected")
            finally:
                for p in params:
                    p.is_tunable = False
126
                    if reset_params:  # Reset scanned parameters to initial values
127 128 129 130
                        # NOTE (sjr) Important to do this after switching off tunable flag
                        #            so that p.is_changing is now False -> means set_value
                        #            of GeometricParameter then updates the associated ABCD
                        #            matrices (back to initial state)
131
                        p.value = initial_param_values[p]
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

        return s

    def __str__(self):
        info = AnalysisStepInfo(Action("Start"))
        self.fill_info(info)
        return str(info)


class BeamTrace(Action):
    """Action for tracing the beam throughout an entire model."""

    def __init__(self, name, **kwargs):
        super().__init__(name)
        self.kwargs = kwargs
        self._info.makes_solution = True

149 150
    def do(self, ws):
        ws.s_prev.add(ws.model.beam_trace(solution_name=self.name, **self.kwargs))
151 152


153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
class Scale(Action):
    """Action for scaling simulation outputs by some fixed amount.
    Included for compatibility with legacy Finesse code. New users
    should apply any desired scalings manually from Python.

    Parameters
    ----------
    detectors : dict
        A dictionary of `detector name: scaling factor` mappings.
    """

    def __init__(self, name, scales, **kwargs):
        super().__init__(name)
        self.kwargs = kwargs
        self._info.makes_solution = False
        self.scales = scales

    def do(self, ws):
        sol = ws.s_prev[-1]
        for det, fac in self.scales.items():
            sol._outputs[det][()] *= fac


176 177 178 179 180 181 182 183
class ABCD(Action):
    """Action to compute a composite ABCD matrix over a given path of a model."""

    def __init__(self, name, **kwargs):
        super().__init__(name)
        self.kwargs = kwargs
        self._info.makes_solution = True

184 185
    def do(self, ws):
        ws.s_prev.add(ws.model.ABCD(solution_name=self.name, **self.kwargs))
186 187


188
class StepParamNDWorkspace(ActionWorkspace):
189
    pass
190 191


192 193 194 195 196 197
class StepParamND(Action):
    def __init__(self, name, *args, pre_step=None, post_step=None, on_complete=None):
        super().__init__(name)

        if len(args) % 3 != 0:
            raise Exception(
198
                "Arguments must be triplets of parameter, array of "
199 200 201 202
                "values to scan over, and offset to array values."
            )

        self.args = args
203
        self._info.parameters_changing = tuple(
204
            f"{p.component.name}.{p.name}" if isinstance(p, Parameter) else p
205 206
            for p in args[::3]
        )
207 208

        self.axes = tuple(np.atleast_1d(_) for _ in args[1::3])
209
        self.offsets = np.array(args[2::3], dtype=np.float64)
210 211 212 213 214
        self.out_shape = tuple(np.size(_) for _ in self.axes)
        self._info.makes_solution = True

        LOGGER.info("Scanning parameters %s", list(self._info.parameters_changing))

215 216
        self.pre_step = pre_step
        self.post_step = post_step
217
        self.on_complete = on_complete
218

219 220 221 222
    def fill_info(self, p_info):
        info = self.copy_info()
        p_info.add(info)

223
        if self.pre_step:
224
            Folder("pre_step", self.pre_step).fill_info(info)
225
        if self.post_step:
226 227 228
            Folder("post_step", self.post_step).fill_info(info)
        if self.on_complete:
            Folder("on_complete", self.on_complete).fill_info(info)
229

230 231
    def setup(self, s_prev, sim):
        ws = StepParamNDWorkspace(s_prev, sim)
232 233 234
        ws.info = self.copy_info()
        ws.fn_do = self.do

235
        ws.params = tuple(get_param(ws.model, p) for p in self._info.parameters_changing)
236 237 238 239 240
        for p in ws.params:
            if not p.is_tunable:
                raise ParameterLocked(
                    f"{repr(p)} must set as tunable " "before building the simulation"
                )
241
        return ws
242

243
    def do(self, ws: StepParamNDWorkspace):
244
        ws.sol = ArraySolution(
245
            self.name, ws.s_prev, ws.sim.detector_workspaces, self.out_shape, self.axes, ws.params
246
        )
247

248
        if self.pre_step:
249
            ws.pre_step = Folder("pre_step", self.pre_step).setup(ws.sol, ws.sim)
250 251
        else:
            ws.pre_step = None
252

253
        if self.post_step:
254
            ws.post_step = Folder("post_step", self.post_step).setup(ws.sol, ws.sim)
255 256
        else:
            ws.post_step = None
257 258 259

        # Now we loop over the actual simulation and run each point
        run_axes_scan(
260
            ws.sim,
261
            self.axes,
262
            ws.params,
263 264
            self.offsets,
            self.out_shape,
265 266 267
            ws.sol,
            ws.pre_step,
            ws.post_step,
268 269
        )

270
        if self.on_complete:
271
            ws = Folder("on_complete", self.on_complete).setup(ws.sol, ws.sim)
272
            ws.fn_do(ws)
273 274 275


class Noxaxis(StepParamND):
276
    def __init__(self, *, name="noxaxis", **kwargs):
277
        super().__init__(name, **kwargs)
278 279 280


class XNaxis(StepParamND):
281
    def __init__(
282
        self, name, *args, pre_step=None, post_step=None, on_complete=None,
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
    ):
        if len(args) % 6 != 0:
            raise Exception(
                "XNaxis arguments must come in groups of six: parameter, mode, start, stop, steps, offset"
            )

        self.N = len(args) // 6

        if self.N == 0:
            raise Exception("XNaxis requires at least one axis to be specified")

        self.__set_args = args
        new_args = []

        for i in range(0, len(args), 6):
            new_args.append(args[i + 0])
            new_args.append(
                get_sweep_array(args[i + 2], args[i + 3], args[i + 4], args[i + 1])
            )
            new_args.append(args[i + 5])

        super().__init__(
Sean Leavey's avatar
Sean Leavey committed
305 306 307 308 309
            name,
            *new_args,
            pre_step=pre_step,
            post_step=post_step,
            on_complete=on_complete,
310
        )
311

312
    def setup(self, s_prev, sim):
313 314
        # If the model has locks, set them up to happen on the
        # pre-step of StepParamND action
315 316
        if len(sim.model.locks) > 0:
            self.pre_step = RunLocks(*sim.model.locks)
317

318
        return super().setup(s_prev, sim)
319

320
    def do(self, ws: StepParamNDWorkspace):
321
        model = ws.sim.model.deepcopy()
322 323

        params = (
324
            get_param(model, pstr) for pstr in ws.info.get_all_parameters_changing()
325 326
        )

327 328 329
        try:
            for p in params:
                p.is_tunable = True
330

331 332 333 334 335 336
            with model.built():
                super().do(ws)
                
        finally:
            for p in params:
                p.is_tunable = False
337 338


339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
    def __getattr__(self, key):
        res = re.match("(parameter|mode|start|stop|steps|offset)([0-9]*)", key)
        if res is None:
            super().__getattribute__(key)
        else:
            grp = res.groups()
            N = 1 if grp[1] == "" else int(grp[1])
            if N == 0:
                raise Exception("Specify an axes greater than 0")
            if N > self.N:
                raise Exception(f"This xaxis does not have {N} axes")
            idx = 6 * (N - 1)
            if grp[0] == "parameter":
                return self.__set_args[idx + 0]
            elif grp[0] == "mode":
                return self.__set_args[idx + 1]
            elif grp[0] == "start":
                return self.__set_args[idx + 2]
            elif grp[0] == "stop":
                return self.__set_args[idx + 3]
            elif grp[0] == "steps":
                return self.__set_args[idx + 4]
            elif grp[0] == "offset":
                return self.__set_args[idx + 5]

364 365

class Xaxis(XNaxis):
Sean Leavey's avatar
Sean Leavey committed
366 367 368 369
    """Scans a parameter between two points for a number of steps.

    Parameters
    ----------
370
    param : :class:`.Parameter`
Sean Leavey's avatar
Sean Leavey committed
371
        Parameter of component to scan.
372 373
    mode : str
        'lin' or 'log' for linear or logarithmic step sizes.
Sean Leavey's avatar
Sean Leavey committed
374 375 376 377 378 379 380 381 382 383
    start, stop : float
        Start and end values of the scan.
    steps : int
        Number of steps between start and end.
    offset : float, optional
        Offset to scanned values. For a given xaxis point, `param` will be set to `x[i] + offset`.
    name : str, optional
        Name used for storage of the results; defaults to "xaxis".
    """

384
    def __init__(
385 386 387 388 389 390 391 392
        self,
        parameter,
        mode,
        start,
        stop,
        steps,
        *,
        offset=False,
393 394 395
        pre_step=None,
        post_step=None,
        on_complete=None,
Sean Leavey's avatar
Sean Leavey committed
396
        name="xaxis",
397
    ):
398 399 400 401 402 403 404 405 406 407 408 409
        super().__init__(
            name,
            parameter,
            mode,
            start,
            stop,
            steps,
            offset,
            pre_step=pre_step,
            post_step=post_step,
            on_complete=on_complete,
        )
410 411 412 413 414 415


class X2axis(XNaxis):
    def __init__(
        self,
        parameter1,
416
        mode1,
417 418 419 420
        start1,
        stop1,
        steps1,
        parameter2,
421
        mode2,
422 423 424
        start2,
        stop2,
        steps2,
425
        *,
426
        offset1=False,
427
        offset2=False,
428 429 430
        pre_step=None,
        post_step=None,
        on_complete=None,
Sean Leavey's avatar
Sean Leavey committed
431
        name="x2axis",
432 433 434
    ):
        super().__init__(
            name,
435 436 437 438 439
            parameter1,
            mode1,
            start1,
            stop1,
            steps1,
440
            offset1,
441 442 443 444 445
            parameter2,
            mode2,
            start2,
            stop2,
            steps2,
446
            offset2,
Sean Leavey's avatar
Sean Leavey committed
447 448 449
            pre_step=pre_step,
            post_step=post_step,
            on_complete=on_complete,
450 451 452 453 454 455 456
        )


class X3axis(XNaxis):
    def __init__(
        self,
        parameter1,
457
        mode1,
458 459 460 461
        start1,
        stop1,
        steps1,
        parameter2,
462
        mode2,
463 464 465 466
        start2,
        stop2,
        steps2,
        parameter3,
467
        mode3,
468 469 470
        start3,
        stop3,
        steps3,
471
        *,
472 473
        offset1=False,
        offset2=False,
474
        offset3=False,
475 476 477
        pre_step=None,
        post_step=None,
        on_complete=None,
Sean Leavey's avatar
Sean Leavey committed
478
        name="x3axis",
479 480 481
    ):
        super().__init__(
            name,
482 483 484 485 486
            parameter1,
            mode1,
            start1,
            stop1,
            steps1,
487
            offset1,
488 489 490 491 492
            parameter2,
            mode2,
            start2,
            stop2,
            steps2,
493
            offset2,
494 495 496 497 498
            parameter3,
            mode3,
            start3,
            stop3,
            steps3,
499
            offset3,
Sean Leavey's avatar
Sean Leavey committed
500 501 502
            pre_step=pre_step,
            post_step=post_step,
            on_complete=on_complete,
503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519
        )


class Serial(Action):
    def __init__(self, *args):
        super().__init__(self.__class__.__name__)
        self.args = args

    def fill_info(self, p_info):
        info = self.copy_info()
        p_info.add(info)
        p_info = info
        for arg in self.args:
            arg.fill_info(p_info)
            if len(p_info.children) == 0:
                raise Exception("Analysis information object should have been made")

520 521
    def setup(self, s_prev, sim):
        ws = ActionWorkspace(s_prev, sim)
522 523 524 525 526
        ws.wss = []
        ws.fn_do = self.do
        curr_children = len(ws.s_prev.children)
        # Here we get workspaces for each of the
        # actions we need to run
527
        for arg in self.args:
528
            ws.wss.append(arg.setup(s_prev, sim))
529
            if len(ws.s_prev.children) > curr_children:
530 531
                # If a solution was made and added as a child in the previous coroutine
                # then that becomes the next solution in the serial chain
532 533
                s_prev = ws.s_prev.children[-1]
        return ws
534

535 536 537
    def do(self, ws):
        for _ in ws.wss:
            _.fn_do(_)
538 539 540 541 542 543 544 545 546 547 548 549 550 551


class Folder(Action):
    def __init__(self, name, action):
        super().__init__(name)
        self.action = action
        self.folder = None

    def fill_info(self, p_info):
        info = self.copy_info()
        p_info.add(info)
        self.action.fill_info(info)
        self.folder = None

552 553
    def setup(self, s_prev, sim):
        ws = ActionWorkspace(s_prev, sim)
554
        ws.folder = BaseSolution(self.name, ws.s_prev)
555
        ws.action_ws = self.action.setup(ws.folder, sim)
556 557
        ws.fn_do = self.do
        return ws
558

559 560
    def do(self, ws):
        ws.action_ws.fn_do(ws.action_ws)
561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582


class Plot(Action):
    def __init__(self, *args, **kwargs):
        super().__init__(self.__class__.__name__)
        self.args = args
        self.kwargs = kwargs

    def do(self, s_prev, model):
        while type(s_prev) is BaseSolution:
            s_prev = s_prev.parent

        if s_prev is not None and hasattr(s_prev, "plot"):
            s_prev.plot()
        else:
            print(f"No plot method found in {s_prev}")


class Printer(Action):
    def __init__(self):
        super().__init__(self.__class__.__name__,)

583 584
    def do(self, ws):
        s_prev, model = ws.s_prev, ws.model
585 586 587 588 589 590 591
        print(s_prev, model)


class PrintModel(Action):
    def __init__(self):
        super().__init__(self.__class__.__name__)

592 593
    def do(self, ws):
        print(ws.model)
594 595 596 597 598 599


class PrintSolution(Action):
    def __init__(self):
        super().__init__(self.__class__.__name__)

600 601
    def do(self, ws):
        print(ws.s_prev)
602 603 604 605 606 607 608


class PrintAttr(Action):
    def __init__(self, *args):
        super().__init__(self.__class__.__name__)
        self.args = args

609 610
    def do(self, ws):
        print(*(f"{_}={ws.model.reduce_get_attr(_)}" for _ in self.args))
611 612 613 614 615 616 617


class ReprAttr(Action):
    def __init__(self, *args):
        super().__init__(self.__class__.__name__)
        self.args = args

618 619
    def do(self, ws):
        print(*(f"{_}={repr(ws.model.reduce_get_attr(_))}" for _ in self.args))
620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643


class Parallel(Action):
    def __init__(self, *args):
        super().__init__(self.__class__.__name__)
        self.args = args

    def fill_info(self, p_info):
        info = self.copy_info()
        p_info.add(info)

        for arg in self.args:
            arg.fill_info(info)

        # probe each parallel path and get what parameters will
        # be changing
        self.params_changing = (
            child.get_all_parameters_changing() for child in info.children
        )

    def do(self, s_prev, MODEL):
        for arg, pc in zip(self.args, self.params_changing):
            model = deepcopy(MODEL)

644
            params = (get_param(model, pstr) for pstr in pc)
645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679

            for p in params:
                p.is_tunable = True

            with model.built():
                arg.do(s_prev, model)

            for p in params:
                p.is_tunable = False


class Debug(Action):
    def __init__(self, name="Debug"):
        super().__init__(name)
        self.cancel = False

    def do(self, s_prev, model):
        if not self.cancel:
            from IPython.terminal.embed import InteractiveShellEmbed

            banner = textwrap.dedent(
                f"""
            ---- Finesse Debugging
            Instance          : {self.name}
            Previous solution : s_prev
            Current model     : model
            Current carrier   : carrier
            Current signal    : signal

            To stop future debug calls set : self.cancel = True
            To continue analyis            : exit
            """
            )
            self.shell = InteractiveShellEmbed(banner1=banner)
            self.shell()
680 681


Daniel Brown's avatar
Daniel Brown committed
682
class RunLocksWorkspace(ActionWorkspace):
683 684 685
    pass


686
class RunLocks(Action):
687
    def __init__(self, *locks, name="run_locks", max_iterations=10000):
688 689 690 691 692 693 694 695 696 697 698 699
        super().__init__(name)
        # Here we setup what this type of action will do
        self.locks = tuple(l.name for l in locks)
        self.num_locks = len(self.locks)
        self.max_iterations = max_iterations
        self._info.makes_a_solution = False
        self._info.empty = True
        self._info.parameters_changing = tuple(l.feedback.full_name for l in locks)

    def fill_info(self, p_info):
        p_info.add(self.copy_info())

700 701 702
    def setup(self, s_prev, sim):
        ws = RunLocksWorkspace(s_prev, sim)
        ws.locks = tuple(sim.model.elements[l] for l in self.locks)
703 704 705
        ws.det_ws = [None,] * self.num_locks

        for i, l in enumerate(ws.locks):
Daniel Brown's avatar
Daniel Brown committed
706
            for j, d in enumerate(sim.model.detectors):
707
                if l.error_signal.name == d.name:
708
                    ws.det_ws[i] = sim.detector_workspaces[j]
709 710 711 712 713 714 715

        if any(_ is None for _ in ws.det_ws):
            raise Exception("Could not find detector workspaces for all locks")

        ws.info = self.copy_info()
        ws.s_prev = s_prev
        ws.fn_do = do_lock
716
        ws.params = tuple(get_param(sim.model, p) for p in self._info.parameters_changing)
717 718 719 720 721 722 723 724 725 726
        ws.max_iterations = self.max_iterations

        for p in ws.params:
            if not p.is_tunable:
                raise ParameterLocked(
                    f"{repr(p)} must set as tunable " "before building the simulation"
                )
        return ws


727
def do_lock(ws: RunLocksWorkspace):
728 729 730 731
    recompute = True
    iters = 0
    while recompute and iters < ws.max_iterations:
        iters += 1
732
        ws.sim.run_carrier()
733 734 735 736 737 738 739 740 741 742 743 744 745
        recompute = False
        for i, dws in enumerate(ws.det_ws):
            acc = ws.locks[i].accuracy
            res = dws.get_output()

            if not (-acc <= res <= acc):
                # We'll need to recompute the carrier sim
                recompute = True
                gain = ws.locks[i].gain
                ws.locks[i].feedback.value += gain * res

    if recompute is True:
        raise Exception("Locks failed")