diff --git a/graph_net/fault_locator/bi_search.py b/graph_net/fault_locator/bi_search.py index 3d2c94167..0b00e7024 100644 --- a/graph_net/fault_locator/bi_search.py +++ b/graph_net/fault_locator/bi_search.py @@ -6,9 +6,9 @@ def bi_search( predicator, # Signature: (ES, tolerance) -> bool stoper, # Signature: (history_list) -> bool tolerance=0, -) -> list[(int, bool)]: +) -> (list[(int, bool)], int): """ - Binary Search Algorithm for Automatic Fault Location. + Binary Search Algorithm for Automatic Fault Location with Faulty Operator Detection. This algorithm locates the first faulty operation in a computational graph by iteratively narrowing the search range through graph truncation and @@ -24,9 +24,12 @@ def bi_search( tolerance (int): Numerical threshold for fault detection. Returns: - list: Search history as a list of (split_point, is_fault) tuples. + tuple: (search_history, faulty_operator_index) + - search_history: list of (split_point, is_fault) tuples + - faulty_operator_index: index of the first faulty operator, or -1 if no fault found """ search_history = [] + faulty_operator_index = -1 # Initialize as -1 meaning no fault found # Initialize boundaries. # 'high' usually represents the total number of operators in the graph. @@ -73,7 +76,19 @@ def bi_search( if not any(h[0] == low for h in search_history): truncated_model_path = truncator(relative_model_path, low) final_es = es_scores_calculator(evaluator(truncated_model_path)) - search_history.append((low, predicator(final_es, tolerance))) + final_is_fault = predicator(final_es, tolerance) + search_history.append((low, final_is_fault)) + + if final_is_fault: + faulty_operator_index = low break - return search_history + faulty_positions = [pos for pos, is_fault in search_history if is_fault] + if faulty_positions: + faulty_operator_index = min(faulty_positions) + faulty_model_path = truncator(relative_model_path, faulty_operator_index) + else: + faulty_operator_index = -1 + faulty_model_path = "" + + return search_history, faulty_operator_index, faulty_model_path diff --git a/graph_net/fault_locator/terminator.py b/graph_net/fault_locator/terminator.py index d585cbd86..ce73fdbdf 100644 --- a/graph_net/fault_locator/terminator.py +++ b/graph_net/fault_locator/terminator.py @@ -6,12 +6,16 @@ def __call__(self, history: list[(int, float)], high: int): from pprint import pprint pprint(history) - print(f"{high=}") return bi_search_terminator(history, high) def bi_search_terminator(history: list[(int, float)], high: int): """Stops when the search interval converges (range is 0 or 1).""" + last_idx, is_broken = history[-1] + if last_idx == 1 and is_broken: + return True + if last_idx == high and not is_broken: + return True if len(history) == 1 and history[0][0] == high and not history[0][1]: return True if len(history) < 2: diff --git a/graph_net/fault_locator/torch/device_evaluator.py b/graph_net/fault_locator/torch/device_evaluator.py new file mode 100644 index 000000000..944c4b6d4 --- /dev/null +++ b/graph_net/fault_locator/torch/device_evaluator.py @@ -0,0 +1,107 @@ +import sys +import subprocess +import time +from pathlib import Path +from graph_net.declare_config_mixin import DeclareConfigMixin + + +class DeviceEvaluator(DeclareConfigMixin): + """ + Evaluator responsible for comparing model performance and accuracy between + a reference device (e.g., CPU) and a target device (e.g., CUDA). + Uses 'default' as the operator library for all target executions. + """ + + def __init__(self, config=None): + self.init_config(config) + + def declare_config( + self, + model_path_prefix: str, + output_dir: str, + ref_device: str = "cpu", + target_device: str = "cuda", + compiler: str = "nope", + ): + """ + Configuration schema for cross-device benchmarking. + """ + pass + + def __call__(self, rel_model_path: str) -> str: + """ + Orchestrates the evaluation pipeline: + 1. Generates ground truth data on the reference device. + 2. Validates performance/accuracy on the target device. + """ + output_path = Path(self.config["output_dir"]) + full_model_path = Path(self.config["model_path_prefix"]) / rel_model_path + + # Define specific workspace for target device logs + workspace = output_path / self.config["target_device"] / rel_model_path + workspace.mkdir(parents=True, exist_ok=True) + + # Directory for sharing ground truth data between runs + reference_dir = output_path / "reference_data" + reference_dir.mkdir(parents=True, exist_ok=True) + + log_file = workspace / "validation.log" + + # Step 1: Execute reference test to establish baseline + print(f"Generating reference data on: {self.config['ref_device']}") + self._run_reference_test(full_model_path, reference_dir) + + # Step 2: Execute target test and return captured logs + print(f"Running target evaluation on: {self.config['target_device']}") + return self._run_target_test(full_model_path, reference_dir, log_file) + + def _run_reference_test(self, full_model_path: Path, reference_dir: Path): + """ + Invokes the reference module to generate expected outputs (Ground Truth). + """ + cmd = [ + sys.executable, + "-m", + "graph_net.torch.test_reference_device", + "--model-path", + str(full_model_path), + "--reference-dir", + str(reference_dir), + "--compiler", + self.config["compiler"], + "--device", + self.config["ref_device"], + ] + # Reference runs are silent; errors will raise a CalledProcessError + subprocess.run(cmd, check=True, capture_output=True, text=True) + + def _run_target_test( + self, full_model_path: Path, reference_dir: Path, log_file: Path + ) -> str: + """ + Executes the model on the target device using 'default' op_lib + and captures the full output log. + """ + cmd = [ + sys.executable, + "-m", + "graph_net.torch.test_target_device", + "--model-path", + str(full_model_path), + "--reference-dir", + str(reference_dir), + "--device", + self.config["target_device"], + "--op-lib", + "default", + ] + + print(" ".join(cmd)) + # Redirect all output to the log file for persistence and analysis + with log_file.open("w") as f: + start_time = time.perf_counter() + subprocess.run(cmd, stdout=f, stderr=subprocess.STDOUT, check=True) + end_time = time.perf_counter() + print(f"Target execution completed in {end_time - start_time:.4f} seconds") + + return log_file.read_text() diff --git a/graph_net/sample_pass/auto_fault_bisearcher.py b/graph_net/sample_pass/auto_fault_bisearcher.py index 62b716ff6..bab0a0d8a 100644 --- a/graph_net/sample_pass/auto_fault_bisearcher.py +++ b/graph_net/sample_pass/auto_fault_bisearcher.py @@ -1,5 +1,6 @@ import os import graph_net +import shutil from pathlib import Path from typing import List, Tuple from graph_net.sample_pass.sample_pass import SamplePass @@ -81,7 +82,10 @@ def __call__(self, rel_model_path: str): """ # 2. Invoke the core binary search algorithm # history type: list[tuple[int, bool]] - history: List[Tuple[int, bool]] = bi_search( + history: List[Tuple[int, bool]] + faulty_operator_index: int + faulty_model_path: str + history, faulty_operator_index, faulty_model_path = bi_search( relative_model_path=rel_model_path, truncator=self.truncator, evaluator=self.evaluator, @@ -100,13 +104,21 @@ def __call__(self, rel_model_path: str): output_base.mkdir(parents=True, exist_ok=True) result_file = output_base / file_name + test_file = ( + Path(self.config["truncator_config"]["output_dir"]) / faulty_model_path + ) # Write history entries in the format: {truncate_size} {has_fault} with result_file.open("w", encoding="utf-8") as f: for trunc_size, has_fault in history: f.write(f"{trunc_size} {has_fault}\n") + save_base = Path(self.config["output_dir"]) / "faulty_test" + save_base.mkdir(parents=True, exist_ok=True) + shutil.copytree(test_file, save_base / test_file.name, dirs_exist_ok=True) print( f"[AutoFault] Search history for {rel_model_path} saved to: {result_file}" ) + print(f"First faulty operator index: {faulty_operator_index}") + print(f"Faulty operator model path: {test_file}") return history diff --git a/graph_net/test/bi_search_test.py b/graph_net/test/bi_search_test.py index e76eca8ea..f57301b39 100644 --- a/graph_net/test/bi_search_test.py +++ b/graph_net/test/bi_search_test.py @@ -27,7 +27,7 @@ def mock_evaluator(self, sub_model_id): def mock_predicator(self, es_scores, tolerance): return any(score > tolerance for score in es_scores) - def mock_stoper(self, history): + def mock_stoper(self, history, high=None): """Stops when the search interval converges (range is 0 or 1).""" if len(history) < 2: return False @@ -40,26 +40,21 @@ def test_bi_search_finds_correct_index(self): truncator = self.mock_truncator() setattr(truncator, "total_steps", 9) - history = bi_search( - model_path=self.model_path, + history, faulty_operator_index, faulty_model_path = bi_search( + relative_model_path=self.model_path, truncator=truncator, evaluator=self.mock_evaluator, + es_scores_calculator=lambda x: x, # Mock ES calculator predicator=self.mock_predicator, stoper=self.mock_stoper, tolerance=0.5, ) print(f"\nFault Test History: {history}") + print(f"Detected faulty operator index: {faulty_operator_index}") - # Filter history for all occurrences where a fault was detected - faulty_steps = [step for step in history if step[1] is True] - - # The result of the fault localization is the minimum index with is_fault=True - if faulty_steps: - # Sort by index to find the first occurrence - actual_fault_index = min(faulty_steps, key=lambda x: x[0])[0] - else: - actual_fault_index = None + # The result of the fault localization is directly provided by the function + actual_fault_index = faulty_operator_index print(f"\nIdentified Fault Index: {actual_fault_index}") self.assertEqual(actual_fault_index, self.fault_index) @@ -76,18 +71,23 @@ def clean_truncator(path, split_point): def healthy_evaluator(sub_model_id): return [0.01] - history = bi_search( - model_path=self.model_path, + history, faulty_operator_index, faulty_model_path = bi_search( + relative_model_path=self.model_path, truncator=clean_truncator, evaluator=healthy_evaluator, + es_scores_calculator=lambda x: x, # Mock ES calculator predicator=self.mock_predicator, stoper=self.mock_stoper, tolerance=0.5, ) print(f"No-Fault Test History: {history}") + print(f"Detected faulty operator index: {faulty_operator_index}") + + # No fault should be detected final_status = history[-1][1] self.assertFalse(final_status) + self.assertEqual(faulty_operator_index, -1) # -1 indicates no fault found if __name__ == "__main__": diff --git a/graph_net/test/device_fault_bisearcher_test.sh b/graph_net/test/device_fault_bisearcher_test.sh new file mode 100644 index 000000000..711f79f6a --- /dev/null +++ b/graph_net/test/device_fault_bisearcher_test.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# Resolve the root directory of the project +GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") + +# Test Environment Setup +MODEL_LIST="$GRAPH_NET_ROOT/graph_net/test/small10_torch_samples_list.txt" +MODEL_PREFIX="$GRAPH_NET_ROOT" +OUTPUT_DIR="/tmp/workspace_auto_fault_bisearcher" + +# Execute the SamplePass via the standard CLI entry point +python3 -m graph_net.apply_sample_pass \ + --model-path-list "$MODEL_LIST" \ + --sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/sample_pass/auto_fault_bisearcher.py" \ + --sample-pass-class-name AutoFaultBisearcher \ + --sample-pass-config $(base64 -w 0 <