From a58e6974567a2b94d3150e3aa2e0b07c1a3311d2 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Fri, 10 Oct 2025 21:54:45 +0000 Subject: [PATCH 1/6] Benchmarking on hf_model --- torchprime/experimental/benchmark/hf_model.py | 81 +++++++++++++ .../benchmark/hf_models_forward.py | 108 +++++++++++++++++ .../benchmark/hf_models_forward_eager.py | 109 ++++++++++++++++++ 3 files changed, 298 insertions(+) create mode 100644 torchprime/experimental/benchmark/hf_model.py create mode 100644 torchprime/experimental/benchmark/hf_models_forward.py create mode 100644 torchprime/experimental/benchmark/hf_models_forward_eager.py diff --git a/torchprime/experimental/benchmark/hf_model.py b/torchprime/experimental/benchmark/hf_model.py new file mode 100644 index 00000000..d8c01996 --- /dev/null +++ b/torchprime/experimental/benchmark/hf_model.py @@ -0,0 +1,81 @@ +from typing import Any + +import torch +from transformers.models.llama import modeling_llama +from transformers.models.qwen2 import modeling_qwen2 + + +def get_llama3_model(torch_dtype: torch.dtype): + """Returns the Llama3.2 1B model.""" + config = modeling_llama.LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=128000, + eos_token_id=128001, + head_dim=64, + hidden_act="silu", + hidden_size=2048, + initializer_range=0.02, + intermediate_size=8192, + max_position_embeddings=131072, + mlp_bias=False, + num_attention_heads=32, + num_hidden_layers=16, + num_key_value_heads=8, + rms_norm_eps=1e-05, + rope_scaling={ + "factor": 32.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + rope_theta=500000.0, + tie_word_embeddings=True, + use_cache=True, + vocab_size=128256, + _attn_implementation="eager", + ) + model = modeling_llama.LlamaForCausalLM(config).to(torch_dtype) + return model + + +def get_qwen2_model(torch_dtype: torch.dtype): + """Returns the Qwen2 1.7B model.""" + config = modeling_qwen2.Qwen2Config( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=151643, + eos_token_id=151645, + head_dim=128, + hidden_act="silu", + hidden_size=2048, + initializer_range=0.02, + intermediate_size=6144, + max_position_embeddings=40960, + max_window_layers=28, + num_attention_heads=16, + num_hidden_layers=28, + num_key_value_heads=8, + rms_norm_eps=1e-06, + rope_scaling=None, + rope_theta=1000000, + sliding_window=None, + tie_word_embeddings=True, + use_cache=True, + use_sliding_window=False, + vocab_size=151936, + _attn_implementation="eager", + ) + model = modeling_qwen2.Qwen2ForCausalLM(config).to(torch_dtype) + return model + + +def get_model(model_name: str, dtype: torch.dtype) -> Any: + match model_name: + case "llama3.2-1B": + return get_llama3_model(dtype) + case "qwen2-1.7B": + return get_qwen2_model(dtype) + case _: + raise ValueError(f"Unsupported model: {model_name}") \ No newline at end of file diff --git a/torchprime/experimental/benchmark/hf_models_forward.py b/torchprime/experimental/benchmark/hf_models_forward.py new file mode 100644 index 00000000..50403559 --- /dev/null +++ b/torchprime/experimental/benchmark/hf_models_forward.py @@ -0,0 +1,108 @@ +import argparse +import os +import time +from typing import Any + +import numpy as np +import torch +import torch_xla + +from torchprime.experimental.benchmark.hf_model import get_model + + +def main(args): + # --- Configuration --- + dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} + torch_dtype = dtype_map[args.dtype] + + # It's good practice to define the device first. + device = torch_xla.device() + + # Create the model on CPU first + model_cpu = get_model(args.model_name, torch_dtype) + config = model_cpu.config + model_cpu.eval() # Set to evaluation mode + + # Move model to the XLA device. + model_tpu = model_cpu.to(device) + + # Create dummy input_ids and move to the XLA device. + input_ids = torch.randint( + 0, config.vocab_size, (args.batch_size, args.seq_len), dtype=torch.long + ) + # Move inputs to the XLA device as well. + input_ids = input_ids.to(device) + + # Preheat the cache. + print("Preheating...") + preheat_start_time = time.perf_counter() + with torch.no_grad(): + _ = model_tpu(input_ids).logits + torch_xla.sync() + preheat_end_time = time.perf_counter() + preheat_time = preheat_end_time - preheat_start_time + print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms") + + # Initial run (warm-up) to trigger XLA compilation + print("Warming up...") + warmup_start_time = time.perf_counter() + with torch.no_grad(): + _ = model_tpu(input_ids).logits + torch_xla.sync() + warmup_end_time = time.perf_counter() + warmup_time = warmup_end_time - warmup_start_time + + # Subsequent runs for measurement + print(f"Starting benchmark for {args.num_runs} runs...") + times = [] + for i in range(args.num_runs): + start_time = time.perf_counter() + with torch.no_grad(): + # The model forward pass is intentionally not assigned to a variable + # to measure only the execution time. + model_tpu(input_ids) + + torch_xla.sync() + end_time = time.perf_counter() + times.append(end_time - start_time) + print(f"Run {i+1}/{args.num_runs}: {(end_time - start_time) * 1000:.2f} ms") + + # Print final performance results + print("\n--- Benchmark Results (Lazy Mode) ---") + print(f"Model: {args.model_name}, DType: {args.dtype}") + print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}") + print(f"Preheat time: {preheat_time * 1000:.2f} ms") + print(f"Warm-up time: {warmup_time * 1000:.2f} ms (includes compilation)") + print(f"Number of runs: {len(times)}") + print(f"Average latency: {np.mean(times) * 1000:.2f} ms") + print(f"Median latency: {np.median(times) * 1000:.2f} ms") + print(f"P90 latency: {np.percentile(times, 90) * 1000:.2f} ms") + print(f"Min latency: {np.min(times) * 1000:.2f} ms") + print(f"Max latency: {np.max(times) * 1000:.2f} ms") + + # Add this line to wait for the TPU to finish and ensure a clean exit + torch_xla.sync() + print("Script finished and exited cleanly.") + os._exit(0) # <-- Use os._exit() instead of sys.exit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark HF models on XLA (Lazy Mode).") + parser.add_argument( + "--model_name", + type=str, + default="llama3.2-1B", + choices=["llama3.2-1B", "qwen2-1.7B"], + help="Model to benchmark (must match a config file name).", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["bfloat16", "float32"], + help="Data type for the model.", + ) + parser.add_argument("--batch_size", type=int, default=1, help="Batch size.") + parser.add_argument("--seq_len", type=int, default=128, help="Sequence length.") + parser.add_argument("--num_runs", type=int, default=10, help="Number of benchmark runs.") + main(parser.parse_args()) \ No newline at end of file diff --git a/torchprime/experimental/benchmark/hf_models_forward_eager.py b/torchprime/experimental/benchmark/hf_models_forward_eager.py new file mode 100644 index 00000000..eace0a7b --- /dev/null +++ b/torchprime/experimental/benchmark/hf_models_forward_eager.py @@ -0,0 +1,109 @@ +import argparse +import os +import time +from typing import Any + +import numpy as np +import torch +import torch_xla + +from torchprime.experimental.benchmark.hf_model import get_model + + +def main(args): + # --- Configuration --- + print("Running in PyTorch/XLA experimental eager mode.") + torch_xla.experimental.eager_mode(True) + + dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} + torch_dtype = dtype_map[args.dtype] + + # It's good practice to define the device first. + device = torch_xla.device() + + # Create the model on CPU first + model_cpu = get_model(args.model_name, torch_dtype) + config = model_cpu.config + model_cpu.eval() # Set to evaluation mode + + # Move model to the XLA device. + model_tpu = model_cpu.to(device) + + # Create dummy input_ids and move to the XLA device. + input_ids = torch.randint( + 0, config.vocab_size, (args.batch_size, args.seq_len), dtype=torch.long + ) + # Move inputs to the XLA device as well. + input_ids = input_ids.to(device) + + # Preheat the cache. + print("Preheating...") + preheat_start_time = time.perf_counter() + with torch.no_grad(): + _ = model_tpu(input_ids).logits + preheat_end_time = time.perf_counter() + preheat_time = preheat_end_time - preheat_start_time + print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms") + + # Initial run (warm-up) + print("Warming up...") + warmup_start_time = time.perf_counter() + with torch.no_grad(): + _ = model_tpu(input_ids).logits + warmup_end_time = time.perf_counter() + warmup_time = warmup_end_time - warmup_start_time + + # Subsequent runs for measurement + print(f"Starting benchmark for {args.num_runs} runs...") + times = [] + for i in range(args.num_runs): + start_time = time.perf_counter() + with torch.no_grad(): + # The model forward pass is intentionally not assigned to a variable + # to measure only the execution time. + model_tpu(input_ids) + + # Do we need this??? + torch_xla.sync() + + end_time = time.perf_counter() + times.append(end_time - start_time) + print(f"Run {i+1}/{args.num_runs}: {(end_time - start_time) * 1000:.2f} ms") + + # Print final performance results + print("\n--- Benchmark Results (Eager Mode) ---") + print(f"Model: {args.model_name}, DType: {args.dtype}") + print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}") + print(f"Preheat time: {preheat_time * 1000:.2f} ms") + print(f"Warm-up time: {warmup_time * 1000:.2f} ms") + print(f"Number of runs: {len(times)}") + print(f"Average latency: {np.mean(times) * 1000:.2f} ms") + print(f"Median latency: {np.median(times) * 1000:.2f} ms") + print(f"P90 latency: {np.percentile(times, 90) * 1000:.2f} ms") + print(f"Min latency: {np.min(times) * 1000:.2f} ms") + print(f"Max latency: {np.max(times) * 1000:.2f} ms") + + print("Script finished and exited cleanly.") + os._exit(0) # <-- Use os._exit() instead of sys.exit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark HF models on XLA (Eager Mode).") + parser.add_argument( + "--model_name", + type=str, + default="llama3.2-1B", + choices=["llama3.2-1B", "qwen2-1.7B"], + help="Model to benchmark (must match a config file name).", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["bfloat16", "float32"], + help="Data type for the model.", + ) + parser.add_argument("--batch_size", type=int, default=1, help="Batch size.") + parser.add_argument("--seq_len", type=int, default=128, help="Sequence length.") + parser.add_argument("--num_runs", type=int, default=10, help="Number of benchmark runs.") + main(parser.parse_args()) \ No newline at end of file From f3d0e3828e7c1db8672067693f8e6b9b23dbc91d Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Mon, 13 Oct 2025 18:29:53 +0000 Subject: [PATCH 2/6] some config fixes --- torchprime/experimental/benchmark/hf_model.py | 14 +++++++------- .../experimental/benchmark/hf_models_forward.py | 2 +- .../benchmark/hf_models_forward_eager.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/torchprime/experimental/benchmark/hf_model.py b/torchprime/experimental/benchmark/hf_model.py index d8c01996..95776b0c 100644 --- a/torchprime/experimental/benchmark/hf_model.py +++ b/torchprime/experimental/benchmark/hf_model.py @@ -2,7 +2,7 @@ import torch from transformers.models.llama import modeling_llama -from transformers.models.qwen2 import modeling_qwen2 +from transformers.models.qwen3 import modeling_qwen3 def get_llama3_model(torch_dtype: torch.dtype): @@ -40,9 +40,9 @@ def get_llama3_model(torch_dtype: torch.dtype): return model -def get_qwen2_model(torch_dtype: torch.dtype): - """Returns the Qwen2 1.7B model.""" - config = modeling_qwen2.Qwen2Config( +def get_qwen3_model(torch_dtype: torch.dtype): + """Returns the Qwen3 1.7B model.""" + config = modeling_qwen3.Qwen3Config( attention_bias=False, attention_dropout=0.0, bos_token_id=151643, @@ -67,7 +67,7 @@ def get_qwen2_model(torch_dtype: torch.dtype): vocab_size=151936, _attn_implementation="eager", ) - model = modeling_qwen2.Qwen2ForCausalLM(config).to(torch_dtype) + model = modeling_qwen3.Qwen3ForCausalLM(config).to(torch_dtype) return model @@ -75,7 +75,7 @@ def get_model(model_name: str, dtype: torch.dtype) -> Any: match model_name: case "llama3.2-1B": return get_llama3_model(dtype) - case "qwen2-1.7B": - return get_qwen2_model(dtype) + case "qwen3-1.7B": + return get_qwen3_model(dtype) case _: raise ValueError(f"Unsupported model: {model_name}") \ No newline at end of file diff --git a/torchprime/experimental/benchmark/hf_models_forward.py b/torchprime/experimental/benchmark/hf_models_forward.py index 50403559..b0e6561b 100644 --- a/torchprime/experimental/benchmark/hf_models_forward.py +++ b/torchprime/experimental/benchmark/hf_models_forward.py @@ -92,7 +92,7 @@ def main(args): "--model_name", type=str, default="llama3.2-1B", - choices=["llama3.2-1B", "qwen2-1.7B"], + choices=["llama3.2-1B", "qwen3-1.7B"], help="Model to benchmark (must match a config file name).", ) parser.add_argument( diff --git a/torchprime/experimental/benchmark/hf_models_forward_eager.py b/torchprime/experimental/benchmark/hf_models_forward_eager.py index eace0a7b..048c6846 100644 --- a/torchprime/experimental/benchmark/hf_models_forward_eager.py +++ b/torchprime/experimental/benchmark/hf_models_forward_eager.py @@ -93,7 +93,7 @@ def main(args): "--model_name", type=str, default="llama3.2-1B", - choices=["llama3.2-1B", "qwen2-1.7B"], + choices=["llama3.2-1B", "qwen3-1.7B"], help="Model to benchmark (must match a config file name).", ) parser.add_argument( From a1a538ad7ae969145895edf07f5d437254e6fd5c Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Mon, 13 Oct 2025 22:44:46 +0000 Subject: [PATCH 3/6] Assign variable for logits --- .../experimental/benchmark/hf_models_forward.py | 12 ++++++------ .../benchmark/hf_models_forward_eager.py | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/torchprime/experimental/benchmark/hf_models_forward.py b/torchprime/experimental/benchmark/hf_models_forward.py index b0e6561b..3aa967f4 100644 --- a/torchprime/experimental/benchmark/hf_models_forward.py +++ b/torchprime/experimental/benchmark/hf_models_forward.py @@ -37,7 +37,8 @@ def main(args): print("Preheating...") preheat_start_time = time.perf_counter() with torch.no_grad(): - _ = model_tpu(input_ids).logits + # Assign to a variable to prevent garbage collection before sync. + logits = model_tpu(input_ids).logits torch_xla.sync() preheat_end_time = time.perf_counter() preheat_time = preheat_end_time - preheat_start_time @@ -47,7 +48,7 @@ def main(args): print("Warming up...") warmup_start_time = time.perf_counter() with torch.no_grad(): - _ = model_tpu(input_ids).logits + logits = model_tpu(input_ids).logits torch_xla.sync() warmup_end_time = time.perf_counter() warmup_time = warmup_end_time - warmup_start_time @@ -58,11 +59,10 @@ def main(args): for i in range(args.num_runs): start_time = time.perf_counter() with torch.no_grad(): - # The model forward pass is intentionally not assigned to a variable - # to measure only the execution time. - model_tpu(input_ids) + # Assign to a variable to prevent garbage collection before sync. + logits = model_tpu(input_ids).logits - torch_xla.sync() + torch_xla.sync() # Wait for the computation to complete. end_time = time.perf_counter() times.append(end_time - start_time) print(f"Run {i+1}/{args.num_runs}: {(end_time - start_time) * 1000:.2f} ms") diff --git a/torchprime/experimental/benchmark/hf_models_forward_eager.py b/torchprime/experimental/benchmark/hf_models_forward_eager.py index 048c6846..e5ced77d 100644 --- a/torchprime/experimental/benchmark/hf_models_forward_eager.py +++ b/torchprime/experimental/benchmark/hf_models_forward_eager.py @@ -40,7 +40,8 @@ def main(args): print("Preheating...") preheat_start_time = time.perf_counter() with torch.no_grad(): - _ = model_tpu(input_ids).logits + # Assign to a variable to prevent garbage collection before sync. + logits = model_tpu(input_ids).logits preheat_end_time = time.perf_counter() preheat_time = preheat_end_time - preheat_start_time print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms") @@ -49,7 +50,7 @@ def main(args): print("Warming up...") warmup_start_time = time.perf_counter() with torch.no_grad(): - _ = model_tpu(input_ids).logits + logits = model_tpu(input_ids).logits warmup_end_time = time.perf_counter() warmup_time = warmup_end_time - warmup_start_time @@ -59,12 +60,11 @@ def main(args): for i in range(args.num_runs): start_time = time.perf_counter() with torch.no_grad(): - # The model forward pass is intentionally not assigned to a variable - # to measure only the execution time. - model_tpu(input_ids) + # Assign to a variable to prevent garbage collection before sync. + logits = model_tpu(input_ids).logits - # Do we need this??? - torch_xla.sync() + # This is critical for accurate timing. XLA operations are asynchronous. + torch_xla.sync() # Wait for the computation to complete. end_time = time.perf_counter() times.append(end_time - start_time) From 5bf4f3e5b1a18ab80a8f17c9d445d2d88da269d5 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Wed, 15 Oct 2025 16:03:34 +0000 Subject: [PATCH 4/6] Use wait_device_ops instead of sync --- .../benchmark/hf_models_forward.py | 23 ++++++------------- .../benchmark/hf_models_forward_eager.py | 8 ++++--- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/torchprime/experimental/benchmark/hf_models_forward.py b/torchprime/experimental/benchmark/hf_models_forward.py index 3aa967f4..0f92489a 100644 --- a/torchprime/experimental/benchmark/hf_models_forward.py +++ b/torchprime/experimental/benchmark/hf_models_forward.py @@ -6,6 +6,7 @@ import numpy as np import torch import torch_xla +import torch_xla.core.xla_model as xm from torchprime.experimental.benchmark.hf_model import get_model @@ -33,23 +34,14 @@ def main(args): # Move inputs to the XLA device as well. input_ids = input_ids.to(device) - # Preheat the cache. - print("Preheating...") - preheat_start_time = time.perf_counter() - with torch.no_grad(): - # Assign to a variable to prevent garbage collection before sync. - logits = model_tpu(input_ids).logits - torch_xla.sync() - preheat_end_time = time.perf_counter() - preheat_time = preheat_end_time - preheat_start_time - print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms") - # Initial run (warm-up) to trigger XLA compilation - print("Warming up...") + print("Warming up (includes XLA graph compilation)...") warmup_start_time = time.perf_counter() with torch.no_grad(): + # The first run triggers compilation, which is a one-time cost. + # Subsequent runs will be much faster as they hit the compilation cache. logits = model_tpu(input_ids).logits - torch_xla.sync() + xm.wait_device_ops() # Block until the graph compilation and execution is complete. warmup_end_time = time.perf_counter() warmup_time = warmup_end_time - warmup_start_time @@ -62,7 +54,7 @@ def main(args): # Assign to a variable to prevent garbage collection before sync. logits = model_tpu(input_ids).logits - torch_xla.sync() # Wait for the computation to complete. + xm.wait_device_ops() # Block until the step's computation is complete for accurate timing. end_time = time.perf_counter() times.append(end_time - start_time) print(f"Run {i+1}/{args.num_runs}: {(end_time - start_time) * 1000:.2f} ms") @@ -71,7 +63,6 @@ def main(args): print("\n--- Benchmark Results (Lazy Mode) ---") print(f"Model: {args.model_name}, DType: {args.dtype}") print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}") - print(f"Preheat time: {preheat_time * 1000:.2f} ms") print(f"Warm-up time: {warmup_time * 1000:.2f} ms (includes compilation)") print(f"Number of runs: {len(times)}") print(f"Average latency: {np.mean(times) * 1000:.2f} ms") @@ -81,7 +72,7 @@ def main(args): print(f"Max latency: {np.max(times) * 1000:.2f} ms") # Add this line to wait for the TPU to finish and ensure a clean exit - torch_xla.sync() + xm.wait_device_ops() # Final sync to ensure all pending operations are done. print("Script finished and exited cleanly.") os._exit(0) # <-- Use os._exit() instead of sys.exit() diff --git a/torchprime/experimental/benchmark/hf_models_forward_eager.py b/torchprime/experimental/benchmark/hf_models_forward_eager.py index e5ced77d..ab0c644c 100644 --- a/torchprime/experimental/benchmark/hf_models_forward_eager.py +++ b/torchprime/experimental/benchmark/hf_models_forward_eager.py @@ -6,6 +6,7 @@ import numpy as np import torch import torch_xla +import torch_xla.core.xla_model as xm from torchprime.experimental.benchmark.hf_model import get_model @@ -42,6 +43,7 @@ def main(args): with torch.no_grad(): # Assign to a variable to prevent garbage collection before sync. logits = model_tpu(input_ids).logits + xm.wait_device_ops() preheat_end_time = time.perf_counter() preheat_time = preheat_end_time - preheat_start_time print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms") @@ -51,6 +53,8 @@ def main(args): warmup_start_time = time.perf_counter() with torch.no_grad(): logits = model_tpu(input_ids).logits + # Block until the operation is complete. + xm.wait_device_ops() warmup_end_time = time.perf_counter() warmup_time = warmup_end_time - warmup_start_time @@ -62,9 +66,7 @@ def main(args): with torch.no_grad(): # Assign to a variable to prevent garbage collection before sync. logits = model_tpu(input_ids).logits - - # This is critical for accurate timing. XLA operations are asynchronous. - torch_xla.sync() # Wait for the computation to complete. + xm.wait_device_ops() end_time = time.perf_counter() times.append(end_time - start_time) From 83da39e63905c8e8742d428f2227d1902dbbb4d3 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Wed, 15 Oct 2025 16:12:02 +0000 Subject: [PATCH 5/6] Fix on forward --- .../experimental/benchmark/hf_models_forward.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torchprime/experimental/benchmark/hf_models_forward.py b/torchprime/experimental/benchmark/hf_models_forward.py index 0f92489a..d90d97fb 100644 --- a/torchprime/experimental/benchmark/hf_models_forward.py +++ b/torchprime/experimental/benchmark/hf_models_forward.py @@ -34,12 +34,21 @@ def main(args): # Move inputs to the XLA device as well. input_ids = input_ids.to(device) + # Preheat the cache. + print("Preheating...") + preheat_start_time = time.perf_counter() + with torch.no_grad(): + # Assign to a variable to prevent garbage collection before sync. + logits = model_tpu(input_ids).logits + xm.wait_device_ops() + preheat_end_time = time.perf_counter() + preheat_time = preheat_end_time - preheat_start_time + print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms") + # Initial run (warm-up) to trigger XLA compilation print("Warming up (includes XLA graph compilation)...") warmup_start_time = time.perf_counter() with torch.no_grad(): - # The first run triggers compilation, which is a one-time cost. - # Subsequent runs will be much faster as they hit the compilation cache. logits = model_tpu(input_ids).logits xm.wait_device_ops() # Block until the graph compilation and execution is complete. warmup_end_time = time.perf_counter() @@ -63,6 +72,7 @@ def main(args): print("\n--- Benchmark Results (Lazy Mode) ---") print(f"Model: {args.model_name}, DType: {args.dtype}") print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}") + print(f"Preheat time: {preheat_time * 1000:.2f} ms") print(f"Warm-up time: {warmup_time * 1000:.2f} ms (includes compilation)") print(f"Number of runs: {len(times)}") print(f"Average latency: {np.mean(times) * 1000:.2f} ms") From a387f2e405a14a6fab4e170763549d0a2ae566a3 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Wed, 15 Oct 2025 16:38:21 +0000 Subject: [PATCH 6/6] Use torch_xla sync in preheat --- torchprime/experimental/benchmark/hf_models_forward.py | 5 +++-- torchprime/experimental/benchmark/hf_models_forward_eager.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torchprime/experimental/benchmark/hf_models_forward.py b/torchprime/experimental/benchmark/hf_models_forward.py index d90d97fb..c987ba41 100644 --- a/torchprime/experimental/benchmark/hf_models_forward.py +++ b/torchprime/experimental/benchmark/hf_models_forward.py @@ -40,7 +40,8 @@ def main(args): with torch.no_grad(): # Assign to a variable to prevent garbage collection before sync. logits = model_tpu(input_ids).logits - xm.wait_device_ops() + torch_xla.sync() + # xm.wait_device_ops() preheat_end_time = time.perf_counter() preheat_time = preheat_end_time - preheat_start_time print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms") @@ -73,7 +74,7 @@ def main(args): print(f"Model: {args.model_name}, DType: {args.dtype}") print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}") print(f"Preheat time: {preheat_time * 1000:.2f} ms") - print(f"Warm-up time: {warmup_time * 1000:.2f} ms (includes compilation)") + print(f"Warm-up time: {warmup_time * 1000:.2f} ms") print(f"Number of runs: {len(times)}") print(f"Average latency: {np.mean(times) * 1000:.2f} ms") print(f"Median latency: {np.median(times) * 1000:.2f} ms") diff --git a/torchprime/experimental/benchmark/hf_models_forward_eager.py b/torchprime/experimental/benchmark/hf_models_forward_eager.py index ab0c644c..759ff2ea 100644 --- a/torchprime/experimental/benchmark/hf_models_forward_eager.py +++ b/torchprime/experimental/benchmark/hf_models_forward_eager.py @@ -43,7 +43,7 @@ def main(args): with torch.no_grad(): # Assign to a variable to prevent garbage collection before sync. logits = model_tpu(input_ids).logits - xm.wait_device_ops() + torch_xla.sync() preheat_end_time = time.perf_counter() preheat_time = preheat_end_time - preheat_start_time print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms")