From 14e0a302c569c23a42b51c851463dfc7bd93b8aa Mon Sep 17 00:00:00 2001 From: SeanvdMeer <18538762+minisean@users.noreply.github.com> Date: Mon, 23 Jun 2025 00:51:06 +0200 Subject: [PATCH 1/2] Reverted change to _calculate_intersection as it introduced a bug. --- .../intrf_state_classification.py | 69 +++++-------------- 1 file changed, 17 insertions(+), 52 deletions(-) diff --git a/src/qce_interp/interface_definitions/intrf_state_classification.py b/src/qce_interp/interface_definitions/intrf_state_classification.py index 1471222..302c339 100644 --- a/src/qce_interp/interface_definitions/intrf_state_classification.py +++ b/src/qce_interp/interface_definitions/intrf_state_classification.py @@ -266,58 +266,23 @@ def _calculate_intersection(coef1: Vec2D, intercept1: float, coef2: Vec2D, inter """ :return: Intersection point of two linear equations defined by coefficients and intercepts. """ - """ - Compute the (x, y) coordinates where the two lines - - coef1.x * x + coef1.y * y + intercept1 = 0 - coef2.x * x + coef2.y * y + intercept2 = 0 - - intersect. - - Parameters - ---------- - coef1, coef2 : Vec2D - Line coefficients (a, b), i.e. normal-vector components. - intercept1, intercept2 : float - Line intercepts *c* (constant terms). - - Returns - ------- - Vec2D - Intersection point. - - Raises - ------ - ValueError - If the lines are parallel (det ≈ 0) or coincident, so no unique - intersection exists. - """ - # Unpack coefficients - a1, b1 = coef1.x, coef1.y - a2, b2 = coef2.x, coef2.y - - # Determinant of the 2×2 system - det: float = a1 * b2 - a2 * b1 - tol: float = 1e-12 - - if abs(det) < tol: - # Parallel or coincident — inspect intercepts for distinction - if abs(a1 * intercept2 - a2 * intercept1) < tol and \ - abs(b1 * intercept2 - b2 * intercept1) < tol: - # raise ValueError("Lines are coincident; infinite intersections.") - return Vec2D(x=0.0, y=0.0) - # raise ValueError("Lines are parallel; no unique intersection.") - return Vec2D(x=0.0, y=0.0) - - # Solve A x = b (A = [[a1, b1], [a2, b2]], b = [-c1, -c2]) - A = np.array([ - [a1, b1], - [a2, b2], - ], dtype=float) - b = np.array([-intercept1, -intercept2], dtype=float) - x, y = np.linalg.solve(A, b) - - return Vec2D.from_vector(np.array([x, y])) + denominator: float = (-coef1.x / coef1.y + coef2.x / coef2.y) + numerator: float = (-intercept2 / coef2.y + intercept1 / coef1.y) + # Deal with possible zero-division error + if denominator != 0: + _x: float = numerator / denominator + elif denominator == 0 and numerator == 0: + _x: float = 1.0 + else: + warn(f"[ZeroDivisionError] During intersect calculation of {coef1} and {coef2}.") + denominator = 1e-6 + _x: float = numerator / denominator + + _y: float = -coef1.x / coef1.y * _x - intercept1 / coef1.y + return Vec2D( + x=_x, + y=_y, + ) @staticmethod def _calculate_intersection_binary_case(coef1: Vec2D, intercept1: float): From bda54e68010990cb760bbfeb90ccc2574a8179b9 Mon Sep 17 00:00:00 2001 From: SeanvdMeer <18538762+minisean@users.noreply.github.com> Date: Fri, 27 Jun 2025 16:30:13 +0200 Subject: [PATCH 2/2] Added option to enforce binary classification plotting --- .../visualization/plot_state_classification.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/qce_interp/visualization/plot_state_classification.py b/src/qce_interp/visualization/plot_state_classification.py index 8033d47..3e71fbc 100644 --- a/src/qce_interp/visualization/plot_state_classification.py +++ b/src/qce_interp/visualization/plot_state_classification.py @@ -8,6 +8,7 @@ from qce_interp.utilities.geometric_definitions import Vec2D, Polygon, euclidean_distance from qce_interp.interface_definitions.intrf_state_classification import ( IStateAcquisitionContainer, + StateAcquisitionContainer, StateBoundaryKey, DirectedStateBoundaryKey, DecisionBoundaries, @@ -524,7 +525,7 @@ def plot_decision_region(state_classifier: IStateAcquisitionContainer, **kwargs) return fig, ax -def plot_state_classification(state_classifier: IStateAcquisitionContainer, **kwargs) -> IFigureAxesPair: +def plot_state_classification(state_classifier: IStateAcquisitionContainer, use_binary_classification: bool = False, **kwargs) -> IFigureAxesPair: """ Creates a plot visualizing state classification and decision boundaries. @@ -532,6 +533,14 @@ def plot_state_classification(state_classifier: IStateAcquisitionContainer, **kw :param kwargs: Additional keyword arguments for plot customization. :return: Tuple containing the figure and axes of the plot. """ + if use_binary_classification: + state_classifier = StateAcquisitionContainer.from_state_acquisitions( + acquisitions=[ + state_classifier.get_state_acquisition(state=StateKey.STATE_0), + state_classifier.get_state_acquisition(state=StateKey.STATE_1), + ] + ) + decision_boundaries: DecisionBoundaries = state_classifier.classification_boundaries kwargs[SubplotKeywordEnum.AXES_FORMAT.value] = IQAxesFormat() kwargs[SubplotKeywordEnum.LABEL_FORMAT.value] = LabelFormat(