extract creat agent method

This commit is contained in:
AI Christianson 2024-12-28 15:39:33 -05:00
parent fd664c0886
commit 13b953bf7f
2 changed files with 28 additions and 9 deletions

View File

@ -4,7 +4,6 @@ import uuid
from rich.panel import Panel
from rich.console import Console
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from ra_aid.env import validate_environment
from ra_aid.tools.memory import _global_memory
from ra_aid.tools.human import ask_human
@ -15,7 +14,8 @@ from ra_aid.agent_utils import (
AgentInterrupt,
run_agent_with_retry,
run_research_agent,
run_planning_agent
run_planning_agent,
create_agent
)
from ra_aid.prompts import (
CHAT_PROMPT,
@ -177,7 +177,7 @@ def main():
initial_request = ask_human.invoke({"question": "What would you like help with?"})
# Create chat agent with appropriate tools
chat_agent = create_react_agent(
chat_agent = create_agent(
model,
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
checkpointer=MemorySaver()

View File

@ -10,8 +10,10 @@ import threading
import time
from typing import Optional
from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt import create_react_agent
from ra_aid.console.formatting import print_stage_header, print_error
from langchain_core.language_models import BaseChatModel
from typing import List, Any
from ra_aid.console.output import print_agent_output
from ra_aid.logging_config import get_logger
from ra_aid.exceptions import AgentInterrupt
@ -41,7 +43,6 @@ from ra_aid.prompts import (
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage
from langchain_core.messages import BaseMessage
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
from rich.console import Console
from rich.markdown import Markdown
@ -65,6 +66,24 @@ console = Console()
logger = get_logger(__name__)
def create_agent(
model: BaseChatModel,
tools: List[Any],
*,
checkpointer: Any = None
) -> Any:
"""Create a react agent with the given configuration.
Args:
model: The LLM model to use
tools: List of tools to provide to the agent
checkpointer: Optional memory checkpointer
Returns:
The created agent instance
"""
return create_react_agent(model, tools, checkpointer=checkpointer)
def run_research_agent(
base_task_or_query: str,
model,
@ -125,7 +144,7 @@ def run_research_agent(
)
# Create agent
agent = create_react_agent(model, tools, checkpointer=memory)
agent = create_agent(model, tools, checkpointer=memory)
# Format prompt sections
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
@ -223,7 +242,7 @@ def run_web_research_agent(
tools = get_web_research_tools(expert_enabled=expert_enabled)
# Create agent
agent = create_react_agent(model, tools, checkpointer=memory)
agent = create_agent(model, tools, checkpointer=memory)
# Format prompt sections
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
@ -306,7 +325,7 @@ def run_planning_agent(
tools = get_planning_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
# Create agent
agent = create_react_agent(model, tools, checkpointer=memory)
agent = create_agent(model, tools, checkpointer=memory)
# Format prompt sections
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
@ -393,7 +412,7 @@ def run_task_implementation_agent(
tools = get_implementation_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
# Create agent
agent = create_react_agent(model, tools, checkpointer=memory)
agent = create_agent(model, tools, checkpointer=memory)
# Build prompt
prompt = IMPLEMENTATION_PROMPT.format(