Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions sentry_sdk/integrations/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from sentry_sdk.tracing import Span
from sentry_sdk._types import TextPart

from openai.types.responses import ResponseInputParam
from openai.types.responses import ResponseInputParam, SequenceNotStr
from openai import Omit

try:
Expand Down Expand Up @@ -220,20 +220,6 @@ def _calculate_token_usage(
)


def _get_input_messages(
kwargs: "dict[str, Any]",
) -> "Optional[Union[Iterable[Any], list[str]]]":
# Input messages (the prompt or data sent to the model)
messages = kwargs.get("messages")
if messages is None:
messages = kwargs.get("input")

if isinstance(messages, str):
messages = [messages]

return messages


def _commmon_set_input_data(
span: "Span",
kwargs: "dict[str, Any]",
Expand Down Expand Up @@ -413,15 +399,47 @@ def _set_embeddings_input_data(
kwargs: "dict[str, Any]",
integration: "OpenAIIntegration",
) -> None:
messages = _get_input_messages(kwargs)
messages: "Union[str, SequenceNotStr[str], Iterable[int], Iterable[Iterable[int]]]" = kwargs.get(
"input"
)

if (
messages is not None
and len(messages) > 0 # type: ignore
and should_send_default_pii()
and integration.include_prompts
not should_send_default_pii()
or not integration.include_prompts
or messages is None
):
normalized_messages = normalize_message_roles(messages) # type: ignore
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
_commmon_set_input_data(span, kwargs)

return

if isinstance(messages, str):
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
_commmon_set_input_data(span, kwargs)

normalized_messages = normalize_message_roles([messages]) # type: ignore
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_embedding_inputs(
normalized_messages, span, scope
)
if messages_data is not None:
set_data_normalized(
span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, messages_data, unpack=False
)

return

# dict special case following https://github.com/openai/openai-python/blob/3e0c05b84a2056870abf3bd6a5e7849020209cc3/src/openai/_utils/_transform.py#L194-L197
if not isinstance(messages, Iterable) or isinstance(messages, dict):
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
_commmon_set_input_data(span, kwargs)
return

messages = list(messages)
kwargs["input"] = messages

if len(messages) > 0:
normalized_messages = normalize_message_roles(messages)
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_embedding_inputs(
normalized_messages, span, scope
Expand Down
230 changes: 220 additions & 10 deletions tests/integrations/openai/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,9 +930,13 @@ async def test_bad_chat_completion_async(sentry_init, capture_events):

@pytest.mark.parametrize(
"send_default_pii, include_prompts",
[(True, True), (True, False), (False, True), (False, False)],
[
(True, False),
(False, True),
(False, False),
],
)
def test_embeddings_create(
def test_embeddings_create_no_pii(
sentry_init, capture_events, send_default_pii, include_prompts
):
sentry_init(
Expand Down Expand Up @@ -966,10 +970,110 @@ def test_embeddings_create(
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "gen_ai.embeddings"
if send_default_pii and include_prompts:
assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]

assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]

assert span["data"]["gen_ai.usage.input_tokens"] == 20
assert span["data"]["gen_ai.usage.total_tokens"] == 30


@pytest.mark.asyncio
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't appear to be needed since the test function isn't using async/await

@pytest.mark.parametrize(
"input",
[
pytest.param(
"hello",
id="string",
),
pytest.param(
["First text", "Second text", "Third text"],
id="string_sequence",
),
pytest.param(
iter(["First text", "Second text", "Third text"]),
id="string_iterable",
),
pytest.param(
[5, 8, 13, 21, 34],
id="tokens",
),
pytest.param(
iter(
[5, 8, 13, 21, 34],
),
id="token_iterable",
),
pytest.param(
[
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
],
id="tokens_sequence",
),
pytest.param(
iter(
[
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
]
),
id="tokens_sequence_iterable",
),
],
)
def test_embeddings_create(sentry_init, capture_events, input, request):
sentry_init(
integrations=[OpenAIIntegration(include_prompts=True)],
traces_sample_rate=1.0,
send_default_pii=True,
)
events = capture_events()

client = OpenAI(api_key="z")

returned_embedding = CreateEmbeddingResponse(
data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])],
model="some-model",
object="list",
usage=EmbeddingTokenUsage(
prompt_tokens=20,
total_tokens=30,
),
)

