From d5e2e0a9a0434bba9c8fbb88c42ac117ea7333af Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Feb 2025 13:04:41 -0800 Subject: [PATCH] refactor(agent_utils.py, ciayn_agent.py): remove unused import cpm to clean up code and improve readability style(tests): format code for better readability and consistency in test files test(tests): update assertions and test cases for better clarity and maintainability --- ra_aid/agent_utils.py | 2 +- ra_aid/agents/ciayn_agent.py | 1 - tests/ra_aid/test_agent_utils.py | 19 ++-- tests/ra_aid/test_ciayn_agent.py | 1 - tests/ra_aid/test_fallback_handler.py | 135 +++++++++++++++++++------- tests/ra_aid/test_llm.py | 9 +- 6 files changed, 119 insertions(+), 48 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index dd46fe8..8b1708b 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -31,7 +31,7 @@ from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.agents_alias import RAgents from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT from ra_aid.console.formatting import print_error, print_stage_header -from ra_aid.console.output import cpm, print_agent_output +from ra_aid.console.output import print_agent_output from ra_aid.exceptions import ( AgentInterrupt, FallbackToolExecutionError, diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 077d01c..18907ce 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -7,7 +7,6 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, System from langchain_core.tools import BaseTool from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES -from ra_aid.console.output import cpm from ra_aid.exceptions import ToolExecutionError from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 082c140..5346921 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -6,7 +6,7 @@ from unittest.mock import Mock, patch import litellm import pytest from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from ra_aid.agent_utils import ( AgentState, @@ -128,9 +128,10 @@ def test_create_agent_openai(mock_model, mock_memory): assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( - mock_model, [], + mock_model, + [], max_tokens=models_params["openai"]["gpt-4"]["token_limit"], - config={'provider': 'openai', 'model': 'gpt-4'} + config={"provider": "openai", "model": "gpt-4"}, ) @@ -144,9 +145,10 @@ def test_create_agent_no_token_limit(mock_model, mock_memory): assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( - mock_model, [], + mock_model, + [], max_tokens=DEFAULT_TOKEN_LIMIT, - config={'provider': 'unknown', 'model': 'unknown-model'} + config={"provider": "unknown", "model": "unknown-model"}, ) @@ -163,7 +165,7 @@ def test_create_agent_missing_config(mock_model, mock_memory): mock_model, [], max_tokens=DEFAULT_TOKEN_LIMIT, - config={'provider': 'openai'} + config={"provider": "openai"}, ) @@ -207,9 +209,10 @@ def test_create_agent_with_checkpointer(mock_model, mock_memory): assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( - mock_model, [], + mock_model, + [], max_tokens=models_params["openai"]["gpt-4"]["token_limit"], - config={'provider': 'openai', 'model': 'gpt-4'} + config={"provider": "openai", "model": "gpt-4"}, ) diff --git a/tests/ra_aid/test_ciayn_agent.py b/tests/ra_aid/test_ciayn_agent.py index 896db33..4cd9dc2 100644 --- a/tests/ra_aid/test_ciayn_agent.py +++ b/tests/ra_aid/test_ciayn_agent.py @@ -26,7 +26,6 @@ class DummyTool: class DummyModel: def invoke(self, _messages: list[BaseMessage]): - return AIMessage("dummy_tool()") def bind_tools(self, tools, tool_choice): diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index 3273391..6fa325d 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -86,6 +86,7 @@ class TestFallbackHandler(unittest.TestCase): def test_load_fallback_tool_models(self): import ra_aid.fallback_handler as fh + original_supported = fh.supported_top_tool_models fh.supported_top_tool_models = [ {"provider": "dummy", "model": "dummy_model", "type": "prompt"} @@ -95,13 +96,16 @@ class TestFallbackHandler(unittest.TestCase): fh.supported_top_tool_models = original_supported def test_extract_failed_tool_name(self): - from ra_aid.exceptions import ToolExecutionError, FallbackToolExecutionError + from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError + # Case when tool_name is provided - error1 = ToolExecutionError("Error", base_message="dummy", tool_name="dummy_tool") + error1 = ToolExecutionError( + "Error", base_message="dummy", tool_name="dummy_tool" + ) name1 = self.fallback_handler.extract_failed_tool_name(error1) self.assertEqual(name1, "dummy_tool") # Case when tool_name is not provided but regex works - error2 = ToolExecutionError("error with name=\"test_tool\"") + error2 = ToolExecutionError('error with name="test_tool"') name2 = self.fallback_handler.extract_failed_tool_name(error2) self.assertEqual(name2, "test_tool") # Case when regex fails and exception is raised @@ -110,16 +114,13 @@ class TestFallbackHandler(unittest.TestCase): self.fallback_handler.extract_failed_tool_name(error3) def test_find_tool_to_bind(self): - # Create a dummy tool to be found - class DummyTool: - def invoke(self, args): - return "result" class DummyWrapper: def __init__(self, func): self.func = func - def dummy_func(args): + + def dummy_func(_args): return "result" - dummy_tool = DummyTool() + dummy_wrapper = DummyWrapper(dummy_func) self.agent.tools.append(dummy_wrapper) tool = self.fallback_handler._find_tool_to_bind(self.agent, dummy_func.__name__) @@ -134,53 +135,82 @@ class TestFallbackHandler(unittest.TestCase): self.tools = tools self.tool_choice = tool_choice return self + def with_retry(self, stop_after_attempt): return self + def invoke(self, msg_list): return "dummy_response" + dummy_model = DummyModel() + # Set current tool for binding class DummyTool: def invoke(self, args): return "result" + self.fallback_handler.current_tool_to_bind = DummyTool() self.fallback_handler.current_failing_tool_name = "test_tool" # Test with force calling ("fc") type fallback_model_fc = {"type": "fc"} - bound_model_fc = self.fallback_handler._bind_tool_model(dummy_model, fallback_model_fc) + bound_model_fc = self.fallback_handler._bind_tool_model( + dummy_model, fallback_model_fc + ) self.assertTrue(hasattr(bound_model_fc, "tool_choice")) self.assertEqual(bound_model_fc.tool_choice, "test_tool") # Test with prompt type fallback_model_prompt = {"type": "prompt"} - bound_model_prompt = self.fallback_handler._bind_tool_model(dummy_model, fallback_model_prompt) + bound_model_prompt = self.fallback_handler._bind_tool_model( + dummy_model, fallback_model_prompt + ) self.assertTrue(bound_model_prompt.tool_choice is None) def test_invoke_fallback(self): - from unittest.mock import patch import os - import ra_aid.llm as llm + from unittest.mock import patch # Successful fallback scenario with proper API key set - with patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}), \ - patch("ra_aid.fallback_handler.supported_top_tool_models", new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}]), \ - patch("ra_aid.fallback_handler.validate_provider_env", return_value=True), \ - patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm: + with ( + patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}), + patch( + "ra_aid.fallback_handler.supported_top_tool_models", + new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}], + ), + patch("ra_aid.fallback_handler.validate_provider_env", return_value=True), + patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm, + ): + class DummyModel: def bind_tools(self, tools, tool_choice=None): return self + def with_retry(self, stop_after_attempt): return self + def invoke(self, msg_list): return DummyResponse() + class DummyResponse: - additional_kwargs = {"tool_calls": [{"id": "1", "type": "test", "function": {"name": "dummy_tool", "arguments": "{\"a\":1}"}}]} + additional_kwargs = { + "tool_calls": [ + { + "id": "1", + "type": "test", + "function": {"name": "dummy_tool", "arguments": '{"a":1}'}, + } + ] + } + def dummy_initialize_llm(provider, model_name): return DummyModel() + mock_init_llm.side_effect = dummy_initialize_llm + # Set current tool for fallback class DummyTool: def invoke(self, args): return "tool_result" + self.fallback_handler.current_tool_to_bind = DummyTool() self.fallback_handler.current_failing_tool_name = "dummy_tool" # Add dummy tool for lookup in invoke_prompt_tool_call @@ -194,39 +224,57 @@ class TestFallbackHandler(unittest.TestCase): }, ) ) - result = self.fallback_handler.invoke_fallback({"provider": "dummy", "model": "dummy_model", "type": "prompt"}) + result = self.fallback_handler.invoke_fallback( + {"provider": "dummy", "model": "dummy_model", "type": "prompt"} + ) self.assertIsInstance(result, list) self.assertEqual(result[1], "tool_result") # Failed fallback scenario due to missing API key (simulate by empty environment) - with patch.dict(os.environ, {}, clear=True), \ - patch("ra_aid.fallback_handler.supported_top_tool_models", new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}]), \ - patch("ra_aid.fallback_handler.validate_provider_env", return_value=False), \ - patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm: + with ( + patch.dict(os.environ, {}, clear=True), + patch( + "ra_aid.fallback_handler.supported_top_tool_models", + new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}], + ), + patch("ra_aid.fallback_handler.validate_provider_env", return_value=False), + patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm, + ): + class FailingDummyModel: def bind_tools(self, tools, tool_choice=None): return self + def with_retry(self, stop_after_attempt): return self + def invoke(self, msg_list): raise Exception("API key missing") + def failing_initialize_llm(provider, model_name): return FailingDummyModel() + mock_init_llm.side_effect = failing_initialize_llm - fallback_result = self.fallback_handler.invoke_fallback({"provider": "dummy", "model": "dummy_model", "type": "prompt"}) + fallback_result = self.fallback_handler.invoke_fallback( + {"provider": "dummy", "model": "dummy_model", "type": "prompt"} + ) self.assertIsNone(fallback_result) # Test that the overall fallback mechanism raises FallbackToolExecutionError when all models fail # Set failure count to trigger the fallback attempt in attempt_fallback from ra_aid.exceptions import FallbackToolExecutionError - self.fallback_handler.tool_failure_consecutive_failures = self.fallback_handler.max_failures + + self.fallback_handler.tool_failure_consecutive_failures = ( + self.fallback_handler.max_failures + ) with self.assertRaises(FallbackToolExecutionError) as cm: self.fallback_handler.attempt_fallback() self.assertIn("All fallback models have failed", str(cm.exception)) def test_construct_prompt_msg_list(self): msgs = self.fallback_handler.construct_prompt_msg_list() - from ra_aid.fallback_handler import SystemMessage, HumanMessage + from ra_aid.fallback_handler import HumanMessage, SystemMessage + self.assertTrue(any(isinstance(m, SystemMessage) for m in msgs)) self.assertTrue(any(isinstance(m, HumanMessage) for m in msgs)) # Test with failed_messages added @@ -238,13 +286,17 @@ class TestFallbackHandler(unittest.TestCase): # Create dummy tool function def dummy_tool_func(args): return "invoked_result" + dummy_tool_func.__name__ = "dummy_tool" + # Create wrapper class class DummyToolWrapper: def __init__(self, func): self.func = func + def invoke(self, args): return self.func(args) + dummy_wrapper = DummyToolWrapper(dummy_tool_func) self.fallback_handler.tools = [dummy_wrapper] tool_call_req = {"name": "dummy_tool", "arguments": {"x": 42}} @@ -255,9 +307,13 @@ class TestFallbackHandler(unittest.TestCase): dummy_tool_call = { "id": "123", "type": "test", - "function": {"name": "dummy_tool", "arguments": "{\"x\":42}"} + "function": {"name": "dummy_tool", "arguments": '{"x":42}'}, } - DummyResponse = type("DummyResponse", (), {"additional_kwargs": {"tool_calls": [dummy_tool_call]}}) + DummyResponse = type( + "DummyResponse", + (), + {"additional_kwargs": {"tool_calls": [dummy_tool_call]}}, + ) result = self.fallback_handler.base_message_to_tool_call_dict(DummyResponse) self.assertEqual(result["id"], "123") self.assertEqual(result["name"], "dummy_tool") @@ -285,22 +341,30 @@ class TestFallbackHandler(unittest.TestCase): def test_handle_failure_response(self): from ra_aid.exceptions import ToolExecutionError + def dummy_handle_failure(error, agent): return ["fallback_response"] - self.fallback_handler.handle_failure = dummy_handle_failure - response = self.fallback_handler.handle_failure_response(ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React") - from ra_aid.fallback_handler import SystemMessage - self.assertTrue(all(isinstance(m, SystemMessage) for m in response)) - response_non = self.fallback_handler.handle_failure_response(ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other") - self.assertIsNone(response_non) + self.fallback_handler.handle_failure = dummy_handle_failure + response = self.fallback_handler.handle_failure_response( + ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React" + ) + from ra_aid.fallback_handler import SystemMessage + + self.assertTrue(all(isinstance(m, SystemMessage) for m in response)) + response_non = self.fallback_handler.handle_failure_response( + ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other" + ) + self.assertIsNone(response_non) def test_init_msg_list_non_overlapping(self): # Test when the first two and last two messages do not overlap. full_list = ["msg1", "msg2", "msg3", "msg4", "msg5"] self.fallback_handler.init_msg_list(full_list) # Expected merged list: first two ("msg1", "msg2") plus last two ("msg4", "msg5") - self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"]) + self.assertEqual( + self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"] + ) def test_init_msg_list_with_overlap(self): # Test when the last two messages overlap with the first two. @@ -309,5 +373,6 @@ class TestFallbackHandler(unittest.TestCase): # Expected merged list: first two ("msg1", "msg2") plus "msg3" from the last two, since "msg1" was already present. self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg3"]) + if __name__ == "__main__": unittest.main() diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index 7f96ed4..4ea9415 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -121,7 +121,9 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch def test_initialize_expert_unsupported_provider(clean_env): """Test error handling for unsupported provider in expert mode.""" - with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"): + with pytest.raises( + ValueError, match=r"Missing required environment variable for provider: unknown" + ): initialize_expert_llm("unknown", "model") @@ -197,7 +199,10 @@ def test_initialize_unsupported_provider(clean_env): """Test initialization with unsupported provider raises ValueError""" with pytest.raises(ValueError) as exc_info: initialize_llm("unsupported", "model") - assert str(exc_info.value) == "Missing required environment variable for provider: unsupported" + assert ( + str(exc_info.value) + == "Missing required environment variable for provider: unsupported" + ) def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemini):