extract creat agent method
This commit is contained in:
parent
fd664c0886
commit
13b953bf7f
|
|
@ -4,7 +4,6 @@ import uuid
|
|||
from rich.panel import Panel
|
||||
from rich.console import Console
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from ra_aid.env import validate_environment
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
from ra_aid.tools.human import ask_human
|
||||
|
|
@ -15,7 +14,8 @@ from ra_aid.agent_utils import (
|
|||
AgentInterrupt,
|
||||
run_agent_with_retry,
|
||||
run_research_agent,
|
||||
run_planning_agent
|
||||
run_planning_agent,
|
||||
create_agent
|
||||
)
|
||||
from ra_aid.prompts import (
|
||||
CHAT_PROMPT,
|
||||
|
|
@ -177,7 +177,7 @@ def main():
|
|||
initial_request = ask_human.invoke({"question": "What would you like help with?"})
|
||||
|
||||
# Create chat agent with appropriate tools
|
||||
chat_agent = create_react_agent(
|
||||
chat_agent = create_agent(
|
||||
model,
|
||||
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
|
||||
checkpointer=MemorySaver()
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ import threading
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from ra_aid.console.formatting import print_stage_header, print_error
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from typing import List, Any
|
||||
from ra_aid.console.output import print_agent_output
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
|
|
@ -41,7 +43,6 @@ from ra_aid.prompts import (
|
|||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
|
@ -65,6 +66,24 @@ console = Console()
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def create_agent(
|
||||
model: BaseChatModel,
|
||||
tools: List[Any],
|
||||
*,
|
||||
checkpointer: Any = None
|
||||
) -> Any:
|
||||
"""Create a react agent with the given configuration.
|
||||
|
||||
Args:
|
||||
model: The LLM model to use
|
||||
tools: List of tools to provide to the agent
|
||||
checkpointer: Optional memory checkpointer
|
||||
|
||||
Returns:
|
||||
The created agent instance
|
||||
"""
|
||||
return create_react_agent(model, tools, checkpointer=checkpointer)
|
||||
|
||||
def run_research_agent(
|
||||
base_task_or_query: str,
|
||||
model,
|
||||
|
|
@ -125,7 +144,7 @@ def run_research_agent(
|
|||
)
|
||||
|
||||
# Create agent
|
||||
agent = create_react_agent(model, tools, checkpointer=memory)
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
|
||||
# Format prompt sections
|
||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||
|
|
@ -223,7 +242,7 @@ def run_web_research_agent(
|
|||
tools = get_web_research_tools(expert_enabled=expert_enabled)
|
||||
|
||||
# Create agent
|
||||
agent = create_react_agent(model, tools, checkpointer=memory)
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
|
||||
# Format prompt sections
|
||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||
|
|
@ -306,7 +325,7 @@ def run_planning_agent(
|
|||
tools = get_planning_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
|
||||
|
||||
# Create agent
|
||||
agent = create_react_agent(model, tools, checkpointer=memory)
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
|
||||
# Format prompt sections
|
||||
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
||||
|
|
@ -393,7 +412,7 @@ def run_task_implementation_agent(
|
|||
tools = get_implementation_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
|
||||
|
||||
# Create agent
|
||||
agent = create_react_agent(model, tools, checkpointer=memory)
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
|
||||
# Build prompt
|
||||
prompt = IMPLEMENTATION_PROMPT.format(
|
||||
|
|
|
|||
Loading…
Reference in New Issue