client.embeddings._post = mock.Mock(return_value=returned_embedding)
with start_transaction(name="openai tx"):
response = client.embeddings.create(input=input, model="text-embedding-3-large")

assert len(response.data[0].embedding) == 3

tx = events[0]
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "gen_ai.embeddings"

param_id = request.node.callspec.id
if param_id == "string":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == ["hello"]
elif param_id == "string_sequence" or param_id == "string_iterable":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
"First text",
"Second text",
"Third text",
]
elif param_id == "tokens" or param_id == "token_iterable":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
5,
8,
13,
21,
34,
]
else:
assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
]
Comment on lines +1056 to +1076
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This might be a stylistic preference, but this might make for a good match statement. I think it can help with readability here by reducing the amount of param_id == Xs you have written throughout.

Suggested change
if param_id == "string":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == ["hello"]
elif param_id == "string_sequence" or param_id == "string_iterable":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
"First text",
"Second text",
"Third text",
]
elif param_id == "tokens" or param_id == "token_iterable":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
5,
8,
13,
21,
34,
]
else:
assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
]
match param_id:
case "string":
assert span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT] == "hello"
case "string_sequence" | "string_iterable":
assert span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT] == [
"First text",
"Second text",
"Third text",
]
case "tokens" | "token_iterable":
assert span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT] == [
5,
8,
13,
21,
34,
]
case _:
assert span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT] == [
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
]


assert span["data"]["gen_ai.usage.input_tokens"] == 20
assert span["data"]["gen_ai.usage.total_tokens"] == 30
Expand All @@ -978,9 +1082,13 @@ def test_embeddings_create(
@pytest.mark.asyncio
@pytest.mark.parametrize(
"send_default_pii, include_prompts",
[(True, True), (True, False), (False, True), (False, False)],
[
(True, False),
(False, True),
(False, False),
],
)
async def test_embeddings_create_async(
async def test_embeddings_create_async_no_pii(
sentry_init, capture_events, send_default_pii, include_prompts
):
sentry_init(
Expand Down Expand Up @@ -1014,10 +1122,112 @@ async def test_embeddings_create_async(
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "gen_ai.embeddings"
if send_default_pii and include_prompts:
assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]

assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]

assert span["data"]["gen_ai.usage.input_tokens"] == 20
assert span["data"]["gen_ai.usage.total_tokens"] == 30


@pytest.mark.asyncio
@pytest.mark.parametrize(
"input",
[
pytest.param(
"hello",
id="string",
),
pytest.param(
["First text", "Second text", "Third text"],
id="string_sequence",
),
pytest.param(
iter(["First text", "Second text", "Third text"]),
id="string_iterable",
),
pytest.param(
[5, 8, 13, 21, 34],
id="tokens",
),
pytest.param(
iter(
[5, 8, 13, 21, 34],
),
id="token_iterable",
),
pytest.param(
[
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
],
id="tokens_sequence",
),
pytest.param(
iter(
[
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
]
),
id="tokens_sequence_iterable",
),
],
)
async def test_embeddings_create_async(sentry_init, capture_events, input, request):
sentry_init(
integrations=[OpenAIIntegration(include_prompts=True)],
traces_sample_rate=1.0,
send_default_pii=True,
)
events = capture_events()

client = AsyncOpenAI(api_key="z")

returned_embedding = CreateEmbeddingResponse(
data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])],
model="some-model",
object="list",
usage=EmbeddingTokenUsage(
prompt_tokens=20,
total_tokens=30,
),
)

client.embeddings._post = AsyncMock(return_value=returned_embedding)
with start_transaction(name="openai tx"):
response = await client.embeddings.create(
input=input, model="text-embedding-3-large"
)

assert len(response.data[0].embedding) == 3

tx = events[0]
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "gen_ai.embeddings"

param_id = request.node.callspec.id
if param_id == "string":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == ["hello"]
elif param_id == "string_sequence" or param_id == "string_iterable":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
"First text",
"Second text",
"Third text",
]
elif param_id == "tokens" or param_id == "token_iterable":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
5,
8,
13,
21,
34,
]
else:
assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
]

assert span["data"]["gen_ai.usage.input_tokens"] == 20
assert span["data"]["gen_ai.usage.total_tokens"] == 30
Expand Down
Loading