diff --git a/src/scorecard_ai/lib/__init__.py b/src/scorecard_ai/lib/__init__.py index b0b981e..2e9aabf 100644 --- a/src/scorecard_ai/lib/__init__.py +++ b/src/scorecard_ai/lib/__init__.py @@ -1,4 +1,5 @@ from ._helpers import ( + SystemOptions, run_and_evaluate, async_run_and_evaluate, ) @@ -18,6 +19,7 @@ __all__ = [ "run_and_evaluate", "async_run_and_evaluate", + "SystemOptions", "StopCheck", "StopChecks", "ChatMessage", diff --git a/src/scorecard_ai/lib/_helpers.py b/src/scorecard_ai/lib/_helpers.py index bca11ff..6341e3a 100644 --- a/src/scorecard_ai/lib/_helpers.py +++ b/src/scorecard_ai/lib/_helpers.py @@ -4,7 +4,9 @@ from __future__ import annotations +import uuid import asyncio +import inspect from typing import Any, Dict, List, TypeVar, Callable, Coroutine from collections.abc import Generator, AsyncGenerator from typing_extensions import TypedDict @@ -20,6 +22,31 @@ _T = TypeVar("_T") +class SystemOptions(TypedDict): + """Options passed to the system function for each testcase execution.""" + + otel_link_id: str + """A unique ID for linking this execution with its OpenTelemetry trace. + Set this as an attribute on your OTel span (e.g. ``scorecard.otel_link_id``) + to deduplicate SDK records with trace-created records.""" + + +def _system_accepts_options(system: Callable[..., Any]) -> bool: + """Check if the system function accepts a third 'options' parameter.""" + try: + sig = inspect.signature(system) + params = list(sig.parameters.values()) + # Accepts options if it has 3+ positional params + positional_kinds = ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + positional = [p for p in params if p.kind in positional_kinds] + return len(positional) >= 3 + except (ValueError, TypeError): + return False + + def _omit_if_not_given(value: _T | NotGiven) -> _T | Omit: """ Converts NotGiven sentinel to Omit sentinel for API calls. @@ -82,7 +109,7 @@ def run_and_evaluate( testset_id: str | NotGiven = NOT_GIVEN, testcases: List[Testcase] | List[SimpleTestcase] | NotGiven = NOT_GIVEN, system_version_id: str | NotGiven = NOT_GIVEN, - system: Callable[[SystemInput, SystemVersion | None], SystemOutput], + system: Callable[..., SystemOutput], trials: int = 1, ) -> RunResponse: """ @@ -103,7 +130,8 @@ def run_and_evaluate( system_version_id: The ID of the SystemVersion to use for the run. - system: The system to run on the Testset. + system: The system to run on the Testset. Receives the testcase input and system version (or None). + Optionally accepts a third ``SystemOptions`` argument containing ``otel_link_id`` for trace deduplication. trials: The number of times to run the system on each Testcase. """ @@ -131,16 +159,24 @@ def run_and_evaluate( client.systems.versions.get(system_version_id) if not isinstance(system_version_id, NotGiven) else None ) + accepts_options = _system_accepts_options(system) + # Run each Testcase sequentially for testcase in testcase_iter: for _ in range(trials): - model_response = system(testcase["inputs"], system_version) + otel_link_id = str(uuid.uuid4()) + options = SystemOptions(otel_link_id=otel_link_id) + if accepts_options: + model_response = system(testcase["inputs"], system_version, options) + else: + model_response = system(testcase["inputs"], system_version) client.records.create( run_id=run.id, testcase_id=_omit_if_not_given(testcase["id"]), inputs=testcase["inputs"], expected=testcase["expected"], outputs=model_response, + extra_body={"otelLinkId": otel_link_id}, ) return RunResponse(id=run.id, url=_get_run_url(client, project_id, run.id)) @@ -154,7 +190,7 @@ async def async_run_and_evaluate( testset_id: str | NotGiven = NOT_GIVEN, testcases: List[Testcase] | List[SimpleTestcase] | NotGiven = NOT_GIVEN, system_version_id: str | NotGiven = NOT_GIVEN, - system: Callable[[SystemInput, SystemVersion | None], SystemOutput], + system: Callable[..., SystemOutput], trials: int = 1, ) -> RunResponse: """ @@ -175,7 +211,8 @@ async def async_run_and_evaluate( system_version_id: The ID of the SystemVersion to use for the run. - system: The system to run on the Testset. + system: The system to run on the Testset. Receives the testcase input and system version (or None). + Optionally accepts a third ``SystemOptions`` argument containing ``otel_link_id`` for trace deduplication. trials: The number of times to run the system on each Testcase. """ @@ -203,16 +240,24 @@ async def async_run_and_evaluate( await client.systems.versions.get(system_version_id) if not isinstance(system_version_id, NotGiven) else None ) + accepts_options = _system_accepts_options(system) + def run_testcase( testcase: _SimpleTestcaseWithId, ) -> Coroutine[Any, Any, Record]: - model_response = system(testcase["inputs"], system_version) + otel_link_id = str(uuid.uuid4()) + options = SystemOptions(otel_link_id=otel_link_id) + if accepts_options: + model_response = system(testcase["inputs"], system_version, options) + else: + model_response = system(testcase["inputs"], system_version) return client.records.create( run_id=run.id, testcase_id=_omit_if_not_given(testcase["id"]), inputs=testcase["inputs"], expected=testcase["expected"], outputs=model_response, + extra_body={"otelLinkId": otel_link_id}, ) # Create a Record for each Testcase