Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 35 additions & 25 deletions python/sdist/amici/swig.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,26 @@ class TypeHintFixer(ast.NodeTransformer):
"size_t": ast.Name("int"),
"bool": ast.Name("bool"),
"boolean": ast.Name("bool"),
"std::unique_ptr< amici::Solver >": ast.Constant("Solver"),
"amici::InternalSensitivityMethod": ast.Constant(
"std::unique_ptr< amici::Solver >": ast.Name("Solver"),
"amici::InternalSensitivityMethod": ast.Name(
"InternalSensitivityMethod"
),
"amici::InterpolationType": ast.Constant("InterpolationType"),
"amici::LinearMultistepMethod": ast.Constant("LinearMultistepMethod"),
"amici::LinearSolver": ast.Constant("LinearSolver"),
"amici::Model *": ast.Constant("Model"),
"amici::Model const *": ast.Constant("Model"),
"amici::NewtonDampingFactorMode": ast.Constant(
"NewtonDampingFactorMode"
),
"amici::NonlinearSolverIteration": ast.Constant(
"amici::InterpolationType": ast.Name("InterpolationType"),
"amici::LinearMultistepMethod": ast.Name("LinearMultistepMethod"),
"amici::LinearSolver": ast.Name("LinearSolver"),
"amici::Model *": ast.Name("Model"),
"amici::Model const *": ast.Name("Model"),
"amici::NewtonDampingFactorMode": ast.Name("NewtonDampingFactorMode"),
"amici::NonlinearSolverIteration": ast.Name(
"NonlinearSolverIteration"
),
"amici::ObservableScaling": ast.Constant("ObservableScaling"),
"amici::ParameterScaling": ast.Constant("ParameterScaling"),
"amici::RDataReporting": ast.Constant("RDataReporting"),
"amici::SensitivityMethod": ast.Constant("SensitivityMethod"),
"amici::SensitivityOrder": ast.Constant("SensitivityOrder"),
"amici::Solver *": ast.Constant("Solver"),
"amici::SteadyStateSensitivityMode": ast.Constant(
"amici::ObservableScaling": ast.Name("ObservableScaling"),
"amici::ParameterScaling": ast.Name("ParameterScaling"),
"amici::RDataReporting": ast.Name("RDataReporting"),
"amici::SensitivityMethod": ast.Name("SensitivityMethod"),
"amici::SensitivityOrder": ast.Name("SensitivityOrder"),
"amici::Solver *": ast.Name("Solver"),
"amici::SteadyStateSensitivityMode": ast.Name(
"SteadyStateSensitivityMode"
),
"amici::realtype": ast.Name("float"),
Expand All @@ -49,15 +47,23 @@ class TypeHintFixer(ast.NodeTransformer):
"StringVector": ast.Name("Sequence[str]"),
"std::string": ast.Name("str"),
"std::string const &": ast.Name("str"),
"std::unique_ptr< amici::ExpData >": ast.Constant("ExpData"),
"std::unique_ptr< amici::ReturnData >": ast.Constant("ReturnData"),
"std::unique_ptr< amici::ExpData >": ast.Name("ExpData"),
"std::unique_ptr< amici::ReturnData >": ast.Name("ReturnData"),
"std::vector< amici::ParameterScaling,"
"std::allocator< amici::ParameterScaling > > const &": ast.Constant(
"std::allocator< amici::ParameterScaling > > const &": ast.Name(
"ParameterScalingVector"
),
"H5::H5File": None,
}

def __init__(self):
super().__init__()

# Add all mapped-to type names to the mapping dict to convert any
# quoted occurrences of those types to unquoted types
for annot in list(self.mapping.values()):
if isinstance(annot, ast.Name):
self.mapping[annot.id] = annot

def visit_FunctionDef(self, node):
# convert type/rtype from docstring to annotation, if possible.
# those may be c++ types, not valid in python, that need to be
Expand Down Expand Up @@ -140,7 +146,9 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
for line_no, line in enumerate(docstring):
if type_str := self.extract_rtype(line):
# handle `:rtype:`
node.returns = ast.Constant(type_str)
node.returns = self.mapping.get(
type_str, ast.Constant(type_str)
)
lines_to_remove.add(line_no)
continue

Expand All @@ -149,7 +157,9 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
# handle `:type ...:`
for arg in node.args.args:
if arg.arg == arg_name:
arg.annotation = ast.Constant(type_str)
arg.annotation = self.mapping.get(
type_str, ast.Name(type_str)
)
lines_to_remove.add(line_no)

if lines_to_remove:
Expand All @@ -160,7 +170,7 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
for line_no, line in enumerate(docstring)
if line_no not in lines_to_remove
)
node.body[0].value = ast.Str(new_docstring)
node.body[0].value = ast.Constant(new_docstring)

@staticmethod
def extract_type(line: str) -> tuple[str, str] | tuple[None, None]:
Expand Down
Loading