From b1b6d0e3a761b331f01c25c8ece728f7c635ef04 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 29 Oct 2025 15:32:15 +0000 Subject: [PATCH 1/4] Fix RecursionError with Dict relationships in get_relationship_to() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixed a bug where using Dict or Mapping type hints with Relationship() would cause infinite recursion. The get_relationship_to() function now properly handles dict/Mapping types by extracting the value type (second type argument), similar to how it handles List types. Changes: - Added handling for dict and Mapping origins in get_relationship_to() - Extracts the value type from Dict[K, V] or Mapping[K, V] annotations - Added comprehensive tests for Dict relationships with attribute_mapped_collection This resolves the RecursionError that occurred when defining relationships like: children: Dict[str, Child] = Relationship(...) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- sqlmodel/_compat.py | 12 +++++ tests/test_dict_relationship_recursion.py | 60 +++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 tests/test_dict_relationship_recursion.py diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 230f8cc362..e204b1778a 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -179,6 +179,18 @@ def get_relationship_to( elif origin is list: use_annotation = get_args(annotation)[0] + # If a dict or Mapping, get the value type (second type argument) + elif origin is dict or origin is Mapping: + args = get_args(annotation) + if len(args) >= 2: + # For Dict[K, V] or Mapping[K, V], we want the value type (V) + use_annotation = args[1] + else: + raise ValueError( + f"Dict/Mapping relationship field '{name}' must have both key " + "and value type arguments (e.g., Dict[str, Model])" + ) + return get_relationship_to( name=name, rel_info=rel_info, annotation=use_annotation ) diff --git a/tests/test_dict_relationship_recursion.py b/tests/test_dict_relationship_recursion.py new file mode 100644 index 0000000000..37d8a6aca9 --- /dev/null +++ b/tests/test_dict_relationship_recursion.py @@ -0,0 +1,60 @@ +"""Test for Dict relationship recursion bug fix.""" +from typing import Dict + +import pytest +from sqlalchemy.orm.collections import attribute_mapped_collection +from sqlmodel import Field, Relationship, SQLModel + + +def test_dict_relationship_pattern(): + """Test that Dict relationships with attribute_mapped_collection work.""" + + # Create a minimal reproduction of the pattern + # This should not raise a RecursionError + + class TestChild(SQLModel, table=True): + __tablename__ = "test_child" + id: int = Field(primary_key=True) + key: str = Field(nullable=False) + parent_id: int = Field(foreign_key="test_parent.id") + parent: "TestParent" = Relationship(back_populates="children") + + class TestParent(SQLModel, table=True): + __tablename__ = "test_parent" + id: int = Field(primary_key=True) + children: Dict[str, "TestChild"] = Relationship( + back_populates="parent", + sa_relationship_kwargs={ + "collection_class": attribute_mapped_collection("key") + }, + ) + + # If we got here without RecursionError, the bug is fixed + assert TestParent.__tablename__ == "test_parent" + assert TestChild.__tablename__ == "test_child" + + +def test_dict_relationship_with_optional(): + """Test that Optional[Dict[...]] relationships also work.""" + from typing import Optional + + class Child(SQLModel, table=True): + __tablename__ = "child" + id: int = Field(primary_key=True) + key: str = Field(nullable=False) + parent_id: int = Field(foreign_key="parent.id") + parent: Optional["Parent"] = Relationship(back_populates="children") + + class Parent(SQLModel, table=True): + __tablename__ = "parent" + id: int = Field(primary_key=True) + children: Optional[Dict[str, "Child"]] = Relationship( + back_populates="parent", + sa_relationship_kwargs={ + "collection_class": attribute_mapped_collection("key") + }, + ) + + # If we got here without RecursionError, the bug is fixed + assert Parent.__tablename__ == "parent" + assert Child.__tablename__ == "child" From 8c2e4c4a9cb10247d216a24a778136bf79a09fa5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Oct 2025 16:10:47 +0000 Subject: [PATCH 2/4] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_dict_relationship_recursion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dict_relationship_recursion.py b/tests/test_dict_relationship_recursion.py index 37d8a6aca9..9993b6843d 100644 --- a/tests/test_dict_relationship_recursion.py +++ b/tests/test_dict_relationship_recursion.py @@ -1,7 +1,7 @@ """Test for Dict relationship recursion bug fix.""" + from typing import Dict -import pytest from sqlalchemy.orm.collections import attribute_mapped_collection from sqlmodel import Field, Relationship, SQLModel From 090be97fb1ddf86103b34c70ea754423f967bee3 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 26 Nov 2025 20:21:20 +0000 Subject: [PATCH 3/4] Fix sa_column/sa_type lost when using Annotated with validators When using Annotated[type, Field(sa_column=...), Validator(...)], Pydantic V2 creates a new pydantic.fields.FieldInfo that doesn't preserve SQLModel-specific attributes like sa_column and sa_type. This caused timezone-aware datetime columns defined with DateTime(timezone=True) to lose their timezone setting. The fix extracts the SQLModel FieldInfo from the original Annotated type's metadata, preserving sa_column and sa_type attributes even when Pydantic V2 merges the field info. Fixes timezone-aware datetime columns losing timezone=True when using Annotated types with Pydantic validators. --- sqlmodel/main.py | 61 ++++++++++++++++++-- tests/test_annotated_sa_column.py | 94 +++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+), 5 deletions(-) create mode 100644 tests/test_annotated_sa_column.py diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7c916f79af..3b79893f3b 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -54,7 +54,7 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import Literal, TypeAlias, deprecated, get_origin +from typing_extensions import Annotated, Literal, TypeAlias, deprecated, get_args, get_origin from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, @@ -562,7 +562,8 @@ def get_config(name: str) -> Any: # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): - col = get_column_from_field(v) + original_annotation = new_cls.__annotations__.get(k) + col = get_column_from_field(v, original_annotation) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. @@ -646,12 +647,44 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) -def get_sqlalchemy_type(field: Any) -> Any: +def _get_sqlmodel_field_info_from_annotation(annotation: Any) -> Optional["FieldInfo"]: + """Extract SQLModel FieldInfo from an Annotated type's metadata. + + When using Annotated[type, Field(...), Validator(...)], Pydantic V2 may create + a new pydantic.fields.FieldInfo that doesn't preserve SQLModel-specific attributes + like sa_column and sa_type. This function looks through the Annotated metadata + to find the original SQLModel FieldInfo. + """ + if get_origin(annotation) is not Annotated: + return None + for arg in get_args(annotation)[1:]: # Skip the first arg (the actual type) + if isinstance(arg, FieldInfo): + return arg + return None + + +def get_sqlalchemy_type(field: Any, original_annotation: Any = None) -> Any: if IS_PYDANTIC_V2: field_info = field else: field_info = field.field_info sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009 + # If sa_type not found on field_info, check if it's in the Annotated metadata + # This handles the case where Pydantic V2 creates a new FieldInfo losing SQLModel attrs + if sa_type is Undefined and IS_PYDANTIC_V2: + # First try field_info.annotation (may be unpacked by Pydantic) + annotation = getattr(field_info, "annotation", None) + if annotation is not None: + sqlmodel_field_info = _get_sqlmodel_field_info_from_annotation(annotation) + if sqlmodel_field_info is not None: + sa_type = getattr(sqlmodel_field_info, "sa_type", Undefined) + # If still not found, try the original annotation from the class + if sa_type is Undefined and original_annotation is not None: + sqlmodel_field_info = _get_sqlmodel_field_info_from_annotation( + original_annotation + ) + if sqlmodel_field_info is not None: + sa_type = getattr(sqlmodel_field_info, "sa_type", Undefined) if sa_type is not Undefined: return sa_type @@ -703,15 +736,33 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def get_column_from_field(field: Any) -> Column: # type: ignore +def get_column_from_field( + field: Any, original_annotation: Any = None +) -> Column: # type: ignore if IS_PYDANTIC_V2: field_info = field else: field_info = field.field_info sa_column = getattr(field_info, "sa_column", Undefined) + # If sa_column not found on field_info, check if it's in the Annotated metadata + # This handles the case where Pydantic V2 creates a new FieldInfo losing SQLModel attrs + if sa_column is Undefined and IS_PYDANTIC_V2: + # First try field_info.annotation (may be unpacked by Pydantic) + annotation = getattr(field_info, "annotation", None) + if annotation is not None: + sqlmodel_field_info = _get_sqlmodel_field_info_from_annotation(annotation) + if sqlmodel_field_info is not None: + sa_column = getattr(sqlmodel_field_info, "sa_column", Undefined) + # If still not found, try the original annotation from the class + if sa_column is Undefined and original_annotation is not None: + sqlmodel_field_info = _get_sqlmodel_field_info_from_annotation( + original_annotation + ) + if sqlmodel_field_info is not None: + sa_column = getattr(sqlmodel_field_info, "sa_column", Undefined) if isinstance(sa_column, Column): return sa_column - sa_type = get_sqlalchemy_type(field) + sa_type = get_sqlalchemy_type(field, original_annotation) primary_key = getattr(field_info, "primary_key", Undefined) if primary_key is Undefined: primary_key = False diff --git a/tests/test_annotated_sa_column.py b/tests/test_annotated_sa_column.py new file mode 100644 index 0000000000..2ac26e8a09 --- /dev/null +++ b/tests/test_annotated_sa_column.py @@ -0,0 +1,94 @@ +"""Tests for Annotated fields with sa_column and Pydantic validators. + +When using Annotated[type, Field(sa_column=...), Validator(...)], Pydantic V2 may +create a new FieldInfo that doesn't preserve SQLModel-specific attributes like +sa_column. These tests ensure the sa_column is properly extracted from the +Annotated metadata. +""" + +from datetime import datetime +from typing import Annotated, Optional + +from pydantic import AfterValidator, BeforeValidator +from sqlalchemy import Column, DateTime, String +from sqlmodel import Field, SQLModel + + +def test_annotated_sa_column_with_validators() -> None: + """Test that sa_column is preserved when using Annotated with validators.""" + + def before_validate(v: datetime) -> datetime: + return v + + def after_validate(v: datetime) -> datetime: + return v + + class Position(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + timestamp: Annotated[ + datetime, + Field( + sa_column=Column( + DateTime(timezone=True), nullable=False, index=True + ) + ), + BeforeValidator(before_validate), + AfterValidator(after_validate), + ] + + # Verify the column type has timezone=True + assert Position.__table__.c.timestamp.type.timezone is True + assert Position.__table__.c.timestamp.nullable is False + assert Position.__table__.c.timestamp.index is True + + +def test_annotated_sa_column_with_single_validator() -> None: + """Test sa_column with just one validator.""" + + def validate_name(v: str) -> str: + return v.strip() + + class Item(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: Annotated[ + str, + Field(sa_column=Column(String(100), nullable=False, unique=True)), + AfterValidator(validate_name), + ] + + assert isinstance(Item.__table__.c.name.type, String) + assert Item.__table__.c.name.type.length == 100 + assert Item.__table__.c.name.nullable is False + assert Item.__table__.c.name.unique is True + + +def test_annotated_sa_column_without_validators() -> None: + """Test that sa_column still works with Annotated but no validators.""" + + class Record(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + created_at: Annotated[ + datetime, + Field(sa_column=Column(DateTime(timezone=True), nullable=False)), + ] + + assert Record.__table__.c.created_at.type.timezone is True + assert Record.__table__.c.created_at.nullable is False + + +def test_annotated_sa_type_with_validators() -> None: + """Test that sa_type is preserved when using Annotated with validators.""" + + def validate_timestamp(v: datetime) -> datetime: + return v + + class Event(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + occurred_at: Annotated[ + datetime, + Field(sa_type=DateTime(timezone=True)), + AfterValidator(validate_timestamp), + ] + + # Verify the column type has timezone=True + assert Event.__table__.c.occurred_at.type.timezone is True From 5dda6aae6f6d6247bf4a0315f558421fdbc411dc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Nov 2025 20:24:45 +0000 Subject: [PATCH 4/4] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 15 ++++++++++----- tests/test_annotated_sa_column.py | 4 +--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3b79893f3b..529e20388e 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -54,7 +54,14 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import Annotated, Literal, TypeAlias, deprecated, get_args, get_origin +from typing_extensions import ( + Annotated, + Literal, + TypeAlias, + deprecated, + get_args, + get_origin, +) from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, @@ -647,7 +654,7 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) -def _get_sqlmodel_field_info_from_annotation(annotation: Any) -> Optional["FieldInfo"]: +def _get_sqlmodel_field_info_from_annotation(annotation: Any) -> Optional[FieldInfo]: """Extract SQLModel FieldInfo from an Annotated type's metadata. When using Annotated[type, Field(...), Validator(...)], Pydantic V2 may create @@ -736,9 +743,7 @@ def get_sqlalchemy_type(field: Any, original_annotation: Any = None) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def get_column_from_field( - field: Any, original_annotation: Any = None -) -> Column: # type: ignore +def get_column_from_field(field: Any, original_annotation: Any = None) -> Column: # type: ignore if IS_PYDANTIC_V2: field_info = field else: diff --git a/tests/test_annotated_sa_column.py b/tests/test_annotated_sa_column.py index 2ac26e8a09..9ec01c41e5 100644 --- a/tests/test_annotated_sa_column.py +++ b/tests/test_annotated_sa_column.py @@ -28,9 +28,7 @@ class Position(SQLModel, table=True): timestamp: Annotated[ datetime, Field( - sa_column=Column( - DateTime(timezone=True), nullable=False, index=True - ) + sa_column=Column(DateTime(timezone=True), nullable=False, index=True) ), BeforeValidator(before_validate), AfterValidator(after_validate),