Skip to content

torch.compile crashes when sizes are unbacked SymInts. #243

@rwkeane

Description

@rwkeane

Describe the bug
Unbacked SymInts (i.e. when the size isnt known at trace time) cause SegmentedPolynomial to crash. This was just the first call I tried it on, I would expect the same issue is present elsewhere too

To Reproduce

Run the following script. Note that my code hits the same issue without setting torch._dynamo.config.capture_dynamic_output_shape_ops = True (so the failure is not limited to that case), but it was significantly easier to create a repro that did use it

"""Minimal repro: cuequivariance fails with unbacked SymInts from torch.compile.

Bug: segmented_polynomial_fused_tp.py:205 does `if size != 1:` with unbacked SymInt.
"""

import torch
from cuequivariance import Irreps
from cuequivariance.group_theory.descriptors import full_tensor_product
from cuequivariance_torch import SegmentedPolynomial

torch._dynamo.config.capture_dynamic_output_shape_ops = True


def create_cg_module(device):
    lhs_irreps = Irreps("SO3", [(4, 0), (2, 1)])
    rhs_irreps = Irreps("SO3", [(3, 0)])
    equation = full_tensor_product(lhs_irreps, rhs_irreps, None)
    return SegmentedPolynomial(equation.polynomial, method="fused_tp").to(device)


@torch.compile(fullgraph=True)
def test_with_cuequivariance(positions, radius, cg_module, call_cg):
    # Create unbacked SymInt via nonzero
    dist = torch.cdist(positions, positions)
    mask = dist < radius
    edge_indices = torch.nonzero(mask)
    num_edges = edge_indices.shape[0]  # unbacked SymInt 'u0'
    
    # Create features with dynamic shape
    lhs = torch.randn(num_edges, 10, device=positions.device)
    rhs = torch.randn(num_edges, 3, device=positions.device)
    
    if not call_cg:
        return lhs.sum() + rhs.sum()  # Works fine with unbacked SymInts
    
    # FAILS HERE: cuequivariance can't handle unbacked SymInt
    result = cg_module([lhs, rhs])
    return result[0]


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    cg_module = create_cg_module(device)
    positions = torch.randn(1, 50, 3, device=device)
    
    # Test 1: Without cuequivariance - works
    print("Test 1 (no cuequivariance):", end=" ")
    result = test_with_cuequivariance(positions, 0.5, cg_module, False)
    print(f"✓ Pass")
    
    # Test 2: With cuequivariance - fails
    print("Test 2 (with cuequivariance):", end=" ")
    result = test_with_cuequivariance(positions, 0.5, cg_module, True)
    print(f"✗ Unexpected pass")

It crashes with error

