diff --git a/issue.md b/issue.md new file mode 100644 index 0000000..62e246b --- /dev/null +++ b/issue.md @@ -0,0 +1,115 @@ +# LLM Tool Call Fallback Feature + +## Overview +Add functionality to automatically fallback to alternative LLM models when a tool call experiences multiple consecutive failures. + +## Background +Currently, when a tool call fails due to LLM-related errors (e.g., API timeouts, rate limits, context length issues), there is no automatic fallback mechanism. This can lead to interrupted workflows and poor user experience. + +## Relevant Files +- ra_aid/agents/ciayn_agent.py +- ra_aid/llm.py +- ra_aid/agent_utils.py +- ra_aid/__main__.py +- ra_aid/models_params.py + + +## Implementation Details + +### Configuration +- Add new configuration value `max_tool_failures` (default: 3) to track consecutive failures before triggering fallback +- Add new command line argument `--no-fallback-tool` to disable fallback behavior (enabled by default) +- **Add new command line argument** `--fallback-tool-models` to specify a comma-separated list of fallback tool models (default: "gpt-3.5-turbo,gpt-4") + This list defines the fallback model sequence used by forced tool calls (via `bind_tools`) when tool call failures occur. +- Track failure count per tool call context +- Reset failure counter on successful tool call +- Store fallback model sequence per provider +- Need to validate if ENV vars are set for provider usage of that fallback model + before usage, if that fallback ENV is not available then fallback to the next model +- Have default list of common models, first try `claude-3-5-sonnet-20241022` but + have many alternative fallback models. + +### Tool Call Wrapper +Create a new wrapper function to handle tool call execution with fallback logic: + +```python +def execute_tool_with_fallback(tool_call_func, *args, **kwargs): + failures = 0 + max_failures = get_config().max_tool_failures + + while failures < max_failures: + try: + return tool_call_func(*args, **kwargs) + except LLMError as e: + failures += 1 + if failures >= max_failures: + # Use forced tool call via bind_tools with retry: + llm_retry = llm_model.with_retry(stop_after_attempt=3) # Try three times + try_fallback_model(force=True, model=llm_retry) + # Merge fallback model chat messages back into the original chat history. + merge_fallback_chat_history() + failures = 0 # Reset counter for new model + else: + raise +``` + +The prompt passed to `try_fallback_model`, should be the failed last few failing tool calls. + +### Model Fallback Sequence +Define fallback sequences for each provider based on model capabilities: + +1. Try same provider's smaller models +2. Try alternative providers' equivalent models +3. Raise final error if all fallbacks fail + +### Provider Strategy Updates +Update provider strategies to support fallback configuration: +- Add provider-specific fallback sequences +- Handle model capability validation during fallback +- Track successful/failed attempts + +## Risks and Mitigations +1. **Performance Impact** + - Risk: Multiple fallback attempts could increase latency + - Mitigation: Set reasonable max_failures limit and timeouts + +2. **Consistency** + - Risk: Different models may give slightly different outputs + - Mitigation: Validate output schema consistency across models + +3. **Cost** + - Risk: Fallback to more expensive models + - Mitigation: Configure cost limits and preferred fallback sequences + +4. **State Management** + - Risk: Loss of context during fallbacks + - Mitigation: Preserve conversation state and tool context + +## Acceptance Criteria +1. Tool calls automatically attempt fallback models after N consecutive failures +2. `--no-fallback-tool` argument successfully disables fallback behavior +3. Fallback sequence respects provider and model capabilities +4. Original error is preserved if all fallbacks fail +5. Unit tests cover fallback scenarios and edge cases +6. README.md updated to reflect new behavior + +## Testing +1. Unit tests for fallback wrapper +2. Integration tests with mock LLM failures +3. Provider strategy fallback tests +4. Command line argument handling +5. Error preservation and reporting +6. Performance impact measurement +7. Edge cases (e.g., partial failures, timeout handling) +8. State preservation during fallbacks + +## Documentation Updates +1. Add fallback feature to main README +2. Document `--no-fallback-tool` in CLI help +3. Document provider-specific fallback sequences + +## Future Considerations +1. Allow custom fallback sequences via configuration +2. Add monitoring and alerting for fallback frequency +3. Optimize fallback selection based on historical success rates +4. Cost-aware fallback routing diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 21868bd..16fd5df 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -149,6 +149,17 @@ Examples: action="store_false", help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.", ) + parser.add_argument( + "--no-fallback-tool", + action="store_true", + help="Disable fallback model switching.", + ) + parser.add_argument( + "--fallback-tool-models", + type=str, + default="gpt-3.5-turbo,gpt-4", + help="Comma-separated list of fallback models to use in order.", + ) parser.add_argument( "--recursion-limit", type=int, @@ -414,12 +425,35 @@ def main(): ) _global_memory["config"]["planner_model"] = args.planner_model or args.model + _global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool + _global_memory["config"]["fallback_tool_models"] = ( + [ + model.strip() + for model in args.fallback_tool_models.split(",") + if model.strip() + ] + if args.fallback_tool_models + else [] + ) + # Store research config with fallback to base values _global_memory["config"]["research_provider"] = ( args.research_provider or args.provider ) _global_memory["config"]["research_model"] = args.research_model or args.model + # Store fallback tool configuration + _global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool + _global_memory["config"]["fallback_tool_models"] = ( + [ + model.strip() + for model in args.fallback_tool_models.split(",") + if model.strip() + ] + if args.fallback_tool_models + else [] + ) + # Run research stage print_stage_header("Research Stage") diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index c615cc4..cbe52bf 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -68,12 +68,29 @@ class CiaynAgent: - Memory management with configurable limits """ + class ToolCallFailure: + """Tracks consecutive failures and fallback model usage for tool calls. + + Attributes: + consecutive_failures (int): Count of consecutive failures for current model + current_provider (Optional[str]): Current provider being used + current_model (Optional[str]): Current model being used + used_fallbacks (Set[str]): Set of fallback models already attempted + """ + + def __init__(self): + self.consecutive_failures = 0 + self.current_provider = None + self.current_model = None + self.used_fallbacks = set() + def __init__( self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT, + config: Optional[dict] = None, ): """Initialize the agent with a model and list of tools. @@ -82,7 +99,17 @@ class CiaynAgent: tools: List of tools available to the agent max_history_messages: Maximum number of messages to keep in chat history max_tokens: Maximum number of tokens allowed in message history (None for no limit) + config: Optional configuration dictionary for fallback settings """ + if config is None: + config = {} + self.config = config + self.provider = config.get("provider", "openai") + self.fallback_enabled = config.get("fallback_tool_enabled", True) + fallback_models_str = config.get("fallback_tool_models", "gpt-3.5-turbo,gpt-4") + self.fallback_tool_models = [ + m.strip() for m in fallback_models_str.split(",") if m.strip() + ] self.model = model self.tools = tools self.max_history_messages = max_history_messages @@ -90,6 +117,7 @@ class CiaynAgent: self.available_functions = [] for t in tools: self.available_functions.append(get_function_info(t.func)) + self._tool_failure = CiaynAgent.ToolCallFailure() def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" @@ -221,23 +249,56 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" return base_prompt def _execute_tool(self, code: str) -> str: - """Execute a tool call and return its result.""" - globals_dict = {tool.func.__name__: tool.func for tool in self.tools} + """Execute a tool call with retry and fallback logic and return its result.""" + max_retries = 3 + retries = 0 + last_error = None + while retries < max_retries: + try: + code = code.strip() + if validate_function_call_pattern(code): + functions_list = "\n\n".join(self.available_functions) + code = _extract_tool_call(code, functions_list) + globals_dict = {tool.func.__name__: tool.func for tool in self.tools} + result = eval(code, globals_dict) + self._tool_failure.consecutive_failures = 0 + return result + except Exception as e: + self._handle_tool_failure(code, e) + last_error = e + retries += 1 + raise ToolExecutionError( + f"Error executing code after {max_retries} attempts: {str(last_error)}" + ) + def _handle_tool_failure(self, code: str, error: Exception) -> None: + self._tool_failure.consecutive_failures += 1 + max_failures = self.config.get("max_tool_failures", 3) + if ( + self.fallback_enabled + and self._tool_failure.consecutive_failures >= max_failures + and self.fallback_tool_models + ): + self._attempt_fallback(code) + + def _attempt_fallback(self, code: str) -> None: + new_model = self.fallback_tool_models[0] + failed_tool_call_name = code.split('(')[0].strip() + logger.error( + f"Tool call failed {self._tool_failure.consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}" + ) try: - code = code.strip() - # code = code.replace("\n", " ") - - # if the eval fails, try to extract it via a model call - if validate_function_call_pattern(code): - functions_list = "\n\n".join(self.available_functions) - code = _extract_tool_call(code, functions_list) - - result = eval(code.strip(), globals_dict) - return result - except Exception as e: - error_msg = f"Error executing code: {str(e)}" - raise ToolExecutionError(error_msg) + from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env + if not validate_provider_env(self.provider): + logger.error(f"Missing environment configuration for provider {self.provider}. Cannot fallback.") + else: + self.model = initialize_llm(self.provider, new_model) + self.model.bind_tools(self.tools, tool_choice=failed_tool_call_name) + self._tool_failure.used_fallbacks.add(new_model) + merge_chat_history() # Assuming merge_chat_history handles merging fallback history + self._tool_failure.consecutive_failures = 0 + except Exception as switch_e: + logger.error(f"Fallback model switching failed: {switch_e}") def _create_agent_chunk(self, content: str) -> Dict[str, Any]: """Create an agent chunk in the format expected by print_agent_output.""" diff --git a/ra_aid/config.py b/ra_aid/config.py index 7977167..6c12a93 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -2,3 +2,5 @@ DEFAULT_RECURSION_LIMIT = 100 DEFAULT_MAX_TEST_CMD_RETRIES = 3 +DEFAULT_MAX_TOOL_FAILURES = 3 +MAX_TOOL_FAILURES = 3 diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 4a4038a..9516f0b 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -1,8 +1,9 @@ import os -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from langchain_anthropic import ChatAnthropic from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI @@ -47,9 +48,9 @@ def create_deepseek_client( return ChatDeepseekReasoner( api_key=api_key, base_url=base_url, - temperature=0 - if is_expert - else (temperature if temperature is not None else 1), + temperature=( + 0 if is_expert else (temperature if temperature is not None else 1) + ), model=model_name, ) @@ -72,9 +73,9 @@ def create_openrouter_client( return ChatDeepseekReasoner( api_key=api_key, base_url="https://openrouter.ai/api/v1", - temperature=0 - if is_expert - else (temperature if temperature is not None else 1), + temperature=( + 0 if is_expert else (temperature if temperature is not None else 1) + ), model=model_name, ) @@ -114,7 +115,12 @@ def get_provider_config(provider: str, is_expert: bool = False) -> Dict[str, Any "base_url": "https://api.deepseek.com", }, } - return configs.get(provider, {}) + config = configs.get(provider, {}) + if not config or not config.get("api_key"): + raise ValueError( + f"Missing required environment variable for provider: {provider}" + ) + return config def create_llm_client( @@ -219,8 +225,41 @@ def initialize_llm( return create_llm_client(provider, model_name, temperature, is_expert=False) -def initialize_expert_llm( - provider: str, model_name: str -) -> BaseChatModel: +def initialize_expert_llm(provider: str, model_name: str) -> BaseChatModel: """Initialize an expert language model client based on the specified provider and model.""" return create_llm_client(provider, model_name, temperature=None, is_expert=True) + + +def validate_provider_env(provider: str) -> bool: + """Check if the required environment variables for a provider are set.""" + required_vars = { + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "openrouter": "OPENROUTER_API_KEY", + "openai-compatible": "OPENAI_API_KEY", + "gemini": "GEMINI_API_KEY", + "deepseek": "DEEPSEEK_API_KEY", + } + key = required_vars.get(provider.lower()) + if key: + return bool(os.getenv(key)) + return False + + +def merge_chat_history( + original_history: List[BaseMessage], fallback_history: List[BaseMessage] +) -> List[BaseMessage]: + """Merge original and fallback chat histories while preserving order. + + Args: + original_history: The original chat message history + fallback_history: Additional messages from fallback attempts + + Returns: + List[BaseMessage]: Combined message history preserving chronological order + + Note: + The function appends fallback messages to maintain context for future + interactions while preserving the original conversation flow. + """ + return original_history + fallback_history diff --git a/ra_aid/models_params.py b/ra_aid/models_params.py index ee83f70..6e0072d 100644 --- a/ra_aid/models_params.py +++ b/ra_aid/models_params.py @@ -27,7 +27,6 @@ models_params = { "gpt-4o-mini": {"token_limit": 128000, "supports_temperature": True}, "o1-preview": {"token_limit": 128000, "supports_temperature": False}, "o1-mini": {"token_limit": 128000, "supports_temperature": False}, - "o1-preview": {"token_limit": 128000, "supports_temperature": False}, "o1": {"token_limit": 200000, "supports_temperature": False}, "o3-mini": {"token_limit": 200000, "supports_temperature": False}, }, diff --git a/ra_aid/tool_configs.py b/ra_aid/tool_configs.py index 3950ecd..e4042f1 100644 --- a/ra_aid/tool_configs.py +++ b/ra_aid/tool_configs.py @@ -1,25 +1,18 @@ from ra_aid.tools import ( ask_expert, ask_human, - delete_key_facts, - delete_key_snippets, - deregister_related_files, emit_expert_context, emit_key_facts, emit_key_snippets, - emit_plan, emit_related_files, emit_research_notes, fuzzy_find_project_files, list_directory_tree, - monorepo_detected, - plan_implementation_completed, read_file_tool, ripgrep_search, run_programming_task, run_shell_command, task_completed, - ui_detected, web_search_tavily, ) from ra_aid.tools.agent import ( @@ -29,7 +22,6 @@ from ra_aid.tools.agent import ( request_task_implementation, request_web_research, ) -from ra_aid.tools.memory import one_shot_completed from ra_aid.tools.write_file import write_file_tool diff --git a/ra_aid/tools/expert.py b/ra_aid/tools/expert.py index 3b8188f..6bb1a69 100644 --- a/ra_aid/tools/expert.py +++ b/ra_aid/tools/expert.py @@ -185,7 +185,11 @@ def ask_expert(question: str) -> str: query_parts.extend(["# Question", question]) query_parts.extend( - ["\n # Addidional Requirements", "**DO NOT OVERTHINK**", "**DO NOT OVERCOMPLICATE**"] + [ + "\n # Addidional Requirements", + "**DO NOT OVERTHINK**", + "**DO NOT OVERCOMPLICATE**", + ] ) # Join all parts diff --git a/ra_aid/tools/programmer.py b/ra_aid/tools/programmer.py index 4f62f92..35d3b2f 100644 --- a/ra_aid/tools/programmer.py +++ b/ra_aid/tools/programmer.py @@ -40,7 +40,7 @@ def run_programming_task( If new files are created, emit them after finishing. They can add/modify files, but not remove. Use run_shell_command to remove files. If referencing files you’ll delete, remove them after they finish. - + Use write_file_tool instead if you need to write the entire contents of file(s). If the programmer wrote files, they actually wrote to disk. You do not need to rewrite the output of what the programmer showed you. @@ -117,7 +117,7 @@ def run_programming_task( # Log the programming task log_work_event(f"Executed programming task: {_truncate_for_log(instructions)}") - + # Return structured output return { "output": truncate_output(result[0].decode()) if result[0] else "", diff --git a/tests/ra_aid/agents/test_ciayn_agent.py b/tests/ra_aid/test_ciayn_agent.py similarity index 63% rename from tests/ra_aid/agents/test_ciayn_agent.py rename to tests/ra_aid/test_ciayn_agent.py index b56310a..65794d0 100644 --- a/tests/ra_aid/agents/test_ciayn_agent.py +++ b/tests/ra_aid/test_ciayn_agent.py @@ -1,11 +1,41 @@ +import unittest from unittest.mock import Mock import pytest from langchain_core.messages import AIMessage, HumanMessage from ra_aid.agents.ciayn_agent import CiaynAgent, validate_function_call_pattern +from ra_aid.exceptions import ToolExecutionError +# Dummy tool function for testing retry and fallback behavior +def dummy_tool(): + dummy_tool.attempt += 1 + if dummy_tool.attempt < 3: + raise Exception("Simulated failure") + return "dummy success" + + +dummy_tool.attempt = 0 + + +class DummyTool: + def __init__(self, func): + self.func = func + + +class DummyModel: + def invoke(self, messages): + # Always return a code snippet that calls dummy_tool() + class Response: + content = "dummy_tool()" + + return Response() + def bind_tools(self, tools, tool_choice): + pass + + +# Fixtures from the source file @pytest.fixture def mock_model(): """Create a mock language model.""" @@ -21,6 +51,7 @@ def agent(mock_model): return CiaynAgent(mock_model, tools, max_history_messages=3) +# Trimming test functions def test_trim_chat_history_preserves_initial_messages(agent): """Test that initial messages are preserved during trimming.""" initial_messages = [ @@ -33,9 +64,7 @@ def test_trim_chat_history_preserves_initial_messages(agent): HumanMessage(content="Chat 3"), AIMessage(content="Chat 4"), ] - result = agent._trim_chat_history(initial_messages, chat_history) - # Verify initial messages are preserved assert result[:2] == initial_messages # Verify only last 3 chat messages are kept (due to max_history_messages=3) @@ -47,9 +76,7 @@ def test_trim_chat_history_under_limit(agent): """Test trimming when chat history is under the maximum limit.""" initial_messages = [HumanMessage(content="Initial")] chat_history = [HumanMessage(content="Chat 1"), AIMessage(content="Chat 2")] - result = agent._trim_chat_history(initial_messages, chat_history) - # Verify no trimming occurred assert len(result) == 3 assert result == initial_messages + chat_history @@ -65,9 +92,7 @@ def test_trim_chat_history_over_limit(agent): AIMessage(content="Chat 4"), HumanMessage(content="Chat 5"), ] - result = agent._trim_chat_history(initial_messages, chat_history) - # Verify correct trimming assert len(result) == 4 # initial + max_history_messages assert result[0] == initial_messages[0] # Initial message preserved @@ -83,9 +108,7 @@ def test_trim_chat_history_empty_initial(agent): HumanMessage(content="Chat 3"), AIMessage(content="Chat 4"), ] - result = agent._trim_chat_history(initial_messages, chat_history) - # Verify only last 3 messages are kept assert len(result) == 3 assert result == chat_history[-3:] @@ -98,9 +121,7 @@ def test_trim_chat_history_empty_chat(agent): AIMessage(content="Initial 2"), ] chat_history = [] - result = agent._trim_chat_history(initial_messages, chat_history) - # Verify initial messages are preserved and no trimming occurred assert result == initial_messages assert len(result) == 2 @@ -109,16 +130,13 @@ def test_trim_chat_history_empty_chat(agent): def test_trim_chat_history_token_limit(): """Test trimming based on token limit.""" agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=25) - initial_messages = [HumanMessage(content="Initial")] # ~2 tokens chat_history = [ HumanMessage(content="A" * 40), # ~10 tokens AIMessage(content="B" * 40), # ~10 tokens HumanMessage(content="C" * 40), # ~10 tokens ] - result = agent._trim_chat_history(initial_messages, chat_history) - # Should keep initial message (~2 tokens) and last message (~10 tokens) assert len(result) == 2 assert result[0] == initial_messages[0] @@ -128,16 +146,13 @@ def test_trim_chat_history_token_limit(): def test_trim_chat_history_no_token_limit(): """Test trimming with no token limit set.""" agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None) - initial_messages = [HumanMessage(content="Initial")] chat_history = [ HumanMessage(content="A" * 1000), AIMessage(content="B" * 1000), HumanMessage(content="C" * 1000), ] - result = agent._trim_chat_history(initial_messages, chat_history) - # Should keep initial message and last 2 messages (max_history_messages=2) assert len(result) == 3 assert result[0] == initial_messages[0] @@ -147,7 +162,6 @@ def test_trim_chat_history_no_token_limit(): def test_trim_chat_history_both_limits(): """Test trimming with both message count and token limits.""" agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=35) - initial_messages = [HumanMessage(content="Init")] # ~1 token chat_history = [ HumanMessage(content="A" * 40), # ~10 tokens @@ -155,9 +169,7 @@ def test_trim_chat_history_both_limits(): HumanMessage(content="C" * 40), # ~10 tokens AIMessage(content="D" * 40), # ~10 tokens ] - result = agent._trim_chat_history(initial_messages, chat_history) - # Should first apply message limit (keeping last 3) # Then token limit should further reduce to fit under 15 tokens assert len(result) == 2 # Initial message + 1 message under token limit @@ -165,6 +177,33 @@ def test_trim_chat_history_both_limits(): assert result[1] == chat_history[-1] +# Fallback tests +class TestCiaynAgentFallback(unittest.TestCase): + def setUp(self): + # Reset dummy_tool attempt counter before each test + dummy_tool.attempt = 0 + self.dummy_tool = DummyTool(dummy_tool) + self.model = DummyModel() + # Create a CiaynAgent with the dummy tool + self.agent = CiaynAgent(self.model, [self.dummy_tool]) + + def test_retry_logic_with_failure_recovery(self): + # Test that _execute_tool retries and eventually returns success + result = self.agent._execute_tool("dummy_tool()") + self.assertEqual(result, "dummy success") + + def test_switch_models_on_fallback(self): + # Test fallback behavior by making dummy_tool always fail + def always_fail(): + raise Exception("Persistent failure") + + always_fail_tool = DummyTool(always_fail) + agent = CiaynAgent(self.model, [always_fail_tool]) + with self.assertRaises(ToolExecutionError): + agent._execute_tool("always_fail()") + + +# Function call validation tests class TestFunctionCallValidation: @pytest.mark.parametrize( "test_input", @@ -221,3 +260,54 @@ class TestFunctionCallValidation: def test_multiline_responses(self, test_input): """Test function calls spanning multiple lines.""" assert not validate_function_call_pattern(test_input) + + +class TestCiaynAgentNewMethods(unittest.TestCase): + def setUp(self): + # Create a dummy tool that always fails for testing fallback + def always_fail(): + raise Exception("Failure for fallback test") + self.always_fail_tool = DummyTool(always_fail) + # Create a dummy model that does minimal work for fallback tests + self.dummy_model = DummyModel() + # Initialize CiaynAgent with configuration to trigger fallback quickly + self.agent = CiaynAgent( + self.dummy_model, + [self.always_fail_tool], + config={"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"} + ) + + def test_handle_tool_failure_increments_counter(self): + initial_failures = self.agent._tool_failure.consecutive_failures + self.agent._handle_tool_failure("dummy_call()", Exception("Test error")) + self.assertEqual(self.agent._tool_failure.consecutive_failures, initial_failures + 1) + + def test_attempt_fallback_invokes_fallback_logic(self): + # Monkey-patch initialize_llm, merge_chat_history, and validate_provider_env + # to simulate fallback switching without external dependencies. + def dummy_initialize_llm(provider, model_name, temperature=None): + return self.dummy_model + def dummy_merge_chat_history(): + return ["merged"] + def dummy_validate_provider_env(provider): + return True + import ra_aid.llm as llm + original_initialize = llm.initialize_llm + original_merge = llm.merge_chat_history + original_validate = llm.validate_provider_env + llm.initialize_llm = dummy_initialize_llm + llm.merge_chat_history = dummy_merge_chat_history + llm.validate_provider_env = dummy_validate_provider_env + + # Set failure counter high enough to trigger fallback in _handle_tool_failure + self.agent._tool_failure.consecutive_failures = 2 + # Call _attempt_fallback; it should reset the failure counter to 0 on success. + self.agent._attempt_fallback("always_fail_tool()") + self.assertEqual(self.agent._tool_failure.consecutive_failures, 0) + # Restore original functions + llm.initialize_llm = original_initialize + llm.merge_chat_history = original_merge + llm.validate_provider_env = original_validate + +if __name__ == "__main__": + unittest.main() diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index 2e7ea10..6789132 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -54,7 +54,9 @@ def test_initialize_expert_defaults(clean_env, mock_openai, monkeypatch): monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key") _llm = initialize_expert_llm("openai", "o1") - mock_openai.assert_called_once_with(api_key="test-key", model="o1", reasoning_effort="high") + mock_openai.assert_called_once_with( + api_key="test-key", model="o1", reasoning_effort="high" + ) def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch): @@ -63,7 +65,10 @@ def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch): _llm = initialize_expert_llm("openai", "gpt-4-preview") mock_openai.assert_called_once_with( - api_key="test-key", model="gpt-4-preview", temperature=0, reasoning_effort="high" + api_key="test-key", + model="gpt-4-preview", + temperature=0, + reasoning_effort="high", ) @@ -348,7 +353,9 @@ def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch): # Test LLM client creation with expert mode _llm = create_llm_client("openai", "o1", is_expert=True) - mock_openai.assert_called_with(api_key="expert-key", model="o1", reasoning_effort="high") + mock_openai.assert_called_with( + api_key="expert-key", model="o1", reasoning_effort="high" + ) # Test environment validation monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "")