diff --git a/examples/playground/einsum.ipynb b/examples/playground/einsum.ipynb new file mode 100644 index 0000000..a265856 --- /dev/null +++ b/examples/playground/einsum.ipynb @@ -0,0 +1,80 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "51f8501a", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "from nkipy.core.trace import NKIPyKernel\n", + "from nkipy.core.compile import lower_to_nki\n", + "from nkipy.runtime.execute import simulate_traced_kernel, baremetal_run_traced_kernel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "657ca110", + "metadata": {}, + "outputs": [], + "source": [ + "def einsum_matmul(A, B):\n", + " return np.einsum('ik,kj->ij', A, B)\n", + "\n", + "A = np.random.rand(2, 3).astype(np.float32)\n", + "B = np.random.rand(3, 4).astype(np.float32)\n", + "\n", + "expected = einsum_matmul(A, B)\n", + "expected" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2a4b5a9", + "metadata": {}, + "outputs": [], + "source": [ + "traced_kernel = NKIPyKernel.trace(einsum_matmul)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9258c26", + "metadata": {}, + "outputs": [], + "source": [ + "out_nkipy = simulate_traced_kernel(traced_kernel, A, B)\n", + "print(\"Is the simulated output the same as Numpy? \", np.allclose(out_nkipy, expected))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0e487ac", + "metadata": {}, + "outputs": [], + "source": [ + "out_baremetal = baremetal_run_traced_kernel(traced_kernel, A, B)\n", + "print(\"Is the baremetal output the same as Numpy? \", np.allclose(out_baremetal, expected))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/playground/nkipy_einsum.py b/examples/playground/nkipy_einsum.py new file mode 100644 index 0000000..885555c --- /dev/null +++ b/examples/playground/nkipy_einsum.py @@ -0,0 +1,189 @@ +import numpy as np +from nkipy.core.trace import NKIPyKernel +from nkipy.runtime.execute import baremetal_run_traced_kernel, simulate_traced_kernel + +print("=" * 80) +print("EINSUM OPERATION TESTS") +print("=" * 80) + + +def run_test(test_func, *test_args): + """Helper to trace, simulate, and run on baremetal.""" + # Run numpy version to get expected output + expected = test_func(*test_args) + print(f"Input shapes: {[a.shape for a in test_args if hasattr(a, 'shape')]}") + if hasattr(expected, "shape"): + print(f"Output shape: {expected.shape}") + else: + print(f"Output: {expected}") + + traced_kernel = NKIPyKernel.trace(test_func) + + # Simulation + out_nkipy = simulate_traced_kernel(traced_kernel, *test_args) + sim_match = np.allclose(out_nkipy, expected) + print(f"Simulation matches NumPy? {sim_match}") + + # Baremetal + try: + out_baremetal = baremetal_run_traced_kernel(traced_kernel, *test_args) + bm_match = np.allclose(out_baremetal, expected) + print(f"Baremetal matches NumPy? {bm_match}") + except Exception as e: + print(f"Baremetal test skipped/failed: {type(e).__name__} - {e}") + + +# ============================================================================= +# 1. Matrix Multiplication +# ============================================================================= +print("\n1. Matrix Multiplication (ik,kj->ij)") +print("-" * 80) + + +def einsum_matmul(A, B): + """Standard matrix multiply: (i, k) x (k, j) -> (i, j)""" + return np.einsum("ik,kj->ij", A, B) + + +A = np.random.rand(2, 3).astype(np.float32) +B = np.random.rand(3, 4).astype(np.float32) +run_test(einsum_matmul, A, B) + + +# ============================================================================= +# 2. Batch Matrix Multiplication +# ============================================================================= +print("\n2. Batch Matrix Multiplication (bik,bkj->bij)") +print("-" * 80) + + +def einsum_batch_matmul(A, B): + """Batch matrix multiply: (batch, i, k) x (batch, k, j) -> (batch, i, j)""" + return np.einsum("bik,bkj->bij", A, B) + + +A = np.random.rand(5, 2, 3).astype(np.float32) +B = np.random.rand(5, 3, 4).astype(np.float32) +run_test(einsum_batch_matmul, A, B) + + +# ============================================================================= +# 3. Dot Product (Inner Product) +# ============================================================================= +print("\n3. Dot Product (i,i->)") +print("-" * 80) + + +def einsum_dot(a, b): + """Dot product of two vectors: sum(a * b)""" + return np.einsum("i,i->", a, b) + + +a = np.array([1, 2, 3], dtype=np.float32) +b = np.array([4, 5, 6], dtype=np.float32) +run_test(einsum_dot, a, b) + + +# ============================================================================= +# 4. Outer Product +# ============================================================================= +print("\n4. Outer Product (i,j->ij)") +print("-" * 80) + + +def einsum_outer(a, b): + """Outer product: (i,) x (j,) -> (i, j)""" + return np.einsum("i,j->ij", a, b) + + +a = np.array([1, 2, 3], dtype=np.float32) +b = np.array([4, 5], dtype=np.float32) +run_test(einsum_outer, a, b) + + +# ============================================================================= +# 5. Element-wise Multiply and Sum (Frobenius inner product) +# ============================================================================= +print("\n5. Element-wise Multiply and Sum (ij,ij->)") +print("-" * 80) + + +def einsum_hadamard_sum(A, B): + """Element-wise multiply then sum all: sum(A * B)""" + return np.einsum("ij,ij->", A, B) + + +A = np.array([[1, 2], [3, 4]], dtype=np.float32) +B = np.array([[5, 6], [7, 8]], dtype=np.float32) +run_test(einsum_hadamard_sum, A, B) + + +# ============================================================================= +# 6. Transpose +# ============================================================================= +print("\n6. Transpose (ij->ji)") +print("-" * 80) + + +def einsum_transpose(A): + """Matrix transpose: (i, j) -> (j, i)""" + return np.einsum("ij->ji", A) + + +A = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) +run_test(einsum_transpose, A) + + +# ============================================================================= +# 7. Sum Along Axis +# ============================================================================= +print("\n7. Sum Along Axis (ij->i)") +print("-" * 80) + + +def einsum_sum_axis(A): + """Sum along last axis: (i, j) -> (i,)""" + return np.einsum("ij->i", A) + + +A = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) +run_test(einsum_sum_axis, A) + + +# ============================================================================= +# 8. Batched Dot Product +# ============================================================================= +print("\n8. Batched Dot Product (bi,bi->b)") +print("-" * 80) + + +def einsum_batch_dot(A, B): + """Dot product for each pair in batch: (batch, i) x (batch, i) -> (batch,)""" + return np.einsum("bi,bi->b", A, B) + + +A = np.random.rand(5, 10).astype(np.float32) +B = np.random.rand(5, 10).astype(np.float32) +run_test(einsum_batch_dot, A, B) + + +# ============================================================================= +# 9. Tensor Contraction +# ============================================================================= +print("\n9. Tensor Contraction (ijk,jkl->il)") +print("-" * 80) + + +def einsum_tensor_contract(A, B): + """Contract on middle dimensions: (i,j,k) x (j,k,l) -> (i,l)""" + return np.einsum("ijk,jkl->il", A, B) + + +A = np.random.rand(2, 3, 4).astype(np.float32) +B = np.random.rand(3, 4, 5).astype(np.float32) +run_test(einsum_tensor_contract, A, B) + + +print("\n" + "=" * 80) +print("TESTS COMPLETE") +print("=" * 80) diff --git a/nkipy/src/nkipy/core/_numpy_dispatch.py b/nkipy/src/nkipy/core/_numpy_dispatch.py index e31ecb3..ee714d0 100644 --- a/nkipy/src/nkipy/core/_numpy_dispatch.py +++ b/nkipy/src/nkipy/core/_numpy_dispatch.py @@ -90,6 +90,9 @@ def register_all_numpy_apis(): # Linear algebra _register_numpy_api(np.matmul, ops.matmul) + # Einstein summation + _register_numpy_api(np.einsum, ops.einsum) + # Transform operations _register_numpy_api(np.reshape, ops.reshape) _register_numpy_api(np.transpose, ops.transpose) diff --git a/nkipy/src/nkipy/core/ops/__init__.py b/nkipy/src/nkipy/core/ops/__init__.py index 31f77d4..2802a5d 100644 --- a/nkipy/src/nkipy/core/ops/__init__.py +++ b/nkipy/src/nkipy/core/ops/__init__.py @@ -89,6 +89,11 @@ # ----------------------------------------------------------------------------- from nkipy.core.ops.linalg import matmul +# ----------------------------------------------------------------------------- +# Einstein summation +# ----------------------------------------------------------------------------- +from nkipy.core.ops.einsum import einsum + # ----------------------------------------------------------------------------- # Neural network operations # ----------------------------------------------------------------------------- @@ -208,6 +213,8 @@ "full_like", # Linalg "matmul", + # Einsum + "einsum", # Reduction "sum", "max", diff --git a/nkipy/src/nkipy/core/ops/einsum.py b/nkipy/src/nkipy/core/ops/einsum.py new file mode 100644 index 0000000..42ad29f --- /dev/null +++ b/nkipy/src/nkipy/core/ops/einsum.py @@ -0,0 +1,580 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Einstein summation (einsum) operation. + +Provides a flexible interface for expressing tensor contractions using +Einstein notation, such as matrix multiplication, batch operations, +traces, and more. +""" + +from typing import Dict, List, Set, Tuple + +import numpy as np + +from nkipy.core.ops._registry import Op + +# ============================================================================= +# Einsum Subscript Parsing +# ============================================================================= + + +def parse_einsum_subscripts( + subscripts: str, num_operands: int +) -> Tuple[List[str], str]: + """Parse einsum subscript string into input and output specifications. + + Args: + subscripts: Einstein notation string (e.g., 'ij,jk->ik' or 'ij,jk') + num_operands: Number of input operands + + Returns: + Tuple of (input_specs, output_spec) where: + - input_specs: List of strings, one per operand (e.g., ['ij', 'jk']) + - output_spec: Output string (e.g., 'ik'), inferred if not provided + + Examples: + >>> parse_einsum_subscripts('ij,jk->ik', 2) + (['ij', 'jk'], 'ik') + + >>> parse_einsum_subscripts('ii', 1) # trace - inferred output is '' + (['ii'], '') + """ + # Remove whitespace + subscripts = subscripts.replace(" ", "") + + # Check for explicit output specification + if "->" in subscripts: + input_str, output_spec = subscripts.split("->") + input_specs = input_str.split(",") + else: + input_specs = subscripts.split(",") + # Infer output: all indices that appear exactly once across all inputs + all_indices: Dict[str, int] = {} + for spec in input_specs: + for idx in spec: + all_indices[idx] = all_indices.get(idx, 0) + 1 + + # Output contains indices that appear exactly once, in order of first appearance + # For implicit output with ..., we need to keep ... if present + output_spec = "" + seen: Set[str] = set() + + # Collect all indices that appear exactly once + unique_indices = sorted( + [idx for idx, count in all_indices.items() if count == 1 and idx != "."] + ) + + # In implicit mode, we preserve order of appearance + for spec in input_specs: + for idx in spec: + if idx in unique_indices and idx not in seen: + output_spec += idx + seen.add(idx) + + if len(input_specs) != num_operands: + raise ValueError( + f"Number of subscripts ({len(input_specs)}) does not match " + f"number of operands ({num_operands})" + ) + + return input_specs, output_spec + + +def analyze_einsum_pattern( + input_specs: List[str], output_spec: str, shapes: List[Tuple[int, ...]] +) -> Dict: + """Analyze einsum pattern to determine dimension mapping and operation type. + + Args: + input_specs: List of input subscript strings + output_spec: Output subscript string + shapes: List of input tensor shapes + + Returns: + Dictionary containing: + - 'input_dims': Dict mapping each input index to dimension size + - 'contracting_dims': Set of indices being contracted (summed over) + - 'batch_dims': Set of indices appearing in all inputs and output + - 'output_order': List of indices in output order + """ + # Build index -> dimension size mapping + input_dims: Dict[str, int] = {} + for spec, shape in zip(input_specs, shapes): + if len(spec) != len(shape): + raise ValueError( + f"Subscript '{spec}' has {len(spec)} indices but shape has " + f"{len(shape)} dimensions: {shape}" + ) + for idx, size in zip(spec, shape): + if idx in input_dims and input_dims[idx] != size: + raise ValueError( + f"Index '{idx}' has inconsistent dimensions: " + f"{input_dims[idx]} vs {size}" + ) + input_dims[idx] = size + + # Collect all unique indices + all_indices = set() + for spec in input_specs: + all_indices.update(spec) + + # Determine contracting dimensions (in inputs but not output) + contracting_dims = all_indices - set(output_spec) + + # Determine batch dimensions (in all inputs and output) + batch_dims = set(output_spec) + for spec in input_specs: + batch_dims &= set(spec) + + return { + "input_dims": input_dims, + "contracting_dims": contracting_dims, + "batch_dims": batch_dims, + "output_order": list(output_spec), + } + + +# ============================================================================= +# Einsum Operation +# ============================================================================= +einsum = Op("einsum") + + +@einsum.impl("hlo") +def _einsum_hlo(subscripts, *operands, dtype=None): + """Einstein summation convention on tensors (HLO implementation). + + Implements einsum using HLO operations: transpose, dot_general, reduce. + Supports common patterns like matrix multiplication, batch operations, + traces, outer products, and more. + + Args: + subscripts: Einstein notation string (e.g., 'ij,jk->ik') + *operands: Input tensors + dtype: Optional output dtype (if None, inferred from inputs) + + Returns: + Result tensor according to einsum specification + + Examples: + >>> # Matrix multiply + >>> einsum('ij,jk->ik', A, B) + + >>> # Batch matrix multiply + >>> einsum('bij,bjk->bik', A, B) + + >>> # Trace + >>> einsum('ii->', A) + + >>> # Outer product + >>> einsum('i,j->ij', a, b) + """ + from nkipy.core.backend.hlo import get_hlo_context + from nkipy.core.tensor import NKIPyTensorRef + + if not operands: + raise ValueError("einsum requires at least one operand") + + # Get shapes + shapes = [] + real_operands = [] + ctx = get_hlo_context() + + for op in operands: + if isinstance(op, NKIPyTensorRef): + real_operands.append(op) + shapes.append(op.backend_tensor.shape) + else: + # Assume it's an HLO tensor or similar + # Wrappping it might be needed if we call tensor ops? + # The original code handled wrapping later. + # We need shapes now for ellipsis expansion. + real_operands.append(op) + shapes.append(op.shape) + + # Parse subscripts + input_specs, output_spec = parse_einsum_subscripts(subscripts, len(operands)) + + # Handle repeated indices (Diagonal/Trace) + # This might modify operands (insert diagonal ops) and specs + cleaned_input_specs = [] + processed_operands = [] + + for i, (spec, op) in enumerate(zip(input_specs, real_operands)): + # Check for repeated indices + if len(set(spec)) != len(spec): + new_op, new_spec = _handle_repeated_indices(ctx, op, spec) + processed_operands.append(new_op) + cleaned_input_specs.append(new_spec) + else: + processed_operands.append(op) + cleaned_input_specs.append(spec) + + input_specs = cleaned_input_specs + + # Refresh shapes after potential diagonal reductions + hlo_operands = [] + final_shapes = [] + for op in processed_operands: + if isinstance(op, NKIPyTensorRef): + hlo_operands.append(op.backend_tensor) + final_shapes.append(op.backend_tensor.shape) + else: + hlo_operands.append(op) + final_shapes.append(op.shape) + + # Analyze pattern + analysis = analyze_einsum_pattern(input_specs, output_spec, final_shapes) + + # Handle special cases for optimization + if len(hlo_operands) == 1: + return _einsum_unary( + ctx, hlo_operands[0], input_specs[0], output_spec, analysis + ) + elif len(hlo_operands) == 2: + return _einsum_binary( + ctx, + hlo_operands[0], + hlo_operands[1], + input_specs[0], + input_specs[1], + output_spec, + analysis, + ) + else: + # General case: reduce to binary operations + return _einsum_nary(ctx, hlo_operands, input_specs, output_spec, analysis) + + +def _handle_repeated_indices(ctx, operand, spec: str): + """Handle repeated indices in a single spec (e.g., 'ii') by taking diagonal.""" + import collections + + from nkipy.core.backend.hlo import as_hlo_tensor + from nkipy.core.tensor import NKIPyTensorRef + + current_operand = operand + if isinstance(current_operand, NKIPyTensorRef): + current_operand = current_operand.backend_tensor + current_spec = list(spec) + + while True: + counts = collections.Counter(current_spec) + repeated = [char for char, count in counts.items() if count > 1] + + if not repeated: + break + + # Handle first repeated index + idx = repeated[0] + # Find first two positions + positions = [i for i, char in enumerate(current_spec) if char == idx] + pos1, pos2 = positions[0], positions[1] + + # Verify dimensions + shape = current_operand.shape + if shape[pos1] != shape[pos2]: + raise ValueError( + f"Repeated index {idx} has incompatible dimensions {shape[pos1]} and {shape[pos2]}" + ) + + dim_size = shape[pos1] + + # Move pos1 and pos2 to the end + # Permutation: All other indices + pos1 + pos2 + other_indices = [i for i in range(len(shape)) if i != pos1 and i != pos2] + perm = other_indices + [pos1, pos2] + + current_operand = ctx.build_op( + "transpose", + [current_operand], + tuple(shape[i] for i in perm), + current_operand.dtype, + {"permutation": perm}, + ) + + # Now shape is (..., N, N) + # Create Identity Mask (N, N) + # iota dimension 0 + iota0 = ctx.build_op( + "iota", [], (dim_size, dim_size), "int32", {"iota_dimension": 0} + ) + # iota dimension 1 + iota1 = ctx.build_op( + "iota", [], (dim_size, dim_size), "int32", {"iota_dimension": 1} + ) + + # Mask = (iota0 == iota1) + pred = ctx.build_op( + "compare", + [iota0, iota1], + (dim_size, dim_size), + np.bool_, + {"comparison_direction": "EQ"}, + ) + + # Convert to dtype + mask = ctx.build_op( + "convert", [pred], (dim_size, dim_size), current_operand.dtype, {} + ) + + # Broadcast mask to matches current_operand magnitude + # Mask has shape (N, N). Operand has (..., N, N). + # We broadcast mask to operands shape. + # Dimensions to broadcast are the '...' ones (0 to len-3). + # We map the mask dimensions [0, 1] to Result dimensions [rank-2, rank-1]. + + rank = len(current_operand.shape) + mask_broadcast = ctx.build_op( + "broadcast", + [mask], + current_operand.shape, + current_operand.dtype, + {"broadcast_dimensions": [rank - 2, rank - 1]}, + ) + + # Multiply + masked_op = ctx.build_op( + "multiply", + [current_operand, mask_broadcast], + current_operand.shape, + current_operand.dtype, + ) + + # Reduce sum over the last dimension (pos2) - which is now at rank-1 + # Reduce dims: [rank-1] + # Init value for add: 0.0 + init_val = as_hlo_tensor(ctx, 0.0, current_operand.dtype) + + reduced_shape = current_operand.shape[:-1] + current_operand = ctx.build_op( + "reduce", + [masked_op, init_val], + reduced_shape, + current_operand.dtype, + {"dimensions": [rank - 1], "computation": "add"}, + ) + + # Update spec + # We removed the char at pos2 (which was moved to end). + # The char at pos1 (which was moved to rank-2) is now at rank-1 (end). + # The other chars are at 0 ... rank-2. + # So new spec order is: [others] + [idx]. + + new_spec_list = [current_spec[i] for i in other_indices] + [idx] + current_spec = new_spec_list + + return current_operand, "".join(current_spec) + + +def _einsum_unary(ctx, operand, input_spec, output_spec, analysis): + """Handle single-operand einsum (transpose, trace, reduction).""" + from nkipy.core.backend.hlo import as_hlo_tensor + from nkipy.core.tensor import NKIPyTensorRef + + # If output is empty, it's a full reduction + if not output_spec: + # Reduce all dimensions + init_tensor = as_hlo_tensor(ctx, 0.0, operand.dtype) + result = ctx.build_op( + "reduce", + [operand, init_tensor], + (), # scalar output + operand.dtype, + { + "dimensions": list(range(len(operand.shape))), + "computation": "add", + }, + ) + return NKIPyTensorRef(result) + + # Use analysis to determine which dimensions to reduce + # Contracting dims are those in input but not in output + dims_to_reduce = [] + output_dims = [] + + for i, idx in enumerate(input_spec): + if idx in analysis["contracting_dims"]: + dims_to_reduce.append(i) + else: + output_dims.append((idx, i, analysis["input_dims"][idx])) + + # Sort output dimensions by their order in output_spec + # output_dims contains (idx, original_pos, size) + # The 'operand' tensor currently has these dimensions in the order they appeared in input_spec (minus reduced ones). + # XLA Reduce preserves relative order of remaining dimensions. + + current_indices = [idx for idx, _, _ in output_dims] + + # If there are dimensions to reduce + if dims_to_reduce: + # The shape expected by reduce op is the shape of the RESULT of reduction? + # Or the shape of the operands? + # Usually XLA build_op('reduce') might take output shape? + # If so, it should match the input-ordered result (since no transpose happens during reduce). + reduced_shape = tuple(size for _, _, size in output_dims) + + init_tensor = as_hlo_tensor(ctx, 0.0, operand.dtype) + operand = ctx.build_op( + "reduce", + [operand, init_tensor], + reduced_shape, + operand.dtype, + { + "dimensions": dims_to_reduce, + "computation": "add", + }, + ) + + # Check if we need to transpose to match output spec + if current_indices != list(output_spec): + # Build permutation + # We want output to be output_spec. + # Current tensor has dims in `current_indices` order. + # We need to mapp `current_indices` -> `output_spec`. + # Transpose perm[i] is the index in input that maps to output[i]. + + try: + perm = [current_indices.index(idx) for idx in output_spec] + except ValueError as e: + # Should not happen if analysis is correct + raise RuntimeError( + f"Internal einsum error: indices mismatch {current_indices} vs {output_spec}" + ) from e + + transposed_shape = tuple(operand.shape[i] for i in perm) + operand = ctx.build_op( + "transpose", + [operand], + transposed_shape, + operand.dtype, + {"permutation": perm}, + ) + + return NKIPyTensorRef(operand) + + +def _einsum_binary(ctx, lhs, rhs, lhs_spec, rhs_spec, output_spec, analysis): + """Handle two-operand einsum (matmul, outer product, etc.).""" + from nkipy.core.tensor import NKIPyTensorRef + + # Find contracting, batch, and free dimensions + lhs_indices = list(lhs_spec) + rhs_indices = list(rhs_spec) + + contracting_dims = analysis["contracting_dims"] + + # Identify dimension roles for each operand + lhs_contracting = [ + i for i, idx in enumerate(lhs_indices) if idx in contracting_dims + ] + rhs_contracting = [ + i for i, idx in enumerate(rhs_indices) if idx in contracting_dims + ] + + lhs_batch = [ + i for i, idx in enumerate(lhs_indices) if idx in analysis["batch_dims"] + ] + rhs_batch = [ + i for i, idx in enumerate(rhs_indices) if idx in analysis["batch_dims"] + ] + + # Compute output shape + output_shape = tuple(analysis["input_dims"][idx] for idx in output_spec) + + # Use dot_general for contraction + if contracting_dims: + result = ctx.build_op( + "dot", + [lhs, rhs], + output_shape, + lhs.dtype, + { + "lhs_contracting_dimensions": lhs_contracting, + "rhs_contracting_dimensions": rhs_contracting, + "lhs_batch_dimensions": lhs_batch, + "rhs_batch_dimensions": rhs_batch, + }, + ) + else: + # No contraction - it's an outer product or broadcast multiply + # Reshape both operands to have compatible shapes, then multiply + # For now, use broadcasting via reshape + multiply + + # Determine the position of each operand's dimensions in output + lhs_out_positions = [output_spec.index(idx) for idx in lhs_indices] + rhs_out_positions = [output_spec.index(idx) for idx in rhs_indices] + + # Broadcast lhs to output shape + lhs_broadcasted = ctx.build_op( + "broadcast", + [lhs], + output_shape, + lhs.dtype, + {"broadcast_dimensions": lhs_out_positions}, + ) + + # Broadcast rhs to output shape + rhs_broadcasted = ctx.build_op( + "broadcast", + [rhs], + output_shape, + rhs.dtype, + {"broadcast_dimensions": rhs_out_positions}, + ) + + # Multiply + result = ctx.build_op( + "multiply", [lhs_broadcasted, rhs_broadcasted], output_shape, lhs.dtype + ) + + return NKIPyTensorRef(result) + + +def _einsum_nary(ctx, operands, input_specs, output_spec, analysis): + """Handle n-ary einsum by reducing to binary operations.""" + # Chain binary operations left-to-right + result = operands[0] + current_spec = input_specs[0] + + for i in range(1, len(operands)): + # Determine intermediate output spec (union of remaining indices) + remaining_specs = input_specs[i:] + remaining_indices = set(output_spec) + for spec in remaining_specs: + remaining_indices.update(spec) + + # Build intermediate spec in canonical order + intermediate_spec = "".join( + idx + for idx in current_spec + input_specs[i] + if idx in remaining_indices + and idx + not in "".join( + idx for idx in current_spec + input_specs[i] if idx in remaining_indices + )[: current_spec.index(idx) if idx in current_spec else len(current_spec)] + ) + + # Perform binary einsum + shapes = [result.shape, operands[i].shape] + sub_analysis = analyze_einsum_pattern( + [current_spec, input_specs[i]], + intermediate_spec if i < len(operands) - 1 else output_spec, + shapes, + ) + + result_ref = _einsum_binary( + ctx, + result, + operands[i], + current_spec, + input_specs[i], + intermediate_spec if i < len(operands) - 1 else output_spec, + sub_analysis, + ) + result = result_ref.backend_tensor + current_spec = intermediate_spec if i < len(operands) - 1 else output_spec + + from nkipy.core.tensor import NKIPyTensorRef + + return NKIPyTensorRef(result) diff --git a/tests/kernels/einsum.py b/tests/kernels/einsum.py new file mode 100644 index 0000000..95afa51 --- /dev/null +++ b/tests/kernels/einsum.py @@ -0,0 +1,193 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Einstein summation (einsum) kernel specifications. + +This module provides kernel specifications for testing various einsum patterns: +- Matrix multiplication +- Batch operations +- Reductions +- Transposes +- Outer products +""" + +import numpy as np +from nkipy.core.specs import CommonTypes, KernelSpec, ShapeSpec, TensorInputSpec + +# ============================================================================= +# Matrix Operations +# ============================================================================= + + +def matmul_einsum(A, B): + """Matrix multiplication using einsum: ij,jk->ik""" + return np.einsum("ij,jk->ik", A, B) + + +def batch_matmul_einsum(A, B): + """Batch matrix multiplication using einsum: bij,bjk->bik""" + return np.einsum("bij,bjk->bik", A, B) + + +# ============================================================================= +# Transpose and Permutation +# ============================================================================= + + +def transpose_einsum(A): + """Transpose using einsum: ij->ji""" + return np.einsum("ij->ji", A) + + +def permute_dims_einsum(A): + """Permute dimensions using einsum: ijk->kij""" + return np.einsum("ijk->kij", A) + + +# ============================================================================= +# Reductions +# ============================================================================= + + +def sum_axis_einsum(A): + """Sum along axis using einsum: ij->i""" + return np.einsum("ij->i", A) + + +# ============================================================================= +# Outer Products +# ============================================================================= + + +def outer_product_einsum(a, b): + """Outer product using einsum: i,j->ij""" + return np.einsum("i,j->ij", a, b) + + +# ============================================================================= +# Advanced Patterns +# ============================================================================= + + +def dot_product_einsum(a, b): + """Dot product using einsum: i,i->""" + return np.einsum("i,i->", a, b) + + +# ============================================================================= +# Kernel Specifications +# ============================================================================= + +kernel_specs = [ + # Matrix multiplication + KernelSpec( + function=matmul_einsum, + inputs=[ + TensorInputSpec( + shape_spec=ShapeSpec(dims=[None, None], default=(32, 64)), + dtype_spec=CommonTypes.FLOATS, + description="First matrix (M, K)", + ), + TensorInputSpec( + shape_spec=ShapeSpec(dims=[None, None], default=(64, 48)), + dtype_spec=CommonTypes.FLOATS, + description="Second matrix (K, N)", + ), + ], + is_pure_numpy=True, + description="Matrix multiplication via einsum (ij,jk->ik)", + ), + # Batch matrix multiplication + KernelSpec( + function=batch_matmul_einsum, + inputs=[ + TensorInputSpec( + shape_spec=ShapeSpec(dims=[None, None, None], default=(4, 32, 64)), + dtype_spec=CommonTypes.FLOATS, + description="First batched matrix (B, M, K)", + ), + TensorInputSpec( + shape_spec=ShapeSpec(dims=[None, None, None], default=(4, 64, 48)), + dtype_spec=CommonTypes.FLOATS, + description="Second batched matrix (B, K, N)", + ), + ], + is_pure_numpy=True, + description="Batch matrix multiplication via einsum (bij,bjk->bik)", + ), + # Transpose + KernelSpec( + function=transpose_einsum, + inputs=[ + TensorInputSpec( + shape_spec=ShapeSpec(dims=[None, None], default=(32, 64)), + dtype_spec=CommonTypes.FLOATS, + description="Matrix to transpose", + ), + ], + is_pure_numpy=True, + description="2D transpose via einsum (ij->ji)", + ), + # Permute dimensions + KernelSpec( + function=permute_dims_einsum, + inputs=[ + TensorInputSpec( + shape_spec=ShapeSpec(dims=[None, None, None], default=(4, 32, 64)), + dtype_spec=CommonTypes.FLOATS, + description="3D tensor to permute", + ), + ], + is_pure_numpy=True, + description="3D permutation via einsum (ijk->kij)", + ), + # Sum along axis + KernelSpec( + function=sum_axis_einsum, + inputs=[ + TensorInputSpec( + shape_spec=ShapeSpec(dims=[None, None], default=(32, 64)), + dtype_spec=CommonTypes.FLOATS, + description="Matrix to reduce", + ), + ], + is_pure_numpy=True, + description="Sum along last axis via einsum (ij->i)", + ), + # Outer product + KernelSpec( + function=outer_product_einsum, + inputs=[ + TensorInputSpec( + shape_spec=ShapeSpec(dims=[None], default=(32,)), + dtype_spec=CommonTypes.FLOATS, + description="First vector", + ), + TensorInputSpec( + shape_spec=ShapeSpec(dims=[None], default=(64,)), + dtype_spec=CommonTypes.FLOATS, + description="Second vector", + ), + ], + is_pure_numpy=True, + description="Outer product via einsum (i,j->ij)", + ), + # Dot product + KernelSpec( + function=dot_product_einsum, + inputs=[ + TensorInputSpec( + shape_spec=ShapeSpec(dims=[None], default=(128,)), + dtype_spec=CommonTypes.FLOATS, + description="First vector", + ), + TensorInputSpec( + shape_spec=ShapeSpec(dims=[None], default=(128,)), + dtype_spec=CommonTypes.FLOATS, + description="Second vector", + ), + ], + is_pure_numpy=True, + description="Dot product via einsum (i,i->)", + ), +] diff --git a/tests/unit/test_tensor_api.py b/tests/unit/test_tensor_api.py index 06b3c6b..d0fb7e1 100644 --- a/tests/unit/test_tensor_api.py +++ b/tests/unit/test_tensor_api.py @@ -1735,5 +1735,332 @@ def kernel_with_constant_one(x): ) +@pytest.mark.parametrize( + "shape_a,shape_b", + [ + ((32, 64), (64, 48)), + ((128, 256), (256, 512)), + ((256, 512), (512, 1024)), + ((1, 128), (128, 256)), + ((256, 128), (128, 1)), + ], +) +def test_einsum_matmul(sim_mode, shape_a, shape_b): + """Test einsum matrix multiplication: ij,jk->ik""" + + def kernel(a, b): + return np.einsum("ij,jk->ik", a, b) + + dtype = np.float32 + np.random.seed(0) + in0 = np.random.uniform(high=1.0, low=0.0, size=shape_a).astype(dtype) + in1 = np.random.uniform(high=1.0, low=0.0, size=shape_b).astype(dtype) + + expected = np.einsum("ij,jk->ik", in0, in1) + out = simulate_kernel_unified(kernel, sim_mode, in0, in1) + simulate_assert_allclose(out, expected) + + if NEURON_AVAILABLE: + out_baremetal = baremetal_run_kernel_unified(kernel, sim_mode, in0, in1) + baremetal_assert_allclose(expected, out_baremetal) + + +@pytest.mark.parametrize( + "shape_a,shape_b", + [ + ((2, 32, 64), (2, 64, 48)), + ((4, 128, 256), (4, 256, 128)), + ((8, 64, 128), (8, 128, 64)), + ((1, 128, 256), (1, 256, 128)), + ], +) +def test_einsum_batch_matmul(sim_mode, shape_a, shape_b): + """Test einsum batch matrix multiplication: bij,bjk->bik""" + + def kernel(a, b): + return np.einsum("bij,bjk->bik", a, b) + + dtype = np.float32 + np.random.seed(0) + in0 = np.random.uniform(high=1.0, low=0.0, size=shape_a).astype(dtype) + in1 = np.random.uniform(high=1.0, low=0.0, size=shape_b).astype(dtype) + + expected = np.einsum("bij,bjk->bik", in0, in1) + out = simulate_kernel_unified(kernel, sim_mode, in0, in1) + simulate_assert_allclose(out, expected) + + if NEURON_AVAILABLE: + out_baremetal = baremetal_run_kernel_unified(kernel, sim_mode, in0, in1) + baremetal_assert_allclose(expected, out_baremetal) + + +@pytest.mark.parametrize( + "shape_a,shape_b", + [ + ((128,), (128,)), + ((256,), (256,)), + ((512,), (512,)), + ((1024,), (1024,)), + ], +) +def test_einsum_dot_product(sim_mode, shape_a, shape_b): + """Test einsum dot product: i,i->""" + + def kernel(a, b): + return np.einsum("i,i->", a, b) + + dtype = np.float32 + np.random.seed(0) + in0 = np.random.uniform(high=1.0, low=0.0, size=shape_a).astype(dtype) + in1 = np.random.uniform(high=1.0, low=0.0, size=shape_b).astype(dtype) + + expected = np.einsum("i,i->", in0, in1) + out = simulate_kernel_unified(kernel, sim_mode, in0, in1) + simulate_assert_allclose(out, expected) + + if NEURON_AVAILABLE: + out_baremetal = baremetal_run_kernel_unified(kernel, sim_mode, in0, in1) + baremetal_assert_allclose(expected, out_baremetal) + + +@pytest.mark.parametrize( + "shape_a,shape_b", + [ + ((32,), (64,)), + ((64,), (128,)), + ((128,), (256,)), + ((16,), (32,)), + ], +) +def test_einsum_outer_product(sim_mode, shape_a, shape_b): + """Test einsum outer product: i,j->ij""" + + def kernel(a, b): + return np.einsum("i,j->ij", a, b) + + dtype = np.float32 + np.random.seed(0) + in0 = np.random.uniform(high=1.0, low=0.0, size=shape_a).astype(dtype) + in1 = np.random.uniform(high=1.0, low=0.0, size=shape_b).astype(dtype) + + expected = np.einsum("i,j->ij", in0, in1) + out = simulate_kernel_unified(kernel, sim_mode, in0, in1) + simulate_assert_allclose(out, expected) + + if NEURON_AVAILABLE: + out_baremetal = baremetal_run_kernel_unified(kernel, sim_mode, in0, in1) + baremetal_assert_allclose(expected, out_baremetal) + + +@pytest.mark.parametrize( + "shape", + [ + (32, 64), + (128, 256), + (256, 512), + (512, 1024), + ], +) +def test_einsum_transpose(sim_mode, shape): + """Test einsum transpose: ij->ji""" + + def kernel(a): + return np.einsum("ij->ji", a) + + dtype = np.float32 + np.random.seed(0) + in0 = np.random.uniform(high=1.0, low=0.0, size=shape).astype(dtype) + + expected = np.einsum("ij->ji", in0) + out = simulate_kernel_unified(kernel, sim_mode, in0) + simulate_assert_allclose(out, expected) + + if NEURON_AVAILABLE: + out_baremetal = baremetal_run_kernel_unified(kernel, sim_mode, in0) + baremetal_assert_allclose(expected, out_baremetal) + + +@pytest.mark.parametrize( + "shape", + [ + (32, 64), + (128, 256), + (256, 512), + ], +) +def test_einsum_sum_axis(sim_mode, shape): + """Test einsum sum along axis: ij->i""" + + def kernel(a): + return np.einsum("ij->i", a) + + dtype = np.float32 + np.random.seed(0) + in0 = np.random.uniform(high=1.0, low=0.0, size=shape).astype(dtype) + + expected = np.einsum("ij->i", in0) + out = simulate_kernel_unified(kernel, sim_mode, in0) + simulate_assert_allclose(out, expected) + + if NEURON_AVAILABLE: + out_baremetal = baremetal_run_kernel_unified(kernel, sim_mode, in0) + baremetal_assert_allclose(expected, out_baremetal) + + +@pytest.mark.parametrize( + "shape_a,shape_b", + [ + ((32, 64), (32, 64)), + ((128, 256), (128, 256)), + ((256, 512), (256, 512)), + ], +) +def test_einsum_hadamard_sum(sim_mode, shape_a, shape_b): + """Test einsum element-wise multiply and sum: ij,ij->""" + + def kernel(a, b): + return np.einsum("ij,ij->", a, b) + + dtype = np.float32 + np.random.seed(0) + in0 = np.random.uniform(high=1.0, low=0.0, size=shape_a).astype(dtype) + in1 = np.random.uniform(high=1.0, low=0.0, size=shape_b).astype(dtype) + + expected = np.einsum("ij,ij->", in0, in1) + out = simulate_kernel_unified(kernel, sim_mode, in0, in1) + simulate_assert_allclose(out, expected) + + if NEURON_AVAILABLE: + out_baremetal = baremetal_run_kernel_unified(kernel, sim_mode, in0, in1) + baremetal_assert_allclose(expected, out_baremetal) + + +@pytest.mark.parametrize( + "shape_a,shape_b", + [ + ((4, 128), (4, 128)), + ((8, 256), (8, 256)), + ((16, 512), (16, 512)), + ], +) +def test_einsum_batch_dot(sim_mode, shape_a, shape_b): + """Test einsum batched dot product: bi,bi->b""" + + def kernel(a, b): + return np.einsum("bi,bi->b", a, b) + + dtype = np.float32 + np.random.seed(0) + in0 = np.random.uniform(high=1.0, low=0.0, size=shape_a).astype(dtype) + in1 = np.random.uniform(high=1.0, low=0.0, size=shape_b).astype(dtype) + + expected = np.einsum("bi,bi->b", in0, in1) + out = simulate_kernel_unified(kernel, sim_mode, in0, in1) + simulate_assert_allclose(out, expected) + + if NEURON_AVAILABLE: + out_baremetal = baremetal_run_kernel_unified(kernel, sim_mode, in0, in1) + baremetal_assert_allclose(expected, out_baremetal) + + +@pytest.mark.parametrize( + "shape", + [ + (2, 3, 4), + (4, 8, 16), + (8, 16, 32), + ], +) +def test_einsum_permute_3d(sim_mode, shape): + """Test einsum 3D permutation: ijk->kij""" + + def kernel(a): + return np.einsum("ijk->kij", a) + + dtype = np.float32 + np.random.seed(0) + in0 = np.random.uniform(high=1.0, low=0.0, size=shape).astype(dtype) + + expected = np.einsum("ijk->kij", in0) + out = simulate_kernel_unified(kernel, sim_mode, in0) + simulate_assert_allclose(out, expected) + + if NEURON_AVAILABLE: + out_baremetal = baremetal_run_kernel_unified(kernel, sim_mode, in0) + baremetal_assert_allclose(expected, out_baremetal) + + +def test_einsum_implicit_output(sim_mode): + """Test einsum with implicit output specification""" + + def kernel(a, b): + return np.einsum("ik,kj", a, b) # Implicit: -> 'ij' + + shape_a = (32, 64) + shape_b = (64, 48) + dtype = np.float32 + + np.random.seed(0) + in0 = np.random.uniform(high=1.0, low=0.0, size=shape_a).astype(dtype) + in1 = np.random.uniform(high=1.0, low=0.0, size=shape_b).astype(dtype) + + expected = np.einsum("ik,kj", in0, in1) + out = simulate_kernel_unified(kernel, sim_mode, in0, in1) + simulate_assert_allclose(out, expected) + + if NEURON_AVAILABLE: + out_baremetal = baremetal_run_kernel_unified(kernel, sim_mode, in0, in1) + baremetal_assert_allclose(expected, out_baremetal) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float16]) +def test_einsum_dtypes(sim_mode, dtype): + """Test einsum with different data types""" + + def kernel(a, b): + return np.einsum("ij,jk->ik", a, b) + + shape_a = (64, 128) + shape_b = (128, 64) + + np.random.seed(0) + in0 = np.random.uniform(high=1.0, low=0.0, size=shape_a).astype(dtype) + in1 = np.random.uniform(high=1.0, low=0.0, size=shape_b).astype(dtype) + + expected = np.einsum("ij,jk->ik", in0, in1) + out = simulate_kernel_unified(kernel, sim_mode, in0, in1) + + # Use relaxed tolerance for float16 + rtol = 1e-2 if dtype == np.float16 else 1e-5 + simulate_assert_allclose(out, expected, rtol=rtol) + + if NEURON_AVAILABLE: + out_baremetal = baremetal_run_kernel_unified(kernel, sim_mode, in0, in1) + baremetal_assert_allclose(expected, out_baremetal, rtol=rtol) + + +def test_einsum_tensor_contraction(sim_mode): + """Test einsum tensor contraction: ijk,jkl->il""" + + def kernel(a, b): + return np.einsum("ijk,jkl->il", a, b) + + shape_a = (2, 3, 4) + shape_b = (3, 4, 5) + dtype = np.float32 + + np.random.seed(0) + in0 = np.random.uniform(high=1.0, low=0.0, size=shape_a).astype(dtype) + in1 = np.random.uniform(high=1.0, low=0.0, size=shape_b).astype(dtype) + + expected = np.einsum("ijk,jkl->il", in0, in1) + out = simulate_kernel_unified(kernel, sim_mode, in0, in1) + simulate_assert_allclose(out, expected) + + if NEURON_AVAILABLE: + out_baremetal = baremetal_run_kernel_unified(kernel, sim_mode, in0, in1) + baremetal_assert_allclose(expected, out_baremetal) + + if __name__ == "__main__": pytest.main([__file__, "-v"])