Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
Version history
===============

**UNRELEASED**
**4.0.0rc3**

- **BACKWARD INCOMPATIBLE** Relationship names changed when multiple FKs or junction tables
connect to the same target table. Regenerating models will break existing code.
- Added support for generating Python enum classes for ``ARRAY(Enum(...))`` columns
(e.g., PostgreSQL ``ARRAY(ENUM)``). Supports named/unnamed enums, shared enums across
columns, and multi-dimensional arrays. Respects ``--options nonativeenums``.
(PR by @sheinbergon)
- Improved relationship naming: one-to-many uses FK column names (e.g.,
``simple_items_parent_container``), many-to-many uses junction table names (e.g.,
``students_enrollments``). Use ``--options nofknames`` to revert to old behavior. (PR by @sheinbergon)
Expand Down
78 changes: 48 additions & 30 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,20 @@ def render_column_type(self, column: Column[Any]) -> str:

args = []
kwargs: dict[str, Any] = {}

# Check if this is an ARRAY column with an Enum item type mapped to a Python enum class
if isinstance(column_type, ARRAY) and isinstance(column_type.item_type, Enum):
if enum_class_name := self.enum_classes.get(
(column.table.name, column.name)
):
self.add_import(ARRAY)
self.add_import(Enum)
rendered_enum = f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls])"
if column_type.dimensions is not None:
kwargs["dimensions"] = repr(column_type.dimensions)

return render_callable("ARRAY", rendered_enum, kwargs=kwargs)

sig = inspect.signature(column_type.__class__.__init__)
defaults = {param.name: param.default for param in sig.parameters.values()}
missing = object()
Expand Down Expand Up @@ -809,6 +823,31 @@ def render_enum_classes(self) -> str:

def fix_column_types(self, table: Table) -> None:
"""Adjust the reflected column types."""

def fix_enum_column(col_name: str, enum_type: Enum) -> None:
if (table.name, col_name) in self.enum_classes:
return

if enum_type.name:
existing_class = None
for (_, _), cls in self.enum_classes.items():
if cls == self._enum_name_to_class_name(enum_type.name):
existing_class = cls
break

if existing_class:
enum_class_name = existing_class
else:
enum_class_name = self._enum_name_to_class_name(enum_type.name)
if enum_class_name not in self.enum_values:
self.enum_values[enum_class_name] = list(enum_type.enums)
else:
enum_class_name = self._create_enum_class(
table.name, col_name, list(enum_type.enums)
)

self.enum_classes[(table.name, col_name)] = enum_class_name

# Detect check constraints for boolean and enum columns
for constraint in table.constraints.copy():
if isinstance(constraint, CheckConstraint):
Expand Down Expand Up @@ -852,37 +891,16 @@ def fix_column_types(self, table: Table) -> None:
and isinstance(column.type, Enum)
and column.type.enums
):
if column.type.name:
# Named enum - create shared enum class if not already created
if (table.name, column.name) not in self.enum_classes:
# Check if we've already created an enum for this name
existing_class = None
for (t, c), cls in self.enum_classes.items():
if cls == self._enum_name_to_class_name(column.type.name):
existing_class = cls
break

if existing_class:
enum_class_name = existing_class
else:
# Create new enum class from the enum's name
enum_class_name = self._enum_name_to_class_name(
column.type.name
)
# Register the enum values if not already registered
if enum_class_name not in self.enum_values:
self.enum_values[enum_class_name] = list(
column.type.enums
)
fix_enum_column(column.name, column.type)

self.enum_classes[(table.name, column.name)] = enum_class_name
else:
# Unnamed enum - create enum class per column
if (table.name, column.name) not in self.enum_classes:
enum_class_name = self._create_enum_class(
table.name, column.name, list(column.type.enums)
)
self.enum_classes[(table.name, column.name)] = enum_class_name
# Handle ARRAY columns with Enum item types (e.g., PostgreSQL ARRAY(ENUM))
elif (
"nonativeenums" not in self.options
and isinstance(column.type, ARRAY)
and isinstance(column.type.item_type, Enum)
and column.type.item_type.enums
):
fix_enum_column(column.name, column.type.item_type)

