Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ ignore_missing_imports = false

[tool.ruff]
line-length = 120
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"]
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"]
exclude = ["**/*.md"]

[tool.ruff.lint]
Expand All @@ -80,6 +80,7 @@ select = [
"G", # logging format
"I", # isort
"LOG", # logging
"UP", # pyupgrade
]

[tool.ruff.lint.per-file-ignores]
Expand Down
7 changes: 3 additions & 4 deletions scripts/bump_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional, Tuple


def get_current_version() -> str:
Expand All @@ -16,7 +15,7 @@ def get_current_version() -> str:
return match.group(1)


def parse_version(version: str) -> Tuple[int, int, int, Optional[str]]:
def parse_version(version: str) -> tuple[int, int, int, str | None]:
"""Parse semantic version string."""
match = re.match(r"(\d+)\.(\d+)\.(\d+)(?:-(.+))?", version)
if not match:
Expand Down Expand Up @@ -135,7 +134,7 @@ def format_git_log(git_log: str) -> str:
return "\n\n".join(sections)


def get_git_log(since_tag: Optional[str] = None) -> str:
def get_git_log(since_tag: str | None = None) -> str:
"""Get git commit messages since last tag."""
cmd = ["git", "log", "--pretty=format:- %s (%h)"]
if since_tag:
Expand Down Expand Up @@ -219,7 +218,7 @@ def main():
print(f"\n✓ Version bumped from {current} to {new}")
print("\nNext steps:")
print("1. Review changes: git diff")
print("2. Commit: git add -A && git commit -m 'chore: bump version to {}'".format(new))
print(f"2. Commit: git add -A && git commit -m 'chore: bump version to {new}'")
print("3. Create PR or push to trigger release workflow")

except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import logging
from typing import Any, List, Optional
from typing import Any

import boto3
from botocore.config import Config
Expand Down Expand Up @@ -40,7 +40,7 @@ def _validate_spans(spans):
return hasattr(first_span, "context") and hasattr(first_span, "instrumentation_scope")


def _is_adot_format(spans: List[Any]) -> bool:
def _is_adot_format(spans: list[Any]) -> bool:
"""Check if spans are already in ADOT format.

ADOT format is detected by presence of 'scope' dict with 'name' field.
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
evaluator_id: str,
region: str = DEFAULT_REGION,
test_pass_score: float = 0.7,
config: Optional[Config] = None,
config: Config | None = None,
):
"""Initialize the evaluator.

Expand All @@ -109,7 +109,7 @@ def _get_default_config() -> Config:
read_timeout=300,
)

def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> List[EvaluationOutput]:
def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]:
"""Evaluate agent output using AgentCore Evaluation API.

Args:
Expand Down Expand Up @@ -153,7 +153,7 @@ def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> List[Eva
for r in response["evaluationResults"]
]

async def evaluate_async(self, evaluation_case: EvaluationData[InputT, OutputT]) -> List[EvaluationOutput]:
async def evaluate_async(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]:
"""Evaluate agent output asynchronously using AgentCore Evaluation API.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any

logger = logging.getLogger(__name__)

Expand All @@ -27,7 +27,7 @@ class SpanMetadata:

trace_id: str
span_id: str
parent_span_id: Optional[str]
parent_span_id: str | None
name: str
start_time: int
end_time: int
Expand All @@ -41,7 +41,7 @@ class SpanMetadata:
class ResourceInfo:
"""Span resource and scope information."""

resource_attributes: Dict[str, Any]
resource_attributes: dict[str, Any]
scope_name: str
scope_version: str

Expand All @@ -51,8 +51,8 @@ class ConversationTurn:
"""A single user-assistant conversation turn."""

user_message: str
assistant_messages: List[Dict[str, Any]]
tool_results: List[str]
assistant_messages: list[dict[str, Any]]
tool_results: list[str]


@dataclass
Expand Down Expand Up @@ -115,7 +115,7 @@ def extract_resource_info(span) -> ResourceInfo:
)

@staticmethod
def get_span_attributes(span) -> Dict[str, Any]:
def get_span_attributes(span) -> dict[str, Any]:
"""Safely extract span attributes."""
return dict(span.attributes) if hasattr(span, "attributes") and span.attributes else {}

Expand All @@ -140,8 +140,8 @@ class ADOTDocumentBuilder:
def build_span_document(
metadata: SpanMetadata,
resource_info: ResourceInfo,
attributes: Dict[str, Any],
) -> Dict[str, Any]:
attributes: dict[str, Any],
) -> dict[str, Any]:
"""Build ADOT span document."""
return {
"resource": {"attributes": resource_info.resource_attributes},
Expand All @@ -167,8 +167,8 @@ def _build_log_record_base(
cls,
metadata: SpanMetadata,
resource_info: ResourceInfo,
body: Dict[str, Any],
) -> Dict[str, Any]:
body: dict[str, Any],
) -> dict[str, Any]:
"""Build base ADOT log record structure shared by all log types."""
return {
"resource": {"attributes": resource_info.resource_attributes},
Expand All @@ -190,7 +190,7 @@ def build_conversation_log_record(
conversation: ConversationTurn,
metadata: SpanMetadata,
resource_info: ResourceInfo,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Build ADOT log record for conversation turn."""
output_messages = []
for i, msg in enumerate(conversation.assistant_messages):
Expand All @@ -217,7 +217,7 @@ def build_tool_log_record(
tool_exec: ToolExecution,
metadata: SpanMetadata,
resource_info: ResourceInfo,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Build ADOT log record for tool execution."""
body = {
"output": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

import logging
from typing import Any, Dict, List, Optional
from typing import Any

from .adot_models import (
ADOTDocumentBuilder,
Expand All @@ -36,7 +36,7 @@ class StrandsEventParser:
EVENT_TOOL_MESSAGE = "gen_ai.tool.message"

@classmethod
def extract_conversation_turn(cls, events: List[Any]) -> Optional[ConversationTurn]:
def extract_conversation_turn(cls, events: list[Any]) -> ConversationTurn | None:
"""Extract conversation turn from Strands span events."""
user_message = None
assistant_messages = []
Expand Down Expand Up @@ -83,7 +83,7 @@ def extract_conversation_turn(cls, events: List[Any]) -> Optional[ConversationTu
return None

@classmethod
def extract_tool_execution(cls, events: List[Any]) -> Optional[ToolExecution]:
def extract_tool_execution(cls, events: list[Any]) -> ToolExecution | None:
"""Extract tool execution from Strands span events."""
tool_input = ""
tool_output = ""
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(self):
self.event_parser = StrandsEventParser()
self.doc_builder = ADOTDocumentBuilder()

def convert_span(self, span) -> List[Dict[str, Any]]:
def convert_span(self, span) -> list[dict[str, Any]]:
"""Convert a single span to ADOT documents."""
documents = []

Expand Down Expand Up @@ -160,7 +160,7 @@ def convert_span(self, span) -> List[Dict[str, Any]]:

return documents

def convert(self, raw_spans: List[Any]) -> List[Dict[str, Any]]:
def convert(self, raw_spans: list[Any]) -> list[dict[str, Any]]:
"""Convert list of Strands OTel spans to ADOT documents."""
documents = []
for span in raw_spans:
Expand All @@ -174,7 +174,7 @@ def convert(self, raw_spans: List[Any]) -> List[Dict[str, Any]]:
# ==============================================================================


def convert_strands_to_adot(raw_spans: List[Any]) -> List[Dict[str, Any]]:
def convert_strands_to_adot(raw_spans: list[Any]) -> list[dict[str, Any]]:
"""Convert Strands OTel spans to ADOT format for AgentCore evaluation.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import time
from datetime import datetime
from typing import Any, List
from typing import Any

import boto3

Expand Down Expand Up @@ -43,7 +43,7 @@ def query_log_group(
session_id: str,
start_time: datetime,
end_time: datetime,
) -> List[dict]:
) -> list[dict]:
"""Query a single CloudWatch log group for session data.

Args:
Expand Down Expand Up @@ -133,7 +133,7 @@ def fetch_spans(
session_id: str,
event_log_group: str,
start_time: datetime,
) -> List[dict]:
) -> list[dict]:
"""Fetch ADOT spans from CloudWatch with configurable event log group.

ADOT spans are always fetched from aws/spans. Event logs can be fetched from
Expand Down Expand Up @@ -185,7 +185,7 @@ def fetch_spans_from_cloudwatch(
event_log_group: str,
start_time: datetime,
region: str = DEFAULT_REGION,
) -> List[dict]:
) -> list[dict]:
"""Fetch ADOT spans from CloudWatch with configurable event log group.

Convenience function that creates a CloudWatchSpanFetcher and fetches spans.
Expand Down
23 changes: 12 additions & 11 deletions src/bedrock_agentcore/identity/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import contextvars
import logging
import os
from collections.abc import Callable
from functools import wraps
from typing import Any, Callable, Dict, List, Literal, Optional
from typing import Any, Literal

import boto3
from botocore.exceptions import ClientError
Expand All @@ -23,14 +24,14 @@ def requires_access_token(
*,
provider_name: str,
into: str = "access_token",
scopes: List[str],
on_auth_url: Optional[Callable[[str], Any]] = None,
scopes: list[str],
on_auth_url: Callable[[str], Any] | None = None,
auth_flow: Literal["M2M", "USER_FEDERATION"],
callback_url: Optional[str] = None,
callback_url: str | None = None,
force_authentication: bool = False,
token_poller: Optional[TokenPoller] = None,
custom_state: Optional[str] = None,
custom_parameters: Optional[Dict[str, str]] = None,
token_poller: TokenPoller | None = None,
custom_state: str | None = None,
custom_parameters: dict[str, str] | None = None,
) -> Callable:
"""Decorator that fetches an OAuth2 access token before calling the decorated function.

Expand Down Expand Up @@ -103,10 +104,10 @@ def sync_wrapper(*args: Any, **kwargs_func: Any) -> Any:

def requires_iam_access_token(
*,
audience: List[str],
audience: list[str],
signing_algorithm: str = "ES384",
duration_seconds: int = 300,
tags: Optional[List[Dict[str, str]]] = None,
tags: list[dict[str, str]] | None = None,
into: str = "access_token",
) -> Callable:
"""Decorator that fetches an AWS IAM JWT token before calling the decorated function.
Expand Down Expand Up @@ -268,7 +269,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
return decorator


def _get_oauth2_callback_url(user_provided_oauth2_callback_url: Optional[str]):
def _get_oauth2_callback_url(user_provided_oauth2_callback_url: str | None):
if user_provided_oauth2_callback_url:
return user_provided_oauth2_callback_url

Expand Down Expand Up @@ -301,7 +302,7 @@ async def _set_up_local_auth(client: IdentityClient) -> str:
config = {}
if config_path.exists():
try:
with open(config_path, "r", encoding="utf-8") as file:
with open(config_path, encoding="utf-8") as file:
config = json.load(file) or {}
except Exception:
print("Could not find existing workload identity and user id")
Expand Down
Loading