extract creat agent method

This commit is contained in:
AI Christianson 2024-12-28 15:39:33 -05:00
parent fd664c0886
commit 13b953bf7f
2 changed files with 28 additions and 9 deletions

View File

@ -4,7 +4,6 @@ import uuid
from rich.panel import Panel from rich.panel import Panel
from rich.console import Console from rich.console import Console
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from ra_aid.env import validate_environment from ra_aid.env import validate_environment
from ra_aid.tools.memory import _global_memory from ra_aid.tools.memory import _global_memory
from ra_aid.tools.human import ask_human from ra_aid.tools.human import ask_human
@ -15,7 +14,8 @@ from ra_aid.agent_utils import (
AgentInterrupt, AgentInterrupt,
run_agent_with_retry, run_agent_with_retry,
run_research_agent, run_research_agent,
run_planning_agent run_planning_agent,
create_agent
) )
from ra_aid.prompts import ( from ra_aid.prompts import (
CHAT_PROMPT, CHAT_PROMPT,
@ -177,7 +177,7 @@ def main():
initial_request = ask_human.invoke({"question": "What would you like help with?"}) initial_request = ask_human.invoke({"question": "What would you like help with?"})
# Create chat agent with appropriate tools # Create chat agent with appropriate tools
chat_agent = create_react_agent( chat_agent = create_agent(
model, model,
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled), get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
checkpointer=MemorySaver() checkpointer=MemorySaver()

View File

@ -10,8 +10,10 @@ import threading
import time import time
from typing import Optional from typing import Optional
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from ra_aid.console.formatting import print_stage_header, print_error from ra_aid.console.formatting import print_stage_header, print_error
from langchain_core.language_models import BaseChatModel
from typing import List, Any
from ra_aid.console.output import print_agent_output from ra_aid.console.output import print_agent_output
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
from ra_aid.exceptions import AgentInterrupt from ra_aid.exceptions import AgentInterrupt
@ -41,7 +43,6 @@ from ra_aid.prompts import (
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langchain_core.messages import BaseMessage
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
@ -65,6 +66,24 @@ console = Console()
logger = get_logger(__name__) logger = get_logger(__name__)
def create_agent(
model: BaseChatModel,
tools: List[Any],
*,
checkpointer: Any = None
) -> Any:
"""Create a react agent with the given configuration.
Args:
model: The LLM model to use
tools: List of tools to provide to the agent
checkpointer: Optional memory checkpointer
Returns:
The created agent instance
"""
return create_react_agent(model, tools, checkpointer=checkpointer)
def run_research_agent( def run_research_agent(
base_task_or_query: str, base_task_or_query: str,
model, model,
@ -125,7 +144,7 @@ def run_research_agent(
) )
# Create agent # Create agent
agent = create_react_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory)
# Format prompt sections # Format prompt sections
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
@ -223,7 +242,7 @@ def run_web_research_agent(
tools = get_web_research_tools(expert_enabled=expert_enabled) tools = get_web_research_tools(expert_enabled=expert_enabled)
# Create agent # Create agent
agent = create_react_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory)
# Format prompt sections # Format prompt sections
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
@ -306,7 +325,7 @@ def run_planning_agent(
tools = get_planning_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False)) tools = get_planning_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
# Create agent # Create agent
agent = create_react_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory)
# Format prompt sections # Format prompt sections
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
@ -393,7 +412,7 @@ def run_task_implementation_agent(
tools = get_implementation_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False)) tools = get_implementation_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
# Create agent # Create agent
agent = create_react_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory)
# Build prompt # Build prompt
prompt = IMPLEMENTATION_PROMPT.format( prompt = IMPLEMENTATION_PROMPT.format(