refactor(agent_utils.py): remove unnecessary debug log statement to clean up code
refactor(fallback_handler.py): improve error handling by raising a specific exception when all fallback models fail test(fallback_handler.py): update tests to reflect changes in the fallback handler's error handling and initialization fix(test_llm.py): update error messages in tests for unsupported providers to be more descriptive and accurate
This commit is contained in:
parent
ac13ce746a
commit
15a3291254
|
|
@ -933,7 +933,6 @@ def run_agent_with_retry(
|
||||||
original_prompt, config, test_attempts, auto_test
|
original_prompt, config, test_attempts, auto_test
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
cpm(f"res:{should_break, prompt, auto_test, test_attempts}")
|
|
||||||
if should_break:
|
if should_break:
|
||||||
break
|
break
|
||||||
if prompt != original_prompt:
|
if prompt != original_prompt:
|
||||||
|
|
|
||||||
|
|
@ -157,10 +157,16 @@ class FallbackHandler:
|
||||||
for fallback_model in self.fallback_tool_models:
|
for fallback_model in self.fallback_tool_models:
|
||||||
result_list = self.invoke_fallback(fallback_model)
|
result_list = self.invoke_fallback(fallback_model)
|
||||||
if result_list:
|
if result_list:
|
||||||
# msg_list_response = [SystemMessage(str(msg)) for msg in result_list]
|
|
||||||
return result_list
|
return result_list
|
||||||
cpm("All fallback models have failed", title="Fallback Failed")
|
|
||||||
return None
|
cpm("All fallback models have failed.", title="Fallback Failed")
|
||||||
|
|
||||||
|
current_failing_tool_name = self.current_failing_tool_name
|
||||||
|
self.reset_fallback_handler()
|
||||||
|
|
||||||
|
raise FallbackToolExecutionError(
|
||||||
|
f"All fallback models have failed for tool: {current_failing_tool_name}"
|
||||||
|
)
|
||||||
|
|
||||||
def reset_fallback_handler(self):
|
def reset_fallback_handler(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -22,16 +22,17 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
self.config = {
|
self.config = {
|
||||||
"max_tool_failures": 2,
|
"max_tool_failures": 2,
|
||||||
"fallback_tool_models": "dummy-fallback-model",
|
"fallback_tool_models": "dummy-fallback-model",
|
||||||
|
"experimental_fallback_handler": True,
|
||||||
}
|
}
|
||||||
self.fallback_handler = FallbackHandler(self.config)
|
self.fallback_handler = FallbackHandler(self.config, [])
|
||||||
self.logger = DummyLogger()
|
self.logger = DummyLogger()
|
||||||
self.agent = DummyAgent()
|
self.agent = DummyAgent()
|
||||||
|
|
||||||
def test_handle_failure_increments_counter(self):
|
def test_handle_failure_increments_counter(self):
|
||||||
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
|
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
|
||||||
self.fallback_handler.handle_failure(
|
error_obj = ToolExecutionError("Test error", base_message="dummy_call()", tool_name="dummy_tool")
|
||||||
"dummy_call()", Exception("Test error"), self.logger, self.agent
|
self.fallback_handler.handle_failure(error_obj, self.agent)
|
||||||
)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.fallback_handler.tool_failure_consecutive_failures,
|
self.fallback_handler.tool_failure_consecutive_failures,
|
||||||
initial_failures + 1,
|
initial_failures + 1,
|
||||||
|
|
@ -62,9 +63,7 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
llm.validate_provider_env = dummy_validate_provider_env
|
llm.validate_provider_env = dummy_validate_provider_env
|
||||||
|
|
||||||
self.fallback_handler.tool_failure_consecutive_failures = 2
|
self.fallback_handler.tool_failure_consecutive_failures = 2
|
||||||
self.fallback_handler.attempt_fallback(
|
self.fallback_handler.attempt_fallback()
|
||||||
"dummy_tool_call()", self.logger, self.agent
|
|
||||||
)
|
|
||||||
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
|
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
|
||||||
|
|
||||||
llm.initialize_llm = original_initialize
|
llm.initialize_llm = original_initialize
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,7 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch
|
||||||
|
|
||||||
def test_initialize_expert_unsupported_provider(clean_env):
|
def test_initialize_expert_unsupported_provider(clean_env):
|
||||||
"""Test error handling for unsupported provider in expert mode."""
|
"""Test error handling for unsupported provider in expert mode."""
|
||||||
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
|
with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"):
|
||||||
initialize_expert_llm("unknown", "model")
|
initialize_expert_llm("unknown", "model")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -197,7 +197,7 @@ def test_initialize_unsupported_provider(clean_env):
|
||||||
"""Test initialization with unsupported provider raises ValueError"""
|
"""Test initialization with unsupported provider raises ValueError"""
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
initialize_llm("unsupported", "model")
|
initialize_llm("unsupported", "model")
|
||||||
assert str(exc_info.value) == "Unsupported provider: unsupported"
|
assert str(exc_info.value) == "Missing required environment variable for provider: unsupported"
|
||||||
|
|
||||||
|
|
||||||
def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemini):
|
def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemini):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue