Skip to content

Conversation

@isVoid
Copy link
Collaborator

@isVoid isVoid commented Jan 23, 2026

This adds a template argument deduction module and unit tests, and it incorporates the recently added argument intent concept from Numbast PR #278 so visible arity and pointer passing stay consistent across templated overloads.

At a high level, the module accepts a list of ast_canopy-parsed function templates plus optional intent overrides, and returns FunctionTemplate instances with deduced argument types for the original parameter names.

Tested cases:

  • Overload selection by visible arity.
    • C++: template <typename T> __device__ T add(T a, T b); and template <typename T> __device__ T add(T a, T b, T c);
    • Example: add(int, int) picks the two-arg overload and deduces T=int.
  • Conflicting placeholder deduction skips an overload.
    • C++: template <typename T> __device__ T add(T a, T b);
    • Example: add(int, float) yields no specialization because T conflicts.
  • Non-templated parameter types must match.
    • C++: template <typename T> __device__ T add_int(int a, T b);
    • Example: add_int(int, float) specializes, add_int(float, float) does not.
  • Return-only placeholders are skipped.
    • C++: template <typename T> __device__ T return_only();
    • Example: return_only<T>() cannot deduce T with no args, so it is skipped.
  • Intent overrides including out_return.
    • C++: template <typename T> __device__ void store_ref(T &out, T value);
    • Examples: store_ref(CPointer(int), int) with overrides={"out": "out_ptr"} or {"out": "inout_ptr"} deduces T=int; store_ref(int) with overrides={"out": "out_return"} also deduces T=int via hidden out param.
  • Invalid overrides surface as errors.
    • C++: template <typename T> __device__ void bad_out(T value);
    • Example: bad_out(value=out_ptr) reports a ValueError in intent_errors.
  • Struct method specialization.
    • C++: struct Box { template <typename T> __device__ T mul(T a, T b) const; };
    • Example: Box::mul(float, float) deduces T=float.
  • Unmappable Numba args are skipped.
    • C++: template <typename T> __device__ T add(T a, T b);
    • Example: add(float32[:], float32[:]) yields no specialization.

This currently is a standalone module that's not wired into any existing binding generation so that it can be tested and reasoned independently.

Summary by CodeRabbit

  • New Features

    • Added template overload deduction and type specialization for templated functions and methods.
  • Improvements

    • Better handling of pointer/reference parameters, placeholder resolution, and conflict detection.
    • Improved mapping from runtime argument types to template parameters and optional debug tracing.
  • Bug Fixes

    • Fixed pointer-type stringification so pointer types are recognized correctly.
  • Tests

    • Added comprehensive tests for overload selection, conflicts, overrides, pointer/ref cases, and method specialization.
  • Chores

    • CI: ensure compiler availability before running style checks.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 23, 2026

📝 Walkthrough

Walkthrough

Implements a new template overload deduction engine that maps Numba types to C++-style strings, deduces template parameters from argument types, specializes templated functions and struct methods, adds tests, updates pointer type handling, and tweaks CI for g++ availability.

Changes

Cohort / File(s) Summary
Core deduction module
numbast/src/numbast/deduction.py
New module implementing template overload deduction: debug utilities, C++-style type normalization, Numba→C++ mapping, pattern-based template parameter extraction, placeholder replacement, function/struct-method specialization, unresolved-placeholder detection, and the public entry deduce_templated_overloads.
Deduction tests
numbast/tests/test_deduction.py
New pytest suite exercising deduce_templated_overloads: arity selection, conflicting deductions, non-templated param matching, return-placeholder skipping, pointer/override handling and related errors, struct-method specialization, and unmappable-argument cases; includes fixtures parsing CUDA header declarations.
Type string handling tweak
numbast/src/numbast/types.py
Adds pointer handling in to_c_type_str by recognizing nbtypes.CPointer and converting the pointee type with a trailing * before consulting existing type maps.
CI style check
ci/check_style.sh
Adds apt-get update and apt-get install -y g++ prior to installing pre-commit to ensure g++ is present in the environment.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant Deductor as deduce_templated_overloads
    participant IntentPlanner as Intent Planner
    participant TypeDeducer as Type Pattern Deducer
    participant Specializer as Specializer
    participant Validator as Placeholder Validator

    Caller->>Deductor: provide overloads, args, overrides, debug
    activate Deductor

    alt overrides provided
        Deductor->>IntentPlanner: compute intent plan (visible params)
        IntentPlanner-->>Deductor: visible params + intent errors
    else no overrides
        Deductor->>Deductor: use all function parameters
    end

    loop per overload
        Deductor->>Deductor: validate arity against visible params
        alt arity matches
            loop per visible parameter
                Deductor->>TypeDeducer: match param pattern with arg type
                TypeDeducer-->>Deductor: mapping fragment or conflict
            end
            Deductor->>Specializer: apply combined mappings to template
            Specializer-->>Deductor: specialized function/struct method
            Deductor->>Validator: check for unresolved placeholders
            Validator-->>Deductor: resolved? accept : discard
        else arity mismatch
            Deductor-->>Deductor: skip overload
        end
    end

    Deductor-->>Caller: return specialized templates + collected errors
    deactivate Deductor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🐰

I hopped through types and patterns bright,
I nudged placeholders into the light,
Arguments matched, templates became new,
Functions specialized, tidy and true,
A happy nibble — deduction's through!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.31% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add template argument deduction module and tests' directly and concisely describes the main change: introducing a new deduction module and its corresponding tests.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In `@numbast/src/numbast/deduction.py`:
- Around line 195-262: The except block in deduce_templated_overloads currently
catches Exception when calling compute_intent_plan; narrow this to only the
documented exceptions by changing the handler to except (ValueError, TypeError)
as exc so only intent-related ValueError/TypeError cases are appended to
intent_errors; locate the try/except around compute_intent_plan (call site:
compute_intent_plan, variables plan/intent_errors) and replace the broad
Exception catch with the tuple of specific exception types.
- Around line 64-69: The helper _numba_arg_to_cxx_type currently catches all
Exceptions when calling to_c_type_str(arg), which can hide programming errors;
change the broad except to catch only ValueError (the error to_c_type_str raises
for unknown Numba types) so that other exceptions (e.g.,
AttributeError/TypeError) still surface; keep the behavior of normalizing the
arg via _normalize_numba_arg_type and returning None on ValueError after
attempting to normalize the C++ type with _normalize_cxx_type_str.
- Around line 72-77: _param_type_matches_arg calls to_numba_type(cxx_type) which
can raise KeyError for unknown C++ types; wrap the to_numba_type call in a
try/except and handle the KeyError by treating unknown mappings the same way as
nbtypes.undefined (i.e., return True or otherwise mark as compatible). Update
the function _param_type_matches_arg to catch KeyError from to_numba_type, log
or comment if desired, and then fallback to the existing branch that returns
True for undefined types; keep using _normalize_numba_arg_type(arg) for the
final comparison.
- Around line 80-105: _deduce_from_type_pattern currently records each
placeholder only once in order while pattern replacement creates a capture group
for every occurrence; update the logic so order records the placeholder for each
occurrence (i.e., append ph each time you replace it) so that match.groups() can
be zipped to every placeholder occurrence and conflicting values are checked
against previously deduced values. Concretely, iterate through cxx_type
occurrences (or perform the replacements with a callback) and for each found
placeholder add a corresponding r"(.*?)" to the regex and append that
placeholder to order; then keep the existing match, strip, emptiness check and
conflict detection against deduced in the same way.

Comment on lines +64 to +69
def _numba_arg_to_cxx_type(arg: nbtypes.Type) -> str | None:
arg = _normalize_numba_arg_type(arg)
try:
return _normalize_cxx_type_str(to_c_type_str(arg))
except Exception:
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Narrow the exception handling to ValueError.

Based on the to_c_type_str implementation in numbast/types.py, it raises ValueError for unknown Numba types. Catching bare Exception could mask unrelated programming errors (e.g., AttributeError, TypeError).

♻️ Proposed fix
 def _numba_arg_to_cxx_type(arg: nbtypes.Type) -> str | None:
     arg = _normalize_numba_arg_type(arg)
     try:
         return _normalize_cxx_type_str(to_c_type_str(arg))
-    except Exception:
+    except ValueError:
         return None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _numba_arg_to_cxx_type(arg: nbtypes.Type) -> str | None:
arg = _normalize_numba_arg_type(arg)
try:
return _normalize_cxx_type_str(to_c_type_str(arg))
except Exception:
return None
def _numba_arg_to_cxx_type(arg: nbtypes.Type) -> str | None:
arg = _normalize_numba_arg_type(arg)
try:
return _normalize_cxx_type_str(to_c_type_str(arg))
except ValueError:
return None
🤖 Prompt for AI Agents
In `@numbast/src/numbast/deduction.py` around lines 64 - 69, The helper
_numba_arg_to_cxx_type currently catches all Exceptions when calling
to_c_type_str(arg), which can hide programming errors; change the broad except
to catch only ValueError (the error to_c_type_str raises for unknown Numba
types) so that other exceptions (e.g., AttributeError/TypeError) still surface;
keep the behavior of normalizing the arg via _normalize_numba_arg_type and
returning None on ValueError after attempting to normalize the C++ type with
_normalize_cxx_type_str.

