diff --git a/symforce/codegen/__init__.py b/symforce/codegen/__init__.py index 3c5b22286..91bb100e8 100644 --- a/symforce/codegen/__init__.py +++ b/symforce/codegen/__init__.py @@ -11,4 +11,5 @@ from .codegen_config import CodegenConfig from .backends.cpp.cpp_config import CppConfig +from .backends.javascript.javascript_config import JavascriptConfig from .backends.python.python_config import PythonConfig diff --git a/symforce/codegen/backends/cpp/cpp_config.py b/symforce/codegen/backends/cpp/cpp_config.py index 0a61ce886..cc30af628 100644 --- a/symforce/codegen/backends/cpp/cpp_config.py +++ b/symforce/codegen/backends/cpp/cpp_config.py @@ -72,3 +72,9 @@ def printer(self) -> CodePrinter: @staticmethod def format_data_accessor(prefix: str, index: int) -> str: return f"{prefix}.Data()[{index}]" + + @staticmethod + def format_matrix_accessor(key: str, i: int, j: int = None) -> str: + if j is None: + return f"{key}({i}, {0})" + return f"{key}({i}, {j})" diff --git a/symforce/codegen/backends/javascript/__init__.py b/symforce/codegen/backends/javascript/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/symforce/codegen/backends/javascript/javascript_config.py b/symforce/codegen/backends/javascript/javascript_config.py new file mode 100644 index 000000000..890bb9d41 --- /dev/null +++ b/symforce/codegen/backends/javascript/javascript_config.py @@ -0,0 +1,58 @@ +# ---------------------------------------------------------------------------- +# SymForce - Copyright 2022, Skydio, Inc. +# This source code is under the Apache 2.0 license found in the LICENSE file. +# ---------------------------------------------------------------------------- + +from __future__ import annotations +from dataclasses import dataclass +from pathlib import Path + +from symforce import typing as T +from symforce.codegen.codegen_config import CodegenConfig + + +CURRENT_DIR = Path(__file__).parent + + +@dataclass +class JavascriptConfig(CodegenConfig): + """ + Code generation config for the javascript backend. + + Args: + doc_comment_line_prefix: Prefix applied to each line in a docstring + line_length: Maximum allowed line length in docstrings; used for formatting docstrings. + use_eigen_types: Use eigen_lcm types for vectors instead of lists + autoformat: Run a code formatter on the generated code + matrix_is_1D: geo.Matrix symbols get formatted as a 1D array + """ + + doc_comment_line_prefix: str = " * " + line_length: int = 100 + use_eigen_types: bool = True + # NOTE(hayk): Add JS autoformatter + autoformat: bool = False + + @classmethod + def backend_name(cls) -> str: + return "javascript" + + @classmethod + def template_dir(cls) -> Path: + return CURRENT_DIR / "templates" + + def templates_to_render(self, generated_file_name: str) -> T.List[T.Tuple[str, str]]: + return [ + ("function/FUNCTION.js.jinja", f"{generated_file_name}.js"), + ] + + def printer(self) -> "sm.CodePrinter": + from symforce.codegen.printers import javascript_code_printer + + return javascript_code_printer.JavascriptCodePrinter() + + @staticmethod + def format_matrix_accessor(key: str, i: int, j: int = None) -> str: + if j is None: + return f"{key}[{i}]" + return f"{key}[{i}][{j}]" diff --git a/symforce/codegen/backends/javascript/templates/function/FUNCTION.js.jinja b/symforce/codegen/backends/javascript/templates/function/FUNCTION.js.jinja new file mode 100644 index 000000000..75db3630c --- /dev/null +++ b/symforce/codegen/backends/javascript/templates/function/FUNCTION.js.jinja @@ -0,0 +1,11 @@ +{# ------------------------------------------------------------------------- #} +{# Function codegen template for Javascript #} +{# ------------------------------------------------------------------------- #} +{%- import "../util/util.jinja" as util with context -%} + +{% if spec.docstring %} +{{ util.print_docstring(spec.docstring) }} +{% endif %} +{{ util.function_declaration(spec) -}} { +{{ util.expr_code(spec) }} +} diff --git a/symforce/codegen/backends/javascript/templates/util/util.jinja b/symforce/codegen/backends/javascript/templates/util/util.jinja new file mode 100644 index 000000000..4ae07e23e --- /dev/null +++ b/symforce/codegen/backends/javascript/templates/util/util.jinja @@ -0,0 +1,97 @@ +{# ------------------------------------------------------------------------- #} +{# Utilities for Javascript code generation templates. #} +{# ------------------------------------------------------------------------- #} + +{# ------------------------------------------------------------------------- #} + + {# Format function docstring + # + # Args: + # docstring (str): + #} +{% macro print_docstring(docstring) %} +{%- if docstring %} + +/* +{%- for line in docstring.split('\n') %} +*{{ ' {}'.format(line).rstrip() }} +{% endfor -%} +*/ +{%- endif -%} +{% endmacro %} + +{# ------------------------------------------------------------------------- #} + +{# Generate function declaration + # + # Args: + # spec (Codegen): + #} +{%- macro function_declaration(spec) -%} +function {{ camelcase_to_snakecase(spec.name) }}( + {%- for name in spec.inputs.keys() -%} + {{ name }}{% if not loop.last %}, {% endif %} + {%- endfor -%}) +{% endmacro -%} + +{# ------------------------------------------------------------------------- #} + +{# Generate inner code for computing the given expression. + # + # Args: + # spec (Codegen): + #} +{% macro expr_code(spec) %} + // Total ops: {{ spec.print_code_results.total_ops }} + + // Input arrays + {% for name, type in spec.inputs.items() %} + {% set T = python_util.get_type(type) %} + {% if not issubclass(T, Values) and not issubclass(T, Matrix) and not is_symbolic(type) and not is_sequence(type) %} + _{{ name }} = {{ name }}.data + {% endif %} + {% endfor %} + + // Intermediate terms ({{ spec.print_code_results.intermediate_terms | length }}) + {% for lhs, rhs in spec.print_code_results.intermediate_terms %} + const {{ lhs }} = {{ rhs }}; + {% endfor %} + + // Output terms ({{ spec.outputs.items() | length }}) + {% for name, type, terms in spec.print_code_results.dense_terms %} + {%- set T = python_util.get_type(type) -%} + {% if issubclass(T, Matrix) and type.shape[1] > 1 %} + {% set rows = type.shape[0] %} + {% set cols = type.shape[1] %} + let _{{ name }} = [...Array({{ rows }})].map(e => Array({{ cols }})); + {% set ns = namespace(iter=0) %} + {% for i in range(rows) %} + {% for j in range(cols) %} + _{{ name }}[{{ i }}][{{ j }}] = {{ terms[ns.iter][1] }}; + {% set ns.iter = ns.iter + 1 %} + {% endfor %} + {% endfor %} + {% elif not is_symbolic(type) %} + {% set dims = ops.StorageOps.storage_dim(type) %} + let _{{name}} = new Array({{ dims }}); + {% for i in range(dims) %} + _{{ name }}[{{ i }}] = {{ terms[i][1] }}; + {% endfor %} + {% else %} + const _{{name}} = {{ terms[0][1] }}; + {% endif %} + + {% endfor %} + return { + {% for name, type in spec.outputs.items() %} + {% set T = python_util.get_type(type) %} + {% if issubclass(T, (Matrix, Values)) or is_sequence(type) or is_symbolic(type) %} + {{ name }}: _{{name}} + {%- else %} + {{ name }}: sym.{{T.__name__}}.from_storage(_{{name}}) + {% endif %} + {% if not loop.last %}, {% endif %} + + {% endfor %} + }; +{% endmacro %} diff --git a/symforce/codegen/backends/python/python_config.py b/symforce/codegen/backends/python/python_config.py index 5993d864c..11ac91c67 100644 --- a/symforce/codegen/backends/python/python_config.py +++ b/symforce/codegen/backends/python/python_config.py @@ -52,6 +52,12 @@ def templates_to_render(self, generated_file_name: str) -> T.List[T.Tuple[str, s ("function/__init__.py.jinja", "__init__.py"), ] + @staticmethod + def format_matrix_accessor(key: str, i: int, j: int = None) -> str: + if j is None: + return f"{key}[{i}]" + return f"{key}[{i}, {j}]" + def printer(self) -> CodePrinter: from symforce.codegen.backends.python import python_code_printer diff --git a/symforce/codegen/codegen_config.py b/symforce/codegen/codegen_config.py index dda67284a..950115890 100644 --- a/symforce/codegen/codegen_config.py +++ b/symforce/codegen/codegen_config.py @@ -77,3 +77,12 @@ def format_data_accessor(prefix: str, index: int) -> str: Format data for accessing a data array in code. """ return f"{prefix}.data[{index}]" + + @staticmethod + @abstractmethod + def format_matrix_accessor(key: str, i: int, j: int = None) -> str: + """ + Format accessor for 2D matrices. If j is None, it is a 1D vector type, which for some + languages is accessed with 2D indices and in some with 1D. + """ + pass diff --git a/symforce/codegen/codegen_util.py b/symforce/codegen/codegen_util.py index b8a6fd360..ac078679c 100644 --- a/symforce/codegen/codegen_util.py +++ b/symforce/codegen/codegen_util.py @@ -340,9 +340,12 @@ def get_formatted_list( formatted_symbols = [sf.Symbol(key)] flattened_value = [value] elif issubclass(arg_cls, sf.Matrix): - if config.matrix_is_1d: - # TODO(nathan): Not sure this works for 2D matrices. Get rid of this. - formatted_symbols = [sf.Symbol(f"{key}[{j}]") for j in range(storage_dim)] + if value.shape[1] == 1: + # Pass in None as the second index for 1D matrices, so the per-backend config can + # decide whether to use 1D or 2D indexing, depending on the language. + formatted_symbols = [] + for i in range(value.shape[0]): + formatted_symbols.append(sf.Symbol(config.format_matrix_accessor(key, i, None))) else: # NOTE(brad): The order of the symbols must match the storage order of sf.Matrix # (as returned by sf.Matrix.to_storage). Hence, if there storage order were @@ -351,7 +354,9 @@ def get_formatted_list( formatted_symbols = [] for j in range(value.shape[1]): for i in range(value.shape[0]): - formatted_symbols.append(sf.Symbol(f"{key}({i}, {j})")) + formatted_symbols.append( + sf.Symbol(config.format_matrix_accessor(key, i, j)) + ) flattened_value = ops.StorageOps.to_storage(value) diff --git a/symforce/codegen/printers/javascript_code_printer.py b/symforce/codegen/printers/javascript_code_printer.py new file mode 100644 index 000000000..0a41842fa --- /dev/null +++ b/symforce/codegen/printers/javascript_code_printer.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# SymForce - Copyright 2022, Skydio, Inc. +# This source code is under the Apache 2.0 license found in the LICENSE file. +# ---------------------------------------------------------------------------- + +from sympy.printing.jscode import JavascriptCodePrinter as SympyJsCodePrinter + +from symforce import typing as T + + +class JavascriptCodePrinter(SympyJsCodePrinter): + """ + Symforce customized code printer for Javascript. Modifies the Sympy printing + behavior for codegen compatibility and efficiency. + """ + + def __init__(self, settings: T.Dict[str, T.Any] = None) -> None: + settings = dict(settings or {},) + super().__init__(settings) diff --git a/symforce/codegen/template_util.py b/symforce/codegen/template_util.py index d36e1e31f..ea73912ba 100644 --- a/symforce/codegen/template_util.py +++ b/symforce/codegen/template_util.py @@ -29,10 +29,12 @@ class FileType(enum.Enum): CUDA = enum.auto() LCM = enum.auto() MAKEFILE = enum.auto() + JAVASCRIPT = enum.auto() TYPESCRIPT = enum.auto() @staticmethod def from_extension(extension: str) -> FileType: + # TODO(hayk): Move up to language-specific directory. (tag=centralize-language-diffs) if extension in ("c", "cpp", "cxx", "cc", "tcc", "h", "hpp", "hxx", "hh"): return FileType.CPP elif extension in ("cu", "cuh"): @@ -47,6 +49,8 @@ def from_extension(extension: str) -> FileType: return FileType.MAKEFILE elif extension == "ts": return FileType.TYPESCRIPT + elif extension == "js": + return FileType.JAVASCRIPT else: raise ValueError(f"Could not get FileType from extension {extension}") @@ -63,7 +67,7 @@ def comment_prefix(self) -> str: """ Return the comment prefix for this file type. """ - if self in (FileType.CPP, FileType.CUDA, FileType.LCM): + if self in (FileType.CPP, FileType.CUDA, FileType.LCM, FileType.JAVASCRIPT): return "//" elif self in (FileType.PYTHON, FileType.PYTHON_INTERFACE): return "#" diff --git a/test/symforce_javascript_codegen_test.py b/test/symforce_javascript_codegen_test.py new file mode 100644 index 000000000..ea3aa3dce --- /dev/null +++ b/test/symforce_javascript_codegen_test.py @@ -0,0 +1,52 @@ +# ---------------------------------------------------------------------------- +# SymForce - Copyright 2022, Skydio, Inc. +# This source code is under the Apache 2.0 license found in the LICENSE file. +# ---------------------------------------------------------------------------- +from pathlib import Path + +from symforce import codegen +from symforce import geo +from symforce import logger +from symforce import path_util +from symforce import sympy as sm +from symforce import typing as T +from symforce.test_util import TestCase + + +class SymforceJavascriptCodegenTest(TestCase): + """ + Simple test for the Javascript codegen backend. + """ + + @staticmethod + def javascript_codegen_example( + a: T.Scalar, b: geo.V2, c: geo.M22, epsilon: T.Scalar + ) -> T.Tuple[geo.V3, geo.M22, T.Scalar]: + return ( + geo.V3(a + c[0], sm.sin(b[0]) ** a, b[1] ** 2 / (a - b[0] - c[1] - epsilon)), + geo.M22( + [[-sm.atan2(b[1], a), (a + b[0]) / c[1, :].norm(epsilon=epsilon)], [1, c[1, 0]]] + ), + a ** 2, + ) + + def test_javascript_codegen(self) -> None: + for config in (codegen.PythonConfig(), codegen.CppConfig(), codegen.JavascriptConfig()): + cg = codegen.Codegen.function( + func=self.javascript_codegen_example, + config=config, + output_names=["d", "e", "f"], + ) + out_path = cg.generate_function().generated_files[0] + + logger.debug(Path(out_path).read_text()) + + if config.backend_name() == "javascript": + self.compare_or_update_file( + path=path_util.symforce_dir() / "test" / "test_data" / out_path.name, + new_file=out_path, + ) + + +if __name__ == "__main__": + TestCase.main() diff --git a/test/test_data/javascript_codegen_example.js b/test/test_data/javascript_codegen_example.js new file mode 100644 index 000000000..3b6460a5b --- /dev/null +++ b/test/test_data/javascript_codegen_example.js @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// This file was autogenerated by symforce from template: +// backends/javascript/templates/function/FUNCTION.js.jinja +// Do NOT modify by hand. +// ----------------------------------------------------------------------------- + + +/** +* This function was autogenerated from a symbolic function. Do not modify by hand. +* +* Symbolic function: javascript_codegen_example +* +* Args: +* a: Scalar +* b: Matrix21 +* c: Matrix22 +* epsilon: Scalar +* +* Outputs: +* d: Matrix31 +* e: Matrix22 +* f: Scalar +*/ +function javascript_codegen_example(a, b, c, epsilon) +{ + // Total ops: 18 + + // Input arrays + + // Intermediate terms (0) + + // Output terms (3) + let _d = new Array(3); + _d[0] = a + c[0][0]; + _d[1] = Math.pow(Math.sin(b[0]), a); + _d[2] = Math.pow(b[1], 2)/(a - b[0] - c[0][1] - epsilon); + + let _e = [...Array(2)].map(e => Array(2)); + _e[0][0] = -Math.atan2(b[1], a); + _e[0][1] = 1; + _e[1][0] = (a + b[0])/Math.sqrt(Math.pow(c[1][0], 2) + Math.pow(c[1][1], 2) + epsilon); + _e[1][1] = c[1][0]; + + const _f = Math.pow(a, 2); + + return { + d: _d, + e: _e, + f: _f + }; + +}