Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions compiler/back_end/cpp/generated_code_templates
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,18 @@ ${write_fields}
bool SizeIsKnown() const { return IntrinsicSizeIn${units}().Ok(); }


// ** static_size_from_parameters_declaration ** ///////////////////////////////
static constexpr ::std::size_t IntrinsicSizeIn${units}(${parameters});

// ** static_size_from_parameters_definition ** ////////////////////////////////
namespace ${parent_type} {
inline constexpr ::std::size_t IntrinsicSizeIn${units}(${parameters}) {
${subexpressions}
return static_cast</**/ ::std::size_t>((${read_value}).ValueOrDefault());
}
} // namespace ${parent_type}


// ** ok_method_test ** ////////////////////////////////////////////////////////
// If we don't have enough information to determine whether ${field} is
// present in the structure, then structure.Ok() should be false.
Expand Down
233 changes: 221 additions & 12 deletions compiler/back_end/cpp/header_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,104 @@ def render_field(self, expression, ir, subexpressions):
)


def _expression_only_depends_on_parameters(expression, parameter_names):
"""Checks if an expression only depends on parameters (not fields).

Arguments:
expression: The expression to check.
parameter_names: A set of parameter names that are valid dependencies.

Returns:
True if the expression only depends on constants and the given parameters,
False if it depends on any fields.
"""
expression = ir_data_utils.reader(expression)

# Constant expressions are fine
if expression.type.which_type == "integer":
if expression.type.integer.modulus == "infinity":
return True
elif expression.type.which_type == "boolean":
if expression.type.boolean.has_field("value"):
return True
elif expression.type.which_type == "enumeration":
if expression.type.enumeration.has_field("value"):
return True

# Check the expression kind
if expression.which_expression == "constant":
return True
elif expression.which_expression == "constant_reference":
return True
elif expression.which_expression == "boolean_constant":
return True
elif expression.which_expression == "field_reference":
# Field references can be parameters if they reference a parameter field.
# Parameters are represented as field_references with a single-element path
# where the field name matches a parameter name.
path = expression.field_reference.path
if len(path) == 1:
field_name = path[0].canonical_name.object_path[-1]
if field_name in parameter_names:
return True
return False
elif expression.which_expression == "builtin_reference":
# Check if this is a parameter reference
name = expression.builtin_reference.canonical_name.object_path[-1]
return name in parameter_names
elif expression.which_expression == "function":
# Recursively check all arguments
for arg in expression.function.args:
if not _expression_only_depends_on_parameters(arg, parameter_names):
return False
return True
elif expression.which_expression is None:
return True

return False


class _StaticParameterFieldRenderer(object):
"""Renderer for expressions in static parameterized functions.

This renderer is used to generate code for expressions that only depend on
parameters, where the parameters are passed as function arguments rather
than accessed through a view.
"""

def __init__(self, parameter_names, ir):
"""Initialize the renderer.

Arguments:
parameter_names: A set of parameter names.
ir: The IR for type lookups.
"""
self._parameter_names = parameter_names
self._ir = ir

def render_existence(self, expression, subexpressions):
# Parameters always exist
del expression, subexpressions # Unused
return _maybe_type("bool") + "(true)"

def render_field(self, expression, ir, subexpressions):
# For static parameter functions, field references that point to parameters
# should be rendered as direct variable access.
path = expression.field_reference.path
if len(path) == 1:
field_name = path[0].canonical_name.object_path[-1]
if field_name in self._parameter_names:
expression_cpp_type = _cpp_basic_type_for_expression(expression, ir)
return "{}({})".format(_maybe_type(expression_cpp_type), field_name)
# This shouldn't happen as we filtered expressions to only those that
# depend on parameters
assert False, (
"Non-parameter field reference in static parameter expression: {}".format(
expression
)
)


class _SubexpressionStore(object):
"""Holder for subexpressions to be assigned to local variables."""

Expand Down Expand Up @@ -849,16 +947,22 @@ def _render_expression(expression, ir, field_reader=None, subexpressions=None):
result = _render_builtin_operation(expression, ir, field_reader, subexpressions)
elif expression.which_expression == "field_reference":
result = field_reader.render_field(expression, ir, subexpressions)
elif (
expression.which_expression == "builtin_reference"
and expression.builtin_reference.canonical_name.object_path[-1]
== "$logical_value"
):
return _ExpressionResult(
_maybe_type("decltype(emboss_reserved_local_value)")
+ "(emboss_reserved_local_value)",
False,
)
elif expression.which_expression == "builtin_reference":
name = expression.builtin_reference.canonical_name.object_path[-1]
if name == "$logical_value":
return _ExpressionResult(
_maybe_type("decltype(emboss_reserved_local_value)")
+ "(emboss_reserved_local_value)",
False,
)
elif isinstance(field_reader, _StaticParameterFieldRenderer):
# For static parameter functions, render parameter as direct variable
expression_cpp_type = _cpp_basic_type_for_expression(expression, ir)
result = "{}({})".format(_maybe_type(expression_cpp_type), name)
else:
# This shouldn't happen - non-logical-value builtin references should
# be handled by type checking
result = None

