Skip to content

Commit b98a3b3

Browse files
committed
Allow values that support SupportsGt (and others) protocol for gt (and others) in Field
1 parent 5611bda commit b98a3b3

File tree

2 files changed

+97
-16
lines changed

2 files changed

+97
-16
lines changed

sqlmodel/main.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
overload,
2323
)
2424

25+
import annotated_types
2526
from pydantic import BaseModel, EmailStr
2627
from pydantic.fields import FieldInfo as PydanticFieldInfo
2728
from sqlalchemy import (
@@ -214,10 +215,10 @@ def Field(
214215
exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
215216
include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
216217
const: Optional[bool] = None,
217-
gt: Optional[float] = None,
218-
ge: Optional[float] = None,
219-
lt: Optional[float] = None,
220-
le: Optional[float] = None,
218+
gt: Optional[annotated_types.SupportsGt] = None,
219+
ge: Optional[annotated_types.SupportsGe] = None,
220+
lt: Optional[annotated_types.SupportsLt] = None,
221+
le: Optional[annotated_types.SupportsLe] = None,
221222
multiple_of: Optional[float] = None,
222223
max_digits: Optional[int] = None,
223224
decimal_places: Optional[int] = None,
@@ -257,10 +258,10 @@ def Field(
257258
exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
258259
include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
259260
const: Optional[bool] = None,
260-
gt: Optional[float] = None,
261-
ge: Optional[float] = None,
262-
lt: Optional[float] = None,
263-
le: Optional[float] = None,
261+
gt: Optional[annotated_types.SupportsGt] = None,
262+
ge: Optional[annotated_types.SupportsGe] = None,
263+
lt: Optional[annotated_types.SupportsLt] = None,
264+
le: Optional[annotated_types.SupportsLe] = None,
264265
multiple_of: Optional[float] = None,
265266
max_digits: Optional[int] = None,
266267
decimal_places: Optional[int] = None,
@@ -309,10 +310,10 @@ def Field(
309310
exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
310311
include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
311312
const: Optional[bool] = None,
312-
gt: Optional[float] = None,
313-
ge: Optional[float] = None,
314-
lt: Optional[float] = None,
315-
le: Optional[float] = None,
313+
gt: Optional[annotated_types.SupportsGt] = None,
314+
ge: Optional[annotated_types.SupportsGe] = None,
315+
lt: Optional[annotated_types.SupportsLt] = None,
316+
le: Optional[annotated_types.SupportsLe] = None,
316317
multiple_of: Optional[float] = None,
317318
max_digits: Optional[int] = None,
318319
decimal_places: Optional[int] = None,
@@ -342,10 +343,10 @@ def Field(
342343
exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
343344
include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
344345
const: Optional[bool] = None,
345-
gt: Optional[float] = None,
346-
ge: Optional[float] = None,
347-
lt: Optional[float] = None,
348-
le: Optional[float] = None,
346+
gt: Optional[annotated_types.SupportsGt] = None,
347+
ge: Optional[annotated_types.SupportsGe] = None,
348+
lt: Optional[annotated_types.SupportsLt] = None,
349+
le: Optional[annotated_types.SupportsLe] = None,
349350
multiple_of: Optional[float] = None,
350351
max_digits: Optional[int] = None,
351352
decimal_places: Optional[int] = None,

tests/test_pydantic/test_field.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,83 @@ class Model(SQLModel):
5454

5555
instance = Model(id=123, foo="bar")
5656
assert "foo=" not in repr(instance)
57+
58+
59+
def test_gt():
60+
class Model(SQLModel):
61+
int_value: int = Field(gt=10)
62+
tuple_value: tuple[int, int] = Field(gt=(1, 2))
63+
64+
Model(int_value=11, tuple_value=(1, 3))
65+
66+
with pytest.raises(ValidationError) as exc_info:
67+
Model(int_value=10, tuple_value=(1, 3))
68+
assert len(exc_info.value.errors()) == 1
69+
assert exc_info.value.errors()[0]["type"] == "greater_than"
70+
assert exc_info.value.errors()[0]["loc"] == ("int_value",)
71+
72+
with pytest.raises(ValidationError) as exc_info_2:
73+
Model(int_value=11, tuple_value=(1, 2))
74+
assert len(exc_info_2.value.errors()) == 1
75+
assert exc_info_2.value.errors()[0]["type"] == "greater_than"
76+
assert exc_info_2.value.errors()[0]["loc"] == ("tuple_value",)
77+
78+
79+
def test_ge():
80+
class Model(SQLModel):
81+
int_value: int = Field(ge=10)
82+
tuple_value: tuple[int, int] = Field(ge=(1, 2))
83+
84+
Model(int_value=10, tuple_value=(1, 2))
85+
86+
with pytest.raises(ValidationError) as exc_info:
87+
Model(int_value=9, tuple_value=(1, 2))
88+
assert len(exc_info.value.errors()) == 1
89+
assert exc_info.value.errors()[0]["type"] == "greater_than_equal"
90+
assert exc_info.value.errors()[0]["loc"] == ("int_value",)
91+
92+
with pytest.raises(ValidationError) as exc_info_2:
93+
Model(int_value=10, tuple_value=(1, 1))
94+
assert len(exc_info_2.value.errors()) == 1
95+
assert exc_info_2.value.errors()[0]["type"] == "greater_than_equal"
96+
assert exc_info_2.value.errors()[0]["loc"] == ("tuple_value",)
97+
98+
99+
def test_lt():
100+
class Model(SQLModel):
101+
int_value: int = Field(lt=10)
102+
tuple_value: tuple[int, int] = Field(lt=(1, 2))
103+
104+
Model(int_value=9, tuple_value=(1, 1))
105+
106+
with pytest.raises(ValidationError) as exc_info:
107+
Model(int_value=10, tuple_value=(1, 1))
108+
assert len(exc_info.value.errors()) == 1
109+
assert exc_info.value.errors()[0]["type"] == "less_than"
110+
assert exc_info.value.errors()[0]["loc"] == ("int_value",)
111+
112+
with pytest.raises(ValidationError) as exc_info_2:
113+
Model(int_value=9, tuple_value=(1, 2))
114+
assert len(exc_info_2.value.errors()) == 1
115+
assert exc_info_2.value.errors()[0]["type"] == "less_than"
116+
assert exc_info_2.value.errors()[0]["loc"] == ("tuple_value",)
117+
118+
119+
def test_le():
120+
class Model(SQLModel):
121+
int_value: int = Field(le=10)
122+
tuple_value: tuple[int, int] = Field(le=(1, 2))
123+
124+
Model(int_value=10, tuple_value=(1, 2))
125+
126+
with pytest.raises(ValidationError) as exc_info:
127+
Model(int_value=11, tuple_value=(1, 2))
128+
assert len(exc_info.value.errors()) == 1
129+
assert exc_info.value.errors()[0]["type"] == "less_than_equal"
130+
assert exc_info.value.errors()[0]["loc"] == ("int_value",)
131+
132+
with pytest.raises(ValidationError) as exc_info_2:
133+
Model(int_value=10, tuple_value=(1, 3))
134+
assert len(exc_info_2.value.errors()) == 1
135+
assert exc_info_2.value.errors()[0]["type"] == "less_than_equal"
136+
assert exc_info_2.value.errors()[0]["loc"] == ("tuple_value",)

0 commit comments

Comments
 (0)