From b9e057f0827b8ed5851ec7e7a8b9ed622019f065 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 9 Jul 2025 02:40:18 +0000 Subject: [PATCH 01/25] chore(internal): bump pinned h11 dep --- requirements-dev.lock | 4 ++-- requirements.lock | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements-dev.lock b/requirements-dev.lock index c1bb9eb..087efaa 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -48,9 +48,9 @@ filelock==3.12.4 frozenlist==1.6.2 # via aiohttp # via aiosignal -h11==0.14.0 +h11==0.16.0 # via httpcore -httpcore==1.0.2 +httpcore==1.0.9 # via httpx httpx==0.28.1 # via httpx-aiohttp diff --git a/requirements.lock b/requirements.lock index 9a36b80..a4b1e3b 100644 --- a/requirements.lock +++ b/requirements.lock @@ -36,9 +36,9 @@ exceptiongroup==1.2.2 frozenlist==1.6.2 # via aiohttp # via aiosignal -h11==0.14.0 +h11==0.16.0 # via httpcore -httpcore==1.0.2 +httpcore==1.0.9 # via httpx httpx==0.28.1 # via httpx-aiohttp From 144f8ca076c92c7e68403d6a0a49ce1caeee69b9 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 9 Jul 2025 02:59:16 +0000 Subject: [PATCH 02/25] chore(package): mark python 3.13 as supported --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index b989bd9..f929b4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Operating System :: OS Independent", "Operating System :: POSIX", "Operating System :: MacOS", From 79004f855735cc872d0050f781bf8b84d04fd592 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Thu, 10 Jul 2025 02:54:17 +0000 Subject: [PATCH 03/25] fix(parsing): correctly handle nested discriminated unions --- src/zeroentropy/_models.py | 13 ++++++----- tests/test_models.py | 45 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/src/zeroentropy/_models.py b/src/zeroentropy/_models.py index 4f21498..528d568 100644 --- a/src/zeroentropy/_models.py +++ b/src/zeroentropy/_models.py @@ -2,9 +2,10 @@ import os import inspect -from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast +from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast from datetime import date, datetime from typing_extensions import ( + List, Unpack, Literal, ClassVar, @@ -366,7 +367,7 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object: if type_ is None: raise RuntimeError(f"Unexpected field type is None for {key}") - return construct_type(value=value, type_=type_) + return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None)) def is_basemodel(type_: type) -> bool: @@ -420,7 +421,7 @@ def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T: return cast(_T, construct_type(value=value, type_=type_)) -def construct_type(*, value: object, type_: object) -> object: +def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]] = None) -> object: """Loose coercion to the expected type with construction of nested values. If the given value does not match the expected type then it is returned as-is. @@ -438,8 +439,10 @@ def construct_type(*, value: object, type_: object) -> object: type_ = type_.__value__ # type: ignore[unreachable] # unwrap `Annotated[T, ...]` -> `T` - if is_annotated_type(type_): - meta: tuple[Any, ...] = get_args(type_)[1:] + if metadata is not None: + meta: tuple[Any, ...] = tuple(metadata) + elif is_annotated_type(type_): + meta = get_args(type_)[1:] type_ = extract_type_arg(type_, 0) else: meta = tuple() diff --git a/tests/test_models.py b/tests/test_models.py index 19461b3..cc3baf1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -889,3 +889,48 @@ class ModelB(BaseModel): ) assert isinstance(m, ModelB) + + +def test_nested_discriminated_union() -> None: + class InnerType1(BaseModel): + type: Literal["type_1"] + + class InnerModel(BaseModel): + inner_value: str + + class InnerType2(BaseModel): + type: Literal["type_2"] + some_inner_model: InnerModel + + class Type1(BaseModel): + base_type: Literal["base_type_1"] + value: Annotated[ + Union[ + InnerType1, + InnerType2, + ], + PropertyInfo(discriminator="type"), + ] + + class Type2(BaseModel): + base_type: Literal["base_type_2"] + + T = Annotated[ + Union[ + Type1, + Type2, + ], + PropertyInfo(discriminator="base_type"), + ] + + model = construct_type( + type_=T, + value={ + "base_type": "base_type_1", + "value": { + "type": "type_2", + }, + }, + ) + assert isinstance(model, Type1) + assert isinstance(model.value, InnerType2) From 3ccc314b021100775fe854d7225405c63b299599 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Fri, 11 Jul 2025 03:13:17 +0000 Subject: [PATCH 04/25] chore(readme): fix version rendering on pypi --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 239e0f2..bd015b2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # ZeroEntropy Python SDK -[![PyPI version]()](https://pypi.org/project/zeroentropy/) + +[![PyPI version](https://img.shields.io/pypi/v/zeroentropy.svg?label=pypi%20(stable))](https://pypi.org/project/zeroentropy/) The ZeroEntropy Python SDK provides convenient access to the [ZeroEntropy REST API](https://docs.zeroentropy.dev/api-reference/) from any Python 3.8+ application. From dfd4e866616c9c7e9fa6ffa3b757abf6d8cd9ebc Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Sat, 12 Jul 2025 02:19:40 +0000 Subject: [PATCH 05/25] fix(client): don't send Content-Type header on GET requests --- pyproject.toml | 2 +- src/zeroentropy/_base_client.py | 11 +++++++++-- tests/test_client.py | 4 ++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f929b4a..57d9123 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ Homepage = "https://github.com/zeroentropy-ai/zeroentropy-python" Repository = "https://github.com/zeroentropy-ai/zeroentropy-python" [project.optional-dependencies] -aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.6"] +aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.8"] [tool.rye] managed = true diff --git a/src/zeroentropy/_base_client.py b/src/zeroentropy/_base_client.py index 42f3bdd..8b4b33c 100644 --- a/src/zeroentropy/_base_client.py +++ b/src/zeroentropy/_base_client.py @@ -529,6 +529,15 @@ def _build_request( # work around https://github.com/encode/httpx/discussions/2880 kwargs["extensions"] = {"sni_hostname": prepared_url.host.replace("_", "-")} + is_body_allowed = options.method.lower() != "get" + + if is_body_allowed: + kwargs["json"] = json_data if is_given(json_data) else None + kwargs["files"] = files + else: + headers.pop("Content-Type", None) + kwargs.pop("data", None) + # TODO: report this error to httpx return self._client.build_request( # pyright: ignore[reportUnknownMemberType] headers=headers, @@ -540,8 +549,6 @@ def _build_request( # so that passing a `TypedDict` doesn't cause an error. # https://github.com/microsoft/pyright/issues/3526#event-6715453066 params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None, - json=json_data if is_given(json_data) else None, - files=files, **kwargs, ) diff --git a/tests/test_client.py b/tests/test_client.py index b7b0f8c..6198354 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -464,7 +464,7 @@ def test_request_extra_query(self) -> None: def test_multipart_repeating_array(self, client: ZeroEntropy) -> None: request = client._build_request( FinalRequestOptions.construct( - method="get", + method="post", url="/foo", headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"}, json_data={"array": ["foo", "bar"]}, @@ -1275,7 +1275,7 @@ def test_request_extra_query(self) -> None: def test_multipart_repeating_array(self, async_client: AsyncZeroEntropy) -> None: request = async_client._build_request( FinalRequestOptions.construct( - method="get", + method="post", url="/foo", headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"}, json_data={"array": ["foo", "bar"]}, From de35b9dc8c136eeaa0909076488f152a5f885398 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Tue, 15 Jul 2025 02:19:02 +0000 Subject: [PATCH 06/25] feat: clean up environment call outs --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index bd015b2..a3c0da0 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,6 @@ pip install --pre zeroentropy[aiohttp] Then you can enable it by instantiating the client with `http_client=DefaultAioHttpClient()`: ```python -import os import asyncio from zeroentropy import DefaultAioHttpClient from zeroentropy import AsyncZeroEntropy @@ -102,7 +101,7 @@ from zeroentropy import AsyncZeroEntropy async def main() -> None: async with AsyncZeroEntropy( - api_key=os.environ.get("ZEROENTROPY_API_KEY"), # This is the default and can be omitted + api_key="My API Key", http_client=DefaultAioHttpClient(), ) as client: response = await client.documents.add( From a8b35a89d5b113eb662bd3113c9433dc5f1382f9 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Tue, 22 Jul 2025 02:25:25 +0000 Subject: [PATCH 07/25] fix(parsing): ignore empty metadata --- src/zeroentropy/_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroentropy/_models.py b/src/zeroentropy/_models.py index 528d568..ffcbf67 100644 --- a/src/zeroentropy/_models.py +++ b/src/zeroentropy/_models.py @@ -439,7 +439,7 @@ def construct_type(*, value: object, type_: object, metadata: Optional[List[Any] type_ = type_.__value__ # type: ignore[unreachable] # unwrap `Annotated[T, ...]` -> `T` - if metadata is not None: + if metadata is not None and len(metadata) > 0: meta: tuple[Any, ...] = tuple(metadata) elif is_annotated_type(type_): meta = get_args(type_)[1:] From 5530ece04b51a3300c4202b66d4e50101c374a60 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 23 Jul 2025 02:27:59 +0000 Subject: [PATCH 08/25] fix(parsing): parse extra field types --- src/zeroentropy/_models.py | 25 +++++++++++++++++++++++-- tests/test_models.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/zeroentropy/_models.py b/src/zeroentropy/_models.py index ffcbf67..b8387ce 100644 --- a/src/zeroentropy/_models.py +++ b/src/zeroentropy/_models.py @@ -208,14 +208,18 @@ def construct( # pyright: ignore[reportIncompatibleMethodOverride] else: fields_values[name] = field_get_default(field) + extra_field_type = _get_extra_fields_type(__cls) + _extra = {} for key, value in values.items(): if key not in model_fields: + parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value + if PYDANTIC_V2: - _extra[key] = value + _extra[key] = parsed else: _fields_set.add(key) - fields_values[key] = value + fields_values[key] = parsed object.__setattr__(m, "__dict__", fields_values) @@ -370,6 +374,23 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object: return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None)) +def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None: + if not PYDANTIC_V2: + # TODO + return None + + schema = cls.__pydantic_core_schema__ + if schema["type"] == "model": + fields = schema["schema"] + if fields["type"] == "model-fields": + extras = fields.get("extras_schema") + if extras and "cls" in extras: + # mypy can't narrow the type + return extras["cls"] # type: ignore[no-any-return] + + return None + + def is_basemodel(type_: type) -> bool: """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`""" if is_union(type_): diff --git a/tests/test_models.py b/tests/test_models.py index cc3baf1..e5bdc49 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Union, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, cast from datetime import datetime, timezone from typing_extensions import Literal, Annotated, TypeAliasType @@ -934,3 +934,30 @@ class Type2(BaseModel): ) assert isinstance(model, Type1) assert isinstance(model.value, InnerType2) + + +@pytest.mark.skipif(not PYDANTIC_V2, reason="this is only supported in pydantic v2 for now") +def test_extra_properties() -> None: + class Item(BaseModel): + prop: int + + class Model(BaseModel): + __pydantic_extra__: Dict[str, Item] = Field(init=False) # pyright: ignore[reportIncompatibleVariableOverride] + + other: str + + if TYPE_CHECKING: + + def __getattr__(self, attr: str) -> Item: ... + + model = construct_type( + type_=Model, + value={ + "a": {"prop": 1}, + "other": "foo", + }, + ) + assert isinstance(model, Model) + assert model.a.prop == 1 + assert isinstance(model.a, Item) + assert model.other == "foo" From c4f86076c0b09422a445ffa91e6880a085fabacd Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Fri, 25 Jul 2025 05:57:08 +0000 Subject: [PATCH 09/25] chore(project): add settings file for vscode --- .gitignore | 1 - .vscode/settings.json | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index 8779740..95ceb18 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ .prism.log -.vscode _dev __pycache__ diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..5b01030 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.analysis.importFormat": "relative", +} From 4e036110dec2d075f4de4cf44691c131dfa545b7 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Thu, 31 Jul 2025 08:31:24 +0000 Subject: [PATCH 10/25] feat(client): support file upload requests --- src/zeroentropy/_base_client.py | 5 ++++- src/zeroentropy/_files.py | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/zeroentropy/_base_client.py b/src/zeroentropy/_base_client.py index 8b4b33c..8548951 100644 --- a/src/zeroentropy/_base_client.py +++ b/src/zeroentropy/_base_client.py @@ -532,7 +532,10 @@ def _build_request( is_body_allowed = options.method.lower() != "get" if is_body_allowed: - kwargs["json"] = json_data if is_given(json_data) else None + if isinstance(json_data, bytes): + kwargs["content"] = json_data + else: + kwargs["json"] = json_data if is_given(json_data) else None kwargs["files"] = files else: headers.pop("Content-Type", None) diff --git a/src/zeroentropy/_files.py b/src/zeroentropy/_files.py index 715cc20..cc14c14 100644 --- a/src/zeroentropy/_files.py +++ b/src/zeroentropy/_files.py @@ -69,12 +69,12 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes: return file if is_tuple_t(file): - return (file[0], _read_file_content(file[1]), *file[2:]) + return (file[0], read_file_content(file[1]), *file[2:]) raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") -def _read_file_content(file: FileContent) -> HttpxFileContent: +def read_file_content(file: FileContent) -> HttpxFileContent: if isinstance(file, os.PathLike): return pathlib.Path(file).read_bytes() return file @@ -111,12 +111,12 @@ async def _async_transform_file(file: FileTypes) -> HttpxFileTypes: return file if is_tuple_t(file): - return (file[0], await _async_read_file_content(file[1]), *file[2:]) + return (file[0], await async_read_file_content(file[1]), *file[2:]) raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") -async def _async_read_file_content(file: FileContent) -> HttpxFileContent: +async def async_read_file_content(file: FileContent) -> HttpxFileContent: if isinstance(file, os.PathLike): return await anyio.Path(file).read_bytes() From 80e5aae0af40e80b5962b0dbf2d5351d9df4432e Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 6 Aug 2025 10:44:41 +0000 Subject: [PATCH 11/25] chore(internal): fix ruff target version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 57d9123..a8bd9b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,7 +159,7 @@ reportPrivateUsage = false [tool.ruff] line-length = 120 output-format = "grouped" -target-version = "py37" +target-version = "py38" [tool.ruff.format] docstring-code-format = true From 0deb0ff394e034254520520675869023912cecc1 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Sat, 9 Aug 2025 06:03:13 +0000 Subject: [PATCH 12/25] chore: update @stainless-api/prism-cli to v5.15.0 --- scripts/mock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/mock b/scripts/mock index d2814ae..0b28f6e 100755 --- a/scripts/mock +++ b/scripts/mock @@ -21,7 +21,7 @@ echo "==> Starting mock server with URL ${URL}" # Run prism mock on the given spec if [ "$1" == "--daemon" ]; then - npm exec --package=@stainless-api/prism-cli@5.8.5 -- prism mock "$URL" &> .prism.log & + npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL" &> .prism.log & # Wait for server to come online echo -n "Waiting for server" @@ -37,5 +37,5 @@ if [ "$1" == "--daemon" ]; then echo else - npm exec --package=@stainless-api/prism-cli@5.8.5 -- prism mock "$URL" + npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL" fi From e5a7e0f1592dcfe868064dcdea1446b358662461 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Sat, 9 Aug 2025 06:05:29 +0000 Subject: [PATCH 13/25] chore(internal): update comment in script --- scripts/test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/test b/scripts/test index 2b87845..dbeda2d 100755 --- a/scripts/test +++ b/scripts/test @@ -43,7 +43,7 @@ elif ! prism_is_running ; then echo -e "To run the server, pass in the path or url of your OpenAPI" echo -e "spec to the prism command:" echo - echo -e " \$ ${YELLOW}npm exec --package=@stoplight/prism-cli@~5.3.2 -- prism mock path/to/your.openapi.yml${NC}" + echo -e " \$ ${YELLOW}npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock path/to/your.openapi.yml${NC}" echo exit 1 From 7e76e087f68489a4a24ac990d48dd665d81f7181 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Fri, 22 Aug 2025 08:15:37 +0000 Subject: [PATCH 14/25] chore: update github action --- .github/workflows/ci.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5743b12..1e4484c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,7 +36,7 @@ jobs: run: ./scripts/lint build: - if: github.repository == 'stainless-sdks/zeroentropy-python' && (github.event_name == 'push' || github.event.pull_request.head.repo.fork) + if: github.event_name == 'push' || github.event.pull_request.head.repo.fork timeout-minutes: 10 name: build permissions: @@ -61,12 +61,14 @@ jobs: run: rye build - name: Get GitHub OIDC Token + if: github.repository == 'stainless-sdks/zeroentropy-python' id: github-oidc uses: actions/github-script@v6 with: script: core.setOutput('github_token', await core.getIDToken()); - name: Upload tarball + if: github.repository == 'stainless-sdks/zeroentropy-python' env: URL: https://pkg.stainless.com/s AUTH: ${{ steps.github-oidc.outputs.github_token }} From bb0ac37c96cf27acdcc4e578e9419cbe4a457e18 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Tue, 26 Aug 2025 06:25:44 +0000 Subject: [PATCH 15/25] chore(internal): change ci workflow machines --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1e4484c..182006b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,7 +42,7 @@ jobs: permissions: contents: read id-token: write - runs-on: depot-ubuntu-24.04 + runs-on: ${{ github.repository == 'stainless-sdks/zeroentropy-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} steps: - uses: actions/checkout@v4 From 8f7795235684015a6cd419e2bab68b99ecf63b23 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 27 Aug 2025 09:29:49 +0000 Subject: [PATCH 16/25] fix: avoid newer type syntax --- src/zeroentropy/_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroentropy/_models.py b/src/zeroentropy/_models.py index b8387ce..92f7c10 100644 --- a/src/zeroentropy/_models.py +++ b/src/zeroentropy/_models.py @@ -304,7 +304,7 @@ def model_dump( exclude_none=exclude_none, ) - return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped + return cast("dict[str, Any]", json_safe(dumped)) if mode == "json" else dumped @override def model_dump_json( From 0c89992ab3ef88427c150b72fe510e26a202db43 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 27 Aug 2025 09:34:03 +0000 Subject: [PATCH 17/25] chore(internal): update pyright exclude list --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a8bd9b7..0c43a8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,6 +148,7 @@ exclude = [ "_dev", ".venv", ".nox", + ".git", ] reportImplicitOverride = true From 856feb326411abd1fc736eee5a64401f526285ca Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Sat, 30 Aug 2025 04:54:17 +0000 Subject: [PATCH 18/25] chore(internal): add Sequence related utils --- src/zeroentropy/_types.py | 36 +++++++++++++++++++++++++++++- src/zeroentropy/_utils/__init__.py | 1 + src/zeroentropy/_utils/_typing.py | 5 +++++ tests/utils.py | 10 ++++++++- 4 files changed, 50 insertions(+), 2 deletions(-) diff --git a/src/zeroentropy/_types.py b/src/zeroentropy/_types.py index 308c2cd..60a5cab 100644 --- a/src/zeroentropy/_types.py +++ b/src/zeroentropy/_types.py @@ -13,10 +13,21 @@ Mapping, TypeVar, Callable, + Iterator, Optional, Sequence, ) -from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable +from typing_extensions import ( + Set, + Literal, + Protocol, + TypeAlias, + TypedDict, + SupportsIndex, + overload, + override, + runtime_checkable, +) import httpx import pydantic @@ -217,3 +228,26 @@ class _GenericAlias(Protocol): class HttpxSendArgs(TypedDict, total=False): auth: httpx.Auth follow_redirects: bool + + +_T_co = TypeVar("_T_co", covariant=True) + + +if TYPE_CHECKING: + # This works because str.__contains__ does not accept object (either in typeshed or at runtime) + # https://github.com/hauntsaninja/useful_types/blob/5e9710f3875107d068e7679fd7fec9cfab0eff3b/useful_types/__init__.py#L285 + class SequenceNotStr(Protocol[_T_co]): + @overload + def __getitem__(self, index: SupportsIndex, /) -> _T_co: ... + @overload + def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ... + def __contains__(self, value: object, /) -> bool: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[_T_co]: ... + def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ... + def count(self, value: Any, /) -> int: ... + def __reversed__(self) -> Iterator[_T_co]: ... +else: + # just point this to a normal `Sequence` at runtime to avoid having to special case + # deserializing our custom sequence type + SequenceNotStr = Sequence diff --git a/src/zeroentropy/_utils/__init__.py b/src/zeroentropy/_utils/__init__.py index d4fda26..ca547ce 100644 --- a/src/zeroentropy/_utils/__init__.py +++ b/src/zeroentropy/_utils/__init__.py @@ -38,6 +38,7 @@ extract_type_arg as extract_type_arg, is_iterable_type as is_iterable_type, is_required_type as is_required_type, + is_sequence_type as is_sequence_type, is_annotated_type as is_annotated_type, is_type_alias_type as is_type_alias_type, strip_annotated_type as strip_annotated_type, diff --git a/src/zeroentropy/_utils/_typing.py b/src/zeroentropy/_utils/_typing.py index 1bac954..845cd6b 100644 --- a/src/zeroentropy/_utils/_typing.py +++ b/src/zeroentropy/_utils/_typing.py @@ -26,6 +26,11 @@ def is_list_type(typ: type) -> bool: return (get_origin(typ) or typ) == list +def is_sequence_type(typ: type) -> bool: + origin = get_origin(typ) or typ + return origin == typing_extensions.Sequence or origin == typing.Sequence or origin == _c_abc.Sequence + + def is_iterable_type(typ: type) -> bool: """If the given type is `typing.Iterable[T]`""" origin = get_origin(typ) or typ diff --git a/tests/utils.py b/tests/utils.py index 0b96481..46f0df3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ import inspect import traceback import contextlib -from typing import Any, TypeVar, Iterator, cast +from typing import Any, TypeVar, Iterator, Sequence, cast from datetime import date, datetime from typing_extensions import Literal, get_args, get_origin, assert_type @@ -15,6 +15,7 @@ is_list_type, is_union_type, extract_type_arg, + is_sequence_type, is_annotated_type, is_type_alias_type, ) @@ -71,6 +72,13 @@ def assert_matches_type( if is_list_type(type_): return _assert_list_type(type_, value) + if is_sequence_type(type_): + assert isinstance(value, Sequence) + inner_type = get_args(type_)[0] + for entry in value: # type: ignore + assert_type(inner_type, entry) # type: ignore + return + if origin == str: assert isinstance(value, str) elif origin == int: From 37625fd6bb4efb5a4926e565eaa7b25852d75f51 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 3 Sep 2025 04:12:54 +0000 Subject: [PATCH 19/25] feat(types): replace List[str] with SequenceNotStr in params --- src/zeroentropy/_utils/_transform.py | 6 ++++++ src/zeroentropy/resources/documents.py | 12 ++++++------ src/zeroentropy/resources/models.py | 8 ++++---- src/zeroentropy/types/document_add_params.py | 8 +++++--- src/zeroentropy/types/document_update_params.py | 6 ++++-- src/zeroentropy/types/model_rerank_params.py | 6 ++++-- 6 files changed, 29 insertions(+), 17 deletions(-) diff --git a/src/zeroentropy/_utils/_transform.py b/src/zeroentropy/_utils/_transform.py index b0cc20a..f0bcefd 100644 --- a/src/zeroentropy/_utils/_transform.py +++ b/src/zeroentropy/_utils/_transform.py @@ -16,6 +16,7 @@ lru_cache, is_mapping, is_iterable, + is_sequence, ) from .._files import is_base64_file_input from ._typing import ( @@ -24,6 +25,7 @@ extract_type_arg, is_iterable_type, is_required_type, + is_sequence_type, is_annotated_type, strip_annotated_type, ) @@ -184,6 +186,8 @@ def _transform_recursive( (is_list_type(stripped_type) and is_list(data)) # Iterable[T] or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + # Sequence[T] + or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str)) ): # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually # intended as an iterable, so we don't transform it. @@ -346,6 +350,8 @@ async def _async_transform_recursive( (is_list_type(stripped_type) and is_list(data)) # Iterable[T] or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + # Sequence[T] + or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str)) ): # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually # intended as an iterable, so we don't transform it. diff --git a/src/zeroentropy/resources/documents.py b/src/zeroentropy/resources/documents.py index bbe7f76..1330909 100644 --- a/src/zeroentropy/resources/documents.py +++ b/src/zeroentropy/resources/documents.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Dict, List, Union, Optional +from typing import Dict, Union, Optional from typing_extensions import Literal import httpx @@ -15,7 +15,7 @@ document_get_info_list_params, document_get_page_info_params, ) -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, SequenceNotStr from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -63,7 +63,7 @@ def update( collection_name: str, path: str, index_status: Optional[Literal["not_parsed", "not_indexed"]] | NotGiven = NOT_GIVEN, - metadata: Optional[Dict[str, Union[str, List[str]]]] | NotGiven = NOT_GIVEN, + metadata: Optional[Dict[str, Union[str, SequenceNotStr[str]]]] | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -182,7 +182,7 @@ def add( collection_name: str, content: document_add_params.Content, path: str, - metadata: Dict[str, Union[str, List[str]]] | NotGiven = NOT_GIVEN, + metadata: Dict[str, Union[str, SequenceNotStr[str]]] | NotGiven = NOT_GIVEN, overwrite: bool | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -466,7 +466,7 @@ async def update( collection_name: str, path: str, index_status: Optional[Literal["not_parsed", "not_indexed"]] | NotGiven = NOT_GIVEN, - metadata: Optional[Dict[str, Union[str, List[str]]]] | NotGiven = NOT_GIVEN, + metadata: Optional[Dict[str, Union[str, SequenceNotStr[str]]]] | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -585,7 +585,7 @@ async def add( collection_name: str, content: document_add_params.Content, path: str, - metadata: Dict[str, Union[str, List[str]]] | NotGiven = NOT_GIVEN, + metadata: Dict[str, Union[str, SequenceNotStr[str]]] | NotGiven = NOT_GIVEN, overwrite: bool | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. diff --git a/src/zeroentropy/resources/models.py b/src/zeroentropy/resources/models.py index 09db25a..4fd006f 100644 --- a/src/zeroentropy/resources/models.py +++ b/src/zeroentropy/resources/models.py @@ -2,12 +2,12 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional import httpx from ..types import model_rerank_params -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, SequenceNotStr from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -46,7 +46,7 @@ def with_streaming_response(self) -> ModelsResourceWithStreamingResponse: def rerank( self, *, - documents: List[str], + documents: SequenceNotStr[str], query: str, model: str | NotGiven = NOT_GIVEN, top_n: Optional[int] | NotGiven = NOT_GIVEN, @@ -126,7 +126,7 @@ def with_streaming_response(self) -> AsyncModelsResourceWithStreamingResponse: async def rerank( self, *, - documents: List[str], + documents: SequenceNotStr[str], query: str, model: str | NotGiven = NOT_GIVEN, top_n: Optional[int] | NotGiven = NOT_GIVEN, diff --git a/src/zeroentropy/types/document_add_params.py b/src/zeroentropy/types/document_add_params.py index ca0ab96..32608ea 100644 --- a/src/zeroentropy/types/document_add_params.py +++ b/src/zeroentropy/types/document_add_params.py @@ -2,9 +2,11 @@ from __future__ import annotations -from typing import Dict, List, Union +from typing import Dict, Union from typing_extensions import Literal, Required, TypeAlias, TypedDict +from .._types import SequenceNotStr + __all__ = [ "DocumentAddParams", "Content", @@ -37,7 +39,7 @@ class DocumentAddParams(TypedDict, total=False): unless `overwrite` is set to `true`. """ - metadata: Dict[str, Union[str, List[str]]] + metadata: Dict[str, Union[str, SequenceNotStr[str]]] """ This is a metadata JSON object that can be used to assign various metadata attributes to your document. The provided object must match the type @@ -65,7 +67,7 @@ class ContentAPITextDocument(TypedDict, total=False): class ContentAPITextPagesDocument(TypedDict, total=False): - pages: Required[List[str]] + pages: Required[SequenceNotStr[str]] """The content of this document, as an array of strings. Each string will be the content of a full page, and can be retrieved using the diff --git a/src/zeroentropy/types/document_update_params.py b/src/zeroentropy/types/document_update_params.py index 84acbe0..01e4e05 100644 --- a/src/zeroentropy/types/document_update_params.py +++ b/src/zeroentropy/types/document_update_params.py @@ -2,9 +2,11 @@ from __future__ import annotations -from typing import Dict, List, Union, Optional +from typing import Dict, Union, Optional from typing_extensions import Literal, Required, TypedDict +from .._types import SequenceNotStr + __all__ = ["DocumentUpdateParams"] @@ -27,7 +29,7 @@ class DocumentUpdateParams(TypedDict, total=False): failure. """ - metadata: Optional[Dict[str, Union[str, List[str]]]] + metadata: Optional[Dict[str, Union[str, SequenceNotStr[str]]]] """ If this field is provided, the given metadata json will replace the document's existing metadata json. In other words, if you want to add a new field, you will diff --git a/src/zeroentropy/types/model_rerank_params.py b/src/zeroentropy/types/model_rerank_params.py index e6d49fe..1005e7f 100644 --- a/src/zeroentropy/types/model_rerank_params.py +++ b/src/zeroentropy/types/model_rerank_params.py @@ -2,14 +2,16 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional from typing_extensions import Required, TypedDict +from .._types import SequenceNotStr + __all__ = ["ModelRerankParams"] class ModelRerankParams(TypedDict, total=False): - documents: Required[List[str]] + documents: Required[SequenceNotStr[str]] """The list of documents to rerank. Each document is a string.""" query: Required[str] From 372675f01b9fd520b0f49749b85314a3333dfa2a Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Thu, 4 Sep 2025 04:23:39 +0000 Subject: [PATCH 20/25] feat: improve future compat with pydantic v3 --- src/zeroentropy/_base_client.py | 6 +- src/zeroentropy/_compat.py | 96 +++++++-------- src/zeroentropy/_models.py | 80 ++++++------- src/zeroentropy/_utils/__init__.py | 10 +- src/zeroentropy/_utils/_compat.py | 45 +++++++ src/zeroentropy/_utils/_datetime_parse.py | 136 ++++++++++++++++++++++ src/zeroentropy/_utils/_transform.py | 6 +- src/zeroentropy/_utils/_typing.py | 2 +- src/zeroentropy/_utils/_utils.py | 1 - tests/test_models.py | 48 ++++---- tests/test_transform.py | 16 +-- tests/test_utils/test_datetime_parse.py | 110 +++++++++++++++++ tests/utils.py | 8 +- 13 files changed, 432 insertions(+), 132 deletions(-) create mode 100644 src/zeroentropy/_utils/_compat.py create mode 100644 src/zeroentropy/_utils/_datetime_parse.py create mode 100644 tests/test_utils/test_datetime_parse.py diff --git a/src/zeroentropy/_base_client.py b/src/zeroentropy/_base_client.py index 8548951..5b7681b 100644 --- a/src/zeroentropy/_base_client.py +++ b/src/zeroentropy/_base_client.py @@ -59,7 +59,7 @@ ModelBuilderProtocol, ) from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping -from ._compat import PYDANTIC_V2, model_copy, model_dump +from ._compat import PYDANTIC_V1, model_copy, model_dump from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type from ._response import ( APIResponse, @@ -232,7 +232,7 @@ def _set_private_attributes( model: Type[_T], options: FinalRequestOptions, ) -> None: - if PYDANTIC_V2 and getattr(self, "__pydantic_private__", None) is None: + if (not PYDANTIC_V1) and getattr(self, "__pydantic_private__", None) is None: self.__pydantic_private__ = {} self._model = model @@ -320,7 +320,7 @@ def _set_private_attributes( client: AsyncAPIClient, options: FinalRequestOptions, ) -> None: - if PYDANTIC_V2 and getattr(self, "__pydantic_private__", None) is None: + if (not PYDANTIC_V1) and getattr(self, "__pydantic_private__", None) is None: self.__pydantic_private__ = {} self._model = model diff --git a/src/zeroentropy/_compat.py b/src/zeroentropy/_compat.py index 92d9ee6..bdef67f 100644 --- a/src/zeroentropy/_compat.py +++ b/src/zeroentropy/_compat.py @@ -12,14 +12,13 @@ _T = TypeVar("_T") _ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) -# --------------- Pydantic v2 compatibility --------------- +# --------------- Pydantic v2, v3 compatibility --------------- # Pyright incorrectly reports some of our functions as overriding a method when they don't # pyright: reportIncompatibleMethodOverride=false -PYDANTIC_V2 = pydantic.VERSION.startswith("2.") +PYDANTIC_V1 = pydantic.VERSION.startswith("1.") -# v1 re-exports if TYPE_CHECKING: def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001 @@ -44,90 +43,92 @@ def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001 ... else: - if PYDANTIC_V2: - from pydantic.v1.typing import ( + # v1 re-exports + if PYDANTIC_V1: + from pydantic.typing import ( get_args as get_args, is_union as is_union, get_origin as get_origin, is_typeddict as is_typeddict, is_literal_type as is_literal_type, ) - from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime + from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime else: - from pydantic.typing import ( + from ._utils import ( get_args as get_args, is_union as is_union, get_origin as get_origin, + parse_date as parse_date, is_typeddict as is_typeddict, + parse_datetime as parse_datetime, is_literal_type as is_literal_type, ) - from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # refactored config if TYPE_CHECKING: from pydantic import ConfigDict as ConfigDict else: - if PYDANTIC_V2: - from pydantic import ConfigDict - else: + if PYDANTIC_V1: # TODO: provide an error message here? ConfigDict = None + else: + from pydantic import ConfigDict as ConfigDict # renamed methods / properties def parse_obj(model: type[_ModelT], value: object) -> _ModelT: - if PYDANTIC_V2: - return model.model_validate(value) - else: + if PYDANTIC_V1: return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + else: + return model.model_validate(value) def field_is_required(field: FieldInfo) -> bool: - if PYDANTIC_V2: - return field.is_required() - return field.required # type: ignore + if PYDANTIC_V1: + return field.required # type: ignore + return field.is_required() def field_get_default(field: FieldInfo) -> Any: value = field.get_default() - if PYDANTIC_V2: - from pydantic_core import PydanticUndefined - - if value == PydanticUndefined: - return None + if PYDANTIC_V1: return value + from pydantic_core import PydanticUndefined + + if value == PydanticUndefined: + return None return value def field_outer_type(field: FieldInfo) -> Any: - if PYDANTIC_V2: - return field.annotation - return field.outer_type_ # type: ignore + if PYDANTIC_V1: + return field.outer_type_ # type: ignore + return field.annotation def get_model_config(model: type[pydantic.BaseModel]) -> Any: - if PYDANTIC_V2: - return model.model_config - return model.__config__ # type: ignore + if PYDANTIC_V1: + return model.__config__ # type: ignore + return model.model_config def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]: - if PYDANTIC_V2: - return model.model_fields - return model.__fields__ # type: ignore + if PYDANTIC_V1: + return model.__fields__ # type: ignore + return model.model_fields def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT: - if PYDANTIC_V2: - return model.model_copy(deep=deep) - return model.copy(deep=deep) # type: ignore + if PYDANTIC_V1: + return model.copy(deep=deep) # type: ignore + return model.model_copy(deep=deep) def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: - if PYDANTIC_V2: - return model.model_dump_json(indent=indent) - return model.json(indent=indent) # type: ignore + if PYDANTIC_V1: + return model.json(indent=indent) # type: ignore + return model.model_dump_json(indent=indent) def model_dump( @@ -139,14 +140,14 @@ def model_dump( warnings: bool = True, mode: Literal["json", "python"] = "python", ) -> dict[str, Any]: - if PYDANTIC_V2 or hasattr(model, "model_dump"): + if (not PYDANTIC_V1) or hasattr(model, "model_dump"): return model.model_dump( mode=mode, exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, # warnings are not supported in Pydantic v1 - warnings=warnings if PYDANTIC_V2 else True, + warnings=True if PYDANTIC_V1 else warnings, ) return cast( "dict[str, Any]", @@ -159,9 +160,9 @@ def model_dump( def model_parse(model: type[_ModelT], data: Any) -> _ModelT: - if PYDANTIC_V2: - return model.model_validate(data) - return model.parse_obj(data) # pyright: ignore[reportDeprecated] + if PYDANTIC_V1: + return model.parse_obj(data) # pyright: ignore[reportDeprecated] + return model.model_validate(data) # generic models @@ -170,17 +171,16 @@ def model_parse(model: type[_ModelT], data: Any) -> _ModelT: class GenericModel(pydantic.BaseModel): ... else: - if PYDANTIC_V2: + if PYDANTIC_V1: + import pydantic.generics + + class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... + else: # there no longer needs to be a distinction in v2 but # we still have to create our own subclass to avoid # inconsistent MRO ordering errors class GenericModel(pydantic.BaseModel): ... - else: - import pydantic.generics - - class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... - # cached properties if TYPE_CHECKING: diff --git a/src/zeroentropy/_models.py b/src/zeroentropy/_models.py index 92f7c10..3a6017e 100644 --- a/src/zeroentropy/_models.py +++ b/src/zeroentropy/_models.py @@ -50,7 +50,7 @@ strip_annotated_type, ) from ._compat import ( - PYDANTIC_V2, + PYDANTIC_V1, ConfigDict, GenericModel as BaseGenericModel, get_args, @@ -81,11 +81,7 @@ class _ConfigProtocol(Protocol): class BaseModel(pydantic.BaseModel): - if PYDANTIC_V2: - model_config: ClassVar[ConfigDict] = ConfigDict( - extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) - ) - else: + if PYDANTIC_V1: @property @override @@ -95,6 +91,10 @@ def model_fields_set(self) -> set[str]: class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] extra: Any = pydantic.Extra.allow # type: ignore + else: + model_config: ClassVar[ConfigDict] = ConfigDict( + extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) + ) def to_dict( self, @@ -215,25 +215,25 @@ def construct( # pyright: ignore[reportIncompatibleMethodOverride] if key not in model_fields: parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value - if PYDANTIC_V2: - _extra[key] = parsed - else: + if PYDANTIC_V1: _fields_set.add(key) fields_values[key] = parsed + else: + _extra[key] = parsed object.__setattr__(m, "__dict__", fields_values) - if PYDANTIC_V2: - # these properties are copied from Pydantic's `model_construct()` method - object.__setattr__(m, "__pydantic_private__", None) - object.__setattr__(m, "__pydantic_extra__", _extra) - object.__setattr__(m, "__pydantic_fields_set__", _fields_set) - else: + if PYDANTIC_V1: # init_private_attributes() does not exist in v2 m._init_private_attributes() # type: ignore # copied from Pydantic v1's `construct()` method object.__setattr__(m, "__fields_set__", _fields_set) + else: + # these properties are copied from Pydantic's `model_construct()` method + object.__setattr__(m, "__pydantic_private__", None) + object.__setattr__(m, "__pydantic_extra__", _extra) + object.__setattr__(m, "__pydantic_fields_set__", _fields_set) return m @@ -243,7 +243,7 @@ def construct( # pyright: ignore[reportIncompatibleMethodOverride] # although not in practice model_construct = construct - if not PYDANTIC_V2: + if PYDANTIC_V1: # we define aliases for some of the new pydantic v2 methods so # that we can just document these methods without having to specify # a specific pydantic version as some users may not know which @@ -363,10 +363,10 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object: if value is None: return field_get_default(field) - if PYDANTIC_V2: - type_ = field.annotation - else: + if PYDANTIC_V1: type_ = cast(type, field.outer_type_) # type: ignore + else: + type_ = field.annotation # type: ignore if type_ is None: raise RuntimeError(f"Unexpected field type is None for {key}") @@ -375,7 +375,7 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object: def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None: - if not PYDANTIC_V2: + if PYDANTIC_V1: # TODO return None @@ -628,30 +628,30 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, for variant in get_args(union): variant = strip_annotated_type(variant) if is_basemodel_type(variant): - if PYDANTIC_V2: - field = _extract_field_schema_pv2(variant, discriminator_field_name) - if not field: + if PYDANTIC_V1: + field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + if not field_info: continue # Note: if one variant defines an alias then they all should - discriminator_alias = field.get("serialization_alias") - - field_schema = field["schema"] + discriminator_alias = field_info.alias - if field_schema["type"] == "literal": - for entry in cast("LiteralSchema", field_schema)["expected"]: + if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation): + for entry in get_args(annotation): if isinstance(entry, str): mapping[entry] = variant else: - field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] - if not field_info: + field = _extract_field_schema_pv2(variant, discriminator_field_name) + if not field: continue # Note: if one variant defines an alias then they all should - discriminator_alias = field_info.alias + discriminator_alias = field.get("serialization_alias") - if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation): - for entry in get_args(annotation): + field_schema = field["schema"] + + if field_schema["type"] == "literal": + for entry in cast("LiteralSchema", field_schema)["expected"]: if isinstance(entry, str): mapping[entry] = variant @@ -714,7 +714,7 @@ class GenericModel(BaseGenericModel, BaseModel): pass -if PYDANTIC_V2: +if not PYDANTIC_V1: from pydantic import TypeAdapter as _TypeAdapter _CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter)) @@ -782,12 +782,12 @@ class FinalRequestOptions(pydantic.BaseModel): json_data: Union[Body, None] = None extra_json: Union[AnyMapping, None] = None - if PYDANTIC_V2: - model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) - else: + if PYDANTIC_V1: class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] arbitrary_types_allowed: bool = True + else: + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) def get_max_retries(self, max_retries: int) -> int: if isinstance(self.max_retries, NotGiven): @@ -820,9 +820,9 @@ def construct( # type: ignore key: strip_not_given(value) for key, value in values.items() } - if PYDANTIC_V2: - return super().model_construct(_fields_set, **kwargs) - return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated] + if PYDANTIC_V1: + return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated] + return super().model_construct(_fields_set, **kwargs) if not TYPE_CHECKING: # type checkers incorrectly complain about this assignment diff --git a/src/zeroentropy/_utils/__init__.py b/src/zeroentropy/_utils/__init__.py index ca547ce..dc64e29 100644 --- a/src/zeroentropy/_utils/__init__.py +++ b/src/zeroentropy/_utils/__init__.py @@ -10,7 +10,6 @@ lru_cache as lru_cache, is_mapping as is_mapping, is_tuple_t as is_tuple_t, - parse_date as parse_date, is_iterable as is_iterable, is_sequence as is_sequence, coerce_float as coerce_float, @@ -23,7 +22,6 @@ coerce_boolean as coerce_boolean, coerce_integer as coerce_integer, file_from_path as file_from_path, - parse_datetime as parse_datetime, strip_not_given as strip_not_given, deepcopy_minimal as deepcopy_minimal, get_async_library as get_async_library, @@ -32,6 +30,13 @@ maybe_coerce_boolean as maybe_coerce_boolean, maybe_coerce_integer as maybe_coerce_integer, ) +from ._compat import ( + get_args as get_args, + is_union as is_union, + get_origin as get_origin, + is_typeddict as is_typeddict, + is_literal_type as is_literal_type, +) from ._typing import ( is_list_type as is_list_type, is_union_type as is_union_type, @@ -56,3 +61,4 @@ function_has_argument as function_has_argument, assert_signatures_in_sync as assert_signatures_in_sync, ) +from ._datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime diff --git a/src/zeroentropy/_utils/_compat.py b/src/zeroentropy/_utils/_compat.py new file mode 100644 index 0000000..dd70323 --- /dev/null +++ b/src/zeroentropy/_utils/_compat.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import sys +import typing_extensions +from typing import Any, Type, Union, Literal, Optional +from datetime import date, datetime +from typing_extensions import get_args as _get_args, get_origin as _get_origin + +from .._types import StrBytesIntFloat +from ._datetime_parse import parse_date as _parse_date, parse_datetime as _parse_datetime + +_LITERAL_TYPES = {Literal, typing_extensions.Literal} + + +def get_args(tp: type[Any]) -> tuple[Any, ...]: + return _get_args(tp) + + +def get_origin(tp: type[Any]) -> type[Any] | None: + return _get_origin(tp) + + +def is_union(tp: Optional[Type[Any]]) -> bool: + if sys.version_info < (3, 10): + return tp is Union # type: ignore[comparison-overlap] + else: + import types + + return tp is Union or tp is types.UnionType + + +def is_typeddict(tp: Type[Any]) -> bool: + return typing_extensions.is_typeddict(tp) + + +def is_literal_type(tp: Type[Any]) -> bool: + return get_origin(tp) in _LITERAL_TYPES + + +def parse_date(value: Union[date, StrBytesIntFloat]) -> date: + return _parse_date(value) + + +def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: + return _parse_datetime(value) diff --git a/src/zeroentropy/_utils/_datetime_parse.py b/src/zeroentropy/_utils/_datetime_parse.py new file mode 100644 index 0000000..7cb9d9e --- /dev/null +++ b/src/zeroentropy/_utils/_datetime_parse.py @@ -0,0 +1,136 @@ +""" +This file contains code from https://github.com/pydantic/pydantic/blob/main/pydantic/v1/datetime_parse.py +without the Pydantic v1 specific errors. +""" + +from __future__ import annotations + +import re +from typing import Dict, Union, Optional +from datetime import date, datetime, timezone, timedelta + +from .._types import StrBytesIntFloat + +date_expr = r"(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})" +time_expr = ( + r"(?P\d{1,2}):(?P\d{1,2})" + r"(?::(?P\d{1,2})(?:\.(?P\d{1,6})\d{0,6})?)?" + r"(?PZ|[+-]\d{2}(?::?\d{2})?)?$" +) + +date_re = re.compile(f"{date_expr}$") +datetime_re = re.compile(f"{date_expr}[T ]{time_expr}") + + +EPOCH = datetime(1970, 1, 1) +# if greater than this, the number is in ms, if less than or equal it's in seconds +# (in seconds this is 11th October 2603, in ms it's 20th August 1970) +MS_WATERSHED = int(2e10) +# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9 +MAX_NUMBER = int(3e20) + + +def _get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]: + if isinstance(value, (int, float)): + return value + try: + return float(value) + except ValueError: + return None + except TypeError: + raise TypeError(f"invalid type; expected {native_expected_type}, string, bytes, int or float") from None + + +def _from_unix_seconds(seconds: Union[int, float]) -> datetime: + if seconds > MAX_NUMBER: + return datetime.max + elif seconds < -MAX_NUMBER: + return datetime.min + + while abs(seconds) > MS_WATERSHED: + seconds /= 1000 + dt = EPOCH + timedelta(seconds=seconds) + return dt.replace(tzinfo=timezone.utc) + + +def _parse_timezone(value: Optional[str]) -> Union[None, int, timezone]: + if value == "Z": + return timezone.utc + elif value is not None: + offset_mins = int(value[-2:]) if len(value) > 3 else 0 + offset = 60 * int(value[1:3]) + offset_mins + if value[0] == "-": + offset = -offset + return timezone(timedelta(minutes=offset)) + else: + return None + + +def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: + """ + Parse a datetime/int/float/string and return a datetime.datetime. + + This function supports time zone offsets. When the input contains one, + the output uses a timezone with a fixed offset from UTC. + + Raise ValueError if the input is well formatted but not a valid datetime. + Raise ValueError if the input isn't well formatted. + """ + if isinstance(value, datetime): + return value + + number = _get_numeric(value, "datetime") + if number is not None: + return _from_unix_seconds(number) + + if isinstance(value, bytes): + value = value.decode() + + assert not isinstance(value, (float, int)) + + match = datetime_re.match(value) + if match is None: + raise ValueError("invalid datetime format") + + kw = match.groupdict() + if kw["microsecond"]: + kw["microsecond"] = kw["microsecond"].ljust(6, "0") + + tzinfo = _parse_timezone(kw.pop("tzinfo")) + kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None} + kw_["tzinfo"] = tzinfo + + return datetime(**kw_) # type: ignore + + +def parse_date(value: Union[date, StrBytesIntFloat]) -> date: + """ + Parse a date/int/float/string and return a datetime.date. + + Raise ValueError if the input is well formatted but not a valid date. + Raise ValueError if the input isn't well formatted. + """ + if isinstance(value, date): + if isinstance(value, datetime): + return value.date() + else: + return value + + number = _get_numeric(value, "date") + if number is not None: + return _from_unix_seconds(number).date() + + if isinstance(value, bytes): + value = value.decode() + + assert not isinstance(value, (float, int)) + match = date_re.match(value) + if match is None: + raise ValueError("invalid date format") + + kw = {k: int(v) for k, v in match.groupdict().items()} + + try: + return date(**kw) + except ValueError: + raise ValueError("invalid date format") from None diff --git a/src/zeroentropy/_utils/_transform.py b/src/zeroentropy/_utils/_transform.py index f0bcefd..c19124f 100644 --- a/src/zeroentropy/_utils/_transform.py +++ b/src/zeroentropy/_utils/_transform.py @@ -19,6 +19,7 @@ is_sequence, ) from .._files import is_base64_file_input +from ._compat import get_origin, is_typeddict from ._typing import ( is_list_type, is_union_type, @@ -29,7 +30,6 @@ is_annotated_type, strip_annotated_type, ) -from .._compat import get_origin, model_dump, is_typeddict _T = TypeVar("_T") @@ -169,6 +169,8 @@ def _transform_recursive( Defaults to the same value as the `annotation` argument. """ + from .._compat import model_dump + if inner_type is None: inner_type = annotation @@ -333,6 +335,8 @@ async def _async_transform_recursive( Defaults to the same value as the `annotation` argument. """ + from .._compat import model_dump + if inner_type is None: inner_type = annotation diff --git a/src/zeroentropy/_utils/_typing.py b/src/zeroentropy/_utils/_typing.py index 845cd6b..193109f 100644 --- a/src/zeroentropy/_utils/_typing.py +++ b/src/zeroentropy/_utils/_typing.py @@ -15,7 +15,7 @@ from ._utils import lru_cache from .._types import InheritsGeneric -from .._compat import is_union as _is_union +from ._compat import is_union as _is_union def is_annotated_type(typ: type) -> bool: diff --git a/src/zeroentropy/_utils/_utils.py b/src/zeroentropy/_utils/_utils.py index ea3cf3f..f081859 100644 --- a/src/zeroentropy/_utils/_utils.py +++ b/src/zeroentropy/_utils/_utils.py @@ -22,7 +22,6 @@ import sniffio from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike -from .._compat import parse_date as parse_date, parse_datetime as parse_datetime _T = TypeVar("_T") _TupleT = TypeVar("_TupleT", bound=Tuple[object, ...]) diff --git a/tests/test_models.py b/tests/test_models.py index e5bdc49..3b040dd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -8,7 +8,7 @@ from pydantic import Field from zeroentropy._utils import PropertyInfo -from zeroentropy._compat import PYDANTIC_V2, parse_obj, model_dump, model_json +from zeroentropy._compat import PYDANTIC_V1, parse_obj, model_dump, model_json from zeroentropy._models import BaseModel, construct_type @@ -294,12 +294,12 @@ class Model(BaseModel): assert cast(bool, m.foo) is True m = Model.construct(foo={"name": 3}) - if PYDANTIC_V2: - assert isinstance(m.foo, Submodel1) - assert m.foo.name == 3 # type: ignore - else: + if PYDANTIC_V1: assert isinstance(m.foo, Submodel2) assert m.foo.name == "3" + else: + assert isinstance(m.foo, Submodel1) + assert m.foo.name == 3 # type: ignore def test_list_of_unions() -> None: @@ -426,10 +426,10 @@ class Model(BaseModel): expected = datetime(2019, 12, 27, 18, 11, 19, 117000, tzinfo=timezone.utc) - if PYDANTIC_V2: - expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}' - else: + if PYDANTIC_V1: expected_json = '{"created_at": "2019-12-27T18:11:19.117000+00:00"}' + else: + expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}' model = Model.construct(created_at="2019-12-27T18:11:19.117Z") assert model.created_at == expected @@ -531,7 +531,7 @@ class Model2(BaseModel): assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)} assert m4.to_dict(mode="json") == {"created_at": time_str} - if not PYDANTIC_V2: + if PYDANTIC_V1: with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): m.to_dict(warnings=False) @@ -556,7 +556,7 @@ class Model(BaseModel): assert m3.model_dump() == {"foo": None} assert m3.model_dump(exclude_none=True) == {} - if not PYDANTIC_V2: + if PYDANTIC_V1: with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): m.model_dump(round_trip=True) @@ -580,10 +580,10 @@ class Model(BaseModel): assert json.loads(m.to_json()) == {"FOO": "hello"} assert json.loads(m.to_json(use_api_names=False)) == {"foo": "hello"} - if PYDANTIC_V2: - assert m.to_json(indent=None) == '{"FOO":"hello"}' - else: + if PYDANTIC_V1: assert m.to_json(indent=None) == '{"FOO": "hello"}' + else: + assert m.to_json(indent=None) == '{"FOO":"hello"}' m2 = Model() assert json.loads(m2.to_json()) == {} @@ -595,7 +595,7 @@ class Model(BaseModel): assert json.loads(m3.to_json()) == {"FOO": None} assert json.loads(m3.to_json(exclude_none=True)) == {} - if not PYDANTIC_V2: + if PYDANTIC_V1: with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): m.to_json(warnings=False) @@ -622,7 +622,7 @@ class Model(BaseModel): assert json.loads(m3.model_dump_json()) == {"foo": None} assert json.loads(m3.model_dump_json(exclude_none=True)) == {} - if not PYDANTIC_V2: + if PYDANTIC_V1: with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): m.model_dump_json(round_trip=True) @@ -679,12 +679,12 @@ class B(BaseModel): ) assert isinstance(m, A) assert m.type == "a" - if PYDANTIC_V2: - assert m.data == 100 # type: ignore[comparison-overlap] - else: + if PYDANTIC_V1: # pydantic v1 automatically converts inputs to strings # if the expected type is a str assert m.data == "100" + else: + assert m.data == 100 # type: ignore[comparison-overlap] def test_discriminated_unions_unknown_variant() -> None: @@ -768,12 +768,12 @@ class B(BaseModel): ) assert isinstance(m, A) assert m.foo_type == "a" - if PYDANTIC_V2: - assert m.data == 100 # type: ignore[comparison-overlap] - else: + if PYDANTIC_V1: # pydantic v1 automatically converts inputs to strings # if the expected type is a str assert m.data == "100" + else: + assert m.data == 100 # type: ignore[comparison-overlap] def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None: @@ -833,7 +833,7 @@ class B(BaseModel): assert UnionType.__discriminator__ is discriminator -@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1") +@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") def test_type_alias_type() -> None: Alias = TypeAliasType("Alias", str) # pyright: ignore @@ -849,7 +849,7 @@ class Model(BaseModel): assert m.union == "bar" -@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1") +@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") def test_field_named_cls() -> None: class Model(BaseModel): cls: str @@ -936,7 +936,7 @@ class Type2(BaseModel): assert isinstance(model.value, InnerType2) -@pytest.mark.skipif(not PYDANTIC_V2, reason="this is only supported in pydantic v2 for now") +@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2 for now") def test_extra_properties() -> None: class Item(BaseModel): prop: int diff --git a/tests/test_transform.py b/tests/test_transform.py index 956dba0..f1036bb 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -15,7 +15,7 @@ parse_datetime, async_transform as _async_transform, ) -from zeroentropy._compat import PYDANTIC_V2 +from zeroentropy._compat import PYDANTIC_V1 from zeroentropy._models import BaseModel _T = TypeVar("_T") @@ -189,7 +189,7 @@ class DateModel(BaseModel): @pytest.mark.asyncio async def test_iso8601_format(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") - tz = "Z" if PYDANTIC_V2 else "+00:00" + tz = "+00:00" if PYDANTIC_V1 else "Z" assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz} # type: ignore[comparison-overlap] @@ -297,11 +297,11 @@ async def test_pydantic_unknown_field(use_async: bool) -> None: @pytest.mark.asyncio async def test_pydantic_mismatched_types(use_async: bool) -> None: model = MyModel.construct(foo=True) - if PYDANTIC_V2: + if PYDANTIC_V1: + params = await transform(model, Any, use_async) + else: with pytest.warns(UserWarning): params = await transform(model, Any, use_async) - else: - params = await transform(model, Any, use_async) assert cast(Any, params) == {"foo": True} @@ -309,11 +309,11 @@ async def test_pydantic_mismatched_types(use_async: bool) -> None: @pytest.mark.asyncio async def test_pydantic_mismatched_object_type(use_async: bool) -> None: model = MyModel.construct(foo=MyModel.construct(hello="world")) - if PYDANTIC_V2: + if PYDANTIC_V1: + params = await transform(model, Any, use_async) + else: with pytest.warns(UserWarning): params = await transform(model, Any, use_async) - else: - params = await transform(model, Any, use_async) assert cast(Any, params) == {"foo": {"hello": "world"}} diff --git a/tests/test_utils/test_datetime_parse.py b/tests/test_utils/test_datetime_parse.py new file mode 100644 index 0000000..2b214ba --- /dev/null +++ b/tests/test_utils/test_datetime_parse.py @@ -0,0 +1,110 @@ +""" +Copied from https://github.com/pydantic/pydantic/blob/v1.10.22/tests/test_datetime_parse.py +with modifications so it works without pydantic v1 imports. +""" + +from typing import Type, Union +from datetime import date, datetime, timezone, timedelta + +import pytest + +from zeroentropy._utils import parse_date, parse_datetime + + +def create_tz(minutes: int) -> timezone: + return timezone(timedelta(minutes=minutes)) + + +@pytest.mark.parametrize( + "value,result", + [ + # Valid inputs + ("1494012444.883309", date(2017, 5, 5)), + (b"1494012444.883309", date(2017, 5, 5)), + (1_494_012_444.883_309, date(2017, 5, 5)), + ("1494012444", date(2017, 5, 5)), + (1_494_012_444, date(2017, 5, 5)), + (0, date(1970, 1, 1)), + ("2012-04-23", date(2012, 4, 23)), + (b"2012-04-23", date(2012, 4, 23)), + ("2012-4-9", date(2012, 4, 9)), + (date(2012, 4, 9), date(2012, 4, 9)), + (datetime(2012, 4, 9, 12, 15), date(2012, 4, 9)), + # Invalid inputs + ("x20120423", ValueError), + ("2012-04-56", ValueError), + (19_999_999_999, date(2603, 10, 11)), # just before watershed + (20_000_000_001, date(1970, 8, 20)), # just after watershed + (1_549_316_052, date(2019, 2, 4)), # nowish in s + (1_549_316_052_104, date(2019, 2, 4)), # nowish in ms + (1_549_316_052_104_324, date(2019, 2, 4)), # nowish in μs + (1_549_316_052_104_324_096, date(2019, 2, 4)), # nowish in ns + ("infinity", date(9999, 12, 31)), + ("inf", date(9999, 12, 31)), + (float("inf"), date(9999, 12, 31)), + ("infinity ", date(9999, 12, 31)), + (int("1" + "0" * 100), date(9999, 12, 31)), + (1e1000, date(9999, 12, 31)), + ("-infinity", date(1, 1, 1)), + ("-inf", date(1, 1, 1)), + ("nan", ValueError), + ], +) +def test_date_parsing(value: Union[str, bytes, int, float], result: Union[date, Type[Exception]]) -> None: + if type(result) == type and issubclass(result, Exception): # pyright: ignore[reportUnnecessaryIsInstance] + with pytest.raises(result): + parse_date(value) + else: + assert parse_date(value) == result + + +@pytest.mark.parametrize( + "value,result", + [ + # Valid inputs + # values in seconds + ("1494012444.883309", datetime(2017, 5, 5, 19, 27, 24, 883_309, tzinfo=timezone.utc)), + (1_494_012_444.883_309, datetime(2017, 5, 5, 19, 27, 24, 883_309, tzinfo=timezone.utc)), + ("1494012444", datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + (b"1494012444", datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + (1_494_012_444, datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + # values in ms + ("1494012444000.883309", datetime(2017, 5, 5, 19, 27, 24, 883, tzinfo=timezone.utc)), + ("-1494012444000.883309", datetime(1922, 8, 29, 4, 32, 35, 999117, tzinfo=timezone.utc)), + (1_494_012_444_000, datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + ("2012-04-23T09:15:00", datetime(2012, 4, 23, 9, 15)), + ("2012-4-9 4:8:16", datetime(2012, 4, 9, 4, 8, 16)), + ("2012-04-23T09:15:00Z", datetime(2012, 4, 23, 9, 15, 0, 0, timezone.utc)), + ("2012-4-9 4:8:16-0320", datetime(2012, 4, 9, 4, 8, 16, 0, create_tz(-200))), + ("2012-04-23T10:20:30.400+02:30", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(150))), + ("2012-04-23T10:20:30.400+02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(120))), + ("2012-04-23T10:20:30.400-02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(-120))), + (b"2012-04-23T10:20:30.400-02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(-120))), + (datetime(2017, 5, 5), datetime(2017, 5, 5)), + (0, datetime(1970, 1, 1, 0, 0, 0, tzinfo=timezone.utc)), + # Invalid inputs + ("x20120423091500", ValueError), + ("2012-04-56T09:15:90", ValueError), + ("2012-04-23T11:05:00-25:00", ValueError), + (19_999_999_999, datetime(2603, 10, 11, 11, 33, 19, tzinfo=timezone.utc)), # just before watershed + (20_000_000_001, datetime(1970, 8, 20, 11, 33, 20, 1000, tzinfo=timezone.utc)), # just after watershed + (1_549_316_052, datetime(2019, 2, 4, 21, 34, 12, 0, tzinfo=timezone.utc)), # nowish in s + (1_549_316_052_104, datetime(2019, 2, 4, 21, 34, 12, 104_000, tzinfo=timezone.utc)), # nowish in ms + (1_549_316_052_104_324, datetime(2019, 2, 4, 21, 34, 12, 104_324, tzinfo=timezone.utc)), # nowish in μs + (1_549_316_052_104_324_096, datetime(2019, 2, 4, 21, 34, 12, 104_324, tzinfo=timezone.utc)), # nowish in ns + ("infinity", datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("inf", datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("inf ", datetime(9999, 12, 31, 23, 59, 59, 999999)), + (1e50, datetime(9999, 12, 31, 23, 59, 59, 999999)), + (float("inf"), datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("-infinity", datetime(1, 1, 1, 0, 0)), + ("-inf", datetime(1, 1, 1, 0, 0)), + ("nan", ValueError), + ], +) +def test_datetime_parsing(value: Union[str, bytes, int, float], result: Union[datetime, Type[Exception]]) -> None: + if type(result) == type and issubclass(result, Exception): # pyright: ignore[reportUnnecessaryIsInstance] + with pytest.raises(result): + parse_datetime(value) + else: + assert parse_datetime(value) == result diff --git a/tests/utils.py b/tests/utils.py index 46f0df3..4953809 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,7 +19,7 @@ is_annotated_type, is_type_alias_type, ) -from zeroentropy._compat import PYDANTIC_V2, field_outer_type, get_model_fields +from zeroentropy._compat import PYDANTIC_V1, field_outer_type, get_model_fields from zeroentropy._models import BaseModel BaseModelT = TypeVar("BaseModelT", bound=BaseModel) @@ -28,12 +28,12 @@ def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool: for name, field in get_model_fields(model).items(): field_value = getattr(value, name) - if PYDANTIC_V2: - allow_none = False - else: + if PYDANTIC_V1: # in v1 nullability was structured differently # https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields allow_none = getattr(field, "allow_none", False) + else: + allow_none = False assert_matches_type( field_outer_type(field), From 97977db20f3f27cb1b01f4be6e3538796c0a5414 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Fri, 5 Sep 2025 04:58:17 +0000 Subject: [PATCH 21/25] chore(internal): move mypy configurations to `pyproject.toml` file --- mypy.ini | 50 ------------------------------------------------ pyproject.toml | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 50 deletions(-) delete mode 100644 mypy.ini diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 88c874b..0000000 --- a/mypy.ini +++ /dev/null @@ -1,50 +0,0 @@ -[mypy] -pretty = True -show_error_codes = True - -# Exclude _files.py because mypy isn't smart enough to apply -# the correct type narrowing and as this is an internal module -# it's fine to just use Pyright. -# -# We also exclude our `tests` as mypy doesn't always infer -# types correctly and Pyright will still catch any type errors. -exclude = ^(src/zeroentropy/_files\.py|_dev/.*\.py|tests/.*)$ - -strict_equality = True -implicit_reexport = True -check_untyped_defs = True -no_implicit_optional = True - -warn_return_any = True -warn_unreachable = True -warn_unused_configs = True - -# Turn these options off as it could cause conflicts -# with the Pyright options. -warn_unused_ignores = False -warn_redundant_casts = False - -disallow_any_generics = True -disallow_untyped_defs = True -disallow_untyped_calls = True -disallow_subclassing_any = True -disallow_incomplete_defs = True -disallow_untyped_decorators = True -cache_fine_grained = True - -# By default, mypy reports an error if you assign a value to the result -# of a function call that doesn't return anything. We do this in our test -# cases: -# ``` -# result = ... -# assert result is None -# ``` -# Changing this codegen to make mypy happy would increase complexity -# and would not be worth it. -disable_error_code = func-returns-value,overload-cannot-match - -# https://github.com/python/mypy/issues/12162 -[mypy.overrides] -module = "black.files.*" -ignore_errors = true -ignore_missing_imports = true diff --git a/pyproject.toml b/pyproject.toml index 0c43a8f..5400c8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,6 +157,58 @@ reportOverlappingOverload = false reportImportCycles = false reportPrivateUsage = false +[tool.mypy] +pretty = true +show_error_codes = true + +# Exclude _files.py because mypy isn't smart enough to apply +# the correct type narrowing and as this is an internal module +# it's fine to just use Pyright. +# +# We also exclude our `tests` as mypy doesn't always infer +# types correctly and Pyright will still catch any type errors. +exclude = ['src/zeroentropy/_files.py', '_dev/.*.py', 'tests/.*'] + +strict_equality = true +implicit_reexport = true +check_untyped_defs = true +no_implicit_optional = true + +warn_return_any = true +warn_unreachable = true +warn_unused_configs = true + +# Turn these options off as it could cause conflicts +# with the Pyright options. +warn_unused_ignores = false +warn_redundant_casts = false + +disallow_any_generics = true +disallow_untyped_defs = true +disallow_untyped_calls = true +disallow_subclassing_any = true +disallow_incomplete_defs = true +disallow_untyped_decorators = true +cache_fine_grained = true + +# By default, mypy reports an error if you assign a value to the result +# of a function call that doesn't return anything. We do this in our test +# cases: +# ``` +# result = ... +# assert result is None +# ``` +# Changing this codegen to make mypy happy would increase complexity +# and would not be worth it. +disable_error_code = "func-returns-value,overload-cannot-match" + +# https://github.com/python/mypy/issues/12162 +[[tool.mypy.overrides]] +module = "black.files.*" +ignore_errors = true +ignore_missing_imports = true + + [tool.ruff] line-length = 120 output-format = "grouped" From ae111a642a1da930f319ce62f1a28f203881cfe5 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Sat, 6 Sep 2025 05:44:29 +0000 Subject: [PATCH 22/25] chore(tests): simplify `get_platform` test `nest_asyncio` is archived and broken on some platforms so it's not worth keeping in our test suite. --- pyproject.toml | 1 - requirements-dev.lock | 1 - tests/test_client.py | 53 +++++-------------------------------------- 3 files changed, 6 insertions(+), 49 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5400c8d..a3c7371 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,6 @@ dev-dependencies = [ "dirty-equals>=0.6.0", "importlib-metadata>=6.7.0", "rich>=13.7.1", - "nest_asyncio==1.6.0", "pytest-xdist>=3.6.1", ] diff --git a/requirements-dev.lock b/requirements-dev.lock index 087efaa..00484cd 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -75,7 +75,6 @@ multidict==6.4.4 mypy==1.14.1 mypy-extensions==1.0.0 # via mypy -nest-asyncio==1.6.0 nodeenv==1.8.0 # via pyright nox==2023.4.22 diff --git a/tests/test_client.py b/tests/test_client.py index 6198354..a97e096 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,13 +6,10 @@ import os import sys import json -import time import asyncio import inspect -import subprocess import tracemalloc from typing import Any, Union, cast -from textwrap import dedent from unittest import mock from typing_extensions import Literal @@ -23,14 +20,17 @@ from zeroentropy import ZeroEntropy, AsyncZeroEntropy, APIResponseValidationError from zeroentropy._types import Omit +from zeroentropy._utils import asyncify from zeroentropy._models import BaseModel, FinalRequestOptions from zeroentropy._exceptions import APIStatusError, APITimeoutError, ZeroEntropyError, APIResponseValidationError from zeroentropy._base_client import ( DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, + OtherPlatform, DefaultHttpxClient, DefaultAsyncHttpxClient, + get_platform, make_request_options, ) @@ -1639,50 +1639,9 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: assert response.http_request.headers.get("x-stainless-retry-count") == "42" - def test_get_platform(self) -> None: - # A previous implementation of asyncify could leave threads unterminated when - # used with nest_asyncio. - # - # Since nest_asyncio.apply() is global and cannot be un-applied, this - # test is run in a separate process to avoid affecting other tests. - test_code = dedent(""" - import asyncio - import nest_asyncio - import threading - - from zeroentropy._utils import asyncify - from zeroentropy._base_client import get_platform - - async def test_main() -> None: - result = await asyncify(get_platform)() - print(result) - for thread in threading.enumerate(): - print(thread.name) - - nest_asyncio.apply() - asyncio.run(test_main()) - """) - with subprocess.Popen( - [sys.executable, "-c", test_code], - text=True, - ) as process: - timeout = 10 # seconds - - start_time = time.monotonic() - while True: - return_code = process.poll() - if return_code is not None: - if return_code != 0: - raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code") - - # success - break - - if time.monotonic() - start_time > timeout: - process.kill() - raise AssertionError("calling get_platform using asyncify resulted in a hung process") - - time.sleep(0.1) + async def test_get_platform(self) -> None: + platform = await asyncify(get_platform)() + assert isinstance(platform, (str, OtherPlatform)) async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: # Test that the proxy environment variables are set correctly From 7b6f1c52caff5579b0f42534ae393bf2bd634e2b Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Mon, 24 Nov 2025 23:28:53 +0000 Subject: [PATCH 23/25] feat(api): manual updates --- .stats.yml | 4 +- README.md | 5 +- pyproject.toml | 12 +- requirements-dev.lock | 9 +- requirements.lock | 9 +- scripts/bootstrap | 14 +- src/zeroentropy/__init__.py | 4 +- src/zeroentropy/_base_client.py | 18 +- src/zeroentropy/_client.py | 16 +- src/zeroentropy/_models.py | 64 ++- src/zeroentropy/_qs.py | 14 +- src/zeroentropy/_streaming.py | 10 +- src/zeroentropy/_types.py | 29 +- src/zeroentropy/_utils/_sync.py | 34 +- src/zeroentropy/_utils/_transform.py | 4 +- src/zeroentropy/_utils/_utils.py | 10 +- src/zeroentropy/resources/collections.py | 42 +- src/zeroentropy/resources/documents.py | 74 ++-- src/zeroentropy/resources/models.py | 73 +++- src/zeroentropy/resources/queries.py | 68 ++-- src/zeroentropy/resources/status.py | 10 +- .../types/collection_add_params.py | 10 + src/zeroentropy/types/document_add_params.py | 9 +- src/zeroentropy/types/model_rerank_params.py | 25 +- .../types/model_rerank_response.py | 3 +- .../types/query_top_documents_params.py | 9 +- .../types/status_get_status_response.py | 6 + tests/api_resources/test_collections.py | 16 + tests/api_resources/test_models.py | 12 +- tests/test_client.py | 364 ++++++++++-------- tests/test_models.py | 8 +- tests/test_transform.py | 11 +- 32 files changed, 595 insertions(+), 401 deletions(-) diff --git a/.stats.yml b/.stats.yml index 569eead..3b44efb 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 14 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/zeroentropy%2Fzeroentropy-bd2f55f423e09b74f83cbad6034fb76f7052363308d02533a908b49543cff459.yml -openapi_spec_hash: 6d7566ebda7fecac4069744949d547e0 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/zeroentropy%2Fzeroentropy-c95681b13dc56e64126746c6e546b564c7f802ae567fc9ccc1aeb8eddd40bb1e.yml +openapi_spec_hash: 2ac723122fe938e384f11b5cf19e85ec config_hash: e07cdee04c971e1db74e91a5a4cd981c diff --git a/README.md b/README.md index a3c0da0..effdcf2 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,7 @@ [![PyPI version](https://img.shields.io/pypi/v/zeroentropy.svg?label=pypi%20(stable))](https://pypi.org/project/zeroentropy/) -The ZeroEntropy Python SDK provides convenient access to the [ZeroEntropy REST API](https://docs.zeroentropy.dev/api-reference/) from any Python 3.8+ -application. +The ZeroEntropy Python SDK provides convenient type-safe access to the [ZeroEntropy REST API](https://docs.zeroentropy.dev/api-reference/) from any Python 3.9+ application. In order to get an API Key, you can visit our [dashboard](https://dashboard.zeroentropy.dev/). @@ -449,7 +448,7 @@ print(zeroentropy.__version__) ## Requirements -Python 3.8 or higher. +Python 3.9 or higher. ## Contributing diff --git a/pyproject.toml b/pyproject.toml index a3c7371..30d0058 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,16 +15,16 @@ dependencies = [ "distro>=1.7.0, <2", "sniffio", ] -requires-python = ">= 3.8" +requires-python = ">= 3.9" classifiers = [ "Typing :: Typed", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Operating System :: OS Independent", "Operating System :: POSIX", "Operating System :: MacOS", @@ -39,7 +39,7 @@ Homepage = "https://github.com/zeroentropy-ai/zeroentropy-python" Repository = "https://github.com/zeroentropy-ai/zeroentropy-python" [project.optional-dependencies] -aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.8"] +aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"] [tool.rye] managed = true @@ -141,7 +141,7 @@ filterwarnings = [ # there are a couple of flags that are still disabled by # default in strict mode as they are experimental and niche. typeCheckingMode = "strict" -pythonVersion = "3.8" +pythonVersion = "3.9" exclude = [ "_dev", @@ -224,6 +224,8 @@ select = [ "B", # remove unused imports "F401", + # check for missing future annotations + "FA102", # bare except statements "E722", # unused arguments @@ -246,6 +248,8 @@ unfixable = [ "T203", ] +extend-safe-fixes = ["FA102"] + [tool.ruff.lint.flake8-tidy-imports.banned-api] "functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead" diff --git a/requirements-dev.lock b/requirements-dev.lock index 00484cd..1b78d4c 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -56,7 +56,7 @@ httpx==0.28.1 # via httpx-aiohttp # via respx # via zeroentropy -httpx-aiohttp==0.1.8 +httpx-aiohttp==0.1.9 # via zeroentropy idna==3.4 # via anyio @@ -88,9 +88,9 @@ pluggy==1.5.0 propcache==0.3.1 # via aiohttp # via yarl -pydantic==2.10.3 +pydantic==2.11.9 # via zeroentropy -pydantic-core==2.27.1 +pydantic-core==2.33.2 # via pydantic pygments==2.18.0 # via rich @@ -125,7 +125,10 @@ typing-extensions==4.12.2 # via pydantic # via pydantic-core # via pyright + # via typing-inspection # via zeroentropy +typing-inspection==0.4.1 + # via pydantic virtualenv==20.24.5 # via nox yarl==1.20.0 diff --git a/requirements.lock b/requirements.lock index a4b1e3b..13241b6 100644 --- a/requirements.lock +++ b/requirements.lock @@ -43,7 +43,7 @@ httpcore==1.0.9 httpx==0.28.1 # via httpx-aiohttp # via zeroentropy -httpx-aiohttp==0.1.8 +httpx-aiohttp==0.1.9 # via zeroentropy idna==3.4 # via anyio @@ -55,9 +55,9 @@ multidict==6.4.4 propcache==0.3.1 # via aiohttp # via yarl -pydantic==2.10.3 +pydantic==2.11.9 # via zeroentropy -pydantic-core==2.27.1 +pydantic-core==2.33.2 # via pydantic sniffio==1.3.0 # via anyio @@ -67,6 +67,9 @@ typing-extensions==4.12.2 # via multidict # via pydantic # via pydantic-core + # via typing-inspection # via zeroentropy +typing-inspection==0.4.1 + # via pydantic yarl==1.20.0 # via aiohttp diff --git a/scripts/bootstrap b/scripts/bootstrap index e84fe62..b430fee 100755 --- a/scripts/bootstrap +++ b/scripts/bootstrap @@ -4,10 +4,18 @@ set -e cd "$(dirname "$0")/.." -if ! command -v rye >/dev/null 2>&1 && [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ]; then +if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ] && [ "$SKIP_BREW" != "1" ] && [ -t 0 ]; then brew bundle check >/dev/null 2>&1 || { - echo "==> Installing Homebrew dependencies…" - brew bundle + echo -n "==> Install Homebrew dependencies? (y/N): " + read -r response + case "$response" in + [yY][eE][sS]|[yY]) + brew bundle + ;; + *) + ;; + esac + echo } fi diff --git a/src/zeroentropy/__init__.py b/src/zeroentropy/__init__.py index 196eb66..c0db229 100644 --- a/src/zeroentropy/__init__.py +++ b/src/zeroentropy/__init__.py @@ -3,7 +3,7 @@ import typing as _t from . import types -from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes +from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given from ._utils import file_from_path from ._client import ( Client, @@ -48,7 +48,9 @@ "ProxiesTypes", "NotGiven", "NOT_GIVEN", + "not_given", "Omit", + "omit", "ZeroEntropyError", "APIError", "APIStatusError", diff --git a/src/zeroentropy/_base_client.py b/src/zeroentropy/_base_client.py index 5b7681b..2769d40 100644 --- a/src/zeroentropy/_base_client.py +++ b/src/zeroentropy/_base_client.py @@ -42,7 +42,6 @@ from ._qs import Querystring from ._files import to_httpx_files, async_to_httpx_files from ._types import ( - NOT_GIVEN, Body, Omit, Query, @@ -57,6 +56,7 @@ RequestOptions, HttpxRequestFiles, ModelBuilderProtocol, + not_given, ) from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping from ._compat import PYDANTIC_V1, model_copy, model_dump @@ -145,9 +145,9 @@ def __init__( def __init__( self, *, - url: URL | NotGiven = NOT_GIVEN, - json: Body | NotGiven = NOT_GIVEN, - params: Query | NotGiven = NOT_GIVEN, + url: URL | NotGiven = not_given, + json: Body | NotGiven = not_given, + params: Query | NotGiven = not_given, ) -> None: self.url = url self.json = json @@ -595,7 +595,7 @@ def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalReques # we internally support defining a temporary header to override the # default `cast_to` type for use with `.with_raw_response` and `.with_streaming_response` # see _response.py for implementation details - override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, NOT_GIVEN) + override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, not_given) if is_given(override_cast_to): options.headers = headers return cast(Type[ResponseT], override_cast_to) @@ -825,7 +825,7 @@ def __init__( version: str, base_url: str | URL, max_retries: int = DEFAULT_MAX_RETRIES, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.Client | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, @@ -1356,7 +1356,7 @@ def __init__( base_url: str | URL, _strict_response_validation: bool, max_retries: int = DEFAULT_MAX_RETRIES, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.AsyncClient | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, @@ -1818,8 +1818,8 @@ def make_request_options( extra_query: Query | None = None, extra_body: Body | None = None, idempotency_key: str | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - post_parser: PostParser | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + post_parser: PostParser | NotGiven = not_given, ) -> RequestOptions: """Create a dict of type RequestOptions without keys of NotGiven values.""" options: RequestOptions = {} diff --git a/src/zeroentropy/_client.py b/src/zeroentropy/_client.py index abf8acc..c87e7f0 100644 --- a/src/zeroentropy/_client.py +++ b/src/zeroentropy/_client.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Union, Mapping +from typing import Any, Mapping from typing_extensions import Self, override import httpx @@ -11,13 +11,13 @@ from . import _exceptions from ._qs import Querystring from ._types import ( - NOT_GIVEN, Omit, Timeout, NotGiven, Transport, ProxiesTypes, RequestOptions, + not_given, ) from ._utils import is_given, get_async_library from ._version import __version__ @@ -59,7 +59,7 @@ def __init__( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -138,9 +138,9 @@ def copy( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.Client | None = None, - max_retries: int | NotGiven = NOT_GIVEN, + max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -235,7 +235,7 @@ def __init__( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -314,9 +314,9 @@ def copy( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.AsyncClient | None = None, - max_retries: int | NotGiven = NOT_GIVEN, + max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, diff --git a/src/zeroentropy/_models.py b/src/zeroentropy/_models.py index 3a6017e..ca9500b 100644 --- a/src/zeroentropy/_models.py +++ b/src/zeroentropy/_models.py @@ -2,6 +2,7 @@ import os import inspect +import weakref from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast from datetime import date, datetime from typing_extensions import ( @@ -256,13 +257,15 @@ def model_dump( mode: Literal["json", "python"] | str = "python", include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool = False, + context: Any | None = None, + by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, + exclude_computed_fields: bool = False, round_trip: bool = False, warnings: bool | Literal["none", "warn", "error"] = True, - context: dict[str, Any] | None = None, + fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, ) -> dict[str, Any]: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump @@ -271,16 +274,24 @@ def model_dump( Args: mode: The mode in which `to_python` should run. - If mode is 'json', the dictionary will only contain JSON serializable types. - If mode is 'python', the dictionary may contain any Python objects. - include: A list of fields to include in the output. - exclude: A list of fields to exclude from the output. + If mode is 'json', the output will only contain JSON serializable types. + If mode is 'python', the output may contain non-JSON-serializable Python objects. + include: A set of fields to include in the output. + exclude: A set of fields to exclude from the output. + context: Additional context to pass to the serializer. by_alias: Whether to use the field's alias in the dictionary key if defined. - exclude_unset: Whether to exclude fields that are unset or None from the output. - exclude_defaults: Whether to exclude fields that are set to their default value from the output. - exclude_none: Whether to exclude fields that have a value of `None` from the output. - round_trip: Whether to enable serialization and deserialization round-trip support. - warnings: Whether to log warnings when invalid fields are encountered. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that are set to their default value. + exclude_none: Whether to exclude fields that have a value of `None`. + exclude_computed_fields: Whether to exclude computed fields. + While this can be useful for round-tripping, it is usually recommended to use the dedicated + `round_trip` parameter instead. + round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T]. + warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, + "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. + fallback: A function to call when an unknown value is encountered. If not provided, + a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. Returns: A dictionary representation of the model. @@ -295,10 +306,14 @@ def model_dump( raise ValueError("context is only supported in Pydantic v2") if serialize_as_any != False: raise ValueError("serialize_as_any is only supported in Pydantic v2") + if fallback is not None: + raise ValueError("fallback is only supported in Pydantic v2") + if exclude_computed_fields != False: + raise ValueError("exclude_computed_fields is only supported in Pydantic v2") dumped = super().dict( # pyright: ignore[reportDeprecated] include=include, exclude=exclude, - by_alias=by_alias, + by_alias=by_alias if by_alias is not None else False, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, @@ -311,15 +326,18 @@ def model_dump_json( self, *, indent: int | None = None, + ensure_ascii: bool = False, include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool = False, + context: Any | None = None, + by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, + exclude_computed_fields: bool = False, round_trip: bool = False, warnings: bool | Literal["none", "warn", "error"] = True, - context: dict[str, Any] | None = None, + fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, ) -> str: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json @@ -348,11 +366,17 @@ def model_dump_json( raise ValueError("context is only supported in Pydantic v2") if serialize_as_any != False: raise ValueError("serialize_as_any is only supported in Pydantic v2") + if fallback is not None: + raise ValueError("fallback is only supported in Pydantic v2") + if ensure_ascii != False: + raise ValueError("ensure_ascii is only supported in Pydantic v2") + if exclude_computed_fields != False: + raise ValueError("exclude_computed_fields is only supported in Pydantic v2") return super().json( # type: ignore[reportDeprecated] indent=indent, include=include, exclude=exclude, - by_alias=by_alias, + by_alias=by_alias if by_alias is not None else False, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, @@ -567,6 +591,9 @@ class CachedDiscriminatorType(Protocol): __discriminator__: DiscriminatorDetails +DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary() + + class DiscriminatorDetails: field_name: str """The name of the discriminator field in the variant class, e.g. @@ -609,8 +636,9 @@ def __init__( def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: - if isinstance(union, CachedDiscriminatorType): - return union.__discriminator__ + cached = DISCRIMINATOR_CACHE.get(union) + if cached is not None: + return cached discriminator_field_name: str | None = None @@ -663,7 +691,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, discriminator_field=discriminator_field_name, discriminator_alias=discriminator_alias, ) - cast(CachedDiscriminatorType, union).__discriminator__ = details + DISCRIMINATOR_CACHE.setdefault(union, details) return details diff --git a/src/zeroentropy/_qs.py b/src/zeroentropy/_qs.py index 274320c..ada6fd3 100644 --- a/src/zeroentropy/_qs.py +++ b/src/zeroentropy/_qs.py @@ -4,7 +4,7 @@ from urllib.parse import parse_qs, urlencode from typing_extensions import Literal, get_args -from ._types import NOT_GIVEN, NotGiven, NotGivenOr +from ._types import NotGiven, not_given from ._utils import flatten _T = TypeVar("_T") @@ -41,8 +41,8 @@ def stringify( self, params: Params, *, - array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, - nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, ) -> str: return urlencode( self.stringify_items( @@ -56,8 +56,8 @@ def stringify_items( self, params: Params, *, - array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, - nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, ) -> list[tuple[str, str]]: opts = Options( qs=self, @@ -143,8 +143,8 @@ def __init__( self, qs: Querystring = _qs, *, - array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, - nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, ) -> None: self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format diff --git a/src/zeroentropy/_streaming.py b/src/zeroentropy/_streaming.py index 1be1801..9f17ee2 100644 --- a/src/zeroentropy/_streaming.py +++ b/src/zeroentropy/_streaming.py @@ -57,9 +57,8 @@ def __stream__(self) -> Iterator[_T]: for sse in iterator: yield process_data(data=sse.json(), cast_to=cast_to, response=response) - # Ensure the entire stream is consumed - for _sse in iterator: - ... + # As we might not fully consume the response stream, we need to close it explicitly + response.close() def __enter__(self) -> Self: return self @@ -121,9 +120,8 @@ async def __stream__(self) -> AsyncIterator[_T]: async for sse in iterator: yield process_data(data=sse.json(), cast_to=cast_to, response=response) - # Ensure the entire stream is consumed - async for _sse in iterator: - ... + # As we might not fully consume the response stream, we need to close it explicitly + await response.aclose() async def __aenter__(self) -> Self: return self diff --git a/src/zeroentropy/_types.py b/src/zeroentropy/_types.py index 60a5cab..ac690c6 100644 --- a/src/zeroentropy/_types.py +++ b/src/zeroentropy/_types.py @@ -117,18 +117,21 @@ class RequestOptions(TypedDict, total=False): # Sentinel class used until PEP 0661 is accepted class NotGiven: """ - A sentinel singleton class used to distinguish omitted keyword arguments - from those passed in with the value None (which may have different behavior). + For parameters with a meaningful None value, we need to distinguish between + the user explicitly passing None, and the user not passing the parameter at + all. + + User code shouldn't need to use not_given directly. For example: ```py - def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... + def create(timeout: Timeout | None | NotGiven = not_given): ... - get(timeout=1) # 1s timeout - get(timeout=None) # No timeout - get() # Default timeout behavior, which may not be statically known at the method definition. + create(timeout=1) # 1s timeout + create(timeout=None) # No timeout + create() # Default timeout behavior ``` """ @@ -140,13 +143,14 @@ def __repr__(self) -> str: return "NOT_GIVEN" -NotGivenOr = Union[_T, NotGiven] +not_given = NotGiven() +# for backwards compatibility: NOT_GIVEN = NotGiven() class Omit: - """In certain situations you need to be able to represent a case where a default value has - to be explicitly removed and `None` is not an appropriate substitute, for example: + """ + To explicitly omit something from being sent in a request, use `omit`. ```py # as the default `Content-Type` header is `application/json` that will be sent @@ -156,8 +160,8 @@ class Omit: # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' client.post(..., headers={"Content-Type": "multipart/form-data"}) - # instead you can remove the default `application/json` header by passing Omit - client.post(..., headers={"Content-Type": Omit()}) + # instead you can remove the default `application/json` header by passing omit + client.post(..., headers={"Content-Type": omit}) ``` """ @@ -165,6 +169,9 @@ def __bool__(self) -> Literal[False]: return False +omit = Omit() + + @runtime_checkable class ModelBuilderProtocol(Protocol): @classmethod diff --git a/src/zeroentropy/_utils/_sync.py b/src/zeroentropy/_utils/_sync.py index ad7ec71..f6027c1 100644 --- a/src/zeroentropy/_utils/_sync.py +++ b/src/zeroentropy/_utils/_sync.py @@ -1,10 +1,8 @@ from __future__ import annotations -import sys import asyncio import functools -import contextvars -from typing import Any, TypeVar, Callable, Awaitable +from typing import TypeVar, Callable, Awaitable from typing_extensions import ParamSpec import anyio @@ -15,34 +13,11 @@ T_ParamSpec = ParamSpec("T_ParamSpec") -if sys.version_info >= (3, 9): - _asyncio_to_thread = asyncio.to_thread -else: - # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread - # for Python 3.8 support - async def _asyncio_to_thread( - func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs - ) -> Any: - """Asynchronously run function *func* in a separate thread. - - Any *args and **kwargs supplied for this function are directly passed - to *func*. Also, the current :class:`contextvars.Context` is propagated, - allowing context variables from the main thread to be accessed in the - separate thread. - - Returns a coroutine that can be awaited to get the eventual result of *func*. - """ - loop = asyncio.events.get_running_loop() - ctx = contextvars.copy_context() - func_call = functools.partial(ctx.run, func, *args, **kwargs) - return await loop.run_in_executor(None, func_call) - - async def to_thread( func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs ) -> T_Retval: if sniffio.current_async_library() == "asyncio": - return await _asyncio_to_thread(func, *args, **kwargs) + return await asyncio.to_thread(func, *args, **kwargs) return await anyio.to_thread.run_sync( functools.partial(func, *args, **kwargs), @@ -53,10 +28,7 @@ async def to_thread( def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: """ Take a blocking function and create an async one that receives the same - positional and keyword arguments. For python version 3.9 and above, it uses - asyncio.to_thread to run the function in a separate thread. For python version - 3.8, it uses locally defined copy of the asyncio.to_thread function which was - introduced in python 3.9. + positional and keyword arguments. Usage: diff --git a/src/zeroentropy/_utils/_transform.py b/src/zeroentropy/_utils/_transform.py index c19124f..5207549 100644 --- a/src/zeroentropy/_utils/_transform.py +++ b/src/zeroentropy/_utils/_transform.py @@ -268,7 +268,7 @@ def _transform_typeddict( annotations = get_type_hints(expected_type, include_extras=True) for key, value in data.items(): if not is_given(value): - # we don't need to include `NotGiven` values here as they'll + # we don't need to include omitted values here as they'll # be stripped out before the request is sent anyway continue @@ -434,7 +434,7 @@ async def _async_transform_typeddict( annotations = get_type_hints(expected_type, include_extras=True) for key, value in data.items(): if not is_given(value): - # we don't need to include `NotGiven` values here as they'll + # we don't need to include omitted values here as they'll # be stripped out before the request is sent anyway continue diff --git a/src/zeroentropy/_utils/_utils.py b/src/zeroentropy/_utils/_utils.py index f081859..eec7f4a 100644 --- a/src/zeroentropy/_utils/_utils.py +++ b/src/zeroentropy/_utils/_utils.py @@ -21,7 +21,7 @@ import sniffio -from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike +from .._types import Omit, NotGiven, FileTypes, HeadersLike _T = TypeVar("_T") _TupleT = TypeVar("_TupleT", bound=Tuple[object, ...]) @@ -63,7 +63,7 @@ def _extract_items( try: key = path[index] except IndexError: - if isinstance(obj, NotGiven): + if not is_given(obj): # no value was provided - we can safely ignore return [] @@ -126,14 +126,14 @@ def _extract_items( return [] -def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]: - return not isinstance(obj, NotGiven) +def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]: + return not isinstance(obj, NotGiven) and not isinstance(obj, Omit) # Type safe methods for narrowing types with TypeVars. # The default narrowing for isinstance(obj, dict) is dict[unknown, unknown], # however this cause Pyright to rightfully report errors. As we know we don't -# care about the contained types we can safely use `object` in it's place. +# care about the contained types we can safely use `object` in its place. # # There are two separate functions defined, `is_*` and `is_*_t` for different use cases. # `is_*` is for when you're dealing with an unknown input diff --git a/src/zeroentropy/resources/collections.py b/src/zeroentropy/resources/collections.py index 24aeaa8..505184a 100644 --- a/src/zeroentropy/resources/collections.py +++ b/src/zeroentropy/resources/collections.py @@ -5,7 +5,7 @@ import httpx from ..types import collection_add_params, collection_delete_params -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -52,7 +52,7 @@ def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CollectionDeleteResponse: """ Deletes a collection. @@ -84,12 +84,13 @@ def add( self, *, collection_name: str, + num_shards: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CollectionAddResponse: """ Adds a collection. @@ -101,6 +102,12 @@ def add( characters. If special characters are used, then the UTF-8 encoded string cannot exceed 1024 bytes. + num_shards: [ADVANCED] The number of shards to use for this collection. By using K shards, + your documents can index with K times more throughput. However, queries will be + automatically sent to all K shards and then aggregated. For large collections, + this can make queries faster. But for small collections, this will make queries + slower. `num_shards` must be one of [1, 8, 16, 32, 64]. The default is 1. + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -111,7 +118,13 @@ def add( """ return self._post( "/collections/add-collection", - body=maybe_transform({"collection_name": collection_name}, collection_add_params.CollectionAddParams), + body=maybe_transform( + { + "collection_name": collection_name, + "num_shards": num_shards, + }, + collection_add_params.CollectionAddParams, + ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -126,7 +139,7 @@ def get_list( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CollectionGetListResponse: """Gets a complete list of all of your collections.""" return self._post( @@ -167,7 +180,7 @@ async def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CollectionDeleteResponse: """ Deletes a collection. @@ -201,12 +214,13 @@ async def add( self, *, collection_name: str, + num_shards: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CollectionAddResponse: """ Adds a collection. @@ -218,6 +232,12 @@ async def add( characters. If special characters are used, then the UTF-8 encoded string cannot exceed 1024 bytes. + num_shards: [ADVANCED] The number of shards to use for this collection. By using K shards, + your documents can index with K times more throughput. However, queries will be + automatically sent to all K shards and then aggregated. For large collections, + this can make queries faster. But for small collections, this will make queries + slower. `num_shards` must be one of [1, 8, 16, 32, 64]. The default is 1. + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -229,7 +249,11 @@ async def add( return await self._post( "/collections/add-collection", body=await async_maybe_transform( - {"collection_name": collection_name}, collection_add_params.CollectionAddParams + { + "collection_name": collection_name, + "num_shards": num_shards, + }, + collection_add_params.CollectionAddParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -245,7 +269,7 @@ async def get_list( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CollectionGetListResponse: """Gets a complete list of all of your collections.""" return await self._post( diff --git a/src/zeroentropy/resources/documents.py b/src/zeroentropy/resources/documents.py index 1330909..22a8e32 100644 --- a/src/zeroentropy/resources/documents.py +++ b/src/zeroentropy/resources/documents.py @@ -15,7 +15,7 @@ document_get_info_list_params, document_get_page_info_params, ) -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, SequenceNotStr +from .._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -62,14 +62,14 @@ def update( *, collection_name: str, path: str, - index_status: Optional[Literal["not_parsed", "not_indexed"]] | NotGiven = NOT_GIVEN, - metadata: Optional[Dict[str, Union[str, SequenceNotStr[str]]]] | NotGiven = NOT_GIVEN, + index_status: Optional[Literal["not_parsed", "not_indexed"]] | Omit = omit, + metadata: Optional[Dict[str, Union[str, SequenceNotStr[str]]]] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentUpdateResponse: """Updates a document. @@ -81,9 +81,9 @@ def update( `index_status` of `indexed`. After this call, the document will have an `index_status` of `not_indexed`, since the document will need to reindex with the new metadata. - - When updating with a non-null `index_status`, setting it to - `not_parsed or `not_indexed`requires that the document must have`index_status`of`parsing_failed`or`indexing_failed`, - respectively. + - When updating with a non-null `index_status`, setting it to `not_parsed` or + `not_indexed` requires that the document must have `index_status` of + `parsing_failed` or `indexing_failed`, respectively. A `404 Not Found` status code will be returned, if the provided collection name or document path does not exist. @@ -139,7 +139,7 @@ def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentDeleteResponse: """ Deletes a document @@ -182,14 +182,14 @@ def add( collection_name: str, content: document_add_params.Content, path: str, - metadata: Dict[str, Union[str, SequenceNotStr[str]]] | NotGiven = NOT_GIVEN, - overwrite: bool | NotGiven = NOT_GIVEN, + metadata: Dict[str, Union[str, SequenceNotStr[str]]] | Omit = omit, + overwrite: bool | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentAddResponse: """ Adds a document to a given collection. @@ -261,13 +261,13 @@ def get_info( *, collection_name: str, path: str, - include_content: bool | NotGiven = NOT_GIVEN, + include_content: bool | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentGetInfoResponse: """Retrieves information about a specific document. @@ -315,15 +315,15 @@ def get_info_list( self, *, collection_name: str, - limit: int | NotGiven = NOT_GIVEN, - path_gt: Optional[str] | NotGiven = NOT_GIVEN, - path_prefix: Optional[str] | NotGiven = NOT_GIVEN, + limit: int | Omit = omit, + path_gt: Optional[str] | Omit = omit, + path_prefix: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> SyncGetDocumentInfoListCursor[DocumentGetInfoListResponse]: """ Retrives a list of document metadata information that matches the provided @@ -383,13 +383,13 @@ def get_page_info( collection_name: str, page_index: int, path: str, - include_content: bool | NotGiven = NOT_GIVEN, + include_content: bool | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentGetPageInfoResponse: """Retrieves information about a specific page. @@ -465,14 +465,14 @@ async def update( *, collection_name: str, path: str, - index_status: Optional[Literal["not_parsed", "not_indexed"]] | NotGiven = NOT_GIVEN, - metadata: Optional[Dict[str, Union[str, SequenceNotStr[str]]]] | NotGiven = NOT_GIVEN, + index_status: Optional[Literal["not_parsed", "not_indexed"]] | Omit = omit, + metadata: Optional[Dict[str, Union[str, SequenceNotStr[str]]]] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentUpdateResponse: """Updates a document. @@ -484,9 +484,9 @@ async def update( `index_status` of `indexed`. After this call, the document will have an `index_status` of `not_indexed`, since the document will need to reindex with the new metadata. - - When updating with a non-null `index_status`, setting it to - `not_parsed or `not_indexed`requires that the document must have`index_status`of`parsing_failed`or`indexing_failed`, - respectively. + - When updating with a non-null `index_status`, setting it to `not_parsed` or + `not_indexed` requires that the document must have `index_status` of + `parsing_failed` or `indexing_failed`, respectively. A `404 Not Found` status code will be returned, if the provided collection name or document path does not exist. @@ -542,7 +542,7 @@ async def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentDeleteResponse: """ Deletes a document @@ -585,14 +585,14 @@ async def add( collection_name: str, content: document_add_params.Content, path: str, - metadata: Dict[str, Union[str, SequenceNotStr[str]]] | NotGiven = NOT_GIVEN, - overwrite: bool | NotGiven = NOT_GIVEN, + metadata: Dict[str, Union[str, SequenceNotStr[str]]] | Omit = omit, + overwrite: bool | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentAddResponse: """ Adds a document to a given collection. @@ -664,13 +664,13 @@ async def get_info( *, collection_name: str, path: str, - include_content: bool | NotGiven = NOT_GIVEN, + include_content: bool | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentGetInfoResponse: """Retrieves information about a specific document. @@ -718,15 +718,15 @@ def get_info_list( self, *, collection_name: str, - limit: int | NotGiven = NOT_GIVEN, - path_gt: Optional[str] | NotGiven = NOT_GIVEN, - path_prefix: Optional[str] | NotGiven = NOT_GIVEN, + limit: int | Omit = omit, + path_gt: Optional[str] | Omit = omit, + path_prefix: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> AsyncPaginator[DocumentGetInfoListResponse, AsyncGetDocumentInfoListCursor[DocumentGetInfoListResponse]]: """ Retrives a list of document metadata information that matches the provided @@ -786,13 +786,13 @@ async def get_page_info( collection_name: str, page_index: int, path: str, - include_content: bool | NotGiven = NOT_GIVEN, + include_content: bool | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentGetPageInfoResponse: """Retrieves information about a specific page. diff --git a/src/zeroentropy/resources/models.py b/src/zeroentropy/resources/models.py index 4fd006f..97fe9d0 100644 --- a/src/zeroentropy/resources/models.py +++ b/src/zeroentropy/resources/models.py @@ -3,11 +3,12 @@ from __future__ import annotations from typing import Optional +from typing_extensions import Literal import httpx from ..types import model_rerank_params -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, SequenceNotStr +from .._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -47,15 +48,16 @@ def rerank( self, *, documents: SequenceNotStr[str], + model: str, query: str, - model: str | NotGiven = NOT_GIVEN, - top_n: Optional[int] | NotGiven = NOT_GIVEN, + latency: Optional[Literal["fast", "slow"]] | Omit = omit, + top_n: Optional[int] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ModelRerankResponse: """ Reranks the provided documents, according to the provided query. @@ -63,19 +65,34 @@ def rerank( The results will be sorted by descending order of relevance. For each document, the index and the score will be returned. The index is relative to the documents array that was passed in. The score is the query-document relevancy determined - by the reranker model. The value will be returned in descending order to + by the reranker model. The results will be returned in descending order of relevance. + Organizations will, by default, have a ratelimit of `2,500,000` + bytes-per-minute. If this is exceeded, requests will be throttled into + `latency: "slow"` mode, up to `10,000,000` bytes-per-minute. If even this is + exceeded, you will get a `429` error. To request higher ratelimits, please + contact [founders@zeroentropy.dev](mailto:founders@zeroentropy.dev) or message + us on [Discord](https://go.zeroentropy.dev/discord) or + [Slack](https://go.zeroentropy.dev/slack)! + Args: documents: The list of documents to rerank. Each document is a string. - query: The query to rerank the documents by. Results will be in descending order of - relevance. + model: The model ID to use for reranking. Options are: ["zerank-2", "zerank-1", + "zerank-1-small"] + + query: The query to rerank the documents by. - model: The model ID to use for reranking. Options are: ["zerank-1-large"] + latency: Whether the call will be inferenced "fast" or "slow". RateLimits for slow API + calls are orders of magnitude higher, but you can expect >10 second latency. + Fast inferences are guaranteed subsecond, but rate limits are lower. If not + specified, first a "fast" call will be attempted, but if you have exceeded your + fast rate limit, then a slow call will be executed. If explicitly set to "fast", + then 429 will be returned if it cannot be executed fast. top_n: If provided, then only the top `n` documents will be returned in the results - array. + array. Otherwise, `n` will be the length of the provided documents array. extra_headers: Send extra headers @@ -90,8 +107,9 @@ def rerank( body=maybe_transform( { "documents": documents, - "query": query, "model": model, + "query": query, + "latency": latency, "top_n": top_n, }, model_rerank_params.ModelRerankParams, @@ -127,15 +145,16 @@ async def rerank( self, *, documents: SequenceNotStr[str], + model: str, query: str, - model: str | NotGiven = NOT_GIVEN, - top_n: Optional[int] | NotGiven = NOT_GIVEN, + latency: Optional[Literal["fast", "slow"]] | Omit = omit, + top_n: Optional[int] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ModelRerankResponse: """ Reranks the provided documents, according to the provided query. @@ -143,19 +162,34 @@ async def rerank( The results will be sorted by descending order of relevance. For each document, the index and the score will be returned. The index is relative to the documents array that was passed in. The score is the query-document relevancy determined - by the reranker model. The value will be returned in descending order to + by the reranker model. The results will be returned in descending order of relevance. + Organizations will, by default, have a ratelimit of `2,500,000` + bytes-per-minute. If this is exceeded, requests will be throttled into + `latency: "slow"` mode, up to `10,000,000` bytes-per-minute. If even this is + exceeded, you will get a `429` error. To request higher ratelimits, please + contact [founders@zeroentropy.dev](mailto:founders@zeroentropy.dev) or message + us on [Discord](https://go.zeroentropy.dev/discord) or + [Slack](https://go.zeroentropy.dev/slack)! + Args: documents: The list of documents to rerank. Each document is a string. - query: The query to rerank the documents by. Results will be in descending order of - relevance. + model: The model ID to use for reranking. Options are: ["zerank-2", "zerank-1", + "zerank-1-small"] + + query: The query to rerank the documents by. - model: The model ID to use for reranking. Options are: ["zerank-1-large"] + latency: Whether the call will be inferenced "fast" or "slow". RateLimits for slow API + calls are orders of magnitude higher, but you can expect >10 second latency. + Fast inferences are guaranteed subsecond, but rate limits are lower. If not + specified, first a "fast" call will be attempted, but if you have exceeded your + fast rate limit, then a slow call will be executed. If explicitly set to "fast", + then 429 will be returned if it cannot be executed fast. top_n: If provided, then only the top `n` documents will be returned in the results - array. + array. Otherwise, `n` will be the length of the provided documents array. extra_headers: Send extra headers @@ -170,8 +204,9 @@ async def rerank( body=await async_maybe_transform( { "documents": documents, - "query": query, "model": model, + "query": query, + "latency": latency, "top_n": top_n, }, model_rerank_params.ModelRerankParams, diff --git a/src/zeroentropy/resources/queries.py b/src/zeroentropy/resources/queries.py index 635e01a..6e89238 100644 --- a/src/zeroentropy/resources/queries.py +++ b/src/zeroentropy/resources/queries.py @@ -8,7 +8,7 @@ import httpx from ..types import query_top_pages_params, query_top_snippets_params, query_top_documents_params -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -51,17 +51,17 @@ def top_documents( *, collection_name: str, k: int, - query: str, - filter: Optional[Dict[str, object]] | NotGiven = NOT_GIVEN, - include_metadata: bool | NotGiven = NOT_GIVEN, - latency_mode: Literal["low", "high"] | NotGiven = NOT_GIVEN, - reranker: Optional[str] | NotGiven = NOT_GIVEN, + query: Optional[str], + filter: Optional[Dict[str, object]] | Omit = omit, + include_metadata: bool | Omit = omit, + latency_mode: Literal["low", "high"] | Omit = omit, + reranker: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> QueryTopDocumentsResponse: """ Get the top K documents that match the given query @@ -74,6 +74,9 @@ def top_documents( 2048, inclusive. query: The natural language query to search with. This cannot exceed 4096 UTF-8 bytes. + If `null`, then the sort will be undefined. The purpose of `null` is to do + faster metadata filter searches without care for relevancy. Cost per query is + unchanged. filter: The query filter to apply. Please read [Metadata Filtering](/metadata-filtering) for more information. If not provided, then all documents will be searched. @@ -125,15 +128,15 @@ def top_pages( collection_name: str, k: int, query: str, - filter: Optional[Dict[str, object]] | NotGiven = NOT_GIVEN, - include_content: bool | NotGiven = NOT_GIVEN, - latency_mode: Literal["low", "high"] | NotGiven = NOT_GIVEN, + filter: Optional[Dict[str, object]] | Omit = omit, + include_content: bool | Omit = omit, + latency_mode: Literal["low", "high"] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> QueryTopPagesResponse: """ Get the top K pages that match the given query @@ -191,16 +194,16 @@ def top_snippets( collection_name: str, k: int, query: str, - filter: Optional[Dict[str, object]] | NotGiven = NOT_GIVEN, - include_document_metadata: bool | NotGiven = NOT_GIVEN, - precise_responses: bool | NotGiven = NOT_GIVEN, - reranker: Optional[str] | NotGiven = NOT_GIVEN, + filter: Optional[Dict[str, object]] | Omit = omit, + include_document_metadata: bool | Omit = omit, + precise_responses: bool | Omit = omit, + reranker: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> QueryTopSnippetsResponse: """ Get the top K snippets that match the given query. @@ -289,17 +292,17 @@ async def top_documents( *, collection_name: str, k: int, - query: str, - filter: Optional[Dict[str, object]] | NotGiven = NOT_GIVEN, - include_metadata: bool | NotGiven = NOT_GIVEN, - latency_mode: Literal["low", "high"] | NotGiven = NOT_GIVEN, - reranker: Optional[str] | NotGiven = NOT_GIVEN, + query: Optional[str], + filter: Optional[Dict[str, object]] | Omit = omit, + include_metadata: bool | Omit = omit, + latency_mode: Literal["low", "high"] | Omit = omit, + reranker: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> QueryTopDocumentsResponse: """ Get the top K documents that match the given query @@ -312,6 +315,9 @@ async def top_documents( 2048, inclusive. query: The natural language query to search with. This cannot exceed 4096 UTF-8 bytes. + If `null`, then the sort will be undefined. The purpose of `null` is to do + faster metadata filter searches without care for relevancy. Cost per query is + unchanged. filter: The query filter to apply. Please read [Metadata Filtering](/metadata-filtering) for more information. If not provided, then all documents will be searched. @@ -363,15 +369,15 @@ async def top_pages( collection_name: str, k: int, query: str, - filter: Optional[Dict[str, object]] | NotGiven = NOT_GIVEN, - include_content: bool | NotGiven = NOT_GIVEN, - latency_mode: Literal["low", "high"] | NotGiven = NOT_GIVEN, + filter: Optional[Dict[str, object]] | Omit = omit, + include_content: bool | Omit = omit, + latency_mode: Literal["low", "high"] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> QueryTopPagesResponse: """ Get the top K pages that match the given query @@ -429,16 +435,16 @@ async def top_snippets( collection_name: str, k: int, query: str, - filter: Optional[Dict[str, object]] | NotGiven = NOT_GIVEN, - include_document_metadata: bool | NotGiven = NOT_GIVEN, - precise_responses: bool | NotGiven = NOT_GIVEN, - reranker: Optional[str] | NotGiven = NOT_GIVEN, + filter: Optional[Dict[str, object]] | Omit = omit, + include_document_metadata: bool | Omit = omit, + precise_responses: bool | Omit = omit, + reranker: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> QueryTopSnippetsResponse: """ Get the top K snippets that match the given query. diff --git a/src/zeroentropy/resources/status.py b/src/zeroentropy/resources/status.py index 5cba378..931573d 100644 --- a/src/zeroentropy/resources/status.py +++ b/src/zeroentropy/resources/status.py @@ -7,7 +7,7 @@ import httpx from ..types import status_get_status_params -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -46,13 +46,13 @@ def with_streaming_response(self) -> StatusResourceWithStreamingResponse: def get_status( self, *, - collection_name: Optional[str] | NotGiven = NOT_GIVEN, + collection_name: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> StatusGetStatusResponse: """ Gets the current indexing status across all documents. @@ -109,13 +109,13 @@ def with_streaming_response(self) -> AsyncStatusResourceWithStreamingResponse: async def get_status( self, *, - collection_name: Optional[str] | NotGiven = NOT_GIVEN, + collection_name: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> StatusGetStatusResponse: """ Gets the current indexing status across all documents. diff --git a/src/zeroentropy/types/collection_add_params.py b/src/zeroentropy/types/collection_add_params.py index c951146..595e7af 100644 --- a/src/zeroentropy/types/collection_add_params.py +++ b/src/zeroentropy/types/collection_add_params.py @@ -14,3 +14,13 @@ class CollectionAddParams(TypedDict, total=False): The maximum length of this string is 1024 characters. If special characters are used, then the UTF-8 encoded string cannot exceed 1024 bytes. """ + + num_shards: int + """[ADVANCED] The number of shards to use for this collection. + + By using K shards, your documents can index with K times more throughput. + However, queries will be automatically sent to all K shards and then aggregated. + For large collections, this can make queries faster. But for small collections, + this will make queries slower. `num_shards` must be one of [1, 8, 16, 32, 64]. + The default is 1. + """ diff --git a/src/zeroentropy/types/document_add_params.py b/src/zeroentropy/types/document_add_params.py index 32608ea..d42e74c 100644 --- a/src/zeroentropy/types/document_add_params.py +++ b/src/zeroentropy/types/document_add_params.py @@ -75,8 +75,13 @@ class ContentAPITextPagesDocument(TypedDict, total=False): the second string has index 1. """ - type: Required[Literal["text-pages"]] - """This field must be `text-pages`""" + type: Required[Literal["text-pages", "text-pages-unordered"]] + """This field must be `text-pages` or `text-pages-unordered`. + + When `unordered` is provided, it is assumed that consecutive pages aren't meant + to be read one after another. For example, PDFs are ordered, and CSVs are + unordered. + """ class ContentAPIBinaryDocument(TypedDict, total=False): diff --git a/src/zeroentropy/types/model_rerank_params.py b/src/zeroentropy/types/model_rerank_params.py index 1005e7f..03af8c0 100644 --- a/src/zeroentropy/types/model_rerank_params.py +++ b/src/zeroentropy/types/model_rerank_params.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Optional -from typing_extensions import Required, TypedDict +from typing_extensions import Literal, Required, TypedDict from .._types import SequenceNotStr @@ -14,17 +14,28 @@ class ModelRerankParams(TypedDict, total=False): documents: Required[SequenceNotStr[str]] """The list of documents to rerank. Each document is a string.""" - query: Required[str] - """The query to rerank the documents by. + model: Required[str] + """The model ID to use for reranking. - Results will be in descending order of relevance. + Options are: ["zerank-2", "zerank-1", "zerank-1-small"] """ - model: str - """The model ID to use for reranking. Options are: ["zerank-1-large"]""" + query: Required[str] + """The query to rerank the documents by.""" + + latency: Optional[Literal["fast", "slow"]] + """Whether the call will be inferenced "fast" or "slow". + + RateLimits for slow API calls are orders of magnitude higher, but you can + expect >10 second latency. Fast inferences are guaranteed subsecond, but rate + limits are lower. If not specified, first a "fast" call will be attempted, but + if you have exceeded your fast rate limit, then a slow call will be executed. If + explicitly set to "fast", then 429 will be returned if it cannot be executed + fast. + """ top_n: Optional[int] """ If provided, then only the top `n` documents will be returned in the results - array. + array. Otherwise, `n` will be the length of the provided documents array. """ diff --git a/src/zeroentropy/types/model_rerank_response.py b/src/zeroentropy/types/model_rerank_response.py index ff8975a..fed8fe1 100644 --- a/src/zeroentropy/types/model_rerank_response.py +++ b/src/zeroentropy/types/model_rerank_response.py @@ -19,7 +19,8 @@ class Result(BaseModel): This number will range between 0.0 and 1.0. This score is dependent on only the query and the scored document; other documents do not affect this score. This - value is deterministic, but may vary slightly due to floating point error. + value is intended to be deterministic, but it may vary slightly due to floating + point error. """ diff --git a/src/zeroentropy/types/query_top_documents_params.py b/src/zeroentropy/types/query_top_documents_params.py index 0eb1eed..e5f74b5 100644 --- a/src/zeroentropy/types/query_top_documents_params.py +++ b/src/zeroentropy/types/query_top_documents_params.py @@ -19,8 +19,13 @@ class QueryTopDocumentsParams(TypedDict, total=False): returned. This number must be between 1 and 2048, inclusive. """ - query: Required[str] - """The natural language query to search with. This cannot exceed 4096 UTF-8 bytes.""" + query: Required[Optional[str]] + """The natural language query to search with. + + This cannot exceed 4096 UTF-8 bytes. If `null`, then the sort will be undefined. + The purpose of `null` is to do faster metadata filter searches without care for + relevancy. Cost per query is unchanged. + """ filter: Optional[Dict[str, object]] """The query filter to apply. diff --git a/src/zeroentropy/types/status_get_status_response.py b/src/zeroentropy/types/status_get_status_response.py index f9aaf57..7288146 100644 --- a/src/zeroentropy/types/status_get_status_response.py +++ b/src/zeroentropy/types/status_get_status_response.py @@ -20,6 +20,12 @@ class StatusGetStatusResponse(BaseModel): please contact us at `founders@zeroentropy.dev` to assist. """ + num_indexed_bytes: int + """The total number of bytes used by documents that are currently indexed. + + Measured as UTF-8 bytes. For PDF/DOCX/PPT/etc, this is of the OCR'ed text. + """ + num_indexed_documents: int """The number of documents that are currently indexed.""" diff --git a/tests/api_resources/test_collections.py b/tests/api_resources/test_collections.py index 249c483..b0c599d 100644 --- a/tests/api_resources/test_collections.py +++ b/tests/api_resources/test_collections.py @@ -59,6 +59,14 @@ def test_method_add(self, client: ZeroEntropy) -> None: ) assert_matches_type(CollectionAddResponse, collection, path=["response"]) + @parametrize + def test_method_add_with_all_params(self, client: ZeroEntropy) -> None: + collection = client.collections.add( + collection_name="collection_name", + num_shards=0, + ) + assert_matches_type(CollectionAddResponse, collection, path=["response"]) + @parametrize def test_raw_response_add(self, client: ZeroEntropy) -> None: response = client.collections.with_raw_response.add( @@ -152,6 +160,14 @@ async def test_method_add(self, async_client: AsyncZeroEntropy) -> None: ) assert_matches_type(CollectionAddResponse, collection, path=["response"]) + @parametrize + async def test_method_add_with_all_params(self, async_client: AsyncZeroEntropy) -> None: + collection = await async_client.collections.add( + collection_name="collection_name", + num_shards=0, + ) + assert_matches_type(CollectionAddResponse, collection, path=["response"]) + @parametrize async def test_raw_response_add(self, async_client: AsyncZeroEntropy) -> None: response = await async_client.collections.with_raw_response.add( diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py index 23ffa4f..0022c11 100644 --- a/tests/api_resources/test_models.py +++ b/tests/api_resources/test_models.py @@ -21,6 +21,7 @@ class TestModels: def test_method_rerank(self, client: ZeroEntropy) -> None: model = client.models.rerank( documents=["string"], + model="model", query="query", ) assert_matches_type(ModelRerankResponse, model, path=["response"]) @@ -29,8 +30,9 @@ def test_method_rerank(self, client: ZeroEntropy) -> None: def test_method_rerank_with_all_params(self, client: ZeroEntropy) -> None: model = client.models.rerank( documents=["string"], - query="query", model="model", + query="query", + latency="fast", top_n=0, ) assert_matches_type(ModelRerankResponse, model, path=["response"]) @@ -39,6 +41,7 @@ def test_method_rerank_with_all_params(self, client: ZeroEntropy) -> None: def test_raw_response_rerank(self, client: ZeroEntropy) -> None: response = client.models.with_raw_response.rerank( documents=["string"], + model="model", query="query", ) @@ -51,6 +54,7 @@ def test_raw_response_rerank(self, client: ZeroEntropy) -> None: def test_streaming_response_rerank(self, client: ZeroEntropy) -> None: with client.models.with_streaming_response.rerank( documents=["string"], + model="model", query="query", ) as response: assert not response.is_closed @@ -71,6 +75,7 @@ class TestAsyncModels: async def test_method_rerank(self, async_client: AsyncZeroEntropy) -> None: model = await async_client.models.rerank( documents=["string"], + model="model", query="query", ) assert_matches_type(ModelRerankResponse, model, path=["response"]) @@ -79,8 +84,9 @@ async def test_method_rerank(self, async_client: AsyncZeroEntropy) -> None: async def test_method_rerank_with_all_params(self, async_client: AsyncZeroEntropy) -> None: model = await async_client.models.rerank( documents=["string"], - query="query", model="model", + query="query", + latency="fast", top_n=0, ) assert_matches_type(ModelRerankResponse, model, path=["response"]) @@ -89,6 +95,7 @@ async def test_method_rerank_with_all_params(self, async_client: AsyncZeroEntrop async def test_raw_response_rerank(self, async_client: AsyncZeroEntropy) -> None: response = await async_client.models.with_raw_response.rerank( documents=["string"], + model="model", query="query", ) @@ -101,6 +108,7 @@ async def test_raw_response_rerank(self, async_client: AsyncZeroEntropy) -> None async def test_streaming_response_rerank(self, async_client: AsyncZeroEntropy) -> None: async with async_client.models.with_streaming_response.rerank( documents=["string"], + model="model", query="query", ) as response: assert not response.is_closed diff --git a/tests/test_client.py b/tests/test_client.py index a97e096..38f4ee5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -59,51 +59,49 @@ def _get_open_connections(client: ZeroEntropy | AsyncZeroEntropy) -> int: class TestZeroEntropy: - client = ZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) - @pytest.mark.respx(base_url=base_url) - def test_raw_response(self, respx_mock: MockRouter) -> None: + def test_raw_response(self, respx_mock: MockRouter, client: ZeroEntropy) -> None: respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.post("/foo", cast_to=httpx.Response) + response = client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} @pytest.mark.respx(base_url=base_url) - def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + def test_raw_response_for_binary(self, respx_mock: MockRouter, client: ZeroEntropy) -> None: respx_mock.post("/foo").mock( return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') ) - response = self.client.post("/foo", cast_to=httpx.Response) + response = client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} - def test_copy(self) -> None: - copied = self.client.copy() - assert id(copied) != id(self.client) + def test_copy(self, client: ZeroEntropy) -> None: + copied = client.copy() + assert id(copied) != id(client) - copied = self.client.copy(api_key="another My API Key") + copied = client.copy(api_key="another My API Key") assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + assert client.api_key == "My API Key" - def test_copy_default_options(self) -> None: + def test_copy_default_options(self, client: ZeroEntropy) -> None: # options that have a default are overridden correctly - copied = self.client.copy(max_retries=7) + copied = client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 2 + assert client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 assert copied.max_retries == 7 # timeout - assert isinstance(self.client.timeout, httpx.Timeout) - copied = self.client.copy(timeout=None) + assert isinstance(client.timeout, httpx.Timeout) + copied = client.copy(timeout=None) assert copied.timeout is None - assert isinstance(self.client.timeout, httpx.Timeout) + assert isinstance(client.timeout, httpx.Timeout) def test_copy_default_headers(self) -> None: client = ZeroEntropy( @@ -138,6 +136,7 @@ def test_copy_default_headers(self) -> None: match="`default_headers` and `set_default_headers` arguments are mutually exclusive", ): client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + client.close() def test_copy_default_query(self) -> None: client = ZeroEntropy( @@ -175,13 +174,15 @@ def test_copy_default_query(self) -> None: ): client.copy(set_default_query={}, default_query={"foo": "Bar"}) - def test_copy_signature(self) -> None: + client.close() + + def test_copy_signature(self, client: ZeroEntropy) -> None: # ensure the same parameters that can be passed to the client are defined in the `.copy()` method init_signature = inspect.signature( # mypy doesn't like that we access the `__init__` property. - self.client.__init__, # type: ignore[misc] + client.__init__, # type: ignore[misc] ) - copy_signature = inspect.signature(self.client.copy) + copy_signature = inspect.signature(client.copy) exclude_params = {"transport", "proxies", "_strict_response_validation"} for name in init_signature.parameters.keys(): @@ -192,12 +193,12 @@ def test_copy_signature(self) -> None: assert copy_param is not None, f"copy() signature is missing the {name} param" @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") - def test_copy_build_request(self) -> None: + def test_copy_build_request(self, client: ZeroEntropy) -> None: options = FinalRequestOptions(method="get", url="/foo") def build_request(options: FinalRequestOptions) -> None: - client = self.client.copy() - client._build_request(options) + client_copy = client.copy() + client_copy._build_request(options) # ensure that the machinery is warmed up before tracing starts. build_request(options) @@ -254,14 +255,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() - def test_request_timeout(self) -> None: - request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) + def test_request_timeout(self, client: ZeroEntropy) -> None: + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT - request = self.client._build_request( - FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) - ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(100.0) @@ -274,6 +273,8 @@ def test_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(0) + client.close() + def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used with httpx.Client(timeout=None) as http_client: @@ -285,6 +286,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) + client.close() + # no timeout given to the httpx client should not use the httpx default with httpx.Client() as http_client: client = ZeroEntropy( @@ -295,6 +298,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT + client.close() + # explicitly passing the default timeout currently results in it being ignored with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: client = ZeroEntropy( @@ -305,6 +310,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + client.close() + async def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): async with httpx.AsyncClient() as http_client: @@ -316,14 +323,14 @@ async def test_invalid_http_client(self) -> None: ) def test_default_headers_option(self) -> None: - client = ZeroEntropy( + test_client = ZeroEntropy( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = ZeroEntropy( + test_client2 = ZeroEntropy( base_url=base_url, api_key=api_key, _strict_response_validation=True, @@ -332,10 +339,13 @@ def test_default_headers_option(self) -> None: "X-Stainless-Lang": "my-overriding-header", }, ) - request = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" + test_client.close() + test_client2.close() + def test_validate_headers(self) -> None: client = ZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -364,8 +374,10 @@ def test_default_query_option(self) -> None: url = httpx.URL(request.url) assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} - def test_request_extra_json(self) -> None: - request = self.client._build_request( + client.close() + + def test_request_extra_json(self, client: ZeroEntropy) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -376,7 +388,7 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": False} - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -387,7 +399,7 @@ def test_request_extra_json(self) -> None: assert data == {"baz": False} # `extra_json` takes priority over `json_data` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -398,8 +410,8 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": None} - def test_request_extra_headers(self) -> None: - request = self.client._build_request( + def test_request_extra_headers(self, client: ZeroEntropy) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -409,7 +421,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -420,8 +432,8 @@ def test_request_extra_headers(self) -> None: ) assert request.headers.get("X-Bar") == "false" - def test_request_extra_query(self) -> None: - request = self.client._build_request( + def test_request_extra_query(self, client: ZeroEntropy) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -434,7 +446,7 @@ def test_request_extra_query(self) -> None: assert params == {"my_query_param": "Foo"} # if both `query` and `extra_query` are given, they are merged - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -448,7 +460,7 @@ def test_request_extra_query(self) -> None: assert params == {"bar": "1", "foo": "2"} # `extra_query` takes priority over `query` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -491,7 +503,7 @@ def test_multipart_repeating_array(self, client: ZeroEntropy) -> None: ] @pytest.mark.respx(base_url=base_url) - def test_basic_union_response(self, respx_mock: MockRouter) -> None: + def test_basic_union_response(self, respx_mock: MockRouter, client: ZeroEntropy) -> None: class Model1(BaseModel): name: str @@ -500,12 +512,12 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @pytest.mark.respx(base_url=base_url) - def test_union_response_different_types(self, respx_mock: MockRouter) -> None: + def test_union_response_different_types(self, respx_mock: MockRouter, client: ZeroEntropy) -> None: """Union of objects with the same field name using a different type""" class Model1(BaseModel): @@ -516,18 +528,18 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: ZeroEntropy) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -543,7 +555,7 @@ class Model(BaseModel): ) ) - response = self.client.get("/foo", cast_to=Model) + response = client.get("/foo", cast_to=Model) assert isinstance(response, Model) assert response.foo == 2 @@ -557,6 +569,8 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" + client.close() + def test_base_url_env(self) -> None: with update_env(ZEROENTROPY_BASE_URL="http://localhost:5000/from/env"): client = ZeroEntropy(api_key=api_key, _strict_response_validation=True) @@ -586,6 +600,7 @@ def test_base_url_trailing_slash(self, client: ZeroEntropy) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + client.close() @pytest.mark.parametrize( "client", @@ -611,6 +626,7 @@ def test_base_url_no_trailing_slash(self, client: ZeroEntropy) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + client.close() @pytest.mark.parametrize( "client", @@ -636,35 +652,36 @@ def test_absolute_request_url(self, client: ZeroEntropy) -> None: ), ) assert request.url == "https://myapi.com/foo" + client.close() def test_copied_client_does_not_close_http(self) -> None: - client = ZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) - assert not client.is_closed() + test_client = ZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() - copied = client.copy() - assert copied is not client + copied = test_client.copy() + assert copied is not test_client del copied - assert not client.is_closed() + assert not test_client.is_closed() def test_client_context_manager(self) -> None: - client = ZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) - with client as c2: - assert c2 is client + test_client = ZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) + with test_client as c2: + assert c2 is test_client assert not c2.is_closed() - assert not client.is_closed() - assert client.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() @pytest.mark.respx(base_url=base_url) - def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: + def test_client_response_validation_error(self, respx_mock: MockRouter, client: ZeroEntropy) -> None: class Model(BaseModel): foo: str respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: - self.client.get("/foo", cast_to=Model) + client.get("/foo", cast_to=Model) assert isinstance(exc.value.__cause__, ValidationError) @@ -686,11 +703,14 @@ class Model(BaseModel): with pytest.raises(APIResponseValidationError): strict_client.get("/foo", cast_to=Model) - client = ZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=False) + non_strict_client = ZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=False) - response = client.get("/foo", cast_to=Model) + response = non_strict_client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + strict_client.close() + non_strict_client.close() + @pytest.mark.parametrize( "remaining_retries,retry_after,timeout", [ @@ -713,9 +733,9 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = ZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) - + def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, client: ZeroEntropy + ) -> None: headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) calculated = client._calculate_retry_timeout(remaining_retries, options, headers) @@ -729,7 +749,7 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, clien with pytest.raises(APITimeoutError): client.status.with_streaming_response.get_status().__enter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(client) == 0 @mock.patch("zeroentropy._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @@ -738,7 +758,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client with pytest.raises(APIStatusError): client.status.with_streaming_response.get_status().__enter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("zeroentropy._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -840,83 +860,77 @@ def test_default_client_creation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - def test_follow_redirects(self, respx_mock: MockRouter) -> None: + def test_follow_redirects(self, respx_mock: MockRouter, client: ZeroEntropy) -> None: # Test that the default follow_redirects=True allows following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) - response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) assert response.status_code == 200 assert response.json() == {"status": "ok"} @pytest.mark.respx(base_url=base_url) - def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: + def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: ZeroEntropy) -> None: # Test that follow_redirects=False prevents following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) with pytest.raises(APIStatusError) as exc_info: - self.client.post( - "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response - ) + client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response) assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" class TestAsyncZeroEntropy: - client = AsyncZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) - @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_raw_response(self, respx_mock: MockRouter) -> None: + async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncZeroEntropy) -> None: respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.post("/foo", cast_to=httpx.Response) + response = await async_client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncZeroEntropy) -> None: respx_mock.post("/foo").mock( return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') ) - response = await self.client.post("/foo", cast_to=httpx.Response) + response = await async_client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} - def test_copy(self) -> None: - copied = self.client.copy() - assert id(copied) != id(self.client) + def test_copy(self, async_client: AsyncZeroEntropy) -> None: + copied = async_client.copy() + assert id(copied) != id(async_client) - copied = self.client.copy(api_key="another My API Key") + copied = async_client.copy(api_key="another My API Key") assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + assert async_client.api_key == "My API Key" - def test_copy_default_options(self) -> None: + def test_copy_default_options(self, async_client: AsyncZeroEntropy) -> None: # options that have a default are overridden correctly - copied = self.client.copy(max_retries=7) + copied = async_client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 2 + assert async_client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 assert copied.max_retries == 7 # timeout - assert isinstance(self.client.timeout, httpx.Timeout) - copied = self.client.copy(timeout=None) + assert isinstance(async_client.timeout, httpx.Timeout) + copied = async_client.copy(timeout=None) assert copied.timeout is None - assert isinstance(self.client.timeout, httpx.Timeout) + assert isinstance(async_client.timeout, httpx.Timeout) - def test_copy_default_headers(self) -> None: + async def test_copy_default_headers(self) -> None: client = AsyncZeroEntropy( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) @@ -949,8 +963,9 @@ def test_copy_default_headers(self) -> None: match="`default_headers` and `set_default_headers` arguments are mutually exclusive", ): client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + await client.close() - def test_copy_default_query(self) -> None: + async def test_copy_default_query(self) -> None: client = AsyncZeroEntropy( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"} ) @@ -986,13 +1001,15 @@ def test_copy_default_query(self) -> None: ): client.copy(set_default_query={}, default_query={"foo": "Bar"}) - def test_copy_signature(self) -> None: + await client.close() + + def test_copy_signature(self, async_client: AsyncZeroEntropy) -> None: # ensure the same parameters that can be passed to the client are defined in the `.copy()` method init_signature = inspect.signature( # mypy doesn't like that we access the `__init__` property. - self.client.__init__, # type: ignore[misc] + async_client.__init__, # type: ignore[misc] ) - copy_signature = inspect.signature(self.client.copy) + copy_signature = inspect.signature(async_client.copy) exclude_params = {"transport", "proxies", "_strict_response_validation"} for name in init_signature.parameters.keys(): @@ -1003,12 +1020,12 @@ def test_copy_signature(self) -> None: assert copy_param is not None, f"copy() signature is missing the {name} param" @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") - def test_copy_build_request(self) -> None: + def test_copy_build_request(self, async_client: AsyncZeroEntropy) -> None: options = FinalRequestOptions(method="get", url="/foo") def build_request(options: FinalRequestOptions) -> None: - client = self.client.copy() - client._build_request(options) + client_copy = async_client.copy() + client_copy._build_request(options) # ensure that the machinery is warmed up before tracing starts. build_request(options) @@ -1065,12 +1082,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() - async def test_request_timeout(self) -> None: - request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) + async def test_request_timeout(self, async_client: AsyncZeroEntropy) -> None: + request = async_client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT - request = self.client._build_request( + request = async_client._build_request( FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) ) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore @@ -1085,6 +1102,8 @@ async def test_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(0) + await client.close() + async def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used async with httpx.AsyncClient(timeout=None) as http_client: @@ -1096,6 +1115,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) + await client.close() + # no timeout given to the httpx client should not use the httpx default async with httpx.AsyncClient() as http_client: client = AsyncZeroEntropy( @@ -1106,6 +1127,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT + await client.close() + # explicitly passing the default timeout currently results in it being ignored async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: client = AsyncZeroEntropy( @@ -1116,6 +1139,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + await client.close() + def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): with httpx.Client() as http_client: @@ -1126,15 +1151,15 @@ def test_invalid_http_client(self) -> None: http_client=cast(Any, http_client), ) - def test_default_headers_option(self) -> None: - client = AsyncZeroEntropy( + async def test_default_headers_option(self) -> None: + test_client = AsyncZeroEntropy( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = AsyncZeroEntropy( + test_client2 = AsyncZeroEntropy( base_url=base_url, api_key=api_key, _strict_response_validation=True, @@ -1143,10 +1168,13 @@ def test_default_headers_option(self) -> None: "X-Stainless-Lang": "my-overriding-header", }, ) - request = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" + await test_client.close() + await test_client2.close() + def test_validate_headers(self) -> None: client = AsyncZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -1157,7 +1185,7 @@ def test_validate_headers(self) -> None: client2 = AsyncZeroEntropy(base_url=base_url, api_key=None, _strict_response_validation=True) _ = client2 - def test_default_query_option(self) -> None: + async def test_default_query_option(self) -> None: client = AsyncZeroEntropy( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"} ) @@ -1175,8 +1203,10 @@ def test_default_query_option(self) -> None: url = httpx.URL(request.url) assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} - def test_request_extra_json(self) -> None: - request = self.client._build_request( + await client.close() + + def test_request_extra_json(self, client: ZeroEntropy) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1187,7 +1217,7 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": False} - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1198,7 +1228,7 @@ def test_request_extra_json(self) -> None: assert data == {"baz": False} # `extra_json` takes priority over `json_data` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1209,8 +1239,8 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": None} - def test_request_extra_headers(self) -> None: - request = self.client._build_request( + def test_request_extra_headers(self, client: ZeroEntropy) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1220,7 +1250,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1231,8 +1261,8 @@ def test_request_extra_headers(self) -> None: ) assert request.headers.get("X-Bar") == "false" - def test_request_extra_query(self) -> None: - request = self.client._build_request( + def test_request_extra_query(self, client: ZeroEntropy) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1245,7 +1275,7 @@ def test_request_extra_query(self) -> None: assert params == {"my_query_param": "Foo"} # if both `query` and `extra_query` are given, they are merged - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1259,7 +1289,7 @@ def test_request_extra_query(self) -> None: assert params == {"bar": "1", "foo": "2"} # `extra_query` takes priority over `query` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1302,7 +1332,7 @@ def test_multipart_repeating_array(self, async_client: AsyncZeroEntropy) -> None ] @pytest.mark.respx(base_url=base_url) - async def test_basic_union_response(self, respx_mock: MockRouter) -> None: + async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncZeroEntropy) -> None: class Model1(BaseModel): name: str @@ -1311,12 +1341,12 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @pytest.mark.respx(base_url=base_url) - async def test_union_response_different_types(self, respx_mock: MockRouter) -> None: + async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncZeroEntropy) -> None: """Union of objects with the same field name using a different type""" class Model1(BaseModel): @@ -1327,18 +1357,20 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + async def test_non_application_json_content_type_for_json_data( + self, respx_mock: MockRouter, async_client: AsyncZeroEntropy + ) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -1354,11 +1386,11 @@ class Model(BaseModel): ) ) - response = await self.client.get("/foo", cast_to=Model) + response = await async_client.get("/foo", cast_to=Model) assert isinstance(response, Model) assert response.foo == 2 - def test_base_url_setter(self) -> None: + async def test_base_url_setter(self) -> None: client = AsyncZeroEntropy( base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True ) @@ -1368,7 +1400,9 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" - def test_base_url_env(self) -> None: + await client.close() + + async def test_base_url_env(self) -> None: with update_env(ZEROENTROPY_BASE_URL="http://localhost:5000/from/env"): client = AsyncZeroEntropy(api_key=api_key, _strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" @@ -1388,7 +1422,7 @@ def test_base_url_env(self) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_trailing_slash(self, client: AsyncZeroEntropy) -> None: + async def test_base_url_trailing_slash(self, client: AsyncZeroEntropy) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1397,6 +1431,7 @@ def test_base_url_trailing_slash(self, client: AsyncZeroEntropy) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() @pytest.mark.parametrize( "client", @@ -1413,7 +1448,7 @@ def test_base_url_trailing_slash(self, client: AsyncZeroEntropy) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_no_trailing_slash(self, client: AsyncZeroEntropy) -> None: + async def test_base_url_no_trailing_slash(self, client: AsyncZeroEntropy) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1422,6 +1457,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncZeroEntropy) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() @pytest.mark.parametrize( "client", @@ -1438,7 +1474,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncZeroEntropy) -> None: ], ids=["standard", "custom http client"], ) - def test_absolute_request_url(self, client: AsyncZeroEntropy) -> None: + async def test_absolute_request_url(self, client: AsyncZeroEntropy) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1447,37 +1483,39 @@ def test_absolute_request_url(self, client: AsyncZeroEntropy) -> None: ), ) assert request.url == "https://myapi.com/foo" + await client.close() async def test_copied_client_does_not_close_http(self) -> None: - client = AsyncZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) - assert not client.is_closed() + test_client = AsyncZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() - copied = client.copy() - assert copied is not client + copied = test_client.copy() + assert copied is not test_client del copied await asyncio.sleep(0.2) - assert not client.is_closed() + assert not test_client.is_closed() async def test_client_context_manager(self) -> None: - client = AsyncZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) - async with client as c2: - assert c2 is client + test_client = AsyncZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) + async with test_client as c2: + assert c2 is test_client assert not c2.is_closed() - assert not client.is_closed() - assert client.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: + async def test_client_response_validation_error( + self, respx_mock: MockRouter, async_client: AsyncZeroEntropy + ) -> None: class Model(BaseModel): foo: str respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: - await self.client.get("/foo", cast_to=Model) + await async_client.get("/foo", cast_to=Model) assert isinstance(exc.value.__cause__, ValidationError) @@ -1488,7 +1526,6 @@ async def test_client_max_retries_validation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: class Model(BaseModel): name: str @@ -1500,11 +1537,14 @@ class Model(BaseModel): with pytest.raises(APIResponseValidationError): await strict_client.get("/foo", cast_to=Model) - client = AsyncZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=False) + non_strict_client = AsyncZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=False) - response = await client.get("/foo", cast_to=Model) + response = await non_strict_client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + await strict_client.close() + await non_strict_client.close() + @pytest.mark.parametrize( "remaining_retries,retry_after,timeout", [ @@ -1527,13 +1567,12 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - @pytest.mark.asyncio - async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = AsyncZeroEntropy(base_url=base_url, api_key=api_key, _strict_response_validation=True) - + async def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncZeroEntropy + ) -> None: headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) - calculated = client._calculate_retry_timeout(remaining_retries, options, headers) + calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] @mock.patch("zeroentropy._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -1546,7 +1585,7 @@ async def test_retrying_timeout_errors_doesnt_leak( with pytest.raises(APITimeoutError): await async_client.status.with_streaming_response.get_status().__aenter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(async_client) == 0 @mock.patch("zeroentropy._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @@ -1557,12 +1596,11 @@ async def test_retrying_status_errors_doesnt_leak( with pytest.raises(APIStatusError): await async_client.status.with_streaming_response.get_status().__aenter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(async_client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("zeroentropy._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio @pytest.mark.parametrize("failure_mode", ["status", "exception"]) async def test_retries_taken( self, @@ -1594,7 +1632,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("zeroentropy._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_omit_retry_count_header( self, async_client: AsyncZeroEntropy, failures_before_success: int, respx_mock: MockRouter ) -> None: @@ -1618,7 +1655,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("zeroentropy._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_overwrite_retry_count_header( self, async_client: AsyncZeroEntropy, failures_before_success: int, respx_mock: MockRouter ) -> None: @@ -1666,26 +1702,26 @@ async def test_default_client_creation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - async def test_follow_redirects(self, respx_mock: MockRouter) -> None: + async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncZeroEntropy) -> None: # Test that the default follow_redirects=True allows following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) - response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) assert response.status_code == 200 assert response.json() == {"status": "ok"} @pytest.mark.respx(base_url=base_url) - async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: + async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncZeroEntropy) -> None: # Test that follow_redirects=False prevents following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) with pytest.raises(APIStatusError) as exc_info: - await self.client.post( + await async_client.post( "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response ) diff --git a/tests/test_models.py b/tests/test_models.py index 3b040dd..ed031c7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,7 +9,7 @@ from zeroentropy._utils import PropertyInfo from zeroentropy._compat import PYDANTIC_V1, parse_obj, model_dump, model_json -from zeroentropy._models import BaseModel, construct_type +from zeroentropy._models import DISCRIMINATOR_CACHE, BaseModel, construct_type class BasicModel(BaseModel): @@ -809,7 +809,7 @@ class B(BaseModel): UnionType = cast(Any, Union[A, B]) - assert not hasattr(UnionType, "__discriminator__") + assert not DISCRIMINATOR_CACHE.get(UnionType) m = construct_type( value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) @@ -818,7 +818,7 @@ class B(BaseModel): assert m.type == "b" assert m.data == "foo" # type: ignore[comparison-overlap] - discriminator = UnionType.__discriminator__ + discriminator = DISCRIMINATOR_CACHE.get(UnionType) assert discriminator is not None m = construct_type( @@ -830,7 +830,7 @@ class B(BaseModel): # if the discriminator details object stays the same between invocations then # we hit the cache - assert UnionType.__discriminator__ is discriminator + assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator @pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") diff --git a/tests/test_transform.py b/tests/test_transform.py index f1036bb..b338f37 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -8,7 +8,7 @@ import pytest -from zeroentropy._types import NOT_GIVEN, Base64FileInput +from zeroentropy._types import Base64FileInput, omit, not_given from zeroentropy._utils import ( PropertyInfo, transform as _transform, @@ -450,4 +450,11 @@ async def test_transform_skipping(use_async: bool) -> None: @pytest.mark.asyncio async def test_strips_notgiven(use_async: bool) -> None: assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} - assert await transform({"foo_bar": NOT_GIVEN}, Foo1, use_async) == {} + assert await transform({"foo_bar": not_given}, Foo1, use_async) == {} + + +@parametrize +@pytest.mark.asyncio +async def test_strips_omit(use_async: bool) -> None: + assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} + assert await transform({"foo_bar": omit}, Foo1, use_async) == {} From 997eb22ff575e29e350a4101422116c202e3353c Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Mon, 24 Nov 2025 23:24:45 +0000 Subject: [PATCH 24/25] codegen metadata --- .stats.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.stats.yml b/.stats.yml index 3b44efb..6531ade 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 14 openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/zeroentropy%2Fzeroentropy-c95681b13dc56e64126746c6e546b564c7f802ae567fc9ccc1aeb8eddd40bb1e.yml openapi_spec_hash: 2ac723122fe938e384f11b5cf19e85ec -config_hash: e07cdee04c971e1db74e91a5a4cd981c +config_hash: 3be2ee54cbc850c508c90b9ffae2efe5 From 335af274c1d93ae7a1261cc2ed0248208ab201b6 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Mon, 24 Nov 2025 23:29:11 +0000 Subject: [PATCH 25/25] release: 0.1.0-alpha.7 --- .release-please-manifest.json | 2 +- CHANGELOG.md | 38 +++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- src/zeroentropy/_version.py | 2 +- 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 4f9005e..b5db7ce 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.1.0-alpha.6" + ".": "0.1.0-alpha.7" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index c57f354..a48de33 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,43 @@ # Changelog +## 0.1.0-alpha.7 (2025-11-24) + +Full Changelog: [v0.1.0-alpha.6...v0.1.0-alpha.7](https://github.com/zeroentropy-ai/zeroentropy-python/compare/v0.1.0-alpha.6...v0.1.0-alpha.7) + +### Features + +* **api:** manual updates ([7b6f1c5](https://github.com/zeroentropy-ai/zeroentropy-python/commit/7b6f1c52caff5579b0f42534ae393bf2bd634e2b)) +* clean up environment call outs ([de35b9d](https://github.com/zeroentropy-ai/zeroentropy-python/commit/de35b9dc8c136eeaa0909076488f152a5f885398)) +* **client:** support file upload requests ([4e03611](https://github.com/zeroentropy-ai/zeroentropy-python/commit/4e036110dec2d075f4de4cf44691c131dfa545b7)) +* improve future compat with pydantic v3 ([372675f](https://github.com/zeroentropy-ai/zeroentropy-python/commit/372675f01b9fd520b0f49749b85314a3333dfa2a)) +* **types:** replace List[str] with SequenceNotStr in params ([37625fd](https://github.com/zeroentropy-ai/zeroentropy-python/commit/37625fd6bb4efb5a4926e565eaa7b25852d75f51)) + + +### Bug Fixes + +* avoid newer type syntax ([8f77952](https://github.com/zeroentropy-ai/zeroentropy-python/commit/8f7795235684015a6cd419e2bab68b99ecf63b23)) +* **client:** don't send Content-Type header on GET requests ([dfd4e86](https://github.com/zeroentropy-ai/zeroentropy-python/commit/dfd4e866616c9c7e9fa6ffa3b757abf6d8cd9ebc)) +* **parsing:** correctly handle nested discriminated unions ([79004f8](https://github.com/zeroentropy-ai/zeroentropy-python/commit/79004f855735cc872d0050f781bf8b84d04fd592)) +* **parsing:** ignore empty metadata ([a8b35a8](https://github.com/zeroentropy-ai/zeroentropy-python/commit/a8b35a89d5b113eb662bd3113c9433dc5f1382f9)) +* **parsing:** parse extra field types ([5530ece](https://github.com/zeroentropy-ai/zeroentropy-python/commit/5530ece04b51a3300c4202b66d4e50101c374a60)) + + +### Chores + +* **internal:** add Sequence related utils ([856feb3](https://github.com/zeroentropy-ai/zeroentropy-python/commit/856feb326411abd1fc736eee5a64401f526285ca)) +* **internal:** bump pinned h11 dep ([b9e057f](https://github.com/zeroentropy-ai/zeroentropy-python/commit/b9e057f0827b8ed5851ec7e7a8b9ed622019f065)) +* **internal:** change ci workflow machines ([bb0ac37](https://github.com/zeroentropy-ai/zeroentropy-python/commit/bb0ac37c96cf27acdcc4e578e9419cbe4a457e18)) +* **internal:** fix ruff target version ([80e5aae](https://github.com/zeroentropy-ai/zeroentropy-python/commit/80e5aae0af40e80b5962b0dbf2d5351d9df4432e)) +* **internal:** move mypy configurations to `pyproject.toml` file ([97977db](https://github.com/zeroentropy-ai/zeroentropy-python/commit/97977db20f3f27cb1b01f4be6e3538796c0a5414)) +* **internal:** update comment in script ([e5a7e0f](https://github.com/zeroentropy-ai/zeroentropy-python/commit/e5a7e0f1592dcfe868064dcdea1446b358662461)) +* **internal:** update pyright exclude list ([0c89992](https://github.com/zeroentropy-ai/zeroentropy-python/commit/0c89992ab3ef88427c150b72fe510e26a202db43)) +* **package:** mark python 3.13 as supported ([144f8ca](https://github.com/zeroentropy-ai/zeroentropy-python/commit/144f8ca076c92c7e68403d6a0a49ce1caeee69b9)) +* **project:** add settings file for vscode ([c4f8607](https://github.com/zeroentropy-ai/zeroentropy-python/commit/c4f86076c0b09422a445ffa91e6880a085fabacd)) +* **readme:** fix version rendering on pypi ([3ccc314](https://github.com/zeroentropy-ai/zeroentropy-python/commit/3ccc314b021100775fe854d7225405c63b299599)) +* **tests:** simplify `get_platform` test ([ae111a6](https://github.com/zeroentropy-ai/zeroentropy-python/commit/ae111a642a1da930f319ce62f1a28f203881cfe5)) +* update @stainless-api/prism-cli to v5.15.0 ([0deb0ff](https://github.com/zeroentropy-ai/zeroentropy-python/commit/0deb0ff394e034254520520675869023912cecc1)) +* update github action ([7e76e08](https://github.com/zeroentropy-ai/zeroentropy-python/commit/7e76e087f68489a4a24ac990d48dd665d81f7181)) + ## 0.1.0-alpha.6 (2025-07-08) Full Changelog: [v0.1.0-alpha.5...v0.1.0-alpha.6](https://github.com/zeroentropy-ai/zeroentropy-python/compare/v0.1.0-alpha.5...v0.1.0-alpha.6) diff --git a/pyproject.toml b/pyproject.toml index 30d0058..82bffd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "zeroentropy" -version = "0.1.0-alpha.6" +version = "0.1.0-alpha.7" description = "The official Python library for the ZeroEntropy API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/zeroentropy/_version.py b/src/zeroentropy/_version.py index 24ebcef..50f41a0 100644 --- a/src/zeroentropy/_version.py +++ b/src/zeroentropy/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "zeroentropy" -__version__ = "0.1.0-alpha.6" # x-release-please-version +__version__ = "0.1.0-alpha.7" # x-release-please-version