From 45b993cfd02d438b8dfbca8d87d4db8c8d2d29d6 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Sun, 9 Feb 2025 22:07:23 -0800 Subject: [PATCH 01/45] feat(issue): add LLM Tool Call Fallback Feature documentation to outline the new functionality for automatic fallback to alternative LLM models after consecutive failures feat(ciayn_agent): implement fallback mechanism in CiaynAgent to handle tool call failures and switch to alternative models feat(__main__): add command line arguments for fallback configuration in the main application feat(llm): add validation for required environment variables for LLM providers and merge chat histories during fallback fix(config): define default values for maximum tool failures in configuration test(ciayn_agent): add unit tests for fallback logic and tool call execution with retries and error handling test(llm): enhance tests for LLM initialization and environment variable validation --- issue.md | 115 ++++++++++++++++ ra_aid/__main__.py | 34 +++++ ra_aid/agents/ciayn_agent.py | 91 +++++++++++-- ra_aid/config.py | 2 + ra_aid/llm.py | 61 +++++++-- ra_aid/models_params.py | 1 - ra_aid/tool_configs.py | 8 -- ra_aid/tools/expert.py | 6 +- ra_aid/tools/programmer.py | 4 +- tests/ra_aid/{agents => }/test_ciayn_agent.py | 128 +++++++++++++++--- tests/ra_aid/test_llm.py | 13 +- 11 files changed, 403 insertions(+), 60 deletions(-) create mode 100644 issue.md rename tests/ra_aid/{agents => }/test_ciayn_agent.py (63%) diff --git a/issue.md b/issue.md new file mode 100644 index 0000000..62e246b --- /dev/null +++ b/issue.md @@ -0,0 +1,115 @@ +# LLM Tool Call Fallback Feature + +## Overview +Add functionality to automatically fallback to alternative LLM models when a tool call experiences multiple consecutive failures. + +## Background +Currently, when a tool call fails due to LLM-related errors (e.g., API timeouts, rate limits, context length issues), there is no automatic fallback mechanism. This can lead to interrupted workflows and poor user experience. + +## Relevant Files +- ra_aid/agents/ciayn_agent.py +- ra_aid/llm.py +- ra_aid/agent_utils.py +- ra_aid/__main__.py +- ra_aid/models_params.py + + +## Implementation Details + +### Configuration +- Add new configuration value `max_tool_failures` (default: 3) to track consecutive failures before triggering fallback +- Add new command line argument `--no-fallback-tool` to disable fallback behavior (enabled by default) +- **Add new command line argument** `--fallback-tool-models` to specify a comma-separated list of fallback tool models (default: "gpt-3.5-turbo,gpt-4") + This list defines the fallback model sequence used by forced tool calls (via `bind_tools`) when tool call failures occur. +- Track failure count per tool call context +- Reset failure counter on successful tool call +- Store fallback model sequence per provider +- Need to validate if ENV vars are set for provider usage of that fallback model + before usage, if that fallback ENV is not available then fallback to the next model +- Have default list of common models, first try `claude-3-5-sonnet-20241022` but + have many alternative fallback models. + +### Tool Call Wrapper +Create a new wrapper function to handle tool call execution with fallback logic: + +```python +def execute_tool_with_fallback(tool_call_func, *args, **kwargs): + failures = 0 + max_failures = get_config().max_tool_failures + + while failures < max_failures: + try: + return tool_call_func(*args, **kwargs) + except LLMError as e: + failures += 1 + if failures >= max_failures: + # Use forced tool call via bind_tools with retry: + llm_retry = llm_model.with_retry(stop_after_attempt=3) # Try three times + try_fallback_model(force=True, model=llm_retry) + # Merge fallback model chat messages back into the original chat history. + merge_fallback_chat_history() + failures = 0 # Reset counter for new model + else: + raise +``` + +The prompt passed to `try_fallback_model`, should be the failed last few failing tool calls. + +### Model Fallback Sequence +Define fallback sequences for each provider based on model capabilities: + +1. Try same provider's smaller models +2. Try alternative providers' equivalent models +3. Raise final error if all fallbacks fail + +### Provider Strategy Updates +Update provider strategies to support fallback configuration: +- Add provider-specific fallback sequences +- Handle model capability validation during fallback +- Track successful/failed attempts + +## Risks and Mitigations +1. **Performance Impact** + - Risk: Multiple fallback attempts could increase latency + - Mitigation: Set reasonable max_failures limit and timeouts + +2. **Consistency** + - Risk: Different models may give slightly different outputs + - Mitigation: Validate output schema consistency across models + +3. **Cost** + - Risk: Fallback to more expensive models + - Mitigation: Configure cost limits and preferred fallback sequences + +4. **State Management** + - Risk: Loss of context during fallbacks + - Mitigation: Preserve conversation state and tool context + +## Acceptance Criteria +1. Tool calls automatically attempt fallback models after N consecutive failures +2. `--no-fallback-tool` argument successfully disables fallback behavior +3. Fallback sequence respects provider and model capabilities +4. Original error is preserved if all fallbacks fail +5. Unit tests cover fallback scenarios and edge cases +6. README.md updated to reflect new behavior + +## Testing +1. Unit tests for fallback wrapper +2. Integration tests with mock LLM failures +3. Provider strategy fallback tests +4. Command line argument handling +5. Error preservation and reporting +6. Performance impact measurement +7. Edge cases (e.g., partial failures, timeout handling) +8. State preservation during fallbacks + +## Documentation Updates +1. Add fallback feature to main README +2. Document `--no-fallback-tool` in CLI help +3. Document provider-specific fallback sequences + +## Future Considerations +1. Allow custom fallback sequences via configuration +2. Add monitoring and alerting for fallback frequency +3. Optimize fallback selection based on historical success rates +4. Cost-aware fallback routing diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 21868bd..16fd5df 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -149,6 +149,17 @@ Examples: action="store_false", help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.", ) + parser.add_argument( + "--no-fallback-tool", + action="store_true", + help="Disable fallback model switching.", + ) + parser.add_argument( + "--fallback-tool-models", + type=str, + default="gpt-3.5-turbo,gpt-4", + help="Comma-separated list of fallback models to use in order.", + ) parser.add_argument( "--recursion-limit", type=int, @@ -414,12 +425,35 @@ def main(): ) _global_memory["config"]["planner_model"] = args.planner_model or args.model + _global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool + _global_memory["config"]["fallback_tool_models"] = ( + [ + model.strip() + for model in args.fallback_tool_models.split(",") + if model.strip() + ] + if args.fallback_tool_models + else [] + ) + # Store research config with fallback to base values _global_memory["config"]["research_provider"] = ( args.research_provider or args.provider ) _global_memory["config"]["research_model"] = args.research_model or args.model + # Store fallback tool configuration + _global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool + _global_memory["config"]["fallback_tool_models"] = ( + [ + model.strip() + for model in args.fallback_tool_models.split(",") + if model.strip() + ] + if args.fallback_tool_models + else [] + ) + # Run research stage print_stage_header("Research Stage") diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index c615cc4..cbe52bf 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -68,12 +68,29 @@ class CiaynAgent: - Memory management with configurable limits """ + class ToolCallFailure: + """Tracks consecutive failures and fallback model usage for tool calls. + + Attributes: + consecutive_failures (int): Count of consecutive failures for current model + current_provider (Optional[str]): Current provider being used + current_model (Optional[str]): Current model being used + used_fallbacks (Set[str]): Set of fallback models already attempted + """ + + def __init__(self): + self.consecutive_failures = 0 + self.current_provider = None + self.current_model = None + self.used_fallbacks = set() + def __init__( self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT, + config: Optional[dict] = None, ): """Initialize the agent with a model and list of tools. @@ -82,7 +99,17 @@ class CiaynAgent: tools: List of tools available to the agent max_history_messages: Maximum number of messages to keep in chat history max_tokens: Maximum number of tokens allowed in message history (None for no limit) + config: Optional configuration dictionary for fallback settings """ + if config is None: + config = {} + self.config = config + self.provider = config.get("provider", "openai") + self.fallback_enabled = config.get("fallback_tool_enabled", True) + fallback_models_str = config.get("fallback_tool_models", "gpt-3.5-turbo,gpt-4") + self.fallback_tool_models = [ + m.strip() for m in fallback_models_str.split(",") if m.strip() + ] self.model = model self.tools = tools self.max_history_messages = max_history_messages @@ -90,6 +117,7 @@ class CiaynAgent: self.available_functions = [] for t in tools: self.available_functions.append(get_function_info(t.func)) + self._tool_failure = CiaynAgent.ToolCallFailure() def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" @@ -221,23 +249,56 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" return base_prompt def _execute_tool(self, code: str) -> str: - """Execute a tool call and return its result.""" - globals_dict = {tool.func.__name__: tool.func for tool in self.tools} + """Execute a tool call with retry and fallback logic and return its result.""" + max_retries = 3 + retries = 0 + last_error = None + while retries < max_retries: + try: + code = code.strip() + if validate_function_call_pattern(code): + functions_list = "\n\n".join(self.available_functions) + code = _extract_tool_call(code, functions_list) + globals_dict = {tool.func.__name__: tool.func for tool in self.tools} + result = eval(code, globals_dict) + self._tool_failure.consecutive_failures = 0 + return result + except Exception as e: + self._handle_tool_failure(code, e) + last_error = e + retries += 1 + raise ToolExecutionError( + f"Error executing code after {max_retries} attempts: {str(last_error)}" + ) + def _handle_tool_failure(self, code: str, error: Exception) -> None: + self._tool_failure.consecutive_failures += 1 + max_failures = self.config.get("max_tool_failures", 3) + if ( + self.fallback_enabled + and self._tool_failure.consecutive_failures >= max_failures + and self.fallback_tool_models + ): + self._attempt_fallback(code) + + def _attempt_fallback(self, code: str) -> None: + new_model = self.fallback_tool_models[0] + failed_tool_call_name = code.split('(')[0].strip() + logger.error( + f"Tool call failed {self._tool_failure.consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}" + ) try: - code = code.strip() - # code = code.replace("\n", " ") - - # if the eval fails, try to extract it via a model call - if validate_function_call_pattern(code): - functions_list = "\n\n".join(self.available_functions) - code = _extract_tool_call(code, functions_list) - - result = eval(code.strip(), globals_dict) - return result - except Exception as e: - error_msg = f"Error executing code: {str(e)}" - raise ToolExecutionError(error_msg) + from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env + if not validate_provider_env(self.provider): + logger.error(f"Missing environment configuration for provider {self.provider}. Cannot fallback.") + else: + self.model = initialize_llm(self.provider, new_model) + self.model.bind_tools(self.tools, tool_choice=failed_tool_call_name) + self._tool_failure.used_fallbacks.add(new_model) + merge_chat_history() # Assuming merge_chat_history handles merging fallback history + self._tool_failure.consecutive_failures = 0 + except Exception as switch_e: + logger.error(f"Fallback model switching failed: {switch_e}") def _create_agent_chunk(self, content: str) -> Dict[str, Any]: """Create an agent chunk in the format expected by print_agent_output.""" diff --git a/ra_aid/config.py b/ra_aid/config.py index 7977167..6c12a93 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -2,3 +2,5 @@ DEFAULT_RECURSION_LIMIT = 100 DEFAULT_MAX_TEST_CMD_RETRIES = 3 +DEFAULT_MAX_TOOL_FAILURES = 3 +MAX_TOOL_FAILURES = 3 diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 4a4038a..9516f0b 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -1,8 +1,9 @@ import os -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from langchain_anthropic import ChatAnthropic from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI @@ -47,9 +48,9 @@ def create_deepseek_client( return ChatDeepseekReasoner( api_key=api_key, base_url=base_url, - temperature=0 - if is_expert - else (temperature if temperature is not None else 1), + temperature=( + 0 if is_expert else (temperature if temperature is not None else 1) + ), model=model_name, ) @@ -72,9 +73,9 @@ def create_openrouter_client( return ChatDeepseekReasoner( api_key=api_key, base_url="https://openrouter.ai/api/v1", - temperature=0 - if is_expert - else (temperature if temperature is not None else 1), + temperature=( + 0 if is_expert else (temperature if temperature is not None else 1) + ), model=model_name, ) @@ -114,7 +115,12 @@ def get_provider_config(provider: str, is_expert: bool = False) -> Dict[str, Any "base_url": "https://api.deepseek.com", }, } - return configs.get(provider, {}) + config = configs.get(provider, {}) + if not config or not config.get("api_key"): + raise ValueError( + f"Missing required environment variable for provider: {provider}" + ) + return config def create_llm_client( @@ -219,8 +225,41 @@ def initialize_llm( return create_llm_client(provider, model_name, temperature, is_expert=False) -def initialize_expert_llm( - provider: str, model_name: str -) -> BaseChatModel: +def initialize_expert_llm(provider: str, model_name: str) -> BaseChatModel: """Initialize an expert language model client based on the specified provider and model.""" return create_llm_client(provider, model_name, temperature=None, is_expert=True) + + +def validate_provider_env(provider: str) -> bool: + """Check if the required environment variables for a provider are set.""" + required_vars = { + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "openrouter": "OPENROUTER_API_KEY", + "openai-compatible": "OPENAI_API_KEY", + "gemini": "GEMINI_API_KEY", + "deepseek": "DEEPSEEK_API_KEY", + } + key = required_vars.get(provider.lower()) + if key: + return bool(os.getenv(key)) + return False + + +def merge_chat_history( + original_history: List[BaseMessage], fallback_history: List[BaseMessage] +) -> List[BaseMessage]: + """Merge original and fallback chat histories while preserving order. + + Args: + original_history: The original chat message history + fallback_history: Additional messages from fallback attempts + + Returns: + List[BaseMessage]: Combined message history preserving chronological order + + Note: + The function appends fallback messages to maintain context for future + interactions while preserving the original conversation flow. + """ + return original_history + fallback_history diff --git a/ra_aid/models_params.py b/ra_aid/models_params.py index ee83f70..6e0072d 100644 --- a/ra_aid/models_params.py +++ b/ra_aid/models_params.py @@ -27,7 +27,6 @@ models_params = { "gpt-4o-mini": {"token_limit": 128000, "supports_temperature": True}, "o1-preview": {"token_limit": 128000, "supports_temperature": False}, "o1-mini": {"token_limit": 128000, "supports_temperature": False}, - "o1-preview": {"token_limit": 128000, "supports_temperature": False}, "o1": {"token_limit": 200000, "supports_temperature": False}, "o3-mini": {"token_limit": 200000, "supports_temperature": False}, }, diff --git a/ra_aid/tool_configs.py b/ra_aid/tool_configs.py index 3950ecd..e4042f1 100644 --- a/ra_aid/tool_configs.py +++ b/ra_aid/tool_configs.py @@ -1,25 +1,18 @@ from ra_aid.tools import ( ask_expert, ask_human, - delete_key_facts, - delete_key_snippets, - deregister_related_files, emit_expert_context, emit_key_facts, emit_key_snippets, - emit_plan, emit_related_files, emit_research_notes, fuzzy_find_project_files, list_directory_tree, - monorepo_detected, - plan_implementation_completed, read_file_tool, ripgrep_search, run_programming_task, run_shell_command, task_completed, - ui_detected, web_search_tavily, ) from ra_aid.tools.agent import ( @@ -29,7 +22,6 @@ from ra_aid.tools.agent import ( request_task_implementation, request_web_research, ) -from ra_aid.tools.memory import one_shot_completed from ra_aid.tools.write_file import write_file_tool diff --git a/ra_aid/tools/expert.py b/ra_aid/tools/expert.py index 3b8188f..6bb1a69 100644 --- a/ra_aid/tools/expert.py +++ b/ra_aid/tools/expert.py @@ -185,7 +185,11 @@ def ask_expert(question: str) -> str: query_parts.extend(["# Question", question]) query_parts.extend( - ["\n # Addidional Requirements", "**DO NOT OVERTHINK**", "**DO NOT OVERCOMPLICATE**"] + [ + "\n # Addidional Requirements", + "**DO NOT OVERTHINK**", + "**DO NOT OVERCOMPLICATE**", + ] ) # Join all parts diff --git a/ra_aid/tools/programmer.py b/ra_aid/tools/programmer.py index 4f62f92..35d3b2f 100644 --- a/ra_aid/tools/programmer.py +++ b/ra_aid/tools/programmer.py @@ -40,7 +40,7 @@ def run_programming_task( If new files are created, emit them after finishing. They can add/modify files, but not remove. Use run_shell_command to remove files. If referencing files you’ll delete, remove them after they finish. - + Use write_file_tool instead if you need to write the entire contents of file(s). If the programmer wrote files, they actually wrote to disk. You do not need to rewrite the output of what the programmer showed you. @@ -117,7 +117,7 @@ def run_programming_task( # Log the programming task log_work_event(f"Executed programming task: {_truncate_for_log(instructions)}") - + # Return structured output return { "output": truncate_output(result[0].decode()) if result[0] else "", diff --git a/tests/ra_aid/agents/test_ciayn_agent.py b/tests/ra_aid/test_ciayn_agent.py similarity index 63% rename from tests/ra_aid/agents/test_ciayn_agent.py rename to tests/ra_aid/test_ciayn_agent.py index b56310a..65794d0 100644 --- a/tests/ra_aid/agents/test_ciayn_agent.py +++ b/tests/ra_aid/test_ciayn_agent.py @@ -1,11 +1,41 @@ +import unittest from unittest.mock import Mock import pytest from langchain_core.messages import AIMessage, HumanMessage from ra_aid.agents.ciayn_agent import CiaynAgent, validate_function_call_pattern +from ra_aid.exceptions import ToolExecutionError +# Dummy tool function for testing retry and fallback behavior +def dummy_tool(): + dummy_tool.attempt += 1 + if dummy_tool.attempt < 3: + raise Exception("Simulated failure") + return "dummy success" + + +dummy_tool.attempt = 0 + + +class DummyTool: + def __init__(self, func): + self.func = func + + +class DummyModel: + def invoke(self, messages): + # Always return a code snippet that calls dummy_tool() + class Response: + content = "dummy_tool()" + + return Response() + def bind_tools(self, tools, tool_choice): + pass + + +# Fixtures from the source file @pytest.fixture def mock_model(): """Create a mock language model.""" @@ -21,6 +51,7 @@ def agent(mock_model): return CiaynAgent(mock_model, tools, max_history_messages=3) +# Trimming test functions def test_trim_chat_history_preserves_initial_messages(agent): """Test that initial messages are preserved during trimming.""" initial_messages = [ @@ -33,9 +64,7 @@ def test_trim_chat_history_preserves_initial_messages(agent): HumanMessage(content="Chat 3"), AIMessage(content="Chat 4"), ] - result = agent._trim_chat_history(initial_messages, chat_history) - # Verify initial messages are preserved assert result[:2] == initial_messages # Verify only last 3 chat messages are kept (due to max_history_messages=3) @@ -47,9 +76,7 @@ def test_trim_chat_history_under_limit(agent): """Test trimming when chat history is under the maximum limit.""" initial_messages = [HumanMessage(content="Initial")] chat_history = [HumanMessage(content="Chat 1"), AIMessage(content="Chat 2")] - result = agent._trim_chat_history(initial_messages, chat_history) - # Verify no trimming occurred assert len(result) == 3 assert result == initial_messages + chat_history @@ -65,9 +92,7 @@ def test_trim_chat_history_over_limit(agent): AIMessage(content="Chat 4"), HumanMessage(content="Chat 5"), ] - result = agent._trim_chat_history(initial_messages, chat_history) - # Verify correct trimming assert len(result) == 4 # initial + max_history_messages assert result[0] == initial_messages[0] # Initial message preserved @@ -83,9 +108,7 @@ def test_trim_chat_history_empty_initial(agent): HumanMessage(content="Chat 3"), AIMessage(content="Chat 4"), ] - result = agent._trim_chat_history(initial_messages, chat_history) - # Verify only last 3 messages are kept assert len(result) == 3 assert result == chat_history[-3:] @@ -98,9 +121,7 @@ def test_trim_chat_history_empty_chat(agent): AIMessage(content="Initial 2"), ] chat_history = [] - result = agent._trim_chat_history(initial_messages, chat_history) - # Verify initial messages are preserved and no trimming occurred assert result == initial_messages assert len(result) == 2 @@ -109,16 +130,13 @@ def test_trim_chat_history_empty_chat(agent): def test_trim_chat_history_token_limit(): """Test trimming based on token limit.""" agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=25) - initial_messages = [HumanMessage(content="Initial")] # ~2 tokens chat_history = [ HumanMessage(content="A" * 40), # ~10 tokens AIMessage(content="B" * 40), # ~10 tokens HumanMessage(content="C" * 40), # ~10 tokens ] - result = agent._trim_chat_history(initial_messages, chat_history) - # Should keep initial message (~2 tokens) and last message (~10 tokens) assert len(result) == 2 assert result[0] == initial_messages[0] @@ -128,16 +146,13 @@ def test_trim_chat_history_token_limit(): def test_trim_chat_history_no_token_limit(): """Test trimming with no token limit set.""" agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None) - initial_messages = [HumanMessage(content="Initial")] chat_history = [ HumanMessage(content="A" * 1000), AIMessage(content="B" * 1000), HumanMessage(content="C" * 1000), ] - result = agent._trim_chat_history(initial_messages, chat_history) - # Should keep initial message and last 2 messages (max_history_messages=2) assert len(result) == 3 assert result[0] == initial_messages[0] @@ -147,7 +162,6 @@ def test_trim_chat_history_no_token_limit(): def test_trim_chat_history_both_limits(): """Test trimming with both message count and token limits.""" agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=35) - initial_messages = [HumanMessage(content="Init")] # ~1 token chat_history = [ HumanMessage(content="A" * 40), # ~10 tokens @@ -155,9 +169,7 @@ def test_trim_chat_history_both_limits(): HumanMessage(content="C" * 40), # ~10 tokens AIMessage(content="D" * 40), # ~10 tokens ] - result = agent._trim_chat_history(initial_messages, chat_history) - # Should first apply message limit (keeping last 3) # Then token limit should further reduce to fit under 15 tokens assert len(result) == 2 # Initial message + 1 message under token limit @@ -165,6 +177,33 @@ def test_trim_chat_history_both_limits(): assert result[1] == chat_history[-1] +# Fallback tests +class TestCiaynAgentFallback(unittest.TestCase): + def setUp(self): + # Reset dummy_tool attempt counter before each test + dummy_tool.attempt = 0 + self.dummy_tool = DummyTool(dummy_tool) + self.model = DummyModel() + # Create a CiaynAgent with the dummy tool + self.agent = CiaynAgent(self.model, [self.dummy_tool]) + + def test_retry_logic_with_failure_recovery(self): + # Test that _execute_tool retries and eventually returns success + result = self.agent._execute_tool("dummy_tool()") + self.assertEqual(result, "dummy success") + + def test_switch_models_on_fallback(self): + # Test fallback behavior by making dummy_tool always fail + def always_fail(): + raise Exception("Persistent failure") + + always_fail_tool = DummyTool(always_fail) + agent = CiaynAgent(self.model, [always_fail_tool]) + with self.assertRaises(ToolExecutionError): + agent._execute_tool("always_fail()") + + +# Function call validation tests class TestFunctionCallValidation: @pytest.mark.parametrize( "test_input", @@ -221,3 +260,54 @@ class TestFunctionCallValidation: def test_multiline_responses(self, test_input): """Test function calls spanning multiple lines.""" assert not validate_function_call_pattern(test_input) + + +class TestCiaynAgentNewMethods(unittest.TestCase): + def setUp(self): + # Create a dummy tool that always fails for testing fallback + def always_fail(): + raise Exception("Failure for fallback test") + self.always_fail_tool = DummyTool(always_fail) + # Create a dummy model that does minimal work for fallback tests + self.dummy_model = DummyModel() + # Initialize CiaynAgent with configuration to trigger fallback quickly + self.agent = CiaynAgent( + self.dummy_model, + [self.always_fail_tool], + config={"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"} + ) + + def test_handle_tool_failure_increments_counter(self): + initial_failures = self.agent._tool_failure.consecutive_failures + self.agent._handle_tool_failure("dummy_call()", Exception("Test error")) + self.assertEqual(self.agent._tool_failure.consecutive_failures, initial_failures + 1) + + def test_attempt_fallback_invokes_fallback_logic(self): + # Monkey-patch initialize_llm, merge_chat_history, and validate_provider_env + # to simulate fallback switching without external dependencies. + def dummy_initialize_llm(provider, model_name, temperature=None): + return self.dummy_model + def dummy_merge_chat_history(): + return ["merged"] + def dummy_validate_provider_env(provider): + return True + import ra_aid.llm as llm + original_initialize = llm.initialize_llm + original_merge = llm.merge_chat_history + original_validate = llm.validate_provider_env + llm.initialize_llm = dummy_initialize_llm + llm.merge_chat_history = dummy_merge_chat_history + llm.validate_provider_env = dummy_validate_provider_env + + # Set failure counter high enough to trigger fallback in _handle_tool_failure + self.agent._tool_failure.consecutive_failures = 2 + # Call _attempt_fallback; it should reset the failure counter to 0 on success. + self.agent._attempt_fallback("always_fail_tool()") + self.assertEqual(self.agent._tool_failure.consecutive_failures, 0) + # Restore original functions + llm.initialize_llm = original_initialize + llm.merge_chat_history = original_merge + llm.validate_provider_env = original_validate + +if __name__ == "__main__": + unittest.main() diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index 2e7ea10..6789132 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -54,7 +54,9 @@ def test_initialize_expert_defaults(clean_env, mock_openai, monkeypatch): monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key") _llm = initialize_expert_llm("openai", "o1") - mock_openai.assert_called_once_with(api_key="test-key", model="o1", reasoning_effort="high") + mock_openai.assert_called_once_with( + api_key="test-key", model="o1", reasoning_effort="high" + ) def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch): @@ -63,7 +65,10 @@ def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch): _llm = initialize_expert_llm("openai", "gpt-4-preview") mock_openai.assert_called_once_with( - api_key="test-key", model="gpt-4-preview", temperature=0, reasoning_effort="high" + api_key="test-key", + model="gpt-4-preview", + temperature=0, + reasoning_effort="high", ) @@ -348,7 +353,9 @@ def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch): # Test LLM client creation with expert mode _llm = create_llm_client("openai", "o1", is_expert=True) - mock_openai.assert_called_with(api_key="expert-key", model="o1", reasoning_effort="high") + mock_openai.assert_called_with( + api_key="expert-key", model="o1", reasoning_effort="high" + ) # Test environment validation monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "") From d8ee4e04f4c7ccaec4e8f78a6fbf9f0e25135ae3 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Mon, 10 Feb 2025 14:13:19 -0800 Subject: [PATCH 02/45] feat(fallback): implement automatic fallback to alternative LLM models on consecutive failures to enhance user experience and prevent infinite error loops refactor(ciayn_agent): restructure tool failure handling to track consecutive failures and fallback attempts more effectively fix(logging): add pretty logging option for improved log readability chore(config): define valid providers for LLM selection and update fallback model loading logic test(ciayn_agent): add unit tests for fallback logic and tool failure handling to ensure reliability and correctness --- issue.md | 47 +-- ra_aid/__main__.py | 33 +- ra_aid/agents/ciayn_agent.py | 95 ++++-- ra_aid/config.py | 10 +- ra_aid/logging_config.py | 47 ++- ra_aid/tool_leaderboard.py | 529 +++++++++++++++++++++++++++++++ tests/ra_aid/test_ciayn_agent.py | 24 +- 7 files changed, 688 insertions(+), 97 deletions(-) create mode 100644 ra_aid/tool_leaderboard.py diff --git a/issue.md b/issue.md index 62e246b..3bd0988 100644 --- a/issue.md +++ b/issue.md @@ -4,15 +4,7 @@ Add functionality to automatically fallback to alternative LLM models when a tool call experiences multiple consecutive failures. ## Background -Currently, when a tool call fails due to LLM-related errors (e.g., API timeouts, rate limits, context length issues), there is no automatic fallback mechanism. This can lead to interrupted workflows and poor user experience. - -## Relevant Files -- ra_aid/agents/ciayn_agent.py -- ra_aid/llm.py -- ra_aid/agent_utils.py -- ra_aid/__main__.py -- ra_aid/models_params.py - +Currently, when a tool call fails due to LLM-related errors (e.g., invalid format), there is no automatic fallback mechanism. This often causes infinite loop of erroring tool calls. ## Implementation Details @@ -59,32 +51,25 @@ The prompt passed to `try_fallback_model`, should be the failed last few failing Define fallback sequences for each provider based on model capabilities: 1. Try same provider's smaller models -2. Try alternative providers' equivalent models +2. Try alternative providers' similar models 3. Raise final error if all fallbacks fail -### Provider Strategy Updates -Update provider strategies to support fallback configuration: -- Add provider-specific fallback sequences -- Handle model capability validation during fallback -- Track successful/failed attempts - ## Risks and Mitigations -1. **Performance Impact** - - Risk: Multiple fallback attempts could increase latency - - Mitigation: Set reasonable max_failures limit and timeouts - -2. **Consistency** - - Risk: Different models may give slightly different outputs - - Mitigation: Validate output schema consistency across models - -3. **Cost** +1. **Cost** - Risk: Fallback to more expensive models - Mitigation: Configure cost limits and preferred fallback sequences -4. **State Management** +2. **State Management** - Risk: Loss of context during fallbacks - Mitigation: Preserve conversation state and tool context +## Relevant Files +- ra_aid/agents/ciayn_agent.py +- ra_aid/llm.py +- ra_aid/agent_utils.py +- ra_aid/__main__.py +- ra_aid/models_params.py + ## Acceptance Criteria 1. Tool calls automatically attempt fallback models after N consecutive failures 2. `--no-fallback-tool` argument successfully disables fallback behavior @@ -93,16 +78,6 @@ Update provider strategies to support fallback configuration: 5. Unit tests cover fallback scenarios and edge cases 6. README.md updated to reflect new behavior -## Testing -1. Unit tests for fallback wrapper -2. Integration tests with mock LLM failures -3. Provider strategy fallback tests -4. Command line argument handling -5. Error preservation and reporting -6. Performance impact measurement -7. Edge cases (e.g., partial failures, timeout handling) -8. State preservation during fallbacks - ## Documentation Updates 1. Add fallback feature to main README 2. Document `--no-fallback-tool` in CLI help diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 16fd5df..654dd60 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -17,7 +17,11 @@ from ra_aid.agent_utils import ( run_planning_agent, run_research_agent, ) -from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT +from ra_aid.config import ( + DEFAULT_MAX_TEST_CMD_RETRIES, + DEFAULT_RECURSION_LIMIT, + VALID_PROVIDERS, +) from ra_aid.dependencies import check_dependencies from ra_aid.env import validate_environment from ra_aid.llm import initialize_llm @@ -40,14 +44,6 @@ def launch_webui(host: str, port: int): def parse_arguments(args=None): - VALID_PROVIDERS = [ - "anthropic", - "openai", - "openrouter", - "openai-compatible", - "deepseek", - "gemini", - ] ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022" OPENAI_DEFAULT_MODEL = "gpt-4o" @@ -80,9 +76,11 @@ Examples: parser.add_argument( "--provider", type=str, - default="openai" - if (os.getenv("OPENAI_API_KEY") and not os.getenv("ANTHROPIC_API_KEY")) - else "anthropic", + default=( + "openai" + if (os.getenv("OPENAI_API_KEY") and not os.getenv("ANTHROPIC_API_KEY")) + else "anthropic" + ), choices=VALID_PROVIDERS, help="The LLM provider to use", ) @@ -138,6 +136,9 @@ Examples: parser.add_argument( "--verbose", action="store_true", help="Enable verbose logging output" ) + parser.add_argument( + "--pretty-logger", action="store_true", help="Enable pretty logging output" + ) parser.add_argument( "--temperature", type=float, @@ -276,7 +277,7 @@ def is_stage_requested(stage: str) -> bool: def main(): """Main entry point for the ra-aid command line tool.""" args = parse_arguments() - setup_logging(args.verbose) + setup_logging(args.verbose, args.pretty_logger) logger.debug("Starting RA.Aid with arguments: %s", args) # Launch web interface if requested @@ -378,9 +379,9 @@ def main(): chat_agent, CHAT_PROMPT.format( initial_request=initial_request, - web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT - if web_research_enabled - else "", + web_research_section=( + WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else "" + ), working_directory=working_directory, current_date=current_date, project_info=formatted_project_info, diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index cbe52bf..a651ab8 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Generator, List, Optional, Union from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES from ra_aid.exceptions import ToolExecutionError from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT @@ -68,22 +69,6 @@ class CiaynAgent: - Memory management with configurable limits """ - class ToolCallFailure: - """Tracks consecutive failures and fallback model usage for tool calls. - - Attributes: - consecutive_failures (int): Count of consecutive failures for current model - current_provider (Optional[str]): Current provider being used - current_model (Optional[str]): Current model being used - used_fallbacks (Set[str]): Set of fallback models already attempted - """ - - def __init__(self): - self.consecutive_failures = 0 - self.current_provider = None - self.current_model = None - self.used_fallbacks = set() - def __init__( self, model, @@ -106,10 +91,8 @@ class CiaynAgent: self.config = config self.provider = config.get("provider", "openai") self.fallback_enabled = config.get("fallback_tool_enabled", True) - fallback_models_str = config.get("fallback_tool_models", "gpt-3.5-turbo,gpt-4") - self.fallback_tool_models = [ - m.strip() for m in fallback_models_str.split(",") if m.strip() - ] + self.fallback_tool_models = self._load_fallback_tool_models(config) + self.model = model self.tools = tools self.max_history_messages = max_history_messages @@ -117,7 +100,18 @@ class CiaynAgent: self.available_functions = [] for t in tools: self.available_functions.append(get_function_info(t.func)) - self._tool_failure = CiaynAgent.ToolCallFailure() + self.tool_failure_consecutive_failures = 0 + self.tool_failure_current_provider = None + self.tool_failure_current_model = None + self.tool_failure_used_fallbacks = set() + + def _load_fallback_tool_models(self, config: dict) -> list: + fallback_tool_models_config = config.get("fallback_tool_models") + if fallback_tool_models_config: + return [m.strip() for m in fallback_tool_models_config.split(",") if m.strip()] + else: + from ra_aid.tool_leaderboard import supported_top_tool_models + return [item["model"] for item in supported_top_tool_models[:5]] def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" @@ -255,48 +249,85 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" last_error = None while retries < max_retries: try: + logger.debug( + f"_execute_tool: attempt {retries+1}, original code: {code}" + ) code = code.strip() if validate_function_call_pattern(code): functions_list = "\n\n".join(self.available_functions) code = _extract_tool_call(code, functions_list) globals_dict = {tool.func.__name__: tool.func for tool in self.tools} + logger.debug(f"_execute_tool: evaluating code: {code}") result = eval(code, globals_dict) - self._tool_failure.consecutive_failures = 0 + logger.debug( + f"_execute_tool: tool executed successfully with result: {result}" + ) + self.tool_failure_consecutive_failures = 0 return result except Exception as e: + logger.debug(f"_execute_tool: exception caught: {e}") self._handle_tool_failure(code, e) last_error = e retries += 1 + logger.debug(f"_execute_tool: retrying, new attempt count: {retries}") raise ToolExecutionError( f"Error executing code after {max_retries} attempts: {str(last_error)}" ) def _handle_tool_failure(self, code: str, error: Exception) -> None: - self._tool_failure.consecutive_failures += 1 - max_failures = self.config.get("max_tool_failures", 3) + logger.debug( + f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}" + ) + self.tool_failure_consecutive_failures += 1 + max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) + logger.debug( + f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}" + ) if ( self.fallback_enabled - and self._tool_failure.consecutive_failures >= max_failures + and self.tool_failure_consecutive_failures >= max_failures and self.fallback_tool_models ): + logger.debug( + "_handle_tool_failure: threshold reached, invoking fallback mechanism." + ) self._attempt_fallback(code) def _attempt_fallback(self, code: str) -> None: + logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") new_model = self.fallback_tool_models[0] - failed_tool_call_name = code.split('(')[0].strip() + failed_tool_call_name = code.split("(")[0].strip() logger.error( - f"Tool call failed {self._tool_failure.consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}" + f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}" ) try: - from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env + from ra_aid.llm import ( + initialize_llm, + merge_chat_history, + validate_provider_env, + ) + + logger.debug(f"_attempt_fallback: validating provider {self.provider}") if not validate_provider_env(self.provider): - logger.error(f"Missing environment configuration for provider {self.provider}. Cannot fallback.") + logger.error( + f"Missing environment configuration for provider {self.provider}. Cannot fallback." + ) else: + logger.debug( + f"_attempt_fallback: initializing fallback model {new_model}" + ) self.model = initialize_llm(self.provider, new_model) + logger.debug( + f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}" + ) self.model.bind_tools(self.tools, tool_choice=failed_tool_call_name) - self._tool_failure.used_fallbacks.add(new_model) - merge_chat_history() # Assuming merge_chat_history handles merging fallback history - self._tool_failure.consecutive_failures = 0 + self.tool_failure_used_fallbacks.add(new_model) + logger.debug("_attempt_fallback: merging chat history for fallback") + merge_chat_history() + self.tool_failure_consecutive_failures = 0 + logger.debug( + "_attempt_fallback: fallback successful and tool failure counter reset" + ) except Exception as switch_e: logger.error(f"Fallback model switching failed: {switch_e}") diff --git a/ra_aid/config.py b/ra_aid/config.py index 6c12a93..41868dd 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -3,4 +3,12 @@ DEFAULT_RECURSION_LIMIT = 100 DEFAULT_MAX_TEST_CMD_RETRIES = 3 DEFAULT_MAX_TOOL_FAILURES = 3 -MAX_TOOL_FAILURES = 3 + +VALID_PROVIDERS = [ + "anthropic", + "openai", + "openrouter", + "openai-compatible", + "deepseek", + "gemini", +] diff --git a/ra_aid/logging_config.py b/ra_aid/logging_config.py index a40aa3a..fb3bf63 100644 --- a/ra_aid/logging_config.py +++ b/ra_aid/logging_config.py @@ -1,18 +1,53 @@ import logging import sys from typing import Optional +from rich.console import Console +from rich.panel import Panel +from rich.markdown import Markdown -def setup_logging(verbose: bool = False) -> None: +class PrettyHandler(logging.Handler): + def __init__(self, level=logging.NOTSET): + super().__init__(level) + self.console = Console() + + def emit(self, record): + try: + msg = self.format(record) + # Determine title and style based on log level + if record.levelno >= logging.CRITICAL: + title = "🔥 CRITICAL" + style = "bold red" + elif record.levelno >= logging.ERROR: + title = "❌ ERROR" + style = "red" + elif record.levelno >= logging.WARNING: + title = "⚠️ WARNING" + style = "yellow" + elif record.levelno >= logging.INFO: + title = "ℹ️ INFO" + style = "green" + else: + title = "🐞 DEBUG" + style = "blue" + self.console.print(Panel(Markdown(msg.strip()), title=title, style=style)) + except Exception: + self.handleError(record) + + +def setup_logging(verbose: bool = False, pretty: bool = False) -> None: logger = logging.getLogger("ra_aid") logger.setLevel(logging.DEBUG if verbose else logging.INFO) if not logger.handlers: - handler = logging.StreamHandler(sys.stdout) - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - handler.setFormatter(formatter) + if pretty: + handler = PrettyHandler() + else: + handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) logger.addHandler(handler) diff --git a/ra_aid/tool_leaderboard.py b/ra_aid/tool_leaderboard.py new file mode 100644 index 0000000..04d7573 --- /dev/null +++ b/ra_aid/tool_leaderboard.py @@ -0,0 +1,529 @@ +from ra_aid.config import VALID_PROVIDERS + +# Data extracted at 2/10/2025: +# https://gorilla.cs.berkeley.edu/leaderboard.html +# In order of overall_acc +leaderboard_data = [ + { + "overall_acc": 74.31, + "model": "watt-tool-70B", + "type": "FC", + "link": "https://huggingface.co/watt-ai/watt-tool-70B/", + "cost": "N/A", + "latency": 3.4, + "ast_summary": 84.06, + "exec_summary": 89.39, + "live_ast_acc": 77.74, + "multi_turn_acc": 58.75, + "relevance": 94.44, + "irrelevance": 76.32, + "organization": "Watt AI Lab", + "license": "Apache-2.0", + "provider": "unknown", + }, + { + "overall_acc": 72.08, + "model": "gpt-4o-2024-11-20", + "type": "Prompt", + "link": "https://openai.com/index/hello-gpt-4o/", + "cost": 13.54, + "latency": 0.78, + "ast_summary": 88.1, + "exec_summary": 89.38, + "live_ast_acc": 79.83, + "multi_turn_acc": 47.62, + "relevance": 83.33, + "irrelevance": 83.76, + "organization": "OpenAI", + "license": "Proprietary", + "provider": "openai", + }, + { + "overall_acc": 69.58, + "model": "gpt-4o-2024-11-20", + "type": "FC", + "link": "https://openai.com/index/hello-gpt-4o/", + "cost": 8.23, + "latency": 1.11, + "ast_summary": 87.42, + "exec_summary": 89.2, + "live_ast_acc": 79.65, + "multi_turn_acc": 41, + "relevance": 83.33, + "irrelevance": 83.15, + "organization": "OpenAI", + "license": "Proprietary", + "provider": "openai", + }, + { + "overall_acc": 67.98, + "model": "watt-tool-8B", + "type": "FC", + "link": "https://huggingface.co/watt-ai/watt-tool-8B/", + "cost": "N/A", + "latency": 1.31, + "ast_summary": 86.56, + "exec_summary": 89.34, + "live_ast_acc": 76.5, + "multi_turn_acc": 39.12, + "relevance": 83.33, + "irrelevance": 83.15, + "organization": "Watt AI Lab", + "license": "Apache-2.0", + "provider": "unknown", + }, + { + "overall_acc": 67.88, + "model": "GPT-4-turbo-2024-04-09", + "type": "FC", + "link": "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo", + "cost": 33.22, + "latency": 2.47, + "ast_summary": 84.73, + "exec_summary": 85.21, + "live_ast_acc": 80.5, + "multi_turn_acc": 38.12, + "relevance": 72.22, + "irrelevance": 83.81, + "organization": "OpenAI", + "license": "Proprietary", + "provider": "openai", + }, + { + "overall_acc": 66.73, + "model": "o1-2024-12-17", + "type": "Prompt", + "link": "https://openai.com/o1/", + "cost": 102.47, + "latency": 5.3, + "ast_summary": 85.67, + "exec_summary": 79.77, + "live_ast_acc": 80.63, + "multi_turn_acc": 36, + "relevance": 72.22, + "irrelevance": 87.78, + "organization": "OpenAI", + "license": "Proprietary", + "provider": "openai", + }, + { + "overall_acc": 64.1, + "model": "GPT-4o-mini-2024-07-18", + "type": "FC", + "link": "https://openai.com/index/gpt-4o-mini-advancing-cost-efficient-intelligence/", + "cost": 0.51, + "latency": 1.49, + "ast_summary": 85.21, + "exec_summary": 83.57, + "live_ast_acc": 74.41, + "multi_turn_acc": 34.12, + "relevance": 83.33, + "irrelevance": 74.75, + "organization": "OpenAI", + "license": "Proprietary", + "provider": "openai", + }, + { + "overall_acc": 62.79, + "model": "o1-mini-2024-09-12", + "type": "Prompt", + "link": "https://openai.com/index/openai-o1-mini-advancing-cost-efficient-reasoning/", + "cost": 29.76, + "latency": 8.44, + "ast_summary": 78.92, + "exec_summary": 82.7, + "live_ast_acc": 78.14, + "multi_turn_acc": 28.25, + "relevance": 61.11, + "irrelevance": 89.62, + "organization": "OpenAI", + "license": "Proprietary", + "provider": "openai", + }, + { + "overall_acc": 62.73, + "model": "Functionary-Medium-v3.1", + "type": "FC", + "link": "https://huggingface.co/meetkai/functionary-medium-v3.1", + "cost": "N/A", + "latency": 14.06, + "ast_summary": 89.88, + "exec_summary": 91.32, + "live_ast_acc": 76.63, + "multi_turn_acc": 21.62, + "relevance": 72.22, + "irrelevance": 76.08, + "organization": "MeetKai", + "license": "MIT", + "provider": "unknown", + }, + { + "overall_acc": 62.19, + "model": "Gemini-1.5-Pro-002", + "type": "Prompt", + "link": "https://deepmind.google/technologies/gemini/pro/", + "cost": 7.05, + "latency": 5.94, + "ast_summary": 88.58, + "exec_summary": 91.27, + "live_ast_acc": 76.72, + "multi_turn_acc": 20.75, + "relevance": 72.22, + "irrelevance": 78.15, + "organization": "Google", + "license": "Proprietary", + "provider": "google", + }, + { + "overall_acc": 61.83, + "model": "Hammer2.1-7b", + "type": "FC", + "link": "https://huggingface.co/MadeAgents/Hammer2.1-7b", + "cost": "N/A", + "latency": 2.08, + "ast_summary": 88.65, + "exec_summary": 85.48, + "live_ast_acc": 75.11, + "multi_turn_acc": 23.5, + "relevance": 82.35, + "irrelevance": 78.59, + "organization": "MadeAgents", + "license": "cc-by-nc-4.0", + "provider": "unknown", + }, + { + "overall_acc": 61.74, + "model": "Gemini-2.0-Flash-Exp", + "type": "Prompt", + "link": "https://deepmind.google/technologies/gemini/flash/", + "cost": 0.0, + "latency": 1.18, + "ast_summary": 89.96, + "exec_summary": 79.89, + "live_ast_acc": 82.01, + "multi_turn_acc": 17.88, + "relevance": 77.78, + "irrelevance": 86.44, + "organization": "Google", + "license": "Proprietary", + "provider": "google", + }, + { + "overall_acc": 61.38, + "model": "Amazon-Nova-Pro-v1:0", + "type": "FC", + "link": "https://aws.amazon.com/cn/ai/generative-ai/nova/", + "cost": 5.26, + "latency": 2.67, + "ast_summary": 84.46, + "exec_summary": 85.64, + "live_ast_acc": 74.32, + "multi_turn_acc": 26.12, + "relevance": 77.78, + "irrelevance": 70.98, + "organization": "Amazon", + "license": "Proprietary", + "provider": "unknown", + }, + { + "overall_acc": 61.31, + "model": "Qwen2.5-72B-Instruct", + "type": "Prompt", + "link": "https://huggingface.co/Qwen/Qwen2.5-72B-Instruct", + "cost": "N/A", + "latency": 3.72, + "ast_summary": 90.81, + "exec_summary": 92.7, + "live_ast_acc": 75.3, + "multi_turn_acc": 18, + "relevance": 100, + "irrelevance": 72.81, + "organization": "Qwen", + "license": "qwen", + "provider": "unknown", + }, + { + "overall_acc": 60.97, + "model": "Gemini-1.5-Pro-002", + "type": "FC", + "link": "https://deepmind.google/technologies/gemini/pro/", + "cost": 5.39, + "latency": 2.07, + "ast_summary": 87.29, + "exec_summary": 84.61, + "live_ast_acc": 76.28, + "multi_turn_acc": 21.62, + "relevance": 72.22, + "irrelevance": 76.9, + "organization": "Google", + "license": "Proprietary", + "provider": "google", + }, + { + "overall_acc": 60.89, + "model": "GPT-4o-mini-2024-07-18", + "type": "Prompt", + "link": "https://openai.com/index/gpt-4o-mini-advancing-cost-efficient-intelligence/", + "cost": 0.84, + "latency": 1.31, + "ast_summary": 86.77, + "exec_summary": 80.84, + "live_ast_acc": 76.5, + "multi_turn_acc": 22, + "relevance": 83.33, + "irrelevance": 80.67, + "organization": "OpenAI", + "license": "Proprietary", + "provider": "openai", + }, + { + "overall_acc": 60.59, + "model": "Gemini-2.0-Flash-Exp", + "type": "FC", + "link": "https://deepmind.google/technologies/gemini/flash/", + "cost": 0.0, + "latency": 0.85, + "ast_summary": 85.1, + "exec_summary": 77.46, + "live_ast_acc": 79.03, + "multi_turn_acc": 20.25, + "relevance": 55.56, + "irrelevance": 91.51, + "organization": "Google", + "license": "Proprietary", + "provider": "google", + }, + { + "overall_acc": 60.46, + "model": "Gemini-1.5-Pro-001", + "type": "Prompt", + "link": "https://deepmind.google/technologies/gemini/pro/", + "cost": 7.0, + "latency": 1.54, + "ast_summary": 85.56, + "exec_summary": 85.77, + "live_ast_acc": 76.68, + "multi_turn_acc": 18.88, + "relevance": 55.56, + "irrelevance": 84.81, + "organization": "Google", + "license": "Proprietary", + "provider": "google", + }, + { + "overall_acc": 60.38, + "model": "Gemini-Exp-1206", + "type": "FC", + "link": "https://blog.google/feed/gemini-exp-1206/", + "cost": 0.0, + "latency": 3.42, + "ast_summary": 85.17, + "exec_summary": 80.86, + "live_ast_acc": 78.54, + "multi_turn_acc": 20.25, + "relevance": 77.78, + "irrelevance": 79.64, + "organization": "Google", + "license": "Proprietary", + "provider": "google", + }, + { + "overall_acc": 59.67, + "model": "Qwen2.5-32B-Instruct", + "type": "Prompt", + "link": "https://huggingface.co/Qwen/Qwen2.5-32B-Instruct", + "cost": "N/A", + "latency": 2.26, + "ast_summary": 85.81, + "exec_summary": 89.79, + "live_ast_acc": 74.23, + "multi_turn_acc": 17.75, + "relevance": 100, + "irrelevance": 73.75, + "organization": "Qwen", + "license": "apache-2.0", + "provider": "unknown", + }, + { + "overall_acc": 59.57, + "model": "GPT-4-turbo-2024-04-09", + "type": "Prompt", + "link": "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo", + "cost": 58.87, + "latency": 1.24, + "ast_summary": 90.88, + "exec_summary": 89.45, + "live_ast_acc": 63.84, + "multi_turn_acc": 30.25, + "relevance": 100, + "irrelevance": 35.57, + "organization": "OpenAI", + "license": "Proprietary", + "provider": "openai", + }, + { + "overall_acc": 59.42, + "model": "Gemini-1.5-Pro-001", + "type": "FC", + "link": "https://deepmind.google/technologies/gemini/pro/", + "cost": 5.1, + "latency": 1.43, + "ast_summary": 84.33, + "exec_summary": 87.95, + "live_ast_acc": 76.23, + "multi_turn_acc": 16, + "relevance": 50, + "irrelevance": 84.39, + "organization": "Google", + "license": "Proprietary", + "provider": "google", + }, + { + "overall_acc": 59.07, + "model": "Hammer2.1-3b", + "type": "FC", + "link": "https://huggingface.co/MadeAgents/Hammer2.1-3b", + "cost": "N/A", + "latency": 1.95, + "ast_summary": 86.85, + "exec_summary": 84.09, + "live_ast_acc": 74.04, + "multi_turn_acc": 17.38, + "relevance": 82.35, + "irrelevance": 81.87, + "organization": "MadeAgents", + "license": "qwen-research", + "provider": "unknown", + }, + { + "overall_acc": 58.45, + "model": "mistral-large-2407", + "type": "FC", + "link": "https://mistral.ai/news/mistral-large-2407/", + "cost": 12.68, + "latency": 3.12, + "ast_summary": 86.81, + "exec_summary": 84.38, + "live_ast_acc": 69.88, + "multi_turn_acc": 23.75, + "relevance": 72.22, + "irrelevance": 52.85, + "organization": "Mistral AI", + "license": "Proprietary", + "provider": "mistral", + }, + { + "overall_acc": 58.42, + "model": "ToolACE-8B", + "type": "FC", + "link": "https://huggingface.co/Team-ACE/ToolACE-8B", + "cost": "N/A", + "latency": 5.24, + "ast_summary": 87.54, + "exec_summary": 89.21, + "live_ast_acc": 78.59, + "multi_turn_acc": 7.75, + "relevance": 83.33, + "irrelevance": 87.88, + "organization": "Huawei Noah & USTC", + "license": "Apache-2.0", + "provider": "unknown", + }, + { + "overall_acc": 57.78, + "model": "xLAM-8x22b-r", + "type": "FC", + "link": "https://huggingface.co/Salesforce/xLAM-8x22b-r", + "cost": "N/A", + "latency": 9.26, + "ast_summary": 83.69, + "exec_summary": 87.88, + "live_ast_acc": 72.59, + "multi_turn_acc": 16.25, + "relevance": 88.89, + "irrelevance": 67.81, + "organization": "Salesforce", + "license": "cc-by-nc-4.0", + "provider": "unknown", + }, + { + "overall_acc": 57.68, + "model": "Qwen2.5-14B-Instruct", + "type": "Prompt", + "link": "https://huggingface.co/Qwen/Qwen2.5-14B-Instruct", + "cost": "N/A", + "latency": 2.02, + "ast_summary": 85.69, + "exec_summary": 88.84, + "live_ast_acc": 74.14, + "multi_turn_acc": 12.25, + "relevance": 77.78, + "irrelevance": 77.06, + "organization": "Qwen", + "license": "apache-2.0", + "provider": "unknown", + }, + { + "overall_acc": 57.23, + "model": "DeepSeek-V3", + "type": "FC", + "link": "https://api-docs.deepseek.com/news/news1226", + "cost": "N/A", + "latency": 2.58, + "ast_summary": 89.17, + "exec_summary": 83.39, + "live_ast_acc": 68.41, + "multi_turn_acc": 18.62, + "relevance": 88.89, + "irrelevance": 59.36, + "organization": "DeepSeek", + "license": "DeepSeek License", + "provider": "unknown", + }, + { + "overall_acc": 57.09, + "model": "Gemini-1.5-Flash-001", + "type": "Prompt", + "link": "https://deepmind.google/technologies/gemini/flash/", + "cost": 0.48, + "latency": 0.71, + "ast_summary": 85.69, + "exec_summary": 83.59, + "live_ast_acc": 68.9, + "multi_turn_acc": 19.5, + "relevance": 83.33, + "irrelevance": 62.78, + "organization": "Google", + "license": "Proprietary", + "provider": "google", + }, + { + "overall_acc": 56.79, + "model": "Gemini-1.5-Flash-002", + "type": "Prompt", + "link": "https://deepmind.google/technologies/gemini/flash/", + "cost": 0.46, + "latency": 0.81, + "ast_summary": 81.65, + "exec_summary": 80.64, + "live_ast_acc": 76.72, + "multi_turn_acc": 12.5, + "relevance": 83.33, + "irrelevance": 78.49, + "organization": "Google", + "license": "Proprietary", + "provider": "google", + }, +] + + +supported_top_tool_models = [ + { + "cost": item["cost"], + "model": item["model"], + "type": item["type"], + "provider": item["provider"], + } + for item in leaderboard_data + if item["provider"] in VALID_PROVIDERS +] diff --git a/tests/ra_aid/test_ciayn_agent.py b/tests/ra_aid/test_ciayn_agent.py index 65794d0..2e39dfa 100644 --- a/tests/ra_aid/test_ciayn_agent.py +++ b/tests/ra_aid/test_ciayn_agent.py @@ -31,6 +31,7 @@ class DummyModel: content = "dummy_tool()" return Response() + def bind_tools(self, tools, tool_choice): pass @@ -267,6 +268,7 @@ class TestCiaynAgentNewMethods(unittest.TestCase): # Create a dummy tool that always fails for testing fallback def always_fail(): raise Exception("Failure for fallback test") + self.always_fail_tool = DummyTool(always_fail) # Create a dummy model that does minimal work for fallback tests self.dummy_model = DummyModel() @@ -274,24 +276,33 @@ class TestCiaynAgentNewMethods(unittest.TestCase): self.agent = CiaynAgent( self.dummy_model, [self.always_fail_tool], - config={"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"} + config={ + "max_tool_failures": 2, + "fallback_tool_models": "dummy-fallback-model", + }, ) def test_handle_tool_failure_increments_counter(self): - initial_failures = self.agent._tool_failure.consecutive_failures + initial_failures = self.agent.tool_failure_consecutive_failures self.agent._handle_tool_failure("dummy_call()", Exception("Test error")) - self.assertEqual(self.agent._tool_failure.consecutive_failures, initial_failures + 1) + self.assertEqual( + self.agent.tool_failure_consecutive_failures, initial_failures + 1 + ) def test_attempt_fallback_invokes_fallback_logic(self): - # Monkey-patch initialize_llm, merge_chat_history, and validate_provider_env + # Monkey-patch initialize_llm, merge_chat_history, and validate_provider_env # to simulate fallback switching without external dependencies. def dummy_initialize_llm(provider, model_name, temperature=None): return self.dummy_model + def dummy_merge_chat_history(): return ["merged"] + def dummy_validate_provider_env(provider): return True + import ra_aid.llm as llm + original_initialize = llm.initialize_llm original_merge = llm.merge_chat_history original_validate = llm.validate_provider_env @@ -300,14 +311,15 @@ class TestCiaynAgentNewMethods(unittest.TestCase): llm.validate_provider_env = dummy_validate_provider_env # Set failure counter high enough to trigger fallback in _handle_tool_failure - self.agent._tool_failure.consecutive_failures = 2 + self.agent.tool_failure_consecutive_failures = 2 # Call _attempt_fallback; it should reset the failure counter to 0 on success. self.agent._attempt_fallback("always_fail_tool()") - self.assertEqual(self.agent._tool_failure.consecutive_failures, 0) + self.assertEqual(self.agent.tool_failure_consecutive_failures, 0) # Restore original functions llm.initialize_llm = original_initialize llm.merge_chat_history = original_merge llm.validate_provider_env = original_validate + if __name__ == "__main__": unittest.main() From 55abf6e5dd0d97144bf930e62d570bc12a6b2fe2 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Mon, 10 Feb 2025 23:37:15 -0800 Subject: [PATCH 03/45] feat(fallback_handler): implement FallbackHandler class to manage tool failures and fallback logic refactor(ciayn_agent): integrate FallbackHandler into CiaynAgent for improved failure handling fix(agent_utils): add missing newline for better readability in run_agent_with_retry function test(fallback_handler): add unit tests for FallbackHandler to ensure correct failure handling and fallback logic --- ra_aid/agent_utils.py | 1 + ra_aid/agents/ciayn_agent.py | 73 ++------------------ ra_aid/config.py | 1 + ra_aid/fallback_handler.py | 99 +++++++++++++++++++++++++++ tests/ra_aid/test_ciayn_agent.py | 56 +-------------- tests/ra_aid/test_fallback_handler.py | 58 ++++++++++++++++ 6 files changed, 164 insertions(+), 124 deletions(-) create mode 100644 ra_aid/fallback_handler.py create mode 100644 tests/ra_aid/test_fallback_handler.py diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index cdeac25..4dcc6a5 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -806,6 +806,7 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: logger.debug("Agent output: %s", chunk) check_interrupt() print_agent_output(chunk) + if _global_memory["plan_completed"]: _global_memory["plan_completed"] = False _global_memory["task_completed"] = False diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index a651ab8..57c7467 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Optional, Union from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage -from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES +from ra_aid.fallback_handler import FallbackHandler from ra_aid.exceptions import ToolExecutionError from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT @@ -90,8 +90,7 @@ class CiaynAgent: config = {} self.config = config self.provider = config.get("provider", "openai") - self.fallback_enabled = config.get("fallback_tool_enabled", True) - self.fallback_tool_models = self._load_fallback_tool_models(config) + self.fallback_handler = FallbackHandler(config) self.model = model self.tools = tools @@ -100,18 +99,8 @@ class CiaynAgent: self.available_functions = [] for t in tools: self.available_functions.append(get_function_info(t.func)) - self.tool_failure_consecutive_failures = 0 self.tool_failure_current_provider = None self.tool_failure_current_model = None - self.tool_failure_used_fallbacks = set() - - def _load_fallback_tool_models(self, config: dict) -> list: - fallback_tool_models_config = config.get("fallback_tool_models") - if fallback_tool_models_config: - return [m.strip() for m in fallback_tool_models_config.split(",") if m.strip()] - else: - from ra_aid.tool_leaderboard import supported_top_tool_models - return [item["model"] for item in supported_top_tool_models[:5]] def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" @@ -262,7 +251,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" logger.debug( f"_execute_tool: tool executed successfully with result: {result}" ) - self.tool_failure_consecutive_failures = 0 + self.fallback_handler.reset_fallback_handler() return result except Exception as e: logger.debug(f"_execute_tool: exception caught: {e}") @@ -275,61 +264,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" ) def _handle_tool_failure(self, code: str, error: Exception) -> None: - logger.debug( - f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}" - ) - self.tool_failure_consecutive_failures += 1 - max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) - logger.debug( - f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}" - ) - if ( - self.fallback_enabled - and self.tool_failure_consecutive_failures >= max_failures - and self.fallback_tool_models - ): - logger.debug( - "_handle_tool_failure: threshold reached, invoking fallback mechanism." - ) - self._attempt_fallback(code) - - def _attempt_fallback(self, code: str) -> None: - logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") - new_model = self.fallback_tool_models[0] - failed_tool_call_name = code.split("(")[0].strip() - logger.error( - f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}" - ) - try: - from ra_aid.llm import ( - initialize_llm, - merge_chat_history, - validate_provider_env, - ) - - logger.debug(f"_attempt_fallback: validating provider {self.provider}") - if not validate_provider_env(self.provider): - logger.error( - f"Missing environment configuration for provider {self.provider}. Cannot fallback." - ) - else: - logger.debug( - f"_attempt_fallback: initializing fallback model {new_model}" - ) - self.model = initialize_llm(self.provider, new_model) - logger.debug( - f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}" - ) - self.model.bind_tools(self.tools, tool_choice=failed_tool_call_name) - self.tool_failure_used_fallbacks.add(new_model) - logger.debug("_attempt_fallback: merging chat history for fallback") - merge_chat_history() - self.tool_failure_consecutive_failures = 0 - logger.debug( - "_attempt_fallback: fallback successful and tool failure counter reset" - ) - except Exception as switch_e: - logger.error(f"Fallback model switching failed: {switch_e}") + self.fallback_handler.handle_failure(code, error, logger, self) def _create_agent_chunk(self, content: str) -> Dict[str, Any]: """Create an agent chunk in the format expected by print_agent_output.""" diff --git a/ra_aid/config.py b/ra_aid/config.py index 41868dd..4c9bfea 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -3,6 +3,7 @@ DEFAULT_RECURSION_LIMIT = 100 DEFAULT_MAX_TEST_CMD_RETRIES = 3 DEFAULT_MAX_TOOL_FAILURES = 3 +FALLBACK_TOOL_MODEL_LIMIT = 5 VALID_PROVIDERS = [ "anthropic", diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py new file mode 100644 index 0000000..0eef7b6 --- /dev/null +++ b/ra_aid/fallback_handler.py @@ -0,0 +1,99 @@ +from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT +from ra_aid.tool_leaderboard import supported_top_tool_models +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env + + +class FallbackHandler: + def __init__(self, config): + self.config = config + self.fallback_enabled = config.get("fallback_tool_enabled", True) + self.fallback_tool_models = self._load_fallback_tool_models(config) + self.tool_failure_consecutive_failures = 0 + self.tool_failure_used_fallbacks = set() + + def _load_fallback_tool_models(self, config): + fallback_tool_models_config = config.get("fallback_tool_models") + if fallback_tool_models_config: + return [ + m.strip() for m in fallback_tool_models_config.split(",") if m.strip() + ] + else: + console = Console() + supported = [] + skipped = [] + for item in supported_top_tool_models: + provider = item.get("provider") + model_name = item.get("model") + if validate_provider_env(provider): + supported.append(model_name) + if len(supported) == FALLBACK_TOOL_MODEL_LIMIT: + break + else: + skipped.append(model_name) + final_models = supported[:FALLBACK_TOOL_MODEL_LIMIT] + message = "Fallback models selected: " + ", ".join(final_models) + if skipped: + message += ( + "\nSkipped top tool calling models due to missing provider ENV API keys: " + + ", ".join(skipped) + ) + console.print(Panel(Markdown(message), title="Fallback Models")) + return final_models + + def handle_failure(self, code: str, error: Exception, logger, agent): + logger.debug( + f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}" + ) + self.tool_failure_consecutive_failures += 1 + max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) + logger.debug( + f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}" + ) + if ( + self.fallback_enabled + and self.tool_failure_consecutive_failures >= max_failures + and self.fallback_tool_models + ): + logger.debug( + "_handle_tool_failure: threshold reached, invoking fallback mechanism." + ) + self.attempt_fallback(code, logger, agent) + + def attempt_fallback(self, code: str, logger, agent): + logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") + new_model = self.fallback_tool_models[0] + failed_tool_call_name = code.split("(")[0].strip() + logger.error( + f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}" + ) + try: + logger.debug(f"_attempt_fallback: validating provider {agent.provider}") + if not validate_provider_env(agent.provider): + logger.error( + f"Missing environment configuration for provider {agent.provider}. Cannot fallback." + ) + else: + logger.debug( + f"_attempt_fallback: initializing fallback model {new_model}" + ) + agent.model = initialize_llm(agent.provider, new_model) + logger.debug( + f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}" + ) + agent.model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) + self.tool_failure_used_fallbacks.add(new_model) + logger.debug("_attempt_fallback: merging chat history for fallback") + merge_chat_history() + self.tool_failure_consecutive_failures = 0 + logger.debug( + "_attempt_fallback: fallback successful and tool failure counter reset" + ) + except Exception as switch_e: + logger.error(f"Fallback model switching failed: {switch_e}") + + def reset_fallback_handler(self): + self.tool_failure_consecutive_failures = 0 + self.tool_failure_used_fallbacks.clear() diff --git a/tests/ra_aid/test_ciayn_agent.py b/tests/ra_aid/test_ciayn_agent.py index 2e39dfa..46c8191 100644 --- a/tests/ra_aid/test_ciayn_agent.py +++ b/tests/ra_aid/test_ciayn_agent.py @@ -264,61 +264,7 @@ class TestFunctionCallValidation: class TestCiaynAgentNewMethods(unittest.TestCase): - def setUp(self): - # Create a dummy tool that always fails for testing fallback - def always_fail(): - raise Exception("Failure for fallback test") - - self.always_fail_tool = DummyTool(always_fail) - # Create a dummy model that does minimal work for fallback tests - self.dummy_model = DummyModel() - # Initialize CiaynAgent with configuration to trigger fallback quickly - self.agent = CiaynAgent( - self.dummy_model, - [self.always_fail_tool], - config={ - "max_tool_failures": 2, - "fallback_tool_models": "dummy-fallback-model", - }, - ) - - def test_handle_tool_failure_increments_counter(self): - initial_failures = self.agent.tool_failure_consecutive_failures - self.agent._handle_tool_failure("dummy_call()", Exception("Test error")) - self.assertEqual( - self.agent.tool_failure_consecutive_failures, initial_failures + 1 - ) - - def test_attempt_fallback_invokes_fallback_logic(self): - # Monkey-patch initialize_llm, merge_chat_history, and validate_provider_env - # to simulate fallback switching without external dependencies. - def dummy_initialize_llm(provider, model_name, temperature=None): - return self.dummy_model - - def dummy_merge_chat_history(): - return ["merged"] - - def dummy_validate_provider_env(provider): - return True - - import ra_aid.llm as llm - - original_initialize = llm.initialize_llm - original_merge = llm.merge_chat_history - original_validate = llm.validate_provider_env - llm.initialize_llm = dummy_initialize_llm - llm.merge_chat_history = dummy_merge_chat_history - llm.validate_provider_env = dummy_validate_provider_env - - # Set failure counter high enough to trigger fallback in _handle_tool_failure - self.agent.tool_failure_consecutive_failures = 2 - # Call _attempt_fallback; it should reset the failure counter to 0 on success. - self.agent._attempt_fallback("always_fail_tool()") - self.assertEqual(self.agent.tool_failure_consecutive_failures, 0) - # Restore original functions - llm.initialize_llm = original_initialize - llm.merge_chat_history = original_merge - llm.validate_provider_env = original_validate + pass if __name__ == "__main__": diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py new file mode 100644 index 0000000..6f0c285 --- /dev/null +++ b/tests/ra_aid/test_fallback_handler.py @@ -0,0 +1,58 @@ +import unittest +from ra_aid.fallback_handler import FallbackHandler + +class DummyLogger: + def debug(self, msg): + pass + def error(self, msg): + pass + +class DummyAgent: + provider = "openai" + tools = [] + model = None + +class TestFallbackHandler(unittest.TestCase): + def setUp(self): + self.config = {"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"} + self.fallback_handler = FallbackHandler(self.config) + self.logger = DummyLogger() + self.agent = DummyAgent() + + def test_handle_failure_increments_counter(self): + initial_failures = self.fallback_handler.tool_failure_consecutive_failures + self.fallback_handler.handle_failure("dummy_call()", Exception("Test error"), self.logger, self.agent) + self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1) + + def test_attempt_fallback_resets_counter(self): + # Monkey-patch dummy functions for fallback components + def dummy_initialize_llm(provider, model_name, temperature=None): + class DummyModel: + def bind_tools(self, tools, tool_choice): + pass + return DummyModel() + + def dummy_merge_chat_history(): + return ["merged"] + + def dummy_validate_provider_env(provider): + return True + + import ra_aid.llm as llm + original_initialize = llm.initialize_llm + original_merge = llm.merge_chat_history + original_validate = llm.validate_provider_env + llm.initialize_llm = dummy_initialize_llm + llm.merge_chat_history = dummy_merge_chat_history + llm.validate_provider_env = dummy_validate_provider_env + + self.fallback_handler.tool_failure_consecutive_failures = 2 + self.fallback_handler.attempt_fallback("dummy_tool_call()", self.logger, self.agent) + self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0) + + llm.initialize_llm = original_initialize + llm.merge_chat_history = original_merge + llm.validate_provider_env = original_validate + +if __name__ == "__main__": + unittest.main() From 0521b3ff9ae9c7cc214cba354ee30319920b51d5 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Feb 2025 00:38:15 -0800 Subject: [PATCH 04/45] feat(config.py): add RETRY_FALLBACK_COUNT and RETRY_FALLBACK_DELAY to configure retry behavior for fallback models refactor(fallback_handler.py): enhance fallback handling logic to support both prompt-based and function-calling fallbacks with retries fix(fallback_handler.py): update fallback model selection to return dictionaries for better structure and access to model properties --- ra_aid/config.py | 2 + ra_aid/fallback_handler.py | 83 +++++++++++++++++++++++--------------- 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/ra_aid/config.py b/ra_aid/config.py index 4c9bfea..e85cb12 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -4,6 +4,8 @@ DEFAULT_RECURSION_LIMIT = 100 DEFAULT_MAX_TEST_CMD_RETRIES = 3 DEFAULT_MAX_TOOL_FAILURES = 3 FALLBACK_TOOL_MODEL_LIMIT = 5 +RETRY_FALLBACK_COUNT = 3 +RETRY_FALLBACK_DELAY = 2 VALID_PROVIDERS = [ "anthropic", diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 0eef7b6..762d697 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -1,4 +1,4 @@ -from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT +from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, RETRY_FALLBACK_DELAY from ra_aid.tool_leaderboard import supported_top_tool_models from rich.console import Console from rich.markdown import Markdown @@ -17,9 +17,11 @@ class FallbackHandler: def _load_fallback_tool_models(self, config): fallback_tool_models_config = config.get("fallback_tool_models") if fallback_tool_models_config: - return [ - m.strip() for m in fallback_tool_models_config.split(",") if m.strip() - ] + # Assume comma-separated model names; wrap each in a dict with default type "prompt" + models = [] + for m in [x.strip() for x in fallback_tool_models_config.split(",") if x.strip()]: + models.append({"model": m, "type": "prompt"}) + return models else: console = Console() supported = [] @@ -28,13 +30,13 @@ class FallbackHandler: provider = item.get("provider") model_name = item.get("model") if validate_provider_env(provider): - supported.append(model_name) + supported.append(item) if len(supported) == FALLBACK_TOOL_MODEL_LIMIT: break else: skipped.append(model_name) - final_models = supported[:FALLBACK_TOOL_MODEL_LIMIT] - message = "Fallback models selected: " + ", ".join(final_models) + final_models = supported # list of dicts + message = "Fallback models selected: " + ", ".join([m["model"] for m in final_models]) if skipped: message += ( "\nSkipped top tool calling models due to missing provider ENV API keys: " @@ -64,36 +66,51 @@ class FallbackHandler: def attempt_fallback(self, code: str, logger, agent): logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") - new_model = self.fallback_tool_models[0] + fallback_model = self.fallback_tool_models[0] failed_tool_call_name = code.split("(")[0].strip() logger.error( - f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}" + f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" ) - try: - logger.debug(f"_attempt_fallback: validating provider {agent.provider}") - if not validate_provider_env(agent.provider): - logger.error( - f"Missing environment configuration for provider {agent.provider}. Cannot fallback." - ) - else: - logger.debug( - f"_attempt_fallback: initializing fallback model {new_model}" - ) - agent.model = initialize_llm(agent.provider, new_model) - logger.debug( - f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}" - ) - agent.model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) - self.tool_failure_used_fallbacks.add(new_model) - logger.debug("_attempt_fallback: merging chat history for fallback") - merge_chat_history() - self.tool_failure_consecutive_failures = 0 - logger.debug( - "_attempt_fallback: fallback successful and tool failure counter reset" - ) - except Exception as switch_e: - logger.error(f"Fallback model switching failed: {switch_e}") + if fallback_model.get("type", "prompt").lower() == "fc": + self.attempt_fallback_function(code, logger, agent) + else: + self.attempt_fallback_prompt(code, logger, agent) def reset_fallback_handler(self): self.tool_failure_consecutive_failures = 0 self.tool_failure_used_fallbacks.clear() + def attempt_fallback_prompt(self, code: str, logger, agent): + logger.debug("Attempting prompt-based fallback using fallback models") + failed_tool_call_name = code.split("(")[0].strip() + for fallback_model in self.fallback_tool_models: + try: + logger.debug(f"Trying fallback model: {fallback_model['model']}") + model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY) + model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) + response = model.invoke(code) + self.tool_failure_used_fallbacks.add(fallback_model['model']) + agent.model = model + self.reset_fallback_handler() + logger.debug("Prompt-based fallback executed successfully with model: " + fallback_model['model']) + return response + except Exception as e: + logger.error(f"Prompt-based fallback with model {fallback_model['model']} failed: {e}") + raise Exception("All prompt-based fallback models failed") + + def attempt_fallback_function(self, code: str, logger, agent): + logger.debug("Attempting function-calling fallback using fallback models") + failed_tool_call_name = code.split("(")[0].strip() + for fallback_model in self.fallback_tool_models: + try: + logger.debug(f"Trying fallback model: {fallback_model['model']}") + model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY) + model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) + response = model.invoke(code) + self.tool_failure_used_fallbacks.add(fallback_model['model']) + agent.model = model + self.reset_fallback_handler() + logger.debug("Function-calling fallback executed successfully with model: " + fallback_model['model']) + return response + except Exception as e: + logger.error(f"Function-calling fallback with model {fallback_model['model']} failed: {e}") + raise Exception("All function-calling fallback models failed") From 3d622911a66382a1159efa5725ad916275e310e5 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Feb 2025 00:40:02 -0800 Subject: [PATCH 05/45] feat(fallback_handler.py): add console notification for tool fallback activation to improve user feedback during failures --- ra_aid/fallback_handler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 762d697..0133888 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -71,6 +71,7 @@ class FallbackHandler: logger.error( f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" ) + Console().print(Panel(Markdown(f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}."), title="Fallback Notification")) if fallback_model.get("type", "prompt").lower() == "fc": self.attempt_fallback_function(code, logger, agent) else: From d39be05e39650007767914c7294fc7a626c53487 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Feb 2025 00:44:39 -0800 Subject: [PATCH 06/45] docs(fallback_handler.py): add detailed docstrings to FallbackHandler methods to improve code documentation and clarity on functionality --- ra_aid/fallback_handler.py | 79 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 0133888..ef5c73d 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -7,7 +7,21 @@ from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env class FallbackHandler: + """ + FallbackHandler manages fallback logic when tool execution fails. + + It loads fallback models from configuration and validated provider settings, + maintains failure counts, and triggers appropriate fallback methods for both + prompt-based and function-calling tool invocations. It also resets internal + counters when a tool call succeeds. + """ def __init__(self, config): + """ + Initialize the FallbackHandler with the given configuration. + + Args: + config (dict): Configuration dictionary that may include fallback settings. + """ self.config = config self.fallback_enabled = config.get("fallback_tool_enabled", True) self.fallback_tool_models = self._load_fallback_tool_models(config) @@ -15,6 +29,19 @@ class FallbackHandler: self.tool_failure_used_fallbacks = set() def _load_fallback_tool_models(self, config): + """ + Load and return fallback tool models based on the provided configuration. + + If the config specifies 'fallback_tool_models', those are used (assuming comma-separated names). + Otherwise, this method filters the supported_top_tool_models based on provider environment validation, + selecting up to FALLBACK_TOOL_MODEL_LIMIT models. + + Args: + config (dict): Configuration dictionary. + + Returns: + list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model. + """ fallback_tool_models_config = config.get("fallback_tool_models") if fallback_tool_models_config: # Assume comma-separated model names; wrap each in a dict with default type "prompt" @@ -46,6 +73,15 @@ class FallbackHandler: return final_models def handle_failure(self, code: str, error: Exception, logger, agent): + """ + Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. + + Args: + code (str): The code that failed to execute. + error (Exception): The exception raised during execution. + logger: Logger instance for logging. + agent: The agent instance on which fallback may be executed. + """ logger.debug( f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}" ) @@ -65,6 +101,14 @@ class FallbackHandler: self.attempt_fallback(code, logger, agent) def attempt_fallback(self, code: str, logger, agent): + """ + Initiate the fallback process by selecting a fallback model and triggering the appropriate fallback method. + + Args: + code (str): The tool code that triggered the fallback. + logger: Logger instance for logging messages. + agent: The agent for which fallback is being executed. + """ logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") fallback_model = self.fallback_tool_models[0] failed_tool_call_name = code.split("(")[0].strip() @@ -78,9 +122,28 @@ class FallbackHandler: self.attempt_fallback_prompt(code, logger, agent) def reset_fallback_handler(self): + """ + Reset the fallback handler's internal failure counters and clear the record of used fallback models. + """ self.tool_failure_consecutive_failures = 0 self.tool_failure_used_fallbacks.clear() def attempt_fallback_prompt(self, code: str, logger, agent): + """ + Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code. + + This method tries each fallback model (with retry logic configured) until one successfully executes the code. + + Args: + code (str): The tool code to invoke via fallback. + logger: Logger instance for logging messages. + agent: The agent instance to update with the new model upon success. + + Returns: + The response from the fallback model invocation. + + Raises: + Exception: If all prompt-based fallback models fail. + """ logger.debug("Attempting prompt-based fallback using fallback models") failed_tool_call_name = code.split("(")[0].strip() for fallback_model in self.fallback_tool_models: @@ -99,6 +162,22 @@ class FallbackHandler: raise Exception("All prompt-based fallback models failed") def attempt_fallback_function(self, code: str, logger, agent): + """ + Attempt a function-calling fallback by iterating over fallback models and invoking the provided code. + + This method tries each fallback model (with retry logic configured) until one successfully executes the code. + + Args: + code (str): The tool code to invoke via fallback. + logger: Logger instance for logging messages. + agent: The agent instance to update with the new model upon success. + + Returns: + The response from the fallback model invocation. + + Raises: + Exception: If all function-calling fallback models fail. + """ logger.debug("Attempting function-calling fallback using fallback models") failed_tool_call_name = code.split("(")[0].strip() for fallback_model in self.fallback_tool_models: From de489584e5b20d68c03c4282cd15be293075ea86 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Feb 2025 01:10:22 -0800 Subject: [PATCH 07/45] refactor(fallback_handler.py): improve code readability by formatting imports and restructuring for loops fix(fallback_handler.py): ensure fallback models have a default type of "prompt" and handle exceptions properly during fallback attempts --- ra_aid/fallback_handler.py | 89 +++++++++++++++++++++++++++++--------- 1 file changed, 69 insertions(+), 20 deletions(-) diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index ef5c73d..ebc3cb6 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -1,4 +1,8 @@ -from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, RETRY_FALLBACK_DELAY +from ra_aid.config import ( + DEFAULT_MAX_TOOL_FAILURES, + FALLBACK_TOOL_MODEL_LIMIT, + RETRY_FALLBACK_COUNT, +) from ra_aid.tool_leaderboard import supported_top_tool_models from rich.console import Console from rich.markdown import Markdown @@ -9,12 +13,13 @@ from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env class FallbackHandler: """ FallbackHandler manages fallback logic when tool execution fails. - + It loads fallback models from configuration and validated provider settings, maintains failure counts, and triggers appropriate fallback methods for both prompt-based and function-calling tool invocations. It also resets internal counters when a tool call succeeds. """ + def __init__(self, config): """ Initialize the FallbackHandler with the given configuration. @@ -46,7 +51,9 @@ class FallbackHandler: if fallback_tool_models_config: # Assume comma-separated model names; wrap each in a dict with default type "prompt" models = [] - for m in [x.strip() for x in fallback_tool_models_config.split(",") if x.strip()]: + for m in [ + x.strip() for x in fallback_tool_models_config.split(",") if x.strip() + ]: models.append({"model": m, "type": "prompt"}) return models else: @@ -62,8 +69,14 @@ class FallbackHandler: break else: skipped.append(model_name) - final_models = supported # list of dicts - message = "Fallback models selected: " + ", ".join([m["model"] for m in final_models]) + final_models = [] + for item in supported: + if "type" not in item: + item["type"] = "prompt" + final_models.append(item) + message = "Fallback models selected: " + ", ".join( + [m["model"] for m in final_models] + ) if skipped: message += ( "\nSkipped top tool calling models due to missing provider ENV API keys: " @@ -115,7 +128,14 @@ class FallbackHandler: logger.error( f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" ) - Console().print(Panel(Markdown(f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}."), title="Fallback Notification")) + Console().print( + Panel( + Markdown( + f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}." + ), + title="Fallback Notification", + ) + ) if fallback_model.get("type", "prompt").lower() == "fc": self.attempt_fallback_function(code, logger, agent) else: @@ -127,6 +147,7 @@ class FallbackHandler: """ self.tool_failure_consecutive_failures = 0 self.tool_failure_used_fallbacks.clear() + def attempt_fallback_prompt(self, code: str, logger, agent): """ Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code. @@ -149,16 +170,30 @@ class FallbackHandler: for fallback_model in self.fallback_tool_models: try: logger.debug(f"Trying fallback model: {fallback_model['model']}") - model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY) - model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) - response = model.invoke(code) - self.tool_failure_used_fallbacks.add(fallback_model['model']) - agent.model = model + simple_model = initialize_llm( + fallback_model["provider"], fallback_model["model"] + ) + binded_model = simple_model.bind_tools( + agent.tools, tool_choice=failed_tool_call_name + ) + retry_model = binded_model.with_retry( + stop_after_attempt=RETRY_FALLBACK_COUNT + ) + response = retry_model.invoke(code) + self.tool_failure_used_fallbacks.add(fallback_model["model"]) + agent.model = retry_model self.reset_fallback_handler() - logger.debug("Prompt-based fallback executed successfully with model: " + fallback_model['model']) + logger.debug( + "Prompt-based fallback executed successfully with model: " + + fallback_model["model"] + ) return response except Exception as e: - logger.error(f"Prompt-based fallback with model {fallback_model['model']} failed: {e}") + if isinstance(e, KeyboardInterrupt): + raise + logger.error( + f"Prompt-based fallback with model {fallback_model['model']} failed: {e}" + ) raise Exception("All prompt-based fallback models failed") def attempt_fallback_function(self, code: str, logger, agent): @@ -183,14 +218,28 @@ class FallbackHandler: for fallback_model in self.fallback_tool_models: try: logger.debug(f"Trying fallback model: {fallback_model['model']}") - model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY) - model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) - response = model.invoke(code) - self.tool_failure_used_fallbacks.add(fallback_model['model']) - agent.model = model + simple_model = initialize_llm( + fallback_model["provider"], fallback_model["model"] + ) + binded_model = simple_model.bind_tools( + agent.tools, tool_choice=failed_tool_call_name + ) + retry_model = binded_model.with_retry( + stop_after_attempt=RETRY_FALLBACK_COUNT + ) + response = retry_model.invoke(code) + self.tool_failure_used_fallbacks.add(fallback_model["model"]) + agent.model = retry_model self.reset_fallback_handler() - logger.debug("Function-calling fallback executed successfully with model: " + fallback_model['model']) + logger.debug( + "Function-calling fallback executed successfully with model: " + + fallback_model["model"] + ) return response except Exception as e: - logger.error(f"Function-calling fallback with model {fallback_model['model']} failed: {e}") + if isinstance(e, KeyboardInterrupt): + raise + logger.error( + f"Function-calling fallback with model {fallback_model['model']} failed: {e}" + ) raise Exception("All function-calling fallback models failed") From 13880677694ea3d0133a7bc300879fdef08a32c1 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Feb 2025 12:16:04 -0800 Subject: [PATCH 08/45] refactor(agent_utils.py): refactor run_agent_with_retry function for better readability and maintainability by extracting helper functions feat(agent_utils.py): add new helper functions for handling API errors and managing interrupt signals fix(agent_utils.py): improve error handling in tool execution and retry logic feat(fallback_handler.py): enhance fallback handling by binding tools correctly during retries test(tests): add unit tests for new helper functions and refactored logic in agent_utils.py --- ra_aid/agent_utils.py | 148 ++++++++++++++++++------------- ra_aid/agents/ciayn_agent.py | 58 +++++------- ra_aid/fallback_handler.py | 39 +++++++- tests/ra_aid/test_agent_utils.py | 78 ++++++++++++++++ 4 files changed, 223 insertions(+), 100 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 4dcc6a5..678ce73 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -9,10 +9,16 @@ import uuid from datetime import datetime from typing import Any, Dict, List, Literal, Optional, Sequence +from langgraph.graph.graph import CompiledGraph import litellm from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from langchain_core.language_models import BaseChatModel -from langchain_core.messages import BaseMessage, HumanMessage, trim_messages +from langchain_core.messages import ( + BaseMessage, + HumanMessage, + InvalidToolCall, + trim_messages, +) from langchain_core.tools import tool from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent @@ -26,7 +32,8 @@ from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT from ra_aid.console.formatting import print_error, print_stage_header from ra_aid.console.output import print_agent_output -from ra_aid.exceptions import AgentInterrupt +from ra_aid.exceptions import AgentInterrupt, ToolExecutionError +from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params from ra_aid.project_info import ( @@ -238,7 +245,7 @@ def create_agent( *, checkpointer: Any = None, agent_type: str = "default", -) -> Any: +): """Create a react agent with the given configuration. Args: @@ -775,61 +782,98 @@ def check_interrupt(): raise AgentInterrupt("Interrupt requested") -def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: - """Run an agent with retry logic for API errors.""" - logger.debug("Running agent with prompt length: %d", len(prompt)) - original_handler = None +# New helper functions for run_agent_with_retry refactoring +def _setup_interrupt_handling(): if threading.current_thread() is threading.main_thread(): original_handler = signal.getsignal(signal.SIGINT) signal.signal(signal.SIGINT, _request_interrupt) + return original_handler + return None + +def _restore_interrupt_handling(original_handler): + if original_handler and threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGINT, original_handler) + + +def _increment_agent_depth(): + current_depth = _global_memory.get("agent_depth", 0) + _global_memory["agent_depth"] = current_depth + 1 + + +def _decrement_agent_depth(): + _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 + + +def _run_agent_stream(agent, prompt, config): + for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config): + logger.debug("Agent output: %s", chunk) + check_interrupt() + print_agent_output(chunk) + if _global_memory["plan_completed"] or _global_memory["task_completed"]: + _global_memory["plan_completed"] = False + _global_memory["task_completed"] = False + _global_memory["completion_message"] = "" + break + + +def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test): + return execute_test_command(config, original_prompt, test_attempts, auto_test) + + +def _handle_api_error(e, attempt, max_retries, base_delay): + if isinstance(e, ValueError): + error_str = str(e).lower() + if "code" not in error_str or "429" not in error_str: + raise e + if attempt == max_retries - 1: + logger.error("Max retries reached, failing: %s", str(e)) + raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}") + logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e)) + delay = base_delay * (2**attempt) + print_error( + f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})" + ) + start = time.monotonic() + while time.monotonic() - start < delay: + check_interrupt() + time.sleep(0.1) + + +def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: + """Run an agent with retry logic for API errors.""" + logger.debug("Running agent with prompt length: %d", len(prompt)) + original_handler = _setup_interrupt_handling() max_retries = 20 base_delay = 1 test_attempts = 0 _max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) auto_test = config.get("auto_test", False) original_prompt = prompt + fallback_handler = FallbackHandler(config) with InterruptibleSection(): try: - # Track agent execution depth - current_depth = _global_memory.get("agent_depth", 0) - _global_memory["agent_depth"] = current_depth + 1 - + _increment_agent_depth() for attempt in range(max_retries): logger.debug("Attempt %d/%d", attempt + 1, max_retries) check_interrupt() try: - for chunk in agent.stream( - {"messages": [HumanMessage(content=prompt)]}, config - ): - logger.debug("Agent output: %s", chunk) - check_interrupt() - print_agent_output(chunk) - - if _global_memory["plan_completed"]: - _global_memory["plan_completed"] = False - _global_memory["task_completed"] = False - _global_memory["completion_message"] = "" - break - if _global_memory["task_completed"]: - _global_memory["task_completed"] = False - _global_memory["completion_message"] = "" - break - - # Execute test command if configured + _run_agent_stream(agent, prompt, config) + fallback_handler.reset_fallback_handler() should_break, prompt, auto_test, test_attempts = ( - execute_test_command( - config, original_prompt, test_attempts, auto_test + _execute_test_command_wrapper( + original_prompt, config, test_attempts, auto_test ) ) if should_break: break if prompt != original_prompt: continue - logger.debug("Agent run completed successfully") return "Agent run completed successfully" + except (ToolExecutionError, InvalidToolCall) as e: + _handle_tool_execution_error(fallback_handler, agent, e) except (KeyboardInterrupt, AgentInterrupt): raise except ( @@ -839,35 +883,15 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: APIError, ValueError, ) as e: - if isinstance(e, ValueError): - error_str = str(e).lower() - if "code" not in error_str or "429" not in error_str: - raise # Re-raise ValueError if it's not a Lambda 429 - if attempt == max_retries - 1: - logger.error("Max retries reached, failing: %s", str(e)) - raise RuntimeError( - f"Max retries ({max_retries}) exceeded. Last error: {e}" - ) - logger.warning( - "API error (attempt %d/%d): %s", - attempt + 1, - max_retries, - str(e), - ) - delay = base_delay * (2**attempt) - print_error( - f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})" - ) - start = time.monotonic() - while time.monotonic() - start < delay: - check_interrupt() - time.sleep(0.1) + _handle_api_error(e, attempt, max_retries, base_delay) finally: - # Reset depth tracking - _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 + _decrement_agent_depth() + _restore_interrupt_handling(original_handler) - if ( - original_handler - and threading.current_thread() is threading.main_thread() - ): - signal.signal(signal.SIGINT, original_handler) + +def _handle_tool_execution_error( + fallback_handler: FallbackHandler, + agent: CiaynAgent | CompiledGraph, + error: ToolExecutionError | InvalidToolCall, +): + fallback_handler.handle_failure("Tool execution error", error, agent) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 57c7467..4060684 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -4,7 +4,6 @@ from typing import Any, Dict, Generator, List, Optional, Union from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage -from ra_aid.fallback_handler import FallbackHandler from ra_aid.exceptions import ToolExecutionError from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT @@ -84,13 +83,12 @@ class CiaynAgent: tools: List of tools available to the agent max_history_messages: Maximum number of messages to keep in chat history max_tokens: Maximum number of tokens allowed in message history (None for no limit) - config: Optional configuration dictionary for fallback settings + config: Optional configuration dictionary """ if config is None: config = {} self.config = config self.provider = config.get("provider", "openai") - self.fallback_handler = FallbackHandler(config) self.model = model self.tools = tools @@ -232,39 +230,29 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" return base_prompt def _execute_tool(self, code: str) -> str: - """Execute a tool call with retry and fallback logic and return its result.""" - max_retries = 3 - retries = 0 - last_error = None - while retries < max_retries: - try: - logger.debug( - f"_execute_tool: attempt {retries+1}, original code: {code}" - ) - code = code.strip() - if validate_function_call_pattern(code): - functions_list = "\n\n".join(self.available_functions) - code = _extract_tool_call(code, functions_list) - globals_dict = {tool.func.__name__: tool.func for tool in self.tools} - logger.debug(f"_execute_tool: evaluating code: {code}") - result = eval(code, globals_dict) - logger.debug( - f"_execute_tool: tool executed successfully with result: {result}" - ) - self.fallback_handler.reset_fallback_handler() - return result - except Exception as e: - logger.debug(f"_execute_tool: exception caught: {e}") - self._handle_tool_failure(code, e) - last_error = e - retries += 1 - logger.debug(f"_execute_tool: retrying, new attempt count: {retries}") - raise ToolExecutionError( - f"Error executing code after {max_retries} attempts: {str(last_error)}" - ) + """Execute a tool call and return its result.""" + globals_dict = {tool.func.__name__: tool.func for tool in self.tools} - def _handle_tool_failure(self, code: str, error: Exception) -> None: - self.fallback_handler.handle_failure(code, error, logger, self) + try: + code = code.strip() + logger.debug(f"_execute_tool: stripped code: {code}") + + # if the eval fails, try to extract it via a model call + if validate_function_call_pattern(code): + functions_list = "\n\n".join(self.available_functions) + logger.debug(f"_execute_tool: code before extraction: {code}") + code = _extract_tool_call(code, functions_list) + logger.debug(f"_execute_tool: code after extraction: {code}") + + logger.debug( + f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}" + ) + result = eval(code.strip(), globals_dict) + logger.debug(f"_execute_tool: result: {result}") + return result + except Exception as e: + error_msg = f"Error executing code: {str(e)}" + raise ToolExecutionError(error_msg) def _create_agent_chunk(self, content: str) -> Dict[str, Any]: """Create an agent chunk in the format expected by print_agent_output.""" diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index ebc3cb6..2b248a5 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -9,6 +9,8 @@ from rich.markdown import Markdown from rich.panel import Panel from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env +logger = get_logger(__name__) + class FallbackHandler: """ @@ -73,6 +75,7 @@ class FallbackHandler: for item in supported: if "type" not in item: item["type"] = "prompt" + item["model"] = item["model"].lower() final_models.append(item) message = "Fallback models selected: " + ", ".join( [m["model"] for m in final_models] @@ -85,7 +88,7 @@ class FallbackHandler: console.print(Panel(Markdown(message), title="Fallback Models")) return final_models - def handle_failure(self, code: str, error: Exception, logger, agent): + def handle_failure(self, code: str, error: Exception, agent): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. @@ -173,8 +176,23 @@ class FallbackHandler: simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) + tool_to_bind = next( + ( + t + for t in agent.tools + if t.func.__name__ == failed_tool_call_name + ), + None, + ) + if tool_to_bind is None: + logger.debug( + f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}" + ) + raise Exception( + f"Tool {failed_tool_call_name} not found in agent.tools" + ) binded_model = simple_model.bind_tools( - agent.tools, tool_choice=failed_tool_call_name + [tool_to_bind], tool_choice=failed_tool_call_name ) retry_model = binded_model.with_retry( stop_after_attempt=RETRY_FALLBACK_COUNT @@ -221,8 +239,23 @@ class FallbackHandler: simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) + tool_to_bind = next( + ( + t + for t in agent.tools + if t.func.__name__ == failed_tool_call_name + ), + None, + ) + if tool_to_bind is None: + logger.debug( + f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}" + ) + raise Exception( + f"Tool {failed_tool_call_name} not found in agent.tools" + ) binded_model = simple_model.bind_tools( - agent.tools, tool_choice=failed_tool_call_name + [tool_to_bind], tool_choice=failed_tool_call_name ) retry_model = binded_model.with_retry( stop_after_attempt=RETRY_FALLBACK_COUNT diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 5e935ed..7a02a85 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -275,3 +275,81 @@ def test_get_model_token_limit_planner(mock_memory): mock_get_info.return_value = {"max_input_tokens": 120000} token_limit = get_model_token_limit(config, "planner") assert token_limit == 120000 + +# New tests for private helper methods in agent_utils.py + +def test_setup_and_restore_interrupt_handling(): + import signal, threading + from ra_aid.agent_utils import _setup_interrupt_handling, _restore_interrupt_handling, _request_interrupt + original_handler = signal.getsignal(signal.SIGINT) + handler = _setup_interrupt_handling() + # Verify the SIGINT handler is set to _request_interrupt + assert signal.getsignal(signal.SIGINT) == _request_interrupt + _restore_interrupt_handling(handler) + # Verify the SIGINT handler is restored to the original + assert signal.getsignal(signal.SIGINT) == original_handler + +def test_increment_and_decrement_agent_depth(): + from ra_aid.agent_utils import _increment_agent_depth, _decrement_agent_depth, _global_memory + _global_memory["agent_depth"] = 10 + _increment_agent_depth() + assert _global_memory["agent_depth"] == 11 + _decrement_agent_depth() + assert _global_memory["agent_depth"] == 10 + +def test_run_agent_stream(monkeypatch): + from ra_aid.agent_utils import _run_agent_stream, _global_memory + # Create a dummy agent that yields one chunk + class DummyAgent: + def stream(self, msg, cfg): + yield {"content": "chunk1"} + dummy_agent = DummyAgent() + # Set flags so that _run_agent_stream will reset them + _global_memory["plan_completed"] = True + _global_memory["task_completed"] = True + _global_memory["completion_message"] = "existing" + call_flag = {"called": False} + def fake_print_agent_output(chunk): + call_flag["called"] = True + monkeypatch.setattr("ra_aid.agent_utils.print_agent_output", fake_print_agent_output) + _run_agent_stream(dummy_agent, "dummy prompt", {}) + assert call_flag["called"] + assert _global_memory["plan_completed"] is False + assert _global_memory["task_completed"] is False + assert _global_memory["completion_message"] == "" + +def test_execute_test_command_wrapper(monkeypatch): + from ra_aid.agent_utils import _execute_test_command_wrapper + # Patch execute_test_command to return a testable tuple + def fake_execute(config, orig, tests, auto): + return (True, "new prompt", auto, tests + 1) + monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute) + result = _execute_test_command_wrapper("orig", {}, 0, False) + assert result == (True, "new prompt", False, 1) + +def test_handle_api_error_valueerror(): + from ra_aid.agent_utils import _handle_api_error + import pytest + # ValueError not containing "code" or "429" should be re-raised + with pytest.raises(ValueError): + _handle_api_error(ValueError("some error"), 0, 5, 1) + +def test_handle_api_error_max_retries(): + from ra_aid.agent_utils import _handle_api_error + import pytest + # When attempt reaches max retries, a RuntimeError should be raised + with pytest.raises(RuntimeError): + _handle_api_error(Exception("error code 429"), 4, 5, 1) + +def test_handle_api_error_retry(monkeypatch): + from ra_aid.agent_utils import _handle_api_error + import time + # Patch time.monotonic and time.sleep to simulate immediate delay expiration + fake_time = [0] + def fake_monotonic(): + fake_time[0] += 0.5 + return fake_time[0] + monkeypatch.setattr(time, "monotonic", fake_monotonic) + monkeypatch.setattr(time, "sleep", lambda s: None) + # Should not raise error when attempt is lower than max retries + _handle_api_error(Exception("error code 429"), 0, 5, 1) From 67ecf72a6c15c89f4ac4199d5a8710d80c1eb10b Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Feb 2025 18:35:34 -0800 Subject: [PATCH 09/45] feat(fallback): implement fallback handler for tool execution errors to enhance error resilience and user experience refactor(fallback): streamline fallback model selection and invocation process for improved maintainability fix(config): reduce maximum tool failures from 3 to 2 to tighten error handling thresholds style(console): improve error message formatting and logging for better clarity and debugging chore(main): remove redundant fallback tool model handling from main function to simplify configuration management --- ra_aid/__main__.py | 18 --- ra_aid/agent_utils.py | 69 +++++++--- ra_aid/config.py | 2 +- ra_aid/console/output.py | 22 +++- ra_aid/exceptions.py | 5 +- ra_aid/fallback_handler.py | 261 +++++++++++++++++++++++-------------- ra_aid/tool_configs.py | 19 ++- 7 files changed, 260 insertions(+), 136 deletions(-) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 654dd60..e027a08 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -427,15 +427,6 @@ def main(): _global_memory["config"]["planner_model"] = args.planner_model or args.model _global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool - _global_memory["config"]["fallback_tool_models"] = ( - [ - model.strip() - for model in args.fallback_tool_models.split(",") - if model.strip() - ] - if args.fallback_tool_models - else [] - ) # Store research config with fallback to base values _global_memory["config"]["research_provider"] = ( @@ -445,15 +436,6 @@ def main(): # Store fallback tool configuration _global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool - _global_memory["config"]["fallback_tool_models"] = ( - [ - model.strip() - for model in args.fallback_tool_models.split(",") - if model.strip() - ] - if args.fallback_tool_models - else [] - ) # Run research stage print_stage_header("Research Stage") diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 678ce73..cf85a0c 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -16,7 +16,6 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import ( BaseMessage, HumanMessage, - InvalidToolCall, trim_messages, ) from langchain_core.tools import tool @@ -339,9 +338,6 @@ def run_research_agent( if memory is None: memory = MemorySaver() - if thread_id is None: - thread_id = str(uuid.uuid4()) - tools = get_research_tools( research_only=research_only, expert_enabled=expert_enabled, @@ -413,7 +409,8 @@ def run_research_agent( if agent is not None: logger.debug("Research agent completed successfully") - _result = run_agent_with_retry(agent, prompt, run_config) + fallback_handler = FallbackHandler(config, tools) + _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) if _result: # Log research completion log_work_event(f"Completed research phase for: {base_task_or_query}") @@ -529,7 +526,8 @@ def run_web_research_agent( console.print(Panel(Markdown(console_message), title="🔬 Researching...")) logger.debug("Web research agent completed successfully") - _result = run_agent_with_retry(agent, prompt, run_config) + fallback_handler = FallbackHandler(config, tools) + _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) if _result: # Log web research completion log_work_event(f"Completed web research phase for: {query}") @@ -634,7 +632,10 @@ def run_planning_agent( try: print_stage_header("Planning Stage") logger.debug("Planning agent completed successfully") - _result = run_agent_with_retry(agent, planning_prompt, run_config) + fallback_handler = FallbackHandler(config, tools) + _result = run_agent_with_retry( + agent, planning_prompt, run_config, fallback_handler + ) if _result: # Log planning completion log_work_event(f"Completed planning phase for: {base_task}") @@ -739,7 +740,8 @@ def run_task_implementation_agent( try: logger.debug("Implementation agent completed successfully") - _result = run_agent_with_retry(agent, prompt, run_config) + fallback_handler = FallbackHandler(config, tools) + _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) if _result: # Log task implementation completion log_work_event(f"Completed implementation of task: {task}") @@ -805,7 +807,7 @@ def _decrement_agent_depth(): _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 -def _run_agent_stream(agent, prompt, config): +def _run_agent_stream(agent: CompiledGraph, prompt: str, config: dict): for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config): logger.debug("Agent output: %s", chunk) check_interrupt() @@ -840,7 +842,9 @@ def _handle_api_error(e, attempt, max_retries, base_delay): time.sleep(0.1) -def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: +def run_agent_with_retry( + agent, prompt: str, config: dict, fallback_handler: FallbackHandler +) -> Optional[str]: """Run an agent with retry logic for API errors.""" logger.debug("Running agent with prompt length: %d", len(prompt)) original_handler = _setup_interrupt_handling() @@ -850,7 +854,6 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: _max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) auto_test = config.get("auto_test", False) original_prompt = prompt - fallback_handler = FallbackHandler(config) with InterruptibleSection(): try: @@ -872,8 +875,13 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: continue logger.debug("Agent run completed successfully") return "Agent run completed successfully" - except (ToolExecutionError, InvalidToolCall) as e: - _handle_tool_execution_error(fallback_handler, agent, e) + except ToolExecutionError as e: + fallback_response = _handle_tool_execution_error( + fallback_handler, agent, e + ) + if fallback_response: + prompt = original_prompt + "\n" + fallback_response + continue except (KeyboardInterrupt, AgentInterrupt): raise except ( @@ -892,6 +900,37 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: def _handle_tool_execution_error( fallback_handler: FallbackHandler, agent: CiaynAgent | CompiledGraph, - error: ToolExecutionError | InvalidToolCall, + error: ToolExecutionError, ): - fallback_handler.handle_failure("Tool execution error", error, agent) + logger.debug("Entering _handle_tool_execution_error with error: %s", error) + if error.tool_name: + failed_tool_call_name = error.tool_name + logger.debug( + "Extracted failed_tool_call_name from error.tool_name: %s", + failed_tool_call_name, + ) + else: + import re + + msg = str(error) + logger.debug("Error message: %s", msg) + match = re.search(r"name=['\"](\w+)['\"]", msg) + if match: + failed_tool_call_name = match.group(1) + logger.debug( + "Extracted failed_tool_call_name using regex: %s", failed_tool_call_name + ) + else: + failed_tool_call_name = "Tool execution error" + logger.debug( + "Defaulting failed_tool_call_name to: %s", failed_tool_call_name + ) + logger.debug( + "Calling fallback_handler.handle_failure with failed_tool_call_name: %s", + failed_tool_call_name, + ) + fallback_response = fallback_handler.handle_failure( + failed_tool_call_name, error, agent + ) + logger.debug("Fallback response received: %s", fallback_response) + return fallback_response diff --git a/ra_aid/config.py b/ra_aid/config.py index e85cb12..54d7995 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -2,7 +2,7 @@ DEFAULT_RECURSION_LIMIT = 100 DEFAULT_MAX_TEST_CMD_RETRIES = 3 -DEFAULT_MAX_TOOL_FAILURES = 3 +DEFAULT_MAX_TOOL_FAILURES = 2 FALLBACK_TOOL_MODEL_LIMIT = 5 RETRY_FALLBACK_COUNT = 3 RETRY_FALLBACK_DELAY = 2 diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index 8b64142..aad96e6 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -1,9 +1,11 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from langchain_core.messages import AIMessage from rich.markdown import Markdown from rich.panel import Panel +from ra_aid.exceptions import ToolExecutionError + # Import shared console instance from .formatting import console @@ -33,10 +35,26 @@ def print_agent_output(chunk: Dict[str, Any]) -> None: elif "tools" in chunk and "messages" in chunk["tools"]: for msg in chunk["tools"]["messages"]: if msg.status == "error" and msg.content: + err_msg = msg.content.strip() console.print( Panel( - Markdown(msg.content.strip()), + Markdown(err_msg), title="❌ Tool Error", border_style="red bold", ) ) + tool_name = getattr(msg, "name", None) + raise ToolExecutionError(err_msg, tool_name=tool_name) + + +def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None: + """ + Print a message using a Panel with Markdown formatting. + + Args: + message (str): The message content to display. + title (Optional[str]): An optional title for the panel. + border_style (str): Border style for the panel. + """ + + console.print(Panel(Markdown(message), title=title, border_style=border_style)) diff --git a/ra_aid/exceptions.py b/ra_aid/exceptions.py index 696b47e..d8bc532 100644 --- a/ra_aid/exceptions.py +++ b/ra_aid/exceptions.py @@ -17,5 +17,6 @@ class ToolExecutionError(Exception): This exception is used to distinguish tool execution failures from other types of errors in the agent system. """ - - pass + def __init__(self, message, tool_name=None): + super().__init__(message) + self.tool_name = tool_name diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 2b248a5..8dd459c 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -1,13 +1,24 @@ +from typing import Dict +from langchain_core.tools import BaseTool +from langgraph.graph.graph import CompiledGraph +from langgraph.graph.message import BaseMessage + +from ra_aid.console.output import cpm +import json + +from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.config import ( DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, ) +from ra_aid.logging_config import get_logger from ra_aid.tool_leaderboard import supported_top_tool_models from rich.console import Console -from rich.markdown import Markdown -from rich.panel import Panel -from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env +from ra_aid.llm import initialize_llm, validate_provider_env + +# from langgraph.graph.message import BaseMessage, BaseMessageChunk +# from langgraph.prebuilt import ToolNode logger = get_logger(__name__) @@ -22,18 +33,21 @@ class FallbackHandler: counters when a tool call succeeds. """ - def __init__(self, config): + def __init__(self, config, tools): """ - Initialize the FallbackHandler with the given configuration. + Initialize the FallbackHandler with the given configuration and tools. Args: config (dict): Configuration dictionary that may include fallback settings. + tools (list): List of available tools. """ self.config = config + self.tools: list[BaseTool] = tools self.fallback_enabled = config.get("fallback_tool_enabled", True) self.fallback_tool_models = self._load_fallback_tool_models(config) self.tool_failure_consecutive_failures = 0 self.tool_failure_used_fallbacks = set() + self.console = Console() def _load_fallback_tool_models(self, config): """ @@ -49,46 +63,37 @@ class FallbackHandler: Returns: list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model. """ - fallback_tool_models_config = config.get("fallback_tool_models") - if fallback_tool_models_config: - # Assume comma-separated model names; wrap each in a dict with default type "prompt" - models = [] - for m in [ - x.strip() for x in fallback_tool_models_config.split(",") if x.strip() - ]: - models.append({"model": m, "type": "prompt"}) - return models - else: - console = Console() - supported = [] - skipped = [] - for item in supported_top_tool_models: - provider = item.get("provider") - model_name = item.get("model") - if validate_provider_env(provider): - supported.append(item) - if len(supported) == FALLBACK_TOOL_MODEL_LIMIT: - break - else: - skipped.append(model_name) - final_models = [] - for item in supported: - if "type" not in item: - item["type"] = "prompt" - item["model"] = item["model"].lower() - final_models.append(item) - message = "Fallback models selected: " + ", ".join( - [m["model"] for m in final_models] + supported = [] + skipped = [] + for item in supported_top_tool_models: + provider = item.get("provider") + model_name = item.get("model") + if validate_provider_env(provider): + supported.append(item) + if len(supported) == FALLBACK_TOOL_MODEL_LIMIT: + break + else: + skipped.append(model_name) + final_models = [] + for item in supported: + if "type" not in item: + item["type"] = "prompt" + item["model"] = item["model"].lower() + final_models.append(item) + message = "Fallback models selected: " + ", ".join( + [m["model"] for m in final_models] + ) + if skipped: + message += ( + "\nSkipped top tool calling models due to missing provider ENV API keys: " + + ", ".join(skipped) ) - if skipped: - message += ( - "\nSkipped top tool calling models due to missing provider ENV API keys: " - + ", ".join(skipped) - ) - console.print(Panel(Markdown(message), title="Fallback Models")) - return final_models + cpm(message, title="Fallback Models") + return final_models - def handle_failure(self, code: str, error: Exception, agent): + def handle_failure( + self, code: str, error: Exception, agent: CiaynAgent | CompiledGraph + ): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. @@ -114,7 +119,7 @@ class FallbackHandler: logger.debug( "_handle_tool_failure: threshold reached, invoking fallback mechanism." ) - self.attempt_fallback(code, logger, agent) + return self.attempt_fallback(code, logger, agent) def attempt_fallback(self, code: str, logger, agent): """ @@ -127,17 +132,13 @@ class FallbackHandler: """ logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") fallback_model = self.fallback_tool_models[0] - failed_tool_call_name = code.split("(")[0].strip() + failed_tool_call_name = code logger.error( f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" ) - Console().print( - Panel( - Markdown( - f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}." - ), - title="Fallback Notification", - ) + cpm( + f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}.", + title="Fallback Notification", ) if fallback_model.get("type", "prompt").lower() == "fc": self.attempt_fallback_function(code, logger, agent) @@ -151,6 +152,30 @@ class FallbackHandler: self.tool_failure_consecutive_failures = 0 self.tool_failure_used_fallbacks.clear() + def _find_tool_to_bind(self, agent, failed_tool_call_name): + logger.debug(f"failed_tool_call_name={failed_tool_call_name}") + tool_to_bind = None + if hasattr(agent, "tools"): + tool_to_bind = next( + (t for t in agent.tools if t.func.__name__ == failed_tool_call_name), + None, + ) + if tool_to_bind is None: + from ra_aid.tool_configs import get_all_tools + + all_tools = get_all_tools() + tool_to_bind = next( + (t for t in all_tools if t.func.__name__ == failed_tool_call_name), + None, + ) + if tool_to_bind is None: + available = [t.func.__name__ for t in get_all_tools()] + logger.debug( + f"Failed to find tool: {failed_tool_call_name}. Available tools: {available}" + ) + raise Exception(f"Tool {failed_tool_call_name} not found in all tools.") + return tool_to_bind + def attempt_fallback_prompt(self, code: str, logger, agent): """ Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code. @@ -169,43 +194,41 @@ class FallbackHandler: Exception: If all prompt-based fallback models fail. """ logger.debug("Attempting prompt-based fallback using fallback models") - failed_tool_call_name = code.split("(")[0].strip() + failed_tool_call_name = code for fallback_model in self.fallback_tool_models: try: logger.debug(f"Trying fallback model: {fallback_model['model']}") simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) - tool_to_bind = next( - ( - t - for t in agent.tools - if t.func.__name__ == failed_tool_call_name - ), - None, - ) - if tool_to_bind is None: - logger.debug( - f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}" - ) - raise Exception( - f"Tool {failed_tool_call_name} not found in agent.tools" - ) + tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name) binded_model = simple_model.bind_tools( [tool_to_bind], tool_choice=failed_tool_call_name ) - retry_model = binded_model.with_retry( - stop_after_attempt=RETRY_FALLBACK_COUNT - ) - response = retry_model.invoke(code) + # retry_model = binded_model.with_retry( + # stop_after_attempt=RETRY_FALLBACK_COUNT + # ) + response = binded_model.invoke(code) + cpm(f"response={response}") + self.tool_failure_used_fallbacks.add(fallback_model["model"]) - agent.model = retry_model - self.reset_fallback_handler() - logger.debug( - "Prompt-based fallback executed successfully with model: " - + fallback_model["model"] - ) - return response + + tool_call = self.base_message_to_tool_call_dict(response) + if tool_call: + result = self.invoke_prompt_tool_call(tool_call) + cpm(f"result={result}") + logger.debug( + "Prompt-based fallback executed successfully with model: " + + fallback_model["model"] + ) + self.reset_fallback_handler() + return result + else: + cpm( + response.content if hasattr(response, "content") else response, + title="Fallback Model Response: " + fallback_model["model"], + ) + return response except Exception as e: if isinstance(e, KeyboardInterrupt): raise @@ -232,28 +255,14 @@ class FallbackHandler: Exception: If all function-calling fallback models fail. """ logger.debug("Attempting function-calling fallback using fallback models") - failed_tool_call_name = code.split("(")[0].strip() + failed_tool_call_name = code for fallback_model in self.fallback_tool_models: try: logger.debug(f"Trying fallback model: {fallback_model['model']}") simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) - tool_to_bind = next( - ( - t - for t in agent.tools - if t.func.__name__ == failed_tool_call_name - ), - None, - ) - if tool_to_bind is None: - logger.debug( - f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}" - ) - raise Exception( - f"Tool {failed_tool_call_name} not found in agent.tools" - ) + tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name) binded_model = simple_model.bind_tools( [tool_to_bind], tool_choice=failed_tool_call_name ) @@ -261,13 +270,18 @@ class FallbackHandler: stop_after_attempt=RETRY_FALLBACK_COUNT ) response = retry_model.invoke(code) + cpm(f"response={response}") self.tool_failure_used_fallbacks.add(fallback_model["model"]) - agent.model = retry_model self.reset_fallback_handler() logger.debug( "Function-calling fallback executed successfully with model: " + fallback_model["model"] ) + + cpm( + response.content if hasattr(response, "content") else response, + title="Fallback Model Response: " + fallback_model["model"], + ) return response except Exception as e: if isinstance(e, KeyboardInterrupt): @@ -276,3 +290,58 @@ class FallbackHandler: f"Function-calling fallback with model {fallback_model['model']} failed: {e}" ) raise Exception("All function-calling fallback models failed") + + def invoke_prompt_tool_call(self, tool_call_request: dict): + """ + Invoke a tool call from a prompt-based fallback response. + + Args: + tool_call_request (dict): The tool call request containing keys 'type', 'name', and 'arguments'. + + Returns: + The result of invoking the tool. + """ + tool_name_to_tool = {tool.func.__name__: tool for tool in self.tools} + name = tool_call_request["name"] + arguments = tool_call_request["arguments"] + # return tool_name_to_tool[name].invoke(arguments) + # tool_call_dict = {"arguments": arguments} + return tool_name_to_tool[name].invoke(arguments) + + def base_message_to_tool_call_dict(self, response: BaseMessage): + """ + Extracts a tool call dictionary from a fallback response. + + Args: + response: The response object containing tool call data. + + Returns: + A tool call dictionary with keys 'id', 'type', 'name', and 'arguments' if a tool call is found, + otherwise None. + """ + tool_calls = None + if hasattr(response, "additional_kwargs") and response.additional_kwargs.get( + "tool_calls" + ): + tool_calls = response.additional_kwargs.get("tool_calls") + elif hasattr(response, "tool_calls"): + tool_calls = response.tool_calls + elif isinstance(response, dict) and response.get("additional_kwargs", {}).get( + "tool_calls" + ): + tool_calls = response.get("additional_kwargs").get("tool_calls") + if tool_calls: + if len(tool_calls) > 1: + logger.warning("Multiple tool calls detected, using the first one") + tool_call = tool_calls[0] + return { + "id": tool_call["id"], + "type": tool_call["type"], + "name": tool_call["function"]["name"], + "arguments": ( + json.loads(tool_call["function"]["arguments"]) + if isinstance(tool_call["function"]["arguments"], str) + else tool_call["function"]["arguments"] + ), + } + return None diff --git a/ra_aid/tool_configs.py b/ra_aid/tool_configs.py index e4042f1..8fce691 100644 --- a/ra_aid/tool_configs.py +++ b/ra_aid/tool_configs.py @@ -28,7 +28,7 @@ from ra_aid.tools.write_file import write_file_tool # Read-only tools that don't modify system state def get_read_only_tools( human_interaction: bool = False, web_research_enabled: bool = False -) -> list: +): """Get the list of read-only tools, optionally including human interaction tools. Args: @@ -61,6 +61,21 @@ def get_read_only_tools( return tools +def get_all_tools_simple(): + """Return a list containing all available tools using existing group methods.""" + return get_all_tools() + +def get_all_tools(): + """Return a list containing all available tools from different groups.""" + all_tools = [] + all_tools.extend(get_read_only_tools()) + all_tools.extend(MODIFICATION_TOOLS) + all_tools.extend(EXPERT_TOOLS) + all_tools.extend(RESEARCH_TOOLS) + all_tools.extend(get_web_research_tools()) + all_tools.extend(get_chat_tools()) + return all_tools + # Define constant tool groups READ_ONLY_TOOLS = get_read_only_tools() @@ -81,7 +96,7 @@ def get_research_tools( expert_enabled: bool = True, human_interaction: bool = False, web_research_enabled: bool = False, -) -> list: +): """Get the list of research tools based on mode and whether expert is enabled. Args: From a7322eaef23fd0cc121aed05d7565e0871af7964 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Feb 2025 18:38:52 -0800 Subject: [PATCH 10/45] refactor(fallback_handler.py): clean up code by removing unused imports and comments to enhance readability refactor(fallback_handler.py): extract tool call extraction logic into a separate method for better organization and maintainability refactor(fallback_handler.py): introduce _parse_tool_arguments method to handle argument parsing, improving code clarity and reusability --- ra_aid/fallback_handler.py | 59 ++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 8dd459c..31df7d5 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -1,4 +1,3 @@ -from typing import Dict from langchain_core.tools import BaseTool from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import BaseMessage @@ -17,9 +16,6 @@ from ra_aid.tool_leaderboard import supported_top_tool_models from rich.console import Console from ra_aid.llm import initialize_llm, validate_provider_env -# from langgraph.graph.message import BaseMessage, BaseMessageChunk -# from langgraph.prebuilt import ToolNode - logger = get_logger(__name__) @@ -304,13 +300,11 @@ class FallbackHandler: tool_name_to_tool = {tool.func.__name__: tool for tool in self.tools} name = tool_call_request["name"] arguments = tool_call_request["arguments"] - # return tool_name_to_tool[name].invoke(arguments) - # tool_call_dict = {"arguments": arguments} return tool_name_to_tool[name].invoke(arguments) def base_message_to_tool_call_dict(self, response: BaseMessage): """ - Extracts a tool call dictionary from a fallback response. + Extracts a tool call dictionary from a BaseMessage. Args: response: The response object containing tool call data. @@ -319,6 +313,41 @@ class FallbackHandler: A tool call dictionary with keys 'id', 'type', 'name', and 'arguments' if a tool call is found, otherwise None. """ + tool_calls = self.get_tool_calls(response) + if tool_calls: + if len(tool_calls) > 1: + logger.warning("Multiple tool calls detected, using the first one") + tool_call = tool_calls[0] + return { + "id": tool_call["id"], + "type": tool_call["type"], + "name": tool_call["function"]["name"], + "arguments": self._parse_tool_arguments( + tool_call["function"]["arguments"] + ), + } + return None + + def _parse_tool_arguments(self, tool_arguments): + """ + Helper method to parse tool call arguments. + If tool_arguments is a string, it returns the JSON-parsed dictionary. + Otherwise, returns tool_arguments as is. + """ + if isinstance(tool_arguments, str): + return json.loads(tool_arguments) + return tool_arguments + + def get_tool_calls(self, response: BaseMessage): + """ + Extracts tool calls list from a fallback response. + + Args: + response: The response object containing tool call data. + + Returns: + The tool calls list if present, otherwise None. + """ tool_calls = None if hasattr(response, "additional_kwargs") and response.additional_kwargs.get( "tool_calls" @@ -330,18 +359,4 @@ class FallbackHandler: "tool_calls" ): tool_calls = response.get("additional_kwargs").get("tool_calls") - if tool_calls: - if len(tool_calls) > 1: - logger.warning("Multiple tool calls detected, using the first one") - tool_call = tool_calls[0] - return { - "id": tool_call["id"], - "type": tool_call["type"], - "name": tool_call["function"]["name"], - "arguments": ( - json.loads(tool_call["function"]["arguments"]) - if isinstance(tool_call["function"]["arguments"], str) - else tool_call["function"]["arguments"] - ), - } - return None + return tool_calls From af9f95ceb1905b7a30a4f6c4fcccb0a29e39e1b0 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Wed, 12 Feb 2025 13:07:12 -0800 Subject: [PATCH 11/45] refactor(agent_utils.py): remove the _handle_tool_execution_error function and simplify error handling in run_agent_with_retry feat(fallback_handler.py): enhance handle_failure method to extract tool name from ToolExecutionError and improve fallback logic fix(exceptions.py): update ToolExecutionError to include base_message for better error context feat(output.py): add base_message to ToolExecutionError for improved debugging chore(tool_configs.py): update get_all_tools function to specify return type style(logging_config.py): reorder imports for consistency test(tests): add tests for new error handling and fallback logic in agent_utils and fallback_handler --- ra_aid/agent_utils.py | 45 +--- ra_aid/console/output.py | 4 +- ra_aid/exceptions.py | 13 +- ra_aid/fallback_handler.py | 306 ++++++++++++++------------ ra_aid/logging_config.py | 3 +- ra_aid/tool_configs.py | 12 +- tests/ra_aid/test_agent_utils.py | 51 ++++- tests/ra_aid/test_fallback_handler.py | 26 ++- 8 files changed, 252 insertions(+), 208 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index cf85a0c..3aea6a8 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -9,7 +9,6 @@ import uuid from datetime import datetime from typing import Any, Dict, List, Literal, Optional, Sequence -from langgraph.graph.graph import CompiledGraph import litellm from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from langchain_core.language_models import BaseChatModel @@ -20,6 +19,7 @@ from langchain_core.messages import ( ) from langchain_core.tools import tool from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph.graph import CompiledGraph from langgraph.prebuilt import create_react_agent from langgraph.prebuilt.chat_agent_executor import AgentState from litellm import get_model_info @@ -876,9 +876,7 @@ def run_agent_with_retry( logger.debug("Agent run completed successfully") return "Agent run completed successfully" except ToolExecutionError as e: - fallback_response = _handle_tool_execution_error( - fallback_handler, agent, e - ) + fallback_response = fallback_handler.handle_failure(e, agent) if fallback_response: prompt = original_prompt + "\n" + fallback_response continue @@ -895,42 +893,3 @@ def run_agent_with_retry( finally: _decrement_agent_depth() _restore_interrupt_handling(original_handler) - - -def _handle_tool_execution_error( - fallback_handler: FallbackHandler, - agent: CiaynAgent | CompiledGraph, - error: ToolExecutionError, -): - logger.debug("Entering _handle_tool_execution_error with error: %s", error) - if error.tool_name: - failed_tool_call_name = error.tool_name - logger.debug( - "Extracted failed_tool_call_name from error.tool_name: %s", - failed_tool_call_name, - ) - else: - import re - - msg = str(error) - logger.debug("Error message: %s", msg) - match = re.search(r"name=['\"](\w+)['\"]", msg) - if match: - failed_tool_call_name = match.group(1) - logger.debug( - "Extracted failed_tool_call_name using regex: %s", failed_tool_call_name - ) - else: - failed_tool_call_name = "Tool execution error" - logger.debug( - "Defaulting failed_tool_call_name to: %s", failed_tool_call_name - ) - logger.debug( - "Calling fallback_handler.handle_failure with failed_tool_call_name: %s", - failed_tool_call_name, - ) - fallback_response = fallback_handler.handle_failure( - failed_tool_call_name, error, agent - ) - logger.debug("Fallback response received: %s", fallback_response) - return fallback_response diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index aad96e6..7e62e1b 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -44,7 +44,9 @@ def print_agent_output(chunk: Dict[str, Any]) -> None: ) ) tool_name = getattr(msg, "name", None) - raise ToolExecutionError(err_msg, tool_name=tool_name) + cpm(f"type(msg): {type(msg)}") + cpm(f"msg: {msg}") + raise ToolExecutionError(err_msg, tool_name=tool_name, base_message=msg) def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None: diff --git a/ra_aid/exceptions.py b/ra_aid/exceptions.py index d8bc532..34710d9 100644 --- a/ra_aid/exceptions.py +++ b/ra_aid/exceptions.py @@ -1,5 +1,9 @@ """Custom exceptions for RA.Aid.""" +from typing import Optional + +from langchain_core.messages import BaseMessage + class AgentInterrupt(Exception): """Exception raised when an agent's execution is interrupted. @@ -17,6 +21,13 @@ class ToolExecutionError(Exception): This exception is used to distinguish tool execution failures from other types of errors in the agent system. """ - def __init__(self, message, tool_name=None): + + def __init__( + self, + message: str, + base_message: Optional[BaseMessage] = None, + tool_name: Optional[str] = None, + ): super().__init__(message) + self.base_message = base_message self.tool_name = tool_name diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 31df7d5..5e80826 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -1,20 +1,22 @@ +import json +import re + from langchain_core.tools import BaseTool from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import BaseMessage -from ra_aid.console.output import cpm -import json - from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.config import ( DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, ) -from ra_aid.logging_config import get_logger -from ra_aid.tool_leaderboard import supported_top_tool_models -from rich.console import Console +from ra_aid.console.output import cpm +from ra_aid.exceptions import ToolExecutionError from ra_aid.llm import initialize_llm, validate_provider_env +from ra_aid.logging_config import get_logger +from ra_aid.tool_configs import get_all_tools +from ra_aid.tool_leaderboard import supported_top_tool_models logger = get_logger(__name__) @@ -41,9 +43,12 @@ class FallbackHandler: self.tools: list[BaseTool] = tools self.fallback_enabled = config.get("fallback_tool_enabled", True) self.fallback_tool_models = self._load_fallback_tool_models(config) + self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) self.tool_failure_consecutive_failures = 0 + self.failed_messages = set() self.tool_failure_used_fallbacks = set() - self.console = Console() + self.current_failing_tool_name = "" + self.current_tool_to_bind = None def _load_fallback_tool_models(self, config): """ @@ -87,66 +92,104 @@ class FallbackHandler: cpm(message, title="Fallback Models") return final_models - def handle_failure( - self, code: str, error: Exception, agent: CiaynAgent | CompiledGraph - ): + def handle_failure(self, error: Exception, agent: CiaynAgent | CompiledGraph): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. Args: - code (str): The code that failed to execute. - error (Exception): The exception raised during execution. - logger: Logger instance for logging. + error (Exception): The exception raised during execution. If the exception has a 'base_message' attribute, that message is recorded. agent: The agent instance on which fallback may be executed. """ - logger.debug( - f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}" - ) - self.tool_failure_consecutive_failures += 1 - max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) - logger.debug( - f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}" - ) + if not self.fallback_enabled: + return None + + failed_tool_call_name = self.extract_failed_tool_name(error) if ( - self.fallback_enabled - and self.tool_failure_consecutive_failures >= max_failures + self.current_failing_tool_name + and failed_tool_call_name != self.current_failing_tool_name + ): + logger.debug( + "New failing tool name identified. Resetting consecutive tool failures." + ) + self.reset_fallback_handler() + + logger.debug( + f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}" + ) + + self.current_failing_tool_name = failed_tool_call_name + self.current_tool_to_bind = self._find_tool_to_bind( + agent, failed_tool_call_name + ) + + if hasattr(error, "base_message") and error.base_message: + self.failed_messages.add(str(error.base_message)) + + self.tool_failure_consecutive_failures += 1 + logger.debug( + f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {self.max_failures}" + ) + + if ( + self.tool_failure_consecutive_failures >= self.max_failures and self.fallback_tool_models ): logger.debug( "_handle_tool_failure: threshold reached, invoking fallback mechanism." ) - return self.attempt_fallback(code, logger, agent) + return self.attempt_fallback() - def attempt_fallback(self, code: str, logger, agent): + def attempt_fallback(self): """ - Initiate the fallback process by selecting a fallback model and triggering the appropriate fallback method. + Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method. - Args: - code (str): The tool code that triggered the fallback. - logger: Logger instance for logging messages. - agent: The agent for which fallback is being executed. + Returns: + The response from a fallback model if any, otherwise None. """ - logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") - fallback_model = self.fallback_tool_models[0] - failed_tool_call_name = code logger.error( - f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" + f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}" ) cpm( - f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}.", + f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.", title="Fallback Notification", ) - if fallback_model.get("type", "prompt").lower() == "fc": - self.attempt_fallback_function(code, logger, agent) - else: - self.attempt_fallback_prompt(code, logger, agent) + for fallback_model in self.fallback_tool_models: + if fallback_model.get("type", "prompt").lower() == "fc": + response = self.attempt_fallback_function(fallback_model) + else: + response = self.attempt_fallback_prompt(fallback_model) + if response: + return response + cpm("All fallback models have failed", title="Fallback Failed") + return None def reset_fallback_handler(self): """ Reset the fallback handler's internal failure counters and clear the record of used fallback models. """ self.tool_failure_consecutive_failures = 0 + self.failed_messages.clear() self.tool_failure_used_fallbacks.clear() + self.fallback_tool_models = self._load_fallback_tool_models(self.config) + + def extract_failed_tool_name(self, error: ToolExecutionError): + if error.tool_name: + failed_tool_call_name = error.tool_name + else: + msg = str(error) + logger.debug("Error message: %s", msg) + match = re.search(r"name=['\"](\w+)['\"]", msg) + if match: + failed_tool_call_name = str(match.group(1)) + logger.debug( + "Extracted failed_tool_call_name using regex: %s", + failed_tool_call_name, + ) + else: + failed_tool_call_name = "Tool execution error" + raise Exception("Fallback failed: Could not extract failed tool name.") + + return failed_tool_call_name def _find_tool_to_bind(self, agent, failed_tool_call_name): logger.debug(f"failed_tool_call_name={failed_tool_call_name}") @@ -157,135 +200,108 @@ class FallbackHandler: None, ) if tool_to_bind is None: - from ra_aid.tool_configs import get_all_tools - all_tools = get_all_tools() tool_to_bind = next( (t for t in all_tools if t.func.__name__ == failed_tool_call_name), None, ) if tool_to_bind is None: - available = [t.func.__name__ for t in get_all_tools()] - logger.debug( - f"Failed to find tool: {failed_tool_call_name}. Available tools: {available}" + # TODO: Would be nice to try fuzzy match or levenstein str match to find closest correspond tool name + raise Exception( + f"Fallback failed: {failed_tool_call_name} not found in all tools." ) - raise Exception(f"Tool {failed_tool_call_name} not found in all tools.") return tool_to_bind - def attempt_fallback_prompt(self, code: str, logger, agent): + def attempt_fallback_prompt(self, fallback_model): """ - Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code. - - This method tries each fallback model (with retry logic configured) until one successfully executes the code. + Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model. Args: - code (str): The tool code to invoke via fallback. - logger: Logger instance for logging messages. - agent: The agent instance to update with the new model upon success. + fallback_model (dict): The fallback model to use. Returns: - The response from the fallback model invocation. - - Raises: - Exception: If all prompt-based fallback models fail. + The response from the fallback model invocation, or None if failed. """ - logger.debug("Attempting prompt-based fallback using fallback models") - failed_tool_call_name = code - for fallback_model in self.fallback_tool_models: - try: - logger.debug(f"Trying fallback model: {fallback_model['model']}") - simple_model = initialize_llm( - fallback_model["provider"], fallback_model["model"] - ) - tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name) - binded_model = simple_model.bind_tools( - [tool_to_bind], tool_choice=failed_tool_call_name - ) - # retry_model = binded_model.with_retry( - # stop_after_attempt=RETRY_FALLBACK_COUNT - # ) - response = binded_model.invoke(code) - cpm(f"response={response}") - - self.tool_failure_used_fallbacks.add(fallback_model["model"]) - - tool_call = self.base_message_to_tool_call_dict(response) - if tool_call: - result = self.invoke_prompt_tool_call(tool_call) - cpm(f"result={result}") - logger.debug( - "Prompt-based fallback executed successfully with model: " - + fallback_model["model"] - ) - self.reset_fallback_handler() - return result - else: - cpm( - response.content if hasattr(response, "content") else response, - title="Fallback Model Response: " + fallback_model["model"], - ) - return response - except Exception as e: - if isinstance(e, KeyboardInterrupt): - raise - logger.error( - f"Prompt-based fallback with model {fallback_model['model']} failed: {e}" - ) - raise Exception("All prompt-based fallback models failed") - - def attempt_fallback_function(self, code: str, logger, agent): - """ - Attempt a function-calling fallback by iterating over fallback models and invoking the provided code. - - This method tries each fallback model (with retry logic configured) until one successfully executes the code. - - Args: - code (str): The tool code to invoke via fallback. - logger: Logger instance for logging messages. - agent: The agent instance to update with the new model upon success. - - Returns: - The response from the fallback model invocation. - - Raises: - Exception: If all function-calling fallback models fail. - """ - logger.debug("Attempting function-calling fallback using fallback models") - failed_tool_call_name = code - for fallback_model in self.fallback_tool_models: - try: - logger.debug(f"Trying fallback model: {fallback_model['model']}") - simple_model = initialize_llm( - fallback_model["provider"], fallback_model["model"] - ) - tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name) - binded_model = simple_model.bind_tools( - [tool_to_bind], tool_choice=failed_tool_call_name - ) - retry_model = binded_model.with_retry( - stop_after_attempt=RETRY_FALLBACK_COUNT - ) - response = retry_model.invoke(code) - cpm(f"response={response}") - self.tool_failure_used_fallbacks.add(fallback_model["model"]) - self.reset_fallback_handler() + try: + logger.debug(f"Trying fallback model: {fallback_model['model']}") + simple_model = initialize_llm( + fallback_model["provider"], fallback_model["model"] + ) + binded_model = simple_model.bind_tools( + [self.current_tool_to_bind], + tool_choice=self.current_failing_tool_name, + ) + retry_model = binded_model.with_retry( + stop_after_attempt=RETRY_FALLBACK_COUNT + ) + response = retry_model.invoke(self.current_failing_tool_name) + cpm(f"response={response}") + self.tool_failure_used_fallbacks.add(fallback_model["model"]) + tool_call = self.base_message_to_tool_call_dict(response) + if tool_call: + result = self.invoke_prompt_tool_call(tool_call) + cpm(f"result={result}") logger.debug( - "Function-calling fallback executed successfully with model: " + "Prompt-based fallback executed successfully with model: " + fallback_model["model"] ) - + self.reset_fallback_handler() + return result + else: cpm( response.content if hasattr(response, "content") else response, title="Fallback Model Response: " + fallback_model["model"], ) return response - except Exception as e: - if isinstance(e, KeyboardInterrupt): - raise - logger.error( - f"Function-calling fallback with model {fallback_model['model']} failed: {e}" - ) - raise Exception("All function-calling fallback models failed") + except Exception as e: + if isinstance(e, KeyboardInterrupt): + raise + logger.error( + f"Prompt-based fallback with model {fallback_model['model']} failed: {e}" + ) + return None + + def attempt_fallback_function(self, fallback_model): + """ + Attempt a function-calling fallback by invoking the current failing tool with the given fallback model. + + Args: + fallback_model (dict): The fallback model to use. + + Returns: + The response from the fallback model invocation, or None if failed. + """ + try: + logger.debug(f"Trying fallback model: {fallback_model['model']}") + simple_model = initialize_llm( + fallback_model["provider"], fallback_model["model"] + ) + binded_model = simple_model.bind_tools( + [self.current_tool_to_bind], + tool_choice=self.current_failing_tool_name, + ) + retry_model = binded_model.with_retry( + stop_after_attempt=RETRY_FALLBACK_COUNT + ) + response = retry_model.invoke(self.current_failing_tool_name) + cpm(f"response={response}") + self.tool_failure_used_fallbacks.add(fallback_model["model"]) + logger.debug( + "Function-calling fallback executed successfully with model: " + + fallback_model["model"] + ) + cpm( + response.content if hasattr(response, "content") else response, + title="Fallback Model Response: " + fallback_model["model"], + ) + return response + except Exception as e: + if isinstance(e, KeyboardInterrupt): + raise + logger.error( + f"Function-calling fallback with model {fallback_model['model']} failed: {e}" + ) + return None def invoke_prompt_tool_call(self, tool_call_request: dict): """ diff --git a/ra_aid/logging_config.py b/ra_aid/logging_config.py index fb3bf63..ba4609f 100644 --- a/ra_aid/logging_config.py +++ b/ra_aid/logging_config.py @@ -1,9 +1,10 @@ import logging import sys from typing import Optional + from rich.console import Console -from rich.panel import Panel from rich.markdown import Markdown +from rich.panel import Panel class PrettyHandler(logging.Handler): diff --git a/ra_aid/tool_configs.py b/ra_aid/tool_configs.py index 8fce691..b982613 100644 --- a/ra_aid/tool_configs.py +++ b/ra_aid/tool_configs.py @@ -1,3 +1,5 @@ +from langchain_core.tools import BaseTool + from ra_aid.tools import ( ask_expert, ask_human, @@ -61,11 +63,13 @@ def get_read_only_tools( return tools + def get_all_tools_simple(): """Return a list containing all available tools using existing group methods.""" return get_all_tools() -def get_all_tools(): + +def get_all_tools() -> list[BaseTool]: """Return a list containing all available tools from different groups.""" all_tools = [] all_tools.extend(get_read_only_tools()) @@ -176,7 +180,7 @@ def get_implementation_tools( return tools -def get_web_research_tools(expert_enabled: bool = True) -> list: +def get_web_research_tools(expert_enabled: bool = True): """Get the list of tools available for web research. Args: @@ -196,9 +200,7 @@ def get_web_research_tools(expert_enabled: bool = True) -> list: return tools -def get_chat_tools( - expert_enabled: bool = True, web_research_enabled: bool = False -) -> list: +def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False): """Get the list of tools available in chat mode. Chat mode includes research and implementation capabilities but excludes diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 7a02a85..2e54f5a 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -276,11 +276,19 @@ def test_get_model_token_limit_planner(mock_memory): token_limit = get_model_token_limit(config, "planner") assert token_limit == 120000 + # New tests for private helper methods in agent_utils.py + def test_setup_and_restore_interrupt_handling(): - import signal, threading - from ra_aid.agent_utils import _setup_interrupt_handling, _restore_interrupt_handling, _request_interrupt + import signal + + from ra_aid.agent_utils import ( + _request_interrupt, + _restore_interrupt_handling, + _setup_interrupt_handling, + ) + original_handler = signal.getsignal(signal.SIGINT) handler = _setup_interrupt_handling() # Verify the SIGINT handler is set to _request_interrupt @@ -289,66 +297,93 @@ def test_setup_and_restore_interrupt_handling(): # Verify the SIGINT handler is restored to the original assert signal.getsignal(signal.SIGINT) == original_handler + def test_increment_and_decrement_agent_depth(): - from ra_aid.agent_utils import _increment_agent_depth, _decrement_agent_depth, _global_memory + from ra_aid.agent_utils import ( + _decrement_agent_depth, + _global_memory, + _increment_agent_depth, + ) + _global_memory["agent_depth"] = 10 _increment_agent_depth() assert _global_memory["agent_depth"] == 11 _decrement_agent_depth() assert _global_memory["agent_depth"] == 10 + def test_run_agent_stream(monkeypatch): - from ra_aid.agent_utils import _run_agent_stream, _global_memory + from ra_aid.agent_utils import _global_memory, _run_agent_stream + # Create a dummy agent that yields one chunk class DummyAgent: def stream(self, msg, cfg): yield {"content": "chunk1"} + dummy_agent = DummyAgent() # Set flags so that _run_agent_stream will reset them _global_memory["plan_completed"] = True _global_memory["task_completed"] = True _global_memory["completion_message"] = "existing" call_flag = {"called": False} + def fake_print_agent_output(chunk): call_flag["called"] = True - monkeypatch.setattr("ra_aid.agent_utils.print_agent_output", fake_print_agent_output) + + monkeypatch.setattr( + "ra_aid.agent_utils.print_agent_output", fake_print_agent_output + ) _run_agent_stream(dummy_agent, "dummy prompt", {}) assert call_flag["called"] assert _global_memory["plan_completed"] is False assert _global_memory["task_completed"] is False assert _global_memory["completion_message"] == "" + def test_execute_test_command_wrapper(monkeypatch): from ra_aid.agent_utils import _execute_test_command_wrapper + # Patch execute_test_command to return a testable tuple def fake_execute(config, orig, tests, auto): return (True, "new prompt", auto, tests + 1) + monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute) result = _execute_test_command_wrapper("orig", {}, 0, False) assert result == (True, "new prompt", False, 1) + def test_handle_api_error_valueerror(): - from ra_aid.agent_utils import _handle_api_error import pytest + + from ra_aid.agent_utils import _handle_api_error + # ValueError not containing "code" or "429" should be re-raised with pytest.raises(ValueError): _handle_api_error(ValueError("some error"), 0, 5, 1) + def test_handle_api_error_max_retries(): - from ra_aid.agent_utils import _handle_api_error import pytest + + from ra_aid.agent_utils import _handle_api_error + # When attempt reaches max retries, a RuntimeError should be raised with pytest.raises(RuntimeError): _handle_api_error(Exception("error code 429"), 4, 5, 1) + def test_handle_api_error_retry(monkeypatch): - from ra_aid.agent_utils import _handle_api_error import time + + from ra_aid.agent_utils import _handle_api_error + # Patch time.monotonic and time.sleep to simulate immediate delay expiration fake_time = [0] + def fake_monotonic(): fake_time[0] += 0.5 return fake_time[0] + monkeypatch.setattr(time, "monotonic", fake_monotonic) monkeypatch.setattr(time, "sleep", lambda s: None) # Should not raise error when attempt is lower than max retries diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index 6f0c285..c400a19 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -1,28 +1,41 @@ import unittest + from ra_aid.fallback_handler import FallbackHandler + class DummyLogger: def debug(self, msg): pass + def error(self, msg): pass + class DummyAgent: provider = "openai" tools = [] model = None + class TestFallbackHandler(unittest.TestCase): def setUp(self): - self.config = {"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"} + self.config = { + "max_tool_failures": 2, + "fallback_tool_models": "dummy-fallback-model", + } self.fallback_handler = FallbackHandler(self.config) self.logger = DummyLogger() self.agent = DummyAgent() def test_handle_failure_increments_counter(self): initial_failures = self.fallback_handler.tool_failure_consecutive_failures - self.fallback_handler.handle_failure("dummy_call()", Exception("Test error"), self.logger, self.agent) - self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1) + self.fallback_handler.handle_failure( + "dummy_call()", Exception("Test error"), self.logger, self.agent + ) + self.assertEqual( + self.fallback_handler.tool_failure_consecutive_failures, + initial_failures + 1, + ) def test_attempt_fallback_resets_counter(self): # Monkey-patch dummy functions for fallback components @@ -30,6 +43,7 @@ class TestFallbackHandler(unittest.TestCase): class DummyModel: def bind_tools(self, tools, tool_choice): pass + return DummyModel() def dummy_merge_chat_history(): @@ -39,6 +53,7 @@ class TestFallbackHandler(unittest.TestCase): return True import ra_aid.llm as llm + original_initialize = llm.initialize_llm original_merge = llm.merge_chat_history original_validate = llm.validate_provider_env @@ -47,12 +62,15 @@ class TestFallbackHandler(unittest.TestCase): llm.validate_provider_env = dummy_validate_provider_env self.fallback_handler.tool_failure_consecutive_failures = 2 - self.fallback_handler.attempt_fallback("dummy_tool_call()", self.logger, self.agent) + self.fallback_handler.attempt_fallback( + "dummy_tool_call()", self.logger, self.agent + ) self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0) llm.initialize_llm = original_initialize llm.merge_chat_history = original_merge llm.validate_provider_env = original_validate + if __name__ == "__main__": unittest.main() From 803acc616686f650c45a3ae9a1e2e481d8683f1b Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Wed, 12 Feb 2025 13:39:25 -0800 Subject: [PATCH 12/45] feat(agent_utils.py): convert fallback response to string for prompt concatenation to ensure proper formatting refactor(fallback_handler.py): change failed_messages from set to list for ordered message handling refactor(fallback_handler.py): update handle_failure method to accept ToolExecutionError type for better type safety refactor(fallback_handler.py): implement _reset_on_new_failure method to encapsulate failure reset logic feat(fallback_handler.py): add construct_prompt_msg_list method to create structured message list for fallback tool calls --- ra_aid/agent_utils.py | 2 +- ra_aid/fallback_handler.py | 61 ++++++++++++++++++++++++++++++-------- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 3aea6a8..ad1adf5 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -878,7 +878,7 @@ def run_agent_with_retry( except ToolExecutionError as e: fallback_response = fallback_handler.handle_failure(e, agent) if fallback_response: - prompt = original_prompt + "\n" + fallback_response + prompt = original_prompt + "\n" + str(fallback_response) continue except (KeyboardInterrupt, AgentInterrupt): raise diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 5e80826..4ae2016 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -4,6 +4,7 @@ import re from langchain_core.tools import BaseTool from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import BaseMessage +from langchain_core.messages import SystemMessage, HumanMessage from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.config import ( @@ -45,7 +46,7 @@ class FallbackHandler: self.fallback_tool_models = self._load_fallback_tool_models(config) self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) self.tool_failure_consecutive_failures = 0 - self.failed_messages = set() + self.failed_messages: list[BaseMessage] = [] self.tool_failure_used_fallbacks = set() self.current_failing_tool_name = "" self.current_tool_to_bind = None @@ -92,7 +93,9 @@ class FallbackHandler: cpm(message, title="Fallback Models") return final_models - def handle_failure(self, error: Exception, agent: CiaynAgent | CompiledGraph): + def handle_failure( + self, error: ToolExecutionError, agent: CiaynAgent | CompiledGraph + ): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. @@ -104,14 +107,7 @@ class FallbackHandler: return None failed_tool_call_name = self.extract_failed_tool_name(error) - if ( - self.current_failing_tool_name - and failed_tool_call_name != self.current_failing_tool_name - ): - logger.debug( - "New failing tool name identified. Resetting consecutive tool failures." - ) - self.reset_fallback_handler() + self._reset_on_new_failure(failed_tool_call_name) logger.debug( f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}" @@ -123,7 +119,7 @@ class FallbackHandler: ) if hasattr(error, "base_message") and error.base_message: - self.failed_messages.add(str(error.base_message)) + self.failed_messages.append(error.base_message) self.tool_failure_consecutive_failures += 1 logger.debug( @@ -172,6 +168,16 @@ class FallbackHandler: self.tool_failure_used_fallbacks.clear() self.fallback_tool_models = self._load_fallback_tool_models(self.config) + def _reset_on_new_failure(self, failed_tool_call_name): + if ( + self.current_failing_tool_name + and failed_tool_call_name != self.current_failing_tool_name + ): + logger.debug( + "New failing tool name identified. Resetting consecutive tool failures." + ) + self.reset_fallback_handler() + def extract_failed_tool_name(self, error: ToolExecutionError): if error.tool_name: failed_tool_call_name = error.tool_name @@ -234,7 +240,11 @@ class FallbackHandler: retry_model = binded_model.with_retry( stop_after_attempt=RETRY_FALLBACK_COUNT ) - response = retry_model.invoke(self.current_failing_tool_name) + # msg_list = [] + msg_list = self.construct_prompt_msg_list() + # response = retry_model.invoke(self.current_failing_tool_name) + response = retry_model.invoke(msg_list) + cpm(f"response={response}") self.tool_failure_used_fallbacks.add(fallback_model["model"]) tool_call = self.base_message_to_tool_call_dict(response) @@ -303,6 +313,33 @@ class FallbackHandler: ) return None + def construct_prompt_msg_list(self): + """ + Construct a list of chat messages for the fallback prompt. + The initial message instructs the assistant that it is a fallback tool caller. + Then includes the failed tool call messages from self.failed_messages. + Finally, it appends a human message asking it to retry calling the tool with correct valid arguments. + + Returns: + list: A list of chat messages. + """ + msg_list: list[BaseMessage] = [] + msg_list.append( + SystemMessage( + content="You are a fallback tool caller. Your only responsibility is to figure out what the previous failed tool call was trying to do and to call that tool with the correct format and arguments, using the provided failure messages." + ) + ) + if self.failed_messages: + # Convert to system messages to avoid API errors asking for correct msg structure + msg_list.extend([SystemMessage(str(msg)) for msg in self.failed_messages]) + + msg_list.append( + HumanMessage( + content=f"Retry using the tool '{self.current_failing_tool_name}' with improved arguments." + ) + ) + return msg_list + def invoke_prompt_tool_call(self, tool_call_request: dict): """ Invoke a tool call from a prompt-based fallback response. From 6e8b0f2e42b18640decedf6a4bfa28ae79f295e6 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Wed, 12 Feb 2025 13:50:37 -0800 Subject: [PATCH 13/45] chore(output.py): remove debug print statements for cleaner code and improved readability --- ra_aid/console/output.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index 7e62e1b..ee6fe08 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -44,8 +44,6 @@ def print_agent_output(chunk: Dict[str, Any]) -> None: ) ) tool_name = getattr(msg, "name", None) - cpm(f"type(msg): {type(msg)}") - cpm(f"msg: {msg}") raise ToolExecutionError(err_msg, tool_name=tool_name, base_message=msg) From 96b41458a1835a9d1cde4d29dba4114a91ef3223 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Wed, 12 Feb 2025 15:35:31 -0800 Subject: [PATCH 14/45] feat(agent_utils.py): refactor agent stream handling to improve clarity and maintainability by introducing reset_agent_completion_flags function feat(fallback_handler.py): enhance fallback handling by allowing RAgents type and improving error handling fix(config.py): update RAgents type definition to include both CompiledGraph and CiaynAgent for better type safety refactor(fallback_handler.py): streamline fallback model invocation and response handling for improved readability and functionality --- ra_aid/agent_utils.py | 41 ++++++--- ra_aid/config.py | 5 ++ ra_aid/fallback_handler.py | 171 +++++++++++++++---------------------- 3 files changed, 99 insertions(+), 118 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index ad1adf5..75b95b7 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -28,7 +28,7 @@ from rich.markdown import Markdown from rich.panel import Panel from ra_aid.agents.ciayn_agent import CiaynAgent -from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT +from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, RAgents from ra_aid.console.formatting import print_error, print_stage_header from ra_aid.console.output import print_agent_output from ra_aid.exceptions import AgentInterrupt, ToolExecutionError @@ -807,16 +807,10 @@ def _decrement_agent_depth(): _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 -def _run_agent_stream(agent: CompiledGraph, prompt: str, config: dict): - for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config): - logger.debug("Agent output: %s", chunk) - check_interrupt() - print_agent_output(chunk) - if _global_memory["plan_completed"] or _global_memory["task_completed"]: - _global_memory["plan_completed"] = False - _global_memory["task_completed"] = False - _global_memory["completion_message"] = "" - break +def reset_agent_completion_flags(): + _global_memory["plan_completed"] = False + _global_memory["task_completed"] = False + _global_memory["completion_message"] = "" def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test): @@ -842,8 +836,26 @@ def _handle_api_error(e, attempt, max_retries, base_delay): time.sleep(0.1) +def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict): + for chunk in agent.stream({"messages": msg_list}, config): + logger.debug("Agent output: %s", chunk) + check_interrupt() + print_agent_output(chunk) + if _global_memory["plan_completed"] or _global_memory["task_completed"]: + reset_agent_completion_flags() + break + check_interrupt() + print_agent_output(chunk) + if _global_memory["plan_completed"] or _global_memory["task_completed"]: + reset_agent_completion_flags() + break + + def run_agent_with_retry( - agent, prompt: str, config: dict, fallback_handler: FallbackHandler + agent: RAgents, + prompt: str, + config: dict, + fallback_handler: FallbackHandler, ) -> Optional[str]: """Run an agent with retry logic for API errors.""" logger.debug("Running agent with prompt length: %d", len(prompt)) @@ -854,6 +866,7 @@ def run_agent_with_retry( _max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) auto_test = config.get("auto_test", False) original_prompt = prompt + msg_list = [HumanMessage(content=prompt)] with InterruptibleSection(): try: @@ -862,7 +875,7 @@ def run_agent_with_retry( logger.debug("Attempt %d/%d", attempt + 1, max_retries) check_interrupt() try: - _run_agent_stream(agent, prompt, config) + _run_agent_stream(agent, msg_list, config) fallback_handler.reset_fallback_handler() should_break, prompt, auto_test, test_attempts = ( _execute_test_command_wrapper( @@ -878,7 +891,7 @@ def run_agent_with_retry( except ToolExecutionError as e: fallback_response = fallback_handler.handle_failure(e, agent) if fallback_response: - prompt = original_prompt + "\n" + str(fallback_response) + msg_list.extend(fallback_response) continue except (KeyboardInterrupt, AgentInterrupt): raise diff --git a/ra_aid/config.py b/ra_aid/config.py index 54d7995..8bfaf5c 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -15,3 +15,8 @@ VALID_PROVIDERS = [ "deepseek", "gemini", ] + +from ra_aid.agents.ciayn_agent import CiaynAgent +from langgraph.graph.graph import CompiledGraph + +RAgents = CompiledGraph | CiaynAgent diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 4ae2016..c1c5ea2 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -1,16 +1,16 @@ import json import re +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import BaseTool -from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import BaseMessage -from langchain_core.messages import SystemMessage, HumanMessage -from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.config import ( DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, + RAgents, ) from ra_aid.console.output import cpm from ra_aid.exceptions import ToolExecutionError @@ -47,11 +47,18 @@ class FallbackHandler: self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) self.tool_failure_consecutive_failures = 0 self.failed_messages: list[BaseMessage] = [] - self.tool_failure_used_fallbacks = set() self.current_failing_tool_name = "" - self.current_tool_to_bind = None + self.current_tool_to_bind: None | BaseTool = None - def _load_fallback_tool_models(self, config): + cpm( + "Fallback models selected: " + ", ".join([self._format_model(m) for m in self.fallback_tool_models]), + title="Fallback Models", + ) + + def _format_model(self, m: dict) -> str: + return f"{m.get('model', '')} ({m.get('type', 'prompt')})" + + def _load_fallback_tool_models(self, _config): """ Load and return fallback tool models based on the provided configuration. @@ -90,12 +97,9 @@ class FallbackHandler: "\nSkipped top tool calling models due to missing provider ENV API keys: " + ", ".join(skipped) ) - cpm(message, title="Fallback Models") return final_models - def handle_failure( - self, error: ToolExecutionError, agent: CiaynAgent | CompiledGraph - ): + def handle_failure(self, error: ToolExecutionError, agent: RAgents): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. @@ -137,10 +141,10 @@ class FallbackHandler: def attempt_fallback(self): """ - Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method. + Initiate the fallback process by iterating over all fallback models to attempt to fix the failing tool call. Returns: - The response from a fallback model if any, otherwise None. + List of [raw_llm_response (SystemMessage), tool_call_result (SystemMessage)] or None. """ logger.error( f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}" @@ -150,12 +154,10 @@ class FallbackHandler: title="Fallback Notification", ) for fallback_model in self.fallback_tool_models: - if fallback_model.get("type", "prompt").lower() == "fc": - response = self.attempt_fallback_function(fallback_model) - else: - response = self.attempt_fallback_prompt(fallback_model) - if response: - return response + result_list = self.invoke_fallback(fallback_model) + if result_list: + msg_list_response = [SystemMessage(str(msg)) for msg in result_list] + return msg_list_response cpm("All fallback models have failed", title="Fallback Failed") return None @@ -165,8 +167,9 @@ class FallbackHandler: """ self.tool_failure_consecutive_failures = 0 self.failed_messages.clear() - self.tool_failure_used_fallbacks.clear() self.fallback_tool_models = self._load_fallback_tool_models(self.config) + self.current_failing_tool_name = "" + self.current_tool_to_bind = None def _reset_on_new_failure(self, failed_tool_call_name): if ( @@ -218,9 +221,21 @@ class FallbackHandler: ) return tool_to_bind - def attempt_fallback_prompt(self, fallback_model): + def _bind_tool_model(self, simple_model: BaseChatModel, fallback_model): + if fallback_model.get("type", "prompt").lower() == "fc": + # Force tool calling with tool_choice param. + bound_model = simple_model.bind_tools( + [self.current_tool_to_bind], + tool_choice=self.current_failing_tool_name, + ) + else: + # Do not force tool calling (Prompt method) + bound_model = simple_model.bind_tools([self.current_tool_to_bind]) + return bound_model + + def invoke_fallback(self, fallback_model): """ - Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model. + Attempt a Prompt or function-calling fallback by invoking the current failing tool with the given fallback model. Args: fallback_model (dict): The fallback model to use. @@ -229,88 +244,33 @@ class FallbackHandler: The response from the fallback model invocation, or None if failed. """ try: - logger.debug(f"Trying fallback model: {fallback_model['model']}") + logger.debug(f"Trying fallback model: {self._format_model(fallback_model)}") simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) - binded_model = simple_model.bind_tools( - [self.current_tool_to_bind], - tool_choice=self.current_failing_tool_name, - ) - retry_model = binded_model.with_retry( + + bound_model = self._bind_tool_model(simple_model, fallback_model) + + retry_model = bound_model.with_retry( stop_after_attempt=RETRY_FALLBACK_COUNT ) - # msg_list = [] + msg_list = self.construct_prompt_msg_list() - # response = retry_model.invoke(self.current_failing_tool_name) response = retry_model.invoke(msg_list) - cpm(f"response={response}") - self.tool_failure_used_fallbacks.add(fallback_model["model"]) + logger.debug(f"raw llm response={response}") tool_call = self.base_message_to_tool_call_dict(response) - if tool_call: - result = self.invoke_prompt_tool_call(tool_call) - cpm(f"result={result}") - logger.debug( - "Prompt-based fallback executed successfully with model: " - + fallback_model["model"] - ) - self.reset_fallback_handler() - return result - else: - cpm( - response.content if hasattr(response, "content") else response, - title="Fallback Model Response: " + fallback_model["model"], - ) - return response + + tool_call_result = self.invoke_prompt_tool_call(tool_call) + cpm(str(tool_call_result), title="Fallback Tool Call Result") + logger.debug(f"Fallback call successful with model: {self._format_model(fallback_model)}") + + self.reset_fallback_handler() + return [response, tool_call_result] except Exception as e: if isinstance(e, KeyboardInterrupt): raise - logger.error( - f"Prompt-based fallback with model {fallback_model['model']} failed: {e}" - ) - return None - - def attempt_fallback_function(self, fallback_model): - """ - Attempt a function-calling fallback by invoking the current failing tool with the given fallback model. - - Args: - fallback_model (dict): The fallback model to use. - - Returns: - The response from the fallback model invocation, or None if failed. - """ - try: - logger.debug(f"Trying fallback model: {fallback_model['model']}") - simple_model = initialize_llm( - fallback_model["provider"], fallback_model["model"] - ) - binded_model = simple_model.bind_tools( - [self.current_tool_to_bind], - tool_choice=self.current_failing_tool_name, - ) - retry_model = binded_model.with_retry( - stop_after_attempt=RETRY_FALLBACK_COUNT - ) - response = retry_model.invoke(self.current_failing_tool_name) - cpm(f"response={response}") - self.tool_failure_used_fallbacks.add(fallback_model["model"]) - logger.debug( - "Function-calling fallback executed successfully with model: " - + fallback_model["model"] - ) - cpm( - response.content if hasattr(response, "content") else response, - title="Fallback Model Response: " + fallback_model["model"], - ) - return response - except Exception as e: - if isinstance(e, KeyboardInterrupt): - raise - logger.error( - f"Function-calling fallback with model {fallback_model['model']} failed: {e}" - ) + logger.error(f"Fallback with model {self._format_model(fallback_model)} failed: {e}") return None def construct_prompt_msg_list(self): @@ -367,19 +327,22 @@ class FallbackHandler: otherwise None. """ tool_calls = self.get_tool_calls(response) - if tool_calls: - if len(tool_calls) > 1: - logger.warning("Multiple tool calls detected, using the first one") - tool_call = tool_calls[0] - return { - "id": tool_call["id"], - "type": tool_call["type"], - "name": tool_call["function"]["name"], - "arguments": self._parse_tool_arguments( - tool_call["function"]["arguments"] - ), - } - return None + + if not tool_calls: + raise Exception( + f"Could not extract tool_call_dict from response: {response}" + ) + + if len(tool_calls) > 1: + logger.warning("Multiple tool calls detected, using the first one") + + tool_call = tool_calls[0] + return { + "id": tool_call["id"], + "type": tool_call["type"], + "name": tool_call["function"]["name"], + "arguments": self._parse_tool_arguments(tool_call["function"]["arguments"]), + } def _parse_tool_arguments(self, tool_arguments): """ From e508e4d1f2abda10c3e8002940fd6fc7048d5bb8 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Wed, 12 Feb 2025 17:55:43 -0800 Subject: [PATCH 15/45] feat(agent_utils.py): introduce get_agent_type function to determine agent type and improve code clarity refactor(agent_utils.py): update _run_agent_stream to utilize agent type for output printing fix(ciayn_agent.py): modify _execute_tool to handle BaseMessage and improve error reporting feat(ciayn_agent.py): add extract_tool_name method to identify tool names from code chore(agents_alias.py): create agents_alias module to avoid circular imports and define RAgents type refactor(config.py): remove direct import of CiaynAgent and update RAgents definition fix(output.py): update print_agent_output to accept agent type for better error handling fix(exceptions.py): add CiaynToolExecutionError for distinguishing tool execution failures refactor(fallback_handler.py): improve logging and error handling in fallback mechanism --- ra_aid/agent_utils.py | 35 ++++++++++++++++-------- ra_aid/agents/ciayn_agent.py | 53 +++++++++++++++++++++++++++--------- ra_aid/agents_alias.py | 10 +++++++ ra_aid/config.py | 5 ---- ra_aid/console/output.py | 11 ++++++-- ra_aid/exceptions.py | 18 ++++++++++++ ra_aid/fallback_handler.py | 13 ++++++--- 7 files changed, 108 insertions(+), 37 deletions(-) create mode 100644 ra_aid/agents_alias.py diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 75b95b7..0f67623 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -19,7 +19,6 @@ from langchain_core.messages import ( ) from langchain_core.tools import tool from langgraph.checkpoint.memory import MemorySaver -from langgraph.graph.graph import CompiledGraph from langgraph.prebuilt import create_react_agent from langgraph.prebuilt.chat_agent_executor import AgentState from litellm import get_model_info @@ -28,7 +27,8 @@ from rich.markdown import Markdown from rich.panel import Panel from ra_aid.agents.ciayn_agent import CiaynAgent -from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, RAgents +from ra_aid.agents_alias import RAgents +from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT from ra_aid.console.formatting import print_error, print_stage_header from ra_aid.console.output import print_agent_output from ra_aid.exceptions import AgentInterrupt, ToolExecutionError @@ -836,16 +836,24 @@ def _handle_api_error(e, attempt, max_retries, base_delay): time.sleep(0.1) +def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]: + """ + Determines the type of the agent. + Returns "CiaynAgent" if agent is an instance of CiaynAgent, otherwise "React". + """ + + if isinstance(agent, CiaynAgent): + return "CiaynAgent" + else: + return "React" + + def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict): for chunk in agent.stream({"messages": msg_list}, config): logger.debug("Agent output: %s", chunk) check_interrupt() - print_agent_output(chunk) - if _global_memory["plan_completed"] or _global_memory["task_completed"]: - reset_agent_completion_flags() - break - check_interrupt() - print_agent_output(chunk) + agent_type = get_agent_type(agent) + print_agent_output(chunk, agent_type) if _global_memory["plan_completed"] or _global_memory["task_completed"]: reset_agent_completion_flags() break @@ -889,10 +897,13 @@ def run_agent_with_retry( logger.debug("Agent run completed successfully") return "Agent run completed successfully" except ToolExecutionError as e: - fallback_response = fallback_handler.handle_failure(e, agent) - if fallback_response: - msg_list.extend(fallback_response) - continue + print("except ToolExecutionError in AGENT UTILS") + if not isinstance(agent, CiaynAgent): + logger.debug("AGENT UTILS ToolExecutionError called!") + fallback_response = fallback_handler.handle_failure(e, agent) + if fallback_response: + msg_list.extend(fallback_response) + continue except (KeyboardInterrupt, AgentInterrupt): raise except ( diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 4060684..59e4e2d 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -2,9 +2,13 @@ import re from dataclasses import dataclass from typing import Any, Dict, Generator, List, Optional, Union +from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.tools import BaseTool +from ra_aid.console.output import cpm from ra_aid.exceptions import ToolExecutionError +from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT from ra_aid.tools.reflection import get_function_info @@ -70,8 +74,8 @@ class CiaynAgent: def __init__( self, - model, - tools: list, + model: BaseChatModel, + tools: list[BaseTool], max_history_messages: int = 50, max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT, config: Optional[dict] = None, @@ -97,8 +101,10 @@ class CiaynAgent: self.available_functions = [] for t in tools: self.available_functions.append(get_function_info(t.func)) + self.tool_failure_current_provider = None self.tool_failure_current_model = None + self.fallback_handler = FallbackHandler(config, tools) def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" @@ -229,8 +235,11 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" return base_prompt - def _execute_tool(self, code: str) -> str: + def _execute_tool(self, msg: BaseMessage) -> str: """Execute a tool call and return its result.""" + + cpm(f"execute_tool msg: { msg }") + code = msg.content globals_dict = {tool.func.__name__: tool.func for tool in self.tools} try: @@ -240,9 +249,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" # if the eval fails, try to extract it via a model call if validate_function_call_pattern(code): functions_list = "\n\n".join(self.available_functions) - logger.debug(f"_execute_tool: code before extraction: {code}") code = _extract_tool_call(code, functions_list) - logger.debug(f"_execute_tool: code after extraction: {code}") logger.debug( f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}" @@ -251,8 +258,15 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" logger.debug(f"_execute_tool: result: {result}") return result except Exception as e: - error_msg = f"Error executing code: {str(e)}" - raise ToolExecutionError(error_msg) + error_msg = f"Error: {str(e)} \n Could not excute code: {code}" + tool_name = self.extract_tool_name(code) + raise ToolExecutionError(error_msg, base_message=msg, tool_name=tool_name) + + def extract_tool_name(self, code: str) -> str: + match = re.match(r"\s*([\w_\-]+)\s*\(", code) + if match: + return match.group(1) + return "" def _create_agent_chunk(self, content: str) -> Dict[str, Any]: """Create an agent chunk in the format expected by print_agent_output.""" @@ -354,18 +368,31 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" try: logger.debug(f"Code generated by agent: {response.content}") - last_result = self._execute_tool(response.content) + last_result = self._execute_tool(response) chat_history.append(response) first_iteration = False yield {} except ToolExecutionError as e: - chat_history.append( - HumanMessage( - content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again." + fallback_response = self.fallback_handler.handle_failure(e, self) + print(f"fallback_response={fallback_response}") + if fallback_response: + hm = HumanMessage( + content="The fallback handler has fixed your tool call results are in the last System message." ) - ) - yield self._create_error_chunk(str(e)) + chat_history.extend(fallback_response) + chat_history.append(hm) + logger.debug("Appended fallback response to chat history.") + yield {} + else: + yield self._create_error_chunk(str(e)) + # yield {"messages": [fallback_response[-1]]} + + # chat_history.append( + # HumanMessage( + # content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again." + # ) + # ) def _extract_tool_call(code: str, functions_list: str) -> str: diff --git a/ra_aid/agents_alias.py b/ra_aid/agents_alias.py new file mode 100644 index 0000000..d3e74c0 --- /dev/null +++ b/ra_aid/agents_alias.py @@ -0,0 +1,10 @@ +from langgraph.graph.graph import CompiledGraph +from typing import TYPE_CHECKING + +# Unfortunately need this to avoid Circular Imports +if TYPE_CHECKING: + from ra_aid.agents.ciayn_agent import CiaynAgent + + RAgents = CompiledGraph | CiaynAgent +else: + RAgents = CompiledGraph diff --git a/ra_aid/config.py b/ra_aid/config.py index 8bfaf5c..54d7995 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -15,8 +15,3 @@ VALID_PROVIDERS = [ "deepseek", "gemini", ] - -from ra_aid.agents.ciayn_agent import CiaynAgent -from langgraph.graph.graph import CompiledGraph - -RAgents = CompiledGraph | CiaynAgent diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index ee6fe08..9b06508 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional from langchain_core.messages import AIMessage from rich.markdown import Markdown @@ -10,7 +10,9 @@ from ra_aid.exceptions import ToolExecutionError from .formatting import console -def print_agent_output(chunk: Dict[str, Any]) -> None: +def print_agent_output( + chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"] +) -> None: """Print only the agent's message content, not tool calls. Args: @@ -44,7 +46,10 @@ def print_agent_output(chunk: Dict[str, Any]) -> None: ) ) tool_name = getattr(msg, "name", None) - raise ToolExecutionError(err_msg, tool_name=tool_name, base_message=msg) + if agent_type == "React": + raise ToolExecutionError( + err_msg, tool_name=tool_name, base_message=msg + ) def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None: diff --git a/ra_aid/exceptions.py b/ra_aid/exceptions.py index 34710d9..b7c714b 100644 --- a/ra_aid/exceptions.py +++ b/ra_aid/exceptions.py @@ -31,3 +31,21 @@ class ToolExecutionError(Exception): super().__init__(message) self.base_message = base_message self.tool_name = tool_name + + +class CiaynToolExecutionError(Exception): + """Exception raised when a tool execution fails. + + This exception is used to distinguish tool execution failures + from other types of errors in the agent system. + """ + + def __init__( + self, + message: str, + base_message: Optional[BaseMessage] = None, + tool_name: Optional[str] = None, + ): + super().__init__(message) + self.base_message = base_message + self.tool_name = tool_name diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index c1c5ea2..e9bb9f3 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -6,11 +6,11 @@ from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import BaseTool from langgraph.graph.message import BaseMessage +from ra_aid.agents_alias import RAgents from ra_aid.config import ( DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, - RAgents, ) from ra_aid.console.output import cpm from ra_aid.exceptions import ToolExecutionError @@ -51,7 +51,8 @@ class FallbackHandler: self.current_tool_to_bind: None | BaseTool = None cpm( - "Fallback models selected: " + ", ".join([self._format_model(m) for m in self.fallback_tool_models]), + "Fallback models selected: " + + ", ".join([self._format_model(m) for m in self.fallback_tool_models]), title="Fallback Models", ) @@ -263,14 +264,18 @@ class FallbackHandler: tool_call_result = self.invoke_prompt_tool_call(tool_call) cpm(str(tool_call_result), title="Fallback Tool Call Result") - logger.debug(f"Fallback call successful with model: {self._format_model(fallback_model)}") + logger.debug( + f"Fallback call successful with model: {self._format_model(fallback_model)}" + ) self.reset_fallback_handler() return [response, tool_call_result] except Exception as e: if isinstance(e, KeyboardInterrupt): raise - logger.error(f"Fallback with model {self._format_model(fallback_model)} failed: {e}") + logger.error( + f"Fallback with model {self._format_model(fallback_model)} failed: {e}" + ) return None def construct_prompt_msg_list(self): From 646d509c221de27eeed356789cf2a92fd2baa0b8 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 00:24:31 -0800 Subject: [PATCH 16/45] feat(agent_utils.py): add agent_type retrieval to enhance fallback handling logic feat(ciayn_agent.py): implement chat_history in CiaynAgent for improved context management during tool execution refactor(ciayn_agent.py): streamline fallback response handling and logging for better clarity and maintainability --- ra_aid/agent_utils.py | 18 ++++++++---- ra_aid/agents/ciayn_agent.py | 54 ++++++++++++++++++------------------ 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 0f67623..f467e5f 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -874,6 +874,7 @@ def run_agent_with_retry( _max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) auto_test = config.get("auto_test", False) original_prompt = prompt + agent_type = get_agent_type(agent) msg_list = [HumanMessage(content=prompt)] with InterruptibleSection(): @@ -898,12 +899,19 @@ def run_agent_with_retry( return "Agent run completed successfully" except ToolExecutionError as e: print("except ToolExecutionError in AGENT UTILS") - if not isinstance(agent, CiaynAgent): - logger.debug("AGENT UTILS ToolExecutionError called!") - fallback_response = fallback_handler.handle_failure(e, agent) - if fallback_response: + logger.debug("AGENT UTILS ToolExecutionError called!") + fallback_response = fallback_handler.handle_failure(e, agent) + if fallback_response: + if agent_type == "React": msg_list.extend(fallback_response) - continue + else: + agent.chat_history.extend(fallback_response) + agent.chat_history.append( + HumanMessage( + content="Fallback tool handler successfully ran your tool call. See last message for result." + ) + ) + continue except (KeyboardInterrupt, AgentInterrupt): raise except ( diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 59e4e2d..211877c 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -98,6 +98,7 @@ class CiaynAgent: self.tools = tools self.max_history_messages = max_history_messages self.max_tokens = max_tokens + self.chat_history = [] self.available_functions = [] for t in tools: self.available_functions.append(get_function_info(t.func)) @@ -105,6 +106,9 @@ class CiaynAgent: self.tool_failure_current_provider = None self.tool_failure_current_model = None self.fallback_handler = FallbackHandler(config, tools) + self.sys_message = SystemMessage( + "Execute efficiently yet completely as a fully autonomous agent." + ) def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" @@ -348,51 +352,47 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" ) -> Generator[Dict[str, Any], None, None]: """Stream agent responses in a format compatible with print_agent_output.""" initial_messages = messages_dict.get("messages", []) - chat_history = [] + # self.chat_history = [] last_result = None first_iteration = True while True: base_prompt = self._build_prompt(None if first_iteration else last_result) - chat_history.append(HumanMessage(content=base_prompt)) + self.chat_history.append(HumanMessage(content=base_prompt)) - full_history = self._trim_chat_history(initial_messages, chat_history) - response = self.model.invoke( - [ - SystemMessage( - "Execute efficiently yet completely as a fully autonomous agent." - ) - ] - + full_history - ) + full_history = self._trim_chat_history(initial_messages, self.chat_history) + response = self.model.invoke([self.sys_message] + full_history) try: logger.debug(f"Code generated by agent: {response.content}") last_result = self._execute_tool(response) - chat_history.append(response) + self.chat_history.append(response) first_iteration = False yield {} except ToolExecutionError as e: - fallback_response = self.fallback_handler.handle_failure(e, self) - print(f"fallback_response={fallback_response}") - if fallback_response: - hm = HumanMessage( - content="The fallback handler has fixed your tool call results are in the last System message." - ) - chat_history.extend(fallback_response) - chat_history.append(hm) - logger.debug("Appended fallback response to chat history.") - yield {} - else: - yield self._create_error_chunk(str(e)) - # yield {"messages": [fallback_response[-1]]} - - # chat_history.append( + # self.chat_history.append( # HumanMessage( # content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again." # ) # ) + raise e + # yield self._create_error_chunk(str(e)) + yield {} + + # fallback_response = self.fallback_handler.handle_failure(e, self) + # print(f"fallback_response={fallback_response}") + # if fallback_response: + # hm = HumanMessage( + # content="The fallback handler has fixed your tool call results are in the last System message." + # ) + # self.chat_history.extend(fallback_response) + # self.chat_history.append(hm) + # logger.debug("Appended fallback response to chat history.") + # yield {} + # else: + # yield self._create_error_chunk(str(e)) + # yield {"messages": [fallback_response[-1]]} def _extract_tool_call(code: str, functions_list: str) -> str: From e2cd51c66d1d4e68a131f17e70e81b2e2b206c95 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 16:14:24 -0800 Subject: [PATCH 17/45] feat(agent_utils.py): add SystemMessage import and improve logging messages for clarity fix(agent_utils.py): handle fallback responses more effectively and ensure fallback handler is optional refactor(ciayn_agent.py): streamline prompt building and extract tool call logic into a separate method chore(ciayn_agent.py): remove commented-out code and improve fallback response handling chore(exceptions.py): remove unused CiaynToolExecutionError class to clean up code chore(fallback_handler.py): simplify fallback response handling logic chore(logging_config.py): add debug print statement for logging handler usage chore(prompts.py): update prompts for clarity and maintainability --- ra_aid/agent_utils.py | 25 ++-- ra_aid/agents/ciayn_agent.py | 221 +++++++---------------------------- ra_aid/exceptions.py | 18 --- ra_aid/fallback_handler.py | 4 +- ra_aid/logging_config.py | 1 + ra_aid/prompts.py | 127 ++++++++++++++++++++ 6 files changed, 188 insertions(+), 208 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index f467e5f..5df995f 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -15,6 +15,7 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import ( BaseMessage, HumanMessage, + SystemMessage, trim_messages, ) from langchain_core.tools import tool @@ -408,7 +409,7 @@ def run_research_agent( display_project_status(project_info) if agent is not None: - logger.debug("Research agent completed successfully") + logger.debug("Research agent created successfully") fallback_handler = FallbackHandler(config, tools) _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) if _result: @@ -863,7 +864,7 @@ def run_agent_with_retry( agent: RAgents, prompt: str, config: dict, - fallback_handler: FallbackHandler, + fallback_handler: Optional[FallbackHandler], ) -> Optional[str]: """Run an agent with retry logic for API errors.""" logger.debug("Running agent with prompt length: %d", len(prompt)) @@ -885,7 +886,8 @@ def run_agent_with_retry( check_interrupt() try: _run_agent_stream(agent, msg_list, config) - fallback_handler.reset_fallback_handler() + if fallback_handler: + fallback_handler.reset_fallback_handler() should_break, prompt, auto_test, test_attempts = ( _execute_test_command_wrapper( original_prompt, config, test_attempts, auto_test @@ -900,18 +902,19 @@ def run_agent_with_retry( except ToolExecutionError as e: print("except ToolExecutionError in AGENT UTILS") logger.debug("AGENT UTILS ToolExecutionError called!") + if not fallback_handler: + continue + fallback_response = fallback_handler.handle_failure(e, agent) if fallback_response: if agent_type == "React": - msg_list.extend(fallback_response) + msg_list_response = [ + SystemMessage(str(msg)) for msg in fallback_response + ] + msg_list.extend(msg_list_response) else: - agent.chat_history.extend(fallback_response) - agent.chat_history.append( - HumanMessage( - content="Fallback tool handler successfully ran your tool call. See last message for result." - ) - ) - continue + pass + continue except (KeyboardInterrupt, AgentInterrupt): raise except ( diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 211877c..f870316 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -6,12 +6,15 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool +from ra_aid.logging_config import get_logger +from ra_aid.tools.expert import get_model +from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT from ra_aid.console.output import cpm from ra_aid.exceptions import ToolExecutionError from ra_aid.fallback_handler import FallbackHandler -from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT from ra_aid.tools.reflection import get_function_info +from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES logger = get_logger(__name__) @@ -113,6 +116,7 @@ class CiaynAgent: def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" base_prompt = "" + if last_result is not None: base_prompt += f"\n{last_result}" @@ -120,122 +124,9 @@ class CiaynAgent: functions_list = "\n\n".join(self.available_functions) # Build the complete prompt without f-strings for the static parts - base_prompt += ( - """ + base_prompt += CIAYN_AGENT_BASE_PROMPT.format(functions_list=functions_list) - -You are a ReAct agent. You run in a loop and use ONE of the available functions per iteration, but you will be called in a loop, so you will be able to accomplish the task over many iterations. -The result of that function call will be given to you in the next message. -Call one function at a time. Function arguments can be complex objects, long strings, etc. if needed. -The user cannot see the results of function calls, so you have to explicitly use a tool like ask_human if you want them to see something. -You must always respond with a single line of python that calls one of the available tools. -Use as many steps as you need to in order to fully complete the task. -Start by asking the user what they want. - -You must carefully review the conversation history, which functions were called so far, returned results, etc., and make sure the very next function call you make makes sense in order to achieve the original goal. -You are expected to use as many steps as necessary to completely achieve the user's request, making many tool calls along the way. -Think hard about what the best *next* tool call is, knowing that you can make as many calls as you need to after that. -You typically don't want to keep calling the same function over and over with the same parameters. - - -You must ONLY use ONE of the following functions (these are the ONLY functions that exist): - -""" - + functions_list - + """ - - -You may use any of the above functions to complete your job. Use the best one for the current step you are on. Be efficient, avoid getting stuck in repetitive loops, and do not hesitate to call functions which delegate your work to make your life easier. -But you MUST NOT assume tools exist that are not in the above list, e.g. write_file_tool. -Consider your task done only once you have taken *ALL* the steps required to complete it. - ---- EXAMPLE BAD OUTPUTS --- - -This tool is not in available functions, so this is a bad tool call: - - -write_file_tool(...) - - -This tool call has a syntax error (unclosed parenthesis, quotes), so it is bad: - - -write_file_tool("asdf - - -This tool call is bad because it includes a message as well as backticks: - - -Sure, I'll make the following tool call to accomplish what you asked me: - -``` -list_directory_tree('.') -``` - - -This tool call is bad because the output code is surrounded with backticks: - - -``` -list_directory_tree('.') -``` - - -The following is bad becasue it makes the same tool call multiple times in a row with the exact same parameters, for no reason, getting stuck in a loop: - - - -list_directory_tree('.') - - -list_directory_tree('.') - - - -The following is bad because it makes more than one tool call in one response: - - -list_directory_tree('.') -read_file_tool('README.md') # Now we've made - -request_research_and_implementation(\"\"\" -Example query. -\"\"\") - - -This is good output because it uses a multiple line string when needed and properly calls the tool, does not output backticks or extra information: - -run_programming_task(\"\"\" -# Example Programming Task - -Implement a widget factory satisfying the following requirements: - -- Requirement A -- Requirement B - -... -\"\"\") - - -As an agent, you will carefully plan ahead, carefully analyze tool call responses, and adapt to circumstances in order to accomplish your goal. - -You will make as many tool calls as you feel necessary in order to fully complete the task. - -We're entrusting you with a lot of autonomy and power, so be efficient and don't mess up. - -You have often been criticized for: - -- Making the same function calls over and over, getting stuck in a loop. - -DO NOT CLAIM YOU ARE FINISHED UNTIL YOU ACTUALLY ARE! -Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" - ) + # base_prompt += "\n\nYou must reply with ONLY ONE of the functions given in available functions." return base_prompt @@ -253,7 +144,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" # if the eval fails, try to extract it via a model call if validate_function_call_pattern(code): functions_list = "\n\n".join(self.available_functions) - code = _extract_tool_call(code, functions_list) + code = self._extract_tool_call(code, functions_list) logger.debug( f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}" @@ -272,6 +163,22 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" return match.group(1) return "" + def handle_fallback_response( + self, fallback_response: list[Any], e: ToolExecutionError + ) -> str: + err_msg = HumanMessage(content=self.error_message_template.format(e=e)) + self.chat_history.append(err_msg) + + if not fallback_response: + return "" + + msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n" + # Passing the fallback invocation may confuse our llm, as invocation methods may differ. + # msg += f"{fallback_response[0]}\n" + msg += f"{e.tool_name}" + msg += f"{fallback_response[1]}" + return msg + def _create_agent_chunk(self, content: str) -> Dict[str, Any]: """Create an agent chunk in the format expected by print_agent_output.""" return {"agent": {"messages": [AIMessage(content=content)]}} @@ -284,7 +191,6 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" @staticmethod def _estimate_tokens(content: Optional[Union[str, BaseMessage]]) -> int: """Estimate number of tokens in content using simple byte length heuristic. - Estimates 1 token per 2.0 bytes of content. For messages, uses the content field. Args: @@ -310,6 +216,22 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" return len(text.encode("utf-8")) // 2.0 + def _extract_tool_call(self, code: str, functions_list: str) -> str: + model = get_model() + prompt = EXTRACT_TOOL_CALL_PROMPT.format( + functions_list=functions_list, code=code + ) + response = model.invoke(prompt) + response = response.content + + pattern = r"([\w_\-]+)\((.*?)\)" + matches = re.findall(pattern, response, re.DOTALL) + if len(matches) == 0: + raise ToolExecutionError("Failed to extract tool call") + ma = matches[0][0].strip() + mb = matches[0][1].strip().replace("\n", " ") + return f"{ma}({mb})" + def _trim_chat_history( self, initial_messages: List[Any], chat_history: List[Any] ) -> List[Any]: @@ -352,14 +274,12 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" ) -> Generator[Dict[str, Any], None, None]: """Stream agent responses in a format compatible with print_agent_output.""" initial_messages = messages_dict.get("messages", []) - # self.chat_history = [] + self.chat_history = [] last_result = None - first_iteration = True while True: - base_prompt = self._build_prompt(None if first_iteration else last_result) + base_prompt = self._build_prompt(last_result) self.chat_history.append(HumanMessage(content=base_prompt)) - full_history = self._trim_chat_history(initial_messages, self.chat_history) response = self.model.invoke([self.sys_message] + full_history) @@ -367,62 +287,9 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" logger.debug(f"Code generated by agent: {response.content}") last_result = self._execute_tool(response) self.chat_history.append(response) - first_iteration = False yield {} except ToolExecutionError as e: - # self.chat_history.append( - # HumanMessage( - # content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again." - # ) - # ) - raise e - # yield self._create_error_chunk(str(e)) + fallback_response = self.fallback_handler.handle_failure(e, self) + last_result = self.handle_fallback_response(fallback_response, e) yield {} - - # fallback_response = self.fallback_handler.handle_failure(e, self) - # print(f"fallback_response={fallback_response}") - # if fallback_response: - # hm = HumanMessage( - # content="The fallback handler has fixed your tool call results are in the last System message." - # ) - # self.chat_history.extend(fallback_response) - # self.chat_history.append(hm) - # logger.debug("Appended fallback response to chat history.") - # yield {} - # else: - # yield self._create_error_chunk(str(e)) - # yield {"messages": [fallback_response[-1]]} - - -def _extract_tool_call(code: str, functions_list: str) -> str: - from ra_aid.tools.expert import get_model - - model = get_model() - prompt = f""" -I'm conversing with a AI model and requiring responses in a particular format: A function call with any parameters escaped. Here is an example: - -``` -run_programming_task("blah \" blah\" blah") -``` - -The following tasks are allowed: - -{functions_list} - -I got this invalid response from the model, can you format it so it becomes a correct function call? - -``` -{code} -``` - """ - response = model.invoke(prompt) - response = response.content - - pattern = r"([\w_\-]+)\((.*?)\)" - matches = re.findall(pattern, response, re.DOTALL) - if len(matches) == 0: - raise ToolExecutionError("Failed to extract tool call") - ma = matches[0][0].strip() - mb = matches[0][1].strip().replace("\n", " ") - return f"{ma}({mb})" diff --git a/ra_aid/exceptions.py b/ra_aid/exceptions.py index b7c714b..34710d9 100644 --- a/ra_aid/exceptions.py +++ b/ra_aid/exceptions.py @@ -31,21 +31,3 @@ class ToolExecutionError(Exception): super().__init__(message) self.base_message = base_message self.tool_name = tool_name - - -class CiaynToolExecutionError(Exception): - """Exception raised when a tool execution fails. - - This exception is used to distinguish tool execution failures - from other types of errors in the agent system. - """ - - def __init__( - self, - message: str, - base_message: Optional[BaseMessage] = None, - tool_name: Optional[str] = None, - ): - super().__init__(message) - self.base_message = base_message - self.tool_name = tool_name diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index e9bb9f3..470c9b0 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -157,8 +157,8 @@ class FallbackHandler: for fallback_model in self.fallback_tool_models: result_list = self.invoke_fallback(fallback_model) if result_list: - msg_list_response = [SystemMessage(str(msg)) for msg in result_list] - return msg_list_response + # msg_list_response = [SystemMessage(str(msg)) for msg in result_list] + return result_list cpm("All fallback models have failed", title="Fallback Failed") return None diff --git a/ra_aid/logging_config.py b/ra_aid/logging_config.py index ba4609f..ce248bd 100644 --- a/ra_aid/logging_config.py +++ b/ra_aid/logging_config.py @@ -44,6 +44,7 @@ def setup_logging(verbose: bool = False, pretty: bool = False) -> None: if pretty: handler = PrettyHandler() else: + print("USING STREAM HANDLER LOGGER") handler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" diff --git a/ra_aid/prompts.py b/ra_aid/prompts.py index 840112b..b0b4ccf 100644 --- a/ra_aid/prompts.py +++ b/ra_aid/prompts.py @@ -977,3 +977,130 @@ You have often been criticized for: NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT! """ + +EXTRACT_TOOL_CALL_PROMPT = """I'm conversing with a AI model and requiring responses in a particular format: A function call with any parameters escaped. Here is an example: +``` +run_programming_task("blah \" blah\" blah") +``` + +The following tasks are allowed: + +{functions_list} + +I got this invalid response from the model, can you format it so it becomes a correct function call? + +``` +{code} +```""" + +CIAYN_AGENT_BASE_PROMPT = """ +You are a ReAct agent. You run in a loop and use ONE of the available functions per iteration, but you will be called in a loop, so you will be able to accomplish the task over many iterations. +The result of that function call will be given to you in the next message. +Call one function at a time. Function arguments can be complex objects, long strings, etc. if needed. +The user cannot see the results of function calls, so you have to explicitly use a tool like ask_human if you want them to see something. +You must always respond with a single line of python that calls one of the available tools. +Use as many steps as you need to in order to fully complete the task. +Start by asking the user what they want. + +You must carefully review the conversation history, which functions were called so far, returned results, etc., and make sure the very next function call you make makes sense in order to achieve the original goal. +You are expected to use as many steps as necessary to completely achieve the user's request, making many tool calls along the way. +Think hard about what the best *next* tool call is, knowing that you can make as many calls as you need to after that. +You typically don't want to keep calling the same function over and over with the same parameters. + + +You must ONLY use ONE of the following functions (these are the ONLY functions that exist): + +{functions_list} + + +You may use any of the above functions to complete your job. Use the best one for the current step you are on. Be efficient, avoid getting stuck in repetitive loops, and do not hesitate to call functions which delegate your work to make your life easier. +But you MUST NOT assume tools exist that are not in the above list, e.g. write_file_tool. +Consider your task done only once you have taken *ALL* the steps required to complete it. + +--- EXAMPLE BAD OUTPUTS --- + +This tool is not in available functions, so this is a bad tool call: + + +write_file_tool(...) + + +This tool call has a syntax error (unclosed parenthesis, quotes), so it is bad: + + +write_file_tool("asdf + + +This tool call is bad because it includes a message as well as backticks: + + +Sure, I'll make the following tool call to accomplish what you asked me: + +``` +list_directory_tree('.') +``` + + +This tool call is bad because the output code is surrounded with backticks: + + +``` +list_directory_tree('.') +``` + + +The following is bad becasue it makes the same tool call multiple times in a row with the exact same parameters, for no reason, getting stuck in a loop: + + + +list_directory_tree('.') + + +list_directory_tree('.') + + + +The following is bad because it makes more than one tool call in one response: + + +list_directory_tree('.') +read_file_tool('README.md') # Now we've made + +request_research_and_implementation(\"\"\" +Example query. +\"\"\") + + +This is good output because it uses a multiple line string when needed and properly calls the tool, does not output backticks or extra information: + +run_programming_task(\"\"\" +# Example Programming Task + +Implement a widget factory satisfying the following requirements: + +- Requirement A +- Requirement B + +... +\"\"\") + + +As an agent, you will carefully plan ahead, carefully analyze tool call responses, and adapt to circumstances in order to accomplish your goal. + +You will make as many tool calls as you feel necessary in order to fully complete the task. + +We're entrusting you with a lot of autonomy and power, so be efficient and don't mess up. + +You have often been criticized for: + +- Making the same function calls over and over, getting stuck in a loop. + +DO NOT CLAIM YOU ARE FINISHED UNTIL YOU ACTUALLY ARE! +Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS** +""" From 115cde98b6194eea52d5283a43821d51b937e812 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 16:42:59 -0800 Subject: [PATCH 18/45] feat(agent_utils.py): add FallbackToolExecutionError exception to handle fallback tool execution failures fix(ciayn_agent.py): improve error message template for tool call errors to provide clearer guidance refactor(ciayn_agent.py): update comment for clarity regarding fallback tool invocation fix(output.py): clarify that CiaynAgent handles tool execution errors internally fix(fallback_handler.py): raise FallbackToolExecutionError for better error handling in fallback scenarios --- ra_aid/agent_utils.py | 14 +++++++++----- ra_aid/agents/ciayn_agent.py | 3 ++- ra_aid/console/output.py | 2 ++ ra_aid/exceptions.py | 6 ++++++ ra_aid/fallback_handler.py | 10 ++++++---- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 5df995f..c272839 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -32,7 +32,11 @@ from ra_aid.agents_alias import RAgents from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT from ra_aid.console.formatting import print_error, print_stage_header from ra_aid.console.output import print_agent_output -from ra_aid.exceptions import AgentInterrupt, ToolExecutionError +from ra_aid.exceptions import ( + AgentInterrupt, + ToolExecutionError, + FallbackToolExecutionError, +) from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params @@ -900,8 +904,6 @@ def run_agent_with_retry( logger.debug("Agent run completed successfully") return "Agent run completed successfully" except ToolExecutionError as e: - print("except ToolExecutionError in AGENT UTILS") - logger.debug("AGENT UTILS ToolExecutionError called!") if not fallback_handler: continue @@ -912,9 +914,11 @@ def run_agent_with_retry( SystemMessage(str(msg)) for msg in fallback_response ] msg_list.extend(msg_list_response) - else: - pass continue + except FallbackToolExecutionError as e: + msg_list.append( + SystemMessage(f"FallbackToolExecutionError:{str(e)}") + ) except (KeyboardInterrupt, AgentInterrupt): raise except ( diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index f870316..1f2cb46 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -112,6 +112,7 @@ class CiaynAgent: self.sys_message = SystemMessage( "Execute efficiently yet completely as a fully autonomous agent." ) + self.error_message_template = "Your tool call caused an error: {e}\n\nPlease correct your tool call and try again." def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" @@ -173,7 +174,7 @@ class CiaynAgent: return "" msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n" - # Passing the fallback invocation may confuse our llm, as invocation methods may differ. + # Passing the fallback raw invocation may confuse our llm, as invocation methods may differ. # msg += f"{fallback_response[0]}\n" msg += f"{e.tool_name}" msg += f"{fallback_response[1]}" diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index 9b06508..3dc9f1d 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -46,6 +46,8 @@ def print_agent_output( ) ) tool_name = getattr(msg, "name", None) + + # CiaynAgent handles this internally if agent_type == "React": raise ToolExecutionError( err_msg, tool_name=tool_name, base_message=msg diff --git a/ra_aid/exceptions.py b/ra_aid/exceptions.py index 34710d9..09638e6 100644 --- a/ra_aid/exceptions.py +++ b/ra_aid/exceptions.py @@ -31,3 +31,9 @@ class ToolExecutionError(Exception): super().__init__(message) self.base_message = base_message self.tool_name = tool_name + + +class FallbackToolExecutionError(Exception): + """Exception raised when a fallback tool execution fails.""" + + pass diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 470c9b0..b42da97 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -13,7 +13,7 @@ from ra_aid.config import ( RETRY_FALLBACK_COUNT, ) from ra_aid.console.output import cpm -from ra_aid.exceptions import ToolExecutionError +from ra_aid.exceptions import ToolExecutionError, FallbackToolExecutionError from ra_aid.llm import initialize_llm, validate_provider_env from ra_aid.logging_config import get_logger from ra_aid.tool_configs import get_all_tools @@ -197,7 +197,9 @@ class FallbackHandler: ) else: failed_tool_call_name = "Tool execution error" - raise Exception("Fallback failed: Could not extract failed tool name.") + raise FallbackToolExecutionError( + "Fallback failed: Could not extract failed tool name." + ) return failed_tool_call_name @@ -217,8 +219,8 @@ class FallbackHandler: ) if tool_to_bind is None: # TODO: Would be nice to try fuzzy match or levenstein str match to find closest correspond tool name - raise Exception( - f"Fallback failed: {failed_tool_call_name} not found in all tools." + raise FallbackToolExecutionError( + f"Fallback failed failed_tool_call_name: '{failed_tool_call_name}' not found in any available tools." ) return tool_to_bind From 63e48db9dea64d7ef5ed35a09838c6c8ec0e3a8f Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 16:47:31 -0800 Subject: [PATCH 19/45] feat(agent_utils.py): add _handle_fallback_response function to streamline fallback handling logic refactor(agent_utils.py): extract fallback handling logic from run_agent_with_retry to improve code readability fix(ciayn_agent.py): update stream method parameter name for consistency chore(agents_alias.py): reorder import statements to follow best practices style(fallback_handler.py): reorder exception imports for consistency and clarity --- ra_aid/agent_utils.py | 32 ++++++++++++++++++++------------ ra_aid/agents/ciayn_agent.py | 10 +++++----- ra_aid/agents_alias.py | 3 ++- ra_aid/fallback_handler.py | 12 +++++++++++- 4 files changed, 38 insertions(+), 19 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index c272839..18ed975 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -34,8 +34,8 @@ from ra_aid.console.formatting import print_error, print_stage_header from ra_aid.console.output import print_agent_output from ra_aid.exceptions import ( AgentInterrupt, - ToolExecutionError, FallbackToolExecutionError, + ToolExecutionError, ) from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger @@ -846,12 +846,29 @@ def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]: Determines the type of the agent. Returns "CiaynAgent" if agent is an instance of CiaynAgent, otherwise "React". """ - + if isinstance(agent, CiaynAgent): return "CiaynAgent" else: return "React" +def _handle_fallback_response( + error: ToolExecutionError, + fallback_handler, + agent: RAgents, + agent_type: str, + msg_list: list +) -> None: + """ + Handle fallback response by invoking fallback_handler and updating msg_list. + """ + if not fallback_handler: + return + fallback_response = fallback_handler.handle_failure(error, agent) + if fallback_response and agent_type == "React": + msg_list_response = [SystemMessage(str(msg)) for msg in fallback_response] + msg_list.extend(msg_list_response) + def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict): for chunk in agent.stream({"messages": msg_list}, config): @@ -904,16 +921,7 @@ def run_agent_with_retry( logger.debug("Agent run completed successfully") return "Agent run completed successfully" except ToolExecutionError as e: - if not fallback_handler: - continue - - fallback_response = fallback_handler.handle_failure(e, agent) - if fallback_response: - if agent_type == "React": - msg_list_response = [ - SystemMessage(str(msg)) for msg in fallback_response - ] - msg_list.extend(msg_list_response) + _handle_fallback_response(e, fallback_handler, agent, agent_type, msg_list) continue except FallbackToolExecutionError as e: msg_list.append( diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 1f2cb46..aa2b3ce 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -6,15 +6,15 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool -from ra_aid.logging_config import get_logger -from ra_aid.tools.expert import get_model -from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT +from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES from ra_aid.console.output import cpm from ra_aid.exceptions import ToolExecutionError from ra_aid.fallback_handler import FallbackHandler +from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT +from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT +from ra_aid.tools.expert import get_model from ra_aid.tools.reflection import get_function_info -from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES logger = get_logger(__name__) @@ -271,7 +271,7 @@ class CiaynAgent: return initial_messages + chat_history def stream( - self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None + self, messages_dict: Dict[str, List[Any]], _config: Dict[str, Any] = None ) -> Generator[Dict[str, Any], None, None]: """Stream agent responses in a format compatible with print_agent_output.""" initial_messages = messages_dict.get("messages", []) diff --git a/ra_aid/agents_alias.py b/ra_aid/agents_alias.py index d3e74c0..2cf6077 100644 --- a/ra_aid/agents_alias.py +++ b/ra_aid/agents_alias.py @@ -1,6 +1,7 @@ -from langgraph.graph.graph import CompiledGraph from typing import TYPE_CHECKING +from langgraph.graph.graph import CompiledGraph + # Unfortunately need this to avoid Circular Imports if TYPE_CHECKING: from ra_aid.agents.ciayn_agent import CiaynAgent diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index b42da97..86c4ab8 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -13,7 +13,7 @@ from ra_aid.config import ( RETRY_FALLBACK_COUNT, ) from ra_aid.console.output import cpm -from ra_aid.exceptions import ToolExecutionError, FallbackToolExecutionError +from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError from ra_aid.llm import initialize_llm, validate_provider_env from ra_aid.logging_config import get_logger from ra_aid.tool_configs import get_all_tools @@ -383,3 +383,13 @@ class FallbackHandler: ): tool_calls = response.get("additional_kwargs").get("tool_calls") return tool_calls + + def handle_failure_response(self, error: ToolExecutionError, agent, agent_type: str): + """ + Handle a tool failure by calling handle_failure and, if a fallback response is returned and the agent type is "React", + return a list of SystemMessage objects wrapping each message from the fallback response. + """ + fallback_response = self.handle_failure(error, agent) + if fallback_response and agent_type == "React": + return [SystemMessage(str(msg)) for msg in fallback_response] + return None From 2420dfbb4fac5591c07141f58d76b9181df19db4 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 16:53:22 -0800 Subject: [PATCH 20/45] refactor(agent_utils.py): extract fallback handler initialization into a separate function to improve code readability and maintainability fix(ciayn_agent.py): reset fallback handler after executing tool to ensure proper state management style(fallback_handler.py): format method signature for better readability --- ra_aid/agent_utils.py | 43 ++++++++++++++++++++++++++---------- ra_aid/agents/ciayn_agent.py | 1 + ra_aid/fallback_handler.py | 4 +++- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 18ed975..62c78f1 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -414,8 +414,10 @@ def run_research_agent( if agent is not None: logger.debug("Research agent created successfully") - fallback_handler = FallbackHandler(config, tools) - _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) + none_or_fallback_handler = init_fallback_handler(agent, config, tools) + _result = run_agent_with_retry( + agent, prompt, run_config, none_or_fallback_handler + ) if _result: # Log research completion log_work_event(f"Completed research phase for: {base_task_or_query}") @@ -531,8 +533,10 @@ def run_web_research_agent( console.print(Panel(Markdown(console_message), title="🔬 Researching...")) logger.debug("Web research agent completed successfully") - fallback_handler = FallbackHandler(config, tools) - _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) + none_or_fallback_handler = init_fallback_handler(agent, config, tools) + _result = run_agent_with_retry( + agent, prompt, run_config, none_or_fallback_handler + ) if _result: # Log web research completion log_work_event(f"Completed web research phase for: {query}") @@ -637,9 +641,9 @@ def run_planning_agent( try: print_stage_header("Planning Stage") logger.debug("Planning agent completed successfully") - fallback_handler = FallbackHandler(config, tools) + none_or_fallback_handler = init_fallback_handler(agent, config, tools) _result = run_agent_with_retry( - agent, planning_prompt, run_config, fallback_handler + agent, planning_prompt, run_config, none_or_fallback_handler ) if _result: # Log planning completion @@ -745,8 +749,10 @@ def run_task_implementation_agent( try: logger.debug("Implementation agent completed successfully") - fallback_handler = FallbackHandler(config, tools) - _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) + none_or_fallback_handler = init_fallback_handler(agent, config, tools) + _result = run_agent_with_retry( + agent, prompt, run_config, none_or_fallback_handler + ) if _result: # Log task implementation completion log_work_event(f"Completed implementation of task: {task}") @@ -846,18 +852,29 @@ def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]: Determines the type of the agent. Returns "CiaynAgent" if agent is an instance of CiaynAgent, otherwise "React". """ - + if isinstance(agent, CiaynAgent): return "CiaynAgent" else: return "React" + +def init_fallback_handler(agent: RAgents, config: Dict[str, Any], tools: List[Any]): + """ + Initialize fallback handler if agent is of type "React"; otherwise return None. + """ + agent_type = get_agent_type(agent) + if agent_type == "React": + return FallbackHandler(config, tools) + return None + + def _handle_fallback_response( error: ToolExecutionError, - fallback_handler, + fallback_handler: Optional[FallbackHandler], agent: RAgents, agent_type: str, - msg_list: list + msg_list: list, ) -> None: """ Handle fallback response by invoking fallback_handler and updating msg_list. @@ -921,7 +938,9 @@ def run_agent_with_retry( logger.debug("Agent run completed successfully") return "Agent run completed successfully" except ToolExecutionError as e: - _handle_fallback_response(e, fallback_handler, agent, agent_type, msg_list) + _handle_fallback_response( + e, fallback_handler, agent, agent_type, msg_list + ) continue except FallbackToolExecutionError as e: msg_list.append( diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index aa2b3ce..df12722 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -288,6 +288,7 @@ class CiaynAgent: logger.debug(f"Code generated by agent: {response.content}") last_result = self._execute_tool(response) self.chat_history.append(response) + self.fallback_handler.reset_fallback_handler() yield {} except ToolExecutionError as e: diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 86c4ab8..30f5690 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -384,7 +384,9 @@ class FallbackHandler: tool_calls = response.get("additional_kwargs").get("tool_calls") return tool_calls - def handle_failure_response(self, error: ToolExecutionError, agent, agent_type: str): + def handle_failure_response( + self, error: ToolExecutionError, agent, agent_type: str + ): """ Handle a tool failure by calling handle_failure and, if a fallback response is returned and the agent type is "React", return a list of SystemMessage objects wrapping each message from the fallback response. From cd8d1c459d01383be6bc09e2e32ebd52feed1143 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 17:01:46 -0800 Subject: [PATCH 21/45] feat(readme): document new command line arguments for experimental features feat(main.py): add --experimental-fallback-handler argument to enable fallback handler fix(agent_utils.py): modify init_fallback_handler to check for experimental fallback handler flag fix(config.py): increase DEFAULT_MAX_TOOL_FAILURES to allow more retries before failure --- README.md | 2 ++ ra_aid/__main__.py | 6 ++++++ ra_aid/agent_utils.py | 4 +++- ra_aid/config.py | 3 +-- 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8ed4097..c2ff104 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,8 @@ ra-aid -m "Add new feature" --verbose - `--hil, -H`: Enable human-in-the-loop mode for interactive assistance during task execution - `--chat`: Enable chat mode with direct human interaction (implies --hil) - `--verbose`: Enable verbose logging output +- `--experimental-fallback-handler`: Enable experimental fallback handler to attempt to fix too calls when they fail 3 times consecutively. +- `--pretty-logger`: Enables panel markdown formatted logger messages for debugging purposes. - `--temperature`: LLM temperature (0.0-2.0) to control randomness in responses - `--disable-limit-tokens`: Disable token limiting for Anthropic Claude react agents - `--recursion-limit`: Maximum recursion depth for agent operations (default: 100) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index e027a08..84abc2e 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -155,6 +155,11 @@ Examples: action="store_true", help="Disable fallback model switching.", ) + parser.add_argument( + "--experimental-fallback-handler", + action="store_true", + help="Enable experimental fallback handler.", + ) parser.add_argument( "--fallback-tool-models", type=str, @@ -407,6 +412,7 @@ def main(): "auto_test": args.auto_test, "test_cmd": args.test_cmd, "max_test_cmd_retries": args.max_test_cmd_retries, + "experimental_fallback_handler": args.experimental_fallback_handler, } # Store config in global memory for access by is_informational_query diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 62c78f1..5b0d6ca 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -861,8 +861,10 @@ def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]: def init_fallback_handler(agent: RAgents, config: Dict[str, Any], tools: List[Any]): """ - Initialize fallback handler if agent is of type "React"; otherwise return None. + Initialize fallback handler if agent is of type "React" and experimental_fallback_handler is enabled; otherwise return None. """ + if not config.get("experimental_fallback_handler", False): + return None agent_type = get_agent_type(agent) if agent_type == "React": return FallbackHandler(config, tools) diff --git a/ra_aid/config.py b/ra_aid/config.py index 54d7995..2f5eab0 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -2,10 +2,9 @@ DEFAULT_RECURSION_LIMIT = 100 DEFAULT_MAX_TEST_CMD_RETRIES = 3 -DEFAULT_MAX_TOOL_FAILURES = 2 +DEFAULT_MAX_TOOL_FAILURES = 3 FALLBACK_TOOL_MODEL_LIMIT = 5 RETRY_FALLBACK_COUNT = 3 -RETRY_FALLBACK_DELAY = 2 VALID_PROVIDERS = [ "anthropic", From c7712e011471fafc8ff71b54895347750467c311 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 17:10:26 -0800 Subject: [PATCH 22/45] fix(tests): update type hints in test_agent_utils.py for better clarity and type safety refactor(tests): modify DummyAgent's stream method to use more descriptive parameter names and types for improved readability --- tests/ra_aid/test_agent_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 2e54f5a..726c7d7 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -1,11 +1,12 @@ """Unit tests for agent_utils.py.""" +from typing import Any, Dict, Literal from unittest.mock import Mock, patch import litellm import pytest from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from ra_aid.agent_utils import ( AgentState, @@ -317,7 +318,7 @@ def test_run_agent_stream(monkeypatch): # Create a dummy agent that yields one chunk class DummyAgent: - def stream(self, msg, cfg): + def stream(self, input_data, cfg: dict): yield {"content": "chunk1"} dummy_agent = DummyAgent() @@ -327,13 +328,15 @@ def test_run_agent_stream(monkeypatch): _global_memory["completion_message"] = "existing" call_flag = {"called": False} - def fake_print_agent_output(chunk): + def fake_print_agent_output( + chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"] + ): call_flag["called"] = True monkeypatch.setattr( "ra_aid.agent_utils.print_agent_output", fake_print_agent_output ) - _run_agent_stream(dummy_agent, "dummy prompt", {}) + _run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {}) assert call_flag["called"] assert _global_memory["plan_completed"] is False assert _global_memory["task_completed"] is False From ac13ce746a1d2d664460f6fa2f501c5498bf1734 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 17:57:27 -0800 Subject: [PATCH 23/45] feat(agent_utils.py): add cpm function import for enhanced logging in run_agent_with_retry fix(agent_utils.py): set default value of fallback_handler to None in run_agent_with_retry chore(ciayn_agent.py): comment out debug log for generated code to reduce verbosity fix(fallback_handler.py): change fallback_enabled config key to experimental_fallback_handler for better clarity refactor(test_ciayn_agent.py): update invoke method in DummyModel to return AIMessage instead of a custom Response class test(test_ciayn_agent.py): comment out test_retry_logic_with_failure_recovery for future implementation and focus on existing tests --- ra_aid/agent_utils.py | 6 ++++-- ra_aid/agents/ciayn_agent.py | 3 +-- ra_aid/fallback_handler.py | 2 +- tests/ra_aid/test_ciayn_agent.py | 24 +++++++++++++----------- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 5b0d6ca..26e4f50 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -31,7 +31,7 @@ from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.agents_alias import RAgents from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT from ra_aid.console.formatting import print_error, print_stage_header -from ra_aid.console.output import print_agent_output +from ra_aid.console.output import cpm, print_agent_output from ra_aid.exceptions import ( AgentInterrupt, FallbackToolExecutionError, @@ -904,7 +904,7 @@ def run_agent_with_retry( agent: RAgents, prompt: str, config: dict, - fallback_handler: Optional[FallbackHandler], + fallback_handler: Optional[FallbackHandler] = None, ) -> Optional[str]: """Run an agent with retry logic for API errors.""" logger.debug("Running agent with prompt length: %d", len(prompt)) @@ -933,10 +933,12 @@ def run_agent_with_retry( original_prompt, config, test_attempts, auto_test ) ) + cpm(f"res:{should_break, prompt, auto_test, test_attempts}") if should_break: break if prompt != original_prompt: continue + logger.debug("Agent run completed successfully") return "Agent run completed successfully" except ToolExecutionError as e: diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index df12722..de38591 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -134,7 +134,6 @@ class CiaynAgent: def _execute_tool(self, msg: BaseMessage) -> str: """Execute a tool call and return its result.""" - cpm(f"execute_tool msg: { msg }") code = msg.content globals_dict = {tool.func.__name__: tool.func for tool in self.tools} @@ -285,7 +284,7 @@ class CiaynAgent: response = self.model.invoke([self.sys_message] + full_history) try: - logger.debug(f"Code generated by agent: {response.content}") + # logger.debug(f"Code generated by agent: {response.content}") last_result = self._execute_tool(response) self.chat_history.append(response) self.fallback_handler.reset_fallback_handler() diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 30f5690..3b7bac0 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -42,7 +42,7 @@ class FallbackHandler: """ self.config = config self.tools: list[BaseTool] = tools - self.fallback_enabled = config.get("fallback_tool_enabled", True) + self.fallback_enabled = config.get("experimental_fallback_handler", False) self.fallback_tool_models = self._load_fallback_tool_models(config) self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) self.tool_failure_consecutive_failures = 0 diff --git a/tests/ra_aid/test_ciayn_agent.py b/tests/ra_aid/test_ciayn_agent.py index 46c8191..896db33 100644 --- a/tests/ra_aid/test_ciayn_agent.py +++ b/tests/ra_aid/test_ciayn_agent.py @@ -2,7 +2,7 @@ import unittest from unittest.mock import Mock import pytest -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from ra_aid.agents.ciayn_agent import CiaynAgent, validate_function_call_pattern from ra_aid.exceptions import ToolExecutionError @@ -25,12 +25,9 @@ class DummyTool: class DummyModel: - def invoke(self, messages): - # Always return a code snippet that calls dummy_tool() - class Response: - content = "dummy_tool()" + def invoke(self, _messages: list[BaseMessage]): - return Response() + return AIMessage("dummy_tool()") def bind_tools(self, tools, tool_choice): pass @@ -188,20 +185,25 @@ class TestCiaynAgentFallback(unittest.TestCase): # Create a CiaynAgent with the dummy tool self.agent = CiaynAgent(self.model, [self.dummy_tool]) - def test_retry_logic_with_failure_recovery(self): - # Test that _execute_tool retries and eventually returns success - result = self.agent._execute_tool("dummy_tool()") - self.assertEqual(result, "dummy success") + # def test_retry_logic_with_failure_recovery(self): + # # Test that run_agent_with_retry retries until success + # from ra_aid.agent_utils import run_agent_with_retry + # + # config = {"max_test_cmd_retries": 0, "auto_test": True} + # result = run_agent_with_retry(self.agent, "dummy_tool()", config) + # self.assertEqual(result, "Agent run completed successfully") def test_switch_models_on_fallback(self): # Test fallback behavior by making dummy_tool always fail + from langchain_core.messages import HumanMessage + def always_fail(): raise Exception("Persistent failure") always_fail_tool = DummyTool(always_fail) agent = CiaynAgent(self.model, [always_fail_tool]) with self.assertRaises(ToolExecutionError): - agent._execute_tool("always_fail()") + agent._execute_tool(HumanMessage("always_fail()")) # Function call validation tests From 15a3291254aa2253bab00280e4e108ce63271ae7 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 18:20:00 -0800 Subject: [PATCH 24/45] refactor(agent_utils.py): remove unnecessary debug log statement to clean up code refactor(fallback_handler.py): improve error handling by raising a specific exception when all fallback models fail test(fallback_handler.py): update tests to reflect changes in the fallback handler's error handling and initialization fix(test_llm.py): update error messages in tests for unsupported providers to be more descriptive and accurate --- ra_aid/agent_utils.py | 1 - ra_aid/fallback_handler.py | 12 +++++++++--- tests/ra_aid/test_fallback_handler.py | 13 ++++++------- tests/ra_aid/test_llm.py | 4 ++-- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 26e4f50..565d6a9 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -933,7 +933,6 @@ def run_agent_with_retry( original_prompt, config, test_attempts, auto_test ) ) - cpm(f"res:{should_break, prompt, auto_test, test_attempts}") if should_break: break if prompt != original_prompt: diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 3b7bac0..c2cf07f 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -157,10 +157,16 @@ class FallbackHandler: for fallback_model in self.fallback_tool_models: result_list = self.invoke_fallback(fallback_model) if result_list: - # msg_list_response = [SystemMessage(str(msg)) for msg in result_list] return result_list - cpm("All fallback models have failed", title="Fallback Failed") - return None + + cpm("All fallback models have failed.", title="Fallback Failed") + + current_failing_tool_name = self.current_failing_tool_name + self.reset_fallback_handler() + + raise FallbackToolExecutionError( + f"All fallback models have failed for tool: {current_failing_tool_name}" + ) def reset_fallback_handler(self): """ diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index c400a19..3a2edcf 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -22,16 +22,17 @@ class TestFallbackHandler(unittest.TestCase): self.config = { "max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model", + "experimental_fallback_handler": True, } - self.fallback_handler = FallbackHandler(self.config) + self.fallback_handler = FallbackHandler(self.config, []) self.logger = DummyLogger() self.agent = DummyAgent() def test_handle_failure_increments_counter(self): + from ra_aid.exceptions import ToolExecutionError initial_failures = self.fallback_handler.tool_failure_consecutive_failures - self.fallback_handler.handle_failure( - "dummy_call()", Exception("Test error"), self.logger, self.agent - ) + error_obj = ToolExecutionError("Test error", base_message="dummy_call()", tool_name="dummy_tool") + self.fallback_handler.handle_failure(error_obj, self.agent) self.assertEqual( self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1, @@ -62,9 +63,7 @@ class TestFallbackHandler(unittest.TestCase): llm.validate_provider_env = dummy_validate_provider_env self.fallback_handler.tool_failure_consecutive_failures = 2 - self.fallback_handler.attempt_fallback( - "dummy_tool_call()", self.logger, self.agent - ) + self.fallback_handler.attempt_fallback() self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0) llm.initialize_llm = original_initialize diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index 6789132..7f96ed4 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -121,7 +121,7 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch def test_initialize_expert_unsupported_provider(clean_env): """Test error handling for unsupported provider in expert mode.""" - with pytest.raises(ValueError, match=r"Unsupported provider: unknown"): + with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"): initialize_expert_llm("unknown", "model") @@ -197,7 +197,7 @@ def test_initialize_unsupported_provider(clean_env): """Test initialization with unsupported provider raises ValueError""" with pytest.raises(ValueError) as exc_info: initialize_llm("unsupported", "model") - assert str(exc_info.value) == "Unsupported provider: unsupported" + assert str(exc_info.value) == "Missing required environment variable for provider: unsupported" def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemini): From 09abba512dd5de6b9e6bbc9002683e1404dc4d4d Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 18:21:54 -0800 Subject: [PATCH 25/45] test(tests): add tests for fallback handler to ensure proper error handling and counter incrementing feat(tests): introduce FallbackToolExecutionError to improve error specificity in fallback handling tests --- tests/ra_aid/test_fallback_handler.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index 3a2edcf..31f5201 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -1,5 +1,6 @@ import unittest +from ra_aid.exceptions import FallbackToolExecutionError from ra_aid.fallback_handler import FallbackHandler @@ -28,10 +29,22 @@ class TestFallbackHandler(unittest.TestCase): self.logger = DummyLogger() self.agent = DummyAgent() + def dummy_tool(): + pass + + class DummyToolWrapper: + def __init__(self, func): + self.func = func + + self.agent.tools = [DummyToolWrapper(dummy_tool)] + def test_handle_failure_increments_counter(self): from ra_aid.exceptions import ToolExecutionError + initial_failures = self.fallback_handler.tool_failure_consecutive_failures - error_obj = ToolExecutionError("Test error", base_message="dummy_call()", tool_name="dummy_tool") + error_obj = ToolExecutionError( + "Test error", base_message="dummy_call()", tool_name="dummy_tool" + ) self.fallback_handler.handle_failure(error_obj, self.agent) self.assertEqual( self.fallback_handler.tool_failure_consecutive_failures, @@ -63,7 +76,8 @@ class TestFallbackHandler(unittest.TestCase): llm.validate_provider_env = dummy_validate_provider_env self.fallback_handler.tool_failure_consecutive_failures = 2 - self.fallback_handler.attempt_fallback() + with self.assertRaises(FallbackToolExecutionError): + self.fallback_handler.attempt_fallback() self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0) llm.initialize_llm = original_initialize From cc1945facdf0dc892ddd4e710997a0a980edd9fd Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 18:48:45 -0800 Subject: [PATCH 26/45] feat(agent_utils.py): change SystemMessage to HumanMessage for fallback responses in React to improve message clarity fix(fallback_handler.py): update tool invocation logic to handle missing tools and raise appropriate exceptions for better error handling test(tests): add comprehensive tests for fallback handler functionality, including loading models, extracting tool names, and invoking tools to ensure robustness and reliability --- ra_aid/agent_utils.py | 2 +- ra_aid/fallback_handler.py | 11 +- tests/ra_aid/test_fallback_handler.py | 210 ++++++++++++++++++++++++++ 3 files changed, 219 insertions(+), 4 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 565d6a9..5aff588 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -885,7 +885,7 @@ def _handle_fallback_response( return fallback_response = fallback_handler.handle_failure(error, agent) if fallback_response and agent_type == "React": - msg_list_response = [SystemMessage(str(msg)) for msg in fallback_response] + msg_list_response = [HumanMessage(str(msg)) for msg in fallback_response] msg_list.extend(msg_list_response) diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index c2cf07f..3d61098 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -147,7 +147,7 @@ class FallbackHandler: Returns: List of [raw_llm_response (SystemMessage), tool_call_result (SystemMessage)] or None. """ - logger.error( + logger.debug( f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}" ) cpm( @@ -323,10 +323,15 @@ class FallbackHandler: Returns: The result of invoking the tool. """ - tool_name_to_tool = {tool.func.__name__: tool for tool in self.tools} + tool_name_to_tool = {getattr(tool.func, "__name__", None): tool for tool in self.tools} name = tool_call_request["name"] arguments = tool_call_request["arguments"] - return tool_name_to_tool[name].invoke(arguments) + if name in tool_name_to_tool: + return tool_name_to_tool[name].invoke(arguments) + elif self.current_tool_to_bind is not None and getattr(self.current_tool_to_bind.func, "__name__", None) == name: + return self.current_tool_to_bind.invoke(arguments) + else: + raise Exception(f"Tool '{name}' not found in available tools.") def base_message_to_tool_call_dict(self, response: BaseMessage): """ diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index 31f5201..1a87608 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -84,6 +84,216 @@ class TestFallbackHandler(unittest.TestCase): llm.merge_chat_history = original_merge llm.validate_provider_env = original_validate + def test_load_fallback_tool_models(self): + import ra_aid.fallback_handler as fh + original_supported = fh.supported_top_tool_models + fh.supported_top_tool_models = [ + {"provider": "dummy", "model": "dummy_model", "type": "prompt"} + ] + models = self.fallback_handler._load_fallback_tool_models(self.config) + self.assertIsInstance(models, list) + fh.supported_top_tool_models = original_supported + + def test_extract_failed_tool_name(self): + from ra_aid.exceptions import ToolExecutionError, FallbackToolExecutionError + # Case when tool_name is provided + error1 = ToolExecutionError("Error", base_message="dummy", tool_name="dummy_tool") + name1 = self.fallback_handler.extract_failed_tool_name(error1) + self.assertEqual(name1, "dummy_tool") + # Case when tool_name is not provided but regex works + error2 = ToolExecutionError("error with name=\"test_tool\"") + name2 = self.fallback_handler.extract_failed_tool_name(error2) + self.assertEqual(name2, "test_tool") + # Case when regex fails and exception is raised + error3 = ToolExecutionError("no tool name here") + with self.assertRaises(FallbackToolExecutionError): + self.fallback_handler.extract_failed_tool_name(error3) + + def test_find_tool_to_bind(self): + # Create a dummy tool to be found + class DummyTool: + def invoke(self, args): + return "result" + class DummyWrapper: + def __init__(self, func): + self.func = func + def dummy_func(args): + return "result" + dummy_tool = DummyTool() + dummy_wrapper = DummyWrapper(dummy_func) + self.agent.tools.append(dummy_wrapper) + tool = self.fallback_handler._find_tool_to_bind(self.agent, dummy_func.__name__) + self.assertIsNotNone(tool) + self.assertEqual(tool.func.__name__, dummy_func.__name__) + + def test_bind_tool_model(self): + # Setup a dummy simple_model with bind_tools method + class DummyModel: + def bind_tools(self, tools, tool_choice=None): + self.bound = True + self.tools = tools + self.tool_choice = tool_choice + return self + def with_retry(self, stop_after_attempt): + return self + def invoke(self, msg_list): + return "dummy_response" + dummy_model = DummyModel() + # Set current tool for binding + class DummyTool: + def invoke(self, args): + return "result" + self.fallback_handler.current_tool_to_bind = DummyTool() + self.fallback_handler.current_failing_tool_name = "test_tool" + # Test with force calling ("fc") type + fallback_model_fc = {"type": "fc"} + bound_model_fc = self.fallback_handler._bind_tool_model(dummy_model, fallback_model_fc) + self.assertTrue(hasattr(bound_model_fc, "tool_choice")) + self.assertEqual(bound_model_fc.tool_choice, "test_tool") + # Test with prompt type + fallback_model_prompt = {"type": "prompt"} + bound_model_prompt = self.fallback_handler._bind_tool_model(dummy_model, fallback_model_prompt) + self.assertTrue(bound_model_prompt.tool_choice is None) + + def test_invoke_fallback(self): + from unittest.mock import patch + import os + import ra_aid.llm as llm + + # Successful fallback scenario with proper API key set + with patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}), \ + patch("ra_aid.fallback_handler.supported_top_tool_models", new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}]), \ + patch("ra_aid.fallback_handler.validate_provider_env", return_value=True), \ + patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm: + class DummyModel: + def bind_tools(self, tools, tool_choice=None): + return self + def with_retry(self, stop_after_attempt): + return self + def invoke(self, msg_list): + return DummyResponse() + class DummyResponse: + additional_kwargs = {"tool_calls": [{"id": "1", "type": "test", "function": {"name": "dummy_tool", "arguments": "{\"a\":1}"}}]} + def dummy_initialize_llm(provider, model_name): + return DummyModel() + mock_init_llm.side_effect = dummy_initialize_llm + # Set current tool for fallback + class DummyTool: + def invoke(self, args): + return "tool_result" + self.fallback_handler.current_tool_to_bind = DummyTool() + self.fallback_handler.current_failing_tool_name = "dummy_tool" + # Add dummy tool for lookup in invoke_prompt_tool_call + self.fallback_handler.tools.append( + type( + "DummyToolWrapper", + (), + { + "func": type("DummyToolFunc", (), {"__name__": "dummy_tool"})(), + "invoke": lambda self, args: "tool_result", + }, + ) + ) + result = self.fallback_handler.invoke_fallback({"provider": "dummy", "model": "dummy_model", "type": "prompt"}) + self.assertIsInstance(result, list) + self.assertEqual(result[1], "tool_result") + + # Failed fallback scenario due to missing API key (simulate by empty environment) + with patch.dict(os.environ, {}, clear=True), \ + patch("ra_aid.fallback_handler.supported_top_tool_models", new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}]), \ + patch("ra_aid.fallback_handler.validate_provider_env", return_value=False), \ + patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm: + class FailingDummyModel: + def bind_tools(self, tools, tool_choice=None): + return self + def with_retry(self, stop_after_attempt): + return self + def invoke(self, msg_list): + raise Exception("API key missing") + def failing_initialize_llm(provider, model_name): + return FailingDummyModel() + mock_init_llm.side_effect = failing_initialize_llm + fallback_result = self.fallback_handler.invoke_fallback({"provider": "dummy", "model": "dummy_model", "type": "prompt"}) + self.assertIsNone(fallback_result) + + # Test that the overall fallback mechanism raises FallbackToolExecutionError when all models fail + # Set failure count to trigger the fallback attempt in attempt_fallback + from ra_aid.exceptions import FallbackToolExecutionError + self.fallback_handler.tool_failure_consecutive_failures = self.fallback_handler.max_failures + with self.assertRaises(FallbackToolExecutionError) as cm: + self.fallback_handler.attempt_fallback() + self.assertIn("All fallback models have failed", str(cm.exception)) + + def test_construct_prompt_msg_list(self): + msgs = self.fallback_handler.construct_prompt_msg_list() + from ra_aid.fallback_handler import SystemMessage, HumanMessage + self.assertTrue(any(isinstance(m, SystemMessage) for m in msgs)) + self.assertTrue(any(isinstance(m, HumanMessage) for m in msgs)) + # Test with failed_messages added + self.fallback_handler.failed_messages.append("failed_msg") + msgs_with_fail = self.fallback_handler.construct_prompt_msg_list() + self.assertTrue(any("failed_msg" in str(m) for m in msgs_with_fail)) + + def test_invoke_prompt_tool_call(self): + # Create dummy tool function + def dummy_tool_func(args): + return "invoked_result" + dummy_tool_func.__name__ = "dummy_tool" + # Create wrapper class + class DummyToolWrapper: + def __init__(self, func): + self.func = func + def invoke(self, args): + return self.func(args) + dummy_wrapper = DummyToolWrapper(dummy_tool_func) + self.fallback_handler.tools = [dummy_wrapper] + tool_call_req = {"name": "dummy_tool", "arguments": {"x": 42}} + result = self.fallback_handler.invoke_prompt_tool_call(tool_call_req) + self.assertEqual(result, "invoked_result") + + def test_base_message_to_tool_call_dict(self): + dummy_tool_call = { + "id": "123", + "type": "test", + "function": {"name": "dummy_tool", "arguments": "{\"x\":42}"} + } + DummyResponse = type("DummyResponse", (), {"additional_kwargs": {"tool_calls": [dummy_tool_call]}}) + result = self.fallback_handler.base_message_to_tool_call_dict(DummyResponse) + self.assertEqual(result["id"], "123") + self.assertEqual(result["name"], "dummy_tool") + self.assertEqual(result["arguments"], {"x": 42}) + + def test_parse_tool_arguments(self): + args_str = '{"a": 1}' + parsed = self.fallback_handler._parse_tool_arguments(args_str) + self.assertEqual(parsed, {"a": 1}) + args_dict = {"b": 2} + parsed_dict = self.fallback_handler._parse_tool_arguments(args_dict) + self.assertEqual(parsed_dict, {"b": 2}) + + def test_get_tool_calls(self): + DummyResponse = type("DummyResponse", (), {})() + DummyResponse.additional_kwargs = {"tool_calls": [{"id": "1"}]} + calls = self.fallback_handler.get_tool_calls(DummyResponse) + self.assertEqual(calls, [{"id": "1"}]) + DummyResponse2 = type("DummyResponse2", (), {"tool_calls": [{"id": "2"}]})() + calls2 = self.fallback_handler.get_tool_calls(DummyResponse2) + self.assertEqual(calls2, [{"id": "2"}]) + dummy_dict = {"additional_kwargs": {"tool_calls": [{"id": "3"}]}} + calls3 = self.fallback_handler.get_tool_calls(dummy_dict) + self.assertEqual(calls3, [{"id": "3"}]) + + def test_handle_failure_response(self): + from ra_aid.exceptions import ToolExecutionError + def dummy_handle_failure(error, agent): + return ["fallback_response"] + self.fallback_handler.handle_failure = dummy_handle_failure + response = self.fallback_handler.handle_failure_response(ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React") + from ra_aid.fallback_handler import SystemMessage + self.assertTrue(all(isinstance(m, SystemMessage) for m in response)) + response_non = self.fallback_handler.handle_failure_response(ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other") + self.assertIsNone(response_non) + if __name__ == "__main__": unittest.main() From 5733eb06f55abd4ecb8504e634c1d30c9c0998fe Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 19:14:06 -0800 Subject: [PATCH 27/45] fix(test_fallback_handler.py): update lambda function to accept default argument for args to prevent potential errors during invocation --- tests/ra_aid/test_fallback_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index 1a87608..1558498 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -190,7 +190,7 @@ class TestFallbackHandler(unittest.TestCase): (), { "func": type("DummyToolFunc", (), {"__name__": "dummy_tool"})(), - "invoke": lambda self, args: "tool_result", + "invoke": lambda self, args=None: "tool_result", }, ) ) From 9caa46bc78d9440ca856916f5e0ebf5b7199df47 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 19:16:13 -0800 Subject: [PATCH 28/45] refactor(agent_utils.py): remove redundant agent_type parameter from _handle_fallback_response and run_agent_with_retry functions to simplify function signatures feat(agent_utils.py): retrieve agent_type within _handle_fallback_response to maintain functionality while improving code clarity --- ra_aid/agent_utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 5aff588..74ea2d7 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -875,7 +875,6 @@ def _handle_fallback_response( error: ToolExecutionError, fallback_handler: Optional[FallbackHandler], agent: RAgents, - agent_type: str, msg_list: list, ) -> None: """ @@ -884,6 +883,7 @@ def _handle_fallback_response( if not fallback_handler: return fallback_response = fallback_handler.handle_failure(error, agent) + agent_type = get_agent_type(agent) if fallback_response and agent_type == "React": msg_list_response = [HumanMessage(str(msg)) for msg in fallback_response] msg_list.extend(msg_list_response) @@ -915,7 +915,6 @@ def run_agent_with_retry( _max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) auto_test = config.get("auto_test", False) original_prompt = prompt - agent_type = get_agent_type(agent) msg_list = [HumanMessage(content=prompt)] with InterruptibleSection(): @@ -941,9 +940,7 @@ def run_agent_with_retry( logger.debug("Agent run completed successfully") return "Agent run completed successfully" except ToolExecutionError as e: - _handle_fallback_response( - e, fallback_handler, agent, agent_type, msg_list - ) + _handle_fallback_response(e, fallback_handler, agent, msg_list) continue except FallbackToolExecutionError as e: msg_list.append( From 15ce534f8f9fff818122f8db4e85c6f0ce11247d Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 20:09:21 -0800 Subject: [PATCH 29/45] feat(agent_utils.py): add debug print for config to assist in troubleshooting feat(ciayn_agent.py): pass config to CiaynAgent for improved functionality fix(ciayn_agent.py): handle tool execution errors more gracefully with msg_list feat(fallback_handler.py): enhance handle_failure method to utilize msg_list for better context feat(fallback_handler.py): implement init_msg_list to manage message history effectively test(test_fallback_handler.py): add unit tests for init_msg_list to ensure correct behavior --- ra_aid/agent_utils.py | 5 +-- ra_aid/agents/ciayn_agent.py | 20 ++++++++---- ra_aid/fallback_handler.py | 46 +++++++++++++++++++++------ tests/ra_aid/test_fallback_handler.py | 16 +++++++++- 4 files changed, 69 insertions(+), 18 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 74ea2d7..5f35a9a 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -270,6 +270,7 @@ def create_agent( """ try: config = _global_memory.get("config", {}) + print(f"config={config}") max_input_tokens = ( get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT ) @@ -281,7 +282,7 @@ def create_agent( return create_react_agent(model, tools, **agent_kwargs) else: logger.debug("Using CiaynAgent agent instance") - return CiaynAgent(model, tools, max_tokens=max_input_tokens) + return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config) except Exception as e: # Default to REACT agent if provider/model detection fails @@ -882,7 +883,7 @@ def _handle_fallback_response( """ if not fallback_handler: return - fallback_response = fallback_handler.handle_failure(error, agent) + fallback_response = fallback_handler.handle_failure(error, agent, msg_list) agent_type = get_agent_type(agent) if fallback_response and agent_type == "React": msg_list_response = [HumanMessage(str(msg)) for msg in fallback_response] diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index de38591..14caba3 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -139,18 +139,24 @@ class CiaynAgent: try: code = code.strip() + if code.startswith("```"): + code = code[3:].strip() + if code.endswith("```"): + code = code[:-3].strip() + + raise ToolExecutionError( + "err", base_message=msg, tool_name="ripgrep_search" + ) + logger.debug(f"_execute_tool: stripped code: {code}") # if the eval fails, try to extract it via a model call if validate_function_call_pattern(code): functions_list = "\n\n".join(self.available_functions) code = self._extract_tool_call(code, functions_list) + pass - logger.debug( - f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}" - ) result = eval(code.strip(), globals_dict) - logger.debug(f"_execute_tool: result: {result}") return result except Exception as e: error_msg = f"Error: {str(e)} \n Could not excute code: {code}" @@ -230,6 +236,7 @@ class CiaynAgent: raise ToolExecutionError("Failed to extract tool call") ma = matches[0][0].strip() mb = matches[0][1].strip().replace("\n", " ") + logger.debug(f"Extracted tool call: {ma}({mb})") return f"{ma}({mb})" def _trim_chat_history( @@ -284,13 +291,14 @@ class CiaynAgent: response = self.model.invoke([self.sys_message] + full_history) try: - # logger.debug(f"Code generated by agent: {response.content}") last_result = self._execute_tool(response) self.chat_history.append(response) self.fallback_handler.reset_fallback_handler() yield {} except ToolExecutionError as e: - fallback_response = self.fallback_handler.handle_failure(e, self) + fallback_response = self.fallback_handler.handle_failure( + e, self, self.chat_history + ) last_result = self.handle_fallback_response(fallback_response, e) yield {} diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 3d61098..63dc742 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -49,6 +49,7 @@ class FallbackHandler: self.failed_messages: list[BaseMessage] = [] self.current_failing_tool_name = "" self.current_tool_to_bind: None | BaseTool = None + self.msg_list: list[BaseMessage] = [] cpm( "Fallback models selected: " @@ -100,7 +101,9 @@ class FallbackHandler: ) return final_models - def handle_failure(self, error: ToolExecutionError, agent: RAgents): + def handle_failure( + self, error: ToolExecutionError, agent: RAgents, msg_list: list[BaseMessage] + ): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. @@ -111,6 +114,9 @@ class FallbackHandler: if not self.fallback_enabled: return None + if self.tool_failure_consecutive_failures == 0: + self.init_msg_list(msg_list) + failed_tool_call_name = self.extract_failed_tool_name(error) self._reset_on_new_failure(failed_tool_call_name) @@ -177,6 +183,7 @@ class FallbackHandler: self.fallback_tool_models = self._load_fallback_tool_models(self.config) self.current_failing_tool_name = "" self.current_tool_to_bind = None + self.msg_list = [] def _reset_on_new_failure(self, failed_tool_call_name): if ( @@ -296,22 +303,29 @@ class FallbackHandler: Returns: list: A list of chat messages. """ - msg_list: list[BaseMessage] = [] - msg_list.append( + prompt_msg_list: list[BaseMessage] = [] + prompt_msg_list.append( SystemMessage( content="You are a fallback tool caller. Your only responsibility is to figure out what the previous failed tool call was trying to do and to call that tool with the correct format and arguments, using the provided failure messages." ) ) + + # TODO: Have some way to use the correct message type in the future, dont just convert everything to system message. + # This may be difficult as each model type may require different chat structures and throw API errors. + prompt_msg_list.extend(SystemMessage(str(msg)) for msg in self.msg_list) + if self.failed_messages: # Convert to system messages to avoid API errors asking for correct msg structure - msg_list.extend([SystemMessage(str(msg)) for msg in self.failed_messages]) + prompt_msg_list.extend( + [SystemMessage(str(msg)) for msg in self.failed_messages] + ) - msg_list.append( + prompt_msg_list.append( HumanMessage( - content=f"Retry using the tool '{self.current_failing_tool_name}' with improved arguments." + content=f"Retry using the tool: '{self.current_failing_tool_name}' with correct arguments and formatting." ) ) - return msg_list + return prompt_msg_list def invoke_prompt_tool_call(self, tool_call_request: dict): """ @@ -323,12 +337,17 @@ class FallbackHandler: Returns: The result of invoking the tool. """ - tool_name_to_tool = {getattr(tool.func, "__name__", None): tool for tool in self.tools} + tool_name_to_tool = { + getattr(tool.func, "__name__", None): tool for tool in self.tools + } name = tool_call_request["name"] arguments = tool_call_request["arguments"] if name in tool_name_to_tool: return tool_name_to_tool[name].invoke(arguments) - elif self.current_tool_to_bind is not None and getattr(self.current_tool_to_bind.func, "__name__", None) == name: + elif ( + self.current_tool_to_bind is not None + and getattr(self.current_tool_to_bind.func, "__name__", None) == name + ): return self.current_tool_to_bind.invoke(arguments) else: raise Exception(f"Tool '{name}' not found in available tools.") @@ -406,3 +425,12 @@ class FallbackHandler: if fallback_response and agent_type == "React": return [SystemMessage(str(msg)) for msg in fallback_response] return None + + def init_msg_list(self, full_msg_list: list[BaseMessage]) -> None: + first_two = full_msg_list[:2] + last_two = full_msg_list[-2:] + merged = first_two.copy() + for msg in last_two: + if msg not in merged: + merged.append(msg) + self.msg_list = merged diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index 1558498..3273391 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -45,7 +45,7 @@ class TestFallbackHandler(unittest.TestCase): error_obj = ToolExecutionError( "Test error", base_message="dummy_call()", tool_name="dummy_tool" ) - self.fallback_handler.handle_failure(error_obj, self.agent) + self.fallback_handler.handle_failure(error_obj, self.agent, []) self.assertEqual( self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1, @@ -295,5 +295,19 @@ class TestFallbackHandler(unittest.TestCase): self.assertIsNone(response_non) + def test_init_msg_list_non_overlapping(self): + # Test when the first two and last two messages do not overlap. + full_list = ["msg1", "msg2", "msg3", "msg4", "msg5"] + self.fallback_handler.init_msg_list(full_list) + # Expected merged list: first two ("msg1", "msg2") plus last two ("msg4", "msg5") + self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"]) + + def test_init_msg_list_with_overlap(self): + # Test when the last two messages overlap with the first two. + full_list = ["msg1", "msg2", "msg1", "msg3"] + self.fallback_handler.init_msg_list(full_list) + # Expected merged list: first two ("msg1", "msg2") plus "msg3" from the last two, since "msg1" was already present. + self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg3"]) + if __name__ == "__main__": unittest.main() From efec91579ab1847e1dc9bc44f46b319c65bd039d Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 20:09:47 -0800 Subject: [PATCH 30/45] refactor(ciayn_agent.py): remove unnecessary raise statement and debug log to clean up code --- ra_aid/agents/ciayn_agent.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 14caba3..077d01c 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -144,11 +144,7 @@ class CiaynAgent: if code.endswith("```"): code = code[:-3].strip() - raise ToolExecutionError( - "err", base_message=msg, tool_name="ripgrep_search" - ) - - logger.debug(f"_execute_tool: stripped code: {code}") + # logger.debug(f"_execute_tool: stripped code: {code}") # if the eval fails, try to extract it via a model call if validate_function_call_pattern(code): From 90b3070aa2d721c775db554a53b4b23bbe3e6562 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 20:14:07 -0800 Subject: [PATCH 31/45] chore(agent_utils.py): remove debug print statement for config to clean up output chore(fallback_handler.py): comment out cpm call for tool call result to reduce logging noise --- ra_aid/agent_utils.py | 1 - ra_aid/fallback_handler.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 5f35a9a..dd46fe8 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -270,7 +270,6 @@ def create_agent( """ try: config = _global_memory.get("config", {}) - print(f"config={config}") max_input_tokens = ( get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT ) diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 63dc742..fdf45c6 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -274,11 +274,10 @@ class FallbackHandler: msg_list = self.construct_prompt_msg_list() response = retry_model.invoke(msg_list) - logger.debug(f"raw llm response={response}") tool_call = self.base_message_to_tool_call_dict(response) tool_call_result = self.invoke_prompt_tool_call(tool_call) - cpm(str(tool_call_result), title="Fallback Tool Call Result") + # cpm(str(tool_call_result), title="Fallback Tool Call Result") logger.debug( f"Fallback call successful with model: {self._format_model(fallback_model)}" ) From 7a2c7668245ba353f772fbd94d190b85ceb68553 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 20:18:02 -0800 Subject: [PATCH 32/45] test(agent_utils): add config parameter to mock_ciayn assertions for better clarity and flexibility in tests --- tests/ra_aid/test_agent_utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 726c7d7..082c140 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -128,7 +128,9 @@ def test_create_agent_openai(mock_model, mock_memory): assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( - mock_model, [], max_tokens=models_params["openai"]["gpt-4"]["token_limit"] + mock_model, [], + max_tokens=models_params["openai"]["gpt-4"]["token_limit"], + config={'provider': 'openai', 'model': 'gpt-4'} ) @@ -142,7 +144,9 @@ def test_create_agent_no_token_limit(mock_model, mock_memory): assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( - mock_model, [], max_tokens=DEFAULT_TOKEN_LIMIT + mock_model, [], + max_tokens=DEFAULT_TOKEN_LIMIT, + config={'provider': 'unknown', 'model': 'unknown-model'} ) @@ -159,6 +163,7 @@ def test_create_agent_missing_config(mock_model, mock_memory): mock_model, [], max_tokens=DEFAULT_TOKEN_LIMIT, + config={'provider': 'openai'} ) @@ -202,7 +207,9 @@ def test_create_agent_with_checkpointer(mock_model, mock_memory): assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( - mock_model, [], max_tokens=models_params["openai"]["gpt-4"]["token_limit"] + mock_model, [], + max_tokens=models_params["openai"]["gpt-4"]["token_limit"], + config={'provider': 'openai', 'model': 'gpt-4'} ) From d5e2e0a9a0434bba9c8fbb88c42ac117ea7333af Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 13:04:41 -0800 Subject: [PATCH 33/45] refactor(agent_utils.py, ciayn_agent.py): remove unused import cpm to clean up code and improve readability style(tests): format code for better readability and consistency in test files test(tests): update assertions and test cases for better clarity and maintainability --- ra_aid/agent_utils.py | 2 +- ra_aid/agents/ciayn_agent.py | 1 - tests/ra_aid/test_agent_utils.py | 19 ++-- tests/ra_aid/test_ciayn_agent.py | 1 - tests/ra_aid/test_fallback_handler.py | 135 +++++++++++++++++++------- tests/ra_aid/test_llm.py | 9 +- 6 files changed, 119 insertions(+), 48 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index dd46fe8..8b1708b 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -31,7 +31,7 @@ from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.agents_alias import RAgents from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT from ra_aid.console.formatting import print_error, print_stage_header -from ra_aid.console.output import cpm, print_agent_output +from ra_aid.console.output import print_agent_output from ra_aid.exceptions import ( AgentInterrupt, FallbackToolExecutionError, diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 077d01c..18907ce 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -7,7 +7,6 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, System from langchain_core.tools import BaseTool from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES -from ra_aid.console.output import cpm from ra_aid.exceptions import ToolExecutionError from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 082c140..5346921 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -6,7 +6,7 @@ from unittest.mock import Mock, patch import litellm import pytest from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from ra_aid.agent_utils import ( AgentState, @@ -128,9 +128,10 @@ def test_create_agent_openai(mock_model, mock_memory): assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( - mock_model, [], + mock_model, + [], max_tokens=models_params["openai"]["gpt-4"]["token_limit"], - config={'provider': 'openai', 'model': 'gpt-4'} + config={"provider": "openai", "model": "gpt-4"}, ) @@ -144,9 +145,10 @@ def test_create_agent_no_token_limit(mock_model, mock_memory): assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( - mock_model, [], + mock_model, + [], max_tokens=DEFAULT_TOKEN_LIMIT, - config={'provider': 'unknown', 'model': 'unknown-model'} + config={"provider": "unknown", "model": "unknown-model"}, ) @@ -163,7 +165,7 @@ def test_create_agent_missing_config(mock_model, mock_memory): mock_model, [], max_tokens=DEFAULT_TOKEN_LIMIT, - config={'provider': 'openai'} + config={"provider": "openai"}, ) @@ -207,9 +209,10 @@ def test_create_agent_with_checkpointer(mock_model, mock_memory): assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( - mock_model, [], + mock_model, + [], max_tokens=models_params["openai"]["gpt-4"]["token_limit"], - config={'provider': 'openai', 'model': 'gpt-4'} + config={"provider": "openai", "model": "gpt-4"}, ) diff --git a/tests/ra_aid/test_ciayn_agent.py b/tests/ra_aid/test_ciayn_agent.py index 896db33..4cd9dc2 100644 --- a/tests/ra_aid/test_ciayn_agent.py +++ b/tests/ra_aid/test_ciayn_agent.py @@ -26,7 +26,6 @@ class DummyTool: class DummyModel: def invoke(self, _messages: list[BaseMessage]): - return AIMessage("dummy_tool()") def bind_tools(self, tools, tool_choice): diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index 3273391..6fa325d 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -86,6 +86,7 @@ class TestFallbackHandler(unittest.TestCase): def test_load_fallback_tool_models(self): import ra_aid.fallback_handler as fh + original_supported = fh.supported_top_tool_models fh.supported_top_tool_models = [ {"provider": "dummy", "model": "dummy_model", "type": "prompt"} @@ -95,13 +96,16 @@ class TestFallbackHandler(unittest.TestCase): fh.supported_top_tool_models = original_supported def test_extract_failed_tool_name(self): - from ra_aid.exceptions import ToolExecutionError, FallbackToolExecutionError + from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError + # Case when tool_name is provided - error1 = ToolExecutionError("Error", base_message="dummy", tool_name="dummy_tool") + error1 = ToolExecutionError( + "Error", base_message="dummy", tool_name="dummy_tool" + ) name1 = self.fallback_handler.extract_failed_tool_name(error1) self.assertEqual(name1, "dummy_tool") # Case when tool_name is not provided but regex works - error2 = ToolExecutionError("error with name=\"test_tool\"") + error2 = ToolExecutionError('error with name="test_tool"') name2 = self.fallback_handler.extract_failed_tool_name(error2) self.assertEqual(name2, "test_tool") # Case when regex fails and exception is raised @@ -110,16 +114,13 @@ class TestFallbackHandler(unittest.TestCase): self.fallback_handler.extract_failed_tool_name(error3) def test_find_tool_to_bind(self): - # Create a dummy tool to be found - class DummyTool: - def invoke(self, args): - return "result" class DummyWrapper: def __init__(self, func): self.func = func - def dummy_func(args): + + def dummy_func(_args): return "result" - dummy_tool = DummyTool() + dummy_wrapper = DummyWrapper(dummy_func) self.agent.tools.append(dummy_wrapper) tool = self.fallback_handler._find_tool_to_bind(self.agent, dummy_func.__name__) @@ -134,53 +135,82 @@ class TestFallbackHandler(unittest.TestCase): self.tools = tools self.tool_choice = tool_choice return self + def with_retry(self, stop_after_attempt): return self + def invoke(self, msg_list): return "dummy_response" + dummy_model = DummyModel() + # Set current tool for binding class DummyTool: def invoke(self, args): return "result" + self.fallback_handler.current_tool_to_bind = DummyTool() self.fallback_handler.current_failing_tool_name = "test_tool" # Test with force calling ("fc") type fallback_model_fc = {"type": "fc"} - bound_model_fc = self.fallback_handler._bind_tool_model(dummy_model, fallback_model_fc) + bound_model_fc = self.fallback_handler._bind_tool_model( + dummy_model, fallback_model_fc + ) self.assertTrue(hasattr(bound_model_fc, "tool_choice")) self.assertEqual(bound_model_fc.tool_choice, "test_tool") # Test with prompt type fallback_model_prompt = {"type": "prompt"} - bound_model_prompt = self.fallback_handler._bind_tool_model(dummy_model, fallback_model_prompt) + bound_model_prompt = self.fallback_handler._bind_tool_model( + dummy_model, fallback_model_prompt + ) self.assertTrue(bound_model_prompt.tool_choice is None) def test_invoke_fallback(self): - from unittest.mock import patch import os - import ra_aid.llm as llm + from unittest.mock import patch # Successful fallback scenario with proper API key set - with patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}), \ - patch("ra_aid.fallback_handler.supported_top_tool_models", new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}]), \ - patch("ra_aid.fallback_handler.validate_provider_env", return_value=True), \ - patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm: + with ( + patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}), + patch( + "ra_aid.fallback_handler.supported_top_tool_models", + new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}], + ), + patch("ra_aid.fallback_handler.validate_provider_env", return_value=True), + patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm, + ): + class DummyModel: def bind_tools(self, tools, tool_choice=None): return self + def with_retry(self, stop_after_attempt): return self + def invoke(self, msg_list): return DummyResponse() + class DummyResponse: - additional_kwargs = {"tool_calls": [{"id": "1", "type": "test", "function": {"name": "dummy_tool", "arguments": "{\"a\":1}"}}]} + additional_kwargs = { + "tool_calls": [ + { + "id": "1", + "type": "test", + "function": {"name": "dummy_tool", "arguments": '{"a":1}'}, + } + ] + } + def dummy_initialize_llm(provider, model_name): return DummyModel() + mock_init_llm.side_effect = dummy_initialize_llm + # Set current tool for fallback class DummyTool: def invoke(self, args): return "tool_result" + self.fallback_handler.current_tool_to_bind = DummyTool() self.fallback_handler.current_failing_tool_name = "dummy_tool" # Add dummy tool for lookup in invoke_prompt_tool_call @@ -194,39 +224,57 @@ class TestFallbackHandler(unittest.TestCase): }, ) ) - result = self.fallback_handler.invoke_fallback({"provider": "dummy", "model": "dummy_model", "type": "prompt"}) + result = self.fallback_handler.invoke_fallback( + {"provider": "dummy", "model": "dummy_model", "type": "prompt"} + ) self.assertIsInstance(result, list) self.assertEqual(result[1], "tool_result") # Failed fallback scenario due to missing API key (simulate by empty environment) - with patch.dict(os.environ, {}, clear=True), \ - patch("ra_aid.fallback_handler.supported_top_tool_models", new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}]), \ - patch("ra_aid.fallback_handler.validate_provider_env", return_value=False), \ - patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm: + with ( + patch.dict(os.environ, {}, clear=True), + patch( + "ra_aid.fallback_handler.supported_top_tool_models", + new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}], + ), + patch("ra_aid.fallback_handler.validate_provider_env", return_value=False), + patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm, + ): + class FailingDummyModel: def bind_tools(self, tools, tool_choice=None): return self + def with_retry(self, stop_after_attempt): return self + def invoke(self, msg_list): raise Exception("API key missing") + def failing_initialize_llm(provider, model_name): return FailingDummyModel() + mock_init_llm.side_effect = failing_initialize_llm - fallback_result = self.fallback_handler.invoke_fallback({"provider": "dummy", "model": "dummy_model", "type": "prompt"}) + fallback_result = self.fallback_handler.invoke_fallback( + {"provider": "dummy", "model": "dummy_model", "type": "prompt"} + ) self.assertIsNone(fallback_result) # Test that the overall fallback mechanism raises FallbackToolExecutionError when all models fail # Set failure count to trigger the fallback attempt in attempt_fallback from ra_aid.exceptions import FallbackToolExecutionError - self.fallback_handler.tool_failure_consecutive_failures = self.fallback_handler.max_failures + + self.fallback_handler.tool_failure_consecutive_failures = ( + self.fallback_handler.max_failures + ) with self.assertRaises(FallbackToolExecutionError) as cm: self.fallback_handler.attempt_fallback() self.assertIn("All fallback models have failed", str(cm.exception)) def test_construct_prompt_msg_list(self): msgs = self.fallback_handler.construct_prompt_msg_list() - from ra_aid.fallback_handler import SystemMessage, HumanMessage + from ra_aid.fallback_handler import HumanMessage, SystemMessage + self.assertTrue(any(isinstance(m, SystemMessage) for m in msgs)) self.assertTrue(any(isinstance(m, HumanMessage) for m in msgs)) # Test with failed_messages added @@ -238,13 +286,17 @@ class TestFallbackHandler(unittest.TestCase): # Create dummy tool function def dummy_tool_func(args): return "invoked_result" + dummy_tool_func.__name__ = "dummy_tool" + # Create wrapper class class DummyToolWrapper: def __init__(self, func): self.func = func + def invoke(self, args): return self.func(args) + dummy_wrapper = DummyToolWrapper(dummy_tool_func) self.fallback_handler.tools = [dummy_wrapper] tool_call_req = {"name": "dummy_tool", "arguments": {"x": 42}} @@ -255,9 +307,13 @@ class TestFallbackHandler(unittest.TestCase): dummy_tool_call = { "id": "123", "type": "test", - "function": {"name": "dummy_tool", "arguments": "{\"x\":42}"} + "function": {"name": "dummy_tool", "arguments": '{"x":42}'}, } - DummyResponse = type("DummyResponse", (), {"additional_kwargs": {"tool_calls": [dummy_tool_call]}}) + DummyResponse = type( + "DummyResponse", + (), + {"additional_kwargs": {"tool_calls": [dummy_tool_call]}}, + ) result = self.fallback_handler.base_message_to_tool_call_dict(DummyResponse) self.assertEqual(result["id"], "123") self.assertEqual(result["name"], "dummy_tool") @@ -285,22 +341,30 @@ class TestFallbackHandler(unittest.TestCase): def test_handle_failure_response(self): from ra_aid.exceptions import ToolExecutionError + def dummy_handle_failure(error, agent): return ["fallback_response"] - self.fallback_handler.handle_failure = dummy_handle_failure - response = self.fallback_handler.handle_failure_response(ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React") - from ra_aid.fallback_handler import SystemMessage - self.assertTrue(all(isinstance(m, SystemMessage) for m in response)) - response_non = self.fallback_handler.handle_failure_response(ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other") - self.assertIsNone(response_non) + self.fallback_handler.handle_failure = dummy_handle_failure + response = self.fallback_handler.handle_failure_response( + ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React" + ) + from ra_aid.fallback_handler import SystemMessage + + self.assertTrue(all(isinstance(m, SystemMessage) for m in response)) + response_non = self.fallback_handler.handle_failure_response( + ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other" + ) + self.assertIsNone(response_non) def test_init_msg_list_non_overlapping(self): # Test when the first two and last two messages do not overlap. full_list = ["msg1", "msg2", "msg3", "msg4", "msg5"] self.fallback_handler.init_msg_list(full_list) # Expected merged list: first two ("msg1", "msg2") plus last two ("msg4", "msg5") - self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"]) + self.assertEqual( + self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"] + ) def test_init_msg_list_with_overlap(self): # Test when the last two messages overlap with the first two. @@ -309,5 +373,6 @@ class TestFallbackHandler(unittest.TestCase): # Expected merged list: first two ("msg1", "msg2") plus "msg3" from the last two, since "msg1" was already present. self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg3"]) + if __name__ == "__main__": unittest.main() diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index 7f96ed4..4ea9415 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -121,7 +121,9 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch def test_initialize_expert_unsupported_provider(clean_env): """Test error handling for unsupported provider in expert mode.""" - with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"): + with pytest.raises( + ValueError, match=r"Missing required environment variable for provider: unknown" + ): initialize_expert_llm("unknown", "model") @@ -197,7 +199,10 @@ def test_initialize_unsupported_provider(clean_env): """Test initialization with unsupported provider raises ValueError""" with pytest.raises(ValueError) as exc_info: initialize_llm("unsupported", "model") - assert str(exc_info.value) == "Missing required environment variable for provider: unsupported" + assert ( + str(exc_info.value) + == "Missing required environment variable for provider: unsupported" + ) def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemini): From 0df5d4333350de26295bc96c1505f22dcaec7292 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 13:50:32 -0800 Subject: [PATCH 34/45] feat(main.py): import models_params and set default temperature for models that support it to improve user experience fix(ciayn_agent.py): update fallback tool error messages to use FallbackToolExecutionError for better error handling fix(config.py): remove unnecessary blank line to maintain code style consistency fix(fallback_handler.py): raise FallbackToolExecutionError for better error clarity when tools are not found fix(llm.py): set default temperature to 0.7 and notify user when not provided for models that support it test(test_llm.py): update tests to check for default temperature behavior and improve error messages for unsupported providers --- ra_aid/__main__.py | 10 ++++++---- ra_aid/agents/ciayn_agent.py | 6 ++---- ra_aid/config.py | 1 + ra_aid/fallback_handler.py | 4 ++-- ra_aid/llm.py | 6 ++++-- tests/ra_aid/test_llm.py | 34 ++++++++++++++++++++++++++-------- 6 files changed, 41 insertions(+), 20 deletions(-) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 2a15872..46bd819 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -8,6 +8,7 @@ from langgraph.checkpoint.memory import MemorySaver from rich.console import Console from rich.panel import Panel from rich.text import Text +from ra_aid.models_params import models_params from ra_aid import print_error, print_stage_header from ra_aid.__version__ import __version__ @@ -23,10 +24,12 @@ from ra_aid.config import ( DEFAULT_RECURSION_LIMIT, VALID_PROVIDERS, ) +from ra_aid.console.output import cpm from ra_aid.dependencies import check_dependencies from ra_aid.env import validate_environment from ra_aid.llm import initialize_llm from ra_aid.logging_config import get_logger, setup_logging +from ra_aid.models_params import DEFAULT_TEMPERATURE from ra_aid.project_info import format_project_info, get_project_info from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT from ra_aid.tool_configs import get_chat_tools @@ -309,7 +312,6 @@ def main(): logger.debug("Environment validation successful") # Validate model configuration early - from ra_aid.models_params import models_params model_config = models_params.get(args.provider, {}).get(args.model or "", {}) supports_temperature = model_config.get( @@ -321,10 +323,10 @@ def main(): if supports_temperature and args.temperature is None: args.temperature = model_config.get("default_temperature") if args.temperature is None: - print_error( - f"Temperature must be provided for model {args.model} which supports temperature" + cpm( + f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}." ) - sys.exit(1) + args.temperature = DEFAULT_TEMPERATURE logger.debug( f"Using default temperature {args.temperature} for model {args.model}" ) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 18907ce..ca5a8db 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -143,8 +143,6 @@ class CiaynAgent: if code.endswith("```"): code = code[:-3].strip() - # logger.debug(f"_execute_tool: stripped code: {code}") - # if the eval fails, try to extract it via a model call if validate_function_call_pattern(code): functions_list = "\n\n".join(self.available_functions) @@ -176,8 +174,8 @@ class CiaynAgent: msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n" # Passing the fallback raw invocation may confuse our llm, as invocation methods may differ. # msg += f"{fallback_response[0]}\n" - msg += f"{e.tool_name}" - msg += f"{fallback_response[1]}" + msg += f"{e.tool_name}\n" + msg += f"\n{fallback_response[1]}\n\n" return msg def _create_agent_chunk(self, content: str) -> Dict[str, Any]: diff --git a/ra_aid/config.py b/ra_aid/config.py index 2f5eab0..9393ba0 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3 FALLBACK_TOOL_MODEL_LIMIT = 5 RETRY_FALLBACK_COUNT = 3 + VALID_PROVIDERS = [ "anthropic", "openai", diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index fdf45c6..1958a00 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -349,7 +349,7 @@ class FallbackHandler: ): return self.current_tool_to_bind.invoke(arguments) else: - raise Exception(f"Tool '{name}' not found in available tools.") + raise FallbackToolExecutionError(f"Tool '{name}' not found in available tools.") def base_message_to_tool_call_dict(self, response: BaseMessage): """ @@ -365,7 +365,7 @@ class FallbackHandler: tool_calls = self.get_tool_calls(response) if not tool_calls: - raise Exception( + raise FallbackToolExecutionError( f"Could not extract tool_call_dict from response: {response}" ) diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 4558beb..506028c 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -9,6 +9,7 @@ from langchain_openai import ChatOpenAI from openai import OpenAI from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner +from ra_aid.console.output import cpm from ra_aid.logging_config import get_logger from .models_params import models_params @@ -228,8 +229,9 @@ def create_llm_client( temp_kwargs = {"temperature": 0} if supports_temperature else {} elif supports_temperature: if temperature is None: - raise ValueError( - f"Temperature must be provided for model {model_name} which supports temperature" + temperature = 0.7 + cpm( + "This model supports temperature argument but none was given. Setting default temperature to 0.7." ) temp_kwargs = {"temperature": temperature} else: diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index 8314367..853be14 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -237,7 +237,7 @@ def test_initialize_openai_compatible(clean_env, mock_openai): def test_initialize_unsupported_provider(clean_env): """Test initialization with unsupported provider raises ValueError""" - with pytest.raises(ValueError, match=r"Unsupported provider: unknown"): + with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"): initialize_llm("unknown", "model") @@ -259,15 +259,33 @@ def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemin max_retries=5, ) - # Test error when no temperature provided for models that support it - with pytest.raises(ValueError, match="Temperature must be provided for model"): - initialize_llm("openai", "test-model") + # Test default temperature when none is provided for models that support it + initialize_llm("openai", "test-model") + mock_openai.assert_called_with( + api_key="test-key", + model="test-model", + temperature=0.7, + timeout=180, + max_retries=5, + ) - with pytest.raises(ValueError, match="Temperature must be provided for model"): - initialize_llm("anthropic", "test-model") + initialize_llm("anthropic", "test-model") + mock_anthropic.assert_called_with( + api_key="test-key", + model_name="test-model", + temperature=0.7, + timeout=180, + max_retries=5, + ) - with pytest.raises(ValueError, match="Temperature must be provided for model"): - initialize_llm("gemini", "test-model") + initialize_llm("gemini", "test-model") + mock_gemini.assert_called_with( + api_key="test-key", + model="test-model", + temperature=0.7, + timeout=180, + max_retries=5, + ) # Test expert models don't require temperature initialize_expert_llm("openai", "o1") From f3a5ce6d8e2cae42b4dbfa4f9278ccbc6c7b1b96 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 13:50:43 -0800 Subject: [PATCH 35/45] chore(ciayn_agent.py): remove debug logging for extracted tool call to clean up logs and reduce verbosity --- ra_aid/agents/ciayn_agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index ca5a8db..a167b25 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -229,7 +229,6 @@ class CiaynAgent: raise ToolExecutionError("Failed to extract tool call") ma = matches[0][0].strip() mb = matches[0][1].strip().replace("\n", " ") - logger.debug(f"Extracted tool call: {ma}({mb})") return f"{ma}({mb})" def _trim_chat_history( From 27400d62250888a18fe8980b92f2845f263c4fb7 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 13:54:06 -0800 Subject: [PATCH 36/45] feat(ciayn_agent.py): add fallback_fixed_msg to inform users about fallback tool handling fix(ciayn_agent.py): ensure error message is logged in chat history when fallback response is empty --- ra_aid/agents/ciayn_agent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index a167b25..b16fff1 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -112,6 +112,9 @@ class CiaynAgent: "Execute efficiently yet completely as a fully autonomous agent." ) self.error_message_template = "Your tool call caused an error: {e}\n\nPlease correct your tool call and try again." + self.fallback_fixed_msg = HumanMessage( + "Fallback tool handler has fixed the tool call see: for the output." + ) def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" @@ -166,11 +169,12 @@ class CiaynAgent: self, fallback_response: list[Any], e: ToolExecutionError ) -> str: err_msg = HumanMessage(content=self.error_message_template.format(e=e)) - self.chat_history.append(err_msg) if not fallback_response: + self.chat_history.append(err_msg) return "" + self.chat_history.append(self.fallback_fixed_msg) msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n" # Passing the fallback raw invocation may confuse our llm, as invocation methods may differ. # msg += f"{fallback_response[0]}\n" From 56ddd967c0612f34eb46279383dd6c16f8ab3faa Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 13:54:36 -0800 Subject: [PATCH 37/45] fix(imports): remove redundant import of models_params in __main__.py for cleaner code style(fallback_handler.py): format error message for better readability style(test_llm.py): format exception assertion for better readability --- ra_aid/__main__.py | 3 +-- ra_aid/fallback_handler.py | 4 +++- tests/ra_aid/test_llm.py | 4 +++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 46bd819..cf787fe 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -8,7 +8,6 @@ from langgraph.checkpoint.memory import MemorySaver from rich.console import Console from rich.panel import Panel from rich.text import Text -from ra_aid.models_params import models_params from ra_aid import print_error, print_stage_header from ra_aid.__version__ import __version__ @@ -29,7 +28,7 @@ from ra_aid.dependencies import check_dependencies from ra_aid.env import validate_environment from ra_aid.llm import initialize_llm from ra_aid.logging_config import get_logger, setup_logging -from ra_aid.models_params import DEFAULT_TEMPERATURE +from ra_aid.models_params import DEFAULT_TEMPERATURE, models_params from ra_aid.project_info import format_project_info, get_project_info from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT from ra_aid.tool_configs import get_chat_tools diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 1958a00..3f7750c 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -349,7 +349,9 @@ class FallbackHandler: ): return self.current_tool_to_bind.invoke(arguments) else: - raise FallbackToolExecutionError(f"Tool '{name}' not found in available tools.") + raise FallbackToolExecutionError( + f"Tool '{name}' not found in available tools." + ) def base_message_to_tool_call_dict(self, response: BaseMessage): """ diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index 853be14..2193d8b 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -237,7 +237,9 @@ def test_initialize_openai_compatible(clean_env, mock_openai): def test_initialize_unsupported_provider(clean_env): """Test initialization with unsupported provider raises ValueError""" - with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"): + with pytest.raises( + ValueError, match=r"Missing required environment variable for provider: unknown" + ): initialize_llm("unknown", "model") From 9f78e7d36cce6338c9c21cc19c6cd9b6d032eda2 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 14:07:51 -0800 Subject: [PATCH 38/45] chore(issue.md): remove outdated issue documentation for LLM Tool Call Fallback Feature as it is no longer relevant --- issue.md | 90 -------------------------------------------------------- 1 file changed, 90 deletions(-) delete mode 100644 issue.md diff --git a/issue.md b/issue.md deleted file mode 100644 index 3bd0988..0000000 --- a/issue.md +++ /dev/null @@ -1,90 +0,0 @@ -# LLM Tool Call Fallback Feature - -## Overview -Add functionality to automatically fallback to alternative LLM models when a tool call experiences multiple consecutive failures. - -## Background -Currently, when a tool call fails due to LLM-related errors (e.g., invalid format), there is no automatic fallback mechanism. This often causes infinite loop of erroring tool calls. - -## Implementation Details - -### Configuration -- Add new configuration value `max_tool_failures` (default: 3) to track consecutive failures before triggering fallback -- Add new command line argument `--no-fallback-tool` to disable fallback behavior (enabled by default) -- **Add new command line argument** `--fallback-tool-models` to specify a comma-separated list of fallback tool models (default: "gpt-3.5-turbo,gpt-4") - This list defines the fallback model sequence used by forced tool calls (via `bind_tools`) when tool call failures occur. -- Track failure count per tool call context -- Reset failure counter on successful tool call -- Store fallback model sequence per provider -- Need to validate if ENV vars are set for provider usage of that fallback model - before usage, if that fallback ENV is not available then fallback to the next model -- Have default list of common models, first try `claude-3-5-sonnet-20241022` but - have many alternative fallback models. - -### Tool Call Wrapper -Create a new wrapper function to handle tool call execution with fallback logic: - -```python -def execute_tool_with_fallback(tool_call_func, *args, **kwargs): - failures = 0 - max_failures = get_config().max_tool_failures - - while failures < max_failures: - try: - return tool_call_func(*args, **kwargs) - except LLMError as e: - failures += 1 - if failures >= max_failures: - # Use forced tool call via bind_tools with retry: - llm_retry = llm_model.with_retry(stop_after_attempt=3) # Try three times - try_fallback_model(force=True, model=llm_retry) - # Merge fallback model chat messages back into the original chat history. - merge_fallback_chat_history() - failures = 0 # Reset counter for new model - else: - raise -``` - -The prompt passed to `try_fallback_model`, should be the failed last few failing tool calls. - -### Model Fallback Sequence -Define fallback sequences for each provider based on model capabilities: - -1. Try same provider's smaller models -2. Try alternative providers' similar models -3. Raise final error if all fallbacks fail - -## Risks and Mitigations -1. **Cost** - - Risk: Fallback to more expensive models - - Mitigation: Configure cost limits and preferred fallback sequences - -2. **State Management** - - Risk: Loss of context during fallbacks - - Mitigation: Preserve conversation state and tool context - -## Relevant Files -- ra_aid/agents/ciayn_agent.py -- ra_aid/llm.py -- ra_aid/agent_utils.py -- ra_aid/__main__.py -- ra_aid/models_params.py - -## Acceptance Criteria -1. Tool calls automatically attempt fallback models after N consecutive failures -2. `--no-fallback-tool` argument successfully disables fallback behavior -3. Fallback sequence respects provider and model capabilities -4. Original error is preserved if all fallbacks fail -5. Unit tests cover fallback scenarios and edge cases -6. README.md updated to reflect new behavior - -## Documentation Updates -1. Add fallback feature to main README -2. Document `--no-fallback-tool` in CLI help -3. Document provider-specific fallback sequences - -## Future Considerations -1. Allow custom fallback sequences via configuration -2. Add monitoring and alerting for fallback frequency -3. Optimize fallback selection based on historical success rates -4. Cost-aware fallback routing From f65918cfd39f7eb446af82c9c6860a09495aaa13 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 14:11:33 -0800 Subject: [PATCH 39/45] refactor(ciayn_agent.py): remove unused variables tool_failure_current_provider and tool_failure_current_model to clean up the code and improve readability --- ra_aid/agents/ciayn_agent.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index b16fff1..8d1393f 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -105,8 +105,6 @@ class CiaynAgent: for t in tools: self.available_functions.append(get_function_info(t.func)) - self.tool_failure_current_provider = None - self.tool_failure_current_model = None self.fallback_handler = FallbackHandler(config, tools) self.sys_message = SystemMessage( "Execute efficiently yet completely as a fully autonomous agent." From 69281c31db0ad8e07fc7fb763a277e576fa304c7 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 14:20:11 -0800 Subject: [PATCH 40/45] chore(llm.py): remove unused merge_chat_history function to clean up codebase chore(test_fallback_handler.py): remove references to the removed merge_chat_history function in tests to maintain consistency --- ra_aid/llm.py | 19 ------------------- tests/ra_aid/test_fallback_handler.py | 6 ------ 2 files changed, 25 deletions(-) diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 506028c..a505638 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -322,22 +322,3 @@ def validate_provider_env(provider: str) -> bool: if key: return bool(os.getenv(key)) return False - - -def merge_chat_history( - original_history: List[BaseMessage], fallback_history: List[BaseMessage] -) -> List[BaseMessage]: - """Merge original and fallback chat histories while preserving order. - - Args: - original_history: The original chat message history - fallback_history: Additional messages from fallback attempts - - Returns: - List[BaseMessage]: Combined message history preserving chronological order - - Note: - The function appends fallback messages to maintain context for future - interactions while preserving the original conversation flow. - """ - return original_history + fallback_history diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index 6fa325d..47f9ebc 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -60,19 +60,14 @@ class TestFallbackHandler(unittest.TestCase): return DummyModel() - def dummy_merge_chat_history(): - return ["merged"] - def dummy_validate_provider_env(provider): return True import ra_aid.llm as llm original_initialize = llm.initialize_llm - original_merge = llm.merge_chat_history original_validate = llm.validate_provider_env llm.initialize_llm = dummy_initialize_llm - llm.merge_chat_history = dummy_merge_chat_history llm.validate_provider_env = dummy_validate_provider_env self.fallback_handler.tool_failure_consecutive_failures = 2 @@ -81,7 +76,6 @@ class TestFallbackHandler(unittest.TestCase): self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0) llm.initialize_llm = original_initialize - llm.merge_chat_history = original_merge llm.validate_provider_env = original_validate def test_load_fallback_tool_models(self): From 7c828053d33f233be08a9b7332ed984fc8c08ae0 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 14:40:46 -0800 Subject: [PATCH 41/45] chore(main.py): remove deprecated fallback tool arguments to simplify configuration and reduce complexity --- ra_aid/__main__.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index cf787fe..caa9f10 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -153,22 +153,11 @@ Examples: action="store_false", help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.", ) - parser.add_argument( - "--no-fallback-tool", - action="store_true", - help="Disable fallback model switching.", - ) parser.add_argument( "--experimental-fallback-handler", action="store_true", help="Enable experimental fallback handler.", ) - parser.add_argument( - "--fallback-tool-models", - type=str, - default="gpt-3.5-turbo,gpt-4", - help="Comma-separated list of fallback models to use in order.", - ) parser.add_argument( "--recursion-limit", type=int, @@ -471,16 +460,12 @@ def main(): ) _global_memory["config"]["planner_model"] = args.planner_model or args.model - _global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool - # Store research config with fallback to base values _global_memory["config"]["research_provider"] = ( args.research_provider or args.provider ) _global_memory["config"]["research_model"] = args.research_model or args.model - # Store fallback tool configuration - _global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool # Store temperature in global config _global_memory["config"]["temperature"] = args.temperature From 119afd8600add41d462c121a36c4b92e7e2a5be6 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 14:41:42 -0800 Subject: [PATCH 42/45] refactor(tool_configs.py): remove unused get_all_tools_simple function to clean up code and improve maintainability --- ra_aid/tool_configs.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ra_aid/tool_configs.py b/ra_aid/tool_configs.py index 729b886..ba657cb 100644 --- a/ra_aid/tool_configs.py +++ b/ra_aid/tool_configs.py @@ -64,11 +64,6 @@ def get_read_only_tools( return tools -def get_all_tools_simple(): - """Return a list containing all available tools using existing group methods.""" - return get_all_tools() - - def get_all_tools() -> list[BaseTool]: """Return a list containing all available tools from different groups.""" all_tools = [] From 4a2a0b691cf351894112567baf7bfac58cf73fc3 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 14:42:56 -0800 Subject: [PATCH 43/45] chore(llm.py): remove unused import of BaseMessage to clean up code and improve readability --- ra_aid/llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ra_aid/llm.py b/ra_aid/llm.py index a505638..32d842d 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional from langchain_anthropic import ChatAnthropic from langchain_core.language_models import BaseChatModel -from langchain_core.messages import BaseMessage from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI from openai import OpenAI From 81354df48b2ff32998bdc8b5421acf044223fe92 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 14:45:33 -0800 Subject: [PATCH 44/45] fix(ciayn_agent.py): correct spelling of "execute" in error message for clarity refactor(ciayn_agent.py): improve error handling by chaining exceptions for better debugging docs(output.py): update docstring to include agent_type parameter for clarity on agent behavior --- ra_aid/agents/ciayn_agent.py | 6 ++++-- ra_aid/console/output.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 8d1393f..af93458 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -153,9 +153,11 @@ class CiaynAgent: result = eval(code.strip(), globals_dict) return result except Exception as e: - error_msg = f"Error: {str(e)} \n Could not excute code: {code}" + error_msg = f"Error: {str(e)} \n Could not execute code: {code}" tool_name = self.extract_tool_name(code) - raise ToolExecutionError(error_msg, base_message=msg, tool_name=tool_name) + raise ToolExecutionError( + error_msg, base_message=msg, tool_name=tool_name + ) from e def extract_tool_name(self, code: str) -> str: match = re.match(r"\s*([\w_\-]+)\s*\(", code) diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index 3dc9f1d..dfba0a8 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -16,7 +16,8 @@ def print_agent_output( """Print only the agent's message content, not tool calls. Args: - chunk: A dictionary containing agent or tool messages + chunk: A dictionary containing agent or tool messages. + agent_type: Specifies the type of agent. 'CiaynAgent' handles tool errors internally, while 'React' raises a ToolExecutionError. """ if "agent" in chunk and "messages" in chunk["agent"]: messages = chunk["agent"]["messages"] From 9ff09b9b93feb12186d7dfc22735f8beabf8970c Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Mon, 17 Feb 2025 10:45:49 -0800 Subject: [PATCH 45/45] refactor(shell.py): rename expected_runtime_seconds parameter to timeout for clarity and consistency in function signature --- ra_aid/tools/shell.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ra_aid/tools/shell.py b/ra_aid/tools/shell.py index 3065d13..6bcaa3f 100644 --- a/ra_aid/tools/shell.py +++ b/ra_aid/tools/shell.py @@ -22,13 +22,13 @@ def _truncate_for_log(text: str, max_length: int = 300) -> str: @tool def run_shell_command( - command: str, expected_runtime_seconds: int = 30 + command: str, timeout: int = 30 ) -> Dict[str, Union[str, int, bool]]: """Execute a shell command and return its output. Args: command: The shell command to execute - expected_runtime_seconds: Expected runtime in seconds, defaults to 30. + timeout: Expected runtime in seconds, defaults to 30. If process exceeds 2x this value, it will be terminated gracefully. If process exceeds 3x this value, it will be killed forcefully. @@ -83,7 +83,7 @@ def run_shell_command( print() output, return_code = run_interactive_command( ["/bin/bash", "-c", command], - expected_runtime_seconds=expected_runtime_seconds, + expected_runtime_seconds=timeout, ) print() result = {