diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 49f7b7a..7ff77e3 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -4,7 +4,6 @@ import uuid from rich.panel import Panel from rich.console import Console from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt import create_react_agent from ra_aid.env import validate_environment from ra_aid.tools.memory import _global_memory from ra_aid.tools.human import ask_human @@ -15,7 +14,8 @@ from ra_aid.agent_utils import ( AgentInterrupt, run_agent_with_retry, run_research_agent, - run_planning_agent + run_planning_agent, + create_agent ) from ra_aid.prompts import ( CHAT_PROMPT, @@ -177,7 +177,7 @@ def main(): initial_request = ask_human.invoke({"question": "What would you like help with?"}) # Create chat agent with appropriate tools - chat_agent = create_react_agent( + chat_agent = create_agent( model, get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled), checkpointer=MemorySaver() diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index d281185..90b0650 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -10,8 +10,10 @@ import threading import time from typing import Optional -from langgraph.prebuilt import create_react_agent +from langgraph.prebuilt import create_react_agent from ra_aid.console.formatting import print_stage_header, print_error +from langchain_core.language_models import BaseChatModel +from typing import List, Any from ra_aid.console.output import print_agent_output from ra_aid.logging_config import get_logger from ra_aid.exceptions import AgentInterrupt @@ -41,7 +43,6 @@ from ra_aid.prompts import ( from langgraph.checkpoint.memory import MemorySaver from langchain_core.messages import HumanMessage -from langchain_core.messages import BaseMessage from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError from rich.console import Console from rich.markdown import Markdown @@ -65,6 +66,24 @@ console = Console() logger = get_logger(__name__) +def create_agent( + model: BaseChatModel, + tools: List[Any], + *, + checkpointer: Any = None +) -> Any: + """Create a react agent with the given configuration. + + Args: + model: The LLM model to use + tools: List of tools to provide to the agent + checkpointer: Optional memory checkpointer + + Returns: + The created agent instance + """ + return create_react_agent(model, tools, checkpointer=checkpointer) + def run_research_agent( base_task_or_query: str, model, @@ -125,7 +144,7 @@ def run_research_agent( ) # Create agent - agent = create_react_agent(model, tools, checkpointer=memory) + agent = create_agent(model, tools, checkpointer=memory) # Format prompt sections expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" @@ -223,7 +242,7 @@ def run_web_research_agent( tools = get_web_research_tools(expert_enabled=expert_enabled) # Create agent - agent = create_react_agent(model, tools, checkpointer=memory) + agent = create_agent(model, tools, checkpointer=memory) # Format prompt sections expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" @@ -306,7 +325,7 @@ def run_planning_agent( tools = get_planning_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False)) # Create agent - agent = create_react_agent(model, tools, checkpointer=memory) + agent = create_agent(model, tools, checkpointer=memory) # Format prompt sections expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else "" @@ -393,7 +412,7 @@ def run_task_implementation_agent( tools = get_implementation_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False)) # Create agent - agent = create_react_agent(model, tools, checkpointer=memory) + agent = create_agent(model, tools, checkpointer=memory) # Build prompt prompt = IMPLEMENTATION_PROMPT.format(