diff --git a/python/sdist/amici/swig.py b/python/sdist/amici/swig.py index e0c518957f..703d4eeffc 100644 --- a/python/sdist/amici/swig.py +++ b/python/sdist/amici/swig.py @@ -18,28 +18,26 @@ class TypeHintFixer(ast.NodeTransformer): "size_t": ast.Name("int"), "bool": ast.Name("bool"), "boolean": ast.Name("bool"), - "std::unique_ptr< amici::Solver >": ast.Constant("Solver"), - "amici::InternalSensitivityMethod": ast.Constant( + "std::unique_ptr< amici::Solver >": ast.Name("Solver"), + "amici::InternalSensitivityMethod": ast.Name( "InternalSensitivityMethod" ), - "amici::InterpolationType": ast.Constant("InterpolationType"), - "amici::LinearMultistepMethod": ast.Constant("LinearMultistepMethod"), - "amici::LinearSolver": ast.Constant("LinearSolver"), - "amici::Model *": ast.Constant("Model"), - "amici::Model const *": ast.Constant("Model"), - "amici::NewtonDampingFactorMode": ast.Constant( - "NewtonDampingFactorMode" - ), - "amici::NonlinearSolverIteration": ast.Constant( + "amici::InterpolationType": ast.Name("InterpolationType"), + "amici::LinearMultistepMethod": ast.Name("LinearMultistepMethod"), + "amici::LinearSolver": ast.Name("LinearSolver"), + "amici::Model *": ast.Name("Model"), + "amici::Model const *": ast.Name("Model"), + "amici::NewtonDampingFactorMode": ast.Name("NewtonDampingFactorMode"), + "amici::NonlinearSolverIteration": ast.Name( "NonlinearSolverIteration" ), - "amici::ObservableScaling": ast.Constant("ObservableScaling"), - "amici::ParameterScaling": ast.Constant("ParameterScaling"), - "amici::RDataReporting": ast.Constant("RDataReporting"), - "amici::SensitivityMethod": ast.Constant("SensitivityMethod"), - "amici::SensitivityOrder": ast.Constant("SensitivityOrder"), - "amici::Solver *": ast.Constant("Solver"), - "amici::SteadyStateSensitivityMode": ast.Constant( + "amici::ObservableScaling": ast.Name("ObservableScaling"), + "amici::ParameterScaling": ast.Name("ParameterScaling"), + "amici::RDataReporting": ast.Name("RDataReporting"), + "amici::SensitivityMethod": ast.Name("SensitivityMethod"), + "amici::SensitivityOrder": ast.Name("SensitivityOrder"), + "amici::Solver *": ast.Name("Solver"), + "amici::SteadyStateSensitivityMode": ast.Name( "SteadyStateSensitivityMode" ), "amici::realtype": ast.Name("float"), @@ -49,15 +47,23 @@ class TypeHintFixer(ast.NodeTransformer): "StringVector": ast.Name("Sequence[str]"), "std::string": ast.Name("str"), "std::string const &": ast.Name("str"), - "std::unique_ptr< amici::ExpData >": ast.Constant("ExpData"), - "std::unique_ptr< amici::ReturnData >": ast.Constant("ReturnData"), + "std::unique_ptr< amici::ExpData >": ast.Name("ExpData"), + "std::unique_ptr< amici::ReturnData >": ast.Name("ReturnData"), "std::vector< amici::ParameterScaling," - "std::allocator< amici::ParameterScaling > > const &": ast.Constant( + "std::allocator< amici::ParameterScaling > > const &": ast.Name( "ParameterScalingVector" ), - "H5::H5File": None, } + def __init__(self): + super().__init__() + + # Add all mapped-to type names to the mapping dict to convert any + # quoted occurrences of those types to unquoted types + for annot in list(self.mapping.values()): + if isinstance(annot, ast.Name): + self.mapping[annot.id] = annot + def visit_FunctionDef(self, node): # convert type/rtype from docstring to annotation, if possible. # those may be c++ types, not valid in python, that need to be @@ -140,7 +146,9 @@ def _annotation_from_docstring(self, node: ast.FunctionDef): for line_no, line in enumerate(docstring): if type_str := self.extract_rtype(line): # handle `:rtype:` - node.returns = ast.Constant(type_str) + node.returns = self.mapping.get( + type_str, ast.Constant(type_str) + ) lines_to_remove.add(line_no) continue @@ -149,7 +157,9 @@ def _annotation_from_docstring(self, node: ast.FunctionDef): # handle `:type ...:` for arg in node.args.args: if arg.arg == arg_name: - arg.annotation = ast.Constant(type_str) + arg.annotation = self.mapping.get( + type_str, ast.Name(type_str) + ) lines_to_remove.add(line_no) if lines_to_remove: @@ -160,7 +170,7 @@ def _annotation_from_docstring(self, node: ast.FunctionDef): for line_no, line in enumerate(docstring) if line_no not in lines_to_remove ) - node.body[0].value = ast.Str(new_docstring) + node.body[0].value = ast.Constant(new_docstring) @staticmethod def extract_type(line: str) -> tuple[str, str] | tuple[None, None]: