diff --git a/src/xai_sdk/aio/collections.py b/src/xai_sdk/aio/collections.py index 2d25af8..129d908 100644 --- a/src/xai_sdk/aio/collections.py +++ b/src/xai_sdk/aio/collections.py @@ -2,6 +2,8 @@ import datetime from typing import Optional, Sequence, Union +from opentelemetry.trace import SpanKind + from ..collections import ( DEFAULT_INDEXING_POLL_INTERVAL, DEFAULT_INDEXING_TIMEOUT, @@ -23,6 +25,9 @@ from ..files import _async_chunk_file_data from ..poll_timer import PollTimer from ..proto import collections_pb2, documents_pb2, shared_pb2, types_pb2 +from ..telemetry import get_tracer + +tracer = get_tracer(__name__) class Client(BaseClient): @@ -74,15 +79,26 @@ async def create( else: field_definitions_pb.append(field_definition) - return await self._collections_stub.CreateCollection( - collections_pb2.CreateCollectionRequest( - collection_name=name, - index_configuration=types_pb2.IndexConfiguration(model_name=model_name) if model_name else None, - chunk_configuration=chunk_configuration_pb, - metric_space=metric_space_pb, - field_definitions=field_definitions_pb, + with tracer.start_as_current_span( + name="collections.create_collection", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "create_collection", + "provider.name": "xai", + }, + ) as span: + collection = await self._collections_stub.CreateCollection( + collections_pb2.CreateCollectionRequest( + collection_name=name, + index_configuration=types_pb2.IndexConfiguration(model_name=model_name) if model_name else None, + chunk_configuration=chunk_configuration_pb, + metric_space=metric_space_pb, + field_definitions=field_definitions_pb, + ) ) - ) + span.set_attribute("collection.id", collection.collection_id) + span.set_attribute("collection.name", collection.collection_name) + return collection async def list( self, @@ -167,14 +183,24 @@ async def update( chunk_configuration_pb = _chunk_configuration_to_pb(chunk_configuration) else: chunk_configuration_pb = chunk_configuration - - return await self._collections_stub.UpdateCollection( - collections_pb2.UpdateCollectionRequest( - collection_id=collection_id, - collection_name=name, - chunk_configuration=chunk_configuration_pb, + with tracer.start_as_current_span( + name="collections.update_collection", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "update_collection", + "provider.name": "xai", + }, + ) as span: + collection = await self._collections_stub.UpdateCollection( + collections_pb2.UpdateCollectionRequest( + collection_id=collection_id, + collection_name=name, + chunk_configuration=chunk_configuration_pb, + ) ) - ) + span.set_attribute("collection.id", collection.collection_id) + span.set_attribute("collection.name", collection.collection_name) + return collection async def delete(self, collection_id: str) -> None: """Deletes a collection. @@ -182,9 +208,17 @@ async def delete(self, collection_id: str) -> None: Args: collection_id: The ID of the collection to delete. """ - return await self._collections_stub.DeleteCollection( - collections_pb2.DeleteCollectionRequest(collection_id=collection_id) - ) + with tracer.start_as_current_span( + name="collections.delete_collection", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "delete_collection", + "provider.name": "xai", + }, + ) as _span: + return await self._collections_stub.DeleteCollection( + collections_pb2.DeleteCollectionRequest(collection_id=collection_id) + ) async def search( self, @@ -286,8 +320,17 @@ async def upload_document( """ # Upload the raw bytes via the streaming Files API, then attach to the collection. upload_chunks = _async_chunk_file_data(filename=name, data=data) - - uploaded_file = await self._files_stub.UploadFile(upload_chunks) + with tracer.start_as_current_span( + name="collections.upload_document", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "upload_document", + "provider.name": "xai", + }, + ) as span: + uploaded_file = await self._files_stub.UploadFile(upload_chunks) + span.set_attribute("file.id", uploaded_file.id) + span.set_attribute("file.name", uploaded_file.filename) # Attach the uploaded file to the target collection as a document. await self._collections_stub.AddDocumentToCollection( @@ -364,13 +407,21 @@ async def add_existing_document( file_id: The ID of the file (document) to add. fields: Additional metadata fields to store with the document in this collection. """ - return await self._collections_stub.AddDocumentToCollection( - collections_pb2.AddDocumentToCollectionRequest( - collection_id=collection_id, - file_id=file_id, - fields=fields, + with tracer.start_as_current_span( + name="collections.add_existing_document", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "add_existing_document", + "provider.name": "xai", + }, + ) as _span: + await self._collections_stub.AddDocumentToCollection( + collections_pb2.AddDocumentToCollectionRequest( + collection_id=collection_id, + file_id=file_id, + fields=fields, + ) ) - ) async def list_documents( self, @@ -456,9 +507,17 @@ async def remove_document(self, collection_id: str, file_id: str) -> None: collection_id: The ID of the collection to remove the document from. file_id: The ID of the file (document) to remove. """ - return await self._collections_stub.RemoveDocumentFromCollection( - collections_pb2.RemoveDocumentFromCollectionRequest(collection_id=collection_id, file_id=file_id) - ) + with tracer.start_as_current_span( + name="collections.remove_document", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "remove_document", + "provider.name": "xai", + }, + ) as _span: + return await self._collections_stub.RemoveDocumentFromCollection( + collections_pb2.RemoveDocumentFromCollectionRequest(collection_id=collection_id, file_id=file_id) + ) async def update_document( self, @@ -482,16 +541,27 @@ async def update_document( Returns: The updated metadata for the document. """ - return await self._collections_stub.UpdateDocument( - collections_pb2.UpdateDocumentRequest( - collection_id=collection_id, - file_id=file_id, - name=name, - data=data, - content_type=content_type, - fields=fields, + with tracer.start_as_current_span( + name="collections.update_document", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "update_document", + "provider.name": "xai", + }, + ) as span: + document = await self._collections_stub.UpdateDocument( + collections_pb2.UpdateDocumentRequest( + collection_id=collection_id, + file_id=file_id, + name=name, + data=data, + content_type=content_type, + fields=fields, + ) ) - ) + span.set_attribute("document.id", document.file_metadata.file_id) + span.set_attribute("document.name", document.file_metadata.name) + return document async def reindex_document(self, collection_id: str, file_id: str) -> None: """Regenerates indices for a document. diff --git a/src/xai_sdk/sync/collections.py b/src/xai_sdk/sync/collections.py index ddf4a22..5bcdd82 100644 --- a/src/xai_sdk/sync/collections.py +++ b/src/xai_sdk/sync/collections.py @@ -2,6 +2,8 @@ import time from typing import Optional, Sequence, Union +from opentelemetry.trace import SpanKind + from ..collections import ( DEFAULT_INDEXING_POLL_INTERVAL, DEFAULT_INDEXING_TIMEOUT, @@ -23,6 +25,9 @@ from ..files import _chunk_file_data from ..poll_timer import PollTimer from ..proto import collections_pb2, documents_pb2, shared_pb2, types_pb2 +from ..telemetry import get_tracer + +tracer = get_tracer(__name__) class Client(BaseClient): @@ -77,15 +82,26 @@ def create( else: field_definitions_pb.append(field_definition) - return self._collections_stub.CreateCollection( - collections_pb2.CreateCollectionRequest( - collection_name=name, - index_configuration=types_pb2.IndexConfiguration(model_name=model_name) if model_name else None, - chunk_configuration=chunk_configuration_pb, - metric_space=metric_space_pb, - field_definitions=field_definitions_pb, + with tracer.start_as_current_span( + name="collections.create_collection", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "create_collection", + "provider.name": "xai", + }, + ) as span: + collection = self._collections_stub.CreateCollection( + collections_pb2.CreateCollectionRequest( + collection_name=name, + index_configuration=types_pb2.IndexConfiguration(model_name=model_name) if model_name else None, + chunk_configuration=chunk_configuration_pb, + metric_space=metric_space_pb, + field_definitions=field_definitions_pb, + ) ) - ) + span.set_attribute("collection.id", collection.collection_id) + span.set_attribute("collection.name", collection.collection_name) + return collection def list( self, @@ -170,14 +186,24 @@ def update( chunk_configuration_pb = _chunk_configuration_to_pb(chunk_configuration) else: chunk_configuration_pb = chunk_configuration - - return self._collections_stub.UpdateCollection( - collections_pb2.UpdateCollectionRequest( - collection_id=collection_id, - collection_name=name, - chunk_configuration=chunk_configuration_pb, + with tracer.start_as_current_span( + name="collections.update_collection", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "update_collection", + "provider.name": "xai", + }, + ) as span: + collection = self._collections_stub.UpdateCollection( + collections_pb2.UpdateCollectionRequest( + collection_id=collection_id, + collection_name=name, + chunk_configuration=chunk_configuration_pb, + ) ) - ) + span.set_attribute("collection.id", collection.collection_id) + span.set_attribute("collection.name", collection.collection_name) + return collection def delete(self, collection_id: str) -> None: """Deletes a collection. @@ -185,9 +211,17 @@ def delete(self, collection_id: str) -> None: Args: collection_id: The ID of the collection to delete. """ - return self._collections_stub.DeleteCollection( - collections_pb2.DeleteCollectionRequest(collection_id=collection_id) - ) + with tracer.start_as_current_span( + name="collections.delete_collection", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "delete_collection", + "provider.name": "xai", + }, + ) as _span: + return self._collections_stub.DeleteCollection( + collections_pb2.DeleteCollectionRequest(collection_id=collection_id) + ) def search( self, @@ -289,8 +323,17 @@ def upload_document( """ # Upload the raw bytes via the streaming Files API, then attach to the collection. upload_chunks = _chunk_file_data(filename=name, data=data) - - uploaded_file = self._files_stub.UploadFile(upload_chunks) + with tracer.start_as_current_span( + name="collections.upload_document", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "upload_document", + "provider.name": "xai", + }, + ) as span: + uploaded_file = self._files_stub.UploadFile(upload_chunks) + span.set_attribute("file.id", uploaded_file.id) + span.set_attribute("file.name", uploaded_file.filename) # Attach the uploaded file to the target collection as a document. self._collections_stub.AddDocumentToCollection( @@ -367,13 +410,21 @@ def add_existing_document( file_id: The ID of the file (document) to add. fields: Additional metadata fields to store with the document in this collection. """ - return self._collections_stub.AddDocumentToCollection( - collections_pb2.AddDocumentToCollectionRequest( - collection_id=collection_id, - file_id=file_id, - fields=fields, + with tracer.start_as_current_span( + name="collections.add_existing_document", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "add_existing_document", + "provider.name": "xai", + }, + ) as _span: + return self._collections_stub.AddDocumentToCollection( + collections_pb2.AddDocumentToCollectionRequest( + collection_id=collection_id, + file_id=file_id, + fields=fields, + ) ) - ) def list_documents( self, @@ -459,9 +510,17 @@ def remove_document(self, collection_id: str, file_id: str) -> None: collection_id: The ID of the collection to remove the document from. file_id: The ID of the file (document) to remove. """ - return self._collections_stub.RemoveDocumentFromCollection( - collections_pb2.RemoveDocumentFromCollectionRequest(collection_id=collection_id, file_id=file_id) - ) + with tracer.start_as_current_span( + name="collections.remove_document", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "remove_document", + "provider.name": "xai", + }, + ) as _span: + return self._collections_stub.RemoveDocumentFromCollection( + collections_pb2.RemoveDocumentFromCollectionRequest(collection_id=collection_id, file_id=file_id) + ) def update_document( self, @@ -485,16 +544,27 @@ def update_document( Returns: The updated metadata for the document. """ - return self._collections_stub.UpdateDocument( - collections_pb2.UpdateDocumentRequest( - collection_id=collection_id, - file_id=file_id, - name=name, - data=data, - content_type=content_type, - fields=fields, + with tracer.start_as_current_span( + name="collections.update_document", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "update_document", + "provider.name": "xai", + }, + ) as span: + document = self._collections_stub.UpdateDocument( + collections_pb2.UpdateDocumentRequest( + collection_id=collection_id, + file_id=file_id, + name=name, + data=data, + content_type=content_type, + fields=fields, + ) ) - ) + span.set_attribute("document.id", document.file_metadata.file_id) + span.set_attribute("document.name", document.file_metadata.name) + return document def reindex_document(self, collection_id: str, file_id: str) -> None: """Regenerates indices for a document. diff --git a/tests/aio/collections_test.py b/tests/aio/collections_test.py index 48c80b5..aa3b61d 100644 --- a/tests/aio/collections_test.py +++ b/tests/aio/collections_test.py @@ -3,10 +3,12 @@ import uuid from typing import Union +from unittest import mock import grpc import pytest import pytest_asyncio +from opentelemetry.trace import SpanKind from pydantic import ValidationError from xai_sdk import AsyncClient @@ -797,3 +799,89 @@ async def test_update_document(client: AsyncClient): assert response.file_metadata.size_bytes == len(new_data) assert response.file_metadata.content_type == new_content_type assert response.fields == new_fields + + +@mock.patch("xai_sdk.aio.collections.tracer") +@pytest.mark.asyncio(loop_scope="session") +async def test_upload_document_creates_span_with_attributes(mock_tracer: mock.MagicMock, client: AsyncClient): + mock_span = mock.MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + + collection_metadata = await client.collections.create(f"test-collection-{uuid.uuid4()}") + assert collection_metadata.collection_id is not None + + mock_tracer.start_as_current_span.assert_any_call( + name="collections.create_collection", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "create_collection", + "provider.name": "xai", + }, + ) + mock_span.set_attribute.assert_any_call("collection.id", collection_metadata.collection_id) + mock_span.set_attribute.assert_any_call("collection.name", collection_metadata.collection_name) + + name = "trace-document.txt" + data = b"Tracing test" + fields = {"key": "value"} + + await client.collections.upload_document(collection_metadata.collection_id, name, data, fields) + + mock_tracer.start_as_current_span.assert_any_call( + name="collections.upload_document", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "upload_document", + "provider.name": "xai", + }, + ) + mock_span.set_attribute.assert_any_call("file.id", mock.ANY) + mock_span.set_attribute.assert_any_call("file.name", name) + + +@mock.patch("xai_sdk.aio.collections.tracer") +@pytest.mark.asyncio(loop_scope="session") +async def test_update_document_creates_span_with_attributes(mock_tracer: mock.MagicMock, client: AsyncClient): + mock_span = mock.MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + + collection_metadata = await client.collections.create(f"test-collection-{uuid.uuid4()}") + assert collection_metadata.collection_id is not None + + mock_tracer.start_as_current_span.assert_any_call( + name="collections.create_collection", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "create_collection", + "provider.name": "xai", + }, + ) + mock_span.set_attribute.assert_any_call("collection.id", collection_metadata.collection_id) + mock_span.set_attribute.assert_any_call("collection.name", collection_metadata.collection_name) + + document_metadata = await client.collections.upload_document( + collection_metadata.collection_id, + "test-document.txt", + b"Hello, world!", + {"key": "value"}, + ) + assert document_metadata.file_metadata.file_id is not None + + new_name = "test-document-2.txt" + + await client.collections.update_document( + collection_metadata.collection_id, + document_metadata.file_metadata.file_id, + name=new_name, + ) + + mock_tracer.start_as_current_span.assert_any_call( + name="collections.update_document", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "update_document", + "provider.name": "xai", + }, + ) + mock_span.set_attribute.assert_any_call("document.id", document_metadata.file_metadata.file_id) + mock_span.set_attribute.assert_any_call("document.name", new_name) diff --git a/tests/sync/collections_test.py b/tests/sync/collections_test.py index 755193a..2ce2271 100644 --- a/tests/sync/collections_test.py +++ b/tests/sync/collections_test.py @@ -1,9 +1,11 @@ import datetime import uuid from typing import Union +from unittest import mock import grpc import pytest +from opentelemetry.trace import SpanKind from pydantic import ValidationError from xai_sdk import Client @@ -888,3 +890,87 @@ def test_update_document(client: Client): assert response.file_metadata.size_bytes == len(new_data) assert response.file_metadata.content_type == new_content_type assert response.fields == new_fields + + +@mock.patch("xai_sdk.sync.collections.tracer") +def test_upload_document_creates_span_with_attributes(mock_tracer: mock.MagicMock, client: Client): + mock_span = mock.MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + + collection_metadata = client.collections.create(f"test-collection-{uuid.uuid4()}") + assert collection_metadata.collection_id is not None + + mock_tracer.start_as_current_span.assert_any_call( + name="collections.create_collection", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "create_collection", + "provider.name": "xai", + }, + ) + mock_span.set_attribute.assert_any_call("collection.id", collection_metadata.collection_id) + mock_span.set_attribute.assert_any_call("collection.name", collection_metadata.collection_name) + + name = "trace-document.txt" + data = b"Tracing test" + fields = {"key": "value"} + + client.collections.upload_document(collection_metadata.collection_id, name, data, fields) + + mock_tracer.start_as_current_span.assert_any_call( + name="collections.upload_document", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "upload_document", + "provider.name": "xai", + }, + ) + mock_span.set_attribute.assert_any_call("file.id", mock.ANY) + mock_span.set_attribute.assert_any_call("file.name", name) + + +@mock.patch("xai_sdk.sync.collections.tracer") +def test_update_document_creates_span_with_attributes(mock_tracer: mock.MagicMock, client: Client): + mock_span = mock.MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + + collection_metadata = client.collections.create(f"test-collection-{uuid.uuid4()}") + assert collection_metadata.collection_id is not None + + mock_tracer.start_as_current_span.assert_any_call( + name="collections.create_collection", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "create_collection", + "provider.name": "xai", + }, + ) + mock_span.set_attribute.assert_any_call("collection.id", collection_metadata.collection_id) + mock_span.set_attribute.assert_any_call("collection.name", collection_metadata.collection_name) + + document_metadata = client.collections.upload_document( + collection_metadata.collection_id, + "test-document.txt", + b"Hello, world!", + {"key": "value"}, + ) + assert document_metadata.file_metadata.file_id is not None + + new_name = "test-document-2.txt" + + client.collections.update_document( + collection_metadata.collection_id, + document_metadata.file_metadata.file_id, + name=new_name, + ) + + mock_tracer.start_as_current_span.assert_any_call( + name="collections.update_document", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "update_document", + "provider.name": "xai", + }, + ) + mock_span.set_attribute.assert_any_call("document.id", document_metadata.file_metadata.file_id) + mock_span.set_attribute.assert_any_call("document.name", new_name)