This commit is contained in:
AI Christianson 2024-12-28 16:44:06 -05:00
parent 13b953bf7f
commit 406d1a5358
3 changed files with 53 additions and 15 deletions

View File

@ -10,9 +10,12 @@ import threading
import time
from typing import Optional
from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt import create_react_agent
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.console.formatting import print_stage_header, print_error
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import tool
from typing import List, Any
from ra_aid.console.output import print_agent_output
from ra_aid.logging_config import get_logger
@ -66,6 +69,12 @@ console = Console()
logger = get_logger(__name__)
@tool
def output_markdown_message(message: str) -> str:
"""Outputs a message to the user, optionally prompting for input."""
console.print(Panel(Markdown(message.strip()), title="🤖 Assistant"))
return "Message output."
def create_agent(
model: BaseChatModel,
tools: List[Any],
@ -82,7 +91,27 @@ def create_agent(
Returns:
The created agent instance
"""
return create_react_agent(model, tools, checkpointer=checkpointer)
try:
# Extract model info from module path
module_path = model.__class__.__module__.split('.')
if len(module_path) > 1:
provider = module_path[1] # e.g. anthropic from langchain_anthropic
else:
provider = None
# Get model name if available
model_name = getattr(model, 'model_name', '').lower()
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
if provider == 'anthropic' and 'claude' in model_name:
return create_react_agent(model, tools, checkpointer=checkpointer)
else:
return CiaynAgent(model, tools)
except Exception as e:
# Default to REACT agent if provider/model detection fails
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
return create_react_agent(model, tools, checkpointer=checkpointer)
def run_research_agent(
base_task_or_query: str,
@ -499,7 +528,11 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
logger.debug("Agent run completed successfully")
if _global_memory['task_completed']:
_global_memory['task_completed'] = False
_global_memory['completion_message'] = ''
break
logger.debug("Agent run completed successfully")
return "Agent run completed successfully"
except (KeyboardInterrupt, AgentInterrupt):
raise

View File

@ -1,7 +1,12 @@
import inspect
from dataclasses import dataclass
from typing import Dict, Any, Generator, List, Optional, Union
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
from ra_aid.exceptions import ToolExecutionError
@dataclass
class ChunkMessage:
content: str
status: str
class CiaynAgent:
"""Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction.
@ -74,9 +79,6 @@ class CiaynAgent:
base_prompt += f"\n<last result>{last_result}</last result>"
base_prompt += f"""
<available functions>
{"\n\n".join(self.available_functions)}
</available functions>
<agent instructions>
You are a ReAct agent. You run in a loop and use ONE of the available functions per iteration.
@ -89,13 +91,13 @@ Use as many steps as you need to in order to fully complete the task.
Start by asking the user what they want.
</agent instructions>
<example response>
check_weather("London")
</example response>
<example response>
output_message(\"\"\"How can I help you today?\"\"\", True)
</example response>
You must carefully review the conversation history, which functions were called so far, returned results, etc., and make sure the very next function call you make makes sense in order to achieve the original goal.
You must ONLY use ONE of the following functions (these are the ONLY functions that exist):
<available functions>
{"\n\n".join(self.available_functions)}
</available functions>
Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
return base_prompt
@ -124,9 +126,10 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
def _create_error_chunk(self, content: str) -> Dict[str, Any]:
"""Create an error chunk in the format expected by print_agent_output."""
message = ChunkMessage(content=content, status="error")
return {
"tools": {
"messages": [{"status": "error", "content": content}]
"messages": [message]
}
}
@ -209,5 +212,5 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
yield {}
except ToolExecutionError as e:
chat_history.append(HumanMessage(content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."))
yield self._create_error_chunk(str(e))
break

View File

@ -192,6 +192,8 @@ def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
"""Store multiple key source code snippets in global memory.
Automatically adds the filepaths of the snippets to related files.
This is for **existing**, or **just-written** files, not for things to be created in the future.
Args:
snippets: List of snippet information dictionaries containing:
- filepath: Path to the source file