diff --git a/CHANGES.rst b/CHANGES.rst index 30bbad68..cbb5a4fc 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,10 +1,14 @@ Version history =============== -**UNRELEASED** +**4.0.0rc3** - **BACKWARD INCOMPATIBLE** Relationship names changed when multiple FKs or junction tables connect to the same target table. Regenerating models will break existing code. +- Added support for generating Python enum classes for ``ARRAY(Enum(...))`` columns + (e.g., PostgreSQL ``ARRAY(ENUM)``). Supports named/unnamed enums, shared enums across + columns, and multi-dimensional arrays. Respects ``--options nonativeenums``. + (PR by @sheinbergon) - Improved relationship naming: one-to-many uses FK column names (e.g., ``simple_items_parent_container``), many-to-many uses junction table names (e.g., ``students_enrollments``). Use ``--options nofknames`` to revert to old behavior. (PR by @sheinbergon) diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index 8318d050..c5eef945 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -554,6 +554,20 @@ def render_column_type(self, column: Column[Any]) -> str: args = [] kwargs: dict[str, Any] = {} + + # Check if this is an ARRAY column with an Enum item type mapped to a Python enum class + if isinstance(column_type, ARRAY) and isinstance(column_type.item_type, Enum): + if enum_class_name := self.enum_classes.get( + (column.table.name, column.name) + ): + self.add_import(ARRAY) + self.add_import(Enum) + rendered_enum = f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls])" + if column_type.dimensions is not None: + kwargs["dimensions"] = repr(column_type.dimensions) + + return render_callable("ARRAY", rendered_enum, kwargs=kwargs) + sig = inspect.signature(column_type.__class__.__init__) defaults = {param.name: param.default for param in sig.parameters.values()} missing = object() @@ -809,6 +823,31 @@ def render_enum_classes(self) -> str: def fix_column_types(self, table: Table) -> None: """Adjust the reflected column types.""" + + def fix_enum_column(col_name: str, enum_type: Enum) -> None: + if (table.name, col_name) in self.enum_classes: + return + + if enum_type.name: + existing_class = None + for (_, _), cls in self.enum_classes.items(): + if cls == self._enum_name_to_class_name(enum_type.name): + existing_class = cls + break + + if existing_class: + enum_class_name = existing_class + else: + enum_class_name = self._enum_name_to_class_name(enum_type.name) + if enum_class_name not in self.enum_values: + self.enum_values[enum_class_name] = list(enum_type.enums) + else: + enum_class_name = self._create_enum_class( + table.name, col_name, list(enum_type.enums) + ) + + self.enum_classes[(table.name, col_name)] = enum_class_name + # Detect check constraints for boolean and enum columns for constraint in table.constraints.copy(): if isinstance(constraint, CheckConstraint): @@ -852,37 +891,16 @@ def fix_column_types(self, table: Table) -> None: and isinstance(column.type, Enum) and column.type.enums ): - if column.type.name: - # Named enum - create shared enum class if not already created - if (table.name, column.name) not in self.enum_classes: - # Check if we've already created an enum for this name - existing_class = None - for (t, c), cls in self.enum_classes.items(): - if cls == self._enum_name_to_class_name(column.type.name): - existing_class = cls - break - - if existing_class: - enum_class_name = existing_class - else: - # Create new enum class from the enum's name - enum_class_name = self._enum_name_to_class_name( - column.type.name - ) - # Register the enum values if not already registered - if enum_class_name not in self.enum_values: - self.enum_values[enum_class_name] = list( - column.type.enums - ) + fix_enum_column(column.name, column.type) - self.enum_classes[(table.name, column.name)] = enum_class_name - else: - # Unnamed enum - create enum class per column - if (table.name, column.name) not in self.enum_classes: - enum_class_name = self._create_enum_class( - table.name, column.name, list(column.type.enums) - ) - self.enum_classes[(table.name, column.name)] = enum_class_name + # Handle ARRAY columns with Enum item types (e.g., PostgreSQL ARRAY(ENUM)) + elif ( + "nonativeenums" not in self.options + and isinstance(column.type, ARRAY) + and isinstance(column.type.item_type, Enum) + and column.type.item_type.enums + ): + fix_enum_column(column.name, column.type.item_type) if not self.keep_dialect_types: try: diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index e6fe9d30..87fb0117 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -2815,3 +2815,235 @@ class Users(Base): status: Mapped[UsersStatus] = mapped_column(Enum(UsersStatus, values_callable=lambda cls: [member.value for member in cls]), nullable=False) """, ) + + +def test_array_enum_named(generator: CodeGenerator) -> None: + Table( + "users", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column( + "roles", + ARRAY(SAEnum("admin", "user", "moderator", name="role_enum")), + nullable=False, + ), + ) + + validate_code( + generator.generate(), + """\ + import enum + + from sqlalchemy import ARRAY, Enum, Integer + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + class Base(DeclarativeBase): + pass + + + class RoleEnum(str, enum.Enum): + ADMIN = 'admin' + USER = 'user' + MODERATOR = 'moderator' + + + class Users(Base): + __tablename__ = 'users' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])), nullable=False) + """, + ) + + +def test_array_enum_unnamed(generator: CodeGenerator) -> None: + Table( + "users", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column( + "roles", + ARRAY(SAEnum("admin", "user")), + nullable=False, + ), + ) + + validate_code( + generator.generate(), + """\ + import enum + + from sqlalchemy import ARRAY, Enum, Integer + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + class Base(DeclarativeBase): + pass + + + class UsersRoles(str, enum.Enum): + ADMIN = 'admin' + USER = 'user' + + + class Users(Base): + __tablename__ = 'users' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + roles: Mapped[list[UsersRoles]] = mapped_column(ARRAY(Enum(UsersRoles, values_callable=lambda cls: [member.value for member in cls])), nullable=False) + """, + ) + + +def test_array_enum_nullable(generator: CodeGenerator) -> None: + Table( + "users", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column( + "roles", + ARRAY(SAEnum("admin", "user", name="role_enum")), + ), + ) + + validate_code( + generator.generate(), + """\ + from typing import Optional + import enum + + from sqlalchemy import ARRAY, Enum, Integer + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + class Base(DeclarativeBase): + pass + + + class RoleEnum(str, enum.Enum): + ADMIN = 'admin' + USER = 'user' + + + class Users(Base): + __tablename__ = 'users' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + roles: Mapped[Optional[list[RoleEnum]]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls]))) + """, + ) + + +def test_array_enum_with_dimensions(generator: CodeGenerator) -> None: + Table( + "items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column( + "tag_matrix", + ARRAY(SAEnum("a", "b", name="tag_enum"), dimensions=2), + nullable=False, + ), + ) + + validate_code( + generator.generate(), + """\ + import enum + + from sqlalchemy import ARRAY, Enum, Integer + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + class Base(DeclarativeBase): + pass + + + class TagEnum(str, enum.Enum): + A = 'a' + B = 'b' + + + class Items(Base): + __tablename__ = 'items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + tag_matrix: Mapped[list[list[TagEnum]]] = mapped_column(ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls]), dimensions=2), nullable=False) + """, + ) + + +def test_array_enum_nonativeenums_option(generator: CodeGenerator) -> None: + Table( + "users", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column( + "roles", + ARRAY(SAEnum("admin", "user", name="role_enum")), + nullable=False, + ), + ) + + generator = DeclarativeGenerator( + generator.metadata, generator.bind, ["nonativeenums"] + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import ARRAY, Enum, Integer + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + class Base(DeclarativeBase): + pass + + + class Users(Base): + __tablename__ = 'users' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + roles: Mapped[list[str]] = mapped_column(ARRAY(Enum('admin', 'user', name='role_enum')), nullable=False) + """, + ) + + +def test_array_enum_shared_with_regular_enum(generator: CodeGenerator) -> None: + Table( + "users", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column( + "primary_role", + SAEnum("admin", "user", name="role_enum"), + nullable=False, + ), + Column( + "all_roles", + ARRAY(SAEnum("admin", "user", name="role_enum")), + nullable=False, + ), + ) + + validate_code( + generator.generate(), + """\ + import enum + + from sqlalchemy import ARRAY, Enum, Integer + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + class Base(DeclarativeBase): + pass + + + class RoleEnum(str, enum.Enum): + ADMIN = 'admin' + USER = 'user' + + + class Users(Base): + __tablename__ = 'users' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + primary_role: Mapped[RoleEnum] = mapped_column(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls]), nullable=False) + all_roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])), nullable=False) + """, + ) diff --git a/tests/test_generator_tables.py b/tests/test_generator_tables.py index fcd73f16..220f7ad6 100644 --- a/tests/test_generator_tables.py +++ b/tests/test_generator_tables.py @@ -4,6 +4,7 @@ import pytest from _pytest.fixtures import FixtureRequest +from sqlalchemy import Enum as SAEnum from sqlalchemy import TypeDecorator from sqlalchemy.dialects import mysql, postgresql, registry from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql @@ -21,7 +22,7 @@ ) from sqlalchemy.sql.expression import text from sqlalchemy.sql.sqltypes import DateTime, NullType -from sqlalchemy.types import INTEGER, NUMERIC, SMALLINT, VARCHAR, Text +from sqlalchemy.types import ARRAY, INTEGER, NUMERIC, SMALLINT, VARCHAR, Text from sqlacodegen.generators import CodeGenerator, TablesGenerator @@ -320,6 +321,83 @@ class StatusEnum(str, enum.Enum): ) +def test_array_enum_named(generator: CodeGenerator) -> None: + Table( + "users", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("roles", ARRAY(SAEnum("admin", "user", "moderator", name="role_enum"))), + ) + + validate_code( + generator.generate(), + """\ + import enum + + from sqlalchemy import ARRAY, Column, Enum, Integer, MetaData, Table + + metadata = MetaData() + + + class RoleEnum(str, enum.Enum): + ADMIN = 'admin' + USER = 'user' + MODERATOR = 'moderator' + + + t_users = Table( + 'users', metadata, + Column('id', Integer, primary_key=True), + Column('roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls]))) + ) + """, + ) + + +def test_array_enum_shared(generator: CodeGenerator) -> None: + Table( + "users", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("roles", ARRAY(SAEnum("admin", "user", name="role_enum"))), + ) + Table( + "groups", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("allowed_roles", ARRAY(SAEnum("admin", "user", name="role_enum"))), + ) + + validate_code( + generator.generate(), + """\ + import enum + + from sqlalchemy import ARRAY, Column, Enum, Integer, MetaData, Table + + metadata = MetaData() + + + class RoleEnum(str, enum.Enum): + ADMIN = 'admin' + USER = 'user' + + + t_groups = Table( + 'groups', metadata, + Column('id', Integer, primary_key=True), + Column('allowed_roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls]))) + ) + + t_users = Table( + 'users', metadata, + Column('id', Integer, primary_key=True), + Column('roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls]))) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) def test_domain_text(generator: CodeGenerator) -> None: Table(