diff --git a/sentry_sdk/integrations/openai.py b/sentry_sdk/integrations/openai.py index 863f146a51..70dcda0384 100644 --- a/sentry_sdk/integrations/openai.py +++ b/sentry_sdk/integrations/openai.py @@ -50,7 +50,7 @@ from sentry_sdk.tracing import Span from sentry_sdk._types import TextPart - from openai.types.responses import ResponseInputParam + from openai.types.responses import ResponseInputParam, SequenceNotStr from openai import Omit try: @@ -220,20 +220,6 @@ def _calculate_token_usage( ) -def _get_input_messages( - kwargs: "dict[str, Any]", -) -> "Optional[Union[Iterable[Any], list[str]]]": - # Input messages (the prompt or data sent to the model) - messages = kwargs.get("messages") - if messages is None: - messages = kwargs.get("input") - - if isinstance(messages, str): - messages = [messages] - - return messages - - def _commmon_set_input_data( span: "Span", kwargs: "dict[str, Any]", @@ -413,15 +399,47 @@ def _set_embeddings_input_data( kwargs: "dict[str, Any]", integration: "OpenAIIntegration", ) -> None: - messages = _get_input_messages(kwargs) + messages: "Union[str, SequenceNotStr[str], Iterable[int], Iterable[Iterable[int]]]" = kwargs.get( + "input" + ) if ( - messages is not None - and len(messages) > 0 # type: ignore - and should_send_default_pii() - and integration.include_prompts + not should_send_default_pii() + or not integration.include_prompts + or messages is None ): - normalized_messages = normalize_message_roles(messages) # type: ignore + set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings") + _commmon_set_input_data(span, kwargs) + + return + + if isinstance(messages, str): + set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings") + _commmon_set_input_data(span, kwargs) + + normalized_messages = normalize_message_roles([messages]) # type: ignore + scope = sentry_sdk.get_current_scope() + messages_data = truncate_and_annotate_embedding_inputs( + normalized_messages, span, scope + ) + if messages_data is not None: + set_data_normalized( + span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, messages_data, unpack=False + ) + + return + + # dict special case following https://github.com/openai/openai-python/blob/3e0c05b84a2056870abf3bd6a5e7849020209cc3/src/openai/_utils/_transform.py#L194-L197 + if not isinstance(messages, Iterable) or isinstance(messages, dict): + set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings") + _commmon_set_input_data(span, kwargs) + return + + messages = list(messages) + kwargs["input"] = messages + + if len(messages) > 0: + normalized_messages = normalize_message_roles(messages) scope = sentry_sdk.get_current_scope() messages_data = truncate_and_annotate_embedding_inputs( normalized_messages, span, scope diff --git a/tests/integrations/openai/test_openai.py b/tests/integrations/openai/test_openai.py index 094b659b2c..2e806cc426 100644 --- a/tests/integrations/openai/test_openai.py +++ b/tests/integrations/openai/test_openai.py @@ -930,9 +930,13 @@ async def test_bad_chat_completion_async(sentry_init, capture_events): @pytest.mark.parametrize( "send_default_pii, include_prompts", - [(True, True), (True, False), (False, True), (False, False)], + [ + (True, False), + (False, True), + (False, False), + ], ) -def test_embeddings_create( +def test_embeddings_create_no_pii( sentry_init, capture_events, send_default_pii, include_prompts ): sentry_init( @@ -966,10 +970,110 @@ def test_embeddings_create( assert tx["type"] == "transaction" span = tx["spans"][0] assert span["op"] == "gen_ai.embeddings" - if send_default_pii and include_prompts: - assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT] + + assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"] + + assert span["data"]["gen_ai.usage.input_tokens"] == 20 + assert span["data"]["gen_ai.usage.total_tokens"] == 30 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input", + [ + pytest.param( + "hello", + id="string", + ), + pytest.param( + ["First text", "Second text", "Third text"], + id="string_sequence", + ), + pytest.param( + iter(["First text", "Second text", "Third text"]), + id="string_iterable", + ), + pytest.param( + [5, 8, 13, 21, 34], + id="tokens", + ), + pytest.param( + iter( + [5, 8, 13, 21, 34], + ), + id="token_iterable", + ), + pytest.param( + [ + [5, 8, 13, 21, 34], + [8, 13, 21, 34, 55], + ], + id="tokens_sequence", + ), + pytest.param( + iter( + [ + [5, 8, 13, 21, 34], + [8, 13, 21, 34, 55], + ] + ), + id="tokens_sequence_iterable", + ), + ], +) +def test_embeddings_create(sentry_init, capture_events, input, request): + sentry_init( + integrations=[OpenAIIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + client = OpenAI(api_key="z") + + returned_embedding = CreateEmbeddingResponse( + data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])], + model="some-model", + object="list", + usage=EmbeddingTokenUsage( + prompt_tokens=20, + total_tokens=30, + ), + ) + + client.embeddings._post = mock.Mock(return_value=returned_embedding) + with start_transaction(name="openai tx"): + response = client.embeddings.create(input=input, model="text-embedding-3-large") + + assert len(response.data[0].embedding) == 3 + + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "gen_ai.embeddings" + + param_id = request.node.callspec.id + if param_id == "string": + assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == ["hello"] + elif param_id == "string_sequence" or param_id == "string_iterable": + assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [ + "First text", + "Second text", + "Third text", + ] + elif param_id == "tokens" or param_id == "token_iterable": + assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [ + 5, + 8, + 13, + 21, + 34, + ] else: - assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"] + assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [ + [5, 8, 13, 21, 34], + [8, 13, 21, 34, 55], + ] assert span["data"]["gen_ai.usage.input_tokens"] == 20 assert span["data"]["gen_ai.usage.total_tokens"] == 30 @@ -978,9 +1082,13 @@ def test_embeddings_create( @pytest.mark.asyncio @pytest.mark.parametrize( "send_default_pii, include_prompts", - [(True, True), (True, False), (False, True), (False, False)], + [ + (True, False), + (False, True), + (False, False), + ], ) -async def test_embeddings_create_async( +async def test_embeddings_create_async_no_pii( sentry_init, capture_events, send_default_pii, include_prompts ): sentry_init( @@ -1014,10 +1122,112 @@ async def test_embeddings_create_async( assert tx["type"] == "transaction" span = tx["spans"][0] assert span["op"] == "gen_ai.embeddings" - if send_default_pii and include_prompts: - assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT] + + assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"] + + assert span["data"]["gen_ai.usage.input_tokens"] == 20 + assert span["data"]["gen_ai.usage.total_tokens"] == 30 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input", + [ + pytest.param( + "hello", + id="string", + ), + pytest.param( + ["First text", "Second text", "Third text"], + id="string_sequence", + ), + pytest.param( + iter(["First text", "Second text", "Third text"]), + id="string_iterable", + ), + pytest.param( + [5, 8, 13, 21, 34], + id="tokens", + ), + pytest.param( + iter( + [5, 8, 13, 21, 34], + ), + id="token_iterable", + ), + pytest.param( + [ + [5, 8, 13, 21, 34], + [8, 13, 21, 34, 55], + ], + id="tokens_sequence", + ), + pytest.param( + iter( + [ + [5, 8, 13, 21, 34], + [8, 13, 21, 34, 55], + ] + ), + id="tokens_sequence_iterable", + ), + ], +) +async def test_embeddings_create_async(sentry_init, capture_events, input, request): + sentry_init( + integrations=[OpenAIIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + client = AsyncOpenAI(api_key="z") + + returned_embedding = CreateEmbeddingResponse( + data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])], + model="some-model", + object="list", + usage=EmbeddingTokenUsage( + prompt_tokens=20, + total_tokens=30, + ), + ) + + client.embeddings._post = AsyncMock(return_value=returned_embedding) + with start_transaction(name="openai tx"): + response = await client.embeddings.create( + input=input, model="text-embedding-3-large" + ) + + assert len(response.data[0].embedding) == 3 + + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "gen_ai.embeddings" + + param_id = request.node.callspec.id + if param_id == "string": + assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == ["hello"] + elif param_id == "string_sequence" or param_id == "string_iterable": + assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [ + "First text", + "Second text", + "Third text", + ] + elif param_id == "tokens" or param_id == "token_iterable": + assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [ + 5, + 8, + 13, + 21, + 34, + ] else: - assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"] + assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [ + [5, 8, 13, 21, 34], + [8, 13, 21, 34, 55], + ] assert span["data"]["gen_ai.usage.input_tokens"] == 20 assert span["data"]["gen_ai.usage.total_tokens"] == 30