feat: improve future compat with pydantic v3

This commit is contained in:
stainless-app[bot] 2025-09-04 02:54:45 +00:00
parent 979c43dbc7
commit 5d58a795ce
13 changed files with 437 additions and 137 deletions

View file

@ -50,7 +50,7 @@ from ._utils import (
strip_annotated_type,
)
from ._compat import (
PYDANTIC_V2,
PYDANTIC_V1,
ConfigDict,
GenericModel as BaseGenericModel,
get_args,
@ -81,11 +81,7 @@ class _ConfigProtocol(Protocol):
class BaseModel(pydantic.BaseModel):
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
)
else:
if PYDANTIC_V1:
@property
@override
@ -95,6 +91,10 @@ class BaseModel(pydantic.BaseModel):
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore
else:
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
)
def to_dict(
self,
@ -215,25 +215,25 @@ class BaseModel(pydantic.BaseModel):
if key not in model_fields:
parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value
if PYDANTIC_V2:
_extra[key] = parsed
else:
if PYDANTIC_V1:
_fields_set.add(key)
fields_values[key] = parsed
else:
_extra[key] = parsed
object.__setattr__(m, "__dict__", fields_values)
if PYDANTIC_V2:
# these properties are copied from Pydantic's `model_construct()` method
object.__setattr__(m, "__pydantic_private__", None)
object.__setattr__(m, "__pydantic_extra__", _extra)
object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
else:
if PYDANTIC_V1:
# init_private_attributes() does not exist in v2
m._init_private_attributes() # type: ignore
# copied from Pydantic v1's `construct()` method
object.__setattr__(m, "__fields_set__", _fields_set)
else:
# these properties are copied from Pydantic's `model_construct()` method
object.__setattr__(m, "__pydantic_private__", None)
object.__setattr__(m, "__pydantic_extra__", _extra)
object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
return m
@ -243,7 +243,7 @@ class BaseModel(pydantic.BaseModel):
# although not in practice
model_construct = construct
if not PYDANTIC_V2:
if PYDANTIC_V1:
# we define aliases for some of the new pydantic v2 methods so
# that we can just document these methods without having to specify
# a specific pydantic version as some users may not know which
@ -363,10 +363,10 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
if value is None:
return field_get_default(field)
if PYDANTIC_V2:
type_ = field.annotation
else:
if PYDANTIC_V1:
type_ = cast(type, field.outer_type_) # type: ignore
else:
type_ = field.annotation # type: ignore
if type_ is None:
raise RuntimeError(f"Unexpected field type is None for {key}")
@ -375,7 +375,7 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None:
if not PYDANTIC_V2:
if PYDANTIC_V1:
# TODO
return None
@ -628,7 +628,19 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
for variant in get_args(union):
variant = strip_annotated_type(variant)
if is_basemodel_type(variant):
if PYDANTIC_V2:
if PYDANTIC_V1:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias
if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation):
for entry in get_args(annotation):
if isinstance(entry, str):
mapping[entry] = variant
else:
field = _extract_field_schema_pv2(variant, discriminator_field_name)
if not field:
continue
@ -642,18 +654,6 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
for entry in cast("LiteralSchema", field_schema)["expected"]:
if isinstance(entry, str):
mapping[entry] = variant
else:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias
if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation):
for entry in get_args(annotation):
if isinstance(entry, str):
mapping[entry] = variant
if not mapping:
return None
@ -714,7 +714,7 @@ else:
pass
if PYDANTIC_V2:
if not PYDANTIC_V1:
from pydantic import TypeAdapter as _TypeAdapter
_CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
@ -782,12 +782,12 @@ class FinalRequestOptions(pydantic.BaseModel):
json_data: Union[Body, None] = None
extra_json: Union[AnyMapping, None] = None
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
else:
if PYDANTIC_V1:
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
arbitrary_types_allowed: bool = True
else:
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
def get_max_retries(self, max_retries: int) -> int:
if isinstance(self.max_retries, NotGiven):
@ -820,9 +820,9 @@ class FinalRequestOptions(pydantic.BaseModel):
key: strip_not_given(value)
for key, value in values.items()
}
if PYDANTIC_V2:
return super().model_construct(_fields_set, **kwargs)
return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
if PYDANTIC_V1:
return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
return super().model_construct(_fields_set, **kwargs)
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment