From 406d1a5358db5a51791d0eb6706f089ad93c2a5a Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Sat, 28 Dec 2024 16:44:06 -0500 Subject: [PATCH] ciayn --- ra_aid/agent_utils.py | 39 +++++++++++++++++++++++++++++++++--- ra_aid/agents/ciayn_agent.py | 27 ++++++++++++++----------- ra_aid/tools/memory.py | 2 ++ 3 files changed, 53 insertions(+), 15 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 90b0650..a0210ab 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -10,9 +10,12 @@ import threading import time from typing import Optional -from langgraph.prebuilt import create_react_agent +from langgraph.prebuilt import create_react_agent +from ra_aid.agents.ciayn_agent import CiaynAgent +from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.console.formatting import print_stage_header, print_error from langchain_core.language_models import BaseChatModel +from langchain_core.tools import tool from typing import List, Any from ra_aid.console.output import print_agent_output from ra_aid.logging_config import get_logger @@ -66,6 +69,12 @@ console = Console() logger = get_logger(__name__) +@tool +def output_markdown_message(message: str) -> str: + """Outputs a message to the user, optionally prompting for input.""" + console.print(Panel(Markdown(message.strip()), title="🤖 Assistant")) + return "Message output." + def create_agent( model: BaseChatModel, tools: List[Any], @@ -82,7 +91,27 @@ def create_agent( Returns: The created agent instance """ - return create_react_agent(model, tools, checkpointer=checkpointer) + try: + # Extract model info from module path + module_path = model.__class__.__module__.split('.') + if len(module_path) > 1: + provider = module_path[1] # e.g. anthropic from langchain_anthropic + else: + provider = None + + # Get model name if available + model_name = getattr(model, 'model_name', '').lower() + + # Use REACT agent for Anthropic Claude models, otherwise use CIAYN + if provider == 'anthropic' and 'claude' in model_name: + return create_react_agent(model, tools, checkpointer=checkpointer) + else: + return CiaynAgent(model, tools) + + except Exception as e: + # Default to REACT agent if provider/model detection fails + logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") + return create_react_agent(model, tools, checkpointer=checkpointer) def run_research_agent( base_task_or_query: str, @@ -499,7 +528,11 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: logger.debug("Agent output: %s", chunk) check_interrupt() print_agent_output(chunk) - logger.debug("Agent run completed successfully") + if _global_memory['task_completed']: + _global_memory['task_completed'] = False + _global_memory['completion_message'] = '' + break + logger.debug("Agent run completed successfully") return "Agent run completed successfully" except (KeyboardInterrupt, AgentInterrupt): raise diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 7d4dde7..5f021d7 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -1,7 +1,12 @@ import inspect +from dataclasses import dataclass from typing import Dict, Any, Generator, List, Optional, Union from langchain_core.messages import AIMessage, HumanMessage, BaseMessage from ra_aid.exceptions import ToolExecutionError +@dataclass +class ChunkMessage: + content: str + status: str class CiaynAgent: """Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction. @@ -74,9 +79,6 @@ class CiaynAgent: base_prompt += f"\n{last_result}" base_prompt += f""" - -{"\n\n".join(self.available_functions)} - You are a ReAct agent. You run in a loop and use ONE of the available functions per iteration. @@ -89,13 +91,13 @@ Use as many steps as you need to in order to fully complete the task. Start by asking the user what they want. - -check_weather("London") - - - -output_message(\"\"\"How can I help you today?\"\"\", True) - +You must carefully review the conversation history, which functions were called so far, returned results, etc., and make sure the very next function call you make makes sense in order to achieve the original goal. + +You must ONLY use ONE of the following functions (these are the ONLY functions that exist): + + +{"\n\n".join(self.available_functions)} + Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" return base_prompt @@ -124,9 +126,10 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" def _create_error_chunk(self, content: str) -> Dict[str, Any]: """Create an error chunk in the format expected by print_agent_output.""" + message = ChunkMessage(content=content, status="error") return { "tools": { - "messages": [{"status": "error", "content": content}] + "messages": [message] } } @@ -209,5 +212,5 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" yield {} except ToolExecutionError as e: + chat_history.append(HumanMessage(content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again.")) yield self._create_error_chunk(str(e)) - break diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index e5262e9..f87ce0e 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -192,6 +192,8 @@ def emit_key_snippets(snippets: List[SnippetInfo]) -> str: """Store multiple key source code snippets in global memory. Automatically adds the filepaths of the snippets to related files. + This is for **existing**, or **just-written** files, not for things to be created in the future. + Args: snippets: List of snippet information dictionaries containing: - filepath: Path to the source file