diff --git a/src/qce_circuit/addon_stim/noise_factories/factory_pauli_noise.py b/src/qce_circuit/addon_stim/noise_factories/factory_pauli_noise.py index ed8643e..3e51cb8 100644 --- a/src/qce_circuit/addon_stim/noise_factories/factory_pauli_noise.py +++ b/src/qce_circuit/addon_stim/noise_factories/factory_pauli_noise.py @@ -1,5 +1,5 @@ # ------------------------------------------- -# Module containing Stim noise-dresser factories for measurement operations. +# Module containing Stim noise-dresser factories for additive pauli noise. # ------------------------------------------- from typing import List, Tuple, Iterator import numpy as np diff --git a/src/qce_circuit/structure/acquisition_indexing/intrf_stabilizer_index_kernel.py b/src/qce_circuit/structure/acquisition_indexing/intrf_stabilizer_index_kernel.py index 27b653f..992fb48 100644 --- a/src/qce_circuit/structure/acquisition_indexing/intrf_stabilizer_index_kernel.py +++ b/src/qce_circuit/structure/acquisition_indexing/intrf_stabilizer_index_kernel.py @@ -16,6 +16,13 @@ class StateKey(Enum): STATE_1 = 1 STATE_2 = 2 + # region Class Methods + def __lt__(self, other): # Add comparison for sorting if needed + if self.__class__ is other.__class__: + return self.value < other.value + return NotImplemented + # endregion + class IStabilizerIndexingKernel(IIndexingKernel, metaclass=ABCMeta): """ diff --git a/src/qce_circuit/structure/circuit_operations.py b/src/qce_circuit/structure/circuit_operations.py index 4240fa8..f534f19 100644 --- a/src/qce_circuit/structure/circuit_operations.py +++ b/src/qce_circuit/structure/circuit_operations.py @@ -1048,7 +1048,7 @@ class VirtualOptional(ICircuitOperation): Data class, containing single-qubit operation. Acts as a visualization wrapper. """ - operation: SingleQubitOperation + operation: ICircuitOperation # region Interface Properties @property @@ -1064,12 +1064,12 @@ def nr_of_repetitions(self) -> int: @property def relation_link(self) -> IRelationLink[ICircuitOperation]: """:return: Description of relation to other circuit node.""" - return self.operation.relation + return self.operation.relation_link @relation_link.setter def relation_link(self, link: IRelationLink[ICircuitOperation]): """:sets: Description of relation to other circuit node.""" - self.operation.relation = link + self.operation.relation_link = link @property def start_time(self) -> float: diff --git a/src/qce_circuit/visualization/visualize_circuit/display_circuit.py b/src/qce_circuit/visualization/visualize_circuit/display_circuit.py index 387aae0..d672a62 100644 --- a/src/qce_circuit/visualization/visualize_circuit/display_circuit.py +++ b/src/qce_circuit/visualization/visualize_circuit/display_circuit.py @@ -199,6 +199,7 @@ def get_transform_constructor(self) -> TransformConstructor: ) def get_operation_draw_components(self) -> List[IDrawComponent]: + minimalist: bool = True individual_component_factory = DrawComponentFactoryManager( default_factory=DefaultFactory(), factory_lookup={ @@ -206,15 +207,15 @@ def get_operation_draw_components(self) -> List[IDrawComponent]: DispersiveMeasure: MeasureFactory(), Reset: ResetFactory(), Wait: WaitFactory(), - Rx180: Rx180Factory(), - Rx90: Rx90Factory(), - Rxm90: Rxm90Factory(), + Rx180: Rx180Factory(minimalist=minimalist), + Rx90: Rx90Factory(minimalist=minimalist), + Rxm90: Rxm90Factory(minimalist=minimalist), RxTheta: RxThetaFactory(), - Ry180: Ry180Factory(), - Ry90: Ry90Factory(), - Rym90: Rym90Factory(), + Ry180: Ry180Factory(minimalist=minimalist), + Ry90: Ry90Factory(minimalist=minimalist), + Rym90: Rym90Factory(minimalist=minimalist), RyTheta: RyThetaFactory(), - Rx180ef: Rx180efFactory(), + Rx180ef: Rx180efFactory(minimalist=minimalist), Rphi90: Rphi90Factory(), VirtualPhase: ZPhaseFactory(), Identity: IdentityFactory(), @@ -338,7 +339,7 @@ def construct_visual_description(circuit: IDeclarativeCircuit, custom_channel_or end_time = operation.end_time return VisualCircuitDescription( - channel_width=end_time + 1.0, + channel_width=end_time, channel_height=1.0, channel_indices=channel_indices, channel_label_map=custom_channel_map, diff --git a/src/qce_circuit/visualization/visualize_circuit/draw_components/annotation_components.py b/src/qce_circuit/visualization/visualize_circuit/draw_components/annotation_components.py index 742682a..fca6e9d 100644 --- a/src/qce_circuit/visualization/visualize_circuit/draw_components/annotation_components.py +++ b/src/qce_circuit/visualization/visualize_circuit/draw_components/annotation_components.py @@ -30,7 +30,7 @@ class HorizontalVariableIndicator(IRectTransformComponent, IDrawComponent): width: float height: float alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: IndicatorStyleSettings = field(default=StyleManager.read_config().indicator_style) + style_settings: IndicatorStyleSettings = field(default_factory=lambda: StyleManager.read_config().indicator_style) text_string: str = field(default=r'$\mathtt{{\delta}}$') # region Interface Properties @@ -140,7 +140,7 @@ class RoundedRectangleHighlight(IRectTransformComponent, IDrawComponent): width: float height: float alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: HighlightStyleSettings = field(default=StyleManager.read_config().highlight_style) + style_settings: HighlightStyleSettings = field(default_factory=lambda: StyleManager.read_config().highlight_style) text_string: str = field(default='x1') # region Interface Properties diff --git a/src/qce_circuit/visualization/visualize_circuit/draw_components/channel_components.py b/src/qce_circuit/visualization/visualize_circuit/draw_components/channel_components.py index 7f11283..ff60df3 100644 --- a/src/qce_circuit/visualization/visualize_circuit/draw_components/channel_components.py +++ b/src/qce_circuit/visualization/visualize_circuit/draw_components/channel_components.py @@ -74,15 +74,20 @@ class ChannelHeader(IRectTransformComponent, IDrawComponent): height: ILengthStrategy channel_name: str state_description: str - style_settings: ChannelStyleSettings = field(default=StyleManager.read_config().channel_style) + style_settings: ChannelStyleSettings = field(default_factory=lambda: StyleManager.read_config().channel_style) # region Interface Properties @property def rectilinear_transform(self) -> IRectTransform: """:return: 'Hard' rectilinear transform boundary. Should be treated as 'personal zone'.""" + width: float = 0.0 + if self.style_settings.enable_label_description: + width += self.channel_name_width + self.divider_width + if self.style_settings.enable_state_description: + width += self.state_description_width + self.divider_width return RectTransform( _pivot_strategy=DynamicPivot(lambda: self.pivot), - _width_strategy=DynamicLength(lambda: self.channel_name_width + self.state_description_width + 2 * self.divider_width), + _width_strategy=DynamicLength(lambda: width), _height_strategy=self.height, _parent_alignment=TransformAlignment.MID_RIGHT, ) @@ -91,7 +96,7 @@ def rectilinear_transform(self) -> IRectTransform: # region Class Properties @property def divider_width(self) -> float: - return 0.4 + return self.style_settings.divider_width @property def channel_name_width(self) -> float: @@ -99,7 +104,7 @@ def channel_name_width(self) -> float: @property def state_description_width(self) -> float: - return 0.7 + return self.style_settings.state_description_width @property def channel_name_pivot(self) -> Vec2D: @@ -129,31 +134,35 @@ def right_divider(self) -> Line2D: # region Interface Methods def draw(self, axes: plt.Axes) -> plt.Axes: """Method used for drawing component on Axes.""" - axes.text( - x=self.channel_name_pivot.x, - y=self.channel_name_pivot.y, - s=self.channel_name, - fontsize=self.style_settings.font_size, - color=self.style_settings.text_color, - ha='left', - va='center', - ) - axes.text( - x=self.state_description_pivot.x, - y=self.state_description_pivot.y, - s=self.state_description, - fontsize=self.style_settings.font_size, - color=self.style_settings.text_color, - ha='center', - va='center', - ) - axes.plot( - [self.center_divider.start.x, self.center_divider.end.x], - [self.center_divider.start.y, self.center_divider.end.y], - linestyle='-', - linewidth=self.style_settings.line_width, - color=self.style_settings.line_color, - ) + if self.style_settings.enable_label_description: + axes.text( + x=self.channel_name_pivot.x, + y=self.channel_name_pivot.y, + s=self.channel_name, + fontsize=self.style_settings.font_size, + color=self.style_settings.text_color, + ha='right', + va='center', + ) + axes.plot( + [self.center_divider.start.x, self.center_divider.end.x], + [self.center_divider.start.y, self.center_divider.end.y], + linestyle='-', + linewidth=self.style_settings.line_width, + color=self.style_settings.line_color, + ) + + if self.style_settings.enable_state_description: + axes.text( + x=self.state_description_pivot.x, + y=self.state_description_pivot.y, + s=self.state_description, + fontsize=self.style_settings.font_size, + color=self.style_settings.text_color, + ha='center', + va='center', + ) + axes.plot( [self.right_divider.start.x, self.right_divider.end.x], [self.right_divider.start.y, self.right_divider.end.y], @@ -174,7 +183,7 @@ class ChannelBar(IRectTransformComponent, IDrawComponent): width: float height: float alignment: TransformAlignment = field(init=False, default=TransformAlignment.MID_LEFT) - style_settings: ChannelStyleSettings = field(default=StyleManager.read_config().channel_style) + style_settings: ChannelStyleSettings = field(default_factory=lambda: StyleManager.read_config().channel_style) # region Class Properties @property diff --git a/src/qce_circuit/visualization/visualize_circuit/draw_components/factory_draw_components.py b/src/qce_circuit/visualization/visualize_circuit/draw_components/factory_draw_components.py index 3f1e0e0..1e93512 100644 --- a/src/qce_circuit/visualization/visualize_circuit/draw_components/factory_draw_components.py +++ b/src/qce_circuit/visualization/visualize_circuit/draw_components/factory_draw_components.py @@ -1,6 +1,7 @@ # ------------------------------------------- # Module containing functionality for constructing draw components from operation class types. # ------------------------------------------- +from dataclasses import dataclass, field from typing import List from qce_circuit.structure.intrf_circuit_operation import ( ICircuitOperation, @@ -109,7 +110,9 @@ def construct(self, operation: Reset, transform_constructor: ITransformConstruct # endregion +@dataclass(frozen=True) class Rx180Factory(IOperationDrawComponentFactory[Rx180, IDrawComponent]): + minimalist: bool = field(default=False) # region Interface Methods def construct(self, operation: Rx180, transform_constructor: ITransformConstructor) -> IDrawComponent: @@ -118,6 +121,14 @@ def construct(self, operation: Rx180, transform_constructor: ITransformConstruct identifier=operation.channel_identifiers[0], time_component=operation, ) + if self.minimalist: + return BlockHeaderBody( + pivot=transform.pivot, + height=transform.height, + alignment=transform.parent_alignment, + header_text=f"{RotationAxis.X.value}", + ) + return BlockRotation( pivot=transform.pivot, height=transform.height, @@ -128,7 +139,9 @@ def construct(self, operation: Rx180, transform_constructor: ITransformConstruct # endregion +@dataclass(frozen=True) class Rx90Factory(IOperationDrawComponentFactory[Rx90, IDrawComponent]): + minimalist: bool = field(default=False) # region Interface Methods def construct(self, operation: Rx90, transform_constructor: ITransformConstructor) -> IDrawComponent: @@ -137,6 +150,14 @@ def construct(self, operation: Rx90, transform_constructor: ITransformConstructo identifier=operation.channel_identifiers[0], time_component=operation, ) + if self.minimalist: + return BlockHeaderBody( + pivot=transform.pivot, + height=transform.height, + alignment=transform.parent_alignment, + header_text=f"+{RotationAxis.X.value}/2", + ) + return BlockRotation( pivot=transform.pivot, height=transform.height, @@ -147,7 +168,9 @@ def construct(self, operation: Rx90, transform_constructor: ITransformConstructo # endregion +@dataclass(frozen=True) class Rxm90Factory(IOperationDrawComponentFactory[Rxm90, IDrawComponent]): + minimalist: bool = field(default=False) # region Interface Methods def construct(self, operation: Rxm90, transform_constructor: ITransformConstructor) -> IDrawComponent: @@ -156,6 +179,14 @@ def construct(self, operation: Rxm90, transform_constructor: ITransformConstruct identifier=operation.channel_identifiers[0], time_component=operation, ) + if self.minimalist: + return BlockHeaderBody( + pivot=transform.pivot, + height=transform.height, + alignment=transform.parent_alignment, + header_text=f"-{RotationAxis.X.value}/2", + ) + return BlockRotation( pivot=transform.pivot, height=transform.height, @@ -185,7 +216,9 @@ def construct(self, operation: RxTheta, transform_constructor: ITransformConstru # endregion +@dataclass(frozen=True) class Ry180Factory(IOperationDrawComponentFactory[Ry180, IDrawComponent]): + minimalist: bool = field(default=False) # region Interface Methods def construct(self, operation: Ry180, transform_constructor: ITransformConstructor) -> IDrawComponent: @@ -194,6 +227,14 @@ def construct(self, operation: Ry180, transform_constructor: ITransformConstruct identifier=operation.channel_identifiers[0], time_component=operation, ) + if self.minimalist: + return BlockHeaderBody( + pivot=transform.pivot, + height=transform.height, + alignment=transform.parent_alignment, + header_text=f"{RotationAxis.Y.value}", + ) + return BlockRotation( pivot=transform.pivot, height=transform.height, @@ -204,7 +245,9 @@ def construct(self, operation: Ry180, transform_constructor: ITransformConstruct # endregion +@dataclass(frozen=True) class Ry90Factory(IOperationDrawComponentFactory[Ry90, IDrawComponent]): + minimalist: bool = field(default=False) # region Interface Methods def construct(self, operation: Ry90, transform_constructor: ITransformConstructor) -> IDrawComponent: @@ -213,6 +256,14 @@ def construct(self, operation: Ry90, transform_constructor: ITransformConstructo identifier=operation.channel_identifiers[0], time_component=operation, ) + if self.minimalist: + return BlockHeaderBody( + pivot=transform.pivot, + height=transform.height, + alignment=transform.parent_alignment, + header_text=f"+{RotationAxis.Y.value}/2", + ) + return BlockRotation( pivot=transform.pivot, height=transform.height, @@ -223,7 +274,9 @@ def construct(self, operation: Ry90, transform_constructor: ITransformConstructo # endregion +@dataclass(frozen=True) class Rym90Factory(IOperationDrawComponentFactory[Rym90, IDrawComponent]): + minimalist: bool = field(default=False) # region Interface Methods def construct(self, operation: Rym90, transform_constructor: ITransformConstructor) -> IDrawComponent: @@ -232,6 +285,14 @@ def construct(self, operation: Rym90, transform_constructor: ITransformConstruct identifier=operation.channel_identifiers[0], time_component=operation, ) + if self.minimalist: + return BlockHeaderBody( + pivot=transform.pivot, + height=transform.height, + alignment=transform.parent_alignment, + header_text=f"-{RotationAxis.Y.value}/2", + ) + return BlockRotation( pivot=transform.pivot, height=transform.height, @@ -261,7 +322,9 @@ def construct(self, operation: RyTheta, transform_constructor: ITransformConstru # endregion +@dataclass(frozen=True) class Rx180efFactory(IOperationDrawComponentFactory[Rx180, IDrawComponent]): + minimalist: bool = field(default=False) # region Interface Methods def construct(self, operation: Rx180, transform_constructor: ITransformConstructor) -> IDrawComponent: @@ -270,6 +333,14 @@ def construct(self, operation: Rx180, transform_constructor: ITransformConstruct identifier=operation.channel_identifiers[0], time_component=operation, ) + if self.minimalist: + return BlockHeaderBody( + pivot=transform.pivot, + height=transform.height, + alignment=transform.parent_alignment, + header_text=f"${RotationAxis.X.value}_{{12}}$", + ) + return BlockRotation( pivot=transform.pivot, height=transform.height, diff --git a/src/qce_circuit/visualization/visualize_circuit/draw_components/icon_components.py b/src/qce_circuit/visualization/visualize_circuit/draw_components/icon_components.py index 6c169c6..4ff9783 100644 --- a/src/qce_circuit/visualization/visualize_circuit/draw_components/icon_components.py +++ b/src/qce_circuit/visualization/visualize_circuit/draw_components/icon_components.py @@ -19,7 +19,7 @@ class IconMeasure(IDrawComponent): """ center: Vec2D radius: float - style_settings: IconStyleSettings = field(default=StyleManager.read_config().icon_style) + style_settings: IconStyleSettings = field(default_factory=lambda: StyleManager.read_config().icon_style) # region Class Properties @property @@ -89,11 +89,13 @@ def draw(self, axes: plt.Axes) -> plt.Axes: tail_width=self.arrow_thickness, ), color=self.style_settings.icon_color, + linewidth=self.circle_thickness, ) arrow_base = patches.Circle( xy=self.circle_center.to_tuple(), radius=self.arrow_base_radius, color=self.style_settings.icon_color, + linewidth=self.circle_thickness, ) # Apply patches axes.add_patch(arc) diff --git a/src/qce_circuit/visualization/visualize_circuit/draw_components/multi_pivot_components.py b/src/qce_circuit/visualization/visualize_circuit/draw_components/multi_pivot_components.py index 03132bc..36f76f3 100644 --- a/src/qce_circuit/visualization/visualize_circuit/draw_components/multi_pivot_components.py +++ b/src/qce_circuit/visualization/visualize_circuit/draw_components/multi_pivot_components.py @@ -35,7 +35,7 @@ @dataclass(frozen=True) class DotComponent(IDrawComponent): base_transform: IRectTransform - style_settings: OperationStyleSettings = field(default=StyleManager.read_config().operation_style) + style_settings: OperationStyleSettings = field(default_factory=lambda: StyleManager.read_config().operation_style) # region Interface Methods def draw(self, axes: plt.Axes) -> plt.Axes: @@ -54,7 +54,7 @@ def draw(self, axes: plt.Axes) -> plt.Axes: @dataclass(frozen=True) class CrossComponent(IDrawComponent): base_transform: IRectTransform - style_settings: OperationStyleSettings = field(default=StyleManager.read_config().operation_style) + style_settings: OperationStyleSettings = field(default_factory=lambda: StyleManager.read_config().operation_style) # region Class Properties @property @@ -117,7 +117,7 @@ class BlockRotationComponent(IDrawComponent): """ base_transform: IRectTransform rotation_angle: RotationAngle = field(default=RotationAngle.THETA) - style_settings: OperationStyleSettings = field(default=StyleManager.read_config().operation_style) + style_settings: OperationStyleSettings = field(default_factory=lambda: StyleManager.read_config().operation_style) # region Class Properties @property @@ -150,7 +150,7 @@ class BlockTwoQubitGate(IRectTransformComponent, IDrawComponent): single_block_height: ILengthStrategy single_block_width: ILengthStrategy alignment: TransformAlignmentSubset = field(default=TransformAlignment.MID_LEFT) - style_settings: OperationStyleSettings = field(default=StyleManager.read_config().operation_style) + style_settings: OperationStyleSettings = field(default_factory=lambda: StyleManager.read_config().operation_style) # region Interface Properties @property @@ -244,7 +244,7 @@ class BlockTwoQubitVacant(BlockTwoQubitGate, IRectTransformComponent, IDrawCompo Data class, containing information to draw a two-qubit gate block that uses two pivots to comply with vertical alignment. """ - style_settings: OperationStyleSettings = field(default=StyleManager.read_config().vacant_operation_style) + style_settings: OperationStyleSettings = field(default_factory=lambda: StyleManager.read_config().vacant_operation_style) # region Interface Methods def draw(self, axes: plt.Axes) -> plt.Axes: @@ -276,7 +276,7 @@ class BlockVerticalBarrier(IRectTransformComponent, IDrawComponent): that uses multiple pivots to comply with vertical alignment. """ multiple_transforms: List[IRectTransform] - style_settings: IndicatorStyleSettings = field(default=StyleManager.read_config().indicator_style) + style_settings: IndicatorStyleSettings = field(default_factory=lambda: StyleManager.read_config().indicator_style) # region Interface Properties @property diff --git a/src/qce_circuit/visualization/visualize_circuit/draw_components/operation_components.py b/src/qce_circuit/visualization/visualize_circuit/draw_components/operation_components.py index 9a997c0..0cfae89 100644 --- a/src/qce_circuit/visualization/visualize_circuit/draw_components/operation_components.py +++ b/src/qce_circuit/visualization/visualize_circuit/draw_components/operation_components.py @@ -137,7 +137,7 @@ class RectangleVacantBlock(IRectTransformComponent, IDrawComponent): width: float height: float alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: OperationStyleSettings = field(default=StyleManager.read_config().vacant_operation_style) + style_settings: OperationStyleSettings = field(default_factory=lambda: StyleManager.read_config().vacant_operation_style) # region Interface Properties @property @@ -225,7 +225,7 @@ class BlockMeasure(IRectTransformComponent, IDrawComponent): width: float height: float alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: OperationStyleSettings = field(default=StyleManager.read_config().operation_style) + style_settings: OperationStyleSettings = field(default_factory=lambda: StyleManager.read_config().operation_style) _base_block: RectangleBlock = field(init=False) # region Interface Properties @@ -356,7 +356,7 @@ class SquareParkBlock(IRectTransformComponent, IDrawComponent): height: float width: float alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: ChannelStyleSettings = field(default=StyleManager.read_config().channel_style) + style_settings: ChannelStyleSettings = field(default_factory=lambda: StyleManager.read_config().channel_style) # region Interface Properties @property @@ -436,7 +436,7 @@ class SquareNetZeroParkBlock(IRectTransformComponent, IDrawComponent): height: float width: float alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: ChannelStyleSettings = field(default=StyleManager.read_config().channel_style) + style_settings: ChannelStyleSettings = field(default_factory=lambda: StyleManager.read_config().channel_style) # region Interface Properties @property diff --git a/src/qce_circuit/visualization/visualize_circuit/style_manager.py b/src/qce_circuit/visualization/visualize_circuit/style_manager.py index 20cb22a..a7b4a66 100644 --- a/src/qce_circuit/visualization/visualize_circuit/style_manager.py +++ b/src/qce_circuit/visualization/visualize_circuit/style_manager.py @@ -22,6 +22,10 @@ class ChannelStyleSettings: text_color: str line_width: float font_size: float + divider_width: float + state_description_width: float + enable_state_description: bool + enable_label_description: bool @dataclass(frozen=True) @@ -94,6 +98,8 @@ class StyleSettings: width_line_small: float = field(default=1.0) width_line_icon: float = field(default=6.0) width_border: float = field(default=2.0) + width_divider: float = field(default=0.4) + width_state_description: float = field(default=0.7) # Radius radius_dot: float = field(default=0.1) @@ -106,7 +112,11 @@ class StyleSettings: line_style_border: str = field(default='-') # Spacing - rectilinear_margin: float = field(default=0.0) + rectilinear_margin: float = field(default=0.1) + + # Header + enable_state_description: bool = field(default=True) + enable_label_description: bool = field(default=True) # region Class Properties @property @@ -116,6 +126,10 @@ def channel_style(self) -> ChannelStyleSettings: text_color=self.color_text, line_width=self.width_line, font_size=self.font_size, + divider_width=self.width_divider, + state_description_width=self.width_state_description, + enable_state_description=self.enable_state_description, + enable_label_description=self.enable_label_description, ) @property @@ -207,14 +221,18 @@ def _default_config_object(cls) -> dict: @classmethod def read_config(cls) -> StyleSettings: - """:return: File-manager config file or overridden settings.""" - # Check for temporary override - if hasattr(cls._override_stack, "current_override") and cls._override_stack.current_override: - return cls._override_stack.current_override + """ + Reads the configuration settings, applying any temporary overrides if present. + It checks if there is a stack of overrides; if so, returns the most recent override. + + :return: The effective StyleSettings, either from file or from temporary overrides. + """ + if hasattr(cls._override_stack, "stack") and cls._override_stack.stack: + # Return the latest override if present + return cls._override_stack.stack[-1] path = get_yaml_file_path(filename=cls.CONFIG_NAME) if not os.path.exists(path): - # Construct config dict default_dict: dict = cls._default_config_object() write_yaml( filename=cls.CONFIG_NAME, @@ -227,18 +245,28 @@ def read_config(cls) -> StyleSettings: @contextmanager def temporary_override(cls, **overrides): """ - Temporarily override specific style settings in memory. - Ensures that changes do not persist beyond the 'with' block. + Temporarily override specific style settings in memory, supporting nested overrides. + The new override is based on the current configuration, which may already have been overridden. + + :param overrides: Keyword arguments for the configuration values to override. + :return: A context manager that yields control with the overridden configuration. """ + # Get the current configuration, which is either the top override or the default configuration. + current_config = cls.read_config() + new_config_dict = current_config.__dict__.copy() + new_config_dict.update(overrides) # Apply new overrides + + new_override = StyleSettings(**new_config_dict) + + # Initialize the override stack if it doesn't exist. + if not hasattr(cls._override_stack, "stack"): + cls._override_stack.stack = [] + # Push the new override onto the stack. + cls._override_stack.stack.append(new_override) + try: - # Store the original settings - original_config = cls.read_config() - new_config_dict = original_config.__dict__.copy() - new_config_dict.update(overrides) # Apply overrides - - # Set thread-local override - cls._override_stack.current_override = StyleSettings(**new_config_dict) - yield # Execute block within overridden context + yield finally: - cls._override_stack.current_override = None # Restore original settings + # Pop the override from the stack to restore the previous configuration. + cls._override_stack.stack.pop() # endregion diff --git a/src/qce_circuit/visualization/visualize_layout/element_components.py b/src/qce_circuit/visualization/visualize_layout/element_components.py index b14ddb5..ca80cca 100644 --- a/src/qce_circuit/visualization/visualize_layout/element_components.py +++ b/src/qce_circuit/visualization/visualize_layout/element_components.py @@ -30,7 +30,7 @@ class DotComponent(IRectTransformComponent, IDrawComponent): """ pivot: Vec2D alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: ElementStyleSettings = field(default=StyleManager.read_config().dot_style) + style_settings: ElementStyleSettings = field(default_factory=lambda: StyleManager.read_config().dot_style) # region Interface Properties @property @@ -67,7 +67,7 @@ class HexagonComponent(IRectTransformComponent, IDrawComponent): pivot: Vec2D rotation: float = field(default=0) alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: ElementStyleSettings = field(default=StyleManager.read_config().hexagon_style) + style_settings: ElementStyleSettings = field(default_factory=lambda: StyleManager.read_config().hexagon_style) # region Interface Properties @property @@ -124,7 +124,7 @@ class ParkingComponent(IRectTransformComponent, IDrawComponent): """ pivot: Vec2D alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: ParkOperationStyleSettings = field(default=StyleManager.read_config().park_operation_style) + style_settings: ParkOperationStyleSettings = field(default_factory=lambda: StyleManager.read_config().park_operation_style) # region Interface Properties @property @@ -163,9 +163,9 @@ class TextComponent(IRectTransformComponent, IDrawComponent): """ pivot: Vec2D text: str - color: str = field(default=StyleManager.read_config().element_text_style.font_color) + color: str = field(default_factory=lambda: StyleManager.read_config().element_text_style.font_color) alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: ElementTextStyleSettings = field(default=StyleManager.read_config().element_text_style) + style_settings: ElementTextStyleSettings = field(default_factory=lambda: StyleManager.read_config().element_text_style) # region Interface Properties @property diff --git a/src/qce_circuit/visualization/visualize_layout/plaquette_components.py b/src/qce_circuit/visualization/visualize_layout/plaquette_components.py index 9777ea3..f461a18 100644 --- a/src/qce_circuit/visualization/visualize_layout/plaquette_components.py +++ b/src/qce_circuit/visualization/visualize_layout/plaquette_components.py @@ -39,7 +39,7 @@ class RectanglePlaquette(IRectTransformComponent, IDrawComponent): rotation: float = field(default=0) background_type: BackgroundType = field(default=BackgroundType.X) alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: PlaquetteStyleSettings = field(default=StyleManager.read_config().plaquette_style_x) + style_settings: PlaquetteStyleSettings = field(default_factory=lambda: StyleManager.read_config().plaquette_style_x) # region Interface Properties @property @@ -82,7 +82,7 @@ class TrianglePlaquette(IRectTransformComponent, IDrawComponent): rotation: float = field(default=0) background_type: BackgroundType = field(default=BackgroundType.X) alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: PlaquetteStyleSettings = field(default=StyleManager.read_config().plaquette_style_x) + style_settings: PlaquetteStyleSettings = field(default_factory=lambda: StyleManager.read_config().plaquette_style_x) # region Interface Properties @property diff --git a/src/qce_circuit/visualization/visualize_layout/polygon_component.py b/src/qce_circuit/visualization/visualize_layout/polygon_component.py index cea8dfb..8c6fafb 100644 --- a/src/qce_circuit/visualization/visualize_layout/polygon_component.py +++ b/src/qce_circuit/visualization/visualize_layout/polygon_component.py @@ -27,7 +27,7 @@ class PolylineComponent(IDrawComponent): """ vertices: List[Vec2D] alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: LineSettings = field(default=StyleManager.read_config().line_style) + style_settings: LineSettings = field(default_factory=lambda: StyleManager.read_config().line_style) # region Interface Methods def draw(self, axes: plt.Axes) -> plt.Axes: @@ -54,7 +54,7 @@ class GateOperationComponent(IDrawComponent): pivot0: Vec2D pivot1: Vec2D alignment: TransformAlignment = field(default=TransformAlignment.MID_LEFT) - style_settings: GateOperationStyleSettings = field(default=StyleManager.read_config().gate_operation_style) + style_settings: GateOperationStyleSettings = field(default_factory=lambda: StyleManager.read_config().gate_operation_style) # region Class Properties @property diff --git a/src/qce_circuit/visualization/visualize_layout/style_manager.py b/src/qce_circuit/visualization/visualize_layout/style_manager.py index a802fc1..81bf29d 100644 --- a/src/qce_circuit/visualization/visualize_layout/style_manager.py +++ b/src/qce_circuit/visualization/visualize_layout/style_manager.py @@ -3,6 +3,8 @@ # ------------------------------------------- import os from dataclasses import dataclass, field +import threading +from contextlib import contextmanager from qce_circuit.utilities.singleton_base import Singleton from qce_circuit.utilities.readwrite_yaml import ( get_yaml_file_path, @@ -201,6 +203,7 @@ class StyleManager(metaclass=Singleton): Behaviour Class, manages import of (device) layout-visualization style file. """ CONFIG_NAME: str = 'config_layout_style.yaml' + _override_stack = threading.local() # Thread-safe storage for overrides # region Class Methods @classmethod @@ -210,10 +213,18 @@ def _default_config_object(cls) -> dict: @classmethod def read_config(cls) -> StyleSettings: - """:return: File-manager config file.""" + """ + Reads the configuration settings, applying any temporary overrides if present. + It checks if there is a stack of overrides; if so, returns the most recent override. + + :return: The effective StyleSettings, either from file or from temporary overrides. + """ + if hasattr(cls._override_stack, "stack") and cls._override_stack.stack: + # Return the latest override if present + return cls._override_stack.stack[-1] + path = get_yaml_file_path(filename=cls.CONFIG_NAME) if not os.path.exists(path): - # Construct config dict default_dict: dict = cls._default_config_object() write_yaml( filename=cls.CONFIG_NAME, @@ -221,4 +232,33 @@ def read_config(cls) -> StyleSettings: make_file=True, ) return StyleSettings(**read_yaml(filename=cls.CONFIG_NAME)) + + @classmethod + @contextmanager + def temporary_override(cls, **overrides): + """ + Temporarily override specific style settings in memory, supporting nested overrides. + The new override is based on the current configuration, which may already have been overridden. + + :param overrides: Keyword arguments for the configuration values to override. + :return: A context manager that yields control with the overridden configuration. + """ + # Get the current configuration, which is either the top override or the default configuration. + current_config = cls.read_config() + new_config_dict = current_config.__dict__.copy() + new_config_dict.update(overrides) # Apply new overrides + + new_override = StyleSettings(**new_config_dict) + + # Initialize the override stack if it doesn't exist. + if not hasattr(cls._override_stack, "stack"): + cls._override_stack.stack = [] + # Push the new override onto the stack. + cls._override_stack.stack.append(new_override) + + try: + yield + finally: + # Pop the override from the stack to restore the previous configuration. + cls._override_stack.stack.pop() # endregion