refactor(agent_utils.py): remove unnecessary debug log statement to clean up code

refactor(fallback_handler.py): improve error handling by raising a specific exception when all fallback models fail
test(fallback_handler.py): update tests to reflect changes in the fallback handler's error handling and initialization
fix(test_llm.py): update error messages in tests for unsupported providers to be more descriptive and accurate
This commit is contained in:
Ariel Frischer 2025-02-13 18:20:00 -08:00
parent ac13ce746a
commit 15a3291254
4 changed files with 17 additions and 13 deletions

View File

@ -933,7 +933,6 @@ def run_agent_with_retry(
original_prompt, config, test_attempts, auto_test
)
)
cpm(f"res:{should_break, prompt, auto_test, test_attempts}")
if should_break:
break
if prompt != original_prompt:

View File

@ -157,10 +157,16 @@ class FallbackHandler:
for fallback_model in self.fallback_tool_models:
result_list = self.invoke_fallback(fallback_model)
if result_list:
# msg_list_response = [SystemMessage(str(msg)) for msg in result_list]
return result_list
cpm("All fallback models have failed", title="Fallback Failed")
return None
cpm("All fallback models have failed.", title="Fallback Failed")
current_failing_tool_name = self.current_failing_tool_name
self.reset_fallback_handler()
raise FallbackToolExecutionError(
f"All fallback models have failed for tool: {current_failing_tool_name}"
)
def reset_fallback_handler(self):
"""

View File

@ -22,16 +22,17 @@ class TestFallbackHandler(unittest.TestCase):
self.config = {
"max_tool_failures": 2,
"fallback_tool_models": "dummy-fallback-model",
"experimental_fallback_handler": True,
}
self.fallback_handler = FallbackHandler(self.config)
self.fallback_handler = FallbackHandler(self.config, [])
self.logger = DummyLogger()
self.agent = DummyAgent()
def test_handle_failure_increments_counter(self):
from ra_aid.exceptions import ToolExecutionError
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
self.fallback_handler.handle_failure(
"dummy_call()", Exception("Test error"), self.logger, self.agent
)
error_obj = ToolExecutionError("Test error", base_message="dummy_call()", tool_name="dummy_tool")
self.fallback_handler.handle_failure(error_obj, self.agent)
self.assertEqual(
self.fallback_handler.tool_failure_consecutive_failures,
initial_failures + 1,
@ -62,9 +63,7 @@ class TestFallbackHandler(unittest.TestCase):
llm.validate_provider_env = dummy_validate_provider_env
self.fallback_handler.tool_failure_consecutive_failures = 2
self.fallback_handler.attempt_fallback(
"dummy_tool_call()", self.logger, self.agent
)
self.fallback_handler.attempt_fallback()
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
llm.initialize_llm = original_initialize

View File

@ -121,7 +121,7 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch
def test_initialize_expert_unsupported_provider(clean_env):
"""Test error handling for unsupported provider in expert mode."""
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"):
initialize_expert_llm("unknown", "model")
@ -197,7 +197,7 @@ def test_initialize_unsupported_provider(clean_env):
"""Test initialization with unsupported provider raises ValueError"""
with pytest.raises(ValueError) as exc_info:
initialize_llm("unsupported", "model")
assert str(exc_info.value) == "Unsupported provider: unsupported"
assert str(exc_info.value) == "Missing required environment variable for provider: unsupported"
def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemini):