-
Notifications
You must be signed in to change notification settings - Fork 17
Add template argument deduction module and tests #283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughImplements a new template overload deduction engine that maps Numba types to C++-style strings, deduces template parameters from argument types, specializes templated functions and struct methods, adds tests, updates pointer type handling, and tweaks CI for g++ availability. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Deductor as deduce_templated_overloads
participant IntentPlanner as Intent Planner
participant TypeDeducer as Type Pattern Deducer
participant Specializer as Specializer
participant Validator as Placeholder Validator
Caller->>Deductor: provide overloads, args, overrides, debug
activate Deductor
alt overrides provided
Deductor->>IntentPlanner: compute intent plan (visible params)
IntentPlanner-->>Deductor: visible params + intent errors
else no overrides
Deductor->>Deductor: use all function parameters
end
loop per overload
Deductor->>Deductor: validate arity against visible params
alt arity matches
loop per visible parameter
Deductor->>TypeDeducer: match param pattern with arg type
TypeDeducer-->>Deductor: mapping fragment or conflict
end
Deductor->>Specializer: apply combined mappings to template
Specializer-->>Deductor: specialized function/struct method
Deductor->>Validator: check for unresolved placeholders
Validator-->>Deductor: resolved? accept : discard
else arity mismatch
Deductor-->>Deductor: skip overload
end
end
Deductor-->>Caller: return specialized templates + collected errors
deactivate Deductor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@numbast/src/numbast/deduction.py`:
- Around line 195-262: The except block in deduce_templated_overloads currently
catches Exception when calling compute_intent_plan; narrow this to only the
documented exceptions by changing the handler to except (ValueError, TypeError)
as exc so only intent-related ValueError/TypeError cases are appended to
intent_errors; locate the try/except around compute_intent_plan (call site:
compute_intent_plan, variables plan/intent_errors) and replace the broad
Exception catch with the tuple of specific exception types.
- Around line 64-69: The helper _numba_arg_to_cxx_type currently catches all
Exceptions when calling to_c_type_str(arg), which can hide programming errors;
change the broad except to catch only ValueError (the error to_c_type_str raises
for unknown Numba types) so that other exceptions (e.g.,
AttributeError/TypeError) still surface; keep the behavior of normalizing the
arg via _normalize_numba_arg_type and returning None on ValueError after
attempting to normalize the C++ type with _normalize_cxx_type_str.
- Around line 72-77: _param_type_matches_arg calls to_numba_type(cxx_type) which
can raise KeyError for unknown C++ types; wrap the to_numba_type call in a
try/except and handle the KeyError by treating unknown mappings the same way as
nbtypes.undefined (i.e., return True or otherwise mark as compatible). Update
the function _param_type_matches_arg to catch KeyError from to_numba_type, log
or comment if desired, and then fallback to the existing branch that returns
True for undefined types; keep using _normalize_numba_arg_type(arg) for the
final comparison.
- Around line 80-105: _deduce_from_type_pattern currently records each
placeholder only once in order while pattern replacement creates a capture group
for every occurrence; update the logic so order records the placeholder for each
occurrence (i.e., append ph each time you replace it) so that match.groups() can
be zipped to every placeholder occurrence and conflicting values are checked
against previously deduced values. Concretely, iterate through cxx_type
occurrences (or perform the replacements with a callback) and for each found
placeholder add a corresponding r"(.*?)" to the regex and append that
placeholder to order; then keep the existing match, strip, emptiness check and
conflict detection against deduced in the same way.
| def _numba_arg_to_cxx_type(arg: nbtypes.Type) -> str | None: | ||
| arg = _normalize_numba_arg_type(arg) | ||
| try: | ||
| return _normalize_cxx_type_str(to_c_type_str(arg)) | ||
| except Exception: | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial
Narrow the exception handling to ValueError.
Based on the to_c_type_str implementation in numbast/types.py, it raises ValueError for unknown Numba types. Catching bare Exception could mask unrelated programming errors (e.g., AttributeError, TypeError).
♻️ Proposed fix
def _numba_arg_to_cxx_type(arg: nbtypes.Type) -> str | None:
arg = _normalize_numba_arg_type(arg)
try:
return _normalize_cxx_type_str(to_c_type_str(arg))
- except Exception:
+ except ValueError:
return None📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _numba_arg_to_cxx_type(arg: nbtypes.Type) -> str | None: | |
| arg = _normalize_numba_arg_type(arg) | |
| try: | |
| return _normalize_cxx_type_str(to_c_type_str(arg)) | |
| except Exception: | |
| return None | |
| def _numba_arg_to_cxx_type(arg: nbtypes.Type) -> str | None: | |
| arg = _normalize_numba_arg_type(arg) | |
| try: | |
| return _normalize_cxx_type_str(to_c_type_str(arg)) | |
| except ValueError: | |
| return None |
🤖 Prompt for AI Agents
In `@numbast/src/numbast/deduction.py` around lines 64 - 69, The helper
_numba_arg_to_cxx_type currently catches all Exceptions when calling
to_c_type_str(arg), which can hide programming errors; change the broad except
to catch only ValueError (the error to_c_type_str raises for unknown Numba
types) so that other exceptions (e.g., AttributeError/TypeError) still surface;
keep the behavior of normalizing the arg via _normalize_numba_arg_type and
returning None on ValueError after attempting to normalize the C++ type with
_normalize_cxx_type_str.
| def _param_type_matches_arg(cxx_type: str, arg: nbtypes.Type) -> bool: | ||
| """Best-effort compatibility check for non-templated parameters.""" | ||
| nb_expected = to_numba_type(cxx_type) | ||
| if nb_expected is nbtypes.undefined: | ||
| return True | ||
| return nb_expected == _normalize_numba_arg_type(arg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing exception handling for unknown C++ types.
to_numba_type can raise KeyError when cxx_type is not in CTYPE_MAPS. This could cause the deduction process to fail unexpectedly for unsupported types. Consider handling the exception gracefully, similar to _numba_arg_to_cxx_type.
🐛 Proposed fix
def _param_type_matches_arg(cxx_type: str, arg: nbtypes.Type) -> bool:
"""Best-effort compatibility check for non-templated parameters."""
- nb_expected = to_numba_type(cxx_type)
+ try:
+ nb_expected = to_numba_type(cxx_type)
+ except KeyError:
+ # Unknown C++ type; assume compatible as a best-effort fallback.
+ return True
if nb_expected is nbtypes.undefined:
return True
return nb_expected == _normalize_numba_arg_type(arg)🤖 Prompt for AI Agents
In `@numbast/src/numbast/deduction.py` around lines 72 - 77,
_param_type_matches_arg calls to_numba_type(cxx_type) which can raise KeyError
for unknown C++ types; wrap the to_numba_type call in a try/except and handle
the KeyError by treating unknown mappings the same way as nbtypes.undefined
(i.e., return True or otherwise mark as compatible). Update the function
_param_type_matches_arg to catch KeyError from to_numba_type, log or comment if
desired, and then fallback to the existing branch that returns True for
undefined types; keep using _normalize_numba_arg_type(arg) for the final
comparison.
| def _deduce_from_type_pattern( | ||
| cxx_type: str, arg_cxx: str, placeholders: Iterable[str] | ||
| ) -> dict[str, str] | None: | ||
| placeholders_in_type = [p for p in placeholders if p in cxx_type] | ||
| if not placeholders_in_type: | ||
| return None | ||
|
|
||
| pattern = re.escape(cxx_type) | ||
| order: list[str] = [] | ||
| for ph in placeholders_in_type: | ||
| pattern = pattern.replace(re.escape(ph), r"(.*?)") | ||
| order.append(ph) | ||
|
|
||
| match = re.fullmatch(pattern, arg_cxx) | ||
| if not match: | ||
| return None | ||
|
|
||
| deduced: dict[str, str] = {} | ||
| for ph, value in zip(order, match.groups()): | ||
| value = value.strip() | ||
| if not value: | ||
| return None | ||
| if ph in deduced and deduced[ph] != value: | ||
| return None | ||
| deduced[ph] = value | ||
| return deduced |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pattern matching bug when placeholder appears multiple times in a type.
When a placeholder appears multiple times in cxx_type (e.g., "pair<T, T>" or "Map<T, T>"), str.replace creates multiple capture groups, but order only contains one entry per unique placeholder. The zip truncates to the shorter iterable, ignoring subsequent occurrences.
Example: cxx_type="pair<T, T>", arg_cxx="pair<int, float>"
- Pattern becomes
pair<(.*?), (.*?)>with 2 groups order = ["T"]- Match groups:
("int", "float") zip(["T"], ("int", "float"))yields only("T", "int")- Returns
{"T": "int"}— incorrectly deduces despite conflicting values
The second occurrence float is never checked against the first deduction int.
🐛 Proposed fix
def _deduce_from_type_pattern(
cxx_type: str, arg_cxx: str, placeholders: Iterable[str]
) -> dict[str, str] | None:
placeholders_in_type = [p for p in placeholders if p in cxx_type]
if not placeholders_in_type:
return None
pattern = re.escape(cxx_type)
order: list[str] = []
for ph in placeholders_in_type:
- pattern = pattern.replace(re.escape(ph), r"(.*?)")
- order.append(ph)
+ escaped_ph = re.escape(ph)
+ # Count occurrences and replace each, tracking in order
+ count = pattern.count(escaped_ph)
+ pattern = pattern.replace(escaped_ph, r"(.*?)")
+ order.extend([ph] * count)
match = re.fullmatch(pattern, arg_cxx)
if not match:
return None
deduced: dict[str, str] = {}
for ph, value in zip(order, match.groups()):
value = value.strip()
if not value:
return None
if ph in deduced and deduced[ph] != value:
return None
deduced[ph] = value
return deduced🤖 Prompt for AI Agents
In `@numbast/src/numbast/deduction.py` around lines 80 - 105,
_deduce_from_type_pattern currently records each placeholder only once in order
while pattern replacement creates a capture group for every occurrence; update
the logic so order records the placeholder for each occurrence (i.e., append ph
each time you replace it) so that match.groups() can be zipped to every
placeholder occurrence and conflicting values are checked against previously
deduced values. Concretely, iterate through cxx_type occurrences (or perform the
replacements with a callback) and for each found placeholder add a corresponding
r"(.*?)" to the regex and append that placeholder to order; then keep the
existing match, strip, emptiness check and conflict detection against deduced in
the same way.
| def deduce_templated_overloads( | ||
| *, | ||
| qualname: str, | ||
| overloads: list[FunctionTemplate], | ||
| args: tuple[nbtypes.Type, ...], | ||
| overrides: dict | None = None, | ||
| debug: bool | None = None, | ||
| ) -> tuple[list[FunctionTemplate], list[Exception]]: | ||
| """ | ||
| Perform template argument deduction for templated method overloads. | ||
| Returns a list of FunctionTemplate objects with fully-specialized | ||
| Function/Method types, plus any arg_intent-related errors encountered | ||
| while computing visible arity. | ||
| Enable debug output by passing debug=True or setting the | ||
| NUMBAST_TAD_DEBUG=1 environment variable. | ||
| """ | ||
| specialized: list[FunctionTemplate] = [] | ||
| intent_errors: list[Exception] = [] | ||
|
|
||
| _debug_print( | ||
| debug, | ||
| f"begin: {qualname}, overloads={len(overloads)}, args={len(args)}, " | ||
| f"overrides={'yes' if overrides else 'no'}", | ||
| ) | ||
|
|
||
| for idx, templ in enumerate(overloads): | ||
| _debug_print( | ||
| debug, | ||
| f"overload[{idx}] {templ.function.name}: " | ||
| f"params={[p.type_.unqualified_non_ref_type_name for p in templ.function.params]}", | ||
| ) | ||
| if overrides is None: | ||
| visible_param_indices = tuple(range(len(templ.function.params))) | ||
| pass_ptr_mask = tuple(False for _ in visible_param_indices) | ||
| else: | ||
| try: | ||
| plan = compute_intent_plan( | ||
| params=templ.function.params, | ||
| param_types=templ.function.param_types, | ||
| overrides=overrides, | ||
| allow_out_return=True, | ||
| ) | ||
| except Exception as exc: | ||
| intent_errors.append(exc) | ||
| _debug_print( | ||
| debug, | ||
| f" intent plan error: {exc}", | ||
| ) | ||
| continue | ||
| visible_param_indices = plan.visible_param_indices | ||
| pass_ptr_mask = plan.pass_ptr_mask | ||
| _debug_print( | ||
| debug, | ||
| " intent plan: " | ||
| f"visible={plan.visible_param_indices}, " | ||
| f"out_return={plan.out_return_indices}, " | ||
| f"pass_ptr={plan.pass_ptr_mask}", | ||
| ) | ||
|
|
||
| if len(visible_param_indices) != len(args): | ||
| _debug_print( | ||
| debug, | ||
| " skip: visible arity mismatch " | ||
| f"(visible={len(visible_param_indices)} vs args={len(args)})", | ||
| ) | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial
LGTM with optional refinement.
The main deduction function is well-structured with comprehensive debug logging. The control flow correctly handles arity filtering, intent planning, and error collection.
Consider narrowing the exception handling at line 239 to (ValueError, TypeError) based on compute_intent_plan's documented exceptions, though the current broad catch is defensible for surfacing all intent-related errors.
🤖 Prompt for AI Agents
In `@numbast/src/numbast/deduction.py` around lines 195 - 262, The except block in
deduce_templated_overloads currently catches Exception when calling
compute_intent_plan; narrow this to only the documented exceptions by changing
the handler to except (ValueError, TypeError) as exc so only intent-related
ValueError/TypeError cases are appended to intent_errors; locate the try/except
around compute_intent_plan (call site: compute_intent_plan, variables
plan/intent_errors) and replace the broad Exception catch with the tuple of
specific exception types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@ci/check_style.sh`:
- Around line 7-9: The CI script currently calls "sudo apt-get update" and "sudo
apt-get install -y g++" which fails when sudo isn't present; update the install
block to check for root (use [ "$(id -u)" -eq 0 ]), and if running as root run
"apt-get update" and "apt-get install -y g++", otherwise fall back to "sudo
apt-get update" and "sudo apt-get install -y g++" so the commands succeed both
in root CI containers and non-root environments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@ci/check_style.sh`:
- Around line 7-9: The CI install step can hang on interactive prompts; update
the apt-get invocation in ci/check_style.sh by setting the
DEBIAN_FRONTEND=noninteractive environment variable when running apt-get install
so it runs non-interactively (e.g., prefix the apt-get install -y g++ command
with DEBIAN_FRONTEND=noninteractive) while keeping apt-get update and the -y
flag intact; this ensures apt-get update and apt-get install -y g++ complete
reliably in the CI container.
This adds a template argument deduction module and unit tests, and it incorporates the recently added argument intent concept from Numbast PR #278 so visible arity and pointer passing stay consistent across templated overloads.
At a high level, the module accepts a list of
ast_canopy-parsed function templates plus optional intent overrides, and returnsFunctionTemplateinstances with deduced argument types for the original parameter names.Tested cases:
template <typename T> __device__ T add(T a, T b);andtemplate <typename T> __device__ T add(T a, T b, T c);add(int, int)picks the two-arg overload and deducesT=int.template <typename T> __device__ T add(T a, T b);add(int, float)yields no specialization becauseTconflicts.template <typename T> __device__ T add_int(int a, T b);add_int(int, float)specializes,add_int(float, float)does not.template <typename T> __device__ T return_only();return_only<T>()cannot deduceTwith no args, so it is skipped.template <typename T> __device__ void store_ref(T &out, T value);store_ref(CPointer(int), int)withoverrides={"out": "out_ptr"}or{"out": "inout_ptr"}deducesT=int;store_ref(int)withoverrides={"out": "out_return"}also deducesT=intvia hidden out param.template <typename T> __device__ void bad_out(T value);bad_out(value=out_ptr)reports aValueErrorinintent_errors.struct Box { template <typename T> __device__ T mul(T a, T b) const; };Box::mul(float, float)deducesT=float.template <typename T> __device__ T add(T a, T b);add(float32[:], float32[:])yields no specialization.This currently is a standalone module that's not wired into any existing binding generation so that it can be tested and reasoned independently.
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.