From 0384b8c7d76a31e9133930dc599584d397ba1a61 Mon Sep 17 00:00:00 2001 From: Michael Matloka Date: Mon, 20 Jan 2025 18:54:23 +0100 Subject: [PATCH 1/8] Format callbacks.py --- posthog/ai/langchain/callbacks.py | 64 +++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/posthog/ai/langchain/callbacks.py b/posthog/ai/langchain/callbacks.py index 7a513b21..129bf7cd 100644 --- a/posthog/ai/langchain/callbacks.py +++ b/posthog/ai/langchain/callbacks.py @@ -1,7 +1,9 @@ try: import langchain # noqa: F401 except ImportError: - raise ModuleNotFoundError("Please install LangChain to use this feature: 'pip install langchain'") + raise ModuleNotFoundError( + "Please install LangChain to use this feature: 'pip install langchain'" + ) import logging import time @@ -19,7 +21,14 @@ from uuid import UUID from langchain.callbacks.base import BaseCallbackHandler -from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.messages import ( + AIMessage, + BaseMessage, + FunctionMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) from langchain_core.outputs import ChatGeneration, LLMResult from pydantic import BaseModel @@ -111,7 +120,9 @@ def on_chat_model_start( **kwargs, ): self._set_parent_of_run(run_id, parent_run_id) - input = [_convert_message_to_dict(message) for row in messages for message in row] + input = [ + _convert_message_to_dict(message) for row in messages for message in row + ] self._set_run_metadata(serialized, run_id, input, **kwargs) def on_llm_start( @@ -161,17 +172,24 @@ def on_llm_end( generation_result = response.generations[-1] if isinstance(generation_result[-1], ChatGeneration): output = [ - _convert_message_to_dict(cast(ChatGeneration, generation).message) for generation in generation_result + _convert_message_to_dict(cast(ChatGeneration, generation).message) + for generation in generation_result ] else: - output = [_extract_raw_esponse(generation) for generation in generation_result] + output = [ + _extract_raw_esponse(generation) for generation in generation_result + ] event_properties = { "$ai_provider": run.get("provider"), "$ai_model": run.get("model"), "$ai_model_parameters": run.get("model_params"), - "$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")), - "$ai_output_choices": with_privacy_mode(self._client, self._privacy_mode, output), + "$ai_input": with_privacy_mode( + self._client, self._privacy_mode, run.get("messages") + ), + "$ai_output_choices": with_privacy_mode( + self._client, self._privacy_mode, output + ), "$ai_http_status": 200, "$ai_input_tokens": input_tokens, "$ai_output_tokens": output_tokens, @@ -219,7 +237,9 @@ def on_llm_error( "$ai_provider": run.get("provider"), "$ai_model": run.get("model"), "$ai_model_parameters": run.get("model_params"), - "$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")), + "$ai_input": with_privacy_mode( + self._client, self._privacy_mode, run.get("messages") + ), "$ai_http_status": _get_http_status(error), "$ai_latency": latency, "$ai_trace_id": trace_id, @@ -339,7 +359,9 @@ def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]: return message_dict -def _parse_usage_model(usage: Union[BaseModel, Dict]) -> Tuple[Union[int, None], Union[int, None]]: +def _parse_usage_model( + usage: Union[BaseModel, Dict], +) -> Tuple[Union[int, None], Union[int, None]]: if isinstance(usage, BaseModel): usage = usage.__dict__ @@ -363,7 +385,9 @@ def _parse_usage_model(usage: Union[BaseModel, Dict]) -> Tuple[Union[int, None], if model_key in usage: captured_count = usage[model_key] final_count = ( - sum(captured_count) if isinstance(captured_count, list) else captured_count + sum(captured_count) + if isinstance(captured_count, list) + else captured_count ) # For Bedrock, the token count is a list when streamed parsed_usage[type_key] = final_count @@ -384,8 +408,12 @@ def _parse_usage(response: LLMResult): if hasattr(response, "generations"): for generation in response.generations: for generation_chunk in generation: - if generation_chunk.generation_info and ("usage_metadata" in generation_chunk.generation_info): - llm_usage = _parse_usage_model(generation_chunk.generation_info["usage_metadata"]) + if generation_chunk.generation_info and ( + "usage_metadata" in generation_chunk.generation_info + ): + llm_usage = _parse_usage_model( + generation_chunk.generation_info["usage_metadata"] + ) break message_chunk = getattr(generation_chunk, "message", {}) @@ -397,13 +425,19 @@ def _parse_usage(response: LLMResult): else None ) bedrock_titan_usage = ( - response_metadata.get("amazon-bedrock-invocationMetrics", None) # for Bedrock-Titan + response_metadata.get( + "amazon-bedrock-invocationMetrics", None + ) # for Bedrock-Titan if isinstance(response_metadata, dict) else None ) - ollama_usage = getattr(message_chunk, "usage_metadata", None) # for Ollama + ollama_usage = getattr( + message_chunk, "usage_metadata", None + ) # for Ollama - chunk_usage = bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage + chunk_usage = ( + bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage + ) if chunk_usage: llm_usage = _parse_usage_model(chunk_usage) break From 199dfd25c2cac44857f78a82660d79df1bda4ea0 Mon Sep 17 00:00:00 2001 From: Michael Matloka Date: Tue, 21 Jan 2025 17:32:04 +0100 Subject: [PATCH 2/8] LangChain tracing, with LangGraph tests --- posthog/ai/langchain/callbacks.py | 186 +++++++- posthog/test/ai/langchain/__init__.py | 1 + posthog/test/ai/langchain/test_callbacks.py | 492 ++++++++++++++------ setup.py | 1 + 4 files changed, 533 insertions(+), 147 deletions(-) diff --git a/posthog/ai/langchain/callbacks.py b/posthog/ai/langchain/callbacks.py index 129bf7cd..4ea0306a 100644 --- a/posthog/ai/langchain/callbacks.py +++ b/posthog/ai/langchain/callbacks.py @@ -5,6 +5,7 @@ "Please install LangChain to use this feature: 'pip install langchain'" ) +import json import logging import time import uuid @@ -30,10 +31,12 @@ ToolMessage, ) from langchain_core.outputs import ChatGeneration, LLMResult +from langchain.schema.agent import AgentAction, AgentFinish from pydantic import BaseModel from posthog.ai.utils import get_model_params, with_privacy_mode from posthog.client import Client +from posthog import default_client log = logging.getLogger("posthog") @@ -53,7 +56,7 @@ class RunMetadata(TypedDict, total=False): class CallbackHandler(BaseCallbackHandler): """ - A callback handler for LangChain that sends events to PostHog LLM Observability. + The PostHog LLM observability callback handler for LangChain. """ _client: Client @@ -74,7 +77,8 @@ class CallbackHandler(BaseCallbackHandler): def __init__( self, - client: Client, + client: Optional[Client] = None, + *, distinct_id: Optional[Union[str, int, float, UUID]] = None, trace_id: Optional[Union[str, int, float, UUID]] = None, properties: Optional[Dict[str, Any]] = None, @@ -90,7 +94,7 @@ def __init__( privacy_mode: Whether to redact the input and output of the trace. groups: Optional additional PostHog groups to use for the trace. """ - self._client = client + self._client = client or default_client self._distinct_id = distinct_id self._trace_id = trace_id self._properties = properties or {} @@ -106,9 +110,12 @@ def on_chain_start( *, run_id: UUID, parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs, ): + self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs) self._set_parent_of_run(run_id, parent_run_id) + self._set_run_metadata(serialized, run_id, inputs, metadata, **kwargs) def on_chat_model_start( self, @@ -119,6 +126,9 @@ def on_chat_model_start( parent_run_id: Optional[UUID] = None, **kwargs, ): + self._log_debug_event( + "on_chat_model_start", run_id, parent_run_id, messages=messages + ) self._set_parent_of_run(run_id, parent_run_id) input = [ _convert_message_to_dict(message) for row in messages for message in row @@ -134,9 +144,58 @@ def on_llm_start( parent_run_id: Optional[UUID] = None, **kwargs: Any, ): + self._log_debug_event("on_llm_start", run_id, parent_run_id, prompts=prompts) self._set_parent_of_run(run_id, parent_run_id) self._set_run_metadata(serialized, run_id, prompts, **kwargs) + def on_llm_new_token( + self, + token: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on new LLM token. Only available when streaming is enabled.""" + self.log.debug( + f"on llm new token: run_id: {run_id} parent_run_id: {parent_run_id}" + ) + + def on_tool_start( + self, + serialized: Optional[Dict[str, Any]], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + self._log_debug_event( + "on_tool_start", run_id, parent_run_id, input_str=input_str + ) + + def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output) + + def on_tool_error( + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error) + def on_chain_end( self, outputs: Dict[str, Any], @@ -146,7 +205,35 @@ def on_chain_end( tags: Optional[List[str]] = None, **kwargs: Any, ): + self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs) self._pop_parent_of_run(run_id) + run_metadata = self._pop_run_metadata(run_id) + + if parent_run_id is None: + self._end_trace( + self._get_trace_id(run_id), + inputs=run_metadata.get("messages") if run_metadata else None, + outputs=outputs, + ) + + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ): + self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error) + self._pop_parent_of_run(run_id) + run_metadata = self._pop_run_metadata(run_id) + + if parent_run_id is None: + self._end_trace( + self._get_trace_id(run_id), + inputs=run_metadata.get("messages") if run_metadata else None, + outputs=None, + ) def on_llm_end( self, @@ -160,6 +247,9 @@ def on_llm_end( """ The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM. """ + self._log_debug_event( + "on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs + ) trace_id = self._get_trace_id(run_id) self._pop_parent_of_run(run_id) run = self._pop_run_metadata(run_id) @@ -207,16 +297,6 @@ def on_llm_end( groups=self._groups, ) - def on_chain_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ): - self._pop_parent_of_run(run_id) - def on_llm_error( self, error: BaseException, @@ -226,6 +306,7 @@ def on_llm_error( tags: Optional[List[str]] = None, **kwargs: Any, ): + self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error) trace_id = self._get_trace_id(run_id) self._pop_parent_of_run(run_id) run = self._pop_run_metadata(run_id) @@ -255,6 +336,51 @@ def on_llm_error( groups=self._groups, ) + def on_retriever_start( + self, + serialized: Optional[Dict[str, Any]], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query) + + def on_retriever_error( + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when Retriever errors.""" + self._log_debug_event("on_retriever_error", run_id, parent_run_id, error=error) + + def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on agent action.""" + self._log_debug_event("on_agent_action", run_id, parent_run_id, action=action) + + def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + self._log_debug_event("on_agent_finish", run_id, parent_run_id, finish=finish) + def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None): """ Set the parent run ID for a chain run. If there is no parent, the run is the root. @@ -324,6 +450,40 @@ def _get_trace_id(self, run_id: UUID): trace_id = uuid.uuid4() return trace_id + def _end_trace( + self, trace_id: UUID, inputs: Dict[str, Any], outputs: Optional[Dict[str, Any]] + ): + event_properties = { + "$ai_trace_id": trace_id, + "$ai_input_state": with_privacy_mode( + self._client, self._privacy_mode, inputs + ), + **self._properties, + } + if outputs is not None: + event_properties["$ai_output_state"] = with_privacy_mode( + self._client, self._privacy_mode, outputs + ) + if self._distinct_id is None: + event_properties["$process_person_profile"] = False + self._client.capture( + distinct_id=self._distinct_id or trace_id, + event="$ai_trace", + properties=event_properties, + groups=self._groups, + ) + + def _log_debug_event( + self, + event_name: str, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs, + ): + log.debug( + f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}, kwargs: {kwargs}" + ) + def _extract_raw_esponse(last_response): """Extract the response from the last response of the LLM call.""" diff --git a/posthog/test/ai/langchain/__init__.py b/posthog/test/ai/langchain/__init__.py index 17075ef5..00f6a7cf 100644 --- a/posthog/test/ai/langchain/__init__.py +++ b/posthog/test/ai/langchain/__init__.py @@ -2,3 +2,4 @@ pytest.importorskip("langchain") pytest.importorskip("langchain_community") +pytest.importorskip("langgraph") diff --git a/posthog/test/ai/langchain/test_callbacks.py b/posthog/test/ai/langchain/test_callbacks.py index 697f5f54..84803c31 100644 --- a/posthog/test/ai/langchain/test_callbacks.py +++ b/posthog/test/ai/langchain/test_callbacks.py @@ -1,17 +1,21 @@ +import logging import math import os +from pyexpat.errors import messages import time +from typing import Optional, TypedDict import uuid -from unittest.mock import patch +from unittest.mock import patch, ANY import pytest from langchain_anthropic.chat_models import ChatAnthropic from langchain_community.chat_models.fake import FakeMessagesListChatModel from langchain_community.llms.fake import FakeListLLM, FakeStreamingListLLM -from langchain_core.messages import AIMessage +from langchain_core.messages import HumanMessage, AIMessage from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda from langchain_openai.chat_models import ChatOpenAI +from langgraph.graph.state import StateGraph, START, END from posthog.ai.langchain import CallbackHandler @@ -23,6 +27,7 @@ def mock_client(): with patch("posthog.client.Client") as mock_client: mock_client.privacy_mode = False + logging.getLogger("posthog").setLevel(logging.DEBUG) yield mock_client @@ -98,7 +103,11 @@ def test_basic_chat_chain(mock_client, stream): responses=[ AIMessage( content="The Los Angeles Dodgers won the World Series in 2020.", - usage_metadata={"input_tokens": 10, "output_tokens": 10, "total_tokens": 20}, + usage_metadata={ + "input_tokens": 10, + "output_tokens": 10, + "total_tokens": 20, + }, ) ] ) @@ -110,26 +119,31 @@ def test_basic_chat_chain(mock_client, stream): result = chain.invoke({}, config={"callbacks": callbacks}) assert result.content == "The Los Angeles Dodgers won the World Series in 2020." - assert mock_client.capture.call_count == 1 - args = mock_client.capture.call_args[1] - props = args["properties"] - - assert args["event"] == "$ai_generation" - assert "distinct_id" in args - assert "$ai_model" in props - assert "$ai_provider" in props - assert props["$ai_input"] == [ + assert mock_client.capture.call_count == 2 + generation_args = mock_client.capture.call_args_list[0][1] + generation_props = generation_args["properties"] + trace_args = mock_client.capture.call_args_list[1][1] + + assert generation_args["event"] == "$ai_generation" + assert "distinct_id" in generation_args + assert "$ai_model" in generation_props + assert "$ai_provider" in generation_props + assert generation_props["$ai_input"] == [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Who won the world series in 2020?"}, ] - assert props["$ai_output_choices"] == [ - {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."} + assert generation_props["$ai_output_choices"] == [ + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + } ] - assert props["$ai_input_tokens"] == 10 - assert props["$ai_output_tokens"] == 10 - assert props["$ai_http_status"] == 200 - assert props["$ai_trace_id"] is not None - assert isinstance(props["$ai_latency"], float) + assert generation_props["$ai_input_tokens"] == 10 + assert generation_props["$ai_output_tokens"] == 10 + assert generation_props["$ai_http_status"] == 200 + assert generation_props["$ai_trace_id"] is not None + assert isinstance(generation_props["$ai_latency"], float) + assert trace_args["event"] == "$ai_trace" @pytest.mark.parametrize("stream", [True, False]) @@ -144,42 +158,63 @@ async def test_async_basic_chat_chain(mock_client, stream): responses=[ AIMessage( content="The Los Angeles Dodgers won the World Series in 2020.", - usage_metadata={"input_tokens": 10, "output_tokens": 10, "total_tokens": 20}, + usage_metadata={ + "input_tokens": 10, + "output_tokens": 10, + "total_tokens": 20, + }, ) ] ) callbacks = [CallbackHandler(mock_client)] chain = prompt | model if stream: - result = [m async for m in chain.astream({}, config={"callbacks": callbacks})][0] + result = [m async for m in chain.astream({}, config={"callbacks": callbacks})][ + 0 + ] else: result = await chain.ainvoke({}, config={"callbacks": callbacks}) assert result.content == "The Los Angeles Dodgers won the World Series in 2020." - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 - args = mock_client.capture.call_args[1] - props = args["properties"] - assert args["event"] == "$ai_generation" - assert "distinct_id" in args - assert "$ai_model" in props - assert "$ai_provider" in props - assert props["$ai_input"] == [ + generation_args = mock_client.capture.call_args_list[0][1] + generation_props = generation_args["properties"] + trace_args = mock_client.capture.call_args_list[1][1] + trace_props = trace_args["properties"] + + assert generation_args["event"] == "$ai_generation" + assert "distinct_id" in generation_args + assert "$ai_model" in generation_props + assert "$ai_provider" in generation_props + assert generation_props["$ai_input"] == [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Who won the world series in 2020?"}, ] - assert props["$ai_output_choices"] == [ - {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."} + assert generation_props["$ai_output_choices"] == [ + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + } ] - assert props["$ai_input_tokens"] == 10 - assert props["$ai_output_tokens"] == 10 - assert props["$ai_http_status"] == 200 - assert props["$ai_trace_id"] is not None - assert isinstance(props["$ai_latency"], float) + assert generation_props["$ai_input_tokens"] == 10 + assert generation_props["$ai_output_tokens"] == 10 + assert generation_props["$ai_http_status"] == 200 + assert generation_props["$ai_trace_id"] is not None + assert isinstance(generation_props["$ai_latency"], float) + + assert trace_args["event"] == "$ai_trace" + assert "distinct_id" in generation_args + assert trace_props["$ai_trace_id"] == generation_props["$ai_trace_id"] @pytest.mark.parametrize( "Model,stream", - [(FakeListLLM, True), (FakeListLLM, False), (FakeStreamingListLLM, True), (FakeStreamingListLLM, False)], + [ + (FakeListLLM, True), + (FakeListLLM, False), + (FakeStreamingListLLM, True), + (FakeStreamingListLLM, False), + ], ) def test_basic_llm_chain(mock_client, Model, stream): model = Model(responses=["The Los Angeles Dodgers won the World Series in 2020."]) @@ -187,14 +222,21 @@ def test_basic_llm_chain(mock_client, Model, stream): if stream: result = "".join( - [m for m in model.stream("Who won the world series in 2020?", config={"callbacks": callbacks})] + [ + m + for m in model.stream( + "Who won the world series in 2020?", config={"callbacks": callbacks} + ) + ] ) else: - result = model.invoke("Who won the world series in 2020?", config={"callbacks": callbacks}) + result = model.invoke( + "Who won the world series in 2020?", config={"callbacks": callbacks} + ) assert result == "The Los Angeles Dodgers won the World Series in 2020." assert mock_client.capture.call_count == 1 - args = mock_client.capture.call_args[1] + args = mock_client.capture.call_args_list[0][1] props = args["properties"] assert args["event"] == "$ai_generation" @@ -202,7 +244,9 @@ def test_basic_llm_chain(mock_client, Model, stream): assert "$ai_model" in props assert "$ai_provider" in props assert props["$ai_input"] == ["Who won the world series in 2020?"] - assert props["$ai_output_choices"] == ["The Los Angeles Dodgers won the World Series in 2020."] + assert props["$ai_output_choices"] == [ + "The Los Angeles Dodgers won the World Series in 2020." + ] assert props["$ai_http_status"] == 200 assert props["$ai_trace_id"] is not None assert isinstance(props["$ai_latency"], float) @@ -210,7 +254,12 @@ def test_basic_llm_chain(mock_client, Model, stream): @pytest.mark.parametrize( "Model,stream", - [(FakeListLLM, True), (FakeListLLM, False), (FakeStreamingListLLM, True), (FakeStreamingListLLM, False)], + [ + (FakeListLLM, True), + (FakeListLLM, False), + (FakeStreamingListLLM, True), + (FakeStreamingListLLM, False), + ], ) async def test_async_basic_llm_chain(mock_client, Model, stream): model = Model(responses=["The Los Angeles Dodgers won the World Series in 2020."]) @@ -218,14 +267,21 @@ async def test_async_basic_llm_chain(mock_client, Model, stream): if stream: result = "".join( - [m async for m in model.astream("Who won the world series in 2020?", config={"callbacks": callbacks})] + [ + m + async for m in model.astream( + "Who won the world series in 2020?", config={"callbacks": callbacks} + ) + ] ) else: - result = await model.ainvoke("Who won the world series in 2020?", config={"callbacks": callbacks}) + result = await model.ainvoke( + "Who won the world series in 2020?", config={"callbacks": callbacks} + ) assert result == "The Los Angeles Dodgers won the World Series in 2020." assert mock_client.capture.call_count == 1 - args = mock_client.capture.call_args[1] + args = mock_client.capture.call_args_list[0][1] props = args["properties"] assert args["event"] == "$ai_generation" @@ -233,7 +289,9 @@ async def test_async_basic_llm_chain(mock_client, Model, stream): assert "$ai_model" in props assert "$ai_provider" in props assert props["$ai_input"] == ["Who won the world series in 2020?"] - assert props["$ai_output_choices"] == ["The Los Angeles Dodgers won the World Series in 2020."] + assert props["$ai_output_choices"] == [ + "The Los Angeles Dodgers won the World Series in 2020." + ] assert props["$ai_http_status"] == 200 assert props["$ai_trace_id"] is not None assert isinstance(props["$ai_latency"], float) @@ -251,50 +309,77 @@ def test_trace_id_for_multiple_chains(mock_client): result = chain.invoke({}, config={"callbacks": callbacks}) assert result.content == "Bar" - assert mock_client.capture.call_count == 2 - - first_call_args = mock_client.capture.call_args_list[0][1] - first_call_props = first_call_args["properties"] - assert first_call_args["event"] == "$ai_generation" - assert "distinct_id" in first_call_args - assert "$ai_model" in first_call_props - assert "$ai_provider" in first_call_props - assert first_call_props["$ai_input"] == [{"role": "user", "content": "Foo"}] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] - assert first_call_props["$ai_http_status"] == 200 - assert first_call_props["$ai_trace_id"] is not None - assert isinstance(first_call_props["$ai_latency"], float) - - second_call_args = mock_client.capture.call_args_list[1][1] - second_call_props = second_call_args["properties"] - assert second_call_args["event"] == "$ai_generation" - assert "distinct_id" in second_call_args - assert "$ai_model" in second_call_props - assert "$ai_provider" in second_call_props - assert second_call_props["$ai_input"] == [{"role": "assistant", "content": "Bar"}] - assert second_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] - assert second_call_props["$ai_http_status"] == 200 - assert second_call_props["$ai_trace_id"] is not None - assert isinstance(second_call_props["$ai_latency"], float) + assert mock_client.capture.call_count == 3 + + first_generation_args = mock_client.capture.call_args_list[0][1] + first_generation_props = first_generation_args["properties"] + assert first_generation_args["event"] == "$ai_generation" + assert "distinct_id" in first_generation_args + assert "$ai_model" in first_generation_props + assert "$ai_provider" in first_generation_props + assert first_generation_props["$ai_input"] == [{"role": "user", "content": "Foo"}] + assert first_generation_props["$ai_output_choices"] == [ + {"role": "assistant", "content": "Bar"} + ] + assert first_generation_props["$ai_http_status"] == 200 + assert first_generation_props["$ai_trace_id"] is not None + assert isinstance(first_generation_props["$ai_latency"], float) + + second_generation_args = mock_client.capture.call_args_list[1][1] + second_generation_props = second_generation_args["properties"] + assert second_generation_args["event"] == "$ai_generation" + assert "distinct_id" in second_generation_args + assert "$ai_model" in second_generation_props + assert "$ai_provider" in second_generation_props + assert second_generation_props["$ai_input"] == [ + {"role": "assistant", "content": "Bar"} + ] + assert second_generation_props["$ai_output_choices"] == [ + {"role": "assistant", "content": "Bar"} + ] + assert second_generation_props["$ai_http_status"] == 200 + assert second_generation_props["$ai_trace_id"] is not None + assert isinstance(second_generation_props["$ai_latency"], float) + + trace_args = mock_client.capture.call_args_list[2][1] + trace_props = trace_args["properties"] + assert trace_args["event"] == "$ai_trace" + assert "distinct_id" in trace_args + assert trace_props["$ai_input_state"] == [{"role": "assistant", "content": "Bar"}] + assert trace_props["$ai_output_state"] == [{"role": "assistant", "content": "Bar"}] + assert trace_props["$ai_trace_id"] is not None # Check that the trace_id is the same as the first call - assert first_call_props["$ai_trace_id"] == second_call_props["$ai_trace_id"] + assert ( + first_generation_props["$ai_trace_id"] + == second_generation_props["$ai_trace_id"] + ) + assert first_generation_props["$ai_trace_id"] == trace_props["$ai_trace_id"] def test_personless_mode(mock_client): prompt = ChatPromptTemplate.from_messages([("user", "Foo")]) chain = prompt | FakeMessagesListChatModel(responses=[AIMessage(content="Bar")]) chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client)]}) - assert mock_client.capture.call_count == 1 - args = mock_client.capture.call_args_list[0][1] - assert args["properties"]["$process_person_profile"] is False + assert mock_client.capture.call_count == 2 + generation_args = mock_client.capture.call_args_list[0][1] + trace_args = mock_client.capture.call_args_list[1][1] + assert generation_args["event"] == "$ai_generation" + assert generation_args["properties"]["$process_person_profile"] is False + assert trace_args["event"] == "$ai_trace" + assert trace_args["properties"]["$process_person_profile"] is False id = uuid.uuid4() - chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]}) - assert mock_client.capture.call_count == 2 - args = mock_client.capture.call_args_list[1][1] - assert "$process_person_profile" not in args["properties"] - assert args["distinct_id"] == id + chain.invoke( + {}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]} + ) + assert mock_client.capture.call_count == 4 + generation_args = mock_client.capture.call_args_list[2][1] + trace_args = mock_client.capture.call_args_list[3][1] + assert "$process_person_profile" not in generation_args["properties"] + assert generation_args["distinct_id"] == id + assert "$process_person_profile" not in trace_args["properties"] + assert trace_args["distinct_id"] == id def test_personless_mode_exception(mock_client): @@ -303,17 +388,26 @@ def test_personless_mode_exception(mock_client): callbacks = CallbackHandler(mock_client) with pytest.raises(Exception): chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 1 - args = mock_client.capture.call_args_list[0][1] - assert args["properties"]["$process_person_profile"] is False + assert mock_client.capture.call_count == 2 + generation_args = mock_client.capture.call_args_list[0][1] + trace_args = mock_client.capture.call_args_list[1][1] + assert generation_args["event"] == "$ai_generation" + assert generation_args["properties"]["$process_person_profile"] is False + assert trace_args["event"] == "$ai_trace" + assert trace_args["properties"]["$process_person_profile"] is False id = uuid.uuid4() with pytest.raises(Exception): - chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]}) - assert mock_client.capture.call_count == 2 - args = mock_client.capture.call_args_list[1][1] - assert "$process_person_profile" not in args["properties"] - assert args["distinct_id"] == id + chain.invoke( + {}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]} + ) + assert mock_client.capture.call_count == 4 + generation_args = mock_client.capture.call_args_list[2][1] + trace_args = mock_client.capture.call_args_list[3][1] + assert "$process_person_profile" not in generation_args["properties"] + assert generation_args["distinct_id"] == id + assert "$process_person_profile" not in trace_args["properties"] + assert trace_args["distinct_id"] == id def test_metadata(mock_client): @@ -324,31 +418,118 @@ def test_metadata(mock_client): ) model = FakeMessagesListChatModel(responses=[AIMessage(content="Bar")]) callbacks = [ - CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"}) + CallbackHandler( + mock_client, + trace_id="test-trace-id", + distinct_id="test_id", + properties={"foo": "bar"}, + ) ] chain = prompt | model result = chain.invoke({}, config={"callbacks": callbacks}) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 - - first_call_args = mock_client.capture.call_args[1] - assert first_call_args["distinct_id"] == "test_id" + assert mock_client.capture.call_count == 2 + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] + assert first_call_args["distinct_id"] == "test_id" assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_trace_id"] == "test-trace-id" assert first_call_props["foo"] == "bar" assert first_call_props["$ai_input"] == [{"role": "user", "content": "Foo"}] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert first_call_props["$ai_output_choices"] == [ + {"role": "assistant", "content": "Bar"} + ] assert first_call_props["$ai_http_status"] == 200 assert isinstance(first_call_props["$ai_latency"], float) + second_call_args = mock_client.capture.call_args_list[1][1] + second_call_props = second_call_args["properties"] + assert second_call_args["distinct_id"] == "test_id" + assert second_call_args["event"] == "$ai_trace" + assert second_call_props["$ai_trace_id"] == "test-trace-id" + assert second_call_props["foo"] == "bar" + assert isinstance(second_call_props["$ai_output_state"], AIMessage) + assert second_call_props["$ai_output_state"].content == "Bar" + assert second_call_props["$ai_input_state"] == "Foo" + + +class FakeGraphState(TypedDict): + messages: list[HumanMessage | AIMessage] + xyz: Optional[str] + + +def test_graph_state(mock_client): + config = {"callbacks": [CallbackHandler(mock_client)]} + + graph = StateGraph(FakeGraphState) + graph.add_node( + "fake_plain", + lambda state: ( + { + "messages": [ + *state["messages"], + AIMessage(content="Let's explore bar."), + ], + "xyz": "abc", + } + ), + ) + graph.add_node( + "fake_llm", + lambda state: ( + ChatPromptTemplate.from_messages([("user", "Foo")]) + | FakeMessagesListChatModel( + responses=[ + *state["messages"], + AIMessage(content="It's a type of greeble."), + ] + ) + ).invoke( + state, + config=config, + ), + ) + graph.add_edge(START, "fake_plain") + graph.add_edge("fake_plain", "fake_llm") + graph.add_edge("fake_llm", END) + + result = graph.compile().invoke( + {"messages": [HumanMessage(content="What's a bar?")], "xyz": None}, + config=config, + ) + + assert len(result["messages"]) == 2 + assert isinstance(result["messages"][0], HumanMessage) + assert result["messages"][0].content == "What's a bar?" + assert isinstance(result["messages"][1], AIMessage) + assert result["messages"][1].content == "Let's explore bar." # TODO + + assert mock_client.capture.call_count == 3 + generation_args = mock_client.capture.call_args_list[0][1] + trace_args = mock_client.capture.call_args_list[2][1] + assert generation_args["event"] == "$ai_generation" + assert trace_args["event"] == "$ai_trace" + assert trace_args["properties"]["$ai_input_state"] == { + "messages": [HumanMessage(content="What's a bar?")], + "xyz": None, + } + assert trace_args["properties"]["$ai_output_state"] == { + "messages": [AIMessage(content="Let's explore bar.")], + "xyz": "abc", + } + def test_callbacks_logic(mock_client): prompt = ChatPromptTemplate.from_messages([("user", "Foo")]) model = FakeMessagesListChatModel(responses=[AIMessage(content="Bar")]) - callbacks = CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"}) + callbacks = CallbackHandler( + mock_client, + trace_id="test-trace-id", + distinct_id="test_id", + properties={"foo": "bar"}, + ) chain = prompt | model chain.invoke({}, config={"callbacks": [callbacks]}) @@ -360,7 +541,9 @@ def assert_intermediary_run(m): assert len(callbacks._parent_tree.items()) == 1 return [m] - (chain | RunnableLambda(assert_intermediary_run) | model).invoke({}, config={"callbacks": [callbacks]}) + (chain | RunnableLambda(assert_intermediary_run) | model).invoke( + {}, config={"callbacks": [callbacks]} + ) assert callbacks._runs == {} assert callbacks._parent_tree == {} @@ -375,7 +558,9 @@ def runnable(_): assert callbacks._runs == {} assert callbacks._parent_tree == {} - assert mock_client.capture.call_count == 0 + assert mock_client.capture.call_count == 1 + trace_call_args = mock_client.capture.call_args_list[0][1] + assert trace_call_args["event"] == "$ai_trace" def test_openai_error(mock_client): @@ -389,9 +574,9 @@ def test_openai_error(mock_client): assert callbacks._runs == {} assert callbacks._parent_tree == {} - assert mock_client.capture.call_count == 1 - args = mock_client.capture.call_args[1] - props = args["properties"] + assert mock_client.capture.call_count == 2 + generation_args = mock_client.capture.call_args_list[0][1] + props = generation_args["properties"] assert props["$ai_http_status"] == 401 assert props["$ai_input"] == [{"role": "user", "content": "Foo"}] assert "$ai_output_choices" not in props @@ -411,7 +596,12 @@ def test_openai_chain(mock_client): temperature=0, max_tokens=1, ) - callbacks = CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"}) + callbacks = CallbackHandler( + mock_client, + trace_id="test-trace-id", + distinct_id="test_id", + properties={"foo": "bar"}, + ) start_time = time.time() result = chain.invoke({}, config={"callbacks": [callbacks]}) approximate_latency = math.floor(time.time() - start_time) @@ -419,7 +609,7 @@ def test_openai_chain(mock_client): assert result.content == "Bar" assert mock_client.capture.call_count == 1 - first_call_args = mock_client.capture.call_args[1] + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_trace_id"] == "test-trace-id" @@ -454,7 +644,11 @@ def test_openai_chain(mock_client): ] assert first_call_props["$ai_http_status"] == 200 assert isinstance(first_call_props["$ai_latency"], float) - assert min(approximate_latency - 1, 0) <= math.floor(first_call_props["$ai_latency"]) <= approximate_latency + assert ( + min(approximate_latency - 1, 0) + <= math.floor(first_call_props["$ai_latency"]) + <= approximate_latency + ) assert first_call_props["$ai_input_tokens"] == 20 assert first_call_props["$ai_output_tokens"] == 1 @@ -480,7 +674,7 @@ def test_openai_captures_multiple_generations(mock_client): assert result.content == "Bar" assert mock_client.capture.call_count == 1 - first_call_args = mock_client.capture.call_args[1] + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] assert first_call_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, @@ -525,7 +719,12 @@ def test_openai_streaming(mock_client): ] ) chain = prompt | ChatOpenAI( - api_key=OPENAI_API_KEY, model="gpt-4o-mini", temperature=0, max_tokens=1, stream=True, stream_usage=True + api_key=OPENAI_API_KEY, + model="gpt-4o-mini", + temperature=0, + max_tokens=1, + stream=True, + stream_usage=True, ) callbacks = CallbackHandler(mock_client) result = [m for m in chain.stream({}, config={"callbacks": [callbacks]})] @@ -534,7 +733,7 @@ def test_openai_streaming(mock_client): assert result.content == "Bar" assert mock_client.capture.call_count == 1 - first_call_args = mock_client.capture.call_args[1] + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] assert first_call_props["$ai_model_parameters"]["stream"] @@ -542,7 +741,9 @@ def test_openai_streaming(mock_client): {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert first_call_props["$ai_output_choices"] == [ + {"role": "assistant", "content": "Bar"} + ] assert first_call_props["$ai_http_status"] == 200 assert first_call_props["$ai_input_tokens"] == 20 assert first_call_props["$ai_output_tokens"] == 1 @@ -557,7 +758,12 @@ async def test_async_openai_streaming(mock_client): ] ) chain = prompt | ChatOpenAI( - api_key=OPENAI_API_KEY, model="gpt-4o-mini", temperature=0, max_tokens=1, stream=True, stream_usage=True + api_key=OPENAI_API_KEY, + model="gpt-4o-mini", + temperature=0, + max_tokens=1, + stream=True, + stream_usage=True, ) callbacks = CallbackHandler(mock_client) result = [m async for m in chain.astream({}, config={"callbacks": [callbacks]})] @@ -566,7 +772,7 @@ async def test_async_openai_streaming(mock_client): assert result.content == "Bar" assert mock_client.capture.call_count == 1 - first_call_args = mock_client.capture.call_args[1] + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] assert first_call_props["$ai_model_parameters"]["stream"] @@ -574,7 +780,9 @@ async def test_async_openai_streaming(mock_client): {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert first_call_props["$ai_output_choices"] == [ + {"role": "assistant", "content": "Bar"} + ] assert first_call_props["$ai_http_status"] == 200 assert first_call_props["$ai_input_tokens"] == 20 assert first_call_props["$ai_output_tokens"] == 1 @@ -591,9 +799,9 @@ def test_base_url_retrieval(mock_client): with pytest.raises(Exception): chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 1 - call = mock_client.capture.call_args[1] - assert call["properties"]["$ai_base_url"] == "https://test.posthog.com" + assert mock_client.capture.call_count == 2 + generation_call = mock_client.capture.call_args_list[0][1] + assert generation_call["properties"]["$ai_base_url"] == "https://test.posthog.com" def test_groups(mock_client): @@ -608,9 +816,9 @@ def test_groups(mock_client): callbacks = CallbackHandler(mock_client, groups={"company": "test_company"}) chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 1 - call = mock_client.capture.call_args[1] - assert call["groups"] == {"company": "test_company"} + assert mock_client.capture.call_count == 2 + generation_call = mock_client.capture.call_args_list[0][1] + assert generation_call["groups"] == {"company": "test_company"} def test_privacy_mode_local(mock_client): @@ -625,10 +833,10 @@ def test_privacy_mode_local(mock_client): callbacks = CallbackHandler(mock_client, privacy_mode=True) chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 1 - call = mock_client.capture.call_args[1] - assert call["properties"]["$ai_input"] is None - assert call["properties"]["$ai_output_choices"] is None + assert mock_client.capture.call_count == 2 + generation_call = mock_client.capture.call_args_list[0][1] + assert generation_call["properties"]["$ai_input"] is None + assert generation_call["properties"]["$ai_output_choices"] is None def test_privacy_mode_global(mock_client): @@ -644,10 +852,10 @@ def test_privacy_mode_global(mock_client): callbacks = CallbackHandler(mock_client) chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 1 - call = mock_client.capture.call_args[1] - assert call["properties"]["$ai_input"] is None - assert call["properties"]["$ai_output_choices"] is None + assert mock_client.capture.call_count == 2 + generation_call = mock_client.capture.call_args_list[0][1] + assert generation_call["properties"]["$ai_input"] is None + assert generation_call["properties"]["$ai_output_choices"] is None @pytest.mark.skipif(not ANTHROPIC_API_KEY, reason="ANTHROPIC_API_KEY is not set") @@ -664,7 +872,12 @@ def test_anthropic_chain(mock_client): temperature=0, max_tokens=1, ) - callbacks = CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"}) + callbacks = CallbackHandler( + mock_client, + trace_id="test-trace-id", + distinct_id="test_id", + properties={"foo": "bar"}, + ) start_time = time.time() result = chain.invoke({}, config={"callbacks": [callbacks]}) approximate_latency = math.floor(time.time() - start_time) @@ -672,7 +885,7 @@ def test_anthropic_chain(mock_client): assert result.content == "Bar" assert mock_client.capture.call_count == 1 - first_call_args = mock_client.capture.call_args[1] + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_trace_id"] == "test-trace-id" @@ -689,10 +902,16 @@ def test_anthropic_chain(mock_client): {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert first_call_props["$ai_output_choices"] == [ + {"role": "assistant", "content": "Bar"} + ] assert first_call_props["$ai_http_status"] == 200 assert isinstance(first_call_props["$ai_latency"], float) - assert min(approximate_latency - 1, 0) <= math.floor(first_call_props["$ai_latency"]) <= approximate_latency + assert ( + min(approximate_latency - 1, 0) + <= math.floor(first_call_props["$ai_latency"]) + <= approximate_latency + ) assert first_call_props["$ai_input_tokens"] == 17 assert first_call_props["$ai_output_tokens"] == 1 @@ -720,14 +939,16 @@ async def test_async_anthropic_streaming(mock_client): assert result.content == "Bar" assert mock_client.capture.call_count == 1 - first_call_args = mock_client.capture.call_args[1] + first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] assert first_call_props["$ai_model_parameters"]["streaming"] assert first_call_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert first_call_props["$ai_output_choices"] == [ + {"role": "assistant", "content": "Bar"} + ] assert first_call_props["$ai_http_status"] == 200 assert first_call_props["$ai_input_tokens"] == 17 assert first_call_props["$ai_output_tokens"] is not None @@ -758,9 +979,9 @@ def test_tool_calls(mock_client): callbacks = CallbackHandler(mock_client) chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 1 - call = mock_client.capture.call_args[1] - assert call["properties"]["$ai_output_choices"][0]["tool_calls"] == [ + assert mock_client.capture.call_count == 2 + generation_call = mock_client.capture.call_args_list[0][1] + assert generation_call["properties"]["$ai_output_choices"][0]["tool_calls"] == [ { "type": "function", "id": "123", @@ -770,4 +991,7 @@ def test_tool_calls(mock_client): }, } ] - assert "additional_kwargs" not in call["properties"]["$ai_output_choices"][0] + assert ( + "additional_kwargs" + not in generation_call["properties"]["$ai_output_choices"][0] + ) diff --git a/setup.py b/setup.py index 06fb4580..c6815c7e 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ "django", "openai", "anthropic", + "langgraph", "langchain-community>=0.2.0", "langchain-openai>=0.2.0", "langchain-anthropic>=0.2.0", From e780da08fc89b2d1002623e099d48d99a7764034 Mon Sep 17 00:00:00 2001 From: Michael Matloka Date: Tue, 21 Jan 2025 20:39:20 +0100 Subject: [PATCH 3/8] Fix formatting --- posthog/ai/langchain/callbacks.py | 79 ++++---------- posthog/test/ai/langchain/test_callbacks.py | 110 +++++--------------- 2 files changed, 47 insertions(+), 142 deletions(-) diff --git a/posthog/ai/langchain/callbacks.py b/posthog/ai/langchain/callbacks.py index 4ea0306a..bc665917 100644 --- a/posthog/ai/langchain/callbacks.py +++ b/posthog/ai/langchain/callbacks.py @@ -1,9 +1,7 @@ try: import langchain # noqa: F401 except ImportError: - raise ModuleNotFoundError( - "Please install LangChain to use this feature: 'pip install langchain'" - ) + raise ModuleNotFoundError("Please install LangChain to use this feature: 'pip install langchain'") import json import logging @@ -126,13 +124,9 @@ def on_chat_model_start( parent_run_id: Optional[UUID] = None, **kwargs, ): - self._log_debug_event( - "on_chat_model_start", run_id, parent_run_id, messages=messages - ) + self._log_debug_event("on_chat_model_start", run_id, parent_run_id, messages=messages) self._set_parent_of_run(run_id, parent_run_id) - input = [ - _convert_message_to_dict(message) for row in messages for message in row - ] + input = [_convert_message_to_dict(message) for row in messages for message in row] self._set_run_metadata(serialized, run_id, input, **kwargs) def on_llm_start( @@ -157,9 +151,7 @@ def on_llm_new_token( **kwargs: Any, ) -> Any: """Run on new LLM token. Only available when streaming is enabled.""" - self.log.debug( - f"on llm new token: run_id: {run_id} parent_run_id: {parent_run_id}" - ) + self.log.debug(f"on llm new token: run_id: {run_id} parent_run_id: {parent_run_id}") def on_tool_start( self, @@ -172,9 +164,7 @@ def on_tool_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - self._log_debug_event( - "on_tool_start", run_id, parent_run_id, input_str=input_str - ) + self._log_debug_event("on_tool_start", run_id, parent_run_id, input_str=input_str) def on_tool_end( self, @@ -247,9 +237,7 @@ def on_llm_end( """ The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM. """ - self._log_debug_event( - "on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs - ) + self._log_debug_event("on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs) trace_id = self._get_trace_id(run_id) self._pop_parent_of_run(run_id) run = self._pop_run_metadata(run_id) @@ -262,24 +250,17 @@ def on_llm_end( generation_result = response.generations[-1] if isinstance(generation_result[-1], ChatGeneration): output = [ - _convert_message_to_dict(cast(ChatGeneration, generation).message) - for generation in generation_result + _convert_message_to_dict(cast(ChatGeneration, generation).message) for generation in generation_result ] else: - output = [ - _extract_raw_esponse(generation) for generation in generation_result - ] + output = [_extract_raw_esponse(generation) for generation in generation_result] event_properties = { "$ai_provider": run.get("provider"), "$ai_model": run.get("model"), "$ai_model_parameters": run.get("model_params"), - "$ai_input": with_privacy_mode( - self._client, self._privacy_mode, run.get("messages") - ), - "$ai_output_choices": with_privacy_mode( - self._client, self._privacy_mode, output - ), + "$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")), + "$ai_output_choices": with_privacy_mode(self._client, self._privacy_mode, output), "$ai_http_status": 200, "$ai_input_tokens": input_tokens, "$ai_output_tokens": output_tokens, @@ -318,9 +299,7 @@ def on_llm_error( "$ai_provider": run.get("provider"), "$ai_model": run.get("model"), "$ai_model_parameters": run.get("model_params"), - "$ai_input": with_privacy_mode( - self._client, self._privacy_mode, run.get("messages") - ), + "$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")), "$ai_http_status": _get_http_status(error), "$ai_latency": latency, "$ai_trace_id": trace_id, @@ -450,20 +429,14 @@ def _get_trace_id(self, run_id: UUID): trace_id = uuid.uuid4() return trace_id - def _end_trace( - self, trace_id: UUID, inputs: Dict[str, Any], outputs: Optional[Dict[str, Any]] - ): + def _end_trace(self, trace_id: UUID, inputs: Dict[str, Any], outputs: Optional[Dict[str, Any]]): event_properties = { "$ai_trace_id": trace_id, - "$ai_input_state": with_privacy_mode( - self._client, self._privacy_mode, inputs - ), + "$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, inputs), **self._properties, } if outputs is not None: - event_properties["$ai_output_state"] = with_privacy_mode( - self._client, self._privacy_mode, outputs - ) + event_properties["$ai_output_state"] = with_privacy_mode(self._client, self._privacy_mode, outputs) if self._distinct_id is None: event_properties["$process_person_profile"] = False self._client.capture( @@ -545,9 +518,7 @@ def _parse_usage_model( if model_key in usage: captured_count = usage[model_key] final_count = ( - sum(captured_count) - if isinstance(captured_count, list) - else captured_count + sum(captured_count) if isinstance(captured_count, list) else captured_count ) # For Bedrock, the token count is a list when streamed parsed_usage[type_key] = final_count @@ -568,12 +539,8 @@ def _parse_usage(response: LLMResult): if hasattr(response, "generations"): for generation in response.generations: for generation_chunk in generation: - if generation_chunk.generation_info and ( - "usage_metadata" in generation_chunk.generation_info - ): - llm_usage = _parse_usage_model( - generation_chunk.generation_info["usage_metadata"] - ) + if generation_chunk.generation_info and ("usage_metadata" in generation_chunk.generation_info): + llm_usage = _parse_usage_model(generation_chunk.generation_info["usage_metadata"]) break message_chunk = getattr(generation_chunk, "message", {}) @@ -585,19 +552,13 @@ def _parse_usage(response: LLMResult): else None ) bedrock_titan_usage = ( - response_metadata.get( - "amazon-bedrock-invocationMetrics", None - ) # for Bedrock-Titan + response_metadata.get("amazon-bedrock-invocationMetrics", None) # for Bedrock-Titan if isinstance(response_metadata, dict) else None ) - ollama_usage = getattr( - message_chunk, "usage_metadata", None - ) # for Ollama + ollama_usage = getattr(message_chunk, "usage_metadata", None) # for Ollama - chunk_usage = ( - bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage - ) + chunk_usage = bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage if chunk_usage: llm_usage = _parse_usage_model(chunk_usage) break diff --git a/posthog/test/ai/langchain/test_callbacks.py b/posthog/test/ai/langchain/test_callbacks.py index 84803c31..3eb3d97d 100644 --- a/posthog/test/ai/langchain/test_callbacks.py +++ b/posthog/test/ai/langchain/test_callbacks.py @@ -3,7 +3,7 @@ import os from pyexpat.errors import messages import time -from typing import Optional, TypedDict +from typing import List, Optional, TypedDict, Union import uuid from unittest.mock import patch, ANY @@ -169,9 +169,7 @@ async def test_async_basic_chat_chain(mock_client, stream): callbacks = [CallbackHandler(mock_client)] chain = prompt | model if stream: - result = [m async for m in chain.astream({}, config={"callbacks": callbacks})][ - 0 - ] + result = [m async for m in chain.astream({}, config={"callbacks": callbacks})][0] else: result = await chain.ainvoke({}, config={"callbacks": callbacks}) assert result.content == "The Los Angeles Dodgers won the World Series in 2020." @@ -218,21 +216,14 @@ async def test_async_basic_chat_chain(mock_client, stream): ) def test_basic_llm_chain(mock_client, Model, stream): model = Model(responses=["The Los Angeles Dodgers won the World Series in 2020."]) - callbacks: list[CallbackHandler] = [CallbackHandler(mock_client)] + callbacks: List[CallbackHandler] = [CallbackHandler(mock_client)] if stream: result = "".join( - [ - m - for m in model.stream( - "Who won the world series in 2020?", config={"callbacks": callbacks} - ) - ] + [m for m in model.stream("Who won the world series in 2020?", config={"callbacks": callbacks})] ) else: - result = model.invoke( - "Who won the world series in 2020?", config={"callbacks": callbacks} - ) + result = model.invoke("Who won the world series in 2020?", config={"callbacks": callbacks}) assert result == "The Los Angeles Dodgers won the World Series in 2020." assert mock_client.capture.call_count == 1 @@ -244,9 +235,7 @@ def test_basic_llm_chain(mock_client, Model, stream): assert "$ai_model" in props assert "$ai_provider" in props assert props["$ai_input"] == ["Who won the world series in 2020?"] - assert props["$ai_output_choices"] == [ - "The Los Angeles Dodgers won the World Series in 2020." - ] + assert props["$ai_output_choices"] == ["The Los Angeles Dodgers won the World Series in 2020."] assert props["$ai_http_status"] == 200 assert props["$ai_trace_id"] is not None assert isinstance(props["$ai_latency"], float) @@ -263,21 +252,14 @@ def test_basic_llm_chain(mock_client, Model, stream): ) async def test_async_basic_llm_chain(mock_client, Model, stream): model = Model(responses=["The Los Angeles Dodgers won the World Series in 2020."]) - callbacks: list[CallbackHandler] = [CallbackHandler(mock_client)] + callbacks: List[CallbackHandler] = [CallbackHandler(mock_client)] if stream: result = "".join( - [ - m - async for m in model.astream( - "Who won the world series in 2020?", config={"callbacks": callbacks} - ) - ] + [m async for m in model.astream("Who won the world series in 2020?", config={"callbacks": callbacks})] ) else: - result = await model.ainvoke( - "Who won the world series in 2020?", config={"callbacks": callbacks} - ) + result = await model.ainvoke("Who won the world series in 2020?", config={"callbacks": callbacks}) assert result == "The Los Angeles Dodgers won the World Series in 2020." assert mock_client.capture.call_count == 1 @@ -289,9 +271,7 @@ async def test_async_basic_llm_chain(mock_client, Model, stream): assert "$ai_model" in props assert "$ai_provider" in props assert props["$ai_input"] == ["Who won the world series in 2020?"] - assert props["$ai_output_choices"] == [ - "The Los Angeles Dodgers won the World Series in 2020." - ] + assert props["$ai_output_choices"] == ["The Los Angeles Dodgers won the World Series in 2020."] assert props["$ai_http_status"] == 200 assert props["$ai_trace_id"] is not None assert isinstance(props["$ai_latency"], float) @@ -318,9 +298,7 @@ def test_trace_id_for_multiple_chains(mock_client): assert "$ai_model" in first_generation_props assert "$ai_provider" in first_generation_props assert first_generation_props["$ai_input"] == [{"role": "user", "content": "Foo"}] - assert first_generation_props["$ai_output_choices"] == [ - {"role": "assistant", "content": "Bar"} - ] + assert first_generation_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] assert first_generation_props["$ai_http_status"] == 200 assert first_generation_props["$ai_trace_id"] is not None assert isinstance(first_generation_props["$ai_latency"], float) @@ -331,12 +309,8 @@ def test_trace_id_for_multiple_chains(mock_client): assert "distinct_id" in second_generation_args assert "$ai_model" in second_generation_props assert "$ai_provider" in second_generation_props - assert second_generation_props["$ai_input"] == [ - {"role": "assistant", "content": "Bar"} - ] - assert second_generation_props["$ai_output_choices"] == [ - {"role": "assistant", "content": "Bar"} - ] + assert second_generation_props["$ai_input"] == [{"role": "assistant", "content": "Bar"}] + assert second_generation_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] assert second_generation_props["$ai_http_status"] == 200 assert second_generation_props["$ai_trace_id"] is not None assert isinstance(second_generation_props["$ai_latency"], float) @@ -350,10 +324,7 @@ def test_trace_id_for_multiple_chains(mock_client): assert trace_props["$ai_trace_id"] is not None # Check that the trace_id is the same as the first call - assert ( - first_generation_props["$ai_trace_id"] - == second_generation_props["$ai_trace_id"] - ) + assert first_generation_props["$ai_trace_id"] == second_generation_props["$ai_trace_id"] assert first_generation_props["$ai_trace_id"] == trace_props["$ai_trace_id"] @@ -370,9 +341,7 @@ def test_personless_mode(mock_client): assert trace_args["properties"]["$process_person_profile"] is False id = uuid.uuid4() - chain.invoke( - {}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]} - ) + chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]}) assert mock_client.capture.call_count == 4 generation_args = mock_client.capture.call_args_list[2][1] trace_args = mock_client.capture.call_args_list[3][1] @@ -398,9 +367,7 @@ def test_personless_mode_exception(mock_client): id = uuid.uuid4() with pytest.raises(Exception): - chain.invoke( - {}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]} - ) + chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]}) assert mock_client.capture.call_count == 4 generation_args = mock_client.capture.call_args_list[2][1] trace_args = mock_client.capture.call_args_list[3][1] @@ -438,9 +405,7 @@ def test_metadata(mock_client): assert first_call_props["$ai_trace_id"] == "test-trace-id" assert first_call_props["foo"] == "bar" assert first_call_props["$ai_input"] == [{"role": "user", "content": "Foo"}] - assert first_call_props["$ai_output_choices"] == [ - {"role": "assistant", "content": "Bar"} - ] + assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] assert first_call_props["$ai_http_status"] == 200 assert isinstance(first_call_props["$ai_latency"], float) @@ -456,7 +421,7 @@ def test_metadata(mock_client): class FakeGraphState(TypedDict): - messages: list[HumanMessage | AIMessage] + messages: List[Union[HumanMessage, AIMessage]] xyz: Optional[str] @@ -504,7 +469,7 @@ def test_graph_state(mock_client): assert isinstance(result["messages"][0], HumanMessage) assert result["messages"][0].content == "What's a bar?" assert isinstance(result["messages"][1], AIMessage) - assert result["messages"][1].content == "Let's explore bar." # TODO + assert result["messages"][1].content == "Let's explore bar." # TODO assert mock_client.capture.call_count == 3 generation_args = mock_client.capture.call_args_list[0][1] @@ -541,9 +506,7 @@ def assert_intermediary_run(m): assert len(callbacks._parent_tree.items()) == 1 return [m] - (chain | RunnableLambda(assert_intermediary_run) | model).invoke( - {}, config={"callbacks": [callbacks]} - ) + (chain | RunnableLambda(assert_intermediary_run) | model).invoke({}, config={"callbacks": [callbacks]}) assert callbacks._runs == {} assert callbacks._parent_tree == {} @@ -644,11 +607,7 @@ def test_openai_chain(mock_client): ] assert first_call_props["$ai_http_status"] == 200 assert isinstance(first_call_props["$ai_latency"], float) - assert ( - min(approximate_latency - 1, 0) - <= math.floor(first_call_props["$ai_latency"]) - <= approximate_latency - ) + assert min(approximate_latency - 1, 0) <= math.floor(first_call_props["$ai_latency"]) <= approximate_latency assert first_call_props["$ai_input_tokens"] == 20 assert first_call_props["$ai_output_tokens"] == 1 @@ -741,9 +700,7 @@ def test_openai_streaming(mock_client): {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [ - {"role": "assistant", "content": "Bar"} - ] + assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] assert first_call_props["$ai_http_status"] == 200 assert first_call_props["$ai_input_tokens"] == 20 assert first_call_props["$ai_output_tokens"] == 1 @@ -780,9 +737,7 @@ async def test_async_openai_streaming(mock_client): {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [ - {"role": "assistant", "content": "Bar"} - ] + assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] assert first_call_props["$ai_http_status"] == 200 assert first_call_props["$ai_input_tokens"] == 20 assert first_call_props["$ai_output_tokens"] == 1 @@ -902,16 +857,10 @@ def test_anthropic_chain(mock_client): {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [ - {"role": "assistant", "content": "Bar"} - ] + assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] assert first_call_props["$ai_http_status"] == 200 assert isinstance(first_call_props["$ai_latency"], float) - assert ( - min(approximate_latency - 1, 0) - <= math.floor(first_call_props["$ai_latency"]) - <= approximate_latency - ) + assert min(approximate_latency - 1, 0) <= math.floor(first_call_props["$ai_latency"]) <= approximate_latency assert first_call_props["$ai_input_tokens"] == 17 assert first_call_props["$ai_output_tokens"] == 1 @@ -946,9 +895,7 @@ async def test_async_anthropic_streaming(mock_client): {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [ - {"role": "assistant", "content": "Bar"} - ] + assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] assert first_call_props["$ai_http_status"] == 200 assert first_call_props["$ai_input_tokens"] == 17 assert first_call_props["$ai_output_tokens"] is not None @@ -991,7 +938,4 @@ def test_tool_calls(mock_client): }, } ] - assert ( - "additional_kwargs" - not in generation_call["properties"]["$ai_output_choices"][0] - ) + assert "additional_kwargs" not in generation_call["properties"]["$ai_output_choices"][0] From e81d61d30db61820034d99f3fab2b9b4f4175542 Mon Sep 17 00:00:00 2001 From: Michael Matloka Date: Tue, 21 Jan 2025 21:53:36 +0100 Subject: [PATCH 4/8] Fix input capture --- posthog/ai/langchain/callbacks.py | 74 +++++++--- posthog/test/ai/langchain/test_callbacks.py | 149 ++++++++++++-------- 2 files changed, 143 insertions(+), 80 deletions(-) diff --git a/posthog/ai/langchain/callbacks.py b/posthog/ai/langchain/callbacks.py index bc665917..941d012c 100644 --- a/posthog/ai/langchain/callbacks.py +++ b/posthog/ai/langchain/callbacks.py @@ -3,7 +3,6 @@ except ImportError: raise ModuleNotFoundError("Please install LangChain to use this feature: 'pip install langchain'") -import json import logging import time import uuid @@ -59,14 +58,25 @@ class CallbackHandler(BaseCallbackHandler): _client: Client """PostHog client instance.""" + _distinct_id: Optional[Union[str, int, float, UUID]] """Distinct ID of the user to associate the trace with.""" + _trace_id: Optional[Union[str, int, float, UUID]] """Global trace ID to be sent with every event. Otherwise, the top-level run ID is used.""" + + _trace_input: Optional[Any] + """The input at the start of the trace. Any JSON object.""" + + _trace_name: Optional[str] + """Name of the trace, exposed in the UI.""" + _properties: Optional[Dict[str, Any]] """Global properties to be sent with every event.""" + _runs: RunStorage """Mapping of run IDs to run metadata as run metadata is only available on the start of generation.""" + _parent_tree: Dict[UUID, UUID] """ A dictionary that maps chain run IDs to their parent chain run IDs (parent pointer tree), @@ -95,6 +105,8 @@ def __init__( self._client = client or default_client self._distinct_id = distinct_id self._trace_id = trace_id + self._trace_name = None + self._trace_input = None self._properties = properties or {} self._privacy_mode = privacy_mode self._groups = groups or {} @@ -113,7 +125,9 @@ def on_chain_start( ): self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs) self._set_parent_of_run(run_id, parent_run_id) - self._set_run_metadata(serialized, run_id, inputs, metadata, **kwargs) + if parent_run_id is None and self._trace_name is None: + self._trace_name = self._get_langchain_run_name(serialized, **kwargs) + self._trace_input = inputs def on_chat_model_start( self, @@ -151,7 +165,7 @@ def on_llm_new_token( **kwargs: Any, ) -> Any: """Run on new LLM token. Only available when streaming is enabled.""" - self.log.debug(f"on llm new token: run_id: {run_id} parent_run_id: {parent_run_id}") + self._log_debug_event("on_llm_new_token", run_id, parent_run_id, token=token) def on_tool_start( self, @@ -160,7 +174,6 @@ def on_tool_start( *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: @@ -192,19 +205,13 @@ def on_chain_end( *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, **kwargs: Any, ): self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs) self._pop_parent_of_run(run_id) - run_metadata = self._pop_run_metadata(run_id) if parent_run_id is None: - self._end_trace( - self._get_trace_id(run_id), - inputs=run_metadata.get("messages") if run_metadata else None, - outputs=outputs, - ) + self._capture_trace(run_id, outputs=outputs) def on_chain_error( self, @@ -216,14 +223,9 @@ def on_chain_error( ): self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error) self._pop_parent_of_run(run_id) - run_metadata = self._pop_run_metadata(run_id) if parent_run_id is None: - self._end_trace( - self._get_trace_id(run_id), - inputs=run_metadata.get("messages") if run_metadata else None, - outputs=None, - ) + self._capture_trace(run_id, outputs=None) def on_llm_end( self, @@ -231,7 +233,6 @@ def on_llm_end( *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, **kwargs: Any, ): """ @@ -284,7 +285,6 @@ def on_llm_error( *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, **kwargs: Any, ): self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error) @@ -322,7 +322,6 @@ def on_retriever_start( *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: @@ -429,10 +428,41 @@ def _get_trace_id(self, run_id: UUID): trace_id = uuid.uuid4() return trace_id - def _end_trace(self, trace_id: UUID, inputs: Dict[str, Any], outputs: Optional[Dict[str, Any]]): + def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> str: + """Retrieve the name of a serialized LangChain runnable. + + The prioritization for the determination of the run name is as follows: + - The value assigned to the "name" key in `kwargs`. + - The value assigned to the "name" key in `serialized`. + - The last entry of the value assigned to the "id" key in `serialized`. + - "". + + Args: + serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data. + **kwargs (Any): Additional keyword arguments, potentially including the 'name' override. + + Returns: + str: The determined name of the Langchain runnable. + """ + if "name" in kwargs and kwargs["name"] is not None: + return kwargs["name"] + + try: + return serialized["name"] + except (KeyError, TypeError): + pass + + try: + return serialized["id"][-1] + except (KeyError, TypeError): + pass + + def _capture_trace(self, run_id: UUID, *, outputs: Optional[Dict[str, Any]]): + trace_id = self._get_trace_id(run_id) event_properties = { + "$ai_trace_name": self._trace_name, "$ai_trace_id": trace_id, - "$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, inputs), + "$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, self._trace_input), **self._properties, } if outputs is not None: diff --git a/posthog/test/ai/langchain/test_callbacks.py b/posthog/test/ai/langchain/test_callbacks.py index 3eb3d97d..add7c5c9 100644 --- a/posthog/test/ai/langchain/test_callbacks.py +++ b/posthog/test/ai/langchain/test_callbacks.py @@ -319,8 +319,9 @@ def test_trace_id_for_multiple_chains(mock_client): trace_props = trace_args["properties"] assert trace_args["event"] == "$ai_trace" assert "distinct_id" in trace_args - assert trace_props["$ai_input_state"] == [{"role": "assistant", "content": "Bar"}] - assert trace_props["$ai_output_state"] == [{"role": "assistant", "content": "Bar"}] + assert trace_props["$ai_input_state"] == {} + assert isinstance(trace_props["$ai_output_state"], AIMessage) + assert trace_props["$ai_output_state"].content == "Bar" assert trace_props["$ai_trace_id"] is not None # Check that the trace_id is the same as the first call @@ -393,31 +394,31 @@ def test_metadata(mock_client): ) ] chain = prompt | model - result = chain.invoke({}, config={"callbacks": callbacks}) + result = chain.invoke({"plan": None}, config={"callbacks": callbacks}) assert result.content == "Bar" assert mock_client.capture.call_count == 2 - first_call_args = mock_client.capture.call_args_list[0][1] - first_call_props = first_call_args["properties"] - assert first_call_args["distinct_id"] == "test_id" - assert first_call_args["event"] == "$ai_generation" - assert first_call_props["$ai_trace_id"] == "test-trace-id" - assert first_call_props["foo"] == "bar" - assert first_call_props["$ai_input"] == [{"role": "user", "content": "Foo"}] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] - assert first_call_props["$ai_http_status"] == 200 - assert isinstance(first_call_props["$ai_latency"], float) - - second_call_args = mock_client.capture.call_args_list[1][1] - second_call_props = second_call_args["properties"] - assert second_call_args["distinct_id"] == "test_id" - assert second_call_args["event"] == "$ai_trace" - assert second_call_props["$ai_trace_id"] == "test-trace-id" - assert second_call_props["foo"] == "bar" - assert isinstance(second_call_props["$ai_output_state"], AIMessage) - assert second_call_props["$ai_output_state"].content == "Bar" - assert second_call_props["$ai_input_state"] == "Foo" + generation_call_args = mock_client.capture.call_args_list[0][1] + generation_call_props = generation_call_args["properties"] + assert generation_call_args["distinct_id"] == "test_id" + assert generation_call_args["event"] == "$ai_generation" + assert generation_call_props["$ai_trace_id"] == "test-trace-id" + assert generation_call_props["foo"] == "bar" + assert generation_call_props["$ai_input"] == [{"role": "user", "content": "Foo"}] + assert generation_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert generation_call_props["$ai_http_status"] == 200 + assert isinstance(generation_call_props["$ai_latency"], float) + + trace_call_args = mock_client.capture.call_args_list[1][1] + trace_call_props = trace_call_args["properties"] + assert trace_call_args["distinct_id"] == "test_id" + assert trace_call_args["event"] == "$ai_trace" + assert trace_call_props["$ai_trace_id"] == "test-trace-id" + assert trace_call_props["foo"] == "bar" + assert trace_call_props["$ai_input_state"] == {"plan": None} + assert isinstance(trace_call_props["$ai_output_state"], AIMessage) + assert trace_call_props["$ai_output_state"].content == "Bar" class FakeGraphState(TypedDict): @@ -469,21 +470,24 @@ def test_graph_state(mock_client): assert isinstance(result["messages"][0], HumanMessage) assert result["messages"][0].content == "What's a bar?" assert isinstance(result["messages"][1], AIMessage) - assert result["messages"][1].content == "Let's explore bar." # TODO + assert result["messages"][1].content == "Let's explore bar." assert mock_client.capture.call_count == 3 generation_args = mock_client.capture.call_args_list[0][1] trace_args = mock_client.capture.call_args_list[2][1] assert generation_args["event"] == "$ai_generation" assert trace_args["event"] == "$ai_trace" - assert trace_args["properties"]["$ai_input_state"] == { - "messages": [HumanMessage(content="What's a bar?")], - "xyz": None, - } - assert trace_args["properties"]["$ai_output_state"] == { - "messages": [AIMessage(content="Let's explore bar.")], - "xyz": "abc", - } + assert len(trace_args["properties"]["$ai_input_state"]["messages"]) == 1 + assert isinstance(trace_args["properties"]["$ai_input_state"]["messages"][0], HumanMessage) + assert trace_args["properties"]["$ai_input_state"]["messages"][0].content == "What's a bar?" + assert trace_args["properties"]["$ai_input_state"]["messages"][0].type == "human" + assert trace_args["properties"]["$ai_input_state"]["xyz"] is None + assert len(trace_args["properties"]["$ai_output_state"]["messages"]) == 2 + assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][0], HumanMessage) # FIXME + assert trace_args["properties"]["$ai_output_state"]["messages"][0].content == "What's a bar?" + assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][1], AIMessage) + assert trace_args["properties"]["$ai_output_state"]["messages"][1].content == "Let's explore bar." + assert trace_args["properties"]["$ai_output_state"]["xyz"] == "abc" def test_callbacks_logic(mock_client): @@ -570,7 +574,7 @@ def test_openai_chain(mock_client): approximate_latency = math.floor(time.time() - start_time) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] @@ -598,13 +602,7 @@ def test_openai_chain(mock_client): {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [ - { - "role": "assistant", - "content": "Bar", - "additional_kwargs": {"refusal": None}, - } - ] + assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar", "refusal": None}] assert first_call_props["$ai_http_status"] == 200 assert isinstance(first_call_props["$ai_latency"], float) assert min(approximate_latency - 1, 0) <= math.floor(first_call_props["$ai_latency"]) <= approximate_latency @@ -631,20 +629,20 @@ def test_openai_captures_multiple_generations(mock_client): result = chain.invoke({}, config={"callbacks": [callbacks]}) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 - first_call_args = mock_client.capture.call_args_list[0][1] - first_call_props = first_call_args["properties"] - assert first_call_props["$ai_input"] == [ + generation_call_args = mock_client.capture.call_args_list[0][1] + generation_call_props = generation_call_args["properties"] + trace_call_args = mock_client.capture.call_args_list[1][1] + trace_call_props = trace_call_args["properties"] + + assert generation_call_args["event"] == "$ai_generation" + assert generation_call_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [ - { - "role": "assistant", - "content": "Bar", - "additional_kwargs": {"refusal": None}, - }, + assert generation_call_props["$ai_output_choices"] == [ + {"role": "assistant", "content": "Bar", "refusal": None}, { "role": "assistant", "content": "Bar", @@ -652,21 +650,25 @@ def test_openai_captures_multiple_generations(mock_client): ] # langchain-openai for langchain v3 - if "max_completion_tokens" in first_call_props["$ai_model_parameters"]: - assert first_call_props["$ai_model_parameters"] == { + if "max_completion_tokens" in generation_call_props["$ai_model_parameters"]: + assert generation_call_props["$ai_model_parameters"] == { "temperature": 0.0, "max_completion_tokens": 1, "stream": False, "n": 2, } else: - assert first_call_props["$ai_model_parameters"] == { + assert generation_call_props["$ai_model_parameters"] == { "temperature": 0.0, "max_tokens": 1, "stream": False, "n": 2, } - assert first_call_props["$ai_http_status"] == 200 + assert generation_call_props["$ai_http_status"] == 200 + + assert trace_call_args["event"] == "$ai_trace" + assert trace_call_props["$ai_input_state"] == {} + assert isinstance(trace_call_props["$ai_output_state"], AIMessage) @pytest.mark.skipif(not OPENAI_API_KEY, reason="OpenAI API key not set") @@ -690,11 +692,14 @@ def test_openai_streaming(mock_client): result = sum(result[1:], result[0]) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] + second_call_args = mock_client.capture.call_args_list[1][1] + second_call_props = second_call_args["properties"] + assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_model_parameters"]["stream"] assert first_call_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, @@ -705,6 +710,10 @@ def test_openai_streaming(mock_client): assert first_call_props["$ai_input_tokens"] == 20 assert first_call_props["$ai_output_tokens"] == 1 + assert second_call_args["event"] == "$ai_trace" + assert second_call_props["$ai_input_state"] == {"input": ""} + assert isinstance(second_call_props["$ai_output_state"], AIMessage) + @pytest.mark.skipif(not OPENAI_API_KEY, reason="OpenAI API key not set") async def test_async_openai_streaming(mock_client): @@ -727,11 +736,14 @@ async def test_async_openai_streaming(mock_client): result = sum(result[1:], result[0]) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] + second_call_args = mock_client.capture.call_args_list[1][1] + second_call_props = second_call_args["properties"] + assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_model_parameters"]["stream"] assert first_call_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, @@ -742,6 +754,10 @@ async def test_async_openai_streaming(mock_client): assert first_call_props["$ai_input_tokens"] == 20 assert first_call_props["$ai_output_tokens"] == 1 + assert second_call_args["event"] == "$ai_trace" + assert second_call_props["$ai_input_state"] == {"input": ""} + assert isinstance(second_call_props["$ai_output_state"], AIMessage) + def test_base_url_retrieval(mock_client): prompt = ChatPromptTemplate.from_messages([("user", "Foo")]) @@ -838,10 +854,13 @@ def test_anthropic_chain(mock_client): approximate_latency = math.floor(time.time() - start_time) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] + second_call_args = mock_client.capture.call_args_list[1][1] + second_call_props = second_call_args["properties"] + assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_trace_id"] == "test-trace-id" assert first_call_props["$ai_provider"] == "anthropic" @@ -864,6 +883,10 @@ def test_anthropic_chain(mock_client): assert first_call_props["$ai_input_tokens"] == 17 assert first_call_props["$ai_output_tokens"] == 1 + assert second_call_args["event"] == "$ai_trace" + assert second_call_props["$ai_input_state"] == {} + assert isinstance(second_call_props["$ai_output_state"], AIMessage) + @pytest.mark.skipif(not ANTHROPIC_API_KEY, reason="ANTHROPIC_API_KEY is not set") async def test_async_anthropic_streaming(mock_client): @@ -886,10 +909,14 @@ async def test_async_anthropic_streaming(mock_client): result = sum(result[1:], result[0]) assert result.content == "Bar" - assert mock_client.capture.call_count == 1 + assert mock_client.capture.call_count == 2 first_call_args = mock_client.capture.call_args_list[0][1] first_call_props = first_call_args["properties"] + second_call_args = mock_client.capture.call_args_list[1][1] + second_call_props = second_call_args["properties"] + + assert first_call_args["event"] == "$ai_generation" assert first_call_props["$ai_model_parameters"]["streaming"] assert first_call_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, @@ -900,6 +927,12 @@ async def test_async_anthropic_streaming(mock_client): assert first_call_props["$ai_input_tokens"] == 17 assert first_call_props["$ai_output_tokens"] is not None + assert second_call_args["event"] == "$ai_trace" + assert second_call_props["$ai_input_state"] == { + "input": "", + } + assert isinstance(second_call_props["$ai_output_state"], AIMessage) + def test_tool_calls(mock_client): prompt = ChatPromptTemplate.from_messages([("user", "Foo")]) From 7d77b5830a6b453a452c166ced9d7cc80373d866 Mon Sep 17 00:00:00 2001 From: Michael Matloka Date: Tue, 21 Jan 2025 21:55:43 +0100 Subject: [PATCH 5/8] Keep some vars in tests unchanged --- posthog/test/ai/langchain/test_callbacks.py | 54 ++++++++++----------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/posthog/test/ai/langchain/test_callbacks.py b/posthog/test/ai/langchain/test_callbacks.py index add7c5c9..e0839713 100644 --- a/posthog/test/ai/langchain/test_callbacks.py +++ b/posthog/test/ai/langchain/test_callbacks.py @@ -291,17 +291,17 @@ def test_trace_id_for_multiple_chains(mock_client): assert result.content == "Bar" assert mock_client.capture.call_count == 3 - first_generation_args = mock_client.capture.call_args_list[0][1] - first_generation_props = first_generation_args["properties"] - assert first_generation_args["event"] == "$ai_generation" - assert "distinct_id" in first_generation_args - assert "$ai_model" in first_generation_props - assert "$ai_provider" in first_generation_props - assert first_generation_props["$ai_input"] == [{"role": "user", "content": "Foo"}] - assert first_generation_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] - assert first_generation_props["$ai_http_status"] == 200 - assert first_generation_props["$ai_trace_id"] is not None - assert isinstance(first_generation_props["$ai_latency"], float) + first_call_args = mock_client.capture.call_args_list[0][1] + first_call_props = first_call_args["properties"] + assert first_call_args["event"] == "$ai_generation" + assert "distinct_id" in first_call_args + assert "$ai_model" in first_call_props + assert "$ai_provider" in first_call_props + assert first_call_props["$ai_input"] == [{"role": "user", "content": "Foo"}] + assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert first_call_props["$ai_http_status"] == 200 + assert first_call_props["$ai_trace_id"] is not None + assert isinstance(first_call_props["$ai_latency"], float) second_generation_args = mock_client.capture.call_args_list[1][1] second_generation_props = second_generation_args["properties"] @@ -325,8 +325,8 @@ def test_trace_id_for_multiple_chains(mock_client): assert trace_props["$ai_trace_id"] is not None # Check that the trace_id is the same as the first call - assert first_generation_props["$ai_trace_id"] == second_generation_props["$ai_trace_id"] - assert first_generation_props["$ai_trace_id"] == trace_props["$ai_trace_id"] + assert first_call_props["$ai_trace_id"] == second_generation_props["$ai_trace_id"] + assert first_call_props["$ai_trace_id"] == trace_props["$ai_trace_id"] def test_personless_mode(mock_client): @@ -631,17 +631,17 @@ def test_openai_captures_multiple_generations(mock_client): assert result.content == "Bar" assert mock_client.capture.call_count == 2 - generation_call_args = mock_client.capture.call_args_list[0][1] - generation_call_props = generation_call_args["properties"] - trace_call_args = mock_client.capture.call_args_list[1][1] - trace_call_props = trace_call_args["properties"] + first_call_args = mock_client.capture.call_args_list[0][1] + first_call_props = first_call_args["properties"] + second_call_args = mock_client.capture.call_args_list[1][1] + second_call_props = second_call_args["properties"] - assert generation_call_args["event"] == "$ai_generation" - assert generation_call_props["$ai_input"] == [ + assert first_call_args["event"] == "$ai_generation" + assert first_call_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert generation_call_props["$ai_output_choices"] == [ + assert first_call_props["$ai_output_choices"] == [ {"role": "assistant", "content": "Bar", "refusal": None}, { "role": "assistant", @@ -650,25 +650,25 @@ def test_openai_captures_multiple_generations(mock_client): ] # langchain-openai for langchain v3 - if "max_completion_tokens" in generation_call_props["$ai_model_parameters"]: - assert generation_call_props["$ai_model_parameters"] == { + if "max_completion_tokens" in first_call_props["$ai_model_parameters"]: + assert first_call_props["$ai_model_parameters"] == { "temperature": 0.0, "max_completion_tokens": 1, "stream": False, "n": 2, } else: - assert generation_call_props["$ai_model_parameters"] == { + assert first_call_props["$ai_model_parameters"] == { "temperature": 0.0, "max_tokens": 1, "stream": False, "n": 2, } - assert generation_call_props["$ai_http_status"] == 200 + assert first_call_props["$ai_http_status"] == 200 - assert trace_call_args["event"] == "$ai_trace" - assert trace_call_props["$ai_input_state"] == {} - assert isinstance(trace_call_props["$ai_output_state"], AIMessage) + assert second_call_args["event"] == "$ai_trace" + assert second_call_props["$ai_input_state"] == {} + assert isinstance(second_call_props["$ai_output_state"], AIMessage) @pytest.mark.skipif(not OPENAI_API_KEY, reason="OpenAI API key not set") From d57e3f2f198072ce5d385430e96bdf50e6b90079 Mon Sep 17 00:00:00 2001 From: Michael Matloka Date: Tue, 21 Jan 2025 22:09:20 +0100 Subject: [PATCH 6/8] Add $ai_trace_name assertions --- posthog/test/ai/langchain/test_callbacks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/posthog/test/ai/langchain/test_callbacks.py b/posthog/test/ai/langchain/test_callbacks.py index e0839713..4e9d74bf 100644 --- a/posthog/test/ai/langchain/test_callbacks.py +++ b/posthog/test/ai/langchain/test_callbacks.py @@ -323,6 +323,7 @@ def test_trace_id_for_multiple_chains(mock_client): assert isinstance(trace_props["$ai_output_state"], AIMessage) assert trace_props["$ai_output_state"].content == "Bar" assert trace_props["$ai_trace_id"] is not None + assert trace_props["$ai_trace_name"] == "RunnableSequence" # Check that the trace_id is the same as the first call assert first_call_props["$ai_trace_id"] == second_generation_props["$ai_trace_id"] @@ -415,6 +416,7 @@ def test_metadata(mock_client): assert trace_call_args["distinct_id"] == "test_id" assert trace_call_args["event"] == "$ai_trace" assert trace_call_props["$ai_trace_id"] == "test-trace-id" + assert trace_call_props["$ai_trace_name"] == "RunnableSequence" assert trace_call_props["foo"] == "bar" assert trace_call_props["$ai_input_state"] == {"plan": None} assert isinstance(trace_call_props["$ai_output_state"], AIMessage) @@ -477,6 +479,7 @@ def test_graph_state(mock_client): trace_args = mock_client.capture.call_args_list[2][1] assert generation_args["event"] == "$ai_generation" assert trace_args["event"] == "$ai_trace" + assert trace_args["properties"]["$ai_trace_name"] == "LangGraph" assert len(trace_args["properties"]["$ai_input_state"]["messages"]) == 1 assert isinstance(trace_args["properties"]["$ai_input_state"]["messages"][0], HumanMessage) assert trace_args["properties"]["$ai_input_state"]["messages"][0].content == "What's a bar?" @@ -528,6 +531,7 @@ def runnable(_): assert mock_client.capture.call_count == 1 trace_call_args = mock_client.capture.call_args_list[0][1] assert trace_call_args["event"] == "$ai_trace" + assert trace_call_args["properties"]["$ai_trace_name"] == "runnable" def test_openai_error(mock_client): From 93977f9a9871b9ec712e37e8194cc16398f861c6 Mon Sep 17 00:00:00 2001 From: Michael Matloka Date: Tue, 21 Jan 2025 22:12:48 +0100 Subject: [PATCH 7/8] Make flake8 happy --- .github/workflows/ci.yml | 2 +- posthog/ai/langchain/callbacks.py | 13 +++---------- posthog/test/ai/langchain/test_callbacks.py | 9 ++++----- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bd2abdbb..d3c96ddd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,7 +36,7 @@ jobs: - name: Lint with flake8 run: | - flake8 posthog --ignore E501 + flake8 posthog --ignore E501,W503 - name: Check import order with isort run: | diff --git a/posthog/ai/langchain/callbacks.py b/posthog/ai/langchain/callbacks.py index 941d012c..679182ad 100644 --- a/posthog/ai/langchain/callbacks.py +++ b/posthog/ai/langchain/callbacks.py @@ -19,21 +19,14 @@ from uuid import UUID from langchain.callbacks.base import BaseCallbackHandler -from langchain_core.messages import ( - AIMessage, - BaseMessage, - FunctionMessage, - HumanMessage, - SystemMessage, - ToolMessage, -) -from langchain_core.outputs import ChatGeneration, LLMResult from langchain.schema.agent import AgentAction, AgentFinish +from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.outputs import ChatGeneration, LLMResult from pydantic import BaseModel +from posthog import default_client from posthog.ai.utils import get_model_params, with_privacy_mode from posthog.client import Client -from posthog import default_client log = logging.getLogger("posthog") diff --git a/posthog/test/ai/langchain/test_callbacks.py b/posthog/test/ai/langchain/test_callbacks.py index 4e9d74bf..e036e257 100644 --- a/posthog/test/ai/langchain/test_callbacks.py +++ b/posthog/test/ai/langchain/test_callbacks.py @@ -1,21 +1,20 @@ import logging import math import os -from pyexpat.errors import messages import time -from typing import List, Optional, TypedDict, Union import uuid -from unittest.mock import patch, ANY +from typing import List, Optional, TypedDict, Union +from unittest.mock import patch import pytest from langchain_anthropic.chat_models import ChatAnthropic from langchain_community.chat_models.fake import FakeMessagesListChatModel from langchain_community.llms.fake import FakeListLLM, FakeStreamingListLLM -from langchain_core.messages import HumanMessage, AIMessage +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda from langchain_openai.chat_models import ChatOpenAI -from langgraph.graph.state import StateGraph, START, END +from langgraph.graph.state import END, START, StateGraph from posthog.ai.langchain import CallbackHandler From d7da8f47221e676c871ebd78dec7a9971831fe75 Mon Sep 17 00:00:00 2001 From: Michael Matloka Date: Tue, 21 Jan 2025 22:38:39 +0100 Subject: [PATCH 8/8] Fix graph test --- posthog/test/ai/langchain/test_callbacks.py | 58 ++++++++++++--------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/posthog/test/ai/langchain/test_callbacks.py b/posthog/test/ai/langchain/test_callbacks.py index e036e257..87517e1c 100644 --- a/posthog/test/ai/langchain/test_callbacks.py +++ b/posthog/test/ai/langchain/test_callbacks.py @@ -433,30 +433,30 @@ def test_graph_state(mock_client): graph = StateGraph(FakeGraphState) graph.add_node( "fake_plain", - lambda state: ( - { - "messages": [ - *state["messages"], - AIMessage(content="Let's explore bar."), - ], - "xyz": "abc", - } - ), + lambda state: { + "messages": [ + *state["messages"], + AIMessage(content="Let's explore bar."), + ], + "xyz": "abc", + }, + ) + intermediate_chain = ChatPromptTemplate.from_messages( + [("user", "Question: What's a bar?")] + ) | FakeMessagesListChatModel( + responses=[ + AIMessage(content="It's a type of greeble."), + ] ) graph.add_node( "fake_llm", - lambda state: ( - ChatPromptTemplate.from_messages([("user", "Foo")]) - | FakeMessagesListChatModel( - responses=[ - *state["messages"], - AIMessage(content="It's a type of greeble."), - ] - ) - ).invoke( - state, - config=config, - ), + lambda state: { + "messages": [ + *state["messages"], + intermediate_chain.invoke(state), + ], + "xyz": state["xyz"], + }, ) graph.add_edge(START, "fake_plain") graph.add_edge("fake_plain", "fake_llm") @@ -467,28 +467,34 @@ def test_graph_state(mock_client): config=config, ) - assert len(result["messages"]) == 2 + assert len(result["messages"]) == 3 assert isinstance(result["messages"][0], HumanMessage) assert result["messages"][0].content == "What's a bar?" assert isinstance(result["messages"][1], AIMessage) assert result["messages"][1].content == "Let's explore bar." + assert isinstance(result["messages"][2], AIMessage) + assert result["messages"][2].content == "It's a type of greeble." - assert mock_client.capture.call_count == 3 + assert mock_client.capture.call_count == 2 generation_args = mock_client.capture.call_args_list[0][1] - trace_args = mock_client.capture.call_args_list[2][1] + trace_args = mock_client.capture.call_args_list[1][1] assert generation_args["event"] == "$ai_generation" assert trace_args["event"] == "$ai_trace" assert trace_args["properties"]["$ai_trace_name"] == "LangGraph" + assert len(trace_args["properties"]["$ai_input_state"]["messages"]) == 1 assert isinstance(trace_args["properties"]["$ai_input_state"]["messages"][0], HumanMessage) assert trace_args["properties"]["$ai_input_state"]["messages"][0].content == "What's a bar?" assert trace_args["properties"]["$ai_input_state"]["messages"][0].type == "human" assert trace_args["properties"]["$ai_input_state"]["xyz"] is None - assert len(trace_args["properties"]["$ai_output_state"]["messages"]) == 2 - assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][0], HumanMessage) # FIXME + assert len(trace_args["properties"]["$ai_output_state"]["messages"]) == 3 + + assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][0], HumanMessage) assert trace_args["properties"]["$ai_output_state"]["messages"][0].content == "What's a bar?" assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][1], AIMessage) assert trace_args["properties"]["$ai_output_state"]["messages"][1].content == "Let's explore bar." + assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][2], AIMessage) + assert trace_args["properties"]["$ai_output_state"]["messages"][2].content == "It's a type of greeble." assert trace_args["properties"]["$ai_output_state"]["xyz"] == "abc"