# Any of the constant expression types should have been handled in the
# previous section.
Expand Down Expand Up @@ -1232,6 +1336,105 @@ def _render_size_method(fields, ir):
assert False, "Expected a $size_in_bits or $size_in_bytes field."


def _generate_static_size_from_parameters(type_ir, ir):
"""Generates a static IntrinsicSizeInBytes(params) function if applicable.

For structs where the size depends only on parameters (not on field values),
generates a static function that computes the size from parameters alone.

Arguments:
type_ir: The IR for the struct definition.
ir: The full IR; used for type lookups.

Returns:
A tuple of (declaration, definition) strings, or ("", "") if not applicable.
"""
# Only applicable if the struct has parameters
if not type_ir.runtime_parameter:
return "", ""

# Find the size field ($size_in_bits or $size_in_bytes)
size_field = None
for field in type_ir.structure.field:
if field.name.name.text in ("$size_in_bits", "$size_in_bytes"):
size_field = field
break

if size_field is None:
return "", ""

# Check if the size expression already is constant (no need for parameterized version)
if (
_render_expression(size_field.read_transform, ir).is_constant
and _render_expression(size_field.existence_condition, ir).is_constant
):
return "", ""

# Get the set of parameter names
parameter_names = set()
for param in type_ir.runtime_parameter:
parameter_names.add(param.name.name.text)

# Check if the size expression only depends on parameters
if not _expression_only_depends_on_parameters(
size_field.read_transform, parameter_names
):
return "", ""

# Also check existence condition
if not _expression_only_depends_on_parameters(
size_field.existence_condition, parameter_names
):
return "", ""

# Generate the static function
type_name = type_ir.name.name.text
units = "Bits" if size_field.name.name.text == "$size_in_bits" else "Bytes"

# Build parameter list
param_list = []
for param in type_ir.runtime_parameter:
param_type = _cpp_basic_type_for_expression_type(param.type, ir)
param_name = param.name.name.text
param_list.append("{} {}".format(param_type, param_name))

parameters = ", ".join(param_list)

# Render the expression with the static parameter renderer
static_field_reader = _StaticParameterFieldRenderer(parameter_names, ir)
subexpressions = _SubexpressionStore("emboss_reserved_local_subexpr_")
read_value = _render_expression(
size_field.read_transform,
ir,
field_reader=static_field_reader,
subexpressions=subexpressions,
)

# Build subexpressions string
subexpr_str = "".join(
[
" const auto {} = {};\n".format(subexpr_name, subexpr)
for subexpr_name, subexpr in subexpressions.subexprs()
]
)

declaration = code_template.format_template(
_TEMPLATES.static_size_from_parameters_declaration,
units=units,
parameters=parameters,
)
definition = code_template.format_template(
_TEMPLATES.static_size_from_parameters_definition,
parent_type=type_name,
units=units,
parameters=parameters,
subexpressions=subexpr_str,
read_value=read_value.rendered,
)

return declaration, definition


def _visibility_for_field(field_ir):
"""Returns the C++ visibility for field_ir within its parent view."""
# Generally, the Google style guide for hand-written C++ forbids having
Expand Down Expand Up @@ -1490,13 +1693,19 @@ def _generate_structure_definition(type_ir, ir, config: Config):
else:
text_stream_methods = ""

# Generate static IntrinsicSizeInBytes(params) if size only depends on parameters
static_size_declaration, static_size_definition = (
_generate_static_size_from_parameters(type_ir, ir)
)

