diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 8b691b3..8688828 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -83,7 +83,7 @@ class Agent(ControlFlowModel, abc.ABC): default=False, description="If True, the agent is given tools for interacting with a human user.", ) - memories: list[Memory] | list[AsyncMemory] = Field( + memories: list[Union[Memory, AsyncMemory]] = Field( default=[], description="A list of memory modules for the agent to use.", ) @@ -345,7 +345,7 @@ def _run_model( create_markdown_artifact( markdown=f""" -{response.content or '(No content)'} +{response.content or "(No content)"} #### Payload ```json @@ -409,7 +409,7 @@ async def _run_model_async( create_markdown_artifact( markdown=f""" -{response.content or '(No content)'} +{response.content or "(No content)"} #### Payload ```json diff --git a/src/controlflow/defaults.py b/src/controlflow/defaults.py index b5836e9..0de18f5 100644 --- a/src/controlflow/defaults.py +++ b/src/controlflow/defaults.py @@ -40,9 +40,7 @@ class Defaults(ControlFlowModel): model: Optional[Any] history: History agent: Agent - memory_provider: ( - Optional[Union[MemoryProvider, str]] | Optional[Union[AsyncMemoryProvider, str]] - ) + memory_provider: Optional[Union[MemoryProvider, AsyncMemoryProvider, str]] # add more defaults here def __repr__(self) -> str: diff --git a/src/controlflow/llm/models.py b/src/controlflow/llm/models.py index 11a66c8..7b0af5b 100644 --- a/src/controlflow/llm/models.py +++ b/src/controlflow/llm/models.py @@ -52,6 +52,8 @@ def get_model( "To use Google as an LLM provider, please install the `langchain_google_genai` package." ) cls = ChatGoogleGenerativeAI + if temperature is None: + temperature = 0.7 elif provider == "groq": try: from langchain_groq import ChatGroq @@ -60,6 +62,8 @@ def get_model( "To use Groq as an LLM provider, please install the `langchain_groq` package." ) cls = ChatGroq + if temperature is None: + temperature = 0.7 elif provider == "ollama": try: from langchain_ollama import ChatOllama @@ -73,7 +77,9 @@ def get_model( f"Could not load provider `{provider}` automatically. Please provide the LLM class manually." ) - return cls(model=model, temperature=temperature, **kwargs) + if temperature is not None: + kwargs["temperature"] = temperature + return cls(model=model, **kwargs) def _get_initial_default_model() -> BaseChatModel: diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 03a09de..e949f73 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -188,7 +188,7 @@ def get_tools(self) -> list[Tool]: tools = as_tools(tools) return tools - def get_memories(self) -> list[Memory] | list[AsyncMemory]: + def get_memories(self) -> list[Union[Memory, AsyncMemory]]: memories = set() memories.update(self.agent.memories) @@ -525,7 +525,7 @@ def compile_prompt(self) -> str: ] prompt = "\n\n".join([p for p in prompts if p]) - logger.debug(f"{'='*10}\nCompiled prompt: {prompt}\n{'='*10}") + logger.debug(f"{'=' * 10}\nCompiled prompt: {prompt}\n{'=' * 10}") return prompt def compile_messages(self) -> list[BaseMessage]: diff --git a/src/controlflow/orchestration/prompt_templates.py b/src/controlflow/orchestration/prompt_templates.py index 4a74df1..9434c6b 100644 --- a/src/controlflow/orchestration/prompt_templates.py +++ b/src/controlflow/orchestration/prompt_templates.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic import model_validator @@ -98,7 +98,7 @@ def should_render(self) -> bool: class MemoryTemplate(Template): template_path: str = "memories.jinja" - memories: list[Memory] | list[AsyncMemory] + memories: list[Union[Memory, AsyncMemory]] def should_render(self) -> bool: return bool(self.memories) diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index 387753b..6a00c3e 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -109,7 +109,9 @@ def _validate_pretty_print_agent_events(cls, data: dict) -> dict: default="openai/gpt-4o", description="The default LLM model for agents.", ) - llm_temperature: float = Field(0.7, description="The temperature for LLM sampling.") + llm_temperature: Union[float, None] = Field( + None, description="The temperature for LLM sampling." + ) max_input_tokens: int = Field( 100_000, description="The maximum number of tokens to send to an LLM." ) diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index b08a548..6683318 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -75,7 +75,7 @@ def __getitem__(self, item): return self.root[item] def __repr__(self) -> str: - return f'Labels: {", ".join(self.root)}' + return f"Labels: {', '.join(self.root)}" class TaskStatus(Enum): @@ -162,7 +162,7 @@ class Task(ControlFlowModel): description="Agents that are allowed to mark this task as complete. If None, all agents are allowed.", ) interactive: bool = False - memories: list[Memory] | list[AsyncMemory] = Field( + memories: list[Union[Memory, AsyncMemory]] = Field( default=[], description="A list of memory modules for the task to use.", )