diff --git a/posthog/ai/anthropic/anthropic.py b/posthog/ai/anthropic/anthropic.py index 49f5c4f8..d8c9d0f2 100644 --- a/posthog/ai/anthropic/anthropic.py +++ b/posthog/ai/anthropic/anthropic.py @@ -8,7 +8,13 @@ import uuid from typing import Any, Dict, Optional -from posthog.ai.utils import call_llm_and_track_usage, get_model_params, merge_system_prompt, with_privacy_mode +from posthog.ai.utils import ( + call_llm_and_track_usage, + extract_core_model_params, + get_model_params, + merge_system_prompt, + with_privacy_mode, +) from posthog.client import Client as PostHogClient @@ -187,6 +193,7 @@ def _capture_streaming_event( "$ai_latency": latency, "$ai_trace_id": posthog_trace_id, "$ai_base_url": str(self._client.base_url), + **extract_core_model_params(kwargs, "anthropic"), **(posthog_properties or {}), } diff --git a/posthog/ai/langchain/callbacks.py b/posthog/ai/langchain/callbacks.py index 7a513b21..1fdb6e68 100644 --- a/posthog/ai/langchain/callbacks.py +++ b/posthog/ai/langchain/callbacks.py @@ -23,7 +23,7 @@ from langchain_core.outputs import ChatGeneration, LLMResult from pydantic import BaseModel -from posthog.ai.utils import get_model_params, with_privacy_mode +from posthog.ai.utils import extract_core_model_params, get_model_params, with_privacy_mode from posthog.client import Client log = logging.getLogger("posthog") @@ -178,6 +178,7 @@ def on_llm_end( "$ai_latency": latency, "$ai_trace_id": trace_id, "$ai_base_url": run.get("base_url"), + **extract_core_model_params(run.get("model_params"), run.get("provider")), **self._properties, } if self._distinct_id is None: @@ -224,6 +225,7 @@ def on_llm_error( "$ai_latency": latency, "$ai_trace_id": trace_id, "$ai_base_url": run.get("base_url"), + **extract_core_model_params(run.get("model_params"), run.get("provider")), **self._properties, } if self._distinct_id is None: diff --git a/posthog/ai/openai/openai.py b/posthog/ai/openai/openai.py index 5987ff4e..4cb80f85 100644 --- a/posthog/ai/openai/openai.py +++ b/posthog/ai/openai/openai.py @@ -8,7 +8,7 @@ except ImportError: raise ModuleNotFoundError("Please install the OpenAI SDK to use this feature: 'pip install openai'") -from posthog.ai.utils import call_llm_and_track_usage, get_model_params, with_privacy_mode +from posthog.ai.utils import call_llm_and_track_usage, extract_core_model_params, get_model_params, with_privacy_mode from posthog.client import Client as PostHogClient @@ -167,6 +167,7 @@ def _capture_streaming_event( "$ai_latency": latency, "$ai_trace_id": posthog_trace_id, "$ai_base_url": str(self._client.base_url), + **extract_core_model_params(kwargs, "openai"), **posthog_properties, } diff --git a/posthog/ai/utils.py b/posthog/ai/utils.py index 6a902a2d..a36e11e2 100644 --- a/posthog/ai/utils.py +++ b/posthog/ai/utils.py @@ -29,6 +29,35 @@ def get_model_params(kwargs: Dict[str, Any]) -> Dict[str, Any]: return model_params +def extract_core_model_params(kwargs: Dict[str, Any], provider: str) -> Dict[str, Any]: + """ + Extracts core model parameters from the kwargs dictionary. + """ + output = {} + if provider == "anthropic": + if "temperature" in kwargs: + output["$ai_temperature"] = kwargs.get("temperature") + if "max_tokens" in kwargs: + output["$ai_max_tokens"] = kwargs.get("max_tokens") + if "stream" in kwargs: + output["$ai_stream"] = kwargs.get("stream") + elif provider == "openai": + if "temperature" in kwargs: + output["$ai_temperature"] = kwargs.get("temperature") + if "max_completion_tokens" in kwargs: + output["$ai_max_tokens"] = kwargs.get("max_completion_tokens") + if "stream" in kwargs: + output["$ai_stream"] = kwargs.get("stream") + else: # default to openai params + if "temperature" in kwargs: + output["$ai_temperature"] = kwargs.get("temperature") + if "max_tokens" in kwargs: + output["$ai_max_tokens"] = kwargs.get("max_completion_tokens") + if "stream" in kwargs: + output["$ai_stream"] = kwargs.get("stream") + return output + + def get_usage(response, provider: str) -> Dict[str, Any]: if provider == "anthropic": return { @@ -148,6 +177,7 @@ def call_llm_and_track_usage( "$ai_latency": latency, "$ai_trace_id": posthog_trace_id, "$ai_base_url": str(base_url), + **extract_core_model_params(kwargs, provider), **(posthog_properties or {}), } @@ -218,6 +248,7 @@ async def call_llm_and_track_usage_async( "$ai_latency": latency, "$ai_trace_id": posthog_trace_id, "$ai_base_url": str(base_url), + **extract_core_model_params(kwargs, provider), **(posthog_properties or {}), } diff --git a/posthog/test/ai/anthropic/test_anthropic.py b/posthog/test/ai/anthropic/test_anthropic.py index a2a3dd91..6931c1ee 100644 --- a/posthog/test/ai/anthropic/test_anthropic.py +++ b/posthog/test/ai/anthropic/test_anthropic.py @@ -296,7 +296,6 @@ def test_streaming_system_prompt(mock_client, mock_anthropic_stream): call_args = mock_client.capture.call_args[1] props = call_args["properties"] - assert props["$ai_input"] == [{"role": "system", "content": "Foo"}, {"role": "user", "content": "Bar"}] @@ -325,3 +324,25 @@ async def test_async_streaming_system_prompt(mock_client, mock_anthropic_stream) {"role": "system", "content": "You must always answer with 'Bar'."}, {"role": "user", "content": "Foo"}, ] + + +def test_core_model_params(mock_client, mock_anthropic_response): + with patch("anthropic.resources.Messages.create", return_value=mock_anthropic_response): + client = Anthropic(api_key="test-key", posthog_client=mock_client) + response = client.messages.create( + model="claude-3-opus-20240229", + temperature=0.5, + max_tokens=100, + stream=False, + messages=[{"role": "user", "content": "Hello"}], + posthog_distinct_id="test-id", + posthog_properties={"foo": "bar"}, + ) + + assert response == mock_anthropic_response + props = mock_client.capture.call_args[1]["properties"] + assert props["$ai_model_parameters"] == {"temperature": 0.5, "max_tokens": 100, "stream": False} + assert props["$ai_temperature"] == 0.5 + assert props["$ai_max_tokens"] == 100 + assert props["$ai_stream"] == False + assert props["foo"] == "bar" diff --git a/posthog/test/ai/langchain/test_callbacks.py b/posthog/test/ai/langchain/test_callbacks.py index 697f5f54..9d46989b 100644 --- a/posthog/test/ai/langchain/test_callbacks.py +++ b/posthog/test/ai/langchain/test_callbacks.py @@ -771,3 +771,25 @@ def test_tool_calls(mock_client): } ] assert "additional_kwargs" not in call["properties"]["$ai_output_choices"][0] + + +@pytest.mark.skipif(not OPENAI_API_KEY, reason="OPENAI_API_KEY is not set") +def test_core_model_params(mock_client): + prompt = ChatPromptTemplate.from_messages([("user", "Foo")]) + chain = prompt | ChatOpenAI( + api_key=OPENAI_API_KEY, + model="gpt-4", + temperature=0.5, + max_tokens=100, + stream=False, + ) + callbacks = CallbackHandler(mock_client) + chain.invoke({}, config={"callbacks": [callbacks]}) + + assert mock_client.capture.call_count == 1 + call = mock_client.capture.call_args[1] + assert call["properties"]["$ai_model_parameters"] == {"temperature": 0.5, "max_tokens": 100, "stream": False} + assert call["properties"]["$ai_temperature"] == 0.5 + assert call["properties"]["$ai_max_tokens"] == 100 + assert call["properties"]["$ai_stream"] == False + assert call["properties"]["foo"] == "bar" diff --git a/posthog/test/ai/openai/test_openai.py b/posthog/test/ai/openai/test_openai.py index d2dbc06f..aee0c645 100644 --- a/posthog/test/ai/openai/test_openai.py +++ b/posthog/test/ai/openai/test_openai.py @@ -173,3 +173,28 @@ def test_privacy_mode_global(mock_client, mock_openai_response): props = call_args["properties"] assert props["$ai_input"] is None assert props["$ai_output_choices"] is None + + +def test_core_model_params(mock_client, mock_openai_response): + with patch("openai.resources.chat.completions.Completions.create", return_value=mock_openai_response): + client = OpenAI(api_key="test-key", posthog_client=mock_client) + response = client.chat.completions.create( + model="gpt-4", + temperature=0.5, + max_completion_tokens=100, + stream=False, + messages=[{"role": "user", "content": "Hello"}], + posthog_distinct_id="test-id", + posthog_properties={"foo": "bar"}, + ) + + assert response == mock_openai_response + assert mock_client.capture.call_count == 1 + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + assert props["$ai_model_parameters"] == {"temperature": 0.5, "max_completion_tokens": 100, "stream": False} + assert props["$ai_temperature"] == 0.5 + assert props["$ai_max_tokens"] == 100 + assert props["$ai_stream"] == False + assert props["foo"] == "bar"