extract creat agent method
This commit is contained in:
parent
fd664c0886
commit
13b953bf7f
|
|
@ -4,7 +4,6 @@ import uuid
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from langgraph.prebuilt import create_react_agent
|
|
||||||
from ra_aid.env import validate_environment
|
from ra_aid.env import validate_environment
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.tools.memory import _global_memory
|
||||||
from ra_aid.tools.human import ask_human
|
from ra_aid.tools.human import ask_human
|
||||||
|
|
@ -15,7 +14,8 @@ from ra_aid.agent_utils import (
|
||||||
AgentInterrupt,
|
AgentInterrupt,
|
||||||
run_agent_with_retry,
|
run_agent_with_retry,
|
||||||
run_research_agent,
|
run_research_agent,
|
||||||
run_planning_agent
|
run_planning_agent,
|
||||||
|
create_agent
|
||||||
)
|
)
|
||||||
from ra_aid.prompts import (
|
from ra_aid.prompts import (
|
||||||
CHAT_PROMPT,
|
CHAT_PROMPT,
|
||||||
|
|
@ -177,7 +177,7 @@ def main():
|
||||||
initial_request = ask_human.invoke({"question": "What would you like help with?"})
|
initial_request = ask_human.invoke({"question": "What would you like help with?"})
|
||||||
|
|
||||||
# Create chat agent with appropriate tools
|
# Create chat agent with appropriate tools
|
||||||
chat_agent = create_react_agent(
|
chat_agent = create_agent(
|
||||||
model,
|
model,
|
||||||
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
|
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
|
||||||
checkpointer=MemorySaver()
|
checkpointer=MemorySaver()
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,10 @@ import threading
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
from ra_aid.console.formatting import print_stage_header, print_error
|
from ra_aid.console.formatting import print_stage_header, print_error
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from typing import List, Any
|
||||||
from ra_aid.console.output import print_agent_output
|
from ra_aid.console.output import print_agent_output
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.exceptions import AgentInterrupt
|
from ra_aid.exceptions import AgentInterrupt
|
||||||
|
|
@ -41,7 +43,6 @@ from ra_aid.prompts import (
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from langchain_core.messages import BaseMessage
|
|
||||||
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
@ -65,6 +66,24 @@ console = Console()
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
def create_agent(
|
||||||
|
model: BaseChatModel,
|
||||||
|
tools: List[Any],
|
||||||
|
*,
|
||||||
|
checkpointer: Any = None
|
||||||
|
) -> Any:
|
||||||
|
"""Create a react agent with the given configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The LLM model to use
|
||||||
|
tools: List of tools to provide to the agent
|
||||||
|
checkpointer: Optional memory checkpointer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created agent instance
|
||||||
|
"""
|
||||||
|
return create_react_agent(model, tools, checkpointer=checkpointer)
|
||||||
|
|
||||||
def run_research_agent(
|
def run_research_agent(
|
||||||
base_task_or_query: str,
|
base_task_or_query: str,
|
||||||
model,
|
model,
|
||||||
|
|
@ -125,7 +144,7 @@ def run_research_agent(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create agent
|
# Create agent
|
||||||
agent = create_react_agent(model, tools, checkpointer=memory)
|
agent = create_agent(model, tools, checkpointer=memory)
|
||||||
|
|
||||||
# Format prompt sections
|
# Format prompt sections
|
||||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||||
|
|
@ -223,7 +242,7 @@ def run_web_research_agent(
|
||||||
tools = get_web_research_tools(expert_enabled=expert_enabled)
|
tools = get_web_research_tools(expert_enabled=expert_enabled)
|
||||||
|
|
||||||
# Create agent
|
# Create agent
|
||||||
agent = create_react_agent(model, tools, checkpointer=memory)
|
agent = create_agent(model, tools, checkpointer=memory)
|
||||||
|
|
||||||
# Format prompt sections
|
# Format prompt sections
|
||||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||||
|
|
@ -306,7 +325,7 @@ def run_planning_agent(
|
||||||
tools = get_planning_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
|
tools = get_planning_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
|
||||||
|
|
||||||
# Create agent
|
# Create agent
|
||||||
agent = create_react_agent(model, tools, checkpointer=memory)
|
agent = create_agent(model, tools, checkpointer=memory)
|
||||||
|
|
||||||
# Format prompt sections
|
# Format prompt sections
|
||||||
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
||||||
|
|
@ -393,7 +412,7 @@ def run_task_implementation_agent(
|
||||||
tools = get_implementation_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
|
tools = get_implementation_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
|
||||||
|
|
||||||
# Create agent
|
# Create agent
|
||||||
agent = create_react_agent(model, tools, checkpointer=memory)
|
agent = create_agent(model, tools, checkpointer=memory)
|
||||||
|
|
||||||
# Build prompt
|
# Build prompt
|
||||||
prompt = IMPLEMENTATION_PROMPT.format(
|
prompt = IMPLEMENTATION_PROMPT.format(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue