Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:

- name: Lint with flake8
run: |
flake8 posthog --ignore E501
flake8 posthog --ignore E501,W503

- name: Check import order with isort
run: |
Expand Down
212 changes: 195 additions & 17 deletions posthog/ai/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
from uuid import UUID

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.agent import AgentAction, AgentFinish
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from pydantic import BaseModel

from posthog import default_client
from posthog.ai.utils import get_model_params, with_privacy_mode
from posthog.client import Client

Expand All @@ -44,19 +46,30 @@ class RunMetadata(TypedDict, total=False):

class CallbackHandler(BaseCallbackHandler):
"""
A callback handler for LangChain that sends events to PostHog LLM Observability.
The PostHog LLM observability callback handler for LangChain.
"""

_client: Client
"""PostHog client instance."""

_distinct_id: Optional[Union[str, int, float, UUID]]
"""Distinct ID of the user to associate the trace with."""

_trace_id: Optional[Union[str, int, float, UUID]]
"""Global trace ID to be sent with every event. Otherwise, the top-level run ID is used."""

_trace_input: Optional[Any]
"""The input at the start of the trace. Any JSON object."""

_trace_name: Optional[str]
"""Name of the trace, exposed in the UI."""

_properties: Optional[Dict[str, Any]]
"""Global properties to be sent with every event."""

_runs: RunStorage
"""Mapping of run IDs to run metadata as run metadata is only available on the start of generation."""

_parent_tree: Dict[UUID, UUID]
"""
A dictionary that maps chain run IDs to their parent chain run IDs (parent pointer tree),
Expand All @@ -65,7 +78,8 @@ class CallbackHandler(BaseCallbackHandler):

def __init__(
self,
client: Client,
client: Optional[Client] = None,
*,
distinct_id: Optional[Union[str, int, float, UUID]] = None,
trace_id: Optional[Union[str, int, float, UUID]] = None,
properties: Optional[Dict[str, Any]] = None,
Expand All @@ -81,9 +95,11 @@ def __init__(
privacy_mode: Whether to redact the input and output of the trace.
groups: Optional additional PostHog groups to use for the trace.
"""
self._client = client
self._client = client or default_client
self._distinct_id = distinct_id
self._trace_id = trace_id
self._trace_name = None
self._trace_input = None
self._properties = properties or {}
self._privacy_mode = privacy_mode
self._groups = groups or {}
Expand All @@ -97,9 +113,14 @@ def on_chain_start(
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs,
):
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
self._set_parent_of_run(run_id, parent_run_id)
if parent_run_id is None and self._trace_name is None:
self._trace_name = self._get_langchain_run_name(serialized, **kwargs)
self._trace_input = inputs

def on_chat_model_start(
self,
Expand All @@ -110,6 +131,7 @@ def on_chat_model_start(
parent_run_id: Optional[UUID] = None,
**kwargs,
):
self._log_debug_event("on_chat_model_start", run_id, parent_run_id, messages=messages)
self._set_parent_of_run(run_id, parent_run_id)
input = [_convert_message_to_dict(message) for row in messages for message in row]
self._set_run_metadata(serialized, run_id, input, **kwargs)
Expand All @@ -123,32 +145,93 @@ def on_llm_start(
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
):
self._log_debug_event("on_llm_start", run_id, parent_run_id, prompts=prompts)
self._set_parent_of_run(run_id, parent_run_id)
self._set_run_metadata(serialized, run_id, prompts, **kwargs)

def on_llm_new_token(
self,
token: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""
self._log_debug_event("on_llm_new_token", run_id, parent_run_id, token=token)

def on_tool_start(
self,
serialized: Optional[Dict[str, Any]],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_tool_start", run_id, parent_run_id, input_str=input_str)

def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output)

def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error)

def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
):
self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs)
self._pop_parent_of_run(run_id)

if parent_run_id is None:
self._capture_trace(run_id, outputs=outputs)

def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
):
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
self._pop_parent_of_run(run_id)

if parent_run_id is None:
self._capture_trace(run_id, outputs=None)

def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
):
"""
The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
"""
self._log_debug_event("on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs)
trace_id = self._get_trace_id(run_id)
self._pop_parent_of_run(run_id)
run = self._pop_run_metadata(run_id)
Expand Down Expand Up @@ -189,25 +272,15 @@ def on_llm_end(
groups=self._groups,
)

def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
):
self._pop_parent_of_run(run_id)

def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
):
self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error)
trace_id = self._get_trace_id(run_id)
self._pop_parent_of_run(run_id)
run = self._pop_run_metadata(run_id)
Expand Down Expand Up @@ -235,6 +308,50 @@ def on_llm_error(
groups=self._groups,
)

def on_retriever_start(
self,
serialized: Optional[Dict[str, Any]],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query)

def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever errors."""
self._log_debug_event("on_retriever_error", run_id, parent_run_id, error=error)

def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent action."""
self._log_debug_event("on_agent_action", run_id, parent_run_id, action=action)

def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_agent_finish", run_id, parent_run_id, finish=finish)

def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None):
"""
Set the parent run ID for a chain run. If there is no parent, the run is the root.
Expand Down Expand Up @@ -304,6 +421,65 @@ def _get_trace_id(self, run_id: UUID):
trace_id = uuid.uuid4()
return trace_id

def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> str:
"""Retrieve the name of a serialized LangChain runnable.

The prioritization for the determination of the run name is as follows:
- The value assigned to the "name" key in `kwargs`.
- The value assigned to the "name" key in `serialized`.
- The last entry of the value assigned to the "id" key in `serialized`.
- "<unknown>".

Args:
serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data.
**kwargs (Any): Additional keyword arguments, potentially including the 'name' override.

Returns:
str: The determined name of the Langchain runnable.
"""
if "name" in kwargs and kwargs["name"] is not None:
return kwargs["name"]

try:
return serialized["name"]
except (KeyError, TypeError):
pass

try:
return serialized["id"][-1]
except (KeyError, TypeError):
pass

def _capture_trace(self, run_id: UUID, *, outputs: Optional[Dict[str, Any]]):
trace_id = self._get_trace_id(run_id)
event_properties = {
"$ai_trace_name": self._trace_name,
"$ai_trace_id": trace_id,
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, self._trace_input),
**self._properties,
}
if outputs is not None:
event_properties["$ai_output_state"] = with_privacy_mode(self._client, self._privacy_mode, outputs)
if self._distinct_id is None:
event_properties["$process_person_profile"] = False
self._client.capture(
distinct_id=self._distinct_id or trace_id,
event="$ai_trace",
properties=event_properties,
groups=self._groups,
)

def _log_debug_event(
self,
event_name: str,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs,
):
log.debug(
f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}, kwargs: {kwargs}"
)


def _extract_raw_esponse(last_response):
"""Extract the response from the last response of the LLM call."""
Expand Down Expand Up @@ -339,7 +515,9 @@ def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
return message_dict


def _parse_usage_model(usage: Union[BaseModel, Dict]) -> Tuple[Union[int, None], Union[int, None]]:
def _parse_usage_model(
usage: Union[BaseModel, Dict],
) -> Tuple[Union[int, None], Union[int, None]]:
if isinstance(usage, BaseModel):
usage = usage.__dict__

Expand Down
1 change: 1 addition & 0 deletions posthog/test/ai/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

pytest.importorskip("langchain")
pytest.importorskip("langchain_community")
pytest.importorskip("langgraph")
Loading
Loading