From 7ae123ec3c7c3e763496a216b7f1438e063b66ea Mon Sep 17 00:00:00 2001 From: Alexey Guseynov Date: Wed, 21 Jan 2026 08:53:05 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 859122394 --- dev/README.md | 93 ++++ dev/__init__.py | 15 + dev/trace.py | 166 ++++++ dev/trace.tpl.html | 596 +++++++++++++++++++++ dev/trace_file.py | 202 +++++++ genai_processors/content_api.py | 22 +- genai_processors/processor.py | 67 ++- genai_processors/tests/content_api_test.py | 22 + genai_processors/tests/trace_file_test.py | 197 +++++++ 9 files changed, 1350 insertions(+), 30 deletions(-) create mode 100644 dev/README.md create mode 100644 dev/__init__.py create mode 100644 dev/trace.py create mode 100644 dev/trace.tpl.html create mode 100644 dev/trace_file.py create mode 100644 genai_processors/tests/trace_file_test.py diff --git a/dev/README.md b/dev/README.md new file mode 100644 index 0000000..05be68b --- /dev/null +++ b/dev/README.md @@ -0,0 +1,93 @@ +# Development Tools for Processors + +## Processor Trace + +Processor traces allow you to record the inputs, outputs, and internal steps of +a processor during its execution. This is useful for debugging, analysis, and +understanding processor behavior. + +A trace is a timeline of events, where each event represents an input part, an +output part, or a call to a sub-processor. Events are time stamped and ordered +chronologically. If a processor calls other processors, sub-traces are created +and nested within the main trace, providing a hierarchical view of execution. + +### Enabling Tracing + +To enable trace collection for a processor, use a `Trace` context manager. We +provide examples here with the `SyncFileTrace` context manager implementation. +Other approaches could be implemented in the future (e.g. stored in a DB or +streaming into a file instead for writing when the trace is done). + +```python +import asyncio +from genai_processors import processor +from genai_processors.dev import trace_file + +@processor.processor_function +async def my_processor_fn(content): + ... + +async def main(): + trace_dir = '/path/to/your/trace/directory' + # Any processor call within this context will be traced. + # Change `trace_file.SyncFileTrace` with other tracing implementation if + # needed. + async with trace_file.SyncFileTrace(trace_dir): + await processor.apply_async(my_processor_fn, parts) +``` + +### Default implementation: write to files + +The default implementation of tracing is done with `trace_file.SyncFileTrace`. +When a processor is called within a `SyncFileTrace`, it records its execution +and saves it into two files under `trace_dir` provided to the trace scope: + +- `{processor_name}_{trace_id}.json` containing a json dictionary that can be +loaded for further programmatic analysis using `SyncFileTrace.load`: + + ```python + import os + from genai_processors.dev import trace_file + + trace_dir = '/path/to/your/trace/directory' + traces = [] + for f in os.listdir(trace_dir): + if f.endswith('.json'): + traces.append(trace_file.SyncFileTrace.load( + os.path.join(trace_dir, f) + ) + ) + ``` + +- `{processor_name}_{trace_id}.html` containing an HTML representation of the +trace that can easily be viewed on a web browser. This is the same content as +the json dictionary. + +### Implementing a new tracing + +To implement a custom trace sink (e.g., save to a database, stream to a network + location), you need to extend the abstract base class `trace.Trace` from +`genai_processors.dev.trace` and implement its abstract methods. Your new class +can then be used in place of `SyncFileTrace`. + +You must implement the following methods: + +* `async def add_input(self, part: content_api.ProcessorPart) -> None`: +Handles input parts received by the processor. +* `async def add_output(self, part: content_api.ProcessorPart) -> None`: +Handles output parts produced by the processor. +* `async def add_sub_trace(self) -> Trace`: +Handles the start of a nested processor call. The returned `trace` should be an +instance of your custom trace implementation. +* `async def _finalize(self) -> None:`: Called when the trace context is +exited. Use this to perform final actions like flushing buffers, closing +connections, or writing data to disk. + +**Asynchronous Design** + +All event-handling methods (`add_input`, `add_output`, `add_sub_trace`) and +`_finalize` are `async`. This design prevents tracing from blocking the +processor's execution thread, which is critical in an asynchronous framework. +If your tracing implementation needs to perform I/O (e.g., writing to a remote +database or file system), you can use `await` for these operations without +blocking the processor. diff --git a/dev/__init__.py b/dev/__init__.py new file mode 100644 index 0000000..9505752 --- /dev/null +++ b/dev/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dev-only tools and features for genai processors.""" diff --git a/dev/trace.py b/dev/trace.py new file mode 100644 index 0000000..13cd398 --- /dev/null +++ b/dev/trace.py @@ -0,0 +1,166 @@ +"""Abstract class of a trace to collect, work with and display processor traces. + +A GenAIprocessor trace is a timeline of input and output events that +were used in a GenAI processor. It includes the user input and potentially the +audio and/or video stream in case of a realtime processor. The trace also +includes the function calls and responses made by the processor. Finally, it +includes the model output parts and any other arbitrary parts produced by the +processor. An event can also be a trace itself if a processor calls another one +internally. + +A trace corresponds to a single processor call. If the processor is called +multiple times, multiple traces will be produced, each containing the input +used to call the processor and the output produced by the call. + +__WARNING__: This is an incubating feature. The trace format is subject to +changes and we do not guarantee backward compatibility at this stage. +""" + +from __future__ import annotations + +import abc +import contextlib +import contextvars +import datetime +from typing import Any + +from absl import logging +from genai_processors import content_api +import pydantic +import shortuuid + + +pydantic_converter = pydantic.TypeAdapter(Any) + + +class Trace(pydantic.BaseModel, abc.ABC): + """A trace of a processor call. + + A trace contains some information about when the processor was called and + includes methods to log input, output and sub-traces to the trace. + + The finalize method must be called to finalize the trace and release any + resources. + + This is up to the implementer to decide how to store the trace. + + The add_sub_trace method should be used to create a new trace. + """ + + model_config = {'arbitrary_types_allowed': True} + + # Name of the trace. + name: str | None = None + + # A description of the processor that produced this trace, i.e. arguments used + # to construct the processor. + processor_description: str | None = None + + # A unique ID for the trace. + trace_id: str = pydantic.Field(default_factory=lambda: str(shortuuid.uuid())) + + # Boolean indicating whether the trace has just been created. This is used to + # determine whether to create a subtrace when a processor is called or using + # the existing trace when it's just been created. + is_new: bool = False + + # The timestamp when the trace was started (the processor was called). + start_time: datetime.datetime = pydantic.Field( + default_factory=datetime.datetime.now + ) + # The timestamp when the trace was ended (the processor returned). + end_time: datetime.datetime | None = None + + _token: contextvars.Token[Trace | None] | None = pydantic.PrivateAttr( + default=None + ) + + async def __aenter__(self) -> Trace: + parent_trace = CURRENT_TRACE.get() + + if parent_trace: + logging.warning( + 'Cannot enter a trace while another trace is already in scope: %s is' + ' ignored in favor of %s', + self, + parent_trace, + ) + + self.is_new = True + self._token = CURRENT_TRACE.set(self) + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + if self._token is None: + return + + self.end_time = datetime.datetime.now() + CURRENT_TRACE.reset(self._token) + await self._finalize() + + @abc.abstractmethod + async def add_input(self, part: content_api.ProcessorPart) -> None: + """Adds an input part to the trace events.""" + raise NotImplementedError() + + @abc.abstractmethod + async def add_output(self, part: content_api.ProcessorPart) -> None: + """Adds an output part to the trace events.""" + raise NotImplementedError() + + @abc.abstractmethod + async def add_sub_trace(self, name: str) -> Trace: + """Adds a sub-trace from a nested processor call to the trace events. + + Args: + name: The name of the sub-trace. + + Returns: + The trace that was added to the trace events. + """ + # TODO(elisseeff, kibergus): consider adding a more generic relationship + # between traces, e.g. traces generated one after another (wiht the + ops) + # or traces generated in parallel (with the // ops). + raise NotImplementedError() + + @abc.abstractmethod + async def _finalize(self) -> None: + """Finalize the trace. + + At this stage, the trace is ready to be stored and/or displayed. It is up + to the implementer to decide how to store the trace. When this function + returns all traces should be considered finalized and stored. + """ + raise NotImplementedError() + + +CURRENT_TRACE: contextvars.ContextVar[Trace | None] = contextvars.ContextVar( + 'current_trace', default=None +) + + +@contextlib.asynccontextmanager +async def call_scope(processor_name: str): + """Context manager for tracing a processor call.""" + parent_trace = CURRENT_TRACE.get() + + if parent_trace is None: + # No tracing in scope - keep things as is. + yield None + elif parent_trace.is_new: + # First call to a processor - re-use the root trace. It has been created + # when the trace_scope was entered. + parent_trace.name = processor_name + parent_trace.is_new = False + yield parent_trace + else: + # Parent is not None and corresponds to an existing trace: adds a new trace. + async with await parent_trace.add_sub_trace( + name=processor_name + ) as new_trace: + yield new_trace diff --git a/dev/trace.tpl.html b/dev/trace.tpl.html new file mode 100644 index 0000000..41ea991 --- /dev/null +++ b/dev/trace.tpl.html @@ -0,0 +1,596 @@ + + + + Trace Viewer + + + + +
+
+ + + diff --git a/dev/trace_file.py b/dev/trace_file.py new file mode 100644 index 0000000..635c57e --- /dev/null +++ b/dev/trace_file.py @@ -0,0 +1,202 @@ +"""Trace implementation that stores traces in two files: JSON and HTML. + +__WARNING__: This is an incubating feature. The trace format is subject to +changes and we do not guarantee backward compatibility at this stage. +""" + +from __future__ import annotations + +import asyncio +import base64 +import datetime +import json +import os +from typing import Any, override + +from genai_processors import content_api +from genai_processors.dev import trace +import pydantic +import shortuuid + +HTML_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), 'trace.tpl.html') +with open(HTML_TEMPLATE_PATH, 'r') as f: + HTML_TEMPLATE = f.read() + +pydantic_converter = pydantic.TypeAdapter(Any) + + +def _bytes_encoder(o: Any) -> Any: + """Encodes bytes in parts based on mime type. + + The dump_python(mode='json') in Pydantic does not encode bytes in utf-8 + mode and this causes issues when sending ProcessorPart to JS/HTML clients + (wrong padding, etc.). This function is used to handle bytes to base64 + encoding to match the behaviour of the JS/HTML side. + + Args: + o: The object to encode. + + Returns: + The encoded object. + """ + if isinstance(o, bytes): + return base64.b64encode(o).decode('utf-8') + else: + return pydantic_converter.dump_python(o, mode='json') + + +# TODO(elisseeff): Adjust the logic to make it less brittle. If a new bytes +# field is added to the ProcessorPart in the future this function will not +# decode it while it should. +def _bytes_decoder(dct: dict[str, Any]) -> Any: + """Decodes base64 encoded bytes in parts based on mime type.""" + if 'data' in dct and 'mime_type' in dct and isinstance(dct['data'], str): + mime_type = dct['mime_type'] + if not mime_type.startswith('text/') and not mime_type.startswith( + 'application/json' + ): + try: + dct['data'] = base64.b64decode(dct['data']) + except (ValueError, TypeError): + pass + return dct + + +class TraceEvent(pydantic.BaseModel): + """A single event in a trace. + + An event represents an input/output part or a sub-trace from a nested + processor call. + + This class is not used in this abstract base class, but is recommend to be + used in the implementations of the trace. + """ + + model_config = {'arbitrary_types_allowed': True} + + # A unique ID for this event. + id: str = pydantic.Field(default_factory=lambda: str(shortuuid.uuid())) + # The timestamp when the event was stored in the trace. + timestamp: datetime.datetime = pydantic.Field( + default_factory=datetime.datetime.now + ) + # Whether the event is an input part to the processor or an output part. + is_input: bool = False + + # The part of the event (as dictionary). None if sub_trace is provided. + # By serializing the part into a dict we ensure that even if the part is + # mutated later, the logged value won't change. + part_dict: dict[str, Any] | None = None + # If set, this event represents a nested processor call via its trace. + sub_trace: SyncFileTrace | None = None + + +class SyncFileTrace(trace.Trace): + """A trace storing events in a file. + + The trace collects all events first in memory and then writes them to a file + when the finalize method is called. + """ + + # Where to store the trace. Required only of the root trace. + trace_dir: str | None = None + + # The events in the trace. Collected in memory. + events: list[TraceEvent] = [] + + def to_json_str(self) -> str: + """Converts the trace to a JSON string.""" + try: + return json.dumps( + self.model_dump(mode='python'), + default=_bytes_encoder, + ) + except TypeError as e: + raise TypeError( + 'Failed to serialize trace to JSON. This might be due to' + ' non-serializable types in ProcessorPart metadata. Ensure parts' + ' added to traces are JSON-serializable.' + ) from e + + @classmethod + def from_json_str(cls, json_str: str) -> trace.Trace: + """Initializes the trace from a JSON string. + + Args: + json_str: The JSON string to initialize the trace from. The bytes field + for audio and image parts are expected to be base64 + utf-8 encoded. + + Returns: + The trace initialized from the JSON string. + """ + return cls.model_validate(json.loads(json_str, object_hook=_bytes_decoder)) + + def _add_part(self, part: content_api.ProcessorPart, is_input: bool) -> None: + """Adds an input or output part to the trace events.""" + event = TraceEvent( + part_dict=part.to_dict(mode='python'), + is_input=is_input, + ) + self.events.append(event) + + @override + async def add_input(self, part: content_api.ProcessorPart) -> None: + """Adds an input part to the trace events.""" + self._add_part(part, is_input=True) + + @override + async def add_output(self, part: content_api.ProcessorPart) -> None: + """Adds an output part to the trace events.""" + self._add_part(part, is_input=False) + + @override + async def add_sub_trace(self, name: str) -> SyncFileTrace: + """Adds a sub-trace from a nested processor call to the trace events.""" + t = SyncFileTrace(name=name) + self.events.append(TraceEvent(sub_trace=t, is_input=False)) + return t + + @override + async def _finalize(self) -> None: + """Saves the trace to a file.""" + if not self.trace_dir: + return + trace_filename = os.path.join( + self.trace_dir, f'{self.name}_{self.trace_id}' + ) + await asyncio.to_thread(self.save, trace_filename + '.json') + await asyncio.to_thread(self.save_html, trace_filename + '.html') + + def save_html(self, path: str) -> None: + """Saves an HTML rendering of the trace to a file.""" + html = HTML_TEMPLATE.format(trace_json=self.to_json_str()) + print(f'html: {path}') + with open(path, 'w') as html_file: + html_file.write(html) + + def save(self, path: str) -> None: + """Saves a trace to a file in JSON format. + + Args: + path: The path to the file. + """ + with open(path, 'w') as html_file: + html_file.write(self.to_json_str()) + + @classmethod + def load(cls, path: str) -> 'SyncFileTrace': + """Reads a trace from a JSON file. + + Args: + path: The path to the file. + + Returns: + The trace. + """ + with open(path, 'r') as html_file: + return SyncFileTrace.model_validate( + json.loads(html_file.read(), object_hook=_bytes_decoder) + ) + + +SyncFileTrace.model_rebuild(force=True) diff --git a/genai_processors/content_api.py b/genai_processors/content_api.py index e2c49f5..415772d 100644 --- a/genai_processors/content_api.py +++ b/genai_processors/content_api.py @@ -1,4 +1,4 @@ -# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,6 +45,17 @@ class ProcessorPart: belongs to, the MIME type of the content, and arbitrary metadata. """ + # GenAI SDK part representation of the ProcessorPart. + _part: genai_types.Part | None = None + # Metadata about the part. Can store arbitrary key/value pairs. + _metadata: dict[str, Any] = {} + # Mime type of the part. + _mimetype: str | None = None + # Role of the part. + _role: str = '' + # Substream name of the part. + _substream_name: str = '' + def __init__( self, value: 'ProcessorPartTypes', @@ -144,10 +155,7 @@ def __init__( if mimetype: self._mimetype = mimetype # Otherwise, if MIME type is specified using inline data, use that. - elif ( - self._part.inline_data - and self._part.inline_data.mime_type - ): + elif self._part.inline_data and self._part.inline_data.mime_type: self._mimetype = self._part.inline_data.mime_type elif self._part.function_call: # OSS library can't depend on protobuf, so we hardcode literal here. @@ -348,7 +356,9 @@ def get_proto_message( ) -> pb_message.Message: """Returns representation of the Part as a given proto message.""" if not mime_types.is_proto_message(self.mimetype, proto_message): - raise ValueError('Part is not a proto message.') + raise ValueError( + f'Part is not a {proto_message.DESCRIPTOR.name} proto message.' + ) return proto_message.FromString(self.bytes) @property diff --git a/genai_processors/processor.py b/genai_processors/processor.py index 8653b8f..172cb75 100644 --- a/genai_processors/processor.py +++ b/genai_processors/processor.py @@ -32,6 +32,7 @@ from genai_processors import map_processor from genai_processors import mime_types from genai_processors import streams +from genai_processors.dev import trace # Aliases context = context_lib.context @@ -127,7 +128,8 @@ async def __call__( Yields: the result of processing the input content. """ - content = _normalize_part_stream(content, producer=self) + normalized_content = _normalize_part_stream(content, producer=self) + # Ensures that the same taskgroup is always added to the context and # includes the proper way of handling generators, i.e. use a queue inside # the task group instead of a generator. @@ -143,30 +145,47 @@ async def __call__( # always executed within the task group and that the `CancelledError` is # handled correctly. tg = context_lib.task_group() - if tg is None: - output_queue = asyncio.Queue[ProcessorPart | None]() - - async def _with_context(): - async with context(): - try: - async for p in _normalize_part_stream( - self.call(content), producer=self.call - ): - output_queue.put_nowait(p) - finally: - output_queue.put_nowait(None) - - task = asyncio.create_task(_with_context()) - try: - async for p in streams.dequeue(output_queue): + + async with trace.call_scope(self.key_prefix) as current_trace: + + if current_trace: + + async def stream_input() -> AsyncIterable[ProcessorPart]: + async for part in normalized_content: + await current_trace.add_input(part) + yield part + + else: + stream_input = lambda: normalized_content + + if tg is None: + output_queue = asyncio.Queue[ProcessorPart | None]() + + async def _with_context(): + async with context(): + try: + async for p in _normalize_part_stream( + self.call(stream_input()), producer=self.call + ): + if current_trace: + await current_trace.add_output(p) + output_queue.put_nowait(p) + finally: + output_queue.put_nowait(None) + + task = asyncio.create_task(_with_context()) + try: + async for p in streams.dequeue(output_queue): + yield p + finally: + await task + else: + async for p in _normalize_part_stream( + self.call(stream_input()), producer=self.call + ): + if current_trace: + await current_trace.add_output(p) yield p - finally: - await task - else: - async for p in _normalize_part_stream( - self.call(content), producer=self.call - ): - yield p @abc.abstractmethod async def call( diff --git a/genai_processors/tests/content_api_test.py b/genai_processors/tests/content_api_test.py index 08cf668..86b82e8 100644 --- a/genai_processors/tests/content_api_test.py +++ b/genai_processors/tests/content_api_test.py @@ -427,6 +427,28 @@ def test_get_proto_message_raises_error_with_incorrect_mimetype(self): with self.assertRaises(ValueError): test_part.get_proto_message(struct_pb2.Struct) + def test_get_proto_message_from_bytes(self): + test_proto = struct_pb2.Struct( + fields={'foo': struct_pb2.Value(string_value='bar')} + ) + part = content_api.ProcessorPart( + test_proto.SerializeToString(), + mimetype=mime_types.proto_message_mime_type(struct_pb2.Struct), + ) + self.assertEqual(part.get_proto_message(struct_pb2.Struct), test_proto) + + def test_get_proto_message_raises_error_with_incorrect_proto_type(self): + test_proto = struct_pb2.Struct( + fields={'foo': struct_pb2.Value(string_value='bar')} + ) + part = content_api.ProcessorPart.from_proto_message( + proto_message=test_proto + ) + with self.assertRaisesRegex( + ValueError, 'Part is not a Duration proto message.' + ): + part.get_proto_message(duration_pb2.Duration) + def test_is_text(self): text_part = content_api.ProcessorPart('hello') self.assertTrue(mime_types.is_text(text_part.mimetype)) diff --git a/genai_processors/tests/trace_file_test.py b/genai_processors/tests/trace_file_test.py new file mode 100644 index 0000000..88745a7 --- /dev/null +++ b/genai_processors/tests/trace_file_test.py @@ -0,0 +1,197 @@ +import asyncio +from collections.abc import AsyncIterable +import io +import json +import os +import shutil +from typing import cast +import unittest + +from absl.testing import absltest +from genai_processors import content_api +from genai_processors import mime_types +from genai_processors import processor +from genai_processors import streams +from genai_processors.dev import trace_file +import numpy as np +from PIL import Image +from scipy.io import wavfile + + +@processor.processor_function +async def to_upper_fn( + content: AsyncIterable[content_api.ProcessorPart], +) -> AsyncIterable[content_api.ProcessorPartTypes]: + async for part in content: + await asyncio.sleep(0.01) # to ensure timestamps are different + if mime_types.is_text(part.mimetype): + yield part.text.upper() + '_sub_trace' + else: + yield part + + +class SubTraceProcessor(processor.Processor): + + def __init__(self): + super().__init__() + self.sub_processor = to_upper_fn + + async def call( + self, content: AsyncIterable[content_api.ProcessorPart] + ) -> AsyncIterable[content_api.ProcessorPartTypes]: + async for part in self.sub_processor(content): + if isinstance(part, content_api.ProcessorPart) and mime_types.is_text( + part.mimetype + ): + yield part.text + '_outer' + else: + yield part + + +class TraceTest(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + super().setUp() + self.trace_dir = os.path.join(absltest.get_default_test_tmpdir(), 'traces') + os.makedirs(self.trace_dir, exist_ok=True) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.trace_dir) + + async def test_trace_generation_and_timestamps(self): + p = SubTraceProcessor() + input_parts = [content_api.ProcessorPart('hello')] + async with trace_file.SyncFileTrace(trace_dir=self.trace_dir): + results = await streams.gather_stream( + p(streams.stream_content(input_parts)) + ) + + self.assertEqual(results[0].text, 'HELLO_sub_trace_outer') + json_files = [f for f in os.listdir(self.trace_dir) if f.endswith('.json')] + self.assertTrue(len(json_files), 1) + trace_path = os.path.join(self.trace_dir, json_files[0]) + self.assertTrue(os.path.exists(trace_path.replace('.json', '.html'))) + + trace = trace_file.SyncFileTrace.load(trace_path) + + # First event is a subtrace for the upper function. This is was is first + # entered in the trace scope. + self.assertFalse(trace.events[0].is_input) + sub_trace = cast(trace_file.SyncFileTrace, trace.events[0].sub_trace) + self.assertIsNotNone(sub_trace) + self.assertIn('to_upper_fn', sub_trace.name) + self.assertFalse(sub_trace.events[1].is_input) + self.assertEqual( + sub_trace.events[1].part_dict['part']['text'], 'HELLO_sub_trace' + ) + self.assertIsNotNone(sub_trace.start_time) + self.assertIsNotNone(sub_trace.end_time) + self.assertLess(sub_trace.start_time, sub_trace.end_time) + + # Second input event is the input part to SubTraceProcessor. + self.assertTrue(trace.events[1].is_input) + self.assertEqual(trace.events[1].part_dict['part']['text'], 'hello') + + # Third event is the output event of SubTraceProcessor + self.assertFalse(trace.events[2].is_input) + self.assertEqual( + trace.events[2].part_dict['part']['text'], 'HELLO_sub_trace_outer' + ) + + async def test_trace_references(self): + p = SubTraceProcessor() + input_part = content_api.ProcessorPart('world') + # First call + async with trace_file.SyncFileTrace(trace_dir=self.trace_dir): + await streams.gather_stream(p(streams.stream_content([input_part]))) + + json_files = [f for f in os.listdir(self.trace_dir) if f.endswith('.json')] + self.assertTrue(len(json_files), 1) + trace1_path = os.path.join(self.trace_dir, json_files[0]) + trace1 = trace_file.SyncFileTrace.load(trace1_path) + self.assertTrue(trace1.events[1].is_input) + self.assertEqual(trace1.events[1].part_dict['part']['text'], 'world') + + sub_trace1 = cast(trace_file.SyncFileTrace, trace1.events[0].sub_trace) + self.assertIsNotNone(sub_trace1) + self.assertTrue(sub_trace1.events[0].is_input) + self.assertIsNotNone(sub_trace1.events[0].part_dict) + + # Second call with same part + for f in os.listdir(self.trace_dir): + os.remove(os.path.join(self.trace_dir, f)) + async with trace_file.SyncFileTrace(trace_dir=self.trace_dir): + await streams.gather_stream(p(streams.stream_content([input_part]))) + json_files = [f for f in os.listdir(self.trace_dir) if f.endswith('.json')] + self.assertTrue(len(json_files), 1) + trace2_path = os.path.join(self.trace_dir, json_files[0]) + trace2 = trace_file.SyncFileTrace.load(trace2_path) + self.assertTrue(trace2.events[1].is_input) + self.assertIsNotNone(trace2.events[1].part_dict) + + async def test_trace_save_load(self): + trace = trace_file.SyncFileTrace(name='test') + await trace.add_input(content_api.ProcessorPart('in')) + await trace.add_input( + content_api.ProcessorPart.from_bytes( + data=b'bytes', + mimetype='image/jpeg', + ) + ) + sub_trace = await trace.add_sub_trace(name='sub_test') + await sub_trace.add_input(content_api.ProcessorPart('sub_in')) + await sub_trace.add_output(content_api.ProcessorPart('sub_out')) + await trace.add_output(content_api.ProcessorPart('out')) + + tmpdir = absltest.get_default_test_tmpdir() + trace_path = os.path.join(tmpdir, 'trace.json') + + trace.save(trace_path) + loaded_trace = trace_file.SyncFileTrace.load(trace_path) + + self.assertEqual( + json.loads(trace.model_dump_json()), + json.loads(loaded_trace.model_dump_json()), + ) + + async def test_save_html(self): + p = SubTraceProcessor() + trace_dir = os.getenv('TEST_UNDECLARED_OUTPUTS_DIR') + + # Create a small green image using PIL + img = Image.new('RGB', (10, 10), color='green') + img_bytes_io = io.BytesIO() + img.save(img_bytes_io, format='PNG') + img_part = content_api.ProcessorPart.from_bytes( + data=img_bytes_io.getvalue(), + mimetype='image/png', + ) + + # Generate a small random WAV audio part + sample_rate = 16000 # samples per second + duration = 0.1 # seconds + num_samples = int(sample_rate * duration) + # Generate random samples between -1 and 1 + random_samples = np.random.uniform(-1, 1, num_samples) + # Scale to int16 range + audio_data = (random_samples * 32767).astype(np.int16) + + audio_bytes_io = io.BytesIO() + wavfile.write(audio_bytes_io, sample_rate, audio_data) + audio_part = content_api.ProcessorPart.from_bytes( + data=audio_bytes_io.getvalue(), + mimetype='audio/wav', + ) + parts = [img_part, audio_part, content_api.ProcessorPart('hello')] + async with trace_file.SyncFileTrace(trace_dir=trace_dir): + await processor.apply_async(p, parts) + + html_files = [f for f in os.listdir(trace_dir) if f.endswith('.html')] + self.assertTrue(len(html_files), 1) + trace_path = os.path.join(trace_dir, html_files[0]) + self.assertTrue(os.path.exists(trace_path)) + + +if __name__ == '__main__': + absltest.main()