diff --git a/src/qce_interp/interface_definitions/intrf_error_identifier.py b/src/qce_interp/interface_definitions/intrf_error_identifier.py index eef11ba..77dd0d9 100644 --- a/src/qce_interp/interface_definitions/intrf_error_identifier.py +++ b/src/qce_interp/interface_definitions/intrf_error_identifier.py @@ -179,6 +179,14 @@ def copy_with_post_selection(self, use_heralded_post_selection: bool = False, us """ raise InterfaceMethodException + @abstractmethod + def copy_with_involved_qubit_ids(self, involved_qubit_ids: List[IQubitID]) -> 'IErrorDetectionIdentifier': + """ + :param involved_qubit_ids: Array-like of involved qubit-ID's to select sub-set data. + :return: Newly constructed instance inheriting IErrorDetectionIdentifier interface based on sub-set involved qubit-IDs. + """ + raise InterfaceMethodException + @abstractmethod def get_post_selection_mask(self, cycle_stabilizer_count: int) -> NDArray[np.bool_]: """ @@ -650,7 +658,7 @@ def get_ternary_projected_classification(self, cycle_stabilizer_count: int) -> N result = result.transpose((1, 2, 0)) return result - def copy_with_post_selection(self, use_heralded_post_selection: bool = False, use_projected_leakage_post_selection: bool = False, use_all_projected_leakage_post_selection: bool = False, use_stabilizer_leakage_post_selection: bool = False, post_selection_qubits: Optional[List[IQubitID]] = None) -> 'IErrorDetectionIdentifier': + def copy_with_post_selection(self, use_heralded_post_selection: bool = False, use_projected_leakage_post_selection: bool = False, use_all_projected_leakage_post_selection: bool = False, use_stabilizer_leakage_post_selection: bool = False, post_selection_qubits: Optional[List[IQubitID]] = None) -> 'ErrorDetectionIdentifier': """ :param use_heralded_post_selection: Use post-selection on heralded initialization. :param use_projected_leakage_post_selection: Use post-selection on leakage events during (final) (data) qubit measurement projections. @@ -675,6 +683,25 @@ def copy_with_post_selection(self, use_heralded_post_selection: bool = False, us post_selection_qubits=post_selection_qubits, use_computational_parity=self._use_computational_parity, ) + + def copy_with_involved_qubit_ids(self, involved_qubit_ids: List[IQubitID]) -> 'ErrorDetectionIdentifier': + """ + :param involved_qubit_ids: Array-like of involved qubit-ID's to select sub-set data. + :return: Newly constructed instance inheriting IErrorDetectionIdentifier interface based on sub-set involved qubit-IDs. + """ + return ErrorDetectionIdentifier( + classifier_lookup=self._classifier_lookup, + index_kernel=self._index_kernel, + involved_qubit_ids=involved_qubit_ids, + device_layout=self._device_layout, + qec_rounds=self._qec_rounds, + use_heralded_post_selection=self._use_post_selection, + use_projected_leakage_post_selection=self._use_projected_leakage_post_selection, + use_all_projected_leakage_post_selection=self._use_all_projected_leakage_post_selection, + use_stabilizer_leakage_post_selection=self._use_stabilizer_leakage_post_selection, + post_selection_qubits=self._post_selection_qubits, + use_computational_parity=self._use_computational_parity, + ) # endregion # region Class Methods @@ -1412,7 +1439,7 @@ def get_labeled_ternary_projected_classification(self, cycle_stabilizer_count: i return data_array - def copy_with_post_selection(self, use_heralded_post_selection: bool = False, use_projected_leakage_post_selection: bool = False, use_all_projected_leakage_post_selection: bool = False, use_stabilizer_leakage_post_selection: bool = False, post_selection_qubits: Optional[List[IQubitID]] = None) -> 'IErrorDetectionIdentifier': + def copy_with_post_selection(self, use_heralded_post_selection: bool = False, use_projected_leakage_post_selection: bool = False, use_all_projected_leakage_post_selection: bool = False, use_stabilizer_leakage_post_selection: bool = False, post_selection_qubits: Optional[List[IQubitID]] = None) -> 'LabeledErrorDetectionIdentifier': """ :param use_heralded_post_selection: Use post-selection on heralded initialization. :param use_projected_leakage_post_selection: Use post-selection on leakage events during (final) (data) qubit measurement projections. @@ -1431,6 +1458,17 @@ def copy_with_post_selection(self, use_heralded_post_selection: bool = False, us ) ) + def copy_with_involved_qubit_ids(self, involved_qubit_ids: List[IQubitID]) -> 'LabeledErrorDetectionIdentifier': + """ + :param involved_qubit_ids: Array-like of involved qubit-ID's to select sub-set data. + :return: Newly constructed instance inheriting IErrorDetectionIdentifier interface based on sub-set involved qubit-IDs. + """ + return LabeledErrorDetectionIdentifier( + error_detection_identifier=self._error_detection_identifier.copy_with_involved_qubit_ids( + involved_qubit_ids=involved_qubit_ids, + ) + ) + def get_post_selection_mask(self, cycle_stabilizer_count: int) -> NDArray[np.bool_]: """ Output shape: (N,) diff --git a/src/qce_interp/utilities/custom_context_manager.py b/src/qce_interp/utilities/custom_context_manager.py new file mode 100644 index 0000000..9000350 --- /dev/null +++ b/src/qce_interp/utilities/custom_context_manager.py @@ -0,0 +1,46 @@ +# ------------------------------------------- +# Customized context managers for better maintainability +# ------------------------------------------- +import warnings + + +class WhileLoopSafetyExceededWarning(Warning): + """ + Raised when while-loop safety counter exceeds the allowed number of iterations. + """ + + # region Class Methods + @classmethod + def warning_format(cls, max_iter: int) -> dict: + return dict( + message=f"Max iterations reached ({max_iter}/{max_iter}), exiting loop.", + category=cls, + ) + # endregion + + +class WhileLoopSafety: + """ + Context manager class, + """ + + # region Class Constructor + def __init__(self, max_iterations: int = 10): + self.counter = 0 + self.max_iterations = max_iterations + # endregion + + # region Class Methods + def safety_condition(self): + if self.counter >= self.max_iterations: + warnings.warn(**WhileLoopSafetyExceededWarning.warning_format(max_iter=self.max_iterations)) + return False + self.counter += 1 + return True + + def __enter__(self) -> 'WhileLoopSafety': + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + # endregion diff --git a/src/qce_interp/utilities/initial_state_manager.py b/src/qce_interp/utilities/initial_state_manager.py new file mode 100644 index 0000000..c395edd --- /dev/null +++ b/src/qce_interp/utilities/initial_state_manager.py @@ -0,0 +1,251 @@ +# ------------------------------------------- +# Module containing functionality for determining initial state from code-layout and included qubit-IDs. +# ------------------------------------------- +import collections +from typing import Dict, Any, Union, List, Optional, Set, Tuple +import numpy as np +from qce_circuit.connectivity.intrf_connectivity_surface_code import ISurfaceCodeLayer +from qce_circuit.connectivity.intrf_channel_identifier import IQubitID +from qce_interp.interface_definitions.intrf_error_identifier import ErrorDetectionIdentifier +from qce_interp.utilities.custom_context_manager import WhileLoopSafety + + +class InitialStateManager: + """ + Behaviour class, contains functionality for determining (odd) initial state distribution. + Deterministic outcome based on order of input data-qubit IDs. + """ + + # region Static Class Methods + @staticmethod + def _find_initial_state( + qubit_mapping: Dict[Any, Union[np.ndarray, List[int]]] + ) -> Optional[Dict[int, int]]: + """ + Determines an initial state for data qubits to ensure even state distribution + for each measurement qubit. + + This function solves a constraint satisfaction problem where data qubits must be + assigned a state (0 or 1) such that for each measurement qubit, the set of + data qubits it references has an equal number of 0s and 1s. + + It uses a backtracking algorithm with constraint propagation to find a valid + solution. It can handle complex dependencies, including shared data qubits. + + :param qubit_mapping: A dictionary where keys are measurement qubit identifiers + and values are lists or numpy arrays of the data qubit + indices they measure. + :return: A dictionary mapping each data qubit index to its determined initial + state (0 or 1), or None if no solution is found. + """ + + # Parse input and build data structures for the solver + all_data_qubits: Set[int] = set() + for data_qubits in qubit_mapping.values(): + for qubit_idx in data_qubits: + all_data_qubits.add(qubit_idx) + + sorted_qubits: List[int] = sorted(list(all_data_qubits)) + initial_states: Dict[int, Optional[int]] = {q: None for q in sorted_qubits} + + constraints: List[Dict[str, Union[Tuple[int, ...], int]]] = [] + qubit_to_constraint_indices: Dict[int, List[int]] = collections.defaultdict(list) + + for meas_qubit, data_qubits in qubit_mapping.items(): + num_data_qubits = len(data_qubits) + if num_data_qubits % 2 != 0: + raise ValueError( + f"Measurement qubit '{meas_qubit}' references an odd number of data " + f"qubits ({num_data_qubits}), making an even distribution of " + "0s and 1s impossible." + ) + + constraint = { + 'qubits': tuple(sorted(data_qubits)), + 'target_sum': num_data_qubits // 2 + } + constraints.append(constraint) + constraint_idx = len(constraints) - 1 + for q_idx in data_qubits: + qubit_to_constraint_indices[q_idx].append(constraint_idx) + + def solve(current_states: Dict[int, Optional[int]]) -> Optional[Dict[int, int]]: + """ + Inner recursive function to perform the backtracking search. + """ + # Find the next unassigned data qubit + try: + next_qubit = next(q for q in sorted_qubits if current_states[q] is None) + except StopIteration: + # Base case: all qubits have been successfully assigned a state + return current_states + + # Try assigning both possible states (0 and 1) + for state_to_try in [0, 1]: + # Create a copy of the current states to allow for backtracking + new_states = current_states.copy() + new_states[next_qubit] = state_to_try + + # Propagate constraints + is_consistent, propagated_states = propagate(new_states, next_qubit) + + if is_consistent: + # If propagation is successful, continue solving recursively + solution = solve(propagated_states) + if solution: + return solution + + # If neither 0 nor 1 leads to a solution, backtrack + return None + + def propagate( + states: Dict[int, Optional[int]], + initial_qubit: int + ) -> Tuple[bool, Dict[int, Optional[int]]]: + """ + Propagates the consequences of a qubit's state assignment. + """ + queue = collections.deque([initial_qubit]) + + while queue: + qubit = queue.popleft() + for const_idx in qubit_to_constraint_indices[qubit]: + constraint = constraints[const_idx] + target_sum = constraint['target_sum'] + involved_qubits = constraint['qubits'] + + known_states_sum = 0 + unknown_qubits: List[int] = [] + + for q_idx in involved_qubits: + if states[q_idx] is not None: + known_states_sum += states[q_idx] + else: + unknown_qubits.append(q_idx) + + if not unknown_qubits: # All qubit states in this constraint are known + if known_states_sum != target_sum: + return False, states # Conflict detected + elif len(unknown_qubits) == 1: + # Can determine the state of the single unknown qubit + unknown_q = unknown_qubits[0] + required_state = target_sum - known_states_sum + + if required_state not in [0, 1]: + # Required state is not binary, so this path is invalid + return False, states + + # If the state was already set by another constraint, check for conflict + if states[unknown_q] is not None and states[unknown_q] != required_state: + return False, states + + if states[unknown_q] is None: + states[unknown_q] = required_state + queue.append(unknown_q) + + return True, states + + # Start the recursive search + return solve(initial_states) + + @staticmethod + def construct_odd_initial_state(code_layout: ISurfaceCodeLayer, involved_data_qubit_ids: Optional[List[IQubitID]] = None) -> Dict[IQubitID, int]: + # Data allocation + result: Dict[IQubitID, int] = {} + + if not involved_data_qubit_ids: + involved_data_qubit_ids = code_layout.data_qubit_ids + + parity_index_lookup: Dict[IQubitID, np.ndarray] = ErrorDetectionIdentifier.get_parity_index_lookup( + parity_layout=code_layout, + involved_data_qubit_ids=involved_data_qubit_ids, + involved_ancilla_qubit_ids=code_layout.ancilla_qubit_ids, + ) + for qubit_index, state_id in InitialStateManager._find_initial_state(qubit_mapping=parity_index_lookup).items(): + result[involved_data_qubit_ids[qubit_index]] = state_id + + return result + + @staticmethod + def construct_qubit_chain(code_layout: ISurfaceCodeLayer, involved_data_qubit_ids: List[IQubitID]) -> List[IQubitID]: + """ + Constructs a 1D chain of alternating data and ancilla qubits from a given + set of involved data qubits. + + This method models the qubit layout as a graph and assumes that the provided + data qubits and their connecting ancillas form a simple, unbranched 1D chain. + It traverses the graph structure to reconstruct the chain sequence. + + :param code_layout: The surface code layout object, containing information about all data and ancilla qubits. + :param involved_data_qubit_ids: An unordered list of (data) qubit IDs that are known to form the chain. + :return: A list of qubit IDs representing the alternating 1D chain, + e.g., [data1, ancilla1, data2, ancilla2, ...]. Returns an empty + list if no valid chain can be formed. + """ + if not involved_data_qubit_ids: + return [] + + # 1. Build a complete connectivity map (ID-based) for the entire layout. + all_data_qubits_in_layout = code_layout.data_qubit_ids + full_index_lookup = ErrorDetectionIdentifier.get_parity_index_lookup( + parity_layout=code_layout, + involved_data_qubit_ids=all_data_qubits_in_layout, + involved_ancilla_qubit_ids=code_layout.ancilla_qubit_ids, + ) + ancilla_to_data_ids_map: Dict[IQubitID, List[IQubitID]] = { + ancilla_id: [all_data_qubits_in_layout[idx] for idx in data_indices] + for ancilla_id, data_indices in full_index_lookup.items() + } + + # 2. Build a filtered adjacency list for the subgraph of the chain. + involved_data_set = set(involved_data_qubit_ids) + adj = collections.defaultdict(list) + for ancilla_id, connected_data_ids in ancilla_to_data_ids_map.items(): + # As per the requirement, we only consider ancillas connecting two qubits. + if len(connected_data_ids) != 2: + continue + + d1, d2 = connected_data_ids + # An ancilla is part of the chain if it connects two data qubits + # that are both in our set of interest. + if d1 in involved_data_set and d2 in involved_data_set: + adj[d1].append(ancilla_id) + adj[ancilla_id].append(d1) + adj[d2].append(ancilla_id) + adj[ancilla_id].append(d2) + + if not adj: + return [involved_data_qubit_ids[0]] if involved_data_qubit_ids else [] + + # 3. Find an endpoint of the chain to start the traversal. + # An endpoint in a 1D chain is a node with only one connection in the subgraph. + start_node = None + for data_qubit_id in involved_data_qubit_ids: + if data_qubit_id in adj and len(adj[data_qubit_id]) == 1: + start_node = data_qubit_id + break + + # If no endpoint is found (e.g., a cycle), start with any node. + if start_node is None: + start_node = involved_data_qubit_ids[0] + + # 4. Walk along the chain from the start node to construct the ordered list. + chain = [] + visited = set() + current_node = start_node + with WhileLoopSafety(max_iterations=len(code_layout.qubit_ids)) as loop: + # Execute while loop in safety environment + while (current_node is not None and current_node not in visited) and loop.safety_condition(): + chain.append(current_node) + visited.add(current_node) + + # Find the next unvisited neighbor to continue the chain. + next_node = None + for neighbor in adj[current_node]: + if neighbor not in visited: + next_node = neighbor + break + current_node = next_node + + return chain + # endregion diff --git a/src/qce_interp/utilities/serialize_error_identifier.py b/src/qce_interp/utilities/serialize_error_identifier.py new file mode 100644 index 0000000..cf8abd2 --- /dev/null +++ b/src/qce_interp/utilities/serialize_error_identifier.py @@ -0,0 +1,169 @@ +# ------------------------------------------- +# Functions for serializing error-detection identifier +# ------------------------------------------- +from typing import List, Tuple, TypeVar, Union +import xarray as xr +import numpy as np +from numpy.typing import NDArray +from tqdm import tqdm +from qce_circuit.language.intrf_declarative_circuit import ( + InitialStateContainer, + InitialStateEnum, +) +from qce_circuit.connectivity.intrf_channel_identifier import IQubitID +from qce_circuit.connectivity.intrf_connectivity_surface_code import ISurfaceCodeLayer +from qce_interp.utilities.custom_exceptions import ZeroClassifierShotsException +from qce_interp.interface_definitions.intrf_error_identifier import ( + ErrorDetectionIdentifier, + ILabeledErrorDetectionIdentifier, + LabeledErrorDetectionIdentifier, + DataArrayLabels, +) +from qce_interp.decoder_examples.mwpm_decoders import MWPMDecoderFast +from qce_interp.decoder_examples.majority_voting import MajorityVotingDecoder +from qce_interp.utilities.initial_state_manager import InitialStateManager + + +__all__ = [ + "construct_processed_dataset", +] + +T = TypeVar("T") + + +def construct_processed_dataset(error_identifier: ErrorDetectionIdentifier, initial_state: InitialStateContainer, qec_rounds: List[int], code_layout: ISurfaceCodeLayer) -> xr.Dataset: + + processed_dataset = xr.Dataset() + decoder_set: List[Tuple[MWPMDecoderFast, MajorityVotingDecoder, InitialStateContainer]] = construct_sub_error_identifiers( + error_identifier=error_identifier, + initial_state=initial_state, + code_layout=code_layout, + ) + # Add defect rates + processed_dataset = update_defect_rates( + dataset=processed_dataset, + labeled_error_identifier=LabeledErrorDetectionIdentifier(error_identifier), + qec_round=qec_rounds[-1], + ) + # Add logical fidelities + processed_dataset = update_logical_fidelity( + dataset=processed_dataset, + decoder_set=decoder_set, + qec_rounds=qec_rounds, + ) + + return processed_dataset + + +def get_odd_subarrays(full_array: List[T], skip: int = 1) -> List[List[T]]: + """ + Generate all sub-arrays of all possible odd lengths (>= 3) from a given 1D array, + with an option to skip every second odd length. + + :param full_array: The input 1D array with arbitrary element types. + :param skip: Determines step size for odd lengths (1 = every odd, 2 = every second odd). + :return: List of sub-arrays. + """ + n = len(full_array) + lengths = [length for length in range(n, 2, -2 * skip)] # Step controls skipping behavior + sub_arrays = [] + + for length in lengths: + for i in range(0, n - length + 1, skip): # Skip also affects starting index + sub_arrays.append(full_array[i:i + length]) + + return sub_arrays + + +def construct_sub_error_identifiers(error_identifier: ErrorDetectionIdentifier, initial_state: InitialStateContainer, code_layout: ISurfaceCodeLayer) -> List[Tuple[MWPMDecoderFast, MajorityVotingDecoder, InitialStateContainer]]: + ordered_involved_qubit_ids: List[IQubitID] = InitialStateManager.construct_qubit_chain( + code_layout=code_layout, + involved_data_qubit_ids=error_identifier.involved_qubit_ids, + ) + + initial_state_arrays = get_odd_subarrays(full_array=initial_state.as_array, skip=1) + involved_qubit_arrays = get_odd_subarrays(full_array=ordered_involved_qubit_ids, skip=2) + + result: List[Tuple[MWPMDecoderFast, MajorityVotingDecoder, InitialStateContainer]] = [] + for _initial_state, _involved_qubits in zip(initial_state_arrays, involved_qubit_arrays): + initial_state_container: InitialStateContainer = InitialStateContainer.from_ordered_list([ + InitialStateEnum.ZERO if state == 0 else InitialStateEnum.ONE + for state in _initial_state + ]) + + _error_identifier: ErrorDetectionIdentifier = error_identifier.copy_with_involved_qubit_ids( + involved_qubit_ids=_involved_qubits, + ) + decoder_mwpm = MWPMDecoderFast( + error_identifier=_error_identifier, + qec_rounds=_error_identifier.qec_rounds, + initial_state_container=initial_state_container, + max_optimization_shots=2000, + optimize=False, + optimized_round=_error_identifier.qec_rounds[-1] + ) + decoder_mv = MajorityVotingDecoder( + error_identifier=_error_identifier, + ) + result.append((decoder_mwpm, decoder_mv, initial_state_container)) + return result + + +def update_defect_rates(dataset: xr.Dataset, labeled_error_identifier: ILabeledErrorDetectionIdentifier, qec_round: int) -> xr.Dataset: + labeled_error_identifier_post_selected: ILabeledErrorDetectionIdentifier = labeled_error_identifier.copy_with_post_selection( + use_heralded_post_selection=labeled_error_identifier.include_heralded_post_selection, + use_projected_leakage_post_selection=False, + use_stabilizer_leakage_post_selection=True, + ) + + for qubit_id in labeled_error_identifier.involved_stabilizer_qubit_ids: + data_array: xr.DataArray = labeled_error_identifier.get_labeled_defect_stabilizer_lookup( + cycle_stabilizer_count=qec_round, + )[qubit_id] + # Calculate the mean across 'measurement_repetition' + dataset[f"defect_rates_{qubit_id.id}"] = data_array.mean(dim=DataArrayLabels.MEASUREMENT.value) + + try: + data_array_post_selected: xr.DataArray = labeled_error_identifier_post_selected.get_labeled_defect_stabilizer_lookup( + cycle_stabilizer_count=qec_round, + )[qubit_id] + dataset[f"defect_rates_post_selected_{qubit_id.id}"] = data_array_post_selected.mean(dim=DataArrayLabels.MEASUREMENT.value) + except ZeroClassifierShotsException as e: + pass + return dataset + + +def update_logical_fidelity(dataset: xr.Dataset, decoder_set: List[Tuple[MWPMDecoderFast, MajorityVotingDecoder, InitialStateContainer]], qec_rounds: Union[NDArray[np.int_], List[int]]) -> xr.Dataset: + x_array: np.ndarray = np.asarray(qec_rounds) + + for decoder_index, (decoder_mwpm, decoder_mv, initial_state) in enumerate(decoder_set): + # MWPM Decoder + mwpm_y_array: np.ndarray = np.full_like(x_array, np.nan, dtype=np.float64) + for i, x in tqdm(enumerate(x_array), desc=f"Processing {decoder_mwpm.__class__.__name__} Decoder (d {len(initial_state.as_array)})", total=len(x_array)): + try: + value: float = decoder_mwpm.get_fidelity(x, target_state=initial_state.as_array) + except ZeroClassifierShotsException: + value = np.nan + mwpm_y_array[i] = value + dataset[f"logical_fidelity_mwpm_d{len(initial_state.as_array)}_{decoder_index}"] = xr.DataArray( + mwpm_y_array, + coords={"qec_cycles": x_array}, + dims=["qec_cycles"], + name="logical_fidelity", + ) + # MV Decoder + mv_y_array: np.ndarray = np.full_like(x_array, np.nan, dtype=np.float64) + for i, x in tqdm(enumerate(x_array), desc=f"Processing {decoder_mv.__class__.__name__} Decoder (d {len(initial_state.as_array)})", total=len(x_array)): + try: + value: float = decoder_mv.get_fidelity(x, target_state=initial_state.as_array) + except ZeroClassifierShotsException: + value = np.nan + mv_y_array[i] = value + dataset[f"logical_fidelity_mv_d{len(initial_state.as_array)}_{decoder_index}"] = xr.DataArray( + mv_y_array, + coords={"qec_cycles": x_array}, + dims=["qec_cycles"], + name="logical_fidelity", + ) + + return dataset