diff --git a/.changeset/fortunate-friendly-avocet.md b/.changeset/fortunate-friendly-avocet.md new file mode 100644 index 00000000..5f1ee443 --- /dev/null +++ b/.changeset/fortunate-friendly-avocet.md @@ -0,0 +1,5 @@ +--- +"stagehand": patch +--- + +Fix ability to pass raw JSON to Extract schema diff --git a/stagehand/handlers/extract_handler.py b/stagehand/handlers/extract_handler.py index 8af621c9..5fa24212 100644 --- a/stagehand/handlers/extract_handler.py +++ b/stagehand/handlers/extract_handler.py @@ -151,9 +151,7 @@ async def extract( processed_data_payload = raw_data_dict # Default to the raw dictionary - if schema and isinstance( - raw_data_dict, dict - ): # schema is the Pydantic model type + if schema and isinstance(schema, type) and issubclass(schema, BaseModel): # Try direct validation first try: validated_model_instance = schema.model_validate(raw_data_dict) diff --git a/stagehand/llm/inference.py b/stagehand/llm/inference.py index b438883b..5c9d122c 100644 --- a/stagehand/llm/inference.py +++ b/stagehand/llm/inference.py @@ -169,11 +169,19 @@ async def extract( start_time = time.time() # Determine if we need to use schema-based response format - # TODO: if schema is json, return json response_format = {"type": "json_object"} if schema: - # If schema is a Pydantic model, use it directly - response_format = schema + if isinstance(schema, dict): + response_format = { + "type": "json_schema", + "json_schema": { + "name": "extraction_schema", + "strict": False, + "schema": schema, + }, + } + else: + response_format = schema # Call the LLM with appropriate parameters try: