diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index 6d78745..091cfb1 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "0.9.0"
+ ".": "0.10.0"
}
\ No newline at end of file
diff --git a/.stats.yml b/.stats.yml
index ed70296..64eaa82 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
-configured_endpoints: 34
+configured_endpoints: 36
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/contextual-ai%2Fsunrise-c8152db455001be3f09a3bc60d63711699d2c2a4ea5f7bbc1d71726efda0fd9b.yml
openapi_spec_hash: 97719df292ca220de5d35d36f9756b95
-config_hash: ae81af9b7eb88a788a80bcf3480e0b6b
+config_hash: fdaf751580ba8a60e222e560847af1ac
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 10ba2b3..b204145 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,26 @@
# Changelog
+## 0.10.0 (2025-11-11)
+
+Full Changelog: [v0.9.0...v0.10.0](https://github.com/ContextualAI/contextual-client-python/compare/v0.9.0...v0.10.0)
+
+### Features
+
+* **api:** update via SDK Studio ([921ea1c](https://github.com/ContextualAI/contextual-client-python/commit/921ea1c3e6e4432638c535c7e413c92d2e1398f5))
+
+
+### Bug Fixes
+
+* **client:** close streams without requiring full consumption ([3f212eb](https://github.com/ContextualAI/contextual-client-python/commit/3f212ebb31085b404c72d827f1d6992dd4bed24c))
+* compat with Python 3.14 ([6f2d195](https://github.com/ContextualAI/contextual-client-python/commit/6f2d1958bb397cedb94f970c361e617e01c3fdf6))
+
+
+### Chores
+
+* **internal/tests:** avoid race condition with implicit client cleanup ([f7f3568](https://github.com/ContextualAI/contextual-client-python/commit/f7f35681c6ac40661872fbdc3159e79ff764d135))
+* **internal:** grammar fix (it's -> its) ([12b822d](https://github.com/ContextualAI/contextual-client-python/commit/12b822dcede4ba84a7889775254f8b02b311ae5f))
+* **package:** drop Python 3.8 support ([c2ddf6a](https://github.com/ContextualAI/contextual-client-python/commit/c2ddf6a2d51ff845cb2dcd872dc37b934ef97199))
+
## 0.9.0 (2025-10-28)
Full Changelog: [v0.8.0...v0.9.0](https://github.com/ContextualAI/contextual-client-python/compare/v0.8.0...v0.9.0)
diff --git a/README.md b/README.md
index c05a737..ccdb4cc 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
[)](https://pypi.org/project/contextual-client/)
-The Contextual AI Python library provides convenient access to the Contextual AI REST API from any Python 3.8+
+The Contextual AI Python library provides convenient access to the Contextual AI REST API from any Python 3.9+
application. The library includes type definitions for all request params and response fields,
and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx).
@@ -467,7 +467,7 @@ print(contextual.__version__)
## Requirements
-Python 3.8 or higher.
+Python 3.9 or higher.
## Contributing
diff --git a/api.md b/api.md
index 225b992..bc19f43 100644
--- a/api.md
+++ b/api.md
@@ -45,6 +45,19 @@ Methods:
- client.datastores.documents.metadata(document_id, \*, datastore_id) -> DocumentMetadata
- client.datastores.documents.set_metadata(document_id, \*, datastore_id, \*\*params) -> DocumentMetadata
+## Contents
+
+Types:
+
+```python
+from contextual.types.datastores import ContentListResponse, ContentMetadataResponse
+```
+
+Methods:
+
+- client.datastores.contents.list(datastore_id, \*\*params) -> SyncContentsPage[ContentListResponse]
+- client.datastores.contents.metadata(content_id, \*, datastore_id, \*\*params) -> ContentMetadataResponse
+
# Agents
Types:
diff --git a/pyproject.toml b/pyproject.toml
index 5d15d3e..9aba0a8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "contextual-client"
-version = "0.9.0"
+version = "0.10.0"
description = "The official Python library for the Contextual AI API"
dynamic = ["readme"]
license = "Apache-2.0"
@@ -15,11 +15,10 @@ dependencies = [
"distro>=1.7.0, <2",
"sniffio",
]
-requires-python = ">= 3.8"
+requires-python = ">= 3.9"
classifiers = [
"Typing :: Typed",
"Intended Audience :: Developers",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
@@ -141,7 +140,7 @@ filterwarnings = [
# there are a couple of flags that are still disabled by
# default in strict mode as they are experimental and niche.
typeCheckingMode = "strict"
-pythonVersion = "3.8"
+pythonVersion = "3.9"
exclude = [
"_dev",
diff --git a/src/contextual/_models.py b/src/contextual/_models.py
index 6a3cd1d..fcec2cf 100644
--- a/src/contextual/_models.py
+++ b/src/contextual/_models.py
@@ -2,6 +2,7 @@
import os
import inspect
+import weakref
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast
from datetime import date, datetime
from typing_extensions import (
@@ -573,6 +574,9 @@ class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails
+DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary()
+
+
class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
@@ -615,8 +619,9 @@ def __init__(
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
- if isinstance(union, CachedDiscriminatorType):
- return union.__discriminator__
+ cached = DISCRIMINATOR_CACHE.get(union)
+ if cached is not None:
+ return cached
discriminator_field_name: str | None = None
@@ -669,7 +674,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
- cast(CachedDiscriminatorType, union).__discriminator__ = details
+ DISCRIMINATOR_CACHE.setdefault(union, details)
return details
diff --git a/src/contextual/_streaming.py b/src/contextual/_streaming.py
index 5f4b671..deb13c2 100644
--- a/src/contextual/_streaming.py
+++ b/src/contextual/_streaming.py
@@ -57,9 +57,8 @@ def __stream__(self) -> Iterator[_T]:
for sse in iterator:
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
- # Ensure the entire stream is consumed
- for _sse in iterator:
- ...
+ # As we might not fully consume the response stream, we need to close it explicitly
+ response.close()
def __enter__(self) -> Self:
return self
@@ -121,9 +120,8 @@ async def __stream__(self) -> AsyncIterator[_T]:
async for sse in iterator:
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
- # Ensure the entire stream is consumed
- async for _sse in iterator:
- ...
+ # As we might not fully consume the response stream, we need to close it explicitly
+ await response.aclose()
async def __aenter__(self) -> Self:
return self
diff --git a/src/contextual/_utils/_sync.py b/src/contextual/_utils/_sync.py
index ad7ec71..f6027c1 100644
--- a/src/contextual/_utils/_sync.py
+++ b/src/contextual/_utils/_sync.py
@@ -1,10 +1,8 @@
from __future__ import annotations
-import sys
import asyncio
import functools
-import contextvars
-from typing import Any, TypeVar, Callable, Awaitable
+from typing import TypeVar, Callable, Awaitable
from typing_extensions import ParamSpec
import anyio
@@ -15,34 +13,11 @@
T_ParamSpec = ParamSpec("T_ParamSpec")
-if sys.version_info >= (3, 9):
- _asyncio_to_thread = asyncio.to_thread
-else:
- # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread
- # for Python 3.8 support
- async def _asyncio_to_thread(
- func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
- ) -> Any:
- """Asynchronously run function *func* in a separate thread.
-
- Any *args and **kwargs supplied for this function are directly passed
- to *func*. Also, the current :class:`contextvars.Context` is propagated,
- allowing context variables from the main thread to be accessed in the
- separate thread.
-
- Returns a coroutine that can be awaited to get the eventual result of *func*.
- """
- loop = asyncio.events.get_running_loop()
- ctx = contextvars.copy_context()
- func_call = functools.partial(ctx.run, func, *args, **kwargs)
- return await loop.run_in_executor(None, func_call)
-
-
async def to_thread(
func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
) -> T_Retval:
if sniffio.current_async_library() == "asyncio":
- return await _asyncio_to_thread(func, *args, **kwargs)
+ return await asyncio.to_thread(func, *args, **kwargs)
return await anyio.to_thread.run_sync(
functools.partial(func, *args, **kwargs),
@@ -53,10 +28,7 @@ async def to_thread(
def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
"""
Take a blocking function and create an async one that receives the same
- positional and keyword arguments. For python version 3.9 and above, it uses
- asyncio.to_thread to run the function in a separate thread. For python version
- 3.8, it uses locally defined copy of the asyncio.to_thread function which was
- introduced in python 3.9.
+ positional and keyword arguments.
Usage:
diff --git a/src/contextual/_utils/_utils.py b/src/contextual/_utils/_utils.py
index 50d5926..eec7f4a 100644
--- a/src/contextual/_utils/_utils.py
+++ b/src/contextual/_utils/_utils.py
@@ -133,7 +133,7 @@ def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]:
# Type safe methods for narrowing types with TypeVars.
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
# however this cause Pyright to rightfully report errors. As we know we don't
-# care about the contained types we can safely use `object` in it's place.
+# care about the contained types we can safely use `object` in its place.
#
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
# `is_*` is for when you're dealing with an unknown input
diff --git a/src/contextual/_version.py b/src/contextual/_version.py
index 1cf85a0..5382f19 100644
--- a/src/contextual/_version.py
+++ b/src/contextual/_version.py
@@ -1,4 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "contextual"
-__version__ = "0.9.0" # x-release-please-version
+__version__ = "0.10.0" # x-release-please-version
diff --git a/src/contextual/pagination.py b/src/contextual/pagination.py
index 9f78352..b32430e 100644
--- a/src/contextual/pagination.py
+++ b/src/contextual/pagination.py
@@ -14,6 +14,8 @@
"AsyncUsersPage",
"SyncPage",
"AsyncPage",
+ "SyncContentsPage",
+ "AsyncContentsPage",
]
_T = TypeVar("_T")
@@ -177,3 +179,47 @@ def next_page_info(self) -> Optional[PageInfo]:
return None
return PageInfo(params={"cursor": next_cursor})
+
+
+class SyncContentsPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
+ data: List[_T]
+
+ @override
+ def _get_page_items(self) -> List[_T]:
+ data = self.data
+ if not data:
+ return []
+ return data
+
+ @override
+ def next_page_info(self) -> Optional[PageInfo]:
+ offset = self._options.params.get("offset") or 0
+ if not isinstance(offset, int):
+ raise ValueError(f'Expected "offset" param to be an integer but got {offset}')
+
+ length = len(self._get_page_items())
+ current_count = offset + length
+
+ return PageInfo(params={"offset": current_count})
+
+
+class AsyncContentsPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
+ data: List[_T]
+
+ @override
+ def _get_page_items(self) -> List[_T]:
+ data = self.data
+ if not data:
+ return []
+ return data
+
+ @override
+ def next_page_info(self) -> Optional[PageInfo]:
+ offset = self._options.params.get("offset") or 0
+ if not isinstance(offset, int):
+ raise ValueError(f'Expected "offset" param to be an integer but got {offset}')
+
+ length = len(self._get_page_items())
+ current_count = offset + length
+
+ return PageInfo(params={"offset": current_count})
diff --git a/src/contextual/resources/datastores/__init__.py b/src/contextual/resources/datastores/__init__.py
index 39359dd..2b127f3 100644
--- a/src/contextual/resources/datastores/__init__.py
+++ b/src/contextual/resources/datastores/__init__.py
@@ -1,5 +1,13 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+from .contents import (
+ ContentsResource,
+ AsyncContentsResource,
+ ContentsResourceWithRawResponse,
+ AsyncContentsResourceWithRawResponse,
+ ContentsResourceWithStreamingResponse,
+ AsyncContentsResourceWithStreamingResponse,
+)
from .documents import (
DocumentsResource,
AsyncDocumentsResource,
@@ -24,6 +32,12 @@
"AsyncDocumentsResourceWithRawResponse",
"DocumentsResourceWithStreamingResponse",
"AsyncDocumentsResourceWithStreamingResponse",
+ "ContentsResource",
+ "AsyncContentsResource",
+ "ContentsResourceWithRawResponse",
+ "AsyncContentsResourceWithRawResponse",
+ "ContentsResourceWithStreamingResponse",
+ "AsyncContentsResourceWithStreamingResponse",
"DatastoresResource",
"AsyncDatastoresResource",
"DatastoresResourceWithRawResponse",
diff --git a/src/contextual/resources/datastores/contents.py b/src/contextual/resources/datastores/contents.py
new file mode 100644
index 0000000..7b2bc03
--- /dev/null
+++ b/src/contextual/resources/datastores/contents.py
@@ -0,0 +1,329 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing import Any, cast
+
+import httpx
+
+from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
+from ..._utils import maybe_transform, async_maybe_transform
+from ..._compat import cached_property
+from ..._resource import SyncAPIResource, AsyncAPIResource
+from ..._response import (
+ to_raw_response_wrapper,
+ to_streamed_response_wrapper,
+ async_to_raw_response_wrapper,
+ async_to_streamed_response_wrapper,
+)
+from ...pagination import SyncContentsPage, AsyncContentsPage
+from ..._base_client import AsyncPaginator, make_request_options
+from ...types.datastores import content_list_params, content_metadata_params
+from ...types.datastores.content_list_response import ContentListResponse
+from ...types.datastores.content_metadata_response import ContentMetadataResponse
+
+__all__ = ["ContentsResource", "AsyncContentsResource"]
+
+
+class ContentsResource(SyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> ContentsResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/ContextualAI/contextual-client-python#accessing-raw-response-data-eg-headers
+ """
+ return ContentsResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> ContentsResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/ContextualAI/contextual-client-python#with_streaming_response
+ """
+ return ContentsResourceWithStreamingResponse(self)
+
+ def list(
+ self,
+ datastore_id: str,
+ *,
+ document_id: str | Omit = omit,
+ limit: int | Omit = omit,
+ offset: int | Omit = omit,
+ search: str | Omit = omit,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ ) -> SyncContentsPage[ContentListResponse]:
+ """
+ Get Document Contents
+
+ Args:
+ datastore_id: Datastore ID of the datastore from which to retrieve the document
+
+ document_id: Document ID of the document to retrieve details for
+
+ limit: The number of content ids to be returned
+
+ offset: The offset to start retrieving content ids
+
+ search: The query to search keywords for
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not datastore_id:
+ raise ValueError(f"Expected a non-empty value for `datastore_id` but received {datastore_id!r}")
+ return self._get_api_list(
+ f"/datastores/{datastore_id}/contents",
+ page=SyncContentsPage[ContentListResponse],
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform(
+ {
+ "document_id": document_id,
+ "limit": limit,
+ "offset": offset,
+ "search": search,
+ },
+ content_list_params.ContentListParams,
+ ),
+ ),
+ model=cast(Any, ContentListResponse), # Union types cannot be passed in as arguments in the type system
+ )
+
+ def metadata(
+ self,
+ content_id: str,
+ *,
+ datastore_id: str,
+ cursor: str | Omit = omit,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ ) -> ContentMetadataResponse:
+ """
+ Get Content Metadata
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not datastore_id:
+ raise ValueError(f"Expected a non-empty value for `datastore_id` but received {datastore_id!r}")
+ if not content_id:
+ raise ValueError(f"Expected a non-empty value for `content_id` but received {content_id!r}")
+ return cast(
+ ContentMetadataResponse,
+ self._get(
+ f"/datastores/{datastore_id}/contents/{content_id}/metadata",
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform({"cursor": cursor}, content_metadata_params.ContentMetadataParams),
+ ),
+ cast_to=cast(
+ Any, ContentMetadataResponse
+ ), # Union types cannot be passed in as arguments in the type system
+ ),
+ )
+
+
+class AsyncContentsResource(AsyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> AsyncContentsResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/ContextualAI/contextual-client-python#accessing-raw-response-data-eg-headers
+ """
+ return AsyncContentsResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> AsyncContentsResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/ContextualAI/contextual-client-python#with_streaming_response
+ """
+ return AsyncContentsResourceWithStreamingResponse(self)
+
+ def list(
+ self,
+ datastore_id: str,
+ *,
+ document_id: str | Omit = omit,
+ limit: int | Omit = omit,
+ offset: int | Omit = omit,
+ search: str | Omit = omit,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ ) -> AsyncPaginator[ContentListResponse, AsyncContentsPage[ContentListResponse]]:
+ """
+ Get Document Contents
+
+ Args:
+ datastore_id: Datastore ID of the datastore from which to retrieve the document
+
+ document_id: Document ID of the document to retrieve details for
+
+ limit: The number of content ids to be returned
+
+ offset: The offset to start retrieving content ids
+
+ search: The query to search keywords for
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not datastore_id:
+ raise ValueError(f"Expected a non-empty value for `datastore_id` but received {datastore_id!r}")
+ return self._get_api_list(
+ f"/datastores/{datastore_id}/contents",
+ page=AsyncContentsPage[ContentListResponse],
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform(
+ {
+ "document_id": document_id,
+ "limit": limit,
+ "offset": offset,
+ "search": search,
+ },
+ content_list_params.ContentListParams,
+ ),
+ ),
+ model=cast(Any, ContentListResponse), # Union types cannot be passed in as arguments in the type system
+ )
+
+ async def metadata(
+ self,
+ content_id: str,
+ *,
+ datastore_id: str,
+ cursor: str | Omit = omit,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ ) -> ContentMetadataResponse:
+ """
+ Get Content Metadata
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not datastore_id:
+ raise ValueError(f"Expected a non-empty value for `datastore_id` but received {datastore_id!r}")
+ if not content_id:
+ raise ValueError(f"Expected a non-empty value for `content_id` but received {content_id!r}")
+ return cast(
+ ContentMetadataResponse,
+ await self._get(
+ f"/datastores/{datastore_id}/contents/{content_id}/metadata",
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=await async_maybe_transform(
+ {"cursor": cursor}, content_metadata_params.ContentMetadataParams
+ ),
+ ),
+ cast_to=cast(
+ Any, ContentMetadataResponse
+ ), # Union types cannot be passed in as arguments in the type system
+ ),
+ )
+
+
+class ContentsResourceWithRawResponse:
+ def __init__(self, contents: ContentsResource) -> None:
+ self._contents = contents
+
+ self.list = to_raw_response_wrapper(
+ contents.list,
+ )
+ self.metadata = to_raw_response_wrapper(
+ contents.metadata,
+ )
+
+
+class AsyncContentsResourceWithRawResponse:
+ def __init__(self, contents: AsyncContentsResource) -> None:
+ self._contents = contents
+
+ self.list = async_to_raw_response_wrapper(
+ contents.list,
+ )
+ self.metadata = async_to_raw_response_wrapper(
+ contents.metadata,
+ )
+
+
+class ContentsResourceWithStreamingResponse:
+ def __init__(self, contents: ContentsResource) -> None:
+ self._contents = contents
+
+ self.list = to_streamed_response_wrapper(
+ contents.list,
+ )
+ self.metadata = to_streamed_response_wrapper(
+ contents.metadata,
+ )
+
+
+class AsyncContentsResourceWithStreamingResponse:
+ def __init__(self, contents: AsyncContentsResource) -> None:
+ self._contents = contents
+
+ self.list = async_to_streamed_response_wrapper(
+ contents.list,
+ )
+ self.metadata = async_to_streamed_response_wrapper(
+ contents.metadata,
+ )
diff --git a/src/contextual/resources/datastores/datastores.py b/src/contextual/resources/datastores/datastores.py
index 3a0b964..5442461 100644
--- a/src/contextual/resources/datastores/datastores.py
+++ b/src/contextual/resources/datastores/datastores.py
@@ -7,6 +7,14 @@
from ...types import datastore_list_params, datastore_create_params, datastore_update_params
from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
from ..._utils import maybe_transform, async_maybe_transform
+from .contents import (
+ ContentsResource,
+ AsyncContentsResource,
+ ContentsResourceWithRawResponse,
+ AsyncContentsResourceWithRawResponse,
+ ContentsResourceWithStreamingResponse,
+ AsyncContentsResourceWithStreamingResponse,
+)
from ..._compat import cached_property
from .documents import (
DocumentsResource,
@@ -38,6 +46,10 @@ class DatastoresResource(SyncAPIResource):
def documents(self) -> DocumentsResource:
return DocumentsResource(self._client)
+ @cached_property
+ def contents(self) -> ContentsResource:
+ return ContentsResource(self._client)
+
@cached_property
def with_raw_response(self) -> DatastoresResourceWithRawResponse:
"""
@@ -340,6 +352,10 @@ class AsyncDatastoresResource(AsyncAPIResource):
def documents(self) -> AsyncDocumentsResource:
return AsyncDocumentsResource(self._client)
+ @cached_property
+ def contents(self) -> AsyncContentsResource:
+ return AsyncContentsResource(self._client)
+
@cached_property
def with_raw_response(self) -> AsyncDatastoresResourceWithRawResponse:
"""
@@ -664,6 +680,10 @@ def __init__(self, datastores: DatastoresResource) -> None:
def documents(self) -> DocumentsResourceWithRawResponse:
return DocumentsResourceWithRawResponse(self._datastores.documents)
+ @cached_property
+ def contents(self) -> ContentsResourceWithRawResponse:
+ return ContentsResourceWithRawResponse(self._datastores.contents)
+
class AsyncDatastoresResourceWithRawResponse:
def __init__(self, datastores: AsyncDatastoresResource) -> None:
@@ -692,6 +712,10 @@ def __init__(self, datastores: AsyncDatastoresResource) -> None:
def documents(self) -> AsyncDocumentsResourceWithRawResponse:
return AsyncDocumentsResourceWithRawResponse(self._datastores.documents)
+ @cached_property
+ def contents(self) -> AsyncContentsResourceWithRawResponse:
+ return AsyncContentsResourceWithRawResponse(self._datastores.contents)
+
class DatastoresResourceWithStreamingResponse:
def __init__(self, datastores: DatastoresResource) -> None:
@@ -720,6 +744,10 @@ def __init__(self, datastores: DatastoresResource) -> None:
def documents(self) -> DocumentsResourceWithStreamingResponse:
return DocumentsResourceWithStreamingResponse(self._datastores.documents)
+ @cached_property
+ def contents(self) -> ContentsResourceWithStreamingResponse:
+ return ContentsResourceWithStreamingResponse(self._datastores.contents)
+
class AsyncDatastoresResourceWithStreamingResponse:
def __init__(self, datastores: AsyncDatastoresResource) -> None:
@@ -747,3 +775,7 @@ def __init__(self, datastores: AsyncDatastoresResource) -> None:
@cached_property
def documents(self) -> AsyncDocumentsResourceWithStreamingResponse:
return AsyncDocumentsResourceWithStreamingResponse(self._datastores.documents)
+
+ @cached_property
+ def contents(self) -> AsyncContentsResourceWithStreamingResponse:
+ return AsyncContentsResourceWithStreamingResponse(self._datastores.contents)
diff --git a/src/contextual/types/datastores/__init__.py b/src/contextual/types/datastores/__init__.py
index 5a18d8e..764a27d 100644
--- a/src/contextual/types/datastores/__init__.py
+++ b/src/contextual/types/datastores/__init__.py
@@ -4,11 +4,15 @@
from .document_metadata import DocumentMetadata as DocumentMetadata
from .ingestion_response import IngestionResponse as IngestionResponse
+from .content_list_params import ContentListParams as ContentListParams
from .base_metadata_filter import BaseMetadataFilter as BaseMetadataFilter
from .document_list_params import DocumentListParams as DocumentListParams
+from .content_list_response import ContentListResponse as ContentListResponse
from .document_ingest_params import DocumentIngestParams as DocumentIngestParams
+from .content_metadata_params import ContentMetadataParams as ContentMetadataParams
from .list_documents_response import ListDocumentsResponse as ListDocumentsResponse
from .composite_metadata_filter import CompositeMetadataFilter as CompositeMetadataFilter
+from .content_metadata_response import ContentMetadataResponse as ContentMetadataResponse
from .base_metadata_filter_param import BaseMetadataFilterParam as BaseMetadataFilterParam
from .document_set_metadata_params import DocumentSetMetadataParams as DocumentSetMetadataParams
from .composite_metadata_filter_param import CompositeMetadataFilterParam as CompositeMetadataFilterParam
diff --git a/src/contextual/types/datastores/content_list_params.py b/src/contextual/types/datastores/content_list_params.py
new file mode 100644
index 0000000..b278fc9
--- /dev/null
+++ b/src/contextual/types/datastores/content_list_params.py
@@ -0,0 +1,21 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing_extensions import TypedDict
+
+__all__ = ["ContentListParams"]
+
+
+class ContentListParams(TypedDict, total=False):
+ document_id: str
+ """Document ID of the document to retrieve details for"""
+
+ limit: int
+ """The number of content ids to be returned"""
+
+ offset: int
+ """The offset to start retrieving content ids"""
+
+ search: str
+ """The query to search keywords for"""
diff --git a/src/contextual/types/datastores/content_list_response.py b/src/contextual/types/datastores/content_list_response.py
new file mode 100644
index 0000000..36766fd
--- /dev/null
+++ b/src/contextual/types/datastores/content_list_response.py
@@ -0,0 +1,39 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Union, Optional
+from typing_extensions import Literal, Annotated, TypeAlias
+
+from pydantic import Field as FieldInfo
+
+from ..._utils import PropertyInfo
+from ..._models import BaseModel
+
+__all__ = ["ContentListResponse", "DocumentContentEntry", "StructuredContentEntry"]
+
+
+class DocumentContentEntry(BaseModel):
+ content_id: str
+ """ID of the content"""
+
+ page_number: int
+ """Page number of the content"""
+
+ content_type: Optional[Literal["unstructured"]] = None
+
+
+class StructuredContentEntry(BaseModel):
+ content_id: str
+ """ID of the content"""
+
+ table_name: str
+ """Name of the table"""
+
+ content_type: Optional[Literal["structured"]] = None
+
+ schema_: Optional[str] = FieldInfo(alias="schema", default=None)
+ """Name of the schema of the table"""
+
+
+ContentListResponse: TypeAlias = Annotated[
+ Union[DocumentContentEntry, StructuredContentEntry], PropertyInfo(discriminator="content_type")
+]
diff --git a/src/contextual/types/datastores/content_metadata_params.py b/src/contextual/types/datastores/content_metadata_params.py
new file mode 100644
index 0000000..b8e9ab3
--- /dev/null
+++ b/src/contextual/types/datastores/content_metadata_params.py
@@ -0,0 +1,13 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing_extensions import Required, TypedDict
+
+__all__ = ["ContentMetadataParams"]
+
+
+class ContentMetadataParams(TypedDict, total=False):
+ datastore_id: Required[str]
+
+ cursor: str
diff --git a/src/contextual/types/datastores/content_metadata_response.py b/src/contextual/types/datastores/content_metadata_response.py
new file mode 100644
index 0000000..025edf6
--- /dev/null
+++ b/src/contextual/types/datastores/content_metadata_response.py
@@ -0,0 +1,80 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Union, Optional
+from typing_extensions import Literal, Annotated, TypeAlias
+
+from ..._utils import PropertyInfo
+from ..._models import BaseModel
+
+__all__ = [
+ "ContentMetadataResponse",
+ "UnstructuredContentMetadata",
+ "StructuredContentMetadata",
+ "FileAnalysisContentMetadata",
+]
+
+
+class UnstructuredContentMetadata(BaseModel):
+ content_id: str
+ """Id of the content."""
+
+ content_text: str
+ """Text of the content."""
+
+ document_id: str
+ """Id of the document which the content belongs to."""
+
+ height: float
+ """Height of the image."""
+
+ page: int
+ """Page number of the content."""
+
+ page_img: str
+ """Image of the page on which the content occurs."""
+
+ width: float
+ """Width of the image."""
+
+ x0: float
+ """X coordinate of the top left corner on the bounding box."""
+
+ x1: float
+ """X coordinate of the bottom right corner on the bounding box."""
+
+ y0: float
+ """Y coordinate of the top left corner on the bounding box."""
+
+ y1: float
+ """Y coordinate of the bottom right corner on the bounding box."""
+
+ content_type: Optional[Literal["unstructured"]] = None
+
+
+class StructuredContentMetadata(BaseModel):
+ content_id: str
+ """Id of the content."""
+
+ content_text: object
+ """Text of the content."""
+
+ content_type: Optional[Literal["structured"]] = None
+
+
+class FileAnalysisContentMetadata(BaseModel):
+ content_id: str
+ """Id of the content."""
+
+ file_format: str
+ """Format of the file."""
+
+ gcp_location: str
+ """GCP location of the file."""
+
+ content_type: Optional[Literal["file_analysis"]] = None
+
+
+ContentMetadataResponse: TypeAlias = Annotated[
+ Union[UnstructuredContentMetadata, StructuredContentMetadata, FileAnalysisContentMetadata],
+ PropertyInfo(discriminator="content_type"),
+]
diff --git a/tests/api_resources/datastores/test_contents.py b/tests/api_resources/datastores/test_contents.py
new file mode 100644
index 0000000..95826d6
--- /dev/null
+++ b/tests/api_resources/datastores/test_contents.py
@@ -0,0 +1,240 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+import os
+from typing import Any, cast
+
+import pytest
+
+from contextual import ContextualAI, AsyncContextualAI
+from tests.utils import assert_matches_type
+from contextual.pagination import SyncContentsPage, AsyncContentsPage
+from contextual.types.datastores import (
+ ContentListResponse,
+ ContentMetadataResponse,
+)
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+
+
+class TestContents:
+ parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @parametrize
+ def test_method_list(self, client: ContextualAI) -> None:
+ content = client.datastores.contents.list(
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ )
+ assert_matches_type(SyncContentsPage[ContentListResponse], content, path=["response"])
+
+ @parametrize
+ def test_method_list_with_all_params(self, client: ContextualAI) -> None:
+ content = client.datastores.contents.list(
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ document_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ limit=0,
+ offset=0,
+ search="search",
+ )
+ assert_matches_type(SyncContentsPage[ContentListResponse], content, path=["response"])
+
+ @parametrize
+ def test_raw_response_list(self, client: ContextualAI) -> None:
+ response = client.datastores.contents.with_raw_response.list(
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ content = response.parse()
+ assert_matches_type(SyncContentsPage[ContentListResponse], content, path=["response"])
+
+ @parametrize
+ def test_streaming_response_list(self, client: ContextualAI) -> None:
+ with client.datastores.contents.with_streaming_response.list(
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ content = response.parse()
+ assert_matches_type(SyncContentsPage[ContentListResponse], content, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ def test_path_params_list(self, client: ContextualAI) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `datastore_id` but received ''"):
+ client.datastores.contents.with_raw_response.list(
+ datastore_id="",
+ )
+
+ @parametrize
+ def test_method_metadata(self, client: ContextualAI) -> None:
+ content = client.datastores.contents.metadata(
+ content_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ )
+ assert_matches_type(ContentMetadataResponse, content, path=["response"])
+
+ @parametrize
+ def test_method_metadata_with_all_params(self, client: ContextualAI) -> None:
+ content = client.datastores.contents.metadata(
+ content_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ cursor="cursor",
+ )
+ assert_matches_type(ContentMetadataResponse, content, path=["response"])
+
+ @parametrize
+ def test_raw_response_metadata(self, client: ContextualAI) -> None:
+ response = client.datastores.contents.with_raw_response.metadata(
+ content_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ content = response.parse()
+ assert_matches_type(ContentMetadataResponse, content, path=["response"])
+
+ @parametrize
+ def test_streaming_response_metadata(self, client: ContextualAI) -> None:
+ with client.datastores.contents.with_streaming_response.metadata(
+ content_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ content = response.parse()
+ assert_matches_type(ContentMetadataResponse, content, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ def test_path_params_metadata(self, client: ContextualAI) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `datastore_id` but received ''"):
+ client.datastores.contents.with_raw_response.metadata(
+ content_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ datastore_id="",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `content_id` but received ''"):
+ client.datastores.contents.with_raw_response.metadata(
+ content_id="",
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ )
+
+
+class TestAsyncContents:
+ parametrize = pytest.mark.parametrize(
+ "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"]
+ )
+
+ @parametrize
+ async def test_method_list(self, async_client: AsyncContextualAI) -> None:
+ content = await async_client.datastores.contents.list(
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ )
+ assert_matches_type(AsyncContentsPage[ContentListResponse], content, path=["response"])
+
+ @parametrize
+ async def test_method_list_with_all_params(self, async_client: AsyncContextualAI) -> None:
+ content = await async_client.datastores.contents.list(
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ document_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ limit=0,
+ offset=0,
+ search="search",
+ )
+ assert_matches_type(AsyncContentsPage[ContentListResponse], content, path=["response"])
+
+ @parametrize
+ async def test_raw_response_list(self, async_client: AsyncContextualAI) -> None:
+ response = await async_client.datastores.contents.with_raw_response.list(
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ content = await response.parse()
+ assert_matches_type(AsyncContentsPage[ContentListResponse], content, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_list(self, async_client: AsyncContextualAI) -> None:
+ async with async_client.datastores.contents.with_streaming_response.list(
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ content = await response.parse()
+ assert_matches_type(AsyncContentsPage[ContentListResponse], content, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ async def test_path_params_list(self, async_client: AsyncContextualAI) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `datastore_id` but received ''"):
+ await async_client.datastores.contents.with_raw_response.list(
+ datastore_id="",
+ )
+
+ @parametrize
+ async def test_method_metadata(self, async_client: AsyncContextualAI) -> None:
+ content = await async_client.datastores.contents.metadata(
+ content_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ )
+ assert_matches_type(ContentMetadataResponse, content, path=["response"])
+
+ @parametrize
+ async def test_method_metadata_with_all_params(self, async_client: AsyncContextualAI) -> None:
+ content = await async_client.datastores.contents.metadata(
+ content_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ cursor="cursor",
+ )
+ assert_matches_type(ContentMetadataResponse, content, path=["response"])
+
+ @parametrize
+ async def test_raw_response_metadata(self, async_client: AsyncContextualAI) -> None:
+ response = await async_client.datastores.contents.with_raw_response.metadata(
+ content_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ content = await response.parse()
+ assert_matches_type(ContentMetadataResponse, content, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_metadata(self, async_client: AsyncContextualAI) -> None:
+ async with async_client.datastores.contents.with_streaming_response.metadata(
+ content_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ content = await response.parse()
+ assert_matches_type(ContentMetadataResponse, content, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ async def test_path_params_metadata(self, async_client: AsyncContextualAI) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `datastore_id` but received ''"):
+ await async_client.datastores.contents.with_raw_response.metadata(
+ content_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ datastore_id="",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `content_id` but received ''"):
+ await async_client.datastores.contents.with_raw_response.metadata(
+ content_id="",
+ datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
+ )
diff --git a/tests/test_client.py b/tests/test_client.py
index 60b3a98..af288e8 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -59,51 +59,49 @@ def _get_open_connections(client: ContextualAI | AsyncContextualAI) -> int:
class TestContextualAI:
- client = ContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
@pytest.mark.respx(base_url=base_url)
- def test_raw_response(self, respx_mock: MockRouter) -> None:
+ def test_raw_response(self, respx_mock: MockRouter, client: ContextualAI) -> None:
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = self.client.post("/foo", cast_to=httpx.Response)
+ response = client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
@pytest.mark.respx(base_url=base_url)
- def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
+ def test_raw_response_for_binary(self, respx_mock: MockRouter, client: ContextualAI) -> None:
respx_mock.post("/foo").mock(
return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
)
- response = self.client.post("/foo", cast_to=httpx.Response)
+ response = client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
- def test_copy(self) -> None:
- copied = self.client.copy()
- assert id(copied) != id(self.client)
+ def test_copy(self, client: ContextualAI) -> None:
+ copied = client.copy()
+ assert id(copied) != id(client)
- copied = self.client.copy(api_key="another My API Key")
+ copied = client.copy(api_key="another My API Key")
assert copied.api_key == "another My API Key"
- assert self.client.api_key == "My API Key"
+ assert client.api_key == "My API Key"
- def test_copy_default_options(self) -> None:
+ def test_copy_default_options(self, client: ContextualAI) -> None:
# options that have a default are overridden correctly
- copied = self.client.copy(max_retries=7)
+ copied = client.copy(max_retries=7)
assert copied.max_retries == 7
- assert self.client.max_retries == 2
+ assert client.max_retries == 2
copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
assert copied.max_retries == 7
# timeout
- assert isinstance(self.client.timeout, httpx.Timeout)
- copied = self.client.copy(timeout=None)
+ assert isinstance(client.timeout, httpx.Timeout)
+ copied = client.copy(timeout=None)
assert copied.timeout is None
- assert isinstance(self.client.timeout, httpx.Timeout)
+ assert isinstance(client.timeout, httpx.Timeout)
def test_copy_default_headers(self) -> None:
client = ContextualAI(
@@ -138,6 +136,7 @@ def test_copy_default_headers(self) -> None:
match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
):
client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
+ client.close()
def test_copy_default_query(self) -> None:
client = ContextualAI(
@@ -175,13 +174,15 @@ def test_copy_default_query(self) -> None:
):
client.copy(set_default_query={}, default_query={"foo": "Bar"})
- def test_copy_signature(self) -> None:
+ client.close()
+
+ def test_copy_signature(self, client: ContextualAI) -> None:
# ensure the same parameters that can be passed to the client are defined in the `.copy()` method
init_signature = inspect.signature(
# mypy doesn't like that we access the `__init__` property.
- self.client.__init__, # type: ignore[misc]
+ client.__init__, # type: ignore[misc]
)
- copy_signature = inspect.signature(self.client.copy)
+ copy_signature = inspect.signature(client.copy)
exclude_params = {"transport", "proxies", "_strict_response_validation"}
for name in init_signature.parameters.keys():
@@ -192,12 +193,12 @@ def test_copy_signature(self) -> None:
assert copy_param is not None, f"copy() signature is missing the {name} param"
@pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
- def test_copy_build_request(self) -> None:
+ def test_copy_build_request(self, client: ContextualAI) -> None:
options = FinalRequestOptions(method="get", url="/foo")
def build_request(options: FinalRequestOptions) -> None:
- client = self.client.copy()
- client._build_request(options)
+ client_copy = client.copy()
+ client_copy._build_request(options)
# ensure that the machinery is warmed up before tracing starts.
build_request(options)
@@ -254,14 +255,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic
print(frame)
raise AssertionError()
- def test_request_timeout(self) -> None:
- request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ def test_request_timeout(self, client: ContextualAI) -> None:
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
- request = self.client._build_request(
- FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
- )
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(100.0)
@@ -274,6 +273,8 @@ def test_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(0)
+ client.close()
+
def test_http_client_timeout_option(self) -> None:
# custom timeout given to the httpx client should be used
with httpx.Client(timeout=None) as http_client:
@@ -285,6 +286,8 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(None)
+ client.close()
+
# no timeout given to the httpx client should not use the httpx default
with httpx.Client() as http_client:
client = ContextualAI(
@@ -295,6 +298,8 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
+ client.close()
+
# explicitly passing the default timeout currently results in it being ignored
with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
client = ContextualAI(
@@ -305,6 +310,8 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT # our default
+ client.close()
+
async def test_invalid_http_client(self) -> None:
with pytest.raises(TypeError, match="Invalid `http_client` arg"):
async with httpx.AsyncClient() as http_client:
@@ -316,14 +323,14 @@ async def test_invalid_http_client(self) -> None:
)
def test_default_headers_option(self) -> None:
- client = ContextualAI(
+ test_client = ContextualAI(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
)
- request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "bar"
assert request.headers.get("x-stainless-lang") == "python"
- client2 = ContextualAI(
+ test_client2 = ContextualAI(
base_url=base_url,
api_key=api_key,
_strict_response_validation=True,
@@ -332,10 +339,13 @@ def test_default_headers_option(self) -> None:
"X-Stainless-Lang": "my-overriding-header",
},
)
- request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "stainless"
assert request.headers.get("x-stainless-lang") == "my-overriding-header"
+ test_client.close()
+ test_client2.close()
+
def test_validate_headers(self) -> None:
client = ContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
@@ -364,8 +374,10 @@ def test_default_query_option(self) -> None:
url = httpx.URL(request.url)
assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
- def test_request_extra_json(self) -> None:
- request = self.client._build_request(
+ client.close()
+
+ def test_request_extra_json(self, client: ContextualAI) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -376,7 +388,7 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": False}
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -387,7 +399,7 @@ def test_request_extra_json(self) -> None:
assert data == {"baz": False}
# `extra_json` takes priority over `json_data` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -398,8 +410,8 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": None}
- def test_request_extra_headers(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_headers(self, client: ContextualAI) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -409,7 +421,7 @@ def test_request_extra_headers(self) -> None:
assert request.headers.get("X-Foo") == "Foo"
# `extra_headers` takes priority over `default_headers` when keys clash
- request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
+ request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -420,8 +432,8 @@ def test_request_extra_headers(self) -> None:
)
assert request.headers.get("X-Bar") == "false"
- def test_request_extra_query(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_query(self, client: ContextualAI) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -434,7 +446,7 @@ def test_request_extra_query(self) -> None:
assert params == {"my_query_param": "Foo"}
# if both `query` and `extra_query` are given, they are merged
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -448,7 +460,7 @@ def test_request_extra_query(self) -> None:
assert params == {"bar": "1", "foo": "2"}
# `extra_query` takes priority over `query` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -491,7 +503,7 @@ def test_multipart_repeating_array(self, client: ContextualAI) -> None:
]
@pytest.mark.respx(base_url=base_url)
- def test_basic_union_response(self, respx_mock: MockRouter) -> None:
+ def test_basic_union_response(self, respx_mock: MockRouter, client: ContextualAI) -> None:
class Model1(BaseModel):
name: str
@@ -500,12 +512,12 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
@pytest.mark.respx(base_url=base_url)
- def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
+ def test_union_response_different_types(self, respx_mock: MockRouter, client: ContextualAI) -> None:
"""Union of objects with the same field name using a different type"""
class Model1(BaseModel):
@@ -516,18 +528,20 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
- response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model1)
assert response.foo == 1
@pytest.mark.respx(base_url=base_url)
- def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
+ def test_non_application_json_content_type_for_json_data(
+ self, respx_mock: MockRouter, client: ContextualAI
+ ) -> None:
"""
Response that sets Content-Type to something other than application/json but returns json data
"""
@@ -543,7 +557,7 @@ class Model(BaseModel):
)
)
- response = self.client.get("/foo", cast_to=Model)
+ response = client.get("/foo", cast_to=Model)
assert isinstance(response, Model)
assert response.foo == 2
@@ -557,6 +571,8 @@ def test_base_url_setter(self) -> None:
assert client.base_url == "https://example.com/from_setter/"
+ client.close()
+
def test_base_url_env(self) -> None:
with update_env(CONTEXTUAL_AI_BASE_URL="http://localhost:5000/from/env"):
client = ContextualAI(api_key=api_key, _strict_response_validation=True)
@@ -586,6 +602,7 @@ def test_base_url_trailing_slash(self, client: ContextualAI) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ client.close()
@pytest.mark.parametrize(
"client",
@@ -611,6 +628,7 @@ def test_base_url_no_trailing_slash(self, client: ContextualAI) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ client.close()
@pytest.mark.parametrize(
"client",
@@ -636,35 +654,36 @@ def test_absolute_request_url(self, client: ContextualAI) -> None:
),
)
assert request.url == "https://myapi.com/foo"
+ client.close()
def test_copied_client_does_not_close_http(self) -> None:
- client = ContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- assert not client.is_closed()
+ test_client = ContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ assert not test_client.is_closed()
- copied = client.copy()
- assert copied is not client
+ copied = test_client.copy()
+ assert copied is not test_client
del copied
- assert not client.is_closed()
+ assert not test_client.is_closed()
def test_client_context_manager(self) -> None:
- client = ContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- with client as c2:
- assert c2 is client
+ test_client = ContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ with test_client as c2:
+ assert c2 is test_client
assert not c2.is_closed()
- assert not client.is_closed()
- assert client.is_closed()
+ assert not test_client.is_closed()
+ assert test_client.is_closed()
@pytest.mark.respx(base_url=base_url)
- def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
+ def test_client_response_validation_error(self, respx_mock: MockRouter, client: ContextualAI) -> None:
class Model(BaseModel):
foo: str
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
with pytest.raises(APIResponseValidationError) as exc:
- self.client.get("/foo", cast_to=Model)
+ client.get("/foo", cast_to=Model)
assert isinstance(exc.value.__cause__, ValidationError)
@@ -686,11 +705,14 @@ class Model(BaseModel):
with pytest.raises(APIResponseValidationError):
strict_client.get("/foo", cast_to=Model)
- client = ContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ non_strict_client = ContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
- response = client.get("/foo", cast_to=Model)
+ response = non_strict_client.get("/foo", cast_to=Model)
assert isinstance(response, str) # type: ignore[unreachable]
+ strict_client.close()
+ non_strict_client.close()
+
@pytest.mark.parametrize(
"remaining_retries,retry_after,timeout",
[
@@ -713,9 +735,9 @@ class Model(BaseModel):
],
)
@mock.patch("time.time", mock.MagicMock(return_value=1696004797))
- def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
- client = ContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
+ def test_parse_retry_after_header(
+ self, remaining_retries: int, retry_after: str, timeout: float, client: ContextualAI
+ ) -> None:
headers = httpx.Headers({"retry-after": retry_after})
options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
@@ -729,7 +751,7 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, clien
with pytest.raises(APITimeoutError):
client.agents.with_streaming_response.create(name="xxx").__enter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(client) == 0
@mock.patch("contextual._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
@@ -738,7 +760,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client
with pytest.raises(APIStatusError):
client.agents.with_streaming_response.create(name="xxx").__enter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(client) == 0
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("contextual._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -840,83 +862,77 @@ def test_default_client_creation(self) -> None:
)
@pytest.mark.respx(base_url=base_url)
- def test_follow_redirects(self, respx_mock: MockRouter) -> None:
+ def test_follow_redirects(self, respx_mock: MockRouter, client: ContextualAI) -> None:
# Test that the default follow_redirects=True allows following redirects
respx_mock.post("/redirect").mock(
return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
)
respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
- response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
+ response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.respx(base_url=base_url)
- def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
+ def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: ContextualAI) -> None:
# Test that follow_redirects=False prevents following redirects
respx_mock.post("/redirect").mock(
return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
)
with pytest.raises(APIStatusError) as exc_info:
- self.client.post(
- "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
- )
+ client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response)
assert exc_info.value.response.status_code == 302
assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
class TestAsyncContextualAI:
- client = AsyncContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
- async def test_raw_response(self, respx_mock: MockRouter) -> None:
+ async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncContextualAI) -> None:
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = await self.client.post("/foo", cast_to=httpx.Response)
+ response = await async_client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
- async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
+ async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncContextualAI) -> None:
respx_mock.post("/foo").mock(
return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
)
- response = await self.client.post("/foo", cast_to=httpx.Response)
+ response = await async_client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
- def test_copy(self) -> None:
- copied = self.client.copy()
- assert id(copied) != id(self.client)
+ def test_copy(self, async_client: AsyncContextualAI) -> None:
+ copied = async_client.copy()
+ assert id(copied) != id(async_client)
- copied = self.client.copy(api_key="another My API Key")
+ copied = async_client.copy(api_key="another My API Key")
assert copied.api_key == "another My API Key"
- assert self.client.api_key == "My API Key"
+ assert async_client.api_key == "My API Key"
- def test_copy_default_options(self) -> None:
+ def test_copy_default_options(self, async_client: AsyncContextualAI) -> None:
# options that have a default are overridden correctly
- copied = self.client.copy(max_retries=7)
+ copied = async_client.copy(max_retries=7)
assert copied.max_retries == 7
- assert self.client.max_retries == 2
+ assert async_client.max_retries == 2
copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
assert copied.max_retries == 7
# timeout
- assert isinstance(self.client.timeout, httpx.Timeout)
- copied = self.client.copy(timeout=None)
+ assert isinstance(async_client.timeout, httpx.Timeout)
+ copied = async_client.copy(timeout=None)
assert copied.timeout is None
- assert isinstance(self.client.timeout, httpx.Timeout)
+ assert isinstance(async_client.timeout, httpx.Timeout)
- def test_copy_default_headers(self) -> None:
+ async def test_copy_default_headers(self) -> None:
client = AsyncContextualAI(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
)
@@ -949,8 +965,9 @@ def test_copy_default_headers(self) -> None:
match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
):
client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
+ await client.close()
- def test_copy_default_query(self) -> None:
+ async def test_copy_default_query(self) -> None:
client = AsyncContextualAI(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
)
@@ -986,13 +1003,15 @@ def test_copy_default_query(self) -> None:
):
client.copy(set_default_query={}, default_query={"foo": "Bar"})
- def test_copy_signature(self) -> None:
+ await client.close()
+
+ def test_copy_signature(self, async_client: AsyncContextualAI) -> None:
# ensure the same parameters that can be passed to the client are defined in the `.copy()` method
init_signature = inspect.signature(
# mypy doesn't like that we access the `__init__` property.
- self.client.__init__, # type: ignore[misc]
+ async_client.__init__, # type: ignore[misc]
)
- copy_signature = inspect.signature(self.client.copy)
+ copy_signature = inspect.signature(async_client.copy)
exclude_params = {"transport", "proxies", "_strict_response_validation"}
for name in init_signature.parameters.keys():
@@ -1003,12 +1022,12 @@ def test_copy_signature(self) -> None:
assert copy_param is not None, f"copy() signature is missing the {name} param"
@pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
- def test_copy_build_request(self) -> None:
+ def test_copy_build_request(self, async_client: AsyncContextualAI) -> None:
options = FinalRequestOptions(method="get", url="/foo")
def build_request(options: FinalRequestOptions) -> None:
- client = self.client.copy()
- client._build_request(options)
+ client_copy = async_client.copy()
+ client_copy._build_request(options)
# ensure that the machinery is warmed up before tracing starts.
build_request(options)
@@ -1065,12 +1084,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic
print(frame)
raise AssertionError()
- async def test_request_timeout(self) -> None:
- request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ async def test_request_timeout(self, async_client: AsyncContextualAI) -> None:
+ request = async_client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
- request = self.client._build_request(
+ request = async_client._build_request(
FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
)
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
@@ -1085,6 +1104,8 @@ async def test_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(0)
+ await client.close()
+
async def test_http_client_timeout_option(self) -> None:
# custom timeout given to the httpx client should be used
async with httpx.AsyncClient(timeout=None) as http_client:
@@ -1096,6 +1117,8 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(None)
+ await client.close()
+
# no timeout given to the httpx client should not use the httpx default
async with httpx.AsyncClient() as http_client:
client = AsyncContextualAI(
@@ -1106,6 +1129,8 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
+ await client.close()
+
# explicitly passing the default timeout currently results in it being ignored
async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
client = AsyncContextualAI(
@@ -1116,6 +1141,8 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT # our default
+ await client.close()
+
def test_invalid_http_client(self) -> None:
with pytest.raises(TypeError, match="Invalid `http_client` arg"):
with httpx.Client() as http_client:
@@ -1126,15 +1153,15 @@ def test_invalid_http_client(self) -> None:
http_client=cast(Any, http_client),
)
- def test_default_headers_option(self) -> None:
- client = AsyncContextualAI(
+ async def test_default_headers_option(self) -> None:
+ test_client = AsyncContextualAI(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
)
- request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "bar"
assert request.headers.get("x-stainless-lang") == "python"
- client2 = AsyncContextualAI(
+ test_client2 = AsyncContextualAI(
base_url=base_url,
api_key=api_key,
_strict_response_validation=True,
@@ -1143,10 +1170,13 @@ def test_default_headers_option(self) -> None:
"X-Stainless-Lang": "my-overriding-header",
},
)
- request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "stainless"
assert request.headers.get("x-stainless-lang") == "my-overriding-header"
+ await test_client.close()
+ await test_client2.close()
+
def test_validate_headers(self) -> None:
client = AsyncContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
@@ -1157,7 +1187,7 @@ def test_validate_headers(self) -> None:
client2 = AsyncContextualAI(base_url=base_url, api_key=None, _strict_response_validation=True)
_ = client2
- def test_default_query_option(self) -> None:
+ async def test_default_query_option(self) -> None:
client = AsyncContextualAI(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
)
@@ -1175,8 +1205,10 @@ def test_default_query_option(self) -> None:
url = httpx.URL(request.url)
assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
- def test_request_extra_json(self) -> None:
- request = self.client._build_request(
+ await client.close()
+
+ def test_request_extra_json(self, client: ContextualAI) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1187,7 +1219,7 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": False}
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1198,7 +1230,7 @@ def test_request_extra_json(self) -> None:
assert data == {"baz": False}
# `extra_json` takes priority over `json_data` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1209,8 +1241,8 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": None}
- def test_request_extra_headers(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_headers(self, client: ContextualAI) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1220,7 +1252,7 @@ def test_request_extra_headers(self) -> None:
assert request.headers.get("X-Foo") == "Foo"
# `extra_headers` takes priority over `default_headers` when keys clash
- request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
+ request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1231,8 +1263,8 @@ def test_request_extra_headers(self) -> None:
)
assert request.headers.get("X-Bar") == "false"
- def test_request_extra_query(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_query(self, client: ContextualAI) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1245,7 +1277,7 @@ def test_request_extra_query(self) -> None:
assert params == {"my_query_param": "Foo"}
# if both `query` and `extra_query` are given, they are merged
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1259,7 +1291,7 @@ def test_request_extra_query(self) -> None:
assert params == {"bar": "1", "foo": "2"}
# `extra_query` takes priority over `query` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1302,7 +1334,7 @@ def test_multipart_repeating_array(self, async_client: AsyncContextualAI) -> Non
]
@pytest.mark.respx(base_url=base_url)
- async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
+ async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncContextualAI) -> None:
class Model1(BaseModel):
name: str
@@ -1311,12 +1343,14 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
@pytest.mark.respx(base_url=base_url)
- async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
+ async def test_union_response_different_types(
+ self, respx_mock: MockRouter, async_client: AsyncContextualAI
+ ) -> None:
"""Union of objects with the same field name using a different type"""
class Model1(BaseModel):
@@ -1327,18 +1361,20 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
- response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model1)
assert response.foo == 1
@pytest.mark.respx(base_url=base_url)
- async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
+ async def test_non_application_json_content_type_for_json_data(
+ self, respx_mock: MockRouter, async_client: AsyncContextualAI
+ ) -> None:
"""
Response that sets Content-Type to something other than application/json but returns json data
"""
@@ -1354,11 +1390,11 @@ class Model(BaseModel):
)
)
- response = await self.client.get("/foo", cast_to=Model)
+ response = await async_client.get("/foo", cast_to=Model)
assert isinstance(response, Model)
assert response.foo == 2
- def test_base_url_setter(self) -> None:
+ async def test_base_url_setter(self) -> None:
client = AsyncContextualAI(
base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
)
@@ -1368,7 +1404,9 @@ def test_base_url_setter(self) -> None:
assert client.base_url == "https://example.com/from_setter/"
- def test_base_url_env(self) -> None:
+ await client.close()
+
+ async def test_base_url_env(self) -> None:
with update_env(CONTEXTUAL_AI_BASE_URL="http://localhost:5000/from/env"):
client = AsyncContextualAI(api_key=api_key, _strict_response_validation=True)
assert client.base_url == "http://localhost:5000/from/env/"
@@ -1388,7 +1426,7 @@ def test_base_url_env(self) -> None:
],
ids=["standard", "custom http client"],
)
- def test_base_url_trailing_slash(self, client: AsyncContextualAI) -> None:
+ async def test_base_url_trailing_slash(self, client: AsyncContextualAI) -> None:
request = client._build_request(
FinalRequestOptions(
method="post",
@@ -1397,6 +1435,7 @@ def test_base_url_trailing_slash(self, client: AsyncContextualAI) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ await client.close()
@pytest.mark.parametrize(
"client",
@@ -1413,7 +1452,7 @@ def test_base_url_trailing_slash(self, client: AsyncContextualAI) -> None:
],
ids=["standard", "custom http client"],
)
- def test_base_url_no_trailing_slash(self, client: AsyncContextualAI) -> None:
+ async def test_base_url_no_trailing_slash(self, client: AsyncContextualAI) -> None:
request = client._build_request(
FinalRequestOptions(
method="post",
@@ -1422,6 +1461,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncContextualAI) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ await client.close()
@pytest.mark.parametrize(
"client",
@@ -1438,7 +1478,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncContextualAI) -> None:
],
ids=["standard", "custom http client"],
)
- def test_absolute_request_url(self, client: AsyncContextualAI) -> None:
+ async def test_absolute_request_url(self, client: AsyncContextualAI) -> None:
request = client._build_request(
FinalRequestOptions(
method="post",
@@ -1447,37 +1487,39 @@ def test_absolute_request_url(self, client: AsyncContextualAI) -> None:
),
)
assert request.url == "https://myapi.com/foo"
+ await client.close()
async def test_copied_client_does_not_close_http(self) -> None:
- client = AsyncContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- assert not client.is_closed()
+ test_client = AsyncContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ assert not test_client.is_closed()
- copied = client.copy()
- assert copied is not client
+ copied = test_client.copy()
+ assert copied is not test_client
del copied
await asyncio.sleep(0.2)
- assert not client.is_closed()
+ assert not test_client.is_closed()
async def test_client_context_manager(self) -> None:
- client = AsyncContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- async with client as c2:
- assert c2 is client
+ test_client = AsyncContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ async with test_client as c2:
+ assert c2 is test_client
assert not c2.is_closed()
- assert not client.is_closed()
- assert client.is_closed()
+ assert not test_client.is_closed()
+ assert test_client.is_closed()
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
- async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
+ async def test_client_response_validation_error(
+ self, respx_mock: MockRouter, async_client: AsyncContextualAI
+ ) -> None:
class Model(BaseModel):
foo: str
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
with pytest.raises(APIResponseValidationError) as exc:
- await self.client.get("/foo", cast_to=Model)
+ await async_client.get("/foo", cast_to=Model)
assert isinstance(exc.value.__cause__, ValidationError)
@@ -1488,7 +1530,6 @@ async def test_client_max_retries_validation(self) -> None:
)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
class Model(BaseModel):
name: str
@@ -1500,11 +1541,14 @@ class Model(BaseModel):
with pytest.raises(APIResponseValidationError):
await strict_client.get("/foo", cast_to=Model)
- client = AsyncContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ non_strict_client = AsyncContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
- response = await client.get("/foo", cast_to=Model)
+ response = await non_strict_client.get("/foo", cast_to=Model)
assert isinstance(response, str) # type: ignore[unreachable]
+ await strict_client.close()
+ await non_strict_client.close()
+
@pytest.mark.parametrize(
"remaining_retries,retry_after,timeout",
[
@@ -1527,13 +1571,12 @@ class Model(BaseModel):
],
)
@mock.patch("time.time", mock.MagicMock(return_value=1696004797))
- @pytest.mark.asyncio
- async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
- client = AsyncContextualAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
+ async def test_parse_retry_after_header(
+ self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncContextualAI
+ ) -> None:
headers = httpx.Headers({"retry-after": retry_after})
options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
- calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
+ calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers)
assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
@mock.patch("contextual._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -1546,7 +1589,7 @@ async def test_retrying_timeout_errors_doesnt_leak(
with pytest.raises(APITimeoutError):
await async_client.agents.with_streaming_response.create(name="xxx").__aenter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(async_client) == 0
@mock.patch("contextual._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
@@ -1557,12 +1600,11 @@ async def test_retrying_status_errors_doesnt_leak(
with pytest.raises(APIStatusError):
await async_client.agents.with_streaming_response.create(name="xxx").__aenter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(async_client) == 0
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("contextual._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
@pytest.mark.parametrize("failure_mode", ["status", "exception"])
async def test_retries_taken(
self,
@@ -1594,7 +1636,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("contextual._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
async def test_omit_retry_count_header(
self, async_client: AsyncContextualAI, failures_before_success: int, respx_mock: MockRouter
) -> None:
@@ -1620,7 +1661,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("contextual._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
async def test_overwrite_retry_count_header(
self, async_client: AsyncContextualAI, failures_before_success: int, respx_mock: MockRouter
) -> None:
@@ -1670,26 +1710,26 @@ async def test_default_client_creation(self) -> None:
)
@pytest.mark.respx(base_url=base_url)
- async def test_follow_redirects(self, respx_mock: MockRouter) -> None:
+ async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncContextualAI) -> None:
# Test that the default follow_redirects=True allows following redirects
respx_mock.post("/redirect").mock(
return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
)
respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
- response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
+ response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.respx(base_url=base_url)
- async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
+ async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncContextualAI) -> None:
# Test that follow_redirects=False prevents following redirects
respx_mock.post("/redirect").mock(
return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
)
with pytest.raises(APIStatusError) as exc_info:
- await self.client.post(
+ await async_client.post(
"/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
)
diff --git a/tests/test_models.py b/tests/test_models.py
index 45f5759..fd2ff9f 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -9,7 +9,7 @@
from contextual._utils import PropertyInfo
from contextual._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
-from contextual._models import BaseModel, construct_type
+from contextual._models import DISCRIMINATOR_CACHE, BaseModel, construct_type
class BasicModel(BaseModel):
@@ -809,7 +809,7 @@ class B(BaseModel):
UnionType = cast(Any, Union[A, B])
- assert not hasattr(UnionType, "__discriminator__")
+ assert not DISCRIMINATOR_CACHE.get(UnionType)
m = construct_type(
value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
@@ -818,7 +818,7 @@ class B(BaseModel):
assert m.type == "b"
assert m.data == "foo" # type: ignore[comparison-overlap]
- discriminator = UnionType.__discriminator__
+ discriminator = DISCRIMINATOR_CACHE.get(UnionType)
assert discriminator is not None
m = construct_type(
@@ -830,7 +830,7 @@ class B(BaseModel):
# if the discriminator details object stays the same between invocations then
# we hit the cache
- assert UnionType.__discriminator__ is discriminator
+ assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator
@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")