From 570d3e1802f739a1648d77dac65c14fdb5903772 Mon Sep 17 00:00:00 2001 From: SeanvdMeer <18538762+minisean@users.noreply.github.com> Date: Tue, 1 Jul 2025 18:05:25 +0200 Subject: [PATCH 1/2] Updated state classification boundary visualization. More reliable and directly based on linear discriminator --- .../plot_state_classification.py | 209 ++++++++++++++---- 1 file changed, 160 insertions(+), 49 deletions(-) diff --git a/src/qce_interp/visualization/plot_state_classification.py b/src/qce_interp/visualization/plot_state_classification.py index 3e71fbc..fcc7949 100644 --- a/src/qce_interp/visualization/plot_state_classification.py +++ b/src/qce_interp/visualization/plot_state_classification.py @@ -3,6 +3,7 @@ # ------------------------------------------- import numpy as np from numpy.typing import NDArray +from itertools import combinations from typing import List, Dict, Tuple from enum import Enum, unique, auto from qce_interp.utilities.geometric_definitions import Vec2D, Polygon, euclidean_distance @@ -27,6 +28,7 @@ from matplotlib.colors import ListedColormap, PowerNorm, Colormap from matplotlib import colormaps from matplotlib import patches +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis STATE_COLORMAP: Dict[StateKey, Colormap] = { @@ -393,62 +395,172 @@ def rotation_point_180_degrees(point: Vec2D, center: Vec2D) -> Vec2D: return result_point -def plot_decision_boundary(decision_boundaries: DecisionBoundaries, **kwargs) -> IFigureAxesPair: +def _get_line_box_intersections( + coefficients: NDArray[np.float64], + intercept: float, + x_limits: Tuple[float, float], + y_limits: Tuple[float, float] +) -> List[NDArray[np.float64]]: """ - Plots decision boundaries for state classification. + Finds the intersection points of a decision line with the plot's bounding box. - :param decision_boundaries: Decision boundaries of states. - :param kwargs: Additional keyword arguments for plot customization. - :return: Tuple containing the figure and axes of the plot. + :param coefficients: The weight vector (w) of the line equation (w*x + b = 0). + :param intercept: The intercept (b) of the line equation. + :param x_limits: A tuple containing the minimum and maximum x-axis limits. + :param y_limits: A tuple containing the minimum and maximum y-axis limits. + :return: A list of unique intersection points as numpy arrays. """ # Data allocation - center: Vec2D = decision_boundaries.mean - boundary_keys: List[StateBoundaryKey] = list(decision_boundaries.boundary_lookup.keys()) + intersection_points: List[NDArray[np.float64]] = [] + x_min, x_max = x_limits + y_min, y_max = y_limits + + # Check intersections with vertical boundaries (x = x_min, x = x_max) + # The line equation is w[0]*x + w[1]*y + b = 0. Solved for y: y = -(w[0]*x + b) / w[1] + if abs(coefficients[1]) > 1e-9: # Avoid division by zero for horizontal lines + for x_value in [x_min, x_max]: + y_value = -(coefficients[0] * x_value + intercept) / coefficients[1] + if y_min <= y_value <= y_max: + intersection_points.append(np.array([x_value, y_value])) + + # Check intersections with horizontal boundaries (y = y_min, y = y_max) + # Solved for x: x = -(w[1]*y + b) / w[0] + if abs(coefficients[0]) > 1e-9: # Avoid division by zero for vertical lines + for y_value in [y_min, y_max]: + x_value = -(coefficients[1] * y_value + intercept) / coefficients[0] + if x_min <= x_value <= x_max: + intersection_points.append(np.array([x_value, y_value])) + + # Remove duplicate points if the line passes through a corner + unique_points: List[NDArray[np.float64]] = [] + for point in intersection_points: + is_duplicate = any(np.allclose(point, unique_point) for unique_point in unique_points) + if not is_duplicate: + unique_points.append(point) + + return unique_points - # Figures and Axes - fig, ax = construct_subplot(**kwargs) - boundary_intersections: Dict[DirectedStateBoundaryKey, Vec2D] = get_axes_intersection_lookup(decision_boundaries=decision_boundaries, ax=ax) - # Store the current limits - original_xlim = ax.get_xlim() - original_ylim = ax.get_ylim() - intersection_points: List[Vec2D] = [] - two_state_classification: bool = len(boundary_keys) == 1 - if two_state_classification: - for boundary_key in boundary_keys: - intersection_points.extend([ - boundary_intersections[DirectedStateBoundaryKey(boundary_key.state_a, boundary_key.state_b)], - boundary_intersections[DirectedStateBoundaryKey(boundary_key.state_b, boundary_key.state_a)] - ]) - else: - for boundary_key in boundary_keys: - intersection_points.extend([ - boundary_intersections[DirectedStateBoundaryKey(boundary_key.state_a, boundary_key.state_b)], - ]) - - # Clip intersection points - for i, intersection_point in enumerate(intersection_points): - _, clipped_intersection_point = clip_line_with_bounds( - line_point1=center, - line_point2=intersection_point, - min_x=original_xlim[0], - max_x=original_xlim[1], - min_y=original_ylim[0], - max_y=original_ylim[1], - ) - intersection_points[i] = clipped_intersection_point # Update intersection points +def plot_decision_boundary(decision_boundaries: DecisionBoundaries, **kwargs) -> IFigureAxesPair: + """ + Plots decision boundaries and regions for an LDA model. - for intersection_point in intersection_points: - ax.plot( - [center.x, intersection_point.x], - [center.y, intersection_point.y], - linestyle='--', - color='k', - linewidth=1, + This function visualizes the classification results by drawing the decision + regions for each class and the linear boundaries between each pair of classes. + It uses the existing limits of the provided axes to define the plot area. + For 3-class problems, it draws rays from the central intersection point outwards. + + :param decision_boundaries: A dataclass instance containing the trained LDA model. + :param axes: A matplotlib Axes object to plot on. + :return: A tuple containing the matplotlib Figure and Axes objects. + """ + # Data allocation + discriminator: LinearDiscriminantAnalysis = decision_boundaries._discriminator + fig, ax = construct_subplot(**kwargs) + number_of_classes: int = len(discriminator.classes_) + line_width: float = 1.0 + meshgrid_samples: int = 101 + x_limits: Tuple[float, float] = ax.get_xlim() + y_limits: Tuple[float, float] = ax.get_ylim() + + if discriminator.n_features_in_ != 2: + raise ValueError( + "This plotting function only supports 2D feature spaces. " + f"The provided LDA model was trained on {discriminator.n_features_in_} features." ) - # Restore the original limits - ax.set_xlim(original_xlim) - ax.set_ylim(original_ylim) + + # Define custom colormaps based on number of classes + if number_of_classes == 2: + colors = [STATE_COLORMAP[StateKey.STATE_0](1.0), STATE_COLORMAP[StateKey.STATE_1](1.0)] + background_cmap = ListedColormap(colors) + else: # n_classes >= 3 + colors = [_color_map(1.0) for _color_map in STATE_COLORMAP.values()] + background_cmap = ListedColormap(colors) + + # Create mesh grid for background plotting + x_grid, y_grid = np.meshgrid( + np.linspace(x_limits[0], x_limits[1], meshgrid_samples), + np.linspace(y_limits[0], y_limits[1], meshgrid_samples) + ) + grid_points = np.c_[x_grid.ravel(), y_grid.ravel()] + + # Plot decision regions (background) + class_grid: NDArray[np.int_] = discriminator.predict(grid_points) + class_grid = class_grid.reshape(x_grid.shape) + ax.contourf(x_grid, y_grid, class_grid, cmap=background_cmap, alpha=0.2, zorder=-10) + + # Plot decision boundary lines + if number_of_classes == 2: + coefficients: NDArray[np.float64] = discriminator.coef_[0] + intercept: float = discriminator.intercept_[0] + line_x_values = np.array(x_limits) + # Handle vertical line case + if abs(coefficients[1]) > 1e-9: + line_y_values = -(coefficients[0] * line_x_values + intercept) / coefficients[1] + ax.plot(line_x_values, line_y_values, 'k--', lw=line_width) + else: + vertical_line_x = float(-intercept / coefficients[0]) + ax.axvline(x=vertical_line_x, color='k', linestyle='--', lw=line_width) + + elif number_of_classes == 3: + # For 3 classes, find the central intersection point of the three boundaries + w01 = discriminator.coef_[0, :] - discriminator.coef_[1, :] + b01 = discriminator.intercept_[0] - discriminator.intercept_[1] + w12 = discriminator.coef_[1, :] - discriminator.coef_[2, :] + b12 = discriminator.intercept_[1] - discriminator.intercept_[2] + + # Solve the system of linear equations to find the intersection + system_matrix = np.array([w01, w12]) + system_vector = np.array([-b01, -b12]) + try: + center_point: NDArray[np.float64] = np.linalg.solve(system_matrix, system_vector) + except np.linalg.LinAlgError: + center_point = None # Fallback for parallel lines + + if center_point is not None: + # Draw rays from the center point outwards for each boundary + for i, j in combinations(range(number_of_classes), 2): + k = next(c for c in range(number_of_classes) if c not in (i, j)) + coefficients = discriminator.coef_[i, :] - discriminator.coef_[j, :] + intercept = discriminator.intercept_[i] - discriminator.intercept_[j] + + # Find where the full line intersects the plot edges + edge_points = _get_line_box_intersections(coefficients, intercept, x_limits, y_limits) + if len(edge_points) != 2: + continue # Should not happen for non-degenerate lines + + # Determine which of the two line segments (center -> edge) is the correct ray + mid_point = (center_point + edge_points[0]) / 2.0 + predicted_class_at_midpoint = discriminator.predict([mid_point])[0] + + if predicted_class_at_midpoint == discriminator.classes_[k]: + ray_end_point = edge_points[1] + else: + ray_end_point = edge_points[0] + + ax.plot( + [center_point[0], ray_end_point[0]], + [center_point[1], ray_end_point[1]], + 'k--', + lw=line_width, + ) + + elif number_of_classes > 3: + # For >3 classes, draw full lines as there's no single intersection point + for i, j in combinations(range(number_of_classes), 2): + coefficients = discriminator.coef_[i, :] - discriminator.coef_[j, :] + intercept = discriminator.intercept_[i] - discriminator.intercept_[j] + line_x_values = np.array(x_limits) + if abs(coefficients[1]) > 1e-9: + line_y_values = -(coefficients[0] * line_x_values + intercept) / coefficients[1] + ax.plot(line_x_values, line_y_values, 'k--', lw=line_width) + else: + vertical_line_x = float(-intercept / coefficients[0]) + ax.axvline(x=vertical_line_x, color='k', linestyle='--', lw=line_width) + + # Final axes formatting + ax.set_xlim(x_limits) + ax.set_ylim(y_limits) return fig, ax @@ -552,6 +664,5 @@ def plot_state_classification(state_classifier: IStateAcquisitionContainer, use_ kwargs[SubplotKeywordEnum.HOST_AXES.value] = (fig, ax) plot_state_shots(state_classifier=state_classifier, **kwargs) plot_decision_boundary(decision_boundaries=decision_boundaries, **kwargs) - fig, ax = plot_decision_region(state_classifier=state_classifier, **kwargs) ax.legend(frameon=False) return fig, ax From e76bb5de6d7a4a984909a42c7d9be9851ad65db4 Mon Sep 17 00:00:00 2001 From: SeanvdMeer <18538762+minisean@users.noreply.github.com> Date: Tue, 12 Aug 2025 12:01:52 +0200 Subject: [PATCH 2/2] Added decision boundary interface and new implementation based on Gaussian distance or mahalanobis distance) --- .../intrf_state_classification.py | 326 +++++++++++++++++- 1 file changed, 311 insertions(+), 15 deletions(-) diff --git a/src/qce_interp/interface_definitions/intrf_state_classification.py b/src/qce_interp/interface_definitions/intrf_state_classification.py index 302c339..af38abb 100644 --- a/src/qce_interp/interface_definitions/intrf_state_classification.py +++ b/src/qce_interp/interface_definitions/intrf_state_classification.py @@ -89,17 +89,99 @@ def __hash__(self): # endregion +class IDecisionBoundaries(ABC): + """ + Interface class, describing exposed methods and properties related to qubit-state decision boundaries. + """ + + # region Interface Properties + @property + @abstractmethod + def mean(self) -> Vec2D: + """:return: Mean IQ-vector based on state boundaries.""" + raise InterfaceMethodException + + @property + @abstractmethod + def state_prediction_index_lookup(self) -> Dict[StateKey, int]: + """Lookup dictionary that maps state key to discriminator prediction index.""" + raise InterfaceMethodException + + @property + @abstractmethod + def prediction_index_to_state_lookup(self) -> Dict[int, StateKey]: + """:return: Lookup dictionary that maps discriminator prediction index to state key.""" + raise InterfaceMethodException + + @property + @abstractmethod + def boundary_lookup(self) -> Dict[StateBoundaryKey, Vec2D]: + """:return: Lookup dictionary pairing state boundary key to IQ vector coordinates.""" + raise InterfaceMethodException + # endregion + + # region Interface Methods + @abstractmethod + def get_boundary(self, key: StateBoundaryKey) -> Optional[Vec2D]: + """ + :return: Boundary point (2D) between state A and B. + If state A == B or if state-boundary is not known, return None. + """ + raise InterfaceMethodException + + @abstractmethod + def get_boundary_between(self, state_a: StateKey, state_b: StateKey) -> Optional[Vec2D]: + """ + :return: Boundary point (2D) between state A and B. + If state A == B or if state-boundary is not known, return None. + """ + raise InterfaceMethodException + + @abstractmethod + def get_binary_predictions(self, shots: NDArray[np.complex64]) -> NDArray[np.int_]: + """ + NOTE: Forces classification of element in group 1 or 2, disregarding other groups. + NOTE: Returns integer prediction value, can be mapped to state-enum using self.prediction_index_to_state_lookup. + :return: Array-like of State key predictions based on shots discrimination. + """ + raise InterfaceMethodException + + @abstractmethod + def get_predictions(self, shots: NDArray[np.complex64]) -> NDArray[np.int_]: + """ + NOTE: Returns integer prediction value, can be mapped to state-enum using self.prediction_index_to_state_lookup. + :return: Array-like of State key predictions based on shots discrimination. + """ + raise InterfaceMethodException + + @abstractmethod + def get_prediction(self, shot: np.complex64) -> StateKey: + """:return: State key prediction based on shot discrimination.""" + raise InterfaceMethodException + + @abstractmethod + def get_fidelity(self, shots: NDArray[np.complex64], assigned_state: StateKey) -> float: + """:return: Assignment fidelity defined as the probability of shots being part of assigned state.""" + raise InterfaceMethodException + + @abstractmethod + def post_select_on(self, shots_to_filter: NDArray[np.complex64], conditional_shots: NDArray[np.complex64], conditional_state: StateKey) -> NDArray[np.complex64]: + """:return: Filtered shots based on conditional shots (of same length) and conditional state.""" + raise InterfaceMethodException + # endregion + + @dataclass(frozen=True) -class DecisionBoundaries: +class DecisionBoundaries(IDecisionBoundaries): """Data class, containing decision boundaries based on states.""" - boundary_lookup: Dict[StateBoundaryKey, Vec2D] + _boundary_lookup: Dict[StateBoundaryKey, Vec2D] _discriminator: LinearDiscriminantAnalysis _state_lookup: Dict[StateKey, int] """Lookup dictionary that maps state key to discriminator prediction index.""" _mean: Optional[Vec2D] = field(default=None) """Explicit specification of boundary means, necessary when handling 2-state classification.""" - # region Class Properties + # region Interface Properties @property def mean(self) -> Vec2D: """:return: Mean IQ-vector based on state boundaries.""" @@ -119,9 +201,13 @@ def prediction_index_to_state_lookup(self) -> Dict[int, StateKey]: """:return: Lookup dictionary that maps discriminator prediction index to state key.""" return {index: state for state, index in self.state_prediction_index_lookup.items()} + @property + def boundary_lookup(self) -> Dict[StateBoundaryKey, Vec2D]: + """:return: Lookup dictionary pairing state boundary key to IQ vector coordinates.""" + return self._boundary_lookup # endregion - # region Class Methods + # region Interface Methods def get_boundary(self, key: StateBoundaryKey) -> Optional[Vec2D]: """ :return: Boundary point (2D) between state A and B. @@ -149,13 +235,13 @@ def get_binary_predictions(self, shots: NDArray[np.complex64]) -> NDArray[np.int :return: Array-like of State key predictions based on shots discrimination. """ shot_reshaped: NDArray[np.float32] = StateAcquisitionContainer.complex_to_real_imag(shots) - # Step 1: Predict probabilities + # Predict probabilities probabilities: np.ndarray = self._discriminator.predict_proba(shot_reshaped) - # Step 2: Compare probabilities for groups 1 and 2 + # Compare probabilities for groups 1 and 2 # Assuming classes are labeled as 0, 1, 2 for groups 1, 2, 3 respectively prob_group_1: np.ndarray = probabilities[:, 0] prob_group_2: np.ndarray = probabilities[:, 1] - # Step 3: Classify based on higher probability + # Classify based on higher probability # Assign to group 1 if prob_group_1 > prob_group_2, else assign to group 2 state_indices: NDArray[np.int_] = np.where(prob_group_1 > prob_group_2, 0, 1) # 0 for group 1, 1 for group 2 return state_indices @@ -194,7 +280,9 @@ def post_select_on(self, shots_to_filter: NDArray[np.complex64], conditional_sho mask: NDArray[np.int_] = np.array( [1 if state_index == conditional_index else np.nan for state_index in state_indices]) return shots_to_filter[~np.isnan(mask)] + # endregion + # region Class Methods @classmethod def from_acquisition_container(cls, container: 'StateAcquisitionContainer') -> 'DecisionBoundaries': """ @@ -238,7 +326,7 @@ def from_acquisition_container(cls, container: 'StateAcquisitionContainer') -> ' ) center: Vec2D = 0.5 * (container.state_acquisition_lookup[state_a].center + container.state_acquisition_lookup[state_b].center) return DecisionBoundaries( - boundary_lookup=intersection_lookup, + _boundary_lookup=intersection_lookup, _discriminator=discriminator, _state_lookup=state_lookup, _mean=center, @@ -253,11 +341,10 @@ def from_acquisition_container(cls, container: 'StateAcquisitionContainer') -> ' intercept2=intercept_lookup[state_b], ) return DecisionBoundaries( - boundary_lookup=intersection_lookup, + _boundary_lookup=intersection_lookup, _discriminator=discriminator, _state_lookup=state_lookup, ) - # endregion # region Static Class Methods @@ -295,6 +382,213 @@ def _calculate_intersection_binary_case(coef1: Vec2D, intercept1: float): # endregion +@dataclass(frozen=True) +class GaussianDecisionBoundaries(IDecisionBoundaries): + """ + Data class, containing decision boundaries based on 2D Gaussian distributions for 0- and 1-states. + + This class provides an alternative to the LinearDiscriminantAnalysis by modeling the + 0- and 1-states as 2D Gaussian distributions in the IQ-plane. + + Classification is performed based on the Mahalanobis distance of a point to each + Gaussian center. A point is classified as state 0 or 1 if it falls within a + specified sigma-threshold of the respective distribution. If it falls within both, + it is assigned to the closer one. Points outside both distributions are classified as state 2. + """ + _mean_0: Vec2D + _mean_1: Vec2D + _inv_cov_0: NDArray[np.float64] + _inv_cov_1: NDArray[np.float64] + _sigma_threshold: float + _linear_boundaries: DecisionBoundaries # Internal instance for delegation + _state_lookup: Dict[StateKey, int] = field(default_factory=lambda: {StateKey.STATE_0: 0, StateKey.STATE_1: 1, StateKey.STATE_2: 2}) + + # region Interface Properties + @property + def mean(self) -> Vec2D: + """:return: Mean IQ-vector, calculated as the midpoint between the two Gaussian centers.""" + return 0.5 * (self._mean_0 + self._mean_1) + + @property + def state_prediction_index_lookup(self) -> Dict[StateKey, int]: + """Lookup dictionary that maps state key to discriminator prediction index.""" + return self._state_lookup + + @property + def prediction_index_to_state_lookup(self) -> Dict[int, StateKey]: + """:return: Lookup dictionary that maps discriminator prediction index to state key.""" + return {index: state for state, index in self.state_prediction_index_lookup.items()} + + @property + def boundary_lookup(self) -> Dict[StateBoundaryKey, Vec2D]: + """:return: An empty dictionary, as boundaries are elliptical and not single points.""" + return self._linear_boundaries.boundary_lookup + # endregion + + # region Interface Methods + def get_boundary(self, key: StateBoundaryKey) -> Optional[Vec2D]: + """:return: Delegated linear boundary point.""" + return self.get_boundary_between(key.state_a, key.state_b) + + def get_boundary_between(self, state_a: StateKey, state_b: StateKey) -> Optional[Vec2D]: + """:return: Delegated linear boundary point.""" + return self._linear_boundaries.get_boundary_between(state_a, state_b) + + def get_binary_predictions(self, shots: NDArray[np.complex64]) -> NDArray[np.int_]: + """ + Forces classification of each shot into state 0 or 1, disregarding state 2. + Classification is based on the smaller Mahalanobis distance, ignoring the sigma threshold. + + :param shots: Array of complex-valued IQ shots. + :return: Array of binary state predictions (0 or 1). + """ + if shots.size == 0: + return np.array([], dtype=int) + + points: NDArray[np.float64] = StateAcquisitionContainer.complex_to_real_imag(shots) + mu_0: NDArray[np.float64] = self._mean_0.to_vector() + mu_1: NDArray[np.float64] = self._mean_1.to_vector() + + delta_0: NDArray[np.float64] = points - mu_0 + delta_1: NDArray[np.float64] = points - mu_1 + mahalanobis_sq_0: NDArray[np.float64] = np.sum(np.dot(delta_0, self._inv_cov_0) * delta_0, axis=1) + mahalanobis_sq_1: NDArray[np.float64] = np.sum(np.dot(delta_1, self._inv_cov_1) * delta_1, axis=1) + + # Assign to 0 if Mahalanobis distance to 0 is smaller, else 1 + state_indices: NDArray[np.int_] = np.where(mahalanobis_sq_0 < mahalanobis_sq_1, + self._state_lookup[StateKey.STATE_0], + self._state_lookup[StateKey.STATE_1]) + return state_indices + + def get_predictions(self, shots: NDArray[np.complex64]) -> NDArray[np.int_]: + """ + Classifies shots based on Mahalanobis distance to Gaussian centers. + + :param shots: Array of complex-valued IQ shots to be classified. + :return: Array of integer state predictions (0, 1, or 2). + """ + if shots.size == 0: + return np.array([], dtype=int) + + points: NDArray[np.float64] = StateAcquisitionContainer.complex_to_real_imag(shots) + mu_0: NDArray[np.float64] = self._mean_0.to_vector() + mu_1: NDArray[np.float64] = self._mean_1.to_vector() + + # Calculate Mahalanobis distance squared for each point to each Gaussian + delta_0: NDArray[np.float64] = points - mu_0 + delta_1: NDArray[np.float64] = points - mu_1 + # (v-mu)^T @ inv_cov @ (v-mu) is equivalent to sum( (v-mu)@inv_cov * (v-mu), axis=1) + mahalanobis_sq_0: NDArray[np.float64] = np.sum(np.dot(delta_0, self._inv_cov_0) * delta_0, axis=1) + mahalanobis_sq_1: NDArray[np.float64] = np.sum(np.dot(delta_1, self._inv_cov_1) * delta_1, axis=1) + + # Determine if points are within the n-sigma threshold + threshold_sq: float = self._sigma_threshold ** 2 + is_in_0: NDArray[np.bool_] = mahalanobis_sq_0 < threshold_sq + is_in_1: NDArray[np.bool_] = mahalanobis_sq_1 < threshold_sq + + # Apply classification logic + # Default to state 2 (outside both ellipses) + predictions: NDArray[np.int_] = np.full(shots.shape, self._state_lookup[StateKey.STATE_2], dtype=np.int_) + + # Case 1: Inside only the 0-state ellipse + predictions[is_in_0 & ~is_in_1] = self._state_lookup[StateKey.STATE_0] + # Case 2: Inside only the 1-state ellipse + predictions[~is_in_0 & is_in_1] = self._state_lookup[StateKey.STATE_1] + + # Case 3: Inside both ellipses (intersection) + both_mask: NDArray[np.bool_] = is_in_0 & is_in_1 + if np.any(both_mask): + # Assign to the state with the smaller Mahalanobis distance + closer_to_0_mask: NDArray[np.bool_] = mahalanobis_sq_0[both_mask] < mahalanobis_sq_1[both_mask] + + # Get indices of the points that are in the intersection + both_indices = np.where(both_mask)[0] + + # Update predictions for points closer to 0 + predictions[both_indices[closer_to_0_mask]] = self._state_lookup[StateKey.STATE_0] + # Update predictions for points closer to 1 + predictions[both_indices[~closer_to_0_mask]] = self._state_lookup[StateKey.STATE_1] + + return predictions + + def get_prediction(self, shot: np.complex64) -> StateKey: + """:return: State key prediction based on a single shot.""" + prediction_index: NDArray[np.int_] = self.get_predictions(shots=np.asarray([shot])) + return self.prediction_index_to_state_lookup[prediction_index[0]] + + def get_fidelity(self, shots: NDArray[np.complex64], assigned_state: StateKey) -> float: + """:return: Assignment fidelity defined as the fraction of shots classified as the assigned state.""" + predictions: NDArray[np.int_] = self.get_predictions(shots) + assigned_index: int = self.state_prediction_index_lookup[assigned_state] + return float(np.mean(predictions == assigned_index)) + + def post_select_on(self, shots_to_filter: NDArray[np.complex64], conditional_shots: NDArray[np.complex64], conditional_state: StateKey) -> NDArray[np.complex64]: + """:return: Filtered shots based on conditional shots (of same length) and conditional state.""" + if len(conditional_shots) == 0: + return shots_to_filter + + predictions: NDArray[np.int_] = self.get_predictions(conditional_shots) + conditional_index: int = self._state_lookup[conditional_state] + mask: NDArray[np.bool_] = (predictions == conditional_index) + return shots_to_filter[mask] + # endregion + + # region Class Methods + @classmethod + def from_acquisition_container(cls, container: 'StateAcquisitionContainer', sigma_threshold: float = 3.0) -> 'GaussianDecisionBoundaries': + """ + Factory method to construct the class from a StateAcquisitionContainer. + + :param container: The container holding the acquisition data for states 0 and 1. + :param sigma_threshold: The number of standard deviations (sigma) to use as the classification boundary. + :return: An instance of GaussianDecisionBoundaries. + """ + # Extract StateAcquisition for states 0 and 1 + try: + acq_0: StateAcquisition = container.get_state_acquisition(StateKey.STATE_0) + acq_1: StateAcquisition = container.get_state_acquisition(StateKey.STATE_1) + except KeyError as e: + raise ValueError( + f"StateAcquisitionContainer must contain both STATE_0 and STATE_1 for GaussianDecisionBoundaries. Missing: {e}") + + # Convert complex shots to 2D real vectors + shots_0_real_imag: NDArray[np.float64] = StateAcquisitionContainer.complex_to_real_imag(acq_0.shots) + shots_1_real_imag: NDArray[np.float64] = StateAcquisitionContainer.complex_to_real_imag(acq_1.shots) + + if shots_0_real_imag.shape[0] < 2 or shots_1_real_imag.shape[0] < 2: + raise ValueError("At least 2 shots are required for each state to calculate a covariance matrix.") + + # Get means (centers of the point clouds) + mean_0: Vec2D = acq_0.center + mean_1: Vec2D = acq_1.center + + # Calculate the 2x2 covariance matrix for each state + # rowvar=False because each column is a variable (I and Q) + cov_0: NDArray[np.float64] = np.cov(shots_0_real_imag, rowvar=False) + cov_1: NDArray[np.float64] = np.cov(shots_1_real_imag, rowvar=False) + + # Calculate the inverse of the covariance matrices + try: + inv_cov_0: NDArray[np.float64] = np.linalg.inv(cov_0) + inv_cov_1: NDArray[np.float64] = np.linalg.inv(cov_1) + except np.linalg.LinAlgError as e: + raise RuntimeError(f"Could not invert covariance matrix. The data might be collinear. Error: {e}") + + # Create the linear boundaries instance for delegation + linear_boundaries = DecisionBoundaries.from_acquisition_container(container) + + # Create and return the class instance + return cls( + _mean_0=mean_0, + _mean_1=mean_1, + _inv_cov_0=inv_cov_0, + _inv_cov_1=inv_cov_1, + _linear_boundaries=linear_boundaries, + _sigma_threshold=sigma_threshold, + ) + # endregion + + class IStateAcquisitionContainer(ABC): """ Interface class, describing state acquisition and classification for state 0, 1 (and 2). @@ -334,7 +628,7 @@ class StateAcquisitionContainer(IStateAcquisitionContainer): Data class, containing raw acquisition shots for state 0, 1 and 2. """ state_acquisition_lookup: Dict[StateKey, StateAcquisition] - decision_boundaries: DecisionBoundaries = field(init=False) + decision_boundaries: IDecisionBoundaries = field(init=True, default=None) # region Interface Properties @property @@ -379,16 +673,18 @@ def get_state_acquisition(self, state: StateKey) -> StateAcquisition: # region Class Methods def __post_init__(self): - object.__setattr__(self, 'decision_boundaries', DecisionBoundaries.from_acquisition_container(self)) + if object.__getattribute__(self, 'decision_boundaries') is None: + object.__setattr__(self, 'decision_boundaries', DecisionBoundaries.from_acquisition_container(self)) @classmethod - def from_state_acquisitions(cls, acquisitions: List[StateAcquisition]) -> 'StateAcquisitionContainer': + def from_state_acquisitions(cls, acquisitions: List[StateAcquisition], decision_boundaries: Optional[DecisionBoundaries] = None) -> 'StateAcquisitionContainer': """:return: Class method constructor based on array-like of (state) acquisitions.""" return StateAcquisitionContainer( state_acquisition_lookup={ acquisition.state: acquisition for acquisition in acquisitions - } + }, + decision_boundaries=decision_boundaries, ) # endregion @@ -930,7 +1226,7 @@ def reshape(cls, container: TStateClassifierContainer, index_slices: NDArray[np. class ShotsClassifierContainer(IStateClassifierContainer): """Data class, containing classified states based on (complex) acquisition and decision boundaries.""" shots: NDArray[np.complex64] - decision_boundaries: DecisionBoundaries + decision_boundaries: IDecisionBoundaries _expected_parity: ParityType = field(default=ParityType.EVEN) _stabilizer_reset: bool = field(default=False) _odd_weight_and_refocusing: bool = field(default=False)