diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bd2abdbb..d3c96ddd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,7 +36,7 @@ jobs: - name: Lint with flake8 run: | - flake8 posthog --ignore E501 + flake8 posthog --ignore E501,W503 - name: Check import order with isort run: | diff --git a/posthog/ai/langchain/callbacks.py b/posthog/ai/langchain/callbacks.py index 7a513b21..679182ad 100644 --- a/posthog/ai/langchain/callbacks.py +++ b/posthog/ai/langchain/callbacks.py @@ -19,10 +19,12 @@ from uuid import UUID from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema.agent import AgentAction, AgentFinish from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.outputs import ChatGeneration, LLMResult from pydantic import BaseModel +from posthog import default_client from posthog.ai.utils import get_model_params, with_privacy_mode from posthog.client import Client @@ -44,19 +46,30 @@ class RunMetadata(TypedDict, total=False): class CallbackHandler(BaseCallbackHandler): """ - A callback handler for LangChain that sends events to PostHog LLM Observability. + The PostHog LLM observability callback handler for LangChain. """ _client: Client """PostHog client instance.""" + _distinct_id: Optional[Union[str, int, float, UUID]] """Distinct ID of the user to associate the trace with.""" + _trace_id: Optional[Union[str, int, float, UUID]] """Global trace ID to be sent with every event. Otherwise, the top-level run ID is used.""" + + _trace_input: Optional[Any] + """The input at the start of the trace. Any JSON object.""" + + _trace_name: Optional[str] + """Name of the trace, exposed in the UI.""" + _properties: Optional[Dict[str, Any]] """Global properties to be sent with every event.""" + _runs: RunStorage """Mapping of run IDs to run metadata as run metadata is only available on the start of generation.""" + _parent_tree: Dict[UUID, UUID] """ A dictionary that maps chain run IDs to their parent chain run IDs (parent pointer tree), @@ -65,7 +78,8 @@ class CallbackHandler(BaseCallbackHandler): def __init__( self, - client: Client, + client: Optional[Client] = None, + *, distinct_id: Optional[Union[str, int, float, UUID]] = None, trace_id: Optional[Union[str, int, float, UUID]] = None, properties: Optional[Dict[str, Any]] = None, @@ -81,9 +95,11 @@ def __init__( privacy_mode: Whether to redact the input and output of the trace. groups: Optional additional PostHog groups to use for the trace. """ - self._client = client + self._client = client or default_client self._distinct_id = distinct_id self._trace_id = trace_id + self._trace_name = None + self._trace_input = None self._properties = properties or {} self._privacy_mode = privacy_mode self._groups = groups or {} @@ -97,9 +113,14 @@ def on_chain_start( *, run_id: UUID, parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs, ): + self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs) self._set_parent_of_run(run_id, parent_run_id) + if parent_run_id is None and self._trace_name is None: + self._trace_name = self._get_langchain_run_name(serialized, **kwargs) + self._trace_input = inputs def on_chat_model_start( self, @@ -110,6 +131,7 @@ def on_chat_model_start( parent_run_id: Optional[UUID] = None, **kwargs, ): + self._log_debug_event("on_chat_model_start", run_id, parent_run_id, messages=messages) self._set_parent_of_run(run_id, parent_run_id) input = [_convert_message_to_dict(message) for row in messages for message in row] self._set_run_metadata(serialized, run_id, input, **kwargs) @@ -123,32 +145,93 @@ def on_llm_start( parent_run_id: Optional[UUID] = None, **kwargs: Any, ): + self._log_debug_event("on_llm_start", run_id, parent_run_id, prompts=prompts) self._set_parent_of_run(run_id, parent_run_id) self._set_run_metadata(serialized, run_id, prompts, **kwargs) + def on_llm_new_token( + self, + token: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on new LLM token. Only available when streaming is enabled.""" + self._log_debug_event("on_llm_new_token", run_id, parent_run_id, token=token) + + def on_tool_start( + self, + serialized: Optional[Dict[str, Any]], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + self._log_debug_event("on_tool_start", run_id, parent_run_id, input_str=input_str) + + def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output) + + def on_tool_error( + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error) + def on_chain_end( self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, **kwargs: Any, ): + self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs) + self._pop_parent_of_run(run_id) + + if parent_run_id is None: + self._capture_trace(run_id, outputs=outputs) + + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ): + self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error) self._pop_parent_of_run(run_id) + if parent_run_id is None: + self._capture_trace(run_id, outputs=None) + def on_llm_end( self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, **kwargs: Any, ): """ The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM. """ + self._log_debug_event("on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs) trace_id = self._get_trace_id(run_id) self._pop_parent_of_run(run_id) run = self._pop_run_metadata(run_id) @@ -189,25 +272,15 @@ def on_llm_end( groups=self._groups, ) - def on_chain_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ): - self._pop_parent_of_run(run_id) - def on_llm_error( self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, **kwargs: Any, ): + self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error) trace_id = self._get_trace_id(run_id) self._pop_parent_of_run(run_id) run = self._pop_run_metadata(run_id) @@ -235,6 +308,50 @@ def on_llm_error( groups=self._groups, ) + def on_retriever_start( + self, + serialized: Optional[Dict[str, Any]], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query) + + def on_retriever_error( + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when Retriever errors.""" + self._log_debug_event("on_retriever_error", run_id, parent_run_id, error=error) + + def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on agent action.""" + self._log_debug_event("on_agent_action", run_id, parent_run_id, action=action) + + def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + self._log_debug_event("on_agent_finish", run_id, parent_run_id, finish=finish) + def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None): """ Set the parent run ID for a chain run. If there is no parent, the run is the root. @@ -304,6 +421,65 @@ def _get_trace_id(self, run_id: UUID): trace_id = uuid.uuid4() return trace_id + def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> str: + """Retrieve the name of a serialized LangChain runnable. + + The prioritization for the determination of the run name is as follows: + - The value assigned to the "name" key in `kwargs`. + - The value assigned to the "name" key in `serialized`. + - The last entry of the value assigned to the "id" key in `serialized`. + - "". + + Args: + serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data. + **kwargs (Any): Additional keyword arguments, potentially including the 'name' override. + + Returns: + str: The determined name of the Langchain runnable. + """ + if "name" in kwargs and kwargs["name"] is not None: + return kwargs["name"] + + try: + return serialized["name"] + except (KeyError, TypeError): + pass + + try: + return serialized["id"][-1] + except (KeyError, TypeError): + pass + + def _capture_trace(self, run_id: UUID, *, outputs: Optional[Dict[str, Any]]): + trace_id = self._get_trace_id(run_id) + event_properties = { + "$ai_trace_name": self._trace_name, + "$ai_trace_id": trace_id, + "$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, self._trace_input), + **self._properties, + } + if outputs is not None: + event_properties["$ai_output_state"] = with_privacy_mode(self._client, self._privacy_mode, outputs) + if self._distinct_id is None: + event_properties["$process_person_profile"] = False + self._client.capture( + distinct_id=self._distinct_id or trace_id, + event="$ai_trace", + properties=event_properties, + groups=self._groups, + ) + + def _log_debug_event( + self, + event_name: str, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs, + ): + log.debug( + f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}, kwargs: {kwargs}" + ) + def _extract_raw_esponse(last_response): """Extract the response from the last response of the LLM call.""" @@ -339,7 +515,9 @@ def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]: return message_dict -def _parse_usage_model(usage: Union[BaseModel, Dict]) -> Tuple[Union[int, None], Union[int, None]]: +def _parse_usage_model( + usage: Union[BaseModel, Dict], +) -> Tuple[Union[int, None], Union[int, None]]: if isinstance(usage, BaseModel): usage = usage.__dict__ diff --git a/posthog/test/ai/langchain/__init__.py b/posthog/test/ai/langchain/__init__.py index 17075ef5..00f6a7cf 100644 --- a/posthog/test/ai/langchain/__init__.py +++ b/posthog/test/ai/langchain/__init__.py @@ -2,3 +2,4 @@ pytest.importorskip("langchain") pytest.importorskip("langchain_community") +pytest.importorskip("langgraph") diff --git a/posthog/test/ai/langchain/test_callbacks.py b/posthog/test/ai/langchain/test_callbacks.py index 697f5f54..87517e1c 100644 --- a/posthog/test/ai/langchain/test_callbacks.py +++ b/posthog/test/ai/langchain/test_callbacks.py @@ -1,17 +1,20 @@ +import logging import math import os import time import uuid +from typing import List, Optional, TypedDict, Union from unittest.mock import patch import pytest from langchain_anthropic.chat_models import ChatAnthropic from langchain_community.chat_models.fake import FakeMessagesListChatModel from langchain_community.llms.fake import FakeListLLM, FakeStreamingListLLM -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda from langchain_openai.chat_models import ChatOpenAI +from langgraph.graph.state import END, START, StateGraph from posthog.ai.langchain import CallbackHandler @@ -23,6 +26,7 @@ def mock_client(): with patch("posthog.client.Client") as mock_client: mock_client.privacy_mode = False + logging.getLogger("posthog").setLevel(logging.DEBUG) yield mock_client @@ -98,7 +102,11 @@ def test_basic_chat_chain(mock_client, stream): responses=[ AIMessage( content="The Los Angeles Dodgers won the World Series in 2020.", - usage_metadata={"input_tokens": 10, "output_tokens": 10, "total_tokens": 20}, + usage_metadata={ + "input_tokens": 10, + "output_tokens": 10, + "total_tokens": 20, + }, ) ] ) @@ -110,26 +118,31 @@ def test_basic_chat_chain(mock_client, stream): result = chain.invoke({}, config={"callbacks": callbacks}) assert result.content == "The Los Angeles Dodgers won the World Series in 2020." - assert mock_client.capture.call_count == 1 - args = mock_client.capture.call_args[1] - props = args["properties"] - - assert args["event"] == "$ai_generation" - assert "distinct_id" in args - assert "$ai_model" in props - assert "$ai_provider" in props - assert props["$ai_input"] == [ + assert mock_client.capture.call_count == 2 + generation_args = mock_client.capture.call_args_list[0][1] + generation_props = generation_args["properties"] + trace_args = mock_client.capture.call_args_list[1][1] + + assert generation_args["event"] == "$ai_generation" + assert "distinct_id" in generation_args + assert "$ai_model" in generation_props + assert "$ai_provider" in generation_props + assert generation_props["$ai_input"] == [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Who won the world series in 2020?"}, ] - assert props["$ai_output_choices"] == [ - {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."} + assert generation_props["$ai_output_choices"] == [ + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + } ] - assert props["$ai_input_tokens"] == 10 - assert props["$ai_output_tokens"] == 10 - assert props["$ai_http_status"] == 200 - assert props["$ai_trace_id"] is not None - assert isinstance(props["$ai_latency"], float) + assert generation_props["$ai_input_tokens"] == 10 + assert generation_props["$ai_output_tokens"] == 10 + assert generation_props["$ai_http_status"] == 200 + assert generation_props["$ai_trace_id"] is not None + assert isinstance(generation_props["$ai_latency"], float) + assert trace_args["event"] == "$ai_trace" @pytest.mark.parametrize("stream", [True, False]) @@ -144,7 +157,11 @@ async def test_async_basic_chat_chain(mock_client, stream): responses=[ AIMessage( content="The Los Angeles Dodgers won the World Series in 2020.", - usage_metadata={"input_tokens": 10, "output_tokens": 10, "total_tokens": 20}, + usage_metadata={ + "input_tokens": 10, + "output_tokens": 10, + "total_tokens": 20, + }, ) ] ) @@ -155,35 +172,50 @@ async def test_async_basic_chat_chain(mock_client, stream): else: result = await chain.ainvoke({}, config={"callbacks": callbacks}) assert result.content == "The Los Angeles Dodgers won the World Series in 2020." - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 - args = mock_client.capture.call_args[1] - props = args["properties"] - assert args["event"] == "$ai_generation" - assert "distinct_id" in args - assert "$ai_model" in props - assert "$ai_provider" in props - assert props["$ai_input"] == [ + generation_args = mock_client.capture.call_args_list[0][1] + generation_props = generation_args["properties"] + trace_args = mock_client.capture.call_args_list[1][1] + trace_props = trace_args["properties"] + + assert generation_args["event"] == "$ai_generation" + assert "distinct_id" in generation_args + assert "$ai_model" in generation_props + assert "$ai_provider" in generation_props + assert generation_props["$ai_input"] == [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Who won the world series in 2020?"}, ] - assert props["$ai_output_choices"] == [ - {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."} + assert generation_props["$ai_output_choices"] == [ + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + } ] - assert props["$ai_input_tokens"] == 10 - assert props["$ai_output_tokens"] == 10 - assert props["$ai_http_status"] == 200 - assert props["$ai_trace_id"] is not None - assert isinstance(props["$ai_latency"], float) + assert generation_props["$ai_input_tokens"] == 10 + assert generation_props["$ai_output_tokens"] == 10 + assert generation_props["$ai_http_status"] == 200 + assert generation_props["$ai_trace_id"] is not None + assert isinstance(generation_props["$ai_latency"], float) + + assert trace_args["event"] == "$ai_trace" + assert "distinct_id" in generation_args + assert trace_props["$ai_trace_id"] == generation_props["$ai_trace_id"] @pytest.mark.parametrize( "Model,stream", - [(FakeListLLM, True), (FakeListLLM, False), (FakeStreamingListLLM, True), (FakeStreamingListLLM, False)], + [ + (FakeListLLM, True), + (FakeListLLM, False), + (FakeStreamingListLLM, True), + (FakeStreamingListLLM, False), + ], ) def test_basic_llm_chain(mock_client, Model, stream): model = Model(responses=["The Los Angeles Dodgers won the World Series in 2020."]) - callbacks: list[CallbackHandler] = [CallbackHandler(mock_client)] + callbacks: List[CallbackHandler] = [CallbackHandler(mock_client)] if stream: result = "".join( @@ -194,7 +226,7 @@ def test_basic_llm_chain(mock_client, Model, stream): assert result == "The Los Angeles Dodgers won the World Series in 2020." assert mock_client.capture.call_count == 1 - args = mock_client.capture.call_args[1] + args = mock_client.capture.call_args_list[0][1] props = args["properties"] assert args["event"] == "$ai_generation" @@ -210,11 +242,16 @@ def test_basic_llm_chain(mock_client, Model, stream): @pytest.mark.parametrize( "Model,stream", - [(FakeListLLM, True), (FakeListLLM, False), (FakeStreamingListLLM, True), (FakeStreamingListLLM, False)], + [ + (FakeListLLM, True), + (FakeListLLM, False), + (FakeStreamingListLLM, True), + (FakeStreamingListLLM, False), + ], ) async def test_async_basic_llm_chain(mock_client, Model, stream): model = Model(responses=["The Los Angeles Dodgers won the World Series in 2020."]) - callbacks: list[CallbackHandler] = [CallbackHandler(mock_client)] + callbacks: List[CallbackHandler] = [CallbackHandler(mock_client)] if stream: result = "".join( @@ -225,7 +262,7 @@ async def test_async_basic_llm_chain(mock_client, Model, stream): assert result == "The Los Angeles Dodgers won the World Series in 2020." assert mock_client.capture.call_count == 1 - args = mock_client.capture.call_args[1] + args = mock_client.capture.call_args_list[0][1] props = args["properties"] assert args["event"] == "$ai_generation" @@ -251,7 +288,7 @@ def test_trace_id_for_multiple_chains(mock_client): result = chain.invoke({}, config={"callbacks": callbacks}) assert result.content == "Bar" - assert mock_client.capture.call_count == 2 + assert mock_client.capture.call_count == 3 first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] @@ -265,36 +302,54 @@ def test_trace_id_for_multiple_chains(mock_client): assert first_call_props["$ai_trace_id"] is not None assert isinstance(first_call_props["$ai_latency"], float) - second_call_args = mock_client.capture.call_args_list[1][1] - second_call_props = second_call_args["properties"] - assert second_call_args["event"] == "$ai_generation" - assert "distinct_id" in second_call_args - assert "$ai_model" in second_call_props - assert "$ai_provider" in second_call_props - assert second_call_props["$ai_input"] == [{"role": "assistant", "content": "Bar"}] - assert second_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] - assert second_call_props["$ai_http_status"] == 200 - assert second_call_props["$ai_trace_id"] is not None - assert isinstance(second_call_props["$ai_latency"], float) + second_generation_args = mock_client.capture.call_args_list[1][1] + second_generation_props = second_generation_args["properties"] + assert second_generation_args["event"] == "$ai_generation" + assert "distinct_id" in second_generation_args + assert "$ai_model" in second_generation_props + assert "$ai_provider" in second_generation_props + assert second_generation_props["$ai_input"] == [{"role": "assistant", "content": "Bar"}] + assert second_generation_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert second_generation_props["$ai_http_status"] == 200 + assert second_generation_props["$ai_trace_id"] is not None + assert isinstance(second_generation_props["$ai_latency"], float) + + trace_args = mock_client.capture.call_args_list[2][1] + trace_props = trace_args["properties"] + assert trace_args["event"] == "$ai_trace" + assert "distinct_id" in trace_args + assert trace_props["$ai_input_state"] == {} + assert isinstance(trace_props["$ai_output_state"], AIMessage) + assert trace_props["$ai_output_state"].content == "Bar" + assert trace_props["$ai_trace_id"] is not None + assert trace_props["$ai_trace_name"] == "RunnableSequence" # Check that the trace_id is the same as the first call - assert first_call_props["$ai_trace_id"] == second_call_props["$ai_trace_id"] + assert first_call_props["$ai_trace_id"] == second_generation_props["$ai_trace_id"] + assert first_call_props["$ai_trace_id"] == trace_props["$ai_trace_id"] def test_personless_mode(mock_client): prompt = ChatPromptTemplate.from_messages([("user", "Foo")]) chain = prompt | FakeMessagesListChatModel(responses=[AIMessage(content="Bar")]) chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client)]}) - assert mock_client.capture.call_count == 1 - args = mock_client.capture.call_args_list[0][1] - assert args["properties"]["$process_person_profile"] is False + assert mock_client.capture.call_count == 2 + generation_args = mock_client.capture.call_args_list[0][1] + trace_args = mock_client.capture.call_args_list[1][1] + assert generation_args["event"] == "$ai_generation" + assert generation_args["properties"]["$process_person_profile"] is False + assert trace_args["event"] == "$ai_trace" + assert trace_args["properties"]["$process_person_profile"] is False id = uuid.uuid4() chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]}) - assert mock_client.capture.call_count == 2 - args = mock_client.capture.call_args_list[1][1] - assert "$process_person_profile" not in args["properties"] - assert args["distinct_id"] == id + assert mock_client.capture.call_count == 4 + generation_args = mock_client.capture.call_args_list[2][1] + trace_args = mock_client.capture.call_args_list[3][1] + assert "$process_person_profile" not in generation_args["properties"] + assert generation_args["distinct_id"] == id + assert "$process_person_profile" not in trace_args["properties"] + assert trace_args["distinct_id"] == id def test_personless_mode_exception(mock_client): @@ -303,17 +358,24 @@ def test_personless_mode_exception(mock_client): callbacks = CallbackHandler(mock_client) with pytest.raises(Exception): chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 1 - args = mock_client.capture.call_args_list[0][1] - assert args["properties"]["$process_person_profile"] is False + assert mock_client.capture.call_count == 2 + generation_args = mock_client.capture.call_args_list[0][1] + trace_args = mock_client.capture.call_args_list[1][1] + assert generation_args["event"] == "$ai_generation" + assert generation_args["properties"]["$process_person_profile"] is False + assert trace_args["event"] == "$ai_trace" + assert trace_args["properties"]["$process_person_profile"] is False id = uuid.uuid4() with pytest.raises(Exception): chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]}) - assert mock_client.capture.call_count == 2 - args = mock_client.capture.call_args_list[1][1] - assert "$process_person_profile" not in args["properties"] - assert args["distinct_id"] == id + assert mock_client.capture.call_count == 4 + generation_args = mock_client.capture.call_args_list[2][1] + trace_args = mock_client.capture.call_args_list[3][1] + assert "$process_person_profile" not in generation_args["properties"] + assert generation_args["distinct_id"] == id + assert "$process_person_profile" not in trace_args["properties"] + assert trace_args["distinct_id"] == id def test_metadata(mock_client): @@ -324,31 +386,127 @@ def test_metadata(mock_client): ) model = FakeMessagesListChatModel(responses=[AIMessage(content="Bar")]) callbacks = [ - CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"}) + CallbackHandler( + mock_client, + trace_id="test-trace-id", + distinct_id="test_id", + properties={"foo": "bar"}, + ) ] chain = prompt | model - result = chain.invoke({}, config={"callbacks": callbacks}) + result = chain.invoke({"plan": None}, config={"callbacks": callbacks}) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 + + generation_call_args = mock_client.capture.call_args_list[0][1] + generation_call_props = generation_call_args["properties"] + assert generation_call_args["distinct_id"] == "test_id" + assert generation_call_args["event"] == "$ai_generation" + assert generation_call_props["$ai_trace_id"] == "test-trace-id" + assert generation_call_props["foo"] == "bar" + assert generation_call_props["$ai_input"] == [{"role": "user", "content": "Foo"}] + assert generation_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert generation_call_props["$ai_http_status"] == 200 + assert isinstance(generation_call_props["$ai_latency"], float) + + trace_call_args = mock_client.capture.call_args_list[1][1] + trace_call_props = trace_call_args["properties"] + assert trace_call_args["distinct_id"] == "test_id" + assert trace_call_args["event"] == "$ai_trace" + assert trace_call_props["$ai_trace_id"] == "test-trace-id" + assert trace_call_props["$ai_trace_name"] == "RunnableSequence" + assert trace_call_props["foo"] == "bar" + assert trace_call_props["$ai_input_state"] == {"plan": None} + assert isinstance(trace_call_props["$ai_output_state"], AIMessage) + assert trace_call_props["$ai_output_state"].content == "Bar" + + +class FakeGraphState(TypedDict): + messages: List[Union[HumanMessage, AIMessage]] + xyz: Optional[str] + + +def test_graph_state(mock_client): + config = {"callbacks": [CallbackHandler(mock_client)]} + + graph = StateGraph(FakeGraphState) + graph.add_node( + "fake_plain", + lambda state: { + "messages": [ + *state["messages"], + AIMessage(content="Let's explore bar."), + ], + "xyz": "abc", + }, + ) + intermediate_chain = ChatPromptTemplate.from_messages( + [("user", "Question: What's a bar?")] + ) | FakeMessagesListChatModel( + responses=[ + AIMessage(content="It's a type of greeble."), + ] + ) + graph.add_node( + "fake_llm", + lambda state: { + "messages": [ + *state["messages"], + intermediate_chain.invoke(state), + ], + "xyz": state["xyz"], + }, + ) + graph.add_edge(START, "fake_plain") + graph.add_edge("fake_plain", "fake_llm") + graph.add_edge("fake_llm", END) - first_call_args = mock_client.capture.call_args[1] - assert first_call_args["distinct_id"] == "test_id" + result = graph.compile().invoke( + {"messages": [HumanMessage(content="What's a bar?")], "xyz": None}, + config=config, + ) - first_call_props = first_call_args["properties"] - assert first_call_args["event"] == "$ai_generation" - assert first_call_props["$ai_trace_id"] == "test-trace-id" - assert first_call_props["foo"] == "bar" - assert first_call_props["$ai_input"] == [{"role": "user", "content": "Foo"}] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] - assert first_call_props["$ai_http_status"] == 200 - assert isinstance(first_call_props["$ai_latency"], float) + assert len(result["messages"]) == 3 + assert isinstance(result["messages"][0], HumanMessage) + assert result["messages"][0].content == "What's a bar?" + assert isinstance(result["messages"][1], AIMessage) + assert result["messages"][1].content == "Let's explore bar." + assert isinstance(result["messages"][2], AIMessage) + assert result["messages"][2].content == "It's a type of greeble." + + assert mock_client.capture.call_count == 2 + generation_args = mock_client.capture.call_args_list[0][1] + trace_args = mock_client.capture.call_args_list[1][1] + assert generation_args["event"] == "$ai_generation" + assert trace_args["event"] == "$ai_trace" + assert trace_args["properties"]["$ai_trace_name"] == "LangGraph" + + assert len(trace_args["properties"]["$ai_input_state"]["messages"]) == 1 + assert isinstance(trace_args["properties"]["$ai_input_state"]["messages"][0], HumanMessage) + assert trace_args["properties"]["$ai_input_state"]["messages"][0].content == "What's a bar?" + assert trace_args["properties"]["$ai_input_state"]["messages"][0].type == "human" + assert trace_args["properties"]["$ai_input_state"]["xyz"] is None + assert len(trace_args["properties"]["$ai_output_state"]["messages"]) == 3 + + assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][0], HumanMessage) + assert trace_args["properties"]["$ai_output_state"]["messages"][0].content == "What's a bar?" + assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][1], AIMessage) + assert trace_args["properties"]["$ai_output_state"]["messages"][1].content == "Let's explore bar." + assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][2], AIMessage) + assert trace_args["properties"]["$ai_output_state"]["messages"][2].content == "It's a type of greeble." + assert trace_args["properties"]["$ai_output_state"]["xyz"] == "abc" def test_callbacks_logic(mock_client): prompt = ChatPromptTemplate.from_messages([("user", "Foo")]) model = FakeMessagesListChatModel(responses=[AIMessage(content="Bar")]) - callbacks = CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"}) + callbacks = CallbackHandler( + mock_client, + trace_id="test-trace-id", + distinct_id="test_id", + properties={"foo": "bar"}, + ) chain = prompt | model chain.invoke({}, config={"callbacks": [callbacks]}) @@ -375,7 +533,10 @@ def runnable(_): assert callbacks._runs == {} assert callbacks._parent_tree == {} - assert mock_client.capture.call_count == 0 + assert mock_client.capture.call_count == 1 + trace_call_args = mock_client.capture.call_args_list[0][1] + assert trace_call_args["event"] == "$ai_trace" + assert trace_call_args["properties"]["$ai_trace_name"] == "runnable" def test_openai_error(mock_client): @@ -389,9 +550,9 @@ def test_openai_error(mock_client): assert callbacks._runs == {} assert callbacks._parent_tree == {} - assert mock_client.capture.call_count == 1 - args = mock_client.capture.call_args[1] - props = args["properties"] + assert mock_client.capture.call_count == 2 + generation_args = mock_client.capture.call_args_list[0][1] + props = generation_args["properties"] assert props["$ai_http_status"] == 401 assert props["$ai_input"] == [{"role": "user", "content": "Foo"}] assert "$ai_output_choices" not in props @@ -411,15 +572,20 @@ def test_openai_chain(mock_client): temperature=0, max_tokens=1, ) - callbacks = CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"}) + callbacks = CallbackHandler( + mock_client, + trace_id="test-trace-id", + distinct_id="test_id", + properties={"foo": "bar"}, + ) start_time = time.time() result = chain.invoke({}, config={"callbacks": [callbacks]}) approximate_latency = math.floor(time.time() - start_time) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 - first_call_args = mock_client.capture.call_args[1] + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_trace_id"] == "test-trace-id" @@ -445,13 +611,7 @@ def test_openai_chain(mock_client): {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [ - { - "role": "assistant", - "content": "Bar", - "additional_kwargs": {"refusal": None}, - } - ] + assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar", "refusal": None}] assert first_call_props["$ai_http_status"] == 200 assert isinstance(first_call_props["$ai_latency"], float) assert min(approximate_latency - 1, 0) <= math.floor(first_call_props["$ai_latency"]) <= approximate_latency @@ -478,20 +638,20 @@ def test_openai_captures_multiple_generations(mock_client): result = chain.invoke({}, config={"callbacks": [callbacks]}) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 - first_call_args = mock_client.capture.call_args[1] + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] + second_call_args = mock_client.capture.call_args_list[1][1] + second_call_props = second_call_args["properties"] + + assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] assert first_call_props["$ai_output_choices"] == [ - { - "role": "assistant", - "content": "Bar", - "additional_kwargs": {"refusal": None}, - }, + {"role": "assistant", "content": "Bar", "refusal": None}, { "role": "assistant", "content": "Bar", @@ -515,6 +675,10 @@ def test_openai_captures_multiple_generations(mock_client): } assert first_call_props["$ai_http_status"] == 200 + assert second_call_args["event"] == "$ai_trace" + assert second_call_props["$ai_input_state"] == {} + assert isinstance(second_call_props["$ai_output_state"], AIMessage) + @pytest.mark.skipif(not OPENAI_API_KEY, reason="OpenAI API key not set") def test_openai_streaming(mock_client): @@ -525,18 +689,26 @@ def test_openai_streaming(mock_client): ] ) chain = prompt | ChatOpenAI( - api_key=OPENAI_API_KEY, model="gpt-4o-mini", temperature=0, max_tokens=1, stream=True, stream_usage=True + api_key=OPENAI_API_KEY, + model="gpt-4o-mini", + temperature=0, + max_tokens=1, + stream=True, + stream_usage=True, ) callbacks = CallbackHandler(mock_client) result = [m for m in chain.stream({}, config={"callbacks": [callbacks]})] result = sum(result[1:], result[0]) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 - first_call_args = mock_client.capture.call_args[1] + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] + second_call_args = mock_client.capture.call_args_list[1][1] + second_call_props = second_call_args["properties"] + assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_model_parameters"]["stream"] assert first_call_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, @@ -547,6 +719,10 @@ def test_openai_streaming(mock_client): assert first_call_props["$ai_input_tokens"] == 20 assert first_call_props["$ai_output_tokens"] == 1 + assert second_call_args["event"] == "$ai_trace" + assert second_call_props["$ai_input_state"] == {"input": ""} + assert isinstance(second_call_props["$ai_output_state"], AIMessage) + @pytest.mark.skipif(not OPENAI_API_KEY, reason="OpenAI API key not set") async def test_async_openai_streaming(mock_client): @@ -557,18 +733,26 @@ async def test_async_openai_streaming(mock_client): ] ) chain = prompt | ChatOpenAI( - api_key=OPENAI_API_KEY, model="gpt-4o-mini", temperature=0, max_tokens=1, stream=True, stream_usage=True + api_key=OPENAI_API_KEY, + model="gpt-4o-mini", + temperature=0, + max_tokens=1, + stream=True, + stream_usage=True, ) callbacks = CallbackHandler(mock_client) result = [m async for m in chain.astream({}, config={"callbacks": [callbacks]})] result = sum(result[1:], result[0]) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 - first_call_args = mock_client.capture.call_args[1] + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] + second_call_args = mock_client.capture.call_args_list[1][1] + second_call_props = second_call_args["properties"] + assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_model_parameters"]["stream"] assert first_call_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, @@ -579,6 +763,10 @@ async def test_async_openai_streaming(mock_client): assert first_call_props["$ai_input_tokens"] == 20 assert first_call_props["$ai_output_tokens"] == 1 + assert second_call_args["event"] == "$ai_trace" + assert second_call_props["$ai_input_state"] == {"input": ""} + assert isinstance(second_call_props["$ai_output_state"], AIMessage) + def test_base_url_retrieval(mock_client): prompt = ChatPromptTemplate.from_messages([("user", "Foo")]) @@ -591,9 +779,9 @@ def test_base_url_retrieval(mock_client): with pytest.raises(Exception): chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 1 - call = mock_client.capture.call_args[1] - assert call["properties"]["$ai_base_url"] == "https://test.posthog.com" + assert mock_client.capture.call_count == 2 + generation_call = mock_client.capture.call_args_list[0][1] + assert generation_call["properties"]["$ai_base_url"] == "https://test.posthog.com" def test_groups(mock_client): @@ -608,9 +796,9 @@ def test_groups(mock_client): callbacks = CallbackHandler(mock_client, groups={"company": "test_company"}) chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 1 - call = mock_client.capture.call_args[1] - assert call["groups"] == {"company": "test_company"} + assert mock_client.capture.call_count == 2 + generation_call = mock_client.capture.call_args_list[0][1] + assert generation_call["groups"] == {"company": "test_company"} def test_privacy_mode_local(mock_client): @@ -625,10 +813,10 @@ def test_privacy_mode_local(mock_client): callbacks = CallbackHandler(mock_client, privacy_mode=True) chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 1 - call = mock_client.capture.call_args[1] - assert call["properties"]["$ai_input"] is None - assert call["properties"]["$ai_output_choices"] is None + assert mock_client.capture.call_count == 2 + generation_call = mock_client.capture.call_args_list[0][1] + assert generation_call["properties"]["$ai_input"] is None + assert generation_call["properties"]["$ai_output_choices"] is None def test_privacy_mode_global(mock_client): @@ -644,10 +832,10 @@ def test_privacy_mode_global(mock_client): callbacks = CallbackHandler(mock_client) chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 1 - call = mock_client.capture.call_args[1] - assert call["properties"]["$ai_input"] is None - assert call["properties"]["$ai_output_choices"] is None + assert mock_client.capture.call_count == 2 + generation_call = mock_client.capture.call_args_list[0][1] + assert generation_call["properties"]["$ai_input"] is None + assert generation_call["properties"]["$ai_output_choices"] is None @pytest.mark.skipif(not ANTHROPIC_API_KEY, reason="ANTHROPIC_API_KEY is not set") @@ -664,16 +852,24 @@ def test_anthropic_chain(mock_client): temperature=0, max_tokens=1, ) - callbacks = CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"}) + callbacks = CallbackHandler( + mock_client, + trace_id="test-trace-id", + distinct_id="test_id", + properties={"foo": "bar"}, + ) start_time = time.time() result = chain.invoke({}, config={"callbacks": [callbacks]}) approximate_latency = math.floor(time.time() - start_time) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 - first_call_args = mock_client.capture.call_args[1] + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] + second_call_args = mock_client.capture.call_args_list[1][1] + second_call_props = second_call_args["properties"] + assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_trace_id"] == "test-trace-id" assert first_call_props["$ai_provider"] == "anthropic" @@ -696,6 +892,10 @@ def test_anthropic_chain(mock_client): assert first_call_props["$ai_input_tokens"] == 17 assert first_call_props["$ai_output_tokens"] == 1 + assert second_call_args["event"] == "$ai_trace" + assert second_call_props["$ai_input_state"] == {} + assert isinstance(second_call_props["$ai_output_state"], AIMessage) + @pytest.mark.skipif(not ANTHROPIC_API_KEY, reason="ANTHROPIC_API_KEY is not set") async def test_async_anthropic_streaming(mock_client): @@ -718,10 +918,14 @@ async def test_async_anthropic_streaming(mock_client): result = sum(result[1:], result[0]) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 - first_call_args = mock_client.capture.call_args[1] + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] + second_call_args = mock_client.capture.call_args_list[1][1] + second_call_props = second_call_args["properties"] + + assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_model_parameters"]["streaming"] assert first_call_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, @@ -732,6 +936,12 @@ async def test_async_anthropic_streaming(mock_client): assert first_call_props["$ai_input_tokens"] == 17 assert first_call_props["$ai_output_tokens"] is not None + assert second_call_args["event"] == "$ai_trace" + assert second_call_props["$ai_input_state"] == { + "input": "", + } + assert isinstance(second_call_props["$ai_output_state"], AIMessage) + def test_tool_calls(mock_client): prompt = ChatPromptTemplate.from_messages([("user", "Foo")]) @@ -758,9 +968,9 @@ def test_tool_calls(mock_client): callbacks = CallbackHandler(mock_client) chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 1 - call = mock_client.capture.call_args[1] - assert call["properties"]["$ai_output_choices"][0]["tool_calls"] == [ + assert mock_client.capture.call_count == 2 + generation_call = mock_client.capture.call_args_list[0][1] + assert generation_call["properties"]["$ai_output_choices"][0]["tool_calls"] == [ { "type": "function", "id": "123", @@ -770,4 +980,4 @@ def test_tool_calls(mock_client): }, } ] - assert "additional_kwargs" not in call["properties"]["$ai_output_choices"][0] + assert "additional_kwargs" not in generation_call["properties"]["$ai_output_choices"][0] diff --git a/setup.py b/setup.py index 06fb4580..c6815c7e 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ "django", "openai", "anthropic", + "langgraph", "langchain-community>=0.2.0", "langchain-openai>=0.2.0", "langchain-anthropic>=0.2.0",