ryan@ryan-dev-box:~/src/env$ TORCHDYNAMO_VERBOSE=1 python /home/ryan/src/env/cuequivariance_bug_final_repro.py 2>&1
/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py:89: UserWarning: `math_dtype` is not provided for method `fused_tp`: using float32.
  warnings.warn(
Test 1 (no cuequivariance): ✓ Pass
Test 2 (with cuequivariance): Traceback (most recent call last):
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/tensor.py", line 1410, in evaluate_expr
    return guard_scalar(self.sym_num)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1519, in guard_scalar
    return guard_bool(a)
           ^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1711, in guard_bool
    return a.node.guard_bool("", 0)  # NB: uses Python backtrace
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 538, in guard_bool
    r = self.evaluate()
        ^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 512, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7223, in evaluate_sym_node
    return self.evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7323, in evaluate_expr
    return self._inner_evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 272, in wrapper
    return retlog(fn(*args, **kwargs))
                  ^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7346, in _inner_evaluate_expr
    return self._evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7570, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Ne(u0, 1) (unhinted: Ne(u0, 1)).  (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to True.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Caused by: if size != 1:  # cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py:205 in forward (_dynamo/variables/tensor.py:1410 in evaluate_expr)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/home/ryan/src/env/cuequivariance_bug_final_repro.py", line 37, in test_with_cuequivariance
    result = cg_module([lhs, rhs])
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial.py", line 283, in forward
    return self.m(inputs, input_indices, output_shapes, output_indices)
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py", line 205, in forward
    if size != 1:

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ryan/src/env/cuequivariance_bug_final_repro.py", line 53, in <module>
    result = test_with_cuequivariance(positions, 0.5, cg_module, True)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 736, in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1495, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 629, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1111, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 793, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 832, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
    transformations(instructions, code_options)
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 267, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 753, in transform
    tracer.run()
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3497, in run
    super().run()
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 834, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2910, in CALL
    self._call(inst)
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2904, in _call
    self.call_function(fn, args, kwargs)
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1193, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py", line 1000, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 529, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1210, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3698, in inline_call
    return tracer.inline_call_()
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3901, in inline_call_
    self.run()
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 834, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2910, in CALL
    self._call(inst)
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2904, in _call
    self.call_function(fn, args, kwargs)
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1193, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py", line 1000, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 529, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1210, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3698, in inline_call
    return tracer.inline_call_()
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3901, in inline_call_
    self.run()
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 790, in inner
    eval_result = value.evaluate_expr(self.output)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/torch/_dynamo/variables/tensor.py", line 1415, in evaluate_expr
    raise UserError(  # noqa: B904
torch._dynamo.exc.UserError: Consider annotating your code using torch._check*(). Could not guard on data-dependent expression Ne(u0, 1) (unhinted: Ne(u0, 1)).  (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to True.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Caused by: if size != 1:  # cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py:205 in forward (_dynamo/variables/tensor.py:1410 in evaluate_expr)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/home/ryan/src/env/cuequivariance_bug_final_repro.py", line 37, in test_with_cuequivariance
    result = cg_module([lhs, rhs])
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial.py", line 283, in forward
    return self.m(inputs, input_indices, output_shapes, output_indices)
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py", line 205, in forward
    if size != 1:

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example

from user code:
   File "/home/ryan/src/env/cuequivariance_bug_final_repro.py", line 37, in test_with_cuequivariance
    result = cg_module([lhs, rhs])
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial.py", line 283, in forward
    return self.m(inputs, input_indices, output_shapes, output_indices)
  File "/home/ryan/anaconda3/envs/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/segmented_polynomial_fused_tp.py", line 205, in forward
    if size != 1:

Expected behavior
A clear and concise description of what you expected to happen.

Screenshots
If applicable, add screenshots to help explain your problem.

GPU HW/SW(please complete the following information):
Here is the full details on my system, as previously reported for a pytorch bug.

python collect_env.py
--2025-11-21 18:09:41--  https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30662 (30K) [text/plain]
Saving to: ‘collect_env.py.1’

collect_env.py.1    100%[===================>]  29.94K  --.-KB/s    in 0s      

2025-11-21 18:09:41 (123 MB/s) - ‘collect_env.py.1’ saved [30662/30662]

Collecting environment information...
PyTorch version: 2.8.0+cu129
Is debug build: False
CUDA used to build PyTorch: 12.9
ROCM used to build PyTorch: N/A

OS: Ubuntu 25.10 (x86_64)
GCC version: (Ubuntu 15.2.0-4ubuntu4) 15.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.42

Python version: 3.12.0 | packaged by Anaconda, Inc. | (main, Oct  2 2023, 17:29:18) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.17.0-6-generic-x86_64-with-glibc2.42
Is CUDA available: True
CUDA runtime version: 12.9.86
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 5060 Ti
Nvidia driver version: 580.95.05
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           48 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  12
On-line CPU(s) list:                     0-11
Vendor ID:                               AuthenticAMD
Model name:                              AMD Ryzen 5 7600X 6-Core Processor
CPU family:                              25
Model:                                   97
Thread(s) per core:                      2
Core(s) per socket:                      6
Socket(s):                               1
Stepping:                                2
Frequency boost:                         enabled
CPU(s) scaling MHz:                      70%
CPU max MHz:                             5457.1050
CPU min MHz:                             427.3640
BogoMIPS:                                9381.80
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpuid_fault cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d amd_lbr_pmc_freeze
Virtualization:                          AMD-V
L1d cache:                               192 KiB (6 instances)
L1i cache:                               192 KiB (6 instances)
L2 cache:                                6 MiB (6 instances)
L3 cache:                                32 MiB (1 instance)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-11
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Old microcode:             Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Mitigation; Safe RET
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Mitigation; Clear CPU buffers
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Mitigation; IBPB before exit to userspace

Versions of relevant libraries:
[pip3] cuequivariance-ops-torch-cu12==0.7.0
[pip3] cuequivariance-torch==0.7.0
[pip3] mypy==1.18.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.9.1.4
[pip3] nvidia-cuda-cupti-cu12==12.9.79
[pip3] nvidia-cuda-nvrtc-cu12==12.9.86
[pip3] nvidia-cuda-runtime-cu12==12.9.79
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.4.1.4
[pip3] nvidia-curand-cu12==10.3.10.19
[pip3] nvidia-cusolver-cu12==11.7.5.82
[pip3] nvidia-cusparse-cu12==12.5.10.65
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.9.86
[pip3] nvidia-nvtx-cu12==12.9.79
[pip3] torch==2.8.0+cu129
[pip3] torch_cluster==1.6.3+pt28cu129
[pip3] torch_geometric==2.5.2
[pip3] torch_scatter==2.1.2+pt28cu129
[pip3] torchaudio==2.8.0+cu129
[pip3] torchvision==0.23.0+cu129
[pip3] triton==3.4.0
[conda] cuequivariance-ops-torch-cu12    0.7.0            pypi_0           pypi
[conda] cuequivariance-torch             0.7.0            pypi_0           pypi
[conda] numpy                            2.2.6            pypi_0           pypi
[conda] nvidia-cublas-cu12               12.9.1.4         pypi_0           pypi
[conda] nvidia-cuda-cupti-cu12           12.9.79          pypi_0           pypi
[conda] nvidia-cuda-nvrtc-cu12           12.9.86          pypi_0           pypi
[conda] nvidia-cuda-runtime-cu12         12.9.79          pypi_0           pypi
[conda] nvidia-cudnn-cu12                9.10.2.21        pypi_0           pypi
[conda] nvidia-cufft-cu12                11.4.1.4         pypi_0           pypi
[conda] nvidia-curand-cu12               10.3.10.19       pypi_0           pypi
[conda] nvidia-cusolver-cu12             11.7.5.82        pypi_0           pypi
[conda] nvidia-cusparse-cu12             12.5.10.65       pypi_0           pypi
[conda] nvidia-cusparselt-cu12           0.7.1            pypi_0           pypi
[conda] nvidia-nccl-cu12                 2.27.3           pypi_0           pypi
[conda] nvidia-nvjitlink-cu12            12.9.86          pypi_0           pypi
[conda] nvidia-nvtx-cu12                 12.9.79          pypi_0           pypi
[conda] torch                            2.8.0+cu129      pypi_0           pypi
[conda] torch-cluster                    1.6.3+pt28cu129  pypi_0           pypi
[conda] torch-geometric                  2.5.2            pypi_0           pypi
[conda] torch-scatter                    2.1.2+pt28cu129  pypi_0           pypi
[conda] torchaudio                       2.8.0+cu129      pypi_0           pypi
[conda] torchvision                      0.23.0+cu129     pypi_0           pypi
[conda] triton                           3.4.0            pypi_0           pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions