From e0d0678b2cd445d37f8e05ef0d214c4bfacaab3e Mon Sep 17 00:00:00 2001 From: shegx01 Date: Wed, 10 Dec 2025 17:11:32 +0200 Subject: [PATCH 1/6] feat: Amazon Bedrock adapter support -- :json_schema --- .vscode/settings.json | 3 + lib/instructor/adapters/bedrock.ex | 522 ++++++++++++++++++++++ lib/instructor/aws_event_stream_parser.ex | 148 ++++++ test/aws_event_stream_parser_test.exs | 121 +++++ test/instructor_test.exs | 26 +- test/support/test_helpers.ex | 79 ++++ test/test_helper.exs | 7 +- 7 files changed, 900 insertions(+), 6 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 lib/instructor/adapters/bedrock.ex create mode 100644 lib/instructor/aws_event_stream_parser.ex create mode 100644 test/aws_event_stream_parser_test.exs diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..a8b67ce --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "cSpell.enabled": false +} diff --git a/lib/instructor/adapters/bedrock.ex b/lib/instructor/adapters/bedrock.ex new file mode 100644 index 0000000..f4febad --- /dev/null +++ b/lib/instructor/adapters/bedrock.ex @@ -0,0 +1,522 @@ +defmodule Instructor.Adapters.Bedrock do + @moduledoc """ + AWS Bedrock adapter for Instructor using the Converse API. + + Uses the unified Converse API which provides a consistent interface across + all Bedrock models (Claude, Llama, Titan, Mistral, etc.). + + ## Configuration + + Configure the Bedrock adapter with bearer token authentication: + + config :instructor, + adapter: Instructor.Adapters.Bedrock, + bedrock: [ + api_key: "your_bearer_token", # defaults to AWS_BEARER_TOKEN_BEDROCK env var + auth_mode: :bearer, # authentication mode (default: :bearer) + http_options: [receive_timeout: 60_000] + ] + + The region is **auto-detected from the bearer token**. You can override it: + + config :instructor, + adapter: Instructor.Adapters.Bedrock, + bedrock: [ + region: "us-west-2", # optional: overrides auto-detection + api_key: "your_bearer_token" + ] + + Region priority: explicit config > AWS_REGION env var > auto-detected from token > "us-east-1" + + You can also provide a custom runtime URL or dynamic token: + + config :instructor, + adapter: Instructor.Adapters.Bedrock, + bedrock: [ + runtime_url: "https://custom-bedrock-endpoint.example.com", + api_key: fn -> get_dynamic_token() end + ] + + ## Usage + + Instructor.chat_completion( + model: "anthropic.claude-3-5-sonnet-20240620-v1:0", + response_model: MySchema, + messages: [%{role: "user", content: "..."}] + ) + + ## Supported Models + + All models that support the Bedrock Converse API: + - **Anthropic Claude**: `anthropic.claude-*` + - **Meta Llama**: `meta.llama*` + - **Amazon Titan**: `amazon.titan-*` + - **Cohere**: `cohere.*` + - **Mistral**: `mistral.*` + """ + + @behaviour Instructor.Adapter + + alias Instructor.AWSEventStreamParser + + @supported_modes [:tools, :json, :md_json] + @default_max_tokens 4096 + + @impl true + def chat_completion(params, user_config \\ nil) do + config = config(user_config) + model_id = Keyword.fetch!(params, :model) + messages = Keyword.fetch!(params, :messages) + mode = Keyword.get(params, :mode, :tools) + max_tokens = Keyword.get(params, :max_tokens, @default_max_tokens) + temperature = Keyword.get(params, :temperature, 1.0) + tools = Keyword.get(params, :tools, []) + stream = Keyword.get(params, :stream, false) + + if mode not in @supported_modes do + raise "Unsupported Bedrock mode #{mode}. Supported modes: #{inspect(@supported_modes)}" + end + + body = build_converse_body(messages, max_tokens, temperature, tools) + + if stream do + do_streaming_chat_completion(mode, model_id, body, config) + else + do_chat_completion(mode, model_id, body, tools, config) + end + end + + defp do_chat_completion(mode, model_id, body, tools, config) do + case converse(model_id, body, config) do + {:ok, response} -> + parse_response(mode, response, tools) + + {:error, _reason} = error -> + error + end + end + + defp do_streaming_chat_completion(mode, model_id, body, config) do + pid = self() + ref = make_ref() + url = build_stream_url(model_id, config) + options = build_stream_request_options(body, config) + + Stream.resource( + fn -> + Task.async(fn -> + options = + Keyword.merge(options, + into: fn {:data, data}, {req, resp} -> + send(pid, {ref, data}) + {:cont, {req, resp}} + end + ) + + Req.post(url, options) + send(pid, {ref, :done}) + end) + end, + fn task -> + receive do + {^ref, :done} -> + {:halt, task} + + {^ref, data} -> + {[data], task} + after + 30_000 -> + raise "Timeout waiting for Bedrock streaming response" + end + end, + fn _ -> nil end + ) + |> AWSEventStreamParser.parse() + |> Stream.map(&parse_stream_chunk_for_mode(mode, &1)) + end + + @impl true + def reask_messages(raw_response, params, _config) do + reask_messages_for_mode(params[:mode], raw_response) + end + + defp reask_messages_for_mode(:tools, %{ + "choices" => [ + %{ + "message" => %{ + "tool_calls" => [ + %{ + "id" => tool_call_id, + "function" => %{"name" => _name, "arguments" => args} + } + ] + } + } + ] + }) do + assistant_message = %{ + role: "assistant", + __bedrock_tool_use__: %{ + "toolUseId" => tool_call_id, + "input" => args + } + } + + tool_result_message = %{ + role: "user", + __bedrock_tool_result__: %{ + "toolUseId" => tool_call_id, + "content" => args + } + } + + [assistant_message, tool_result_message] + end + + defp reask_messages_for_mode(_mode, _raw_response), do: [] + + # Converse API Request Building + defp build_converse_body(messages, max_tokens, temperature, tools) do + {system_messages, user_messages} = extract_system_messages(messages) + + %{ + "messages" => format_messages(user_messages), + "inferenceConfig" => %{ + "maxTokens" => max_tokens, + "temperature" => temperature + } + } + |> maybe_add_system(system_messages) + |> maybe_add_tools(tools) + end + + defp extract_system_messages(messages) do + Enum.split_with(messages, &is_system_message?/1) + end + + defp is_system_message?(%{role: role}), do: role in ["system", :system] + defp is_system_message?(%{"role" => "system"}), do: true + defp is_system_message?(_), do: false + + defp format_messages(messages) do + Enum.map(messages, &format_message/1) + end + + defp format_message(%{__bedrock_tool_use__: tool_use} = msg) do + input = + case tool_use["input"] do + input when is_binary(input) -> Jason.decode!(input) + input when is_map(input) -> input + end + + %{ + "role" => to_string(msg.role), + "content" => [ + %{ + "toolUse" => %{ + "toolUseId" => tool_use["toolUseId"], + "name" => "Schema", + "input" => input + } + } + ] + } + end + + defp format_message(%{__bedrock_tool_result__: tool_result} = msg) do + %{ + "role" => to_string(msg.role), + "content" => [ + %{ + "toolResult" => %{ + "toolUseId" => tool_result["toolUseId"], + "content" => [%{"text" => tool_result["content"]}] + } + } + ] + } + end + + defp format_message(msg) do + %{ + "role" => get_role(msg), + "content" => [%{"text" => get_content(msg)}] + } + end + + defp get_role(%{role: role}), do: to_string(role) + defp get_role(%{"role" => role}), do: role + + defp get_content(%{content: content}), do: content + defp get_content(%{"content" => content}), do: content + + defp maybe_add_system(body, []), do: body + + defp maybe_add_system(body, system_messages) do + system_messages + |> Enum.map(&get_content/1) + |> Enum.map(&%{"text" => &1}) + |> then(&Map.put(body, "system", &1)) + end + + defp maybe_add_tools(body, []), do: body + + defp maybe_add_tools(body, tools) do + formatted_tools = Enum.map(tools, &format_tool/1) + + first_tool_name = + case formatted_tools do + [%{"toolSpec" => %{"name" => name}} | _] -> name + _ -> nil + end + + tool_config = %{ + "tools" => formatted_tools + } + + tool_config = + if first_tool_name do + Map.put(tool_config, "toolChoice", %{"tool" => %{"name" => first_tool_name}}) + else + tool_config + end + + Map.put(body, "toolConfig", tool_config) + end + + defp format_tool(tool) do + function = tool[:function] || tool["function"] + + %{ + "toolSpec" => %{ + "name" => function[:name] || function["name"], + "description" => function[:description] || function["description"], + "inputSchema" => %{ + "json" => function[:parameters] || function["parameters"] + } + } + } + end + + defp parse_response(mode, response, tools) do + raw_response = build_raw_response(response, tools) + + case parse_content_for_mode(mode, raw_response) do + {:ok, parsed} -> {:ok, raw_response, parsed} + {:error, _} = error -> error + end + end + + defp build_raw_response(response, tools) do + output = response["output"] || %{} + message = output["message"] || %{} + content = message["content"] || [] + stop_reason = response["stopReason"] + + tool_use = Enum.find(content, &(&1["toolUse"] != nil)) + + if tool_use && tools != [] do + tool_use_block = tool_use["toolUse"] + + %{ + "choices" => [ + %{ + "message" => %{ + "tool_calls" => [ + %{ + "id" => tool_use_block["toolUseId"], + "function" => %{ + "name" => tool_use_block["name"], + "arguments" => Jason.encode!(tool_use_block["input"]) + } + } + ] + }, + "finish_reason" => normalize_stop_reason(stop_reason) + } + ] + } + else + text_content = + content + |> Enum.filter(&(&1["text"] != nil)) + |> Enum.map(& &1["text"]) + |> Enum.join("") + + %{ + "choices" => [ + %{ + "message" => %{ + "content" => text_content + }, + "finish_reason" => normalize_stop_reason(stop_reason) + } + ] + } + end + end + + defp normalize_stop_reason("end_turn"), do: "stop" + defp normalize_stop_reason("tool_use"), do: "tool_calls" + defp normalize_stop_reason("max_tokens"), do: "length" + defp normalize_stop_reason(other), do: other + + defp parse_content_for_mode(:tools, %{ + "choices" => [ + %{"message" => %{"tool_calls" => [%{"function" => %{"arguments" => args}}]}} + ] + }) do + Jason.decode(args) + end + + defp parse_content_for_mode(:json, %{"choices" => [%{"message" => %{"content" => content}}]}) do + Jason.decode(content) + end + + defp parse_content_for_mode(:md_json, %{"choices" => [%{"message" => %{"content" => content}}]}) do + extract_json_from_markdown(content) + end + + defp parse_content_for_mode(mode, response) do + {:error, "Unsupported mode #{mode} with response #{inspect(response)}"} + end + + defp extract_json_from_markdown(content) do + case Regex.run(~r/```(?:json)?\s*([\s\S]*?)\s*```/, content) do + [_, json] -> Jason.decode(json) + nil -> Jason.decode(content) + end + end + + # --------------------------------------------------------- + # Streaming Response Parsing + # --------------------------------------------------------- + + # Tool use streaming - input comes as string chunks + defp parse_stream_chunk_for_mode(:tools, %{"delta" => %{"toolUse" => %{"input" => chunk}}}) do + chunk + end + + # Text streaming + defp parse_stream_chunk_for_mode(_mode, %{"delta" => %{"text" => chunk}}) do + chunk + end + + # Skip non-content events + defp parse_stream_chunk_for_mode(_mode, %{"role" => _}), do: "" + defp parse_stream_chunk_for_mode(_mode, %{"stopReason" => _}), do: "" + defp parse_stream_chunk_for_mode(_mode, %{"usage" => _}), do: "" + defp parse_stream_chunk_for_mode(_mode, %{"metrics" => _}), do: "" + defp parse_stream_chunk_for_mode(_mode, _), do: "" + + # --------------------------------------------------------- + # HTTP / Converse API + # --------------------------------------------------------- + + @doc false + def converse(model_id, body, config) do + url = build_url(model_id, config) + options = build_request_options(body, config) + + case Req.post(url, options) do + {:ok, %Req.Response{status: 200, body: response_body}} -> + {:ok, response_body} + + {:ok, %Req.Response{status: status, body: error_body}} -> + {:error, "Unexpected HTTP response code: #{status}\n#{inspect(error_body)}"} + + {:error, reason} -> + {:error, "Bedrock request failed: #{inspect(reason)}"} + end + end + + @doc false + def build_url(model_id, config) do + base_url = config[:runtime_url] || "https://bedrock-runtime.#{config[:region]}.amazonaws.com" + path = "/model/#{URI.encode(model_id)}/converse" + base_url <> path + end + + defp build_stream_url(model_id, config) do + base_url = config[:runtime_url] || "https://bedrock-runtime.#{config[:region]}.amazonaws.com" + path = "/model/#{URI.encode(model_id)}/converse-stream" + base_url <> path + end + + defp build_base_request_options(body, config, accept, extra_opts \\ []) do + http_options = Keyword.get(config, :http_options, []) + + Keyword.merge( + http_options, + [ + headers: %{ + "content-type" => "application/json", + "accept" => accept + }, + auth: auth_header(config), + json: body + ] ++ extra_opts + ) + end + + defp build_request_options(body, config) do + build_base_request_options(body, config, "application/json") + end + + defp build_stream_request_options(body, config) do + build_base_request_options(body, config, "application/vnd.amazon.eventstream", + decode_body: false + ) + end + + defp auth_header(config) do + case Keyword.get(config, :auth_mode, :bearer) do + :bearer -> {:bearer, api_key(config)} + end + end + + defp api_key(config) do + case Keyword.get(config, :api_key) do + fun when is_function(fun, 0) -> fun.() + key -> key + end + end + + defp config(nil), do: config(Application.get_env(:instructor, :bedrock, [])) + + defp config(base_config) do + api_key = Keyword.get(base_config, :api_key) || System.get_env("AWS_BEARER_TOKEN_BEDROCK") + + # Auto-detect region from token if not explicitly provided + region = + Keyword.get(base_config, :region) || + System.get_env("AWS_REGION") || + extract_region_from_token(api_key) + + default_config = [ + region: region, + runtime_url: nil, + api_key: api_key, + auth_mode: :bearer, + http_options: [receive_timeout: 60_000] + ] + + Keyword.merge(default_config, base_config) + end + + defp extract_region_from_token(nil), do: nil + + defp extract_region_from_token(token) when is_binary(token) do + # Token format: bedrock-api-key-{base64_encoded_content} + # The base64 content contains X-Amz-Credential with region info + with "bedrock-api-key-" <> base64_part <- token, + {:ok, decoded} <- Base.decode64(base64_part), + [_, region] <- Regex.run(~r/X-Amz-Credential=[^%]+%2F\d+%2F([^%]+)%2F/, decoded) do + URI.decode(region) + else + _ -> nil + end + end + + defp extract_region_from_token(_), do: nil +end diff --git a/lib/instructor/aws_event_stream_parser.ex b/lib/instructor/aws_event_stream_parser.ex new file mode 100644 index 0000000..c8022c9 --- /dev/null +++ b/lib/instructor/aws_event_stream_parser.ex @@ -0,0 +1,148 @@ +defmodule Instructor.AWSEventStreamParser do + @moduledoc """ + Parser for AWS Event Stream binary format (application/vnd.amazon.eventstream). + + Used by AWS Bedrock ConverseStream and other AWS streaming APIs. + + ## Event Stream Format + + Each message in the stream has the following structure: + - 4 bytes: total message length + - 4 bytes: headers length + - 4 bytes: prelude CRC32 + - Headers section (variable length) + - Payload section (variable length) + - 4 bytes: message CRC32 + + ## Header Format + + Each header consists of: + - 1 byte: header name length + - N bytes: header name + - 1 byte: header value type (7 = string) + - For string type: 2 bytes length + N bytes value + """ + + # AWS Event Stream message structure constants + @total_length_bytes 4 + @headers_length_bytes 4 + @prelude_crc_bytes 4 + @message_crc_bytes 4 + + @prelude_bytes @total_length_bytes + @headers_length_bytes + @prelude_crc_bytes + @fixed_overhead @prelude_bytes + @message_crc_bytes + + @doc """ + Parses a stream of binary chunks into decoded JSON events. + + Returns a stream of parsed JSON maps from the AWS Event Stream format. + """ + @spec parse(stream :: Enumerable.t()) :: Enumerable.t() + def parse(stream) do + Stream.transform( + stream, + fn -> <<>> end, + fn chunk, buf -> + parse_events(buf <> chunk, []) + end, + fn buf -> + case parse_events(buf, []) do + {events, _rest} -> {events, <<>>} + end + end, + fn _ -> nil end + ) + end + + defp parse_events(<<>>, acc), do: {Enum.reverse(acc), <<>>} + + defp parse_events(data, acc) when byte_size(data) < @prelude_bytes do + {Enum.reverse(acc), data} + end + + defp parse_events( + <> = data, + acc + ) do + payload_length = total_length - headers_length - @fixed_overhead + + if byte_size(rest) >= headers_length + payload_length + @message_crc_bytes do + <> = rest + + case parse_event_payload(headers_data, payload) do + {:ok, event} -> + parse_events(remaining, [event | acc]) + + :skip -> + parse_events(remaining, acc) + + {:error, _reason} -> + parse_events(remaining, acc) + end + else + {Enum.reverse(acc), data} + end + end + + defp parse_events(data, acc), do: {Enum.reverse(acc), data} + + defp parse_event_payload(headers_data, payload) do + headers = parse_headers(headers_data, %{}) + event_type = headers[":event-type"] + message_type = headers[":message-type"] + + cond do + message_type == "exception" -> + {:error, payload} + + # Accept all Bedrock Converse stream event types + message_type == "event" and byte_size(payload) > 0 -> + case Jason.decode(payload) do + {:ok, %{"bytes" => base64_bytes}} -> + {:ok, Jason.decode!(Base.decode64!(base64_bytes))} + + {:ok, json} -> + {:ok, json} + + {:error, _} -> + :skip + end + + # Legacy: also accept "chunk" event type for backwards compatibility + event_type in ["chunk", nil] and byte_size(payload) > 0 -> + case Jason.decode(payload) do + {:ok, %{"bytes" => base64_bytes}} -> + {:ok, Jason.decode!(Base.decode64!(base64_bytes))} + + {:ok, json} -> + {:ok, json} + + {:error, _} -> + :skip + end + + true -> + :skip + end + end + + defp parse_headers(<<>>, acc), do: acc + + defp parse_headers(<>, acc) do + <> = rest + + {value, rest} = + case type do + # Type 7 = string + 7 -> + <> = rest + {value, rest} + + _ -> + {nil, rest} + end + + parse_headers(rest, Map.put(acc, name, value)) + end +end diff --git a/test/aws_event_stream_parser_test.exs b/test/aws_event_stream_parser_test.exs new file mode 100644 index 0000000..5a6b3ef --- /dev/null +++ b/test/aws_event_stream_parser_test.exs @@ -0,0 +1,121 @@ +defmodule Instructor.AWSEventStreamParserTest do + use ExUnit.Case, async: true + + alias Instructor.AWSEventStreamParser + + # AWS Event Stream message structure (all fields are 32-bit / 4 bytes) + @total_length_bytes 4 + @headers_length_bytes 4 + @prelude_crc_bytes 4 + @message_crc_bytes 4 + @prelude_bytes @total_length_bytes + @headers_length_bytes + @prelude_crc_bytes + + describe "parse/1" do + test "parses a single event" do + payload = %{"text" => "Lorem ipsum"} + event_data = build_content_block_data(payload) + + result = + event_data + |> build_event_stream_message() + |> List.wrap() + |> Stream.concat([]) + |> AWSEventStreamParser.parse() + |> Enum.to_list() + + assert [%{"contentBlockDelta" => %{"delta" => ^payload}}] = result + end + + test "parses multiple events" do + event1 = build_event_stream_message(%{"messageStart" => %{"role" => "assistant"}}) + + event2 = + build_content_block_data() + |> build_event_stream_message() + + event3 = build_event_stream_message(%{"messageStop" => %{"stopReason" => "end_turn"}}) + + result = + [event1 <> event2 <> event3] + |> Stream.concat([]) + |> AWSEventStreamParser.parse() + |> Enum.to_list() + + assert length(result) == 3 + assert [%{"messageStart" => _}, %{"contentBlockDelta" => _}, %{"messageStop" => _}] = result + end + + test "handles chunked data across multiple stream elements" do + payload = %{"text" => "Lorem ipsum dolor sit amet"} + + event = + build_content_block_data(payload) + |> build_event_stream_message() + + {part1, part2} = String.split_at(event, div(byte_size(event), 2)) + + result = + [part1, part2] + |> Stream.concat([]) + |> AWSEventStreamParser.parse() + |> Enum.to_list() + + assert [%{"contentBlockDelta" => %{"delta" => ^payload}}] = result + end + + test "parses tool use events" do + event = + build_content_block_data(%{ + "toolUse" => %{"input" => ~s|{name: test}|} + }) + |> put_in(["contentBlockDelta", "delta", "contentBlockIndex"], 0) + |> build_event_stream_message() + + result = + [event] + |> Stream.concat([]) + |> AWSEventStreamParser.parse() + |> Enum.to_list() + + assert [%{"contentBlockDelta" => %{"delta" => %{"toolUse" => %{"input" => _}}}}] = result + end + end + + # Private Helper functions + defp build_event_stream_message(payload) do + json_payload = Jason.encode!(payload) + + headers = + encode_headers([ + {":event-type", "chunk"}, + {":content-type", "application/json"}, + {":message-type", "event"} + ]) + + headers_length = byte_size(headers) + payload_length = byte_size(json_payload) + total_length = @prelude_bytes + headers_length + payload_length + @message_crc_bytes + + # Build the message (using 0 for CRCs since our parser ignores them) + <> <> + headers <> + json_payload <> + <<0::32>> + end + + # AWS Event Stream binary format encoding + defp encode_headers(header_list) do + Enum.reduce(header_list, <<>>, fn {name, value}, acc -> + acc <> + <> <> + name <> + <<7::8>> <> + <> <> + value + end) + end + + defp build_content_block_data(payload \\ %{"text" => "Hello World"}) do + %{"contentBlockDelta" => %{"delta" => payload}} + end +end diff --git a/test/instructor_test.exs b/test/instructor_test.exs index 37f05c4..a8e6fdf 100644 --- a/test/instructor_test.exs +++ b/test/instructor_test.exs @@ -43,6 +43,16 @@ defmodule InstructorTest do :openai_mock -> Application.put_env(:instructor, :adapter, InstructorTest.MockOpenAI) + + :bedrock -> + Application.put_env(:instructor, :adapter, Instructor.Adapters.Bedrock) + + Application.put_env(:instructor, :bedrock, + api_key: System.fetch_env!("AWS_BEARER_TOKEN_BEDROCK") + ) + + :bedrock_mock -> + Application.put_env(:instructor, :adapter, InstructorTest.MockBedrock) end end @@ -50,16 +60,25 @@ defmodule InstructorTest do TestHelpers.mock_openai_response(mode, expected) end + def mock_response(:bedrock_mock, mode, expected) do + TestHelpers.mock_bedrock_response(mode, expected) + end + def mock_response(_, _, _), do: nil def mock_stream_response(:openai_mock, mode, expected) do TestHelpers.mock_openai_response_stream(mode, expected) end + def mock_stream_response(:bedrock_mock, mode, expected) do + TestHelpers.mock_bedrock_response_stream(mode, expected) + end + def mock_stream_response(_, _, _), do: nil for {adapter, params} <- [ {:openai_mock, [mode: :tools, model: "gpt-4.1-mini"]}, + {:bedrock_mock, [mode: :tools, model: "anthropic.claude-3-5-sonnet-20240620-v1:0"]}, {:openai, [mode: :tools, model: "gpt-4.1-mini"]}, {:openai, [mode: :json, model: "gpt-4.1-mini"]}, {:openai, [mode: :json_schema, model: "gpt-4.1-mini"]}, @@ -69,7 +88,9 @@ defmodule InstructorTest do {:xai, [mode: :tools, model: "grok-2-latest"]}, {:xai, [mode: :json_schema, model: "grok-2-latest"]}, {:ollama, [mode: :tools, model: "llama3.1"]}, - {:anthropic, [mode: :tools, model: "claude-3-5-sonnet-20240620", max_tokens: 1024]} + {:anthropic, [mode: :tools, model: "claude-3-5-sonnet-20240620", max_tokens: 1024]}, + {:bedrock, + [mode: :tools, model: "anthropic.claude-3-5-sonnet-20240620-v1:0", max_tokens: 1024]} ] do describe "#{inspect(adapter)} #{params[:mode]} #{params[:model]}" do @tag adapter: adapter @@ -369,7 +390,6 @@ defmodule InstructorTest do field(:number, :integer) end - def validate_changeset(changeset) do changeset |> Ecto.Changeset.validate_change(:number, fn :number, number -> @@ -386,7 +406,6 @@ defmodule InstructorTest do test "reask" do mock_response(unquote(adapter), :tools, %{number: 11}) - result = Instructor.chat_completion( Keyword.merge(unquote(params), @@ -400,7 +419,6 @@ defmodule InstructorTest do assert {:ok, %{number: number}} = result assert number >= 10 - end end end diff --git a/test/support/test_helpers.ex b/test/support/test_helpers.ex index a27bffa..b3934e2 100644 --- a/test/support/test_helpers.ex +++ b/test/support/test_helpers.ex @@ -111,6 +111,85 @@ defmodule Instructor.TestHelpers do end) end + # --------------------------------------------------------- + # Bedrock Mock Helpers (Converse API format) + # --------------------------------------------------------- + + def mock_bedrock_response(:tools, result) do + InstructorTest.MockBedrock + |> expect(:chat_completion, fn _params, _config -> + {:ok, + %{ + "choices" => [ + %{ + "finish_reason" => "tool_calls", + "message" => %{ + "tool_calls" => [ + %{ + "id" => "tooluse_e8Civ0HoDy", + "function" => %{ + "arguments" => Jason.encode!(result), + "name" => "schema" + } + } + ] + } + } + ] + }, result} + end) + end + + def mock_bedrock_response(mode, result) when mode in [:json, :md_json] do + InstructorTest.MockBedrock + |> expect(:chat_completion, fn _params, _config -> + {:ok, + %{ + "choices" => [ + %{ + "finish_reason" => "stop", + "message" => %{ + "content" => Jason.encode!(result) + } + } + ] + }, result} + end) + end + + def mock_bedrock_response_stream(:tools, result) do + chunks = + Jason.encode!(%{value: result}) + |> String.graphemes() + |> Enum.chunk_every(12) + |> Enum.map(&Enum.join(&1, "")) + + InstructorTest.MockBedrock + |> expect(:chat_completion, fn _params, _config -> + chunks + end) + end + + def mock_bedrock_response_stream(mode, result) when mode in [:json, :md_json] do + chunks = + Jason.encode!(%{value: result}) + |> String.graphemes() + |> Enum.chunk_every(12) + |> Enum.map(&Enum.join(&1, "")) + + InstructorTest.MockBedrock + |> expect(:chat_completion, fn _params, _config -> + chunks + end) + end + + def mock_bedrock_reask_messages do + InstructorTest.MockBedrock + |> expect(:reask_messages, fn _raw_response, _params, _config -> + [] + end) + end + def is_stream?(variable) do case variable do %Stream{} -> diff --git a/test/test_helper.exs b/test/test_helper.exs index d83e8ad..ceeccaf 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -1,13 +1,15 @@ Mox.defmock(InstructorTest.MockOpenAI, for: Instructor.Adapter) +Mox.defmock(InstructorTest.MockBedrock, for: Instructor.Adapter) # Exclude the unmocked tests by default, to run them use: # # mix test --only adapter:llamacpp # mix test --only adapter:openai +# mix test --only adapter:bedrock # # to run all the non-local models, use: # -# mix test --include adapter:gemini --include adapter:anthropic --include adapter:openai +# mix test --include adapter:gemini --include adapter:anthropic --include adapter:openai --include adapter:bedrock # # ExUnit.configure( @@ -18,7 +20,8 @@ ExUnit.configure( adapter: :gemini, adapter: :xai, adapter: :llamacpp, - adapter: :ollama + adapter: :ollama, + adapter: :bedrock ] ) From 070c57aab5085a3403b292a4e26f37d004004eec Mon Sep 17 00:00:00 2001 From: shegx01 Date: Fri, 12 Dec 2025 09:56:48 +0200 Subject: [PATCH 2/6] Remove VSCode configuration files --- .vscode/launch.json | 31 ------------------------------- .vscode/settings.json | 3 --- 2 files changed, 34 deletions(-) delete mode 100644 .vscode/launch.json delete mode 100644 .vscode/settings.json diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 79f5763..0000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "type": "mix_task", - "name": "mix (Default task)", - "request": "launch", - "projectDir": "${workspaceRoot}" - }, - { - "type": "mix_task", - "name": "mix test", - "request": "launch", - "task": "test", - "taskArgs": [ - "--trace" - ], - "debugAutoInterpretAllModules": false, - "debugInterpretModulesPatterns": ["Instructor.*"], - "startApps": true, - "projectDir": "${workspaceRoot}", - "requireFiles": [ - "test/**/test_helper.exs", - "test/**/*_test.exs" - ] - } - ] -} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index a8b67ce..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "cSpell.enabled": false -} From dd2e73957999b48cf134295a97925b6c3010e8da Mon Sep 17 00:00:00 2001 From: shegx01 Date: Fri, 12 Dec 2025 09:56:56 +0200 Subject: [PATCH 3/6] Remove VSCode configuration from version control --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 30c4278..116ca30 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ instructor-*.tar *.DS_Store llama.log +.vscode From f2736609227806fb09bfda4b108322de4ab5e912 Mon Sep 17 00:00:00 2001 From: shegx01 Date: Fri, 12 Dec 2025 12:27:30 +0200 Subject: [PATCH 4/6] Handle all AWS event stream header types in parser --- lib/instructor/aws_event_stream_parser.ex | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/lib/instructor/aws_event_stream_parser.ex b/lib/instructor/aws_event_stream_parser.ex index c8022c9..724df0b 100644 --- a/lib/instructor/aws_event_stream_parser.ex +++ b/lib/instructor/aws_event_stream_parser.ex @@ -134,13 +134,15 @@ defmodule Instructor.AWSEventStreamParser do {value, rest} = case type do - # Type 7 = string - 7 -> - <> = rest - {value, rest} - - _ -> - {nil, rest} + type when type in [0, 1] -> {nil, rest} + 2 -> <<_::8, rest::binary>> = rest; {nil, rest} + 3 -> <<_::16, rest::binary>> = rest; {nil, rest} + 4 -> <<_::32, rest::binary>> = rest; {nil, rest} + 5 -> <<_::64, rest::binary>> = rest; {nil, rest} + 6 -> <> = rest; {nil, rest} + 7 -> <> = rest; {value, rest} + 8 -> <<_::64, rest::binary>> = rest; {nil, rest} + 9 -> <<_::binary-size(16), rest::binary>> = rest; {nil, rest} end parse_headers(rest, Map.put(acc, name, value)) From 886b3773850288b8674cceb85d98764d1b1bd05f Mon Sep 17 00:00:00 2001 From: shegx01 Date: Fri, 12 Dec 2025 12:38:28 +0200 Subject: [PATCH 5/6] Format code with consistent line breaks and indentation --- lib/instructor/adapters/anthropic.ex | 4 ++- lib/instructor/aws_event_stream_parser.ex | 43 ++++++++++++++++++----- lib/instructor/json_schema.ex | 24 +++++++++---- lib/instructor/types/duration.ex | 3 +- pages/cookbook/o1_cot_ui.exs | 5 ++- pages/cookbook/streaming_ui.exs | 33 ++++++++++++----- test/json_schema_test.exs | 9 +++-- 7 files changed, 91 insertions(+), 30 deletions(-) diff --git a/lib/instructor/adapters/anthropic.ex b/lib/instructor/adapters/anthropic.ex index 83a210e..445d386 100644 --- a/lib/instructor/adapters/anthropic.ex +++ b/lib/instructor/adapters/anthropic.ex @@ -67,7 +67,9 @@ defmodule Instructor.Adapters.Anthropic do reask_messages_for_mode(params[:mode], raw_response) end - defp reask_messages_for_mode(:tools, %{"content" => [%{"input" => args, "type" => "tool_use", "id" => id, "name" => name}]}) do + defp reask_messages_for_mode(:tools, %{ + "content" => [%{"input" => args, "type" => "tool_use", "id" => id, "name" => name}] + }) do [ %{ role: "assistant", diff --git a/lib/instructor/aws_event_stream_parser.ex b/lib/instructor/aws_event_stream_parser.ex index 724df0b..0e3c47b 100644 --- a/lib/instructor/aws_event_stream_parser.ex +++ b/lib/instructor/aws_event_stream_parser.ex @@ -134,15 +134,40 @@ defmodule Instructor.AWSEventStreamParser do {value, rest} = case type do - type when type in [0, 1] -> {nil, rest} - 2 -> <<_::8, rest::binary>> = rest; {nil, rest} - 3 -> <<_::16, rest::binary>> = rest; {nil, rest} - 4 -> <<_::32, rest::binary>> = rest; {nil, rest} - 5 -> <<_::64, rest::binary>> = rest; {nil, rest} - 6 -> <> = rest; {nil, rest} - 7 -> <> = rest; {value, rest} - 8 -> <<_::64, rest::binary>> = rest; {nil, rest} - 9 -> <<_::binary-size(16), rest::binary>> = rest; {nil, rest} + type when type in [0, 1] -> + {nil, rest} + + 2 -> + <<_::8, rest::binary>> = rest + {nil, rest} + + 3 -> + <<_::16, rest::binary>> = rest + {nil, rest} + + 4 -> + <<_::32, rest::binary>> = rest + {nil, rest} + + 5 -> + <<_::64, rest::binary>> = rest + {nil, rest} + + 6 -> + <> = rest + {nil, rest} + + 7 -> + <> = rest + {value, rest} + + 8 -> + <<_::64, rest::binary>> = rest + {nil, rest} + + 9 -> + <<_::binary-size(16), rest::binary>> = rest + {nil, rest} end parse_headers(rest, Map.put(acc, name, value)) diff --git a/lib/instructor/json_schema.ex b/lib/instructor/json_schema.ex index 4655299..3c53954 100644 --- a/lib/instructor/json_schema.ex +++ b/lib/instructor/json_schema.ex @@ -290,7 +290,11 @@ defmodule Instructor.JSONSchema do defp for_type(:decimal), do: %{type: "number", format: "float"} defp for_type(:date), - do: %{type: "string", description: "ISO8601 Date, [yyyy]-[mm]-[dd], e.g. \"2024-07-20\"", format: "date"} + do: %{ + type: "string", + description: "ISO8601 Date, [yyyy]-[mm]-[dd], e.g. \"2024-07-20\"", + format: "date" + } defp for_type(:time), do: %{ @@ -302,35 +306,40 @@ defmodule Instructor.JSONSchema do defp for_type(:time_usec), do: %{ type: "string", - description: "ISO8601 Time with microseconds, [hh]:[mm]:[ss].[microseconds], e.g. \"12:00:00.000000\"", + description: + "ISO8601 Time with microseconds, [hh]:[mm]:[ss].[microseconds], e.g. \"12:00:00.000000\"", pattern: "^[0-9]{2}:?[0-9]{2}:?[0-9]{2}.[0-9]{6}$" } defp for_type(:naive_datetime), do: %{ type: "string", - description: "ISO8601 DateTime, [yyyy]-[mm]-[dd]T[hh]:[mm]:[ss], e.g. \"2024-07-20T12:00:00\"", + description: + "ISO8601 DateTime, [yyyy]-[mm]-[dd]T[hh]:[mm]:[ss], e.g. \"2024-07-20T12:00:00\"", format: "date-time" } defp for_type(:naive_datetime_usec), do: %{ type: "string", - description: "ISO8601 DateTime with microseconds, [yyyy]-[mm]-[dd]T[hh]:[mm]:[ss].[microseconds], e.g. \"2024-07-20T12:00:00.000000\"", + description: + "ISO8601 DateTime with microseconds, [yyyy]-[mm]-[dd]T[hh]:[mm]:[ss].[microseconds], e.g. \"2024-07-20T12:00:00.000000\"", format: "date-time" } defp for_type(:utc_datetime), do: %{ type: "string", - description: "ISO8601 DateTime, [yyyy]-[mm]-[dd]T[hh]:[mm]:[ss]Z, e.g. \"2024-07-20T12:00:00Z\"", + description: + "ISO8601 DateTime, [yyyy]-[mm]-[dd]T[hh]:[mm]:[ss]Z, e.g. \"2024-07-20T12:00:00Z\"", format: "date-time" } defp for_type(:utc_datetime_usec), do: %{ type: "string", - description: "ISO8601 DateTime with microseconds, [yyyy]-[mm]-[dd]T[hh]:[mm]:[ss].[microseconds]Z, e.g. \"2024-07-20T12:00:00.000000Z\"", + description: + "ISO8601 DateTime with microseconds, [yyyy]-[mm]-[dd]T[hh]:[mm]:[ss].[microseconds]Z, e.g. \"2024-07-20T12:00:00.000000Z\"", format: "date-time" } @@ -451,7 +460,8 @@ defmodule Instructor.JSONSchema do |> maybe_call_with_path(fun, path, opts) end - defp do_traverse_and_update(tree, fun, path, opts), do: maybe_call_with_path(tree, fun, path, opts) + defp do_traverse_and_update(tree, fun, path, opts), + do: maybe_call_with_path(tree, fun, path, opts) defp maybe_call_with_path(value, fun, path, opts) do if Keyword.get(opts, :include_path, false) do diff --git a/lib/instructor/types/duration.ex b/lib/instructor/types/duration.ex index ac879f8..3901c25 100644 --- a/lib/instructor/types/duration.ex +++ b/lib/instructor/types/duration.ex @@ -16,7 +16,8 @@ defmodule Instructor.Types.Duration do type: "string", description: "A valid ISO8601 duration, e.g. PT3M14S", format: "duration", - pattern: "^P(?:(\\d+)Y)?(?:(\\d+)M)?(?:(\\d+)D)?(?:T(?:(\\d+)H)?(?:(\\d+)M)?(?:(\\d+(?:\\.\\d+)?)S)?)?$" + pattern: + "^P(?:(\\d+)Y)?(?:(\\d+)M)?(?:(\\d+)D)?(?:T(?:(\\d+)H)?(?:(\\d+)M)?(?:(\\d+(?:\\.\\d+)?)S)?)?$" } end diff --git a/pages/cookbook/o1_cot_ui.exs b/pages/cookbook/o1_cot_ui.exs index 8e63228..d9ae9c0 100644 --- a/pages/cookbook/o1_cot_ui.exs +++ b/pages/cookbook/o1_cot_ui.exs @@ -47,6 +47,7 @@ defmodule DemoLive do pid = self() selected_model = socket.assigns.selected_model + response_model = socket.assigns.output_schema |> Enum.map(fn {key, type} -> {String.to_atom(key), type} end) @@ -161,7 +162,9 @@ defmodule DemoLive do
🍓 Elixir - Structured Outputs w/ Reasoning + + Structured Outputs w/ Reasoning +