Commit 598ebcef authored by Sean Leavey's avatar Sean Leavey

Add component replacement method and refactor other circuit getters/setters

parent 0bd279d9
Pipeline #72791 passed with stage
in 25 minutes and 46 seconds
......@@ -32,3 +32,19 @@ You can print the circuit to retrieve a list of its constituents:
Circuits are only useful once you add components. This is achieved using the various ``add_``
methods, such as :meth:`.add_resistor`, :meth:`.add_capacitor`, :meth:`.add_inductor` and
Circuit manipulation
Circuits can be modified before and after applying :ref:`analyses <analyses/index:Analyses>`.
Circuit components can be removed with :meth:`.remove_component` or replaced with
When a component is removed, any connected nodes shared by other components are preserved.
When a component is replaced with another one, its nodes are copied to the new component and the new
component's nodes are overwritten. The components being swapped must be compatible: the number of
nodes in the current and replacement component must be the same, meaning that :ref:`passive
components <components/passive-components:Passive components>` can only be swapped for other passive
components, and :ref:`op-amps <components/op-amps:Op-amps>` can only be swapped for other op-amps.
......@@ -4,7 +4,8 @@ import abc
from unittest import TestCase
import numpy as np
from zero.components import OpAmp, Resistor, Node, OpAmpVoltageNoise, OpAmpCurrentNoise
from zero.components import (Resistor, Capacitor, Inductor, OpAmp, Node, OpAmpVoltageNoise,
from zero.solution import Solution
from import Series, Response, NoiseDensity, MultiNoiseDensity
......@@ -18,6 +19,8 @@ class ZeroDataTestCase(TestCase, metaclass=abc.ABCMeta):
super().__init__(*args, **kwargs)
self._last_node_num = 0
self._last_resistor_num = 0
self._last_capacitor_num = 0
self._last_inductor_num = 0
self._last_opamp_num = 0
def _unique_node_name(self):
......@@ -28,6 +31,14 @@ class ZeroDataTestCase(TestCase, metaclass=abc.ABCMeta):
self._last_resistor_num += 1
return f"r{self._last_resistor_num}"
def _unique_capacitor_name(self):
self._last_capacitor_num += 1
return f"c{self._last_capacitor_num}"
def _unique_inductor_name(self):
self._last_inductor_num += 1
return f"l{self._last_inductor_num}"
def _unique_opamp_name(self):
self._last_opamp_num += 1
return f"op{self._last_opamp_num}"
......@@ -49,7 +60,13 @@ class ZeroDataTestCase(TestCase, metaclass=abc.ABCMeta):
def _node(self):
return Node(self._unique_node_name())
def _opamp(self, node1, node2, node3, model=None):
def _opamp(self, node1=None, node2=None, node3=None, model=None):
if node1 is None:
node1 = self._node()
if node2 is None:
node2 = self._node()
if node3 is None:
node3 = self._node()
if model is None:
model = "OP00"
return OpAmp(name=self._unique_opamp_name(), model=model, node1=node1, node2=node2,
......@@ -64,6 +81,24 @@ class ZeroDataTestCase(TestCase, metaclass=abc.ABCMeta):
value = "1k"
return Resistor(name=self._unique_resistor_name(), node1=node1, node2=node2, value=value)
def _capacitor(self, node1=None, node2=None, value=None):
if node1 is None:
node1 = self._node()
if node2 is None:
node2 = self._node()
if value is None:
value = "1u"
return Capacitor(name=self._unique_capacitor_name(), node1=node1, node2=node2, value=value)
def _inductor(self, node1=None, node2=None, value=None):
if node1 is None:
node1 = self._node()
if node2 is None:
node2 = self._node()
if value is None:
value = "1u"
return Inductor(name=self._unique_inductor_name(), node1=node1, node2=node2, value=value)
def _voltage_noise(self, component=None):
if component is None:
component = self._resistor()
"""Circuit tests"""
from unittest import TestCase
from itertools import permutations
from zero import Circuit
from zero.components import Resistor, Capacitor, Inductor, OpAmp, Node
from import ZeroDataTestCase
class CircuitTestCase(TestCase):
......@@ -154,7 +156,6 @@ class CircuitTestCase(TestCase):
self.assertRaisesRegex(ValueError, r"element with name 'n1' already in circuit",
self.circuit.add_component, op1)
def test_cannot_add_node_with_same_name_as_component(self):
"""Test node with same name as existing component cannot be added"""
# first component
......@@ -175,3 +176,60 @@ class CircuitTestCase(TestCase):
self.assertRaisesRegex(ValueError, r"node 'r1' is the same as existing circuit component",
self.circuit.add_component, op1)
class TestComponentReplacement(ZeroDataTestCase):
def setUp(self):
self.passives = [self._resistor(), self._resistor(),
self._capacitor(), self._capacitor(),
self._inductor(), self._inductor()]
def test_replace_passive_passive(self):
"""Test passive component replacement."""
for cmp1, cmp2 in permutations(self.passives, 2):
with self.subTest((cmp1, cmp2)):
circuit = Circuit()
circuit.replace_component(cmp1, cmp2)
# Test circuit composition.
# Nodes in cmp2 should have been copied from cmp1.
self.assertEqual(cmp1.nodes, cmp2.nodes)
def test_replace_opamp_opamp(self):
"""Test op-amp replacement."""
op1 = self._opamp()
op2 = self._opamp()
circuit = Circuit()
circuit.replace_component(op1, op2)
# Test circuit composition.
# Nodes in cmp2 should have been copied from cmp1.
self.assertEqual(op1.nodes, op2.nodes)
def test_cannot_replace_passive_opamp_or_opamp_passive(self):
"""Test passive components cannot be replaced with op-amps (and vice versa)."""
opamp = self._opamp()
for passive in self.passives:
# Test replacement of op-amp.
with self.subTest((opamp, passive)):
circuit = Circuit()
self.assertRaises(ValueError, circuit.replace_component, opamp, passive)
# Test that the circuit still has the op-amp.
# Test replacement of passive.
with self.subTest((passive, opamp)):
circuit = Circuit()
self.assertRaises(ValueError, circuit.replace_component, passive, opamp)
# Test that the circuit still has the passive.
......@@ -28,7 +28,7 @@ class Circuit:
nodes : :class:`set` of :class:`.Node`
The circuit nodes.
# disallowed component names
# Aisallowed component names.
RESERVED_NAMES = ["all", "allop", "allr", "sum"]
def __init__(self):
......@@ -108,24 +108,19 @@ class Circuit:
if component is None:
raise ValueError("component cannot be None")
if is None:
# assign name
# Assign name.
if in self:
raise ValueError(f"element with name '{}' already in circuit")
elif in self.RESERVED_NAMES:
raise ValueError(f"component name '{}' is reserved")
# add component to end of list
# Add component to end of list.
# add nodes
# Add nodes.
for node in component.nodes:
if in self.component_names:
raise ValueError(f"node '{}' is the same as existing circuit component")
def add_resistor(self, *args, **kwargs):
......@@ -169,15 +164,19 @@ class Circuit:
component : :class:`str` or :class:`.Component`
The component to remove.
If the component is not found.
if isinstance(component, str):
# get component by name
# Get component by name.
component = self.get_component(component)
# remove component
elif component not in self.components:
raise ComponentNotFoundError(component)
# implicitly remove orphaned nodes by regenerating node set from components
# Implicitly remove orphaned nodes by regenerating node set from components.
def get_component(self, component_name):
......@@ -185,7 +184,7 @@ class Circuit:
component_name : :class:`str`
component_name : :class:`str` or :class:`.Component`
The name of the component to fetch.
......@@ -195,17 +194,51 @@ class Circuit:
If the component is not found.
# Get the component name from the object, if appropriate.
component_name = getattr(component_name, "name", component_name)
name = component_name.lower()
for component in self.components:
if name ==
return component
raise ComponentNotFoundError(component_name)
def replace_component(self, current_component, new_component):
"""Replace circuit component with a new one.
This can be used to replace components of the same type, but can also replace components of
different types as long as they have the same number of nodes.
current_component : :class:`.Component`
The component to replace.
new_component : :class:`str` or :class:`.Component`
The new component.
If the current component is not in the circuit.
If the new component is already in the circuit, or if the nodes of the new component are
incompatible with those of the current component.
current_component = self.get_component(current_component)
if new_component in self.components:
raise ValueError(f"{new_component} is already in the circuit")
if len(current_component.nodes) != len(new_component.nodes):
raise ValueError(f"{current_component} and {new_component} nodes are incompatible")
# Copy the nodes.
nodes = current_component.nodes
# Do the replacement.
LOGGER.debug(f"Overwriting {new_component}'s nodes with those from {current_component}'")
new_component.nodes = nodes
def has_component(self, component_name):
"""Check if component is present in circuit.
......@@ -219,14 +252,18 @@ class Circuit:
True if component exists, False otherwise.
return component_name.lower() in [name.lower() for name in self.component_names]
except ComponentNotFoundError:
return False
return True
def get_node(self, node_name):
"""Get circuit node by name.
node_name : :class:`str`
node_name : :class:`str` or :class:`.Node`
The name of the node to fetch.
......@@ -239,12 +276,12 @@ class Circuit:
If the node is not found.
# Get the node name from the object, if appropriate.
node_name = getattr(node_name, "name", node_name)
name = node_name.lower()
for node in self.nodes:
if name ==
return node
raise NodeNotFoundError(name)
def has_node(self, node_name):
......@@ -260,7 +297,11 @@ class Circuit:
True if node exists, False otherwise.
return node_name.lower() in [name.lower() for name in self.node_names]
except NodeNotFoundError:
return False
return True
def get_element(self, element_name):
"""Get circuit element (component or node) by name.
......@@ -281,17 +322,14 @@ class Circuit:
If the element is not found.
name = element_name.lower()
return self.get_component(name)
except ComponentNotFoundError:
return self.get_node(name)
except NodeNotFoundError:
raise ElementNotFoundError(element_name)
def has_element(self, element_name):
......@@ -307,7 +345,7 @@ class Circuit:
True if element exists, False otherwise.
return element_name.lower() in [name.lower() for name in self.element_names]
return self.has_component(element_name) or self.has_node(element_name)
def get_noise(self, noise_name):
"""Get noise by component or node name.
......@@ -60,12 +60,12 @@ class Component(BaseElement, metaclass=abc.ABCMeta):
def nodes(self, nodes):
for node in list(nodes):
nodes = list(nodes)
for index, node in enumerate(nodes):
if not isinstance(node, Node):
# parse node name
node = Node(str(node))
# Parse as a node name.
nodes[index] = Node(str(node))
self._nodes = nodes
def add_noise(self, noise):
"""Add a noise source to the component.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment