diff --git a/CHANGES.rst b/CHANGES.rst index cbb5a4f..c676a40 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,23 @@ Version history =============== +**UNRELEASED** + +- **BACKWARD INCOMPATIBLE** API changes (for those who customize code generation by + subclassing the existing generators): + + * Added new optional keyword argument, ``explicit_foreign_keys`` to + ``DeclarativeGenerator``, to force foreign keys to be rendered as + ``ClassName.attribute_name`` string references + * Removed the ``render_relationship_args()`` method from the SQLModel generator + * Added two new methods for customizing relationship rendering in + ``DeclarativeGenerator``: + + * ``render_relationship_annotation()``: returns the appropriate type annotation + (without the ``Mapped`` wrapper) for the relationship + * ``render_relationship_arguments()``: returns a dictionary of keyword arguments to + ``sqlalchemy.orm.relationship()`` + **4.0.0rc3** - **BACKWARD INCOMPATIBLE** Relationship names changed when multiple FKs or junction tables @@ -14,6 +31,8 @@ Version history ``students_enrollments``). Use ``--options nofknames`` to revert to old behavior. (PR by @sheinbergon) - Fixed ``Index`` kwargs (e.g. ``mysql_length``) being ignored during code generation (PR by @luliangce) +- Fixed the SQLModel generator not adding the ``foreign_keys`` parameters when + generating multiple relationships between the same two tables **4.0.0rc2** diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index c5eef94..83fe982 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -5,7 +5,7 @@ import sys from abc import ABCMeta, abstractmethod from collections import defaultdict -from collections.abc import Collection, Iterable, Sequence +from collections.abc import Collection, Iterable, Mapping, Sequence from dataclasses import dataclass from importlib import import_module from inspect import Parameter @@ -1001,10 +1001,12 @@ def __init__( *, indentation: str = " ", base_class_name: str = "Base", + explicit_foreign_keys: bool = False, ): super().__init__(metadata, bind, options, indentation=indentation) self.base_class_name: str = base_class_name self.inflect_engine = inflect.engine() + self.explicit_foreign_keys = explicit_foreign_keys def generate_base(self) -> None: self.base = Base( @@ -1626,6 +1628,33 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str: return f"{column_attr.name}: Mapped[{rendered_column_python_type}] = {rendered_column}" def render_relationship(self, relationship: RelationshipAttribute) -> str: + kwargs = self.render_relationship_arguments(relationship) + annotation = self.render_relationship_annotation(relationship) + rendered_relationship = render_callable( + "relationship", repr(relationship.target.name), kwargs=kwargs + ) + return f"{relationship.name}: Mapped[{annotation}] = {rendered_relationship}" + + def render_relationship_annotation( + self, relationship: RelationshipAttribute + ) -> str: + match relationship.type: + case RelationshipType.ONE_TO_MANY: + return f"list[{relationship.target.name!r}]" + case RelationshipType.ONE_TO_ONE | RelationshipType.MANY_TO_ONE: + if relationship.constraint and any( + col.nullable for col in relationship.constraint.columns + ): + self.add_literal_import("typing", "Optional") + return f"Optional[{relationship.target.name!r}]" + else: + return f"'{relationship.target.name}'" + case RelationshipType.MANY_TO_MANY: + return f"list[{relationship.target.name!r}]" + + def render_relationship_arguments( + self, relationship: RelationshipAttribute + ) -> Mapping[str, Any]: def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str: rendered = [] for attr in column_attrs: @@ -1641,7 +1670,7 @@ def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str: render_as_string = False # Assume that column_attrs are all in relationship.source or none for attr in column_attrs: - if attr.model is relationship.source: + if not self.explicit_foreign_keys and attr.model is relationship.source: rendered.append(attr.name) else: rendered.append(f"{attr.model.name}.{attr.name}") @@ -1697,33 +1726,7 @@ def render_join(terms: list[JoinType]) -> str: if relationship.backref: kwargs["back_populates"] = repr(relationship.backref.name) - rendered_relationship = render_callable( - "relationship", repr(relationship.target.name), kwargs=kwargs - ) - - relationship_type: str - if relationship.type == RelationshipType.ONE_TO_MANY: - relationship_type = f"list['{relationship.target.name}']" - elif relationship.type in ( - RelationshipType.ONE_TO_ONE, - RelationshipType.MANY_TO_ONE, - ): - relationship_type = f"'{relationship.target.name}'" - if relationship.constraint and any( - col.nullable for col in relationship.constraint.columns - ): - self.add_literal_import("typing", "Optional") - relationship_type = f"Optional[{relationship_type}]" - elif relationship.type == RelationshipType.MANY_TO_MANY: - relationship_type = f"list['{relationship.target.name}']" - else: - self.add_literal_import("typing", "Any") - relationship_type = "Any" - - return ( - f"{relationship.name}: Mapped[{relationship_type}] " - f"= {rendered_relationship}" - ) + return kwargs class DataclassGenerator(DeclarativeGenerator): @@ -1778,6 +1781,7 @@ def __init__( options, indentation=indentation, base_class_name=base_class_name, + explicit_foreign_keys=True, ) @property @@ -1858,34 +1862,26 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str: return f"{column_attr.name}: {rendered_column_python_type} = {rendered_field}" def render_relationship(self, relationship: RelationshipAttribute) -> str: - rendered = super().render_relationship(relationship).partition(" = ")[2] - args = self.render_relationship_args(rendered) - kwargs: dict[str, Any] = {} - annotation = repr(relationship.target.name) + kwargs = self.render_relationship_arguments(relationship) + annotation = self.render_relationship_annotation(relationship) + + native_kwargs: dict[str, Any] = {} + non_native_kwargs: dict[str, Any] = {} + for key, value in kwargs.items(): + # The following keyword arguments are natively supported in Relationship + if key in ("back_populates", "cascade_delete", "passive_deletes"): + native_kwargs[key] = value + else: + non_native_kwargs[key] = value - if relationship.type in ( - RelationshipType.ONE_TO_MANY, - RelationshipType.MANY_TO_MANY, - ): - annotation = f"list[{annotation}]" - else: - self.add_literal_import("typing", "Optional") - annotation = f"Optional[{annotation}]" + if non_native_kwargs: + native_kwargs["sa_relationship_kwargs"] = ( + "{" + + ", ".join( + f"{key!r}: {value}" for key, value in non_native_kwargs.items() + ) + + "}" + ) - rendered_field = render_callable("Relationship", *args, kwargs=kwargs) + rendered_field = render_callable("Relationship", kwargs=native_kwargs) return f"{relationship.name}: {annotation} = {rendered_field}" - - def render_relationship_args(self, arguments: str) -> list[str]: - argument_list = arguments.split(",") - # delete ')' and ' ' from args - argument_list[-1] = argument_list[-1][:-1] - argument_list = [argument[1:] for argument in argument_list] - - rendered_args: list[str] = [] - for arg in argument_list: - if "back_populates" in arg: - rendered_args.append(arg) - if "uselist=False" in arg: - rendered_args.append("sa_relationship_kwargs={'uselist': False}") - - return rendered_args diff --git a/tests/test_generator_sqlmodel.py b/tests/test_generator_sqlmodel.py index 55a7f11..5314a79 100644 --- a/tests/test_generator_sqlmodel.py +++ b/tests/test_generator_sqlmodel.py @@ -142,6 +142,66 @@ class SimpleGoods(SQLModel, table=True): ) +def test_onetomany_multiref(generator: CodeGenerator) -> None: + Table( + "simple_items_multiref", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("parent_container_id", INTEGER), + Column("top_container_id", INTEGER, nullable=False), + ForeignKeyConstraint( + ["parent_container_id"], ["simple_containers_multiref.id"] + ), + ForeignKeyConstraint(["top_container_id"], ["simple_containers_multiref.id"]), + ) + Table( + "simple_containers_multiref", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ + from typing import Optional + + from sqlalchemy import Column, ForeignKey, Integer + from sqlmodel import Field, Relationship, SQLModel + + class SimpleContainersMultiref(SQLModel, table=True): + __tablename__ = 'simple_containers_multiref' + + id: int = Field(sa_column=Column('id', Integer, primary_key=True)) + + simple_items_multiref_parent_container: list['SimpleItemsMultiref'] = \ +Relationship(back_populates='parent_container', sa_relationship_kwargs={\ +'foreign_keys': '[SimpleItemsMultiref.parent_container_id]'}) + simple_items_multiref_top_container: list['SimpleItemsMultiref'] = \ +Relationship(back_populates='top_container', sa_relationship_kwargs={'foreign_keys': \ +'[SimpleItemsMultiref.top_container_id]'}) + + + class SimpleItemsMultiref(SQLModel, table=True): + __tablename__ = 'simple_items_multiref' + + id: int = Field(sa_column=Column('id', Integer, primary_key=True)) + top_container_id: int = \ +Field(sa_column=Column('top_container_id', \ +ForeignKey('simple_containers_multiref.id'), nullable=False)) + parent_container_id: Optional[int] = \ +Field(default=None, sa_column=Column('parent_container_id', \ +ForeignKey('simple_containers_multiref.id'))) + + parent_container: Optional['SimpleContainersMultiref'] = Relationship(\ +back_populates='simple_items_multiref_parent_container', sa_relationship_kwargs={\ +'foreign_keys': '[SimpleItemsMultiref.parent_container_id]'}) + top_container: 'SimpleContainersMultiref' = Relationship(\ +back_populates='simple_items_multiref_top_container', sa_relationship_kwargs={\ +'foreign_keys': '[SimpleItemsMultiref.top_container_id]'}) + """, + ) + + def test_onetoone(generator: CodeGenerator) -> None: Table( "simple_onetoone", @@ -167,7 +227,7 @@ class OtherItems(SQLModel, table=True): id: int = Field(sa_column=Column('id', Integer, primary_key=True)) simple_onetoone: Optional['SimpleOnetoone'] = Relationship(\ -sa_relationship_kwargs={'uselist': False}, back_populates='other_item') +back_populates='other_item', sa_relationship_kwargs={'uselist': False}) class SimpleOnetoone(SQLModel, table=True):