Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
134 commits
Select commit Hold shift + click to select a range
426e89d
Created branch, added codeowners
Jan 30, 2025
10d7bb3
Initial migration from internal repo (#5)
STFleming Jan 30, 2025
5ba98f9
BERT builder flow arguments for fifosim n_inferences (#6)
STFleming Feb 17, 2025
ccd023b
[SoftMax] New Improved SoftMax (#11)
STFleming Mar 7, 2025
ad639bc
[BugFix] Issues with incorrect configuration of SIMD for ShuffleB nod…
STFleming Mar 7, 2025
4afffcd
Adding cycle testing to custom op test scripts (#7)
penrosed Mar 7, 2025
fcd7bc3
Added a custom step that extracts metadata for the shell integration …
STFleming Mar 13, 2025
fab2842
[TinyBERT] Removing accidentally included start_step in the endtoend …
STFleming Mar 18, 2025
d7fb002
Removing rtlsim_backend after pyverilator deprecation (#16)
STFleming Mar 19, 2025
0c72dda
Name stylize BrainSmith --> Brainsmith (#17)
tafk7 Mar 20, 2025
dbfbe67
[TinyBERT] Add ref IO to stitched_ip as part of metadata handover (#18)
STFleming Mar 27, 2025
3d8aac0
[Testing] Created OpTest class for abstracting CustomOp tests (#19)
penrosed Apr 3, 2025
99e2aa2
Initial repository structure (#20)
tafk7 Apr 11, 2025
15fb647
Add Custom ONNXSCRIPT repository to BrainSmith (#21)
jsmonson Apr 17, 2025
752bd39
Revert "Add Custom ONNXSCRIPT repository to BrainSmith (#21)" (#22)
tafk7 Apr 18, 2025
17fc5ca
[CustomOps] Update brainsmith custom ops with changes on finn side (#25)
auphelia Apr 28, 2025
3dcfe0b
Initial continuous integration tests (#24)
tafk7 Apr 29, 2025
ff45805
Revert onnxscript add Revert (#26)
jsmonson May 2, 2025
ec39f0d
Fix Dynamic Matmul Initial Config For BERT-Large (#28)
jsmonson May 2, 2025
fc73217
fix argparse arg that could never be false (#30)
jsmonson May 27, 2025
4530385
Patch Pull Request #30: Update args variable to match new argument na…
jsmonson May 28, 2025
7a410b2
update pytorch to 2.7 (#34)
jsmonson May 29, 2025
30e48ad
[Hotfix] Cleanup CI runner artifacts (#33)
tafk7 May 29, 2025
40abeed
update brevitas commit hash (#36)
jsmonson May 30, 2025
937a639
Set onnxscript to a fixed commit id (#37)
jsmonson Jun 2, 2025
84b5301
Hardware Kernel Generator: RTL Parser & wrapper generation (#32)
tafk7 Jun 4, 2025
3b50184
Add BERT-Large CI Test (#40)
jsmonson Jun 11, 2025
9dc7ae8
Docker workflow modernization (#38)
tafk7 Jun 17, 2025
d14fbba
Update FINN (#41)
auphelia Jun 25, 2025
56b90fd
Core DSE & Plugin Library (#44)
tafk7 Aug 1, 2025
1ebf35c
[Deps] Update and fix finn and qonnx deps
auphelia Aug 27, 2025
d3e9f4d
Merge remote-tracking branch 'origin/main' into develop-branch-update
Aug 29, 2025
dedc4cc
Merge branch 'dev/update_finn_qonnx' of github.com:microsoft/Brainsmi…
Aug 29, 2025
eee7ee0
Merge branch 'main' of github.com:microsoft/Brainsmith into develop
Aug 29, 2025
b2e4c44
[Deps] Update and fix finn and qonnx deps (#54)
jsmonson Aug 29, 2025
a41bf20
Hotfix: Update and freeze FINN dependency #2 (#55)
jsmonson Aug 29, 2025
87493eb
Kernel Integrator (#48)
tafk7 Sep 1, 2025
f407e99
update tranformers, add onnxscript, update brevitas
Sep 12, 2025
ce6614e
update tranformers, add onnxscript, update brevitas
Sep 12, 2025
c75bd5a
add loop rolling step
Sep 12, 2025
5c5c7f5
add loop rolling step
Sep 12, 2025
4bc48b4
add bert dynamo export
Sep 12, 2025
a27278a
add bert dynamo export
Sep 12, 2025
db18921
preserve metadata through simplify operation
Sep 12, 2025
ab375b6
preserve metadata through simplify operation
Sep 12, 2025
f741dfc
add bert mlo demo
Sep 12, 2025
9c7a702
add bert mlo demo
Sep 12, 2025
85fb1a2
reinster from white space
Sep 12, 2025
4030c96
reinster from white space
Sep 12, 2025
9cecf0d
update to onnxscript 0.5.0
Sep 15, 2025
812f9cb
update to onnxscript 0.5.0
Sep 15, 2025
ce1c84a
remove custom onnxscript repo
Sep 15, 2025
3ebe70f
remove custom onnxscript repo
Sep 15, 2025
2274c77
update additonal onnx script location
Sep 15, 2025
1960795
update additonal onnx script location
Sep 15, 2025
671ee7a
added split large fifo option for MLO
Sep 15, 2025
b6621cd
added split large fifo option for MLO
Sep 15, 2025
216cbb0
Merge pull request #58 from microsoft/develop-branch-update
tafk7 Sep 19, 2025
729d969
Integration tests & updated docs (#59)
tafk7 Sep 22, 2025
66fa33a
Merge branch 'develop' of github.com:microsoft/Brainsmith into dev/jo…
Sep 24, 2025
b47d7a7
Merge branch 'develop' of github.com:microsoft/Brainsmith into dev/jo…
Sep 24, 2025
cd4f3d5
align with devleop branch
Sep 24, 2025
8481ba2
align with devleop branch
Sep 24, 2025
b211b3a
update bash script to match bert_demo.py args
Sep 26, 2025
22b736c
initial testing of a trained single layer BERT model being passed thr…
STFleming Sep 29, 2025
7d652e3
Changing the input datatype to match the new input
STFleming Sep 30, 2025
9634b70
Collapsing some of the additional mul nodes (thanks @auphelia)
STFleming Sep 30, 2025
466eafc
Small changes to work with the indices from the crop node that gets g…
STFleming Sep 30, 2025
96adde7
Getting up initial training script for pipecleaner model
STFleming Oct 2, 2025
9e35722
Merge remote-tracking branch 'origin/develop' into dev/joshmonson/add…
Oct 2, 2025
f87cedb
Merge branch 'dev/joshmonson/add-loop-rolling' of github.com:microsof…
Oct 2, 2025
a8d2ed1
return missing lines
Oct 2, 2025
cac2921
Merge branch 'develop' of github.com:microsoft/Brainsmith into dev/jo…
Oct 2, 2025
64d2cf3
Add Round and Clip Thresholds Step for MLO
Oct 2, 2025
aef87a1
Fixing commit for dynamo export thanks @auphelia
STFleming Oct 3, 2025
a32d7fc
Making Quantization work inside the Brainsmith container
STFleming Oct 3, 2025
db5ee89
FIFO configuration for a single layer for faster deployment
STFleming Oct 3, 2025
995e724
Adding some precalculated FIFO depths
STFleming Oct 3, 2025
d5726d4
Pointing at latest changes from @auphelia
STFleming Oct 3, 2025
51120ab
Produce a dcp
STFleming Oct 3, 2025
91d4916
update finn configs
Oct 9, 2025
452eecc
4-bit weights are current broken due to fetch weights.
Oct 11, 2025
c5ec519
Merging with Josh's latest MLO tests
STFleming Oct 13, 2025
646a4dc
fetch repos pointing at the appropriate branches
STFleming Oct 13, 2025
61ec09a
Fix typo
STFleming Oct 13, 2025
01ebe97
removing cleanup from here in case it is removing metadata
STFleming Oct 13, 2025
d631499
[Transforms] Add node metadata propagation to bsmith transforms
auphelia Oct 13, 2025
5148465
Merge branch 'dev/auphelia/propagate-metadata' into dev/sfleming/trai…
auphelia Oct 13, 2025
ab30a8d
[Transforms] Add node metadata propagation to bsmith transforms (#72)
auphelia Oct 13, 2025
7d5868b
[BertFlow] Add loop body hierarchies
auphelia Oct 14, 2025
a3851a1
fix the metadata issue.
Oct 15, 2025
d8af33c
forgotten file from last commit
Oct 15, 2025
e6e9f2e
Merge pull request #77 from microsoft/dev/joshmonson/trainedbert_mlo_…
STFleming Oct 16, 2025
8c1d0dc
Merge remote-tracking branch 'origin/dev/joshmonson/add-loop-rolling'…
STFleming Oct 20, 2025
a7678e9
Added an initial folding configuration to try and get the end2end flo…
STFleming Oct 20, 2025
a51639f
[BertMLO] Add prefix to folding config
auphelia Oct 20, 2025
95cfca9
[BertMLO] Add first iteration of folding config json containing all n…
auphelia Oct 21, 2025
9474188
update loop_body_hierarhcy to list of lists
Oct 21, 2025
65022cf
Adding back in the head removal to avoid the automated partitioning.
STFleming Oct 22, 2025
93699b1
[TrainedBERT] removing the duplicate generate_reference_io
STFleming Oct 22, 2025
950c494
Merge pull request #78 from microsoft/dev/joshmonson/run_untrained_mlo
STFleming Oct 22, 2025
6aad596
[ShellHandover] Updated the shell handover generation to include spec…
STFleming Oct 30, 2025
6e039d9
[Crop] Update crop node execute node fct to use hlsbackend
auphelia Oct 31, 2025
5f5993b
initial documentation for loop-rolling
Oct 31, 2025
0ed5942
Roll-back some of the claims made by the AI.
Oct 31, 2025
9e3e445
remove GPT since we don't support that
Oct 31, 2025
6f9796e
additional explaninations
Oct 31, 2025
ae06ebc
updates that need to be reviewed.
Oct 31, 2025
ea65d83
[LayernormHLS] Update execute node fct
auphelia Nov 4, 2025
f76211d
almost done
Nov 7, 2025
e1cb39b
asked ai to review and test the code snippets. They at least run but …
Nov 7, 2025
6091c48
Merge branch 'dev/sfleming/trainedbert_mlo' of github.com:microsoft/B…
Nov 7, 2025
68e5bcb
one more fix
Nov 7, 2025
b2d6e1e
docs: update CLI syntax, enhance styling, and add image lightbox support
tafk7 Nov 9, 2025
4a61e2a
Polish mkdocs site
tafk7 Nov 10, 2025
58681b7
Merge branch 'develop' into dev/tafk/clean-mlo-merge
tafk7 Nov 10, 2025
8bbd49f
Add imports for training
tafk7 Nov 10, 2025
cde2ce3
Merge branch 'dev/tafk/docs-v010' into dev/tafk/clean-mlo-merge
tafk7 Nov 10, 2025
2171a8b
refactor: MLO metadata propagation and build pipeline reorganization
tafk7 Nov 10, 2025
3f83875
Release v0.1.0
tafk7 Nov 11, 2025
e2df546
Rename expand norms steps file for clarity
tafk7 Nov 12, 2025
497dd50
Merge branch 'main' into dev/tafk/clean-mlo-merge
tafk7 Nov 12, 2025
ffbdc34
Missed merge resolutions
tafk7 Nov 12, 2025
81f1228
Merge pull request #90 from microsoft/develop
tafk7 Nov 13, 2025
d133d2d
Merge pull request #92 from microsoft/develop
tafk7 Nov 13, 2025
6b1e9ef
Merge pull request #95 from microsoft/develop
tafk7 Nov 14, 2025
9fea045
Leftover merge cleanup
tafk7 Nov 14, 2025
96a01c6
Merge remote-tracking branch 'origin' into dev/tafk/clean-mlo-merge
tafk7 Nov 14, 2025
dd90fab
Merge remote-tracking branch 'origin/develop' into dev/tafk/clean-mlo…
tafk7 Nov 14, 2025
8ecb3ca
Moveup bitwidth
tafk7 Nov 14, 2025
7e079ee
feat: add mem_modes system and refactor BuildContext
tafk7 Nov 19, 2025
1df3cbe
refactor: consolidate compilation steps into logical modules
tafk7 Nov 22, 2025
c696e90
fix(thresholding): handle MLO nodes in RTL codegen
tafk7 Nov 24, 2025
6ea2677
fix: update kernel IPI signatures and import paths
tafk7 Dec 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions brainsmith/_internal/io/dependency_installers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def install(self, name: str, dep: dict, dest: Path, force: bool, quiet: bool) ->
cmd.extend([dep["url"], str(dest)])

if not quiet:
logger.info("Cloning %s from %s", name, dep["url"])
logger.debug("Cloning %s from %s", name, dep['url'])

result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
Expand Down Expand Up @@ -170,7 +170,7 @@ def install(self, name: str, dep: dict, dest: Path, force: bool, quiet: bool) ->

try:
if not quiet:
logger.info("Downloading %s from %s", name, dep["url"])
logger.debug("Downloading %s from %s", name, dep['url'])

urlretrieve(dep["url"], zip_path)

Expand Down Expand Up @@ -254,7 +254,7 @@ def _install_finn_xsim(self, force: bool, quiet: bool) -> None:

# Build with finn-xsim
if not quiet:
logger.info("Building finn-xsim...")
logger.debug("Building finn-xsim...")

# Construct build command
build_cmd = ["python3", "-m", "finn.xsi.setup"]
Expand All @@ -265,16 +265,16 @@ def _install_finn_xsim(self, force: bool, quiet: bool) -> None:
python_cmd = " ".join(build_cmd)
bash_cmd = f"source {settings_script} && {python_cmd}"

logger.info("Running: %s", bash_cmd)
logger.debug("Running: %s", bash_cmd)

# Execute build
result = subprocess.run(["bash", "-c", bash_cmd], capture_output=True, text=True)

# Log output at INFO level (visible with --logs info)
# Log output at DEBUG level (visible with --logs debug)
if result.stdout:
for line in result.stdout.splitlines():
if line.strip():
logger.info(line)
logger.debug(line)

if result.stderr:
for line in result.stderr.splitlines():
Expand Down Expand Up @@ -322,7 +322,7 @@ def _install_generic_build(self, name: str, dep: dict, force: bool, quiet: bool)
raise BuildError(error_msg)

if not quiet:
logger.info("Building %s in %s", name, source_dir)
logger.debug("Building %s in %s", name, source_dir)

# Run build command
env = os.environ.copy()
Expand All @@ -334,7 +334,7 @@ def _install_generic_build(self, name: str, dep: dict, force: bool, quiet: bool)
if result.stdout:
for line in result.stdout.splitlines():
if line.strip():
logger.info(line)
logger.debug(line)

if result.stderr:
for line in result.stderr.splitlines():
Expand Down
4 changes: 2 additions & 2 deletions brainsmith/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# Licensed under the MIT License.

# Version information for brainsmith
__version__ = "0.0.1a"
__version_tuple__ = (0, 0, 1, "a")
__version__ = "0.1.0"
__version_tuple__ = (0, 1, 0)
62 changes: 46 additions & 16 deletions brainsmith/dataflow/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
from math import gcd
from typing import TYPE_CHECKING, Any

from onnx import NodeProto
from qonnx.core.datatype import BaseDataType
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.util.basic import get_by_name

from brainsmith._internal.math import divisors

Expand All @@ -57,20 +59,16 @@ class BuildContext:
Attributes:
schema: KernelSchema defining structure
model_w: ModelWrapper for ONNX graph access
node_inputs: ONNX node input tensor names
node_outputs: ONNX node output tensor names
node: ONNX NodeProto (provides .input, .output, .name)
param_getter: Function to retrieve nodeattr values
param_setter: Function to store nodeattr values
node_name: Node name for error messages
"""

schema: KernelSchema
model_w: ModelWrapper
node_inputs: list[str]
node_outputs: list[str]
node: NodeProto
param_getter: Callable[[str], Any]
param_setter: Callable[[str, Any], None]
node_name: str = "<unknown>"


class DesignSpaceBuilder:
Expand All @@ -85,11 +83,9 @@ class DesignSpaceBuilder:
>>> context = BuildContext(
... schema=kernel_schema,
... model_w=model_wrapper,
... node_inputs=list(node.input),
... node_outputs=list(node.output),
... node=node,
... param_getter=self.get_nodeattr,
... param_setter=self.set_nodeattr,
... node_name=node.name
... )
>>> design_space = builder.build(context)
>>> point = design_space.configure({"SIMD": 64, "PE": 1})
Expand Down Expand Up @@ -195,12 +191,12 @@ def build(self, ctx: BuildContext) -> KernelDesignSpace:
self._ctx = ctx
self._interfaces: dict[str, Any] = {}

logger.debug(f"Building KernelDesignSpace for {ctx.node_name}")
logger.debug(f"Building KernelDesignSpace for {ctx.node.name}")

# Build input interfaces from ONNX graph
inputs: dict[str, InterfaceDesignSpace] = {}

for i, inp_name in enumerate(ctx.node_inputs):
for i, inp_name in enumerate(ctx.node.input):
if not inp_name:
continue

Expand Down Expand Up @@ -248,7 +244,7 @@ def build(self, ctx: BuildContext) -> KernelDesignSpace:
# Build output interfaces (may derive datatypes from inputs)
outputs: dict[str, InterfaceDesignSpace] = {}

for i, out_name in enumerate(ctx.node_outputs):
for i, out_name in enumerate(ctx.node.output):
if i >= len(ctx.schema.outputs):
logger.warning(
f"Node has output {i} but schema only defines {len(ctx.schema.outputs)} outputs"
Expand Down Expand Up @@ -294,7 +290,7 @@ def build(self, ctx: BuildContext) -> KernelDesignSpace:
if (e := c.check(validation_ctx))
]
if failed:
raise ValueError(f"{ctx.node_name} validation failed:\n" + "\n".join(failed))
raise ValueError(f"{ctx.node.name} validation failed:\n" + "\n".join(failed))

logger.debug(f" All {len(structural_constraints)} structural constraints passed")

Expand All @@ -317,7 +313,7 @@ def build(self, ctx: BuildContext) -> KernelDesignSpace:
parameters=all_dimensions,
)

logger.debug(f"KernelDesignSpace built successfully for {ctx.node_name}")
logger.debug(f"KernelDesignSpace built successfully for {ctx.node.name}")
return design_space

def _resolve_datatype(
Expand Down Expand Up @@ -696,8 +692,42 @@ def _compute_dimension_ranges(
f"{ordered_count} ordered, {discrete_count} discrete"
)

# Combine tiling + DSE dimensions
all_dimensions = {**tiling_dimensions, **dse_dimensions}
# Generate input<idx>MemType parameters from mem_modes
mem_mode_dimensions = {}
for idx, inp in enumerate(schema.inputs):
if inp.mem_modes is None:
continue

param_name = f"input{idx}MemType"

# Check if InferKernel marked this input as a weight
# Attribute presence indicates weight; absence indicates pure streaming input
attr = get_by_name(self._ctx.node.attribute, param_name)
if attr is None:
# Not a weight - skip parameter creation
logger.debug(f"Skipping {param_name}: not marked as weight by InferKernel")
continue

values = inp.mem_modes

# Support callable for context-aware filtering (e.g., MLO)
if callable(values):
values = values(self._ctx)

# Ensure frozenset for discrete parameter
if not isinstance(values, frozenset):
values = frozenset(values)

mem_mode_dimensions[param_name] = values

if mem_mode_dimensions:
logger.debug(
f"Added {len(mem_mode_dimensions)} mem_mode dimensions: "
+ ", ".join(f"{k}={v}" for k, v in mem_mode_dimensions.items())
)

# Combine tiling + DSE + mem_mode dimensions
all_dimensions = {**tiling_dimensions, **dse_dimensions, **mem_mode_dimensions}

return all_dimensions

Expand Down
15 changes: 11 additions & 4 deletions brainsmith/dataflow/dse_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class InterfaceDesignSpace:
datatype: Interface datatype
is_weight: Whether this is a weight tensor (constant)
tensor_name: ONNX tensor name for initializer lookups
mem_mode: Memory mode for weight inputs (embedded/decoupled/dynamic)
parallelism_dimension: OrderedParameter for stream parameter (None if no parallelism)
parallelism_param: Parameter name for stream dimension (e.g., "SIMD", "PE")
"""
Expand All @@ -88,16 +89,18 @@ class InterfaceDesignPoint:
"""Interface instance with resolved parallelization.

Flyweight pattern: references parent design space, stores only configuration-
specific stream_shape. Delegates tensor_shape, block_shape, and datatype
to design space for minimal memory overhead.
specific stream_shape and mem_mode. Delegates tensor_shape, block_shape, and
datatype to design space for minimal memory overhead.

Attributes:
design_space: Parent InterfaceDesignSpace
stream_shape: Resolved stream dimensions for this configuration
mem_mode: Memory mode for weight inputs (embedded/decoupled/dynamic)
"""

design_space: InterfaceDesignSpace
stream_shape: Shape
mem_mode: str | None = None # Memory mode (embedded/decoupled/dynamic) for weight inputs

# Convenience properties (delegate to design space)
@property
Expand Down Expand Up @@ -399,7 +402,7 @@ def _instantiate_interfaces(
from .template_resolution import resolve_template

configured = {}
for interface in interfaces.values():
for idx, interface in enumerate(interfaces.values()):
stream_shape = (
interface.block_shape
if interface.stream_tiling is None
Expand All @@ -413,8 +416,12 @@ def _instantiate_interfaces(
)
)

# Extract mem_mode from params if this is an input with mem_modes
mem_mode_param = f"input{idx}MemType"
mem_mode = params.get(mem_mode_param)

configured_interface = InterfaceDesignPoint(
design_space=interface, stream_shape=stream_shape
design_space=interface, stream_shape=stream_shape, mem_mode=mem_mode
)
configured[interface.name] = configured_interface
interface_lookup[interface.name] = configured_interface
Expand Down
10 changes: 7 additions & 3 deletions brainsmith/dataflow/kernel_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,9 @@ def _ensure_ready(self, model_w: ModelWrapper) -> None:
build_ctx = BuildContext(
schema=self.kernel_schema,
model_w=model_w,
node_inputs=list(self.onnx_node.input),
node_outputs=list(self.onnx_node.output),
node=self.onnx_node,
param_getter=self.get_nodeattr,
param_setter=self.set_nodeattr,
node_name=self.onnx_node.name,
)

try:
Expand All @@ -324,6 +322,12 @@ def _ensure_ready(self, model_w: ModelWrapper) -> None:
# OrderedParameter: use get_default() (explicit default or minimum)
initial_value = param.get_default()
else: # frozenset
# Defensive: skip empty parameter sets (shouldn't happen with new design)
if len(param) == 0:
logger.debug(
f"{self.onnx_node.name}: Skipping empty parameter {param_name}"
)
continue
# Discrete: use sorted first value
initial_value = sorted(param)[0]

Expand Down
27 changes: 27 additions & 0 deletions brainsmith/dataflow/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ class InputSchema:
stream_tiling: Stream tiling specification (e.g., ["SIMD"], [1, 1, 1, "PE"])
datatype: Datatype spec (None to use from ONNX, or DatatypeSpec union type to derive/optimize)
required_layout: Expected input layout (e.g., "NHWC", "NCHW"), None if no requirement
mem_modes: Memory mode options for weight inputs (frozenset or callable returning frozenset).
Valid modes: "embedded" (compile-time constant), "decoupled" (separate memory),
"dynamic"/"external" (streaming). Generates input<idx>MemType DSE parameter.
"""

# Identity
Expand All @@ -268,6 +271,9 @@ class InputSchema:
# Transformation requirements (NEW - embedded in interface)
required_layout: str | None = None

# Memory mode specification for weight inputs
mem_modes: frozenset[str] | Callable | None = None

def __post_init__(self):
"""Validate interface requirements."""
if self.required_layout and self.required_layout not in {"NCHW", "NHWC"}:
Expand All @@ -276,6 +282,21 @@ def __post_init__(self):
f"Must be 'NCHW' or 'NHWC'."
)

# Validate mem_modes if specified
if self.mem_modes is not None and not callable(self.mem_modes):
VALID_MEM_MODES = {"embedded", "decoupled", "dynamic", "external"}
if not isinstance(self.mem_modes, frozenset):
raise TypeError(
f"mem_modes for input '{self.name}' must be frozenset or callable, "
f"got {type(self.mem_modes).__name__}"
)
invalid = self.mem_modes - VALID_MEM_MODES
if invalid:
raise ValueError(
f"Invalid mem_modes {invalid} for input '{self.name}'. "
f"Valid modes: {VALID_MEM_MODES}"
)

@property
def tiling_attrs(self) -> list[str]:
"""Extract unique template parameter names from tiling specs."""
Expand Down Expand Up @@ -461,6 +482,12 @@ def build_nodeattr_registry(self) -> dict[str, tuple]:
for param in template_params:
attrs[param] = ("i", False, 1) # Default 1, will be computed from factoring

# Memory mode parameters (input<idx>MemType) - auto-extracted from mem_modes
for idx, inp in enumerate(self.inputs):
if inp.mem_modes is not None:
# Add input<idx>MemType as a string parameter
attrs[f"input{idx}MemType"] = ("s", False, "embedded")

# DSE parameters (resource parameters)
for param_name, param_spec in self.dse_parameters.items():
attrs[param_name] = _infer_nodeattr_type(param_spec)
Expand Down
Loading