refactor(agent_utils.py, ciayn_agent.py): remove unused import cpm to clean up code and improve readability
style(tests): format code for better readability and consistency in test files test(tests): update assertions and test cases for better clarity and maintainability
This commit is contained in:
parent
7a2c766824
commit
d5e2e0a9a0
|
|
@ -31,7 +31,7 @@ from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.agents_alias import RAgents
|
from ra_aid.agents_alias import RAgents
|
||||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
||||||
from ra_aid.console.formatting import print_error, print_stage_header
|
from ra_aid.console.formatting import print_error, print_stage_header
|
||||||
from ra_aid.console.output import cpm, print_agent_output
|
from ra_aid.console.output import print_agent_output
|
||||||
from ra_aid.exceptions import (
|
from ra_aid.exceptions import (
|
||||||
AgentInterrupt,
|
AgentInterrupt,
|
||||||
FallbackToolExecutionError,
|
FallbackToolExecutionError,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, System
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
|
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
|
||||||
from ra_aid.console.output import cpm
|
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
from ra_aid.fallback_handler import FallbackHandler
|
from ra_aid.fallback_handler import FallbackHandler
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
|
||||||
import litellm
|
import litellm
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
from ra_aid.agent_utils import (
|
from ra_aid.agent_utils import (
|
||||||
AgentState,
|
AgentState,
|
||||||
|
|
@ -128,9 +128,10 @@ def test_create_agent_openai(mock_model, mock_memory):
|
||||||
|
|
||||||
assert agent == "ciayn_agent"
|
assert agent == "ciayn_agent"
|
||||||
mock_ciayn.assert_called_once_with(
|
mock_ciayn.assert_called_once_with(
|
||||||
mock_model, [],
|
mock_model,
|
||||||
|
[],
|
||||||
max_tokens=models_params["openai"]["gpt-4"]["token_limit"],
|
max_tokens=models_params["openai"]["gpt-4"]["token_limit"],
|
||||||
config={'provider': 'openai', 'model': 'gpt-4'}
|
config={"provider": "openai", "model": "gpt-4"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -144,9 +145,10 @@ def test_create_agent_no_token_limit(mock_model, mock_memory):
|
||||||
|
|
||||||
assert agent == "ciayn_agent"
|
assert agent == "ciayn_agent"
|
||||||
mock_ciayn.assert_called_once_with(
|
mock_ciayn.assert_called_once_with(
|
||||||
mock_model, [],
|
mock_model,
|
||||||
|
[],
|
||||||
max_tokens=DEFAULT_TOKEN_LIMIT,
|
max_tokens=DEFAULT_TOKEN_LIMIT,
|
||||||
config={'provider': 'unknown', 'model': 'unknown-model'}
|
config={"provider": "unknown", "model": "unknown-model"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -163,7 +165,7 @@ def test_create_agent_missing_config(mock_model, mock_memory):
|
||||||
mock_model,
|
mock_model,
|
||||||
[],
|
[],
|
||||||
max_tokens=DEFAULT_TOKEN_LIMIT,
|
max_tokens=DEFAULT_TOKEN_LIMIT,
|
||||||
config={'provider': 'openai'}
|
config={"provider": "openai"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -207,9 +209,10 @@ def test_create_agent_with_checkpointer(mock_model, mock_memory):
|
||||||
|
|
||||||
assert agent == "ciayn_agent"
|
assert agent == "ciayn_agent"
|
||||||
mock_ciayn.assert_called_once_with(
|
mock_ciayn.assert_called_once_with(
|
||||||
mock_model, [],
|
mock_model,
|
||||||
|
[],
|
||||||
max_tokens=models_params["openai"]["gpt-4"]["token_limit"],
|
max_tokens=models_params["openai"]["gpt-4"]["token_limit"],
|
||||||
config={'provider': 'openai', 'model': 'gpt-4'}
|
config={"provider": "openai", "model": "gpt-4"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@ class DummyTool:
|
||||||
|
|
||||||
class DummyModel:
|
class DummyModel:
|
||||||
def invoke(self, _messages: list[BaseMessage]):
|
def invoke(self, _messages: list[BaseMessage]):
|
||||||
|
|
||||||
return AIMessage("dummy_tool()")
|
return AIMessage("dummy_tool()")
|
||||||
|
|
||||||
def bind_tools(self, tools, tool_choice):
|
def bind_tools(self, tools, tool_choice):
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,7 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
|
|
||||||
def test_load_fallback_tool_models(self):
|
def test_load_fallback_tool_models(self):
|
||||||
import ra_aid.fallback_handler as fh
|
import ra_aid.fallback_handler as fh
|
||||||
|
|
||||||
original_supported = fh.supported_top_tool_models
|
original_supported = fh.supported_top_tool_models
|
||||||
fh.supported_top_tool_models = [
|
fh.supported_top_tool_models = [
|
||||||
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
|
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
|
||||||
|
|
@ -95,13 +96,16 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
fh.supported_top_tool_models = original_supported
|
fh.supported_top_tool_models = original_supported
|
||||||
|
|
||||||
def test_extract_failed_tool_name(self):
|
def test_extract_failed_tool_name(self):
|
||||||
from ra_aid.exceptions import ToolExecutionError, FallbackToolExecutionError
|
from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError
|
||||||
|
|
||||||
# Case when tool_name is provided
|
# Case when tool_name is provided
|
||||||
error1 = ToolExecutionError("Error", base_message="dummy", tool_name="dummy_tool")
|
error1 = ToolExecutionError(
|
||||||
|
"Error", base_message="dummy", tool_name="dummy_tool"
|
||||||
|
)
|
||||||
name1 = self.fallback_handler.extract_failed_tool_name(error1)
|
name1 = self.fallback_handler.extract_failed_tool_name(error1)
|
||||||
self.assertEqual(name1, "dummy_tool")
|
self.assertEqual(name1, "dummy_tool")
|
||||||
# Case when tool_name is not provided but regex works
|
# Case when tool_name is not provided but regex works
|
||||||
error2 = ToolExecutionError("error with name=\"test_tool\"")
|
error2 = ToolExecutionError('error with name="test_tool"')
|
||||||
name2 = self.fallback_handler.extract_failed_tool_name(error2)
|
name2 = self.fallback_handler.extract_failed_tool_name(error2)
|
||||||
self.assertEqual(name2, "test_tool")
|
self.assertEqual(name2, "test_tool")
|
||||||
# Case when regex fails and exception is raised
|
# Case when regex fails and exception is raised
|
||||||
|
|
@ -110,16 +114,13 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
self.fallback_handler.extract_failed_tool_name(error3)
|
self.fallback_handler.extract_failed_tool_name(error3)
|
||||||
|
|
||||||
def test_find_tool_to_bind(self):
|
def test_find_tool_to_bind(self):
|
||||||
# Create a dummy tool to be found
|
|
||||||
class DummyTool:
|
|
||||||
def invoke(self, args):
|
|
||||||
return "result"
|
|
||||||
class DummyWrapper:
|
class DummyWrapper:
|
||||||
def __init__(self, func):
|
def __init__(self, func):
|
||||||
self.func = func
|
self.func = func
|
||||||
def dummy_func(args):
|
|
||||||
|
def dummy_func(_args):
|
||||||
return "result"
|
return "result"
|
||||||
dummy_tool = DummyTool()
|
|
||||||
dummy_wrapper = DummyWrapper(dummy_func)
|
dummy_wrapper = DummyWrapper(dummy_func)
|
||||||
self.agent.tools.append(dummy_wrapper)
|
self.agent.tools.append(dummy_wrapper)
|
||||||
tool = self.fallback_handler._find_tool_to_bind(self.agent, dummy_func.__name__)
|
tool = self.fallback_handler._find_tool_to_bind(self.agent, dummy_func.__name__)
|
||||||
|
|
@ -134,53 +135,82 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
self.tool_choice = tool_choice
|
self.tool_choice = tool_choice
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_retry(self, stop_after_attempt):
|
def with_retry(self, stop_after_attempt):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def invoke(self, msg_list):
|
def invoke(self, msg_list):
|
||||||
return "dummy_response"
|
return "dummy_response"
|
||||||
|
|
||||||
dummy_model = DummyModel()
|
dummy_model = DummyModel()
|
||||||
|
|
||||||
# Set current tool for binding
|
# Set current tool for binding
|
||||||
class DummyTool:
|
class DummyTool:
|
||||||
def invoke(self, args):
|
def invoke(self, args):
|
||||||
return "result"
|
return "result"
|
||||||
|
|
||||||
self.fallback_handler.current_tool_to_bind = DummyTool()
|
self.fallback_handler.current_tool_to_bind = DummyTool()
|
||||||
self.fallback_handler.current_failing_tool_name = "test_tool"
|
self.fallback_handler.current_failing_tool_name = "test_tool"
|
||||||
# Test with force calling ("fc") type
|
# Test with force calling ("fc") type
|
||||||
fallback_model_fc = {"type": "fc"}
|
fallback_model_fc = {"type": "fc"}
|
||||||
bound_model_fc = self.fallback_handler._bind_tool_model(dummy_model, fallback_model_fc)
|
bound_model_fc = self.fallback_handler._bind_tool_model(
|
||||||
|
dummy_model, fallback_model_fc
|
||||||
|
)
|
||||||
self.assertTrue(hasattr(bound_model_fc, "tool_choice"))
|
self.assertTrue(hasattr(bound_model_fc, "tool_choice"))
|
||||||
self.assertEqual(bound_model_fc.tool_choice, "test_tool")
|
self.assertEqual(bound_model_fc.tool_choice, "test_tool")
|
||||||
# Test with prompt type
|
# Test with prompt type
|
||||||
fallback_model_prompt = {"type": "prompt"}
|
fallback_model_prompt = {"type": "prompt"}
|
||||||
bound_model_prompt = self.fallback_handler._bind_tool_model(dummy_model, fallback_model_prompt)
|
bound_model_prompt = self.fallback_handler._bind_tool_model(
|
||||||
|
dummy_model, fallback_model_prompt
|
||||||
|
)
|
||||||
self.assertTrue(bound_model_prompt.tool_choice is None)
|
self.assertTrue(bound_model_prompt.tool_choice is None)
|
||||||
|
|
||||||
def test_invoke_fallback(self):
|
def test_invoke_fallback(self):
|
||||||
from unittest.mock import patch
|
|
||||||
import os
|
import os
|
||||||
import ra_aid.llm as llm
|
from unittest.mock import patch
|
||||||
|
|
||||||
# Successful fallback scenario with proper API key set
|
# Successful fallback scenario with proper API key set
|
||||||
with patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}), \
|
with (
|
||||||
patch("ra_aid.fallback_handler.supported_top_tool_models", new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}]), \
|
patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}),
|
||||||
patch("ra_aid.fallback_handler.validate_provider_env", return_value=True), \
|
patch(
|
||||||
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm:
|
"ra_aid.fallback_handler.supported_top_tool_models",
|
||||||
|
new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}],
|
||||||
|
),
|
||||||
|
patch("ra_aid.fallback_handler.validate_provider_env", return_value=True),
|
||||||
|
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm,
|
||||||
|
):
|
||||||
|
|
||||||
class DummyModel:
|
class DummyModel:
|
||||||
def bind_tools(self, tools, tool_choice=None):
|
def bind_tools(self, tools, tool_choice=None):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_retry(self, stop_after_attempt):
|
def with_retry(self, stop_after_attempt):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def invoke(self, msg_list):
|
def invoke(self, msg_list):
|
||||||
return DummyResponse()
|
return DummyResponse()
|
||||||
|
|
||||||
class DummyResponse:
|
class DummyResponse:
|
||||||
additional_kwargs = {"tool_calls": [{"id": "1", "type": "test", "function": {"name": "dummy_tool", "arguments": "{\"a\":1}"}}]}
|
additional_kwargs = {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "1",
|
||||||
|
"type": "test",
|
||||||
|
"function": {"name": "dummy_tool", "arguments": '{"a":1}'},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
def dummy_initialize_llm(provider, model_name):
|
def dummy_initialize_llm(provider, model_name):
|
||||||
return DummyModel()
|
return DummyModel()
|
||||||
|
|
||||||
mock_init_llm.side_effect = dummy_initialize_llm
|
mock_init_llm.side_effect = dummy_initialize_llm
|
||||||
|
|
||||||
# Set current tool for fallback
|
# Set current tool for fallback
|
||||||
class DummyTool:
|
class DummyTool:
|
||||||
def invoke(self, args):
|
def invoke(self, args):
|
||||||
return "tool_result"
|
return "tool_result"
|
||||||
|
|
||||||
self.fallback_handler.current_tool_to_bind = DummyTool()
|
self.fallback_handler.current_tool_to_bind = DummyTool()
|
||||||
self.fallback_handler.current_failing_tool_name = "dummy_tool"
|
self.fallback_handler.current_failing_tool_name = "dummy_tool"
|
||||||
# Add dummy tool for lookup in invoke_prompt_tool_call
|
# Add dummy tool for lookup in invoke_prompt_tool_call
|
||||||
|
|
@ -194,39 +224,57 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = self.fallback_handler.invoke_fallback({"provider": "dummy", "model": "dummy_model", "type": "prompt"})
|
result = self.fallback_handler.invoke_fallback(
|
||||||
|
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
|
||||||
|
)
|
||||||
self.assertIsInstance(result, list)
|
self.assertIsInstance(result, list)
|
||||||
self.assertEqual(result[1], "tool_result")
|
self.assertEqual(result[1], "tool_result")
|
||||||
|
|
||||||
# Failed fallback scenario due to missing API key (simulate by empty environment)
|
# Failed fallback scenario due to missing API key (simulate by empty environment)
|
||||||
with patch.dict(os.environ, {}, clear=True), \
|
with (
|
||||||
patch("ra_aid.fallback_handler.supported_top_tool_models", new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}]), \
|
patch.dict(os.environ, {}, clear=True),
|
||||||
patch("ra_aid.fallback_handler.validate_provider_env", return_value=False), \
|
patch(
|
||||||
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm:
|
"ra_aid.fallback_handler.supported_top_tool_models",
|
||||||
|
new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}],
|
||||||
|
),
|
||||||
|
patch("ra_aid.fallback_handler.validate_provider_env", return_value=False),
|
||||||
|
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm,
|
||||||
|
):
|
||||||
|
|
||||||
class FailingDummyModel:
|
class FailingDummyModel:
|
||||||
def bind_tools(self, tools, tool_choice=None):
|
def bind_tools(self, tools, tool_choice=None):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_retry(self, stop_after_attempt):
|
def with_retry(self, stop_after_attempt):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def invoke(self, msg_list):
|
def invoke(self, msg_list):
|
||||||
raise Exception("API key missing")
|
raise Exception("API key missing")
|
||||||
|
|
||||||
def failing_initialize_llm(provider, model_name):
|
def failing_initialize_llm(provider, model_name):
|
||||||
return FailingDummyModel()
|
return FailingDummyModel()
|
||||||
|
|
||||||
mock_init_llm.side_effect = failing_initialize_llm
|
mock_init_llm.side_effect = failing_initialize_llm
|
||||||
fallback_result = self.fallback_handler.invoke_fallback({"provider": "dummy", "model": "dummy_model", "type": "prompt"})
|
fallback_result = self.fallback_handler.invoke_fallback(
|
||||||
|
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
|
||||||
|
)
|
||||||
self.assertIsNone(fallback_result)
|
self.assertIsNone(fallback_result)
|
||||||
|
|
||||||
# Test that the overall fallback mechanism raises FallbackToolExecutionError when all models fail
|
# Test that the overall fallback mechanism raises FallbackToolExecutionError when all models fail
|
||||||
# Set failure count to trigger the fallback attempt in attempt_fallback
|
# Set failure count to trigger the fallback attempt in attempt_fallback
|
||||||
from ra_aid.exceptions import FallbackToolExecutionError
|
from ra_aid.exceptions import FallbackToolExecutionError
|
||||||
self.fallback_handler.tool_failure_consecutive_failures = self.fallback_handler.max_failures
|
|
||||||
|
self.fallback_handler.tool_failure_consecutive_failures = (
|
||||||
|
self.fallback_handler.max_failures
|
||||||
|
)
|
||||||
with self.assertRaises(FallbackToolExecutionError) as cm:
|
with self.assertRaises(FallbackToolExecutionError) as cm:
|
||||||
self.fallback_handler.attempt_fallback()
|
self.fallback_handler.attempt_fallback()
|
||||||
self.assertIn("All fallback models have failed", str(cm.exception))
|
self.assertIn("All fallback models have failed", str(cm.exception))
|
||||||
|
|
||||||
def test_construct_prompt_msg_list(self):
|
def test_construct_prompt_msg_list(self):
|
||||||
msgs = self.fallback_handler.construct_prompt_msg_list()
|
msgs = self.fallback_handler.construct_prompt_msg_list()
|
||||||
from ra_aid.fallback_handler import SystemMessage, HumanMessage
|
from ra_aid.fallback_handler import HumanMessage, SystemMessage
|
||||||
|
|
||||||
self.assertTrue(any(isinstance(m, SystemMessage) for m in msgs))
|
self.assertTrue(any(isinstance(m, SystemMessage) for m in msgs))
|
||||||
self.assertTrue(any(isinstance(m, HumanMessage) for m in msgs))
|
self.assertTrue(any(isinstance(m, HumanMessage) for m in msgs))
|
||||||
# Test with failed_messages added
|
# Test with failed_messages added
|
||||||
|
|
@ -238,13 +286,17 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
# Create dummy tool function
|
# Create dummy tool function
|
||||||
def dummy_tool_func(args):
|
def dummy_tool_func(args):
|
||||||
return "invoked_result"
|
return "invoked_result"
|
||||||
|
|
||||||
dummy_tool_func.__name__ = "dummy_tool"
|
dummy_tool_func.__name__ = "dummy_tool"
|
||||||
|
|
||||||
# Create wrapper class
|
# Create wrapper class
|
||||||
class DummyToolWrapper:
|
class DummyToolWrapper:
|
||||||
def __init__(self, func):
|
def __init__(self, func):
|
||||||
self.func = func
|
self.func = func
|
||||||
|
|
||||||
def invoke(self, args):
|
def invoke(self, args):
|
||||||
return self.func(args)
|
return self.func(args)
|
||||||
|
|
||||||
dummy_wrapper = DummyToolWrapper(dummy_tool_func)
|
dummy_wrapper = DummyToolWrapper(dummy_tool_func)
|
||||||
self.fallback_handler.tools = [dummy_wrapper]
|
self.fallback_handler.tools = [dummy_wrapper]
|
||||||
tool_call_req = {"name": "dummy_tool", "arguments": {"x": 42}}
|
tool_call_req = {"name": "dummy_tool", "arguments": {"x": 42}}
|
||||||
|
|
@ -255,9 +307,13 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
dummy_tool_call = {
|
dummy_tool_call = {
|
||||||
"id": "123",
|
"id": "123",
|
||||||
"type": "test",
|
"type": "test",
|
||||||
"function": {"name": "dummy_tool", "arguments": "{\"x\":42}"}
|
"function": {"name": "dummy_tool", "arguments": '{"x":42}'},
|
||||||
}
|
}
|
||||||
DummyResponse = type("DummyResponse", (), {"additional_kwargs": {"tool_calls": [dummy_tool_call]}})
|
DummyResponse = type(
|
||||||
|
"DummyResponse",
|
||||||
|
(),
|
||||||
|
{"additional_kwargs": {"tool_calls": [dummy_tool_call]}},
|
||||||
|
)
|
||||||
result = self.fallback_handler.base_message_to_tool_call_dict(DummyResponse)
|
result = self.fallback_handler.base_message_to_tool_call_dict(DummyResponse)
|
||||||
self.assertEqual(result["id"], "123")
|
self.assertEqual(result["id"], "123")
|
||||||
self.assertEqual(result["name"], "dummy_tool")
|
self.assertEqual(result["name"], "dummy_tool")
|
||||||
|
|
@ -285,22 +341,30 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
|
|
||||||
def test_handle_failure_response(self):
|
def test_handle_failure_response(self):
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
|
|
||||||
def dummy_handle_failure(error, agent):
|
def dummy_handle_failure(error, agent):
|
||||||
return ["fallback_response"]
|
return ["fallback_response"]
|
||||||
self.fallback_handler.handle_failure = dummy_handle_failure
|
|
||||||
response = self.fallback_handler.handle_failure_response(ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React")
|
|
||||||
from ra_aid.fallback_handler import SystemMessage
|
|
||||||
self.assertTrue(all(isinstance(m, SystemMessage) for m in response))
|
|
||||||
response_non = self.fallback_handler.handle_failure_response(ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other")
|
|
||||||
self.assertIsNone(response_non)
|
|
||||||
|
|
||||||
|
self.fallback_handler.handle_failure = dummy_handle_failure
|
||||||
|
response = self.fallback_handler.handle_failure_response(
|
||||||
|
ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React"
|
||||||
|
)
|
||||||
|
from ra_aid.fallback_handler import SystemMessage
|
||||||
|
|
||||||
|
self.assertTrue(all(isinstance(m, SystemMessage) for m in response))
|
||||||
|
response_non = self.fallback_handler.handle_failure_response(
|
||||||
|
ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other"
|
||||||
|
)
|
||||||
|
self.assertIsNone(response_non)
|
||||||
|
|
||||||
def test_init_msg_list_non_overlapping(self):
|
def test_init_msg_list_non_overlapping(self):
|
||||||
# Test when the first two and last two messages do not overlap.
|
# Test when the first two and last two messages do not overlap.
|
||||||
full_list = ["msg1", "msg2", "msg3", "msg4", "msg5"]
|
full_list = ["msg1", "msg2", "msg3", "msg4", "msg5"]
|
||||||
self.fallback_handler.init_msg_list(full_list)
|
self.fallback_handler.init_msg_list(full_list)
|
||||||
# Expected merged list: first two ("msg1", "msg2") plus last two ("msg4", "msg5")
|
# Expected merged list: first two ("msg1", "msg2") plus last two ("msg4", "msg5")
|
||||||
self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"])
|
self.assertEqual(
|
||||||
|
self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"]
|
||||||
|
)
|
||||||
|
|
||||||
def test_init_msg_list_with_overlap(self):
|
def test_init_msg_list_with_overlap(self):
|
||||||
# Test when the last two messages overlap with the first two.
|
# Test when the last two messages overlap with the first two.
|
||||||
|
|
@ -309,5 +373,6 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
# Expected merged list: first two ("msg1", "msg2") plus "msg3" from the last two, since "msg1" was already present.
|
# Expected merged list: first two ("msg1", "msg2") plus "msg3" from the last two, since "msg1" was already present.
|
||||||
self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg3"])
|
self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg3"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,9 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch
|
||||||
|
|
||||||
def test_initialize_expert_unsupported_provider(clean_env):
|
def test_initialize_expert_unsupported_provider(clean_env):
|
||||||
"""Test error handling for unsupported provider in expert mode."""
|
"""Test error handling for unsupported provider in expert mode."""
|
||||||
with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"):
|
with pytest.raises(
|
||||||
|
ValueError, match=r"Missing required environment variable for provider: unknown"
|
||||||
|
):
|
||||||
initialize_expert_llm("unknown", "model")
|
initialize_expert_llm("unknown", "model")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -197,7 +199,10 @@ def test_initialize_unsupported_provider(clean_env):
|
||||||
"""Test initialization with unsupported provider raises ValueError"""
|
"""Test initialization with unsupported provider raises ValueError"""
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
initialize_llm("unsupported", "model")
|
initialize_llm("unsupported", "model")
|
||||||
assert str(exc_info.value) == "Missing required environment variable for provider: unsupported"
|
assert (
|
||||||
|
str(exc_info.value)
|
||||||
|
== "Missing required environment variable for provider: unsupported"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemini):
|
def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemini):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue