diff --git a/dev/README.md b/dev/README.md
new file mode 100644
index 0000000..05be68b
--- /dev/null
+++ b/dev/README.md
@@ -0,0 +1,93 @@
+# Development Tools for Processors
+
+## Processor Trace
+
+Processor traces allow you to record the inputs, outputs, and internal steps of
+a processor during its execution. This is useful for debugging, analysis, and
+understanding processor behavior.
+
+A trace is a timeline of events, where each event represents an input part, an
+output part, or a call to a sub-processor. Events are time stamped and ordered
+chronologically. If a processor calls other processors, sub-traces are created
+and nested within the main trace, providing a hierarchical view of execution.
+
+### Enabling Tracing
+
+To enable trace collection for a processor, use a `Trace` context manager. We
+provide examples here with the `SyncFileTrace` context manager implementation.
+Other approaches could be implemented in the future (e.g. stored in a DB or
+streaming into a file instead for writing when the trace is done).
+
+```python
+import asyncio
+from genai_processors import processor
+from genai_processors.dev import trace_file
+
+@processor.processor_function
+async def my_processor_fn(content):
+ ...
+
+async def main():
+ trace_dir = '/path/to/your/trace/directory'
+ # Any processor call within this context will be traced.
+ # Change `trace_file.SyncFileTrace` with other tracing implementation if
+ # needed.
+ async with trace_file.SyncFileTrace(trace_dir):
+ await processor.apply_async(my_processor_fn, parts)
+```
+
+### Default implementation: write to files
+
+The default implementation of tracing is done with `trace_file.SyncFileTrace`.
+When a processor is called within a `SyncFileTrace`, it records its execution
+and saves it into two files under `trace_dir` provided to the trace scope:
+
+- `{processor_name}_{trace_id}.json` containing a json dictionary that can be
+loaded for further programmatic analysis using `SyncFileTrace.load`:
+
+ ```python
+ import os
+ from genai_processors.dev import trace_file
+
+ trace_dir = '/path/to/your/trace/directory'
+ traces = []
+ for f in os.listdir(trace_dir):
+ if f.endswith('.json'):
+ traces.append(trace_file.SyncFileTrace.load(
+ os.path.join(trace_dir, f)
+ )
+ )
+ ```
+
+- `{processor_name}_{trace_id}.html` containing an HTML representation of the
+trace that can easily be viewed on a web browser. This is the same content as
+the json dictionary.
+
+### Implementing a new tracing
+
+To implement a custom trace sink (e.g., save to a database, stream to a network
+ location), you need to extend the abstract base class `trace.Trace` from
+`genai_processors.dev.trace` and implement its abstract methods. Your new class
+can then be used in place of `SyncFileTrace`.
+
+You must implement the following methods:
+
+* `async def add_input(self, part: content_api.ProcessorPart) -> None`:
+Handles input parts received by the processor.
+* `async def add_output(self, part: content_api.ProcessorPart) -> None`:
+Handles output parts produced by the processor.
+* `async def add_sub_trace(self) -> Trace`:
+Handles the start of a nested processor call. The returned `trace` should be an
+instance of your custom trace implementation.
+* `async def _finalize(self) -> None:`: Called when the trace context is
+exited. Use this to perform final actions like flushing buffers, closing
+connections, or writing data to disk.
+
+**Asynchronous Design**
+
+All event-handling methods (`add_input`, `add_output`, `add_sub_trace`) and
+`_finalize` are `async`. This design prevents tracing from blocking the
+processor's execution thread, which is critical in an asynchronous framework.
+If your tracing implementation needs to perform I/O (e.g., writing to a remote
+database or file system), you can use `await` for these operations without
+blocking the processor.
diff --git a/dev/__init__.py b/dev/__init__.py
new file mode 100644
index 0000000..9505752
--- /dev/null
+++ b/dev/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Dev-only tools and features for genai processors."""
diff --git a/dev/trace.py b/dev/trace.py
new file mode 100644
index 0000000..13cd398
--- /dev/null
+++ b/dev/trace.py
@@ -0,0 +1,166 @@
+"""Abstract class of a trace to collect, work with and display processor traces.
+
+A GenAIprocessor trace is a timeline of input and output events that
+were used in a GenAI processor. It includes the user input and potentially the
+audio and/or video stream in case of a realtime processor. The trace also
+includes the function calls and responses made by the processor. Finally, it
+includes the model output parts and any other arbitrary parts produced by the
+processor. An event can also be a trace itself if a processor calls another one
+internally.
+
+A trace corresponds to a single processor call. If the processor is called
+multiple times, multiple traces will be produced, each containing the input
+used to call the processor and the output produced by the call.
+
+__WARNING__: This is an incubating feature. The trace format is subject to
+changes and we do not guarantee backward compatibility at this stage.
+"""
+
+from __future__ import annotations
+
+import abc
+import contextlib
+import contextvars
+import datetime
+from typing import Any
+
+from absl import logging
+from genai_processors import content_api
+import pydantic
+import shortuuid
+
+
+pydantic_converter = pydantic.TypeAdapter(Any)
+
+
+class Trace(pydantic.BaseModel, abc.ABC):
+ """A trace of a processor call.
+
+ A trace contains some information about when the processor was called and
+ includes methods to log input, output and sub-traces to the trace.
+
+ The finalize method must be called to finalize the trace and release any
+ resources.
+
+ This is up to the implementer to decide how to store the trace.
+
+ The add_sub_trace method should be used to create a new trace.
+ """
+
+ model_config = {'arbitrary_types_allowed': True}
+
+ # Name of the trace.
+ name: str | None = None
+
+ # A description of the processor that produced this trace, i.e. arguments used
+ # to construct the processor.
+ processor_description: str | None = None
+
+ # A unique ID for the trace.
+ trace_id: str = pydantic.Field(default_factory=lambda: str(shortuuid.uuid()))
+
+ # Boolean indicating whether the trace has just been created. This is used to
+ # determine whether to create a subtrace when a processor is called or using
+ # the existing trace when it's just been created.
+ is_new: bool = False
+
+ # The timestamp when the trace was started (the processor was called).
+ start_time: datetime.datetime = pydantic.Field(
+ default_factory=datetime.datetime.now
+ )
+ # The timestamp when the trace was ended (the processor returned).
+ end_time: datetime.datetime | None = None
+
+ _token: contextvars.Token[Trace | None] | None = pydantic.PrivateAttr(
+ default=None
+ )
+
+ async def __aenter__(self) -> Trace:
+ parent_trace = CURRENT_TRACE.get()
+
+ if parent_trace:
+ logging.warning(
+ 'Cannot enter a trace while another trace is already in scope: %s is'
+ ' ignored in favor of %s',
+ self,
+ parent_trace,
+ )
+
+ self.is_new = True
+ self._token = CURRENT_TRACE.set(self)
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: Any,
+ ) -> None:
+ if self._token is None:
+ return
+
+ self.end_time = datetime.datetime.now()
+ CURRENT_TRACE.reset(self._token)
+ await self._finalize()
+
+ @abc.abstractmethod
+ async def add_input(self, part: content_api.ProcessorPart) -> None:
+ """Adds an input part to the trace events."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def add_output(self, part: content_api.ProcessorPart) -> None:
+ """Adds an output part to the trace events."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def add_sub_trace(self, name: str) -> Trace:
+ """Adds a sub-trace from a nested processor call to the trace events.
+
+ Args:
+ name: The name of the sub-trace.
+
+ Returns:
+ The trace that was added to the trace events.
+ """
+ # TODO(elisseeff, kibergus): consider adding a more generic relationship
+ # between traces, e.g. traces generated one after another (wiht the + ops)
+ # or traces generated in parallel (with the // ops).
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def _finalize(self) -> None:
+ """Finalize the trace.
+
+ At this stage, the trace is ready to be stored and/or displayed. It is up
+ to the implementer to decide how to store the trace. When this function
+ returns all traces should be considered finalized and stored.
+ """
+ raise NotImplementedError()
+
+
+CURRENT_TRACE: contextvars.ContextVar[Trace | None] = contextvars.ContextVar(
+ 'current_trace', default=None
+)
+
+
+@contextlib.asynccontextmanager
+async def call_scope(processor_name: str):
+ """Context manager for tracing a processor call."""
+ parent_trace = CURRENT_TRACE.get()
+
+ if parent_trace is None:
+ # No tracing in scope - keep things as is.
+ yield None
+ elif parent_trace.is_new:
+ # First call to a processor - re-use the root trace. It has been created
+ # when the trace_scope was entered.
+ parent_trace.name = processor_name
+ parent_trace.is_new = False
+ yield parent_trace
+ else:
+ # Parent is not None and corresponds to an existing trace: adds a new trace.
+ async with await parent_trace.add_sub_trace(
+ name=processor_name
+ ) as new_trace:
+ yield new_trace
diff --git a/dev/trace.tpl.html b/dev/trace.tpl.html
new file mode 100644
index 0000000..41ea991
--- /dev/null
+++ b/dev/trace.tpl.html
@@ -0,0 +1,596 @@
+
+
+
+ Trace Viewer
+
+
+
+
+
+
+
+
+
diff --git a/dev/trace_file.py b/dev/trace_file.py
new file mode 100644
index 0000000..635c57e
--- /dev/null
+++ b/dev/trace_file.py
@@ -0,0 +1,202 @@
+"""Trace implementation that stores traces in two files: JSON and HTML.
+
+__WARNING__: This is an incubating feature. The trace format is subject to
+changes and we do not guarantee backward compatibility at this stage.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import base64
+import datetime
+import json
+import os
+from typing import Any, override
+
+from genai_processors import content_api
+from genai_processors.dev import trace
+import pydantic
+import shortuuid
+
+HTML_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), 'trace.tpl.html')
+with open(HTML_TEMPLATE_PATH, 'r') as f:
+ HTML_TEMPLATE = f.read()
+
+pydantic_converter = pydantic.TypeAdapter(Any)
+
+
+def _bytes_encoder(o: Any) -> Any:
+ """Encodes bytes in parts based on mime type.
+
+ The dump_python(mode='json') in Pydantic does not encode bytes in utf-8
+ mode and this causes issues when sending ProcessorPart to JS/HTML clients
+ (wrong padding, etc.). This function is used to handle bytes to base64
+ encoding to match the behaviour of the JS/HTML side.
+
+ Args:
+ o: The object to encode.
+
+ Returns:
+ The encoded object.
+ """
+ if isinstance(o, bytes):
+ return base64.b64encode(o).decode('utf-8')
+ else:
+ return pydantic_converter.dump_python(o, mode='json')
+
+
+# TODO(elisseeff): Adjust the logic to make it less brittle. If a new bytes
+# field is added to the ProcessorPart in the future this function will not
+# decode it while it should.
+def _bytes_decoder(dct: dict[str, Any]) -> Any:
+ """Decodes base64 encoded bytes in parts based on mime type."""
+ if 'data' in dct and 'mime_type' in dct and isinstance(dct['data'], str):
+ mime_type = dct['mime_type']
+ if not mime_type.startswith('text/') and not mime_type.startswith(
+ 'application/json'
+ ):
+ try:
+ dct['data'] = base64.b64decode(dct['data'])
+ except (ValueError, TypeError):
+ pass
+ return dct
+
+
+class TraceEvent(pydantic.BaseModel):
+ """A single event in a trace.
+
+ An event represents an input/output part or a sub-trace from a nested
+ processor call.
+
+ This class is not used in this abstract base class, but is recommend to be
+ used in the implementations of the trace.
+ """
+
+ model_config = {'arbitrary_types_allowed': True}
+
+ # A unique ID for this event.
+ id: str = pydantic.Field(default_factory=lambda: str(shortuuid.uuid()))
+ # The timestamp when the event was stored in the trace.
+ timestamp: datetime.datetime = pydantic.Field(
+ default_factory=datetime.datetime.now
+ )
+ # Whether the event is an input part to the processor or an output part.
+ is_input: bool = False
+
+ # The part of the event (as dictionary). None if sub_trace is provided.
+ # By serializing the part into a dict we ensure that even if the part is
+ # mutated later, the logged value won't change.
+ part_dict: dict[str, Any] | None = None
+ # If set, this event represents a nested processor call via its trace.
+ sub_trace: SyncFileTrace | None = None
+
+
+class SyncFileTrace(trace.Trace):
+ """A trace storing events in a file.
+
+ The trace collects all events first in memory and then writes them to a file
+ when the finalize method is called.
+ """
+
+ # Where to store the trace. Required only of the root trace.
+ trace_dir: str | None = None
+
+ # The events in the trace. Collected in memory.
+ events: list[TraceEvent] = []
+
+ def to_json_str(self) -> str:
+ """Converts the trace to a JSON string."""
+ try:
+ return json.dumps(
+ self.model_dump(mode='python'),
+ default=_bytes_encoder,
+ )
+ except TypeError as e:
+ raise TypeError(
+ 'Failed to serialize trace to JSON. This might be due to'
+ ' non-serializable types in ProcessorPart metadata. Ensure parts'
+ ' added to traces are JSON-serializable.'
+ ) from e
+
+ @classmethod
+ def from_json_str(cls, json_str: str) -> trace.Trace:
+ """Initializes the trace from a JSON string.
+
+ Args:
+ json_str: The JSON string to initialize the trace from. The bytes field
+ for audio and image parts are expected to be base64 + utf-8 encoded.
+
+ Returns:
+ The trace initialized from the JSON string.
+ """
+ return cls.model_validate(json.loads(json_str, object_hook=_bytes_decoder))
+
+ def _add_part(self, part: content_api.ProcessorPart, is_input: bool) -> None:
+ """Adds an input or output part to the trace events."""
+ event = TraceEvent(
+ part_dict=part.to_dict(mode='python'),
+ is_input=is_input,
+ )
+ self.events.append(event)
+
+ @override
+ async def add_input(self, part: content_api.ProcessorPart) -> None:
+ """Adds an input part to the trace events."""
+ self._add_part(part, is_input=True)
+
+ @override
+ async def add_output(self, part: content_api.ProcessorPart) -> None:
+ """Adds an output part to the trace events."""
+ self._add_part(part, is_input=False)
+
+ @override
+ async def add_sub_trace(self, name: str) -> SyncFileTrace:
+ """Adds a sub-trace from a nested processor call to the trace events."""
+ t = SyncFileTrace(name=name)
+ self.events.append(TraceEvent(sub_trace=t, is_input=False))
+ return t
+
+ @override
+ async def _finalize(self) -> None:
+ """Saves the trace to a file."""
+ if not self.trace_dir:
+ return
+ trace_filename = os.path.join(
+ self.trace_dir, f'{self.name}_{self.trace_id}'
+ )
+ await asyncio.to_thread(self.save, trace_filename + '.json')
+ await asyncio.to_thread(self.save_html, trace_filename + '.html')
+
+ def save_html(self, path: str) -> None:
+ """Saves an HTML rendering of the trace to a file."""
+ html = HTML_TEMPLATE.format(trace_json=self.to_json_str())
+ print(f'html: {path}')
+ with open(path, 'w') as html_file:
+ html_file.write(html)
+
+ def save(self, path: str) -> None:
+ """Saves a trace to a file in JSON format.
+
+ Args:
+ path: The path to the file.
+ """
+ with open(path, 'w') as html_file:
+ html_file.write(self.to_json_str())
+
+ @classmethod
+ def load(cls, path: str) -> 'SyncFileTrace':
+ """Reads a trace from a JSON file.
+
+ Args:
+ path: The path to the file.
+
+ Returns:
+ The trace.
+ """
+ with open(path, 'r') as html_file:
+ return SyncFileTrace.model_validate(
+ json.loads(html_file.read(), object_hook=_bytes_decoder)
+ )
+
+
+SyncFileTrace.model_rebuild(force=True)
diff --git a/genai_processors/content_api.py b/genai_processors/content_api.py
index e2c49f5..415772d 100644
--- a/genai_processors/content_api.py
+++ b/genai_processors/content_api.py
@@ -1,4 +1,4 @@
-# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
+# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -45,6 +45,17 @@ class ProcessorPart:
belongs to, the MIME type of the content, and arbitrary metadata.
"""
+ # GenAI SDK part representation of the ProcessorPart.
+ _part: genai_types.Part | None = None
+ # Metadata about the part. Can store arbitrary key/value pairs.
+ _metadata: dict[str, Any] = {}
+ # Mime type of the part.
+ _mimetype: str | None = None
+ # Role of the part.
+ _role: str = ''
+ # Substream name of the part.
+ _substream_name: str = ''
+
def __init__(
self,
value: 'ProcessorPartTypes',
@@ -144,10 +155,7 @@ def __init__(
if mimetype:
self._mimetype = mimetype
# Otherwise, if MIME type is specified using inline data, use that.
- elif (
- self._part.inline_data
- and self._part.inline_data.mime_type
- ):
+ elif self._part.inline_data and self._part.inline_data.mime_type:
self._mimetype = self._part.inline_data.mime_type
elif self._part.function_call:
# OSS library can't depend on protobuf, so we hardcode literal here.
@@ -348,7 +356,9 @@ def get_proto_message(
) -> pb_message.Message:
"""Returns representation of the Part as a given proto message."""
if not mime_types.is_proto_message(self.mimetype, proto_message):
- raise ValueError('Part is not a proto message.')
+ raise ValueError(
+ f'Part is not a {proto_message.DESCRIPTOR.name} proto message.'
+ )
return proto_message.FromString(self.bytes)
@property
diff --git a/genai_processors/processor.py b/genai_processors/processor.py
index 8653b8f..172cb75 100644
--- a/genai_processors/processor.py
+++ b/genai_processors/processor.py
@@ -32,6 +32,7 @@
from genai_processors import map_processor
from genai_processors import mime_types
from genai_processors import streams
+from genai_processors.dev import trace
# Aliases
context = context_lib.context
@@ -127,7 +128,8 @@ async def __call__(
Yields:
the result of processing the input content.
"""
- content = _normalize_part_stream(content, producer=self)
+ normalized_content = _normalize_part_stream(content, producer=self)
+
# Ensures that the same taskgroup is always added to the context and
# includes the proper way of handling generators, i.e. use a queue inside
# the task group instead of a generator.
@@ -143,30 +145,47 @@ async def __call__(
# always executed within the task group and that the `CancelledError` is
# handled correctly.
tg = context_lib.task_group()
- if tg is None:
- output_queue = asyncio.Queue[ProcessorPart | None]()
-
- async def _with_context():
- async with context():
- try:
- async for p in _normalize_part_stream(
- self.call(content), producer=self.call
- ):
- output_queue.put_nowait(p)
- finally:
- output_queue.put_nowait(None)
-
- task = asyncio.create_task(_with_context())
- try:
- async for p in streams.dequeue(output_queue):
+
+ async with trace.call_scope(self.key_prefix) as current_trace:
+
+ if current_trace:
+
+ async def stream_input() -> AsyncIterable[ProcessorPart]:
+ async for part in normalized_content:
+ await current_trace.add_input(part)
+ yield part
+
+ else:
+ stream_input = lambda: normalized_content
+
+ if tg is None:
+ output_queue = asyncio.Queue[ProcessorPart | None]()
+
+ async def _with_context():
+ async with context():
+ try:
+ async for p in _normalize_part_stream(
+ self.call(stream_input()), producer=self.call
+ ):
+ if current_trace:
+ await current_trace.add_output(p)
+ output_queue.put_nowait(p)
+ finally:
+ output_queue.put_nowait(None)
+
+ task = asyncio.create_task(_with_context())
+ try:
+ async for p in streams.dequeue(output_queue):
+ yield p
+ finally:
+ await task
+ else:
+ async for p in _normalize_part_stream(
+ self.call(stream_input()), producer=self.call
+ ):
+ if current_trace:
+ await current_trace.add_output(p)
yield p
- finally:
- await task
- else:
- async for p in _normalize_part_stream(
- self.call(content), producer=self.call
- ):
- yield p
@abc.abstractmethod
async def call(
diff --git a/genai_processors/tests/content_api_test.py b/genai_processors/tests/content_api_test.py
index 08cf668..86b82e8 100644
--- a/genai_processors/tests/content_api_test.py
+++ b/genai_processors/tests/content_api_test.py
@@ -427,6 +427,28 @@ def test_get_proto_message_raises_error_with_incorrect_mimetype(self):
with self.assertRaises(ValueError):
test_part.get_proto_message(struct_pb2.Struct)
+ def test_get_proto_message_from_bytes(self):
+ test_proto = struct_pb2.Struct(
+ fields={'foo': struct_pb2.Value(string_value='bar')}
+ )
+ part = content_api.ProcessorPart(
+ test_proto.SerializeToString(),
+ mimetype=mime_types.proto_message_mime_type(struct_pb2.Struct),
+ )
+ self.assertEqual(part.get_proto_message(struct_pb2.Struct), test_proto)
+
+ def test_get_proto_message_raises_error_with_incorrect_proto_type(self):
+ test_proto = struct_pb2.Struct(
+ fields={'foo': struct_pb2.Value(string_value='bar')}
+ )
+ part = content_api.ProcessorPart.from_proto_message(
+ proto_message=test_proto
+ )
+ with self.assertRaisesRegex(
+ ValueError, 'Part is not a Duration proto message.'
+ ):
+ part.get_proto_message(duration_pb2.Duration)
+
def test_is_text(self):
text_part = content_api.ProcessorPart('hello')
self.assertTrue(mime_types.is_text(text_part.mimetype))
diff --git a/genai_processors/tests/trace_file_test.py b/genai_processors/tests/trace_file_test.py
new file mode 100644
index 0000000..88745a7
--- /dev/null
+++ b/genai_processors/tests/trace_file_test.py
@@ -0,0 +1,197 @@
+import asyncio
+from collections.abc import AsyncIterable
+import io
+import json
+import os
+import shutil
+from typing import cast
+import unittest
+
+from absl.testing import absltest
+from genai_processors import content_api
+from genai_processors import mime_types
+from genai_processors import processor
+from genai_processors import streams
+from genai_processors.dev import trace_file
+import numpy as np
+from PIL import Image
+from scipy.io import wavfile
+
+
+@processor.processor_function
+async def to_upper_fn(
+ content: AsyncIterable[content_api.ProcessorPart],
+) -> AsyncIterable[content_api.ProcessorPartTypes]:
+ async for part in content:
+ await asyncio.sleep(0.01) # to ensure timestamps are different
+ if mime_types.is_text(part.mimetype):
+ yield part.text.upper() + '_sub_trace'
+ else:
+ yield part
+
+
+class SubTraceProcessor(processor.Processor):
+
+ def __init__(self):
+ super().__init__()
+ self.sub_processor = to_upper_fn
+
+ async def call(
+ self, content: AsyncIterable[content_api.ProcessorPart]
+ ) -> AsyncIterable[content_api.ProcessorPartTypes]:
+ async for part in self.sub_processor(content):
+ if isinstance(part, content_api.ProcessorPart) and mime_types.is_text(
+ part.mimetype
+ ):
+ yield part.text + '_outer'
+ else:
+ yield part
+
+
+class TraceTest(unittest.IsolatedAsyncioTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.trace_dir = os.path.join(absltest.get_default_test_tmpdir(), 'traces')
+ os.makedirs(self.trace_dir, exist_ok=True)
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.trace_dir)
+
+ async def test_trace_generation_and_timestamps(self):
+ p = SubTraceProcessor()
+ input_parts = [content_api.ProcessorPart('hello')]
+ async with trace_file.SyncFileTrace(trace_dir=self.trace_dir):
+ results = await streams.gather_stream(
+ p(streams.stream_content(input_parts))
+ )
+
+ self.assertEqual(results[0].text, 'HELLO_sub_trace_outer')
+ json_files = [f for f in os.listdir(self.trace_dir) if f.endswith('.json')]
+ self.assertTrue(len(json_files), 1)
+ trace_path = os.path.join(self.trace_dir, json_files[0])
+ self.assertTrue(os.path.exists(trace_path.replace('.json', '.html')))
+
+ trace = trace_file.SyncFileTrace.load(trace_path)
+
+ # First event is a subtrace for the upper function. This is was is first
+ # entered in the trace scope.
+ self.assertFalse(trace.events[0].is_input)
+ sub_trace = cast(trace_file.SyncFileTrace, trace.events[0].sub_trace)
+ self.assertIsNotNone(sub_trace)
+ self.assertIn('to_upper_fn', sub_trace.name)
+ self.assertFalse(sub_trace.events[1].is_input)
+ self.assertEqual(
+ sub_trace.events[1].part_dict['part']['text'], 'HELLO_sub_trace'
+ )
+ self.assertIsNotNone(sub_trace.start_time)
+ self.assertIsNotNone(sub_trace.end_time)
+ self.assertLess(sub_trace.start_time, sub_trace.end_time)
+
+ # Second input event is the input part to SubTraceProcessor.
+ self.assertTrue(trace.events[1].is_input)
+ self.assertEqual(trace.events[1].part_dict['part']['text'], 'hello')
+
+ # Third event is the output event of SubTraceProcessor
+ self.assertFalse(trace.events[2].is_input)
+ self.assertEqual(
+ trace.events[2].part_dict['part']['text'], 'HELLO_sub_trace_outer'
+ )
+
+ async def test_trace_references(self):
+ p = SubTraceProcessor()
+ input_part = content_api.ProcessorPart('world')
+ # First call
+ async with trace_file.SyncFileTrace(trace_dir=self.trace_dir):
+ await streams.gather_stream(p(streams.stream_content([input_part])))
+
+ json_files = [f for f in os.listdir(self.trace_dir) if f.endswith('.json')]
+ self.assertTrue(len(json_files), 1)
+ trace1_path = os.path.join(self.trace_dir, json_files[0])
+ trace1 = trace_file.SyncFileTrace.load(trace1_path)
+ self.assertTrue(trace1.events[1].is_input)
+ self.assertEqual(trace1.events[1].part_dict['part']['text'], 'world')
+
+ sub_trace1 = cast(trace_file.SyncFileTrace, trace1.events[0].sub_trace)
+ self.assertIsNotNone(sub_trace1)
+ self.assertTrue(sub_trace1.events[0].is_input)
+ self.assertIsNotNone(sub_trace1.events[0].part_dict)
+
+ # Second call with same part
+ for f in os.listdir(self.trace_dir):
+ os.remove(os.path.join(self.trace_dir, f))
+ async with trace_file.SyncFileTrace(trace_dir=self.trace_dir):
+ await streams.gather_stream(p(streams.stream_content([input_part])))
+ json_files = [f for f in os.listdir(self.trace_dir) if f.endswith('.json')]
+ self.assertTrue(len(json_files), 1)
+ trace2_path = os.path.join(self.trace_dir, json_files[0])
+ trace2 = trace_file.SyncFileTrace.load(trace2_path)
+ self.assertTrue(trace2.events[1].is_input)
+ self.assertIsNotNone(trace2.events[1].part_dict)
+
+ async def test_trace_save_load(self):
+ trace = trace_file.SyncFileTrace(name='test')
+ await trace.add_input(content_api.ProcessorPart('in'))
+ await trace.add_input(
+ content_api.ProcessorPart.from_bytes(
+ data=b'bytes',
+ mimetype='image/jpeg',
+ )
+ )
+ sub_trace = await trace.add_sub_trace(name='sub_test')
+ await sub_trace.add_input(content_api.ProcessorPart('sub_in'))
+ await sub_trace.add_output(content_api.ProcessorPart('sub_out'))
+ await trace.add_output(content_api.ProcessorPart('out'))
+
+ tmpdir = absltest.get_default_test_tmpdir()
+ trace_path = os.path.join(tmpdir, 'trace.json')
+
+ trace.save(trace_path)
+ loaded_trace = trace_file.SyncFileTrace.load(trace_path)
+
+ self.assertEqual(
+ json.loads(trace.model_dump_json()),
+ json.loads(loaded_trace.model_dump_json()),
+ )
+
+ async def test_save_html(self):
+ p = SubTraceProcessor()
+ trace_dir = os.getenv('TEST_UNDECLARED_OUTPUTS_DIR')
+
+ # Create a small green image using PIL
+ img = Image.new('RGB', (10, 10), color='green')
+ img_bytes_io = io.BytesIO()
+ img.save(img_bytes_io, format='PNG')
+ img_part = content_api.ProcessorPart.from_bytes(
+ data=img_bytes_io.getvalue(),
+ mimetype='image/png',
+ )
+
+ # Generate a small random WAV audio part
+ sample_rate = 16000 # samples per second
+ duration = 0.1 # seconds
+ num_samples = int(sample_rate * duration)
+ # Generate random samples between -1 and 1
+ random_samples = np.random.uniform(-1, 1, num_samples)
+ # Scale to int16 range
+ audio_data = (random_samples * 32767).astype(np.int16)
+
+ audio_bytes_io = io.BytesIO()
+ wavfile.write(audio_bytes_io, sample_rate, audio_data)
+ audio_part = content_api.ProcessorPart.from_bytes(
+ data=audio_bytes_io.getvalue(),
+ mimetype='audio/wav',
+ )
+ parts = [img_part, audio_part, content_api.ProcessorPart('hello')]
+ async with trace_file.SyncFileTrace(trace_dir=trace_dir):
+ await processor.apply_async(p, parts)
+
+ html_files = [f for f in os.listdir(trace_dir) if f.endswith('.html')]
+ self.assertTrue(len(html_files), 1)
+ trace_path = os.path.join(trace_dir, html_files[0])
+ self.assertTrue(os.path.exists(trace_path))
+
+
+if __name__ == '__main__':
+ absltest.main()