class_forward_declarations = code_template.format_template(
_TEMPLATES.structure_view_declaration, name=type_name
)
class_bodies = code_template.format_template(
_TEMPLATES.structure_view_class,
name=type_ir.name.canonical_name.object_path[-1],
size_method=_render_size_method(type_ir.structure.field, ir),
size_method=_render_size_method(type_ir.structure.field, ir)
+ static_size_declaration,
field_method_declarations="".join(field_method_declarations),
field_ok_checks="\n".join(ok_method_clauses),
parameter_ok_checks="\n".join(parameter_checks),
Expand All @@ -1514,7 +1723,7 @@ def _generate_structure_definition(type_ir, ir, config: Config):
initialize_parameters_initialized_true=(initialize_parameters_initialized_true),
units=units,
)
method_definitions = "\n".join(field_method_definitions)
method_definitions = "\n".join(field_method_definitions) + static_size_definition
early_virtual_field_types = "\n".join(virtual_field_type_definitions)
all_field_helper_type_definitions = "\n".join(field_helper_type_definitions)
return (
Expand Down
19 changes: 19 additions & 0 deletions compiler/back_end/cpp/testcode/parameters_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,25 @@ TEST(Axes, VirtualUsingParameter) {
EXPECT_EQ(3, view.axis_count_plus_one().Read());
}

TEST(Axes, StaticIntrinsicSizeInBytesFromParameters) {
// Test the static IntrinsicSizeInBytes function that takes parameters directly
// For Axes struct, the size is axes * 4 bytes
EXPECT_EQ(0U, Axes::IntrinsicSizeInBytes(0));
EXPECT_EQ(4U, Axes::IntrinsicSizeInBytes(1));
EXPECT_EQ(8U, Axes::IntrinsicSizeInBytes(2));
EXPECT_EQ(12U, Axes::IntrinsicSizeInBytes(3));
EXPECT_EQ(16U, Axes::IntrinsicSizeInBytes(4));
EXPECT_EQ(40U, Axes::IntrinsicSizeInBytes(10));

// Verify static function matches instance method
::std::array<char, 12> values = {1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0};
auto view = MakeAxesView(2, &values);
EXPECT_EQ(view.SizeInBytes(), Axes::IntrinsicSizeInBytes(2));

auto view3 = MakeAxesView(3, &values);
EXPECT_EQ(view3.SizeInBytes(), Axes::IntrinsicSizeInBytes(3));
}

TEST(AxesEnvelope, FieldPassedAsParameter) {
::std::array<unsigned char, 9> values = {2, 0, 0, 0, 0x80, 0, 100, 0, 0};
auto view = MakeAxesEnvelopeView(&values);
Expand Down
19 changes: 19 additions & 0 deletions testdata/golden_cpp/parameters.emb.h
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,7 @@ if (!parameters_initialized_) return false;
}
bool SizeIsKnown() const { return IntrinsicSizeInBytes().Ok(); }

static constexpr ::std::size_t IntrinsicSizeInBytes(::std::int32_t axes);


template <typename OtherStorage>
Expand Down Expand Up @@ -9039,6 +9040,24 @@ GenericAxesView<
Storage>::EmbossReservedDollarVirtualMinSizeInBytesView::UncheckedRead() {
return Axes::MinSizeInBytes();
}
namespace Axes {
inline constexpr ::std::size_t IntrinsicSizeInBytes(::std::int32_t axes) {
const auto emboss_reserved_local_subexpr_1 = ::emboss::support::Maybe</**/::std::int32_t>(axes);
const auto emboss_reserved_local_subexpr_2 = ::emboss::support::Product</**/::std::int32_t, ::std::int32_t, ::std::int32_t, ::std::int32_t>(emboss_reserved_local_subexpr_1, ::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(4LL)));
const auto emboss_reserved_local_subexpr_3 = ::emboss::support::Sum</**/::std::int32_t, ::std::int32_t, ::std::int32_t, ::std::int32_t>(::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(0LL)), emboss_reserved_local_subexpr_2);
const auto emboss_reserved_local_subexpr_4 = ::emboss::support::Choice</**/::std::int32_t, ::std::int32_t, bool, ::std::int32_t, ::std::int32_t>(::emboss::support::Maybe</**/bool>(true), emboss_reserved_local_subexpr_3, ::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(0LL)));
const auto emboss_reserved_local_subexpr_5 = ::emboss::support::GreaterThan</**/::std::int32_t, bool, ::std::int32_t, ::std::int32_t>(emboss_reserved_local_subexpr_1, ::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(0LL)));
const auto emboss_reserved_local_subexpr_6 = ::emboss::support::Choice</**/::std::int32_t, ::std::int32_t, bool, ::std::int32_t, ::std::int32_t>(emboss_reserved_local_subexpr_5, ::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(4LL)), ::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(0LL)));
const auto emboss_reserved_local_subexpr_7 = ::emboss::support::GreaterThan</**/::std::int32_t, bool, ::std::int32_t, ::std::int32_t>(emboss_reserved_local_subexpr_1, ::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(1LL)));
const auto emboss_reserved_local_subexpr_8 = ::emboss::support::Choice</**/::std::int32_t, ::std::int32_t, bool, ::std::int32_t, ::std::int32_t>(emboss_reserved_local_subexpr_7, ::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(8LL)), ::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(0LL)));
const auto emboss_reserved_local_subexpr_9 = ::emboss::support::GreaterThan</**/::std::int32_t, bool, ::std::int32_t, ::std::int32_t>(emboss_reserved_local_subexpr_1, ::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(2LL)));
const auto emboss_reserved_local_subexpr_10 = ::emboss::support::Choice</**/::std::int32_t, ::std::int32_t, bool, ::std::int32_t, ::std::int32_t>(emboss_reserved_local_subexpr_9, ::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(12LL)), ::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(0LL)));
const auto emboss_reserved_local_subexpr_11 = ::emboss::support::Maximum</**/::std::int32_t, ::std::int32_t, ::std::int32_t, ::std::int32_t, ::std::int32_t, ::std::int32_t, ::std::int32_t>(::emboss::support::Maybe</**/::std::int32_t>(static_cast</**/::std::int32_t>(0LL)), emboss_reserved_local_subexpr_4, emboss_reserved_local_subexpr_6, emboss_reserved_local_subexpr_8, emboss_reserved_local_subexpr_10);

return static_cast</**/ ::std::size_t>((emboss_reserved_local_subexpr_11).ValueOrDefault());
}
} // namespace Axes

namespace AxisPair {

} // namespace AxisPair
Expand Down