if not self.keep_dialect_types:
try:
Expand Down
232 changes: 232 additions & 0 deletions tests/test_generator_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -2815,3 +2815,235 @@ class Users(Base):
status: Mapped[UsersStatus] = mapped_column(Enum(UsersStatus, values_callable=lambda cls: [member.value for member in cls]), nullable=False)
""",
)


def test_array_enum_named(generator: CodeGenerator) -> None:
Table(
"users",
generator.metadata,
Column("id", INTEGER, primary_key=True),
Column(
"roles",
ARRAY(SAEnum("admin", "user", "moderator", name="role_enum")),
nullable=False,
),
)

validate_code(
generator.generate(),
"""\
import enum
from sqlalchemy import ARRAY, Enum, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class RoleEnum(str, enum.Enum):
ADMIN = 'admin'
USER = 'user'
MODERATOR = 'moderator'
class Users(Base):
__tablename__ = 'users'
id: Mapped[int] = mapped_column(Integer, primary_key=True)
roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])), nullable=False)
""",
)


def test_array_enum_unnamed(generator: CodeGenerator) -> None:
Table(
"users",
generator.metadata,
Column("id", INTEGER, primary_key=True),
Column(
"roles",
ARRAY(SAEnum("admin", "user")),
nullable=False,
),
)

validate_code(
generator.generate(),
"""\
import enum
from sqlalchemy import ARRAY, Enum, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class UsersRoles(str, enum.Enum):
ADMIN = 'admin'
USER = 'user'
class Users(Base):
__tablename__ = 'users'
id: Mapped[int] = mapped_column(Integer, primary_key=True)
roles: Mapped[list[UsersRoles]] = mapped_column(ARRAY(Enum(UsersRoles, values_callable=lambda cls: [member.value for member in cls])), nullable=False)
""",
)


def test_array_enum_nullable(generator: CodeGenerator) -> None:
Table(
"users",
generator.metadata,
Column("id", INTEGER, primary_key=True),
Column(
"roles",
ARRAY(SAEnum("admin", "user", name="role_enum")),
),
)

validate_code(
generator.generate(),
"""\
from typing import Optional
import enum
from sqlalchemy import ARRAY, Enum, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class RoleEnum(str, enum.Enum):
ADMIN = 'admin'
USER = 'user'
class Users(Base):
__tablename__ = 'users'
id: Mapped[int] = mapped_column(Integer, primary_key=True)
roles: Mapped[Optional[list[RoleEnum]]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])))
""",
)


def test_array_enum_with_dimensions(generator: CodeGenerator) -> None:
Table(
"items",
generator.metadata,
Column("id", INTEGER, primary_key=True),
Column(
"tag_matrix",
ARRAY(SAEnum("a", "b", name="tag_enum"), dimensions=2),
nullable=False,
),
)

validate_code(
generator.generate(),
"""\
import enum
from sqlalchemy import ARRAY, Enum, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class TagEnum(str, enum.Enum):
A = 'a'
B = 'b'
class Items(Base):
__tablename__ = 'items'
id: Mapped[int] = mapped_column(Integer, primary_key=True)
tag_matrix: Mapped[list[list[TagEnum]]] = mapped_column(ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls]), dimensions=2), nullable=False)
""",
)


def test_array_enum_nonativeenums_option(generator: CodeGenerator) -> None:
Table(
"users",
generator.metadata,
Column("id", INTEGER, primary_key=True),
Column(
"roles",
ARRAY(SAEnum("admin", "user", name="role_enum")),
nullable=False,
),
)

generator = DeclarativeGenerator(
generator.metadata, generator.bind, ["nonativeenums"]
)

validate_code(
generator.generate(),
"""\
from sqlalchemy import ARRAY, Enum, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class Users(Base):
__tablename__ = 'users'
id: Mapped[int] = mapped_column(Integer, primary_key=True)
roles: Mapped[list[str]] = mapped_column(ARRAY(Enum('admin', 'user', name='role_enum')), nullable=False)
""",
)


def test_array_enum_shared_with_regular_enum(generator: CodeGenerator) -> None:
Table(
"users",
generator.metadata,
Column("id", INTEGER, primary_key=True),
Column(
"primary_role",
SAEnum("admin", "user", name="role_enum"),
nullable=False,
),
Column(
"all_roles",
ARRAY(SAEnum("admin", "user", name="role_enum")),
nullable=False,
),
)

validate_code(
generator.generate(),
"""\
import enum
from sqlalchemy import ARRAY, Enum, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class RoleEnum(str, enum.Enum):
ADMIN = 'admin'
USER = 'user'
class Users(Base):
__tablename__ = 'users'
id: Mapped[int] = mapped_column(Integer, primary_key=True)
primary_role: Mapped[RoleEnum] = mapped_column(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls]), nullable=False)
all_roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])), nullable=False)
""",
)
Loading