Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/scorecard_ai/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._helpers import (
SystemOptions,
run_and_evaluate,
async_run_and_evaluate,
)
Expand All @@ -18,6 +19,7 @@
__all__ = [
"run_and_evaluate",
"async_run_and_evaluate",
"SystemOptions",
"StopCheck",
"StopChecks",
"ChatMessage",
Expand Down
57 changes: 51 additions & 6 deletions src/scorecard_ai/lib/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from __future__ import annotations

import uuid
import asyncio
import inspect
from typing import Any, Dict, List, TypeVar, Callable, Coroutine
from collections.abc import Generator, AsyncGenerator
from typing_extensions import TypedDict
Expand All @@ -20,6 +22,31 @@
_T = TypeVar("_T")


class SystemOptions(TypedDict):
"""Options passed to the system function for each testcase execution."""

otel_link_id: str
"""A unique ID for linking this execution with its OpenTelemetry trace.
Set this as an attribute on your OTel span (e.g. ``scorecard.otel_link_id``)
to deduplicate SDK records with trace-created records."""


def _system_accepts_options(system: Callable[..., Any]) -> bool:
"""Check if the system function accepts a third 'options' parameter."""
try:
sig = inspect.signature(system)
params = list(sig.parameters.values())
# Accepts options if it has 3+ positional params
positional_kinds = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
positional = [p for p in params if p.kind in positional_kinds]
return len(positional) >= 3
except (ValueError, TypeError):
return False


def _omit_if_not_given(value: _T | NotGiven) -> _T | Omit:
"""
Converts NotGiven sentinel to Omit sentinel for API calls.
Expand Down Expand Up @@ -82,7 +109,7 @@ def run_and_evaluate(
testset_id: str | NotGiven = NOT_GIVEN,
testcases: List[Testcase] | List[SimpleTestcase] | NotGiven = NOT_GIVEN,
system_version_id: str | NotGiven = NOT_GIVEN,
system: Callable[[SystemInput, SystemVersion | None], SystemOutput],
system: Callable[..., SystemOutput],
trials: int = 1,
) -> RunResponse:
"""
Expand All @@ -103,7 +130,8 @@ def run_and_evaluate(

system_version_id: The ID of the SystemVersion to use for the run.

system: The system to run on the Testset.
system: The system to run on the Testset. Receives the testcase input and system version (or None).
Optionally accepts a third ``SystemOptions`` argument containing ``otel_link_id`` for trace deduplication.

trials: The number of times to run the system on each Testcase.
"""
Expand Down Expand Up @@ -131,16 +159,24 @@ def run_and_evaluate(
client.systems.versions.get(system_version_id) if not isinstance(system_version_id, NotGiven) else None
)

accepts_options = _system_accepts_options(system)

# Run each Testcase sequentially
for testcase in testcase_iter:
for _ in range(trials):
model_response = system(testcase["inputs"], system_version)
otel_link_id = str(uuid.uuid4())
options = SystemOptions(otel_link_id=otel_link_id)
if accepts_options:
model_response = system(testcase["inputs"], system_version, options)
else:
model_response = system(testcase["inputs"], system_version)
client.records.create(
run_id=run.id,
testcase_id=_omit_if_not_given(testcase["id"]),
inputs=testcase["inputs"],
expected=testcase["expected"],
outputs=model_response,
extra_body={"otelLinkId": otel_link_id},
)

return RunResponse(id=run.id, url=_get_run_url(client, project_id, run.id))
Expand All @@ -154,7 +190,7 @@ async def async_run_and_evaluate(
testset_id: str | NotGiven = NOT_GIVEN,
testcases: List[Testcase] | List[SimpleTestcase] | NotGiven = NOT_GIVEN,
system_version_id: str | NotGiven = NOT_GIVEN,
system: Callable[[SystemInput, SystemVersion | None], SystemOutput],
system: Callable[..., SystemOutput],
trials: int = 1,
) -> RunResponse:
"""
Expand All @@ -175,7 +211,8 @@ async def async_run_and_evaluate(

system_version_id: The ID of the SystemVersion to use for the run.

system: The system to run on the Testset.
system: The system to run on the Testset. Receives the testcase input and system version (or None).
Optionally accepts a third ``SystemOptions`` argument containing ``otel_link_id`` for trace deduplication.

trials: The number of times to run the system on each Testcase.
"""
Expand Down Expand Up @@ -203,16 +240,24 @@ async def async_run_and_evaluate(
await client.systems.versions.get(system_version_id) if not isinstance(system_version_id, NotGiven) else None
)

accepts_options = _system_accepts_options(system)

def run_testcase(
testcase: _SimpleTestcaseWithId,
) -> Coroutine[Any, Any, Record]:
model_response = system(testcase["inputs"], system_version)
otel_link_id = str(uuid.uuid4())
options = SystemOptions(otel_link_id=otel_link_id)
if accepts_options:
model_response = system(testcase["inputs"], system_version, options)
else:
model_response = system(testcase["inputs"], system_version)
return client.records.create(
run_id=run.id,
testcase_id=_omit_if_not_given(testcase["id"]),
inputs=testcase["inputs"],
expected=testcase["expected"],
outputs=model_response,
extra_body={"otelLinkId": otel_link_id},
)

# Create a Record for each Testcase
Expand Down