diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs index df68be1bc1..d06ce3d5d1 100644 --- a/qdp/qdp-python/src/lib.rs +++ b/qdp/qdp-python/src/lib.rs @@ -139,6 +139,37 @@ impl Drop for QuantumTensor { unsafe impl Send for QuantumTensor {} unsafe impl Sync for QuantumTensor {} +/// Helper to detect PyTorch tensor +fn is_pytorch_tensor(obj: &Bound<'_, PyAny>) -> PyResult { + let type_obj = obj.get_type(); + let name = type_obj.name()?; + if name != "Tensor" { + return Ok(false); + } + let module = type_obj.module()?; + let module_name = module.to_str()?; + Ok(module_name == "torch") +} + +/// Helper to validate tensor +fn validate_tensor(tensor: &Bound<'_, PyAny>) -> PyResult<()> { + if !is_pytorch_tensor(tensor)? { + return Err(PyRuntimeError::new_err("Object is not a PyTorch Tensor")); + } + + let device = tensor.getattr("device")?; + let device_type: String = device.getattr("type")?.extract()?; + + if device_type != "cpu" { + return Err(PyRuntimeError::new_err(format!( + "Only CPU tensors are currently supported for this path. Got device: {}", + device_type + ))); + } + + Ok(()) +} + /// PyO3 wrapper for QdpEngine /// /// Provides Python bindings for GPU-accelerated quantum state encoding. @@ -215,6 +246,42 @@ impl QdpEngine { }) } + /// Encode from PyTorch Tensor + /// + /// Args: + /// tensor: PyTorch Tensor (must be on CPU) + /// num_qubits: Number of qubits for encoding + /// encoding_method: Encoding strategy + /// + /// Returns: + /// QuantumTensor: DLPack-compatible tensor + fn encode_tensor( + &self, + tensor: &Bound<'_, PyAny>, + num_qubits: usize, + encoding_method: &str, + ) -> PyResult { + validate_tensor(tensor)?; + + // NOTE(perf): `tolist()` + `extract()` makes extra copies (Tensor -> Python list -> Vec). + // TODO: follow-up PR can use `numpy()`/buffer protocol (and possibly pinned host memory) + // to reduce copy overhead. + let data: Vec = tensor + .call_method0("flatten")? + .call_method0("tolist")? + .extract()?; + + let ptr = self + .engine + .encode(&data, num_qubits, encoding_method) + .map_err(|e| PyRuntimeError::new_err(format!("Encoding failed: {}", e)))?; + + Ok(QuantumTensor { + ptr, + consumed: false, + }) + } + /// Encode from Parquet file /// /// Args: diff --git a/qdp/qdp-python/tests/test_bindings.py b/qdp/qdp-python/tests/test_bindings.py index 7808abc8c9..ea23aceb76 100644 --- a/qdp/qdp-python/tests/test_bindings.py +++ b/qdp/qdp-python/tests/test_bindings.py @@ -77,9 +77,10 @@ def test_dlpack_device_id_non_zero(): qtensor = engine.encode(data, 2, "amplitude") device_info = qtensor.__dlpack_device__() - assert device_info == (2, device_id), ( - f"Expected (2, {device_id}) for CUDA device {device_id}" - ) + assert device_info == ( + 2, + device_id, + ), f"Expected (2, {device_id}) for CUDA device {device_id}" # Verify PyTorch integration works with non-zero device_id torch_tensor = torch.from_dlpack(qtensor) @@ -143,3 +144,48 @@ def test_pytorch_precision_float64(): torch_tensor = torch.from_dlpack(qtensor) assert torch_tensor.dtype == torch.complex128 + + +@pytest.mark.gpu +def test_encode_tensor_cpu(): + """Test encoding from CPU PyTorch tensor.""" + pytest.importorskip("torch") + import torch + from mahout_qdp import QdpEngine + + if not torch.cuda.is_available(): + pytest.skip("GPU required for QdpEngine") + + engine = QdpEngine(0) + data = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float64) + qtensor = engine.encode_tensor(data, 2, "amplitude") + + # Verify result + torch_tensor = torch.from_dlpack(qtensor) + assert torch_tensor.is_cuda + assert torch_tensor.shape == (1, 4) + + +@pytest.mark.gpu +def test_encode_tensor_errors(): + """Test error handling for encode_tensor.""" + pytest.importorskip("torch") + import torch + from mahout_qdp import QdpEngine + + if not torch.cuda.is_available(): + pytest.skip("GPU required for QdpEngine") + + engine = QdpEngine(0) + + # Test non-tensor input + with pytest.raises(RuntimeError, match="Object is not a PyTorch Tensor"): + engine.encode_tensor([1.0, 2.0], 1, "amplitude") + + # Test GPU tensor input (should fail as only CPU is supported for this path) + if torch.cuda.is_available(): + gpu_tensor = torch.tensor([1.0, 2.0], device="cuda:0") + with pytest.raises( + RuntimeError, match="Only CPU tensors are currently supported" + ): + engine.encode_tensor(gpu_tensor, 1, "amplitude")