Comment on lines +72 to +77
def _param_type_matches_arg(cxx_type: str, arg: nbtypes.Type) -> bool:
"""Best-effort compatibility check for non-templated parameters."""
nb_expected = to_numba_type(cxx_type)
if nb_expected is nbtypes.undefined:
return True
return nb_expected == _normalize_numba_arg_type(arg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing exception handling for unknown C++ types.

to_numba_type can raise KeyError when cxx_type is not in CTYPE_MAPS. This could cause the deduction process to fail unexpectedly for unsupported types. Consider handling the exception gracefully, similar to _numba_arg_to_cxx_type.

🐛 Proposed fix
 def _param_type_matches_arg(cxx_type: str, arg: nbtypes.Type) -> bool:
     """Best-effort compatibility check for non-templated parameters."""
-    nb_expected = to_numba_type(cxx_type)
+    try:
+        nb_expected = to_numba_type(cxx_type)
+    except KeyError:
+        # Unknown C++ type; assume compatible as a best-effort fallback.
+        return True
     if nb_expected is nbtypes.undefined:
         return True
     return nb_expected == _normalize_numba_arg_type(arg)
🤖 Prompt for AI Agents
In `@numbast/src/numbast/deduction.py` around lines 72 - 77,
_param_type_matches_arg calls to_numba_type(cxx_type) which can raise KeyError
for unknown C++ types; wrap the to_numba_type call in a try/except and handle
the KeyError by treating unknown mappings the same way as nbtypes.undefined
(i.e., return True or otherwise mark as compatible). Update the function
_param_type_matches_arg to catch KeyError from to_numba_type, log or comment if
desired, and then fallback to the existing branch that returns True for
undefined types; keep using _normalize_numba_arg_type(arg) for the final
comparison.

Comment on lines +80 to +105
def _deduce_from_type_pattern(
cxx_type: str, arg_cxx: str, placeholders: Iterable[str]
) -> dict[str, str] | None:
placeholders_in_type = [p for p in placeholders if p in cxx_type]
if not placeholders_in_type:
return None

pattern = re.escape(cxx_type)
order: list[str] = []
for ph in placeholders_in_type:
pattern = pattern.replace(re.escape(ph), r"(.*?)")
order.append(ph)

match = re.fullmatch(pattern, arg_cxx)
if not match:
return None

deduced: dict[str, str] = {}
for ph, value in zip(order, match.groups()):
value = value.strip()
if not value:
return None
if ph in deduced and deduced[ph] != value:
return None
deduced[ph] = value
return deduced
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pattern matching bug when placeholder appears multiple times in a type.

When a placeholder appears multiple times in cxx_type (e.g., "pair<T, T>" or "Map<T, T>"), str.replace creates multiple capture groups, but order only contains one entry per unique placeholder. The zip truncates to the shorter iterable, ignoring subsequent occurrences.

Example: cxx_type="pair<T, T>", arg_cxx="pair<int, float>"

  • Pattern becomes pair<(.*?), (.*?)> with 2 groups
  • order = ["T"]
  • Match groups: ("int", "float")
  • zip(["T"], ("int", "float")) yields only ("T", "int")
  • Returns {"T": "int"} — incorrectly deduces despite conflicting values

The second occurrence float is never checked against the first deduction int.

🐛 Proposed fix
 def _deduce_from_type_pattern(
     cxx_type: str, arg_cxx: str, placeholders: Iterable[str]
 ) -> dict[str, str] | None:
     placeholders_in_type = [p for p in placeholders if p in cxx_type]
     if not placeholders_in_type:
         return None
 
     pattern = re.escape(cxx_type)
     order: list[str] = []
     for ph in placeholders_in_type:
-        pattern = pattern.replace(re.escape(ph), r"(.*?)")
-        order.append(ph)
+        escaped_ph = re.escape(ph)
+        # Count occurrences and replace each, tracking in order
+        count = pattern.count(escaped_ph)
+        pattern = pattern.replace(escaped_ph, r"(.*?)")
+        order.extend([ph] * count)
 
     match = re.fullmatch(pattern, arg_cxx)
     if not match:
         return None
 
     deduced: dict[str, str] = {}
     for ph, value in zip(order, match.groups()):
         value = value.strip()
         if not value:
             return None
         if ph in deduced and deduced[ph] != value:
             return None
         deduced[ph] = value
     return deduced
🤖 Prompt for AI Agents
In `@numbast/src/numbast/deduction.py` around lines 80 - 105,
_deduce_from_type_pattern currently records each placeholder only once in order
while pattern replacement creates a capture group for every occurrence; update
the logic so order records the placeholder for each occurrence (i.e., append ph
each time you replace it) so that match.groups() can be zipped to every
placeholder occurrence and conflicting values are checked against previously
deduced values. Concretely, iterate through cxx_type occurrences (or perform the
replacements with a callback) and for each found placeholder add a corresponding
r"(.*?)" to the regex and append that placeholder to order; then keep the
existing match, strip, emptiness check and conflict detection against deduced in
the same way.

Comment on lines +195 to +262
def deduce_templated_overloads(
*,
qualname: str,
overloads: list[FunctionTemplate],
args: tuple[nbtypes.Type, ...],
overrides: dict | None = None,
debug: bool | None = None,
) -> tuple[list[FunctionTemplate], list[Exception]]:
"""
Perform template argument deduction for templated method overloads.
Returns a list of FunctionTemplate objects with fully-specialized
Function/Method types, plus any arg_intent-related errors encountered
while computing visible arity.
Enable debug output by passing debug=True or setting the
NUMBAST_TAD_DEBUG=1 environment variable.
"""
specialized: list[FunctionTemplate] = []
intent_errors: list[Exception] = []

_debug_print(
debug,
f"begin: {qualname}, overloads={len(overloads)}, args={len(args)}, "
f"overrides={'yes' if overrides else 'no'}",
)

for idx, templ in enumerate(overloads):
_debug_print(
debug,
f"overload[{idx}] {templ.function.name}: "
f"params={[p.type_.unqualified_non_ref_type_name for p in templ.function.params]}",
)
if overrides is None:
visible_param_indices = tuple(range(len(templ.function.params)))
pass_ptr_mask = tuple(False for _ in visible_param_indices)
else:
try:
plan = compute_intent_plan(
params=templ.function.params,
param_types=templ.function.param_types,
overrides=overrides,
allow_out_return=True,
)
except Exception as exc:
intent_errors.append(exc)
_debug_print(
debug,
f" intent plan error: {exc}",
)
continue
visible_param_indices = plan.visible_param_indices
pass_ptr_mask = plan.pass_ptr_mask
_debug_print(
debug,
" intent plan: "
f"visible={plan.visible_param_indices}, "
f"out_return={plan.out_return_indices}, "
f"pass_ptr={plan.pass_ptr_mask}",
)

if len(visible_param_indices) != len(args):
_debug_print(
debug,
" skip: visible arity mismatch "
f"(visible={len(visible_param_indices)} vs args={len(args)})",
)
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

LGTM with optional refinement.

The main deduction function is well-structured with comprehensive debug logging. The control flow correctly handles arity filtering, intent planning, and error collection.

Consider narrowing the exception handling at line 239 to (ValueError, TypeError) based on compute_intent_plan's documented exceptions, though the current broad catch is defensible for surfacing all intent-related errors.

🤖 Prompt for AI Agents
In `@numbast/src/numbast/deduction.py` around lines 195 - 262, The except block in
deduce_templated_overloads currently catches Exception when calling
compute_intent_plan; narrow this to only the documented exceptions by changing
the handler to except (ValueError, TypeError) as exc so only intent-related
ValueError/TypeError cases are appended to intent_errors; locate the try/except
around compute_intent_plan (call site: compute_intent_plan, variables
plan/intent_errors) and replace the broad Exception catch with the tuple of
specific exception types.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@ci/check_style.sh`:
- Around line 7-9: The CI script currently calls "sudo apt-get update" and "sudo
apt-get install -y g++" which fails when sudo isn't present; update the install
block to check for root (use [ "$(id -u)" -eq 0 ]), and if running as root run
"apt-get update" and "apt-get install -y g++", otherwise fall back to "sudo
apt-get update" and "sudo apt-get install -y g++" so the commands succeed both
in root CI containers and non-root environments.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@ci/check_style.sh`:
- Around line 7-9: The CI install step can hang on interactive prompts; update
the apt-get invocation in ci/check_style.sh by setting the
DEBIAN_FRONTEND=noninteractive environment variable when running apt-get install
so it runs non-interactively (e.g., prefix the apt-get install -y g++ command
with DEBIAN_FRONTEND=noninteractive) while keeping apt-get update and the -y
flag intact; this ensures apt-get update and apt-get install -y g++ complete
reliably in the CI container.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant