From 88538d92fcefc72a69af2665f671bf3cc196b7c0 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Fri, 7 Mar 2025 20:29:37 -0500 Subject: [PATCH] support think tag --- ra_aid/agent_backends/ciayn_agent.py | 23 ++- ra_aid/models_params.py | 2 +- .../test_ciayn_agent_think_tag.py | 132 ++++++++++++++++++ 3 files changed, 152 insertions(+), 5 deletions(-) create mode 100644 tests/ra_aid/agent_backends/test_ciayn_agent_think_tag.py diff --git a/ra_aid/agent_backends/ciayn_agent.py b/ra_aid/agent_backends/ciayn_agent.py index 855eb57..347ac3e 100644 --- a/ra_aid/agent_backends/ciayn_agent.py +++ b/ra_aid/agent_backends/ciayn_agent.py @@ -13,13 +13,16 @@ from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES from ra_aid.exceptions import ToolExecutionError from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger -from ra_aid.models_params import DEFAULT_TOKEN_LIMIT +from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params from ra_aid.prompts.ciayn_prompts import CIAYN_AGENT_SYSTEM_PROMPT, CIAYN_AGENT_HUMAN_PROMPT, EXTRACT_TOOL_CALL_PROMPT, NO_TOOL_CALL_PROMPT from ra_aid.tools.expert import get_model from ra_aid.tools.reflection import get_function_info from ra_aid.console.output import cpm -from ra_aid.console.formatting import print_warning, print_error +from ra_aid.console.formatting import print_warning, print_error, console from ra_aid.agent_context import should_exit +from ra_aid.text import extract_think_tag +from rich.panel import Panel +from rich.markdown import Markdown logger = get_logger(__name__) @@ -620,8 +623,20 @@ class CiaynAgent: self.chat_history.append(HumanMessage(content=base_prompt)) full_history = self._trim_chat_history(initial_messages, self.chat_history) response = self.model.invoke([self.sys_message] + full_history) - print("RESPONSE") - print(response.content) + + # Check if model supports think tags + provider = self.config.get("provider", "") + model_name = self.config.get("model", "") + model_config = models_params.get(provider, {}).get(model_name, {}) + supports_think_tag = model_config.get("supports_think_tag", False) + supports_thinking = model_config.get("supports_thinking", False) + + # Extract think tags if supported + if supports_think_tag or supports_thinking: + think_content, remaining_text = extract_think_tag(response.content) + if think_content: + # console.print(Panel(Markdown(think_content), title="💭 Thoughts")) + response.content = remaining_text # Check if the response is empty or doesn't contain a valid tool call if not response.content or not response.content.strip(): diff --git a/ra_aid/models_params.py b/ra_aid/models_params.py index 3e08fad..8c3652f 100644 --- a/ra_aid/models_params.py +++ b/ra_aid/models_params.py @@ -168,7 +168,7 @@ models_params = { "openai-compatible": { "qwen-qwq-32b": { "token_limit": 130000, - "think_tag": True, + "supports_think_tag": True, "supports_temperature": True, "latency_coefficient": DEFAULT_BASE_LATENCY, "max_tokens": 64000, diff --git a/tests/ra_aid/agent_backends/test_ciayn_agent_think_tag.py b/tests/ra_aid/agent_backends/test_ciayn_agent_think_tag.py new file mode 100644 index 0000000..302d532 --- /dev/null +++ b/tests/ra_aid/agent_backends/test_ciayn_agent_think_tag.py @@ -0,0 +1,132 @@ +import pytest +from unittest.mock import MagicMock, patch + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from ra_aid.agent_backends.ciayn_agent import CiaynAgent + + +@pytest.fixture +def mock_get_model(): + """Mock the get_model function to avoid database connection issues.""" + with patch("ra_aid.agent_backends.ciayn_agent.get_model") as mock: + mock.return_value = MagicMock() + yield mock + + +def test_stream_supports_think_tag(mock_get_model): + """Test that CiaynAgent.stream extracts think tags when the model supports them.""" + # Setup mock model + mock_model = MagicMock() + mock_response = AIMessage(content="These are my thoughtsActual response") + mock_model.invoke.return_value = mock_response + + # Setup agent with config that supports think tags + config = { + "provider": "openai-compatible", + "model": "qwen-qwq-32b" + } + agent = CiaynAgent(mock_model, [], config=config) + + # Mock print_warning and print_error to avoid unwanted console output + with patch("ra_aid.agent_backends.ciayn_agent.print_warning"), \ + patch("ra_aid.agent_backends.ciayn_agent.print_error"): + + # We're not patching console.print to verify it's called with the panel + # Mock _execute_tool to avoid actually executing tools + with patch.object(agent, "_execute_tool") as mock_execute: + mock_execute.return_value = "Tool result" + + # For console.print, we want to verify it's called, but not actually print anything + with patch("rich.console.Console.print") as mock_console_print: + # Call stream method + next(agent.stream({"messages": []}, {})) + + # Verify console.print was called + mock_console_print.assert_called() + + # Check if the response content was updated to remove the think tag + assert "Actual response" in mock_execute.call_args[0][0].content + assert "" not in mock_execute.call_args[0][0].content + + +def test_stream_no_think_tag_support(mock_get_model): + """Test that CiaynAgent.stream doesn't extract think tags when not supported.""" + # Setup mock model + mock_model = MagicMock() + mock_response = AIMessage(content="These are my thoughtsActual response") + mock_model.invoke.return_value = mock_response + + # Setup agent with config that doesn't support think tags + config = { + "provider": "openai", + "model": "gpt-4" + } + agent = CiaynAgent(mock_model, [], config=config) + + # Mock print_warning and print_error to avoid unwanted console output + with patch("ra_aid.agent_backends.ciayn_agent.print_warning"), \ + patch("ra_aid.agent_backends.ciayn_agent.print_error"): + + # Mock _execute_tool to avoid actually executing tools + with patch.object(agent, "_execute_tool") as mock_execute: + mock_execute.return_value = "Tool result" + + # For console.print, we want to patch it to verify Panel with title="💭 Thoughts" is not used + with patch("ra_aid.agent_backends.ciayn_agent.Panel") as mock_panel: + # Call stream method + next(agent.stream({"messages": []}, {})) + + # Verify panel was not created with '💭 Thoughts' title + thoughts_panel_call = None + for call in mock_panel.call_args_list: + args, kwargs = call + if kwargs.get("title") == "💭 Thoughts": + thoughts_panel_call = call + break + + assert thoughts_panel_call is None, "A panel with title '💭 Thoughts' was created but should not have been" + + # Check that the response content was not modified + assert "These are my thoughtsActual response" in mock_execute.call_args[0][0].content + + +def test_stream_with_no_think_tags(mock_get_model): + """Test that CiaynAgent.stream works properly when no think tags are present.""" + # Setup mock model + mock_model = MagicMock() + mock_response = AIMessage(content="Actual response without tags") + mock_model.invoke.return_value = mock_response + + # Setup agent with config that supports think tags + config = { + "provider": "openai-compatible", + "model": "qwen-qwq-32b" + } + agent = CiaynAgent(mock_model, [], config=config) + + # Mock print_warning and print_error to avoid unwanted console output + with patch("ra_aid.agent_backends.ciayn_agent.print_warning"), \ + patch("ra_aid.agent_backends.ciayn_agent.print_error"): + + # Mock _execute_tool to avoid actually executing tools + with patch.object(agent, "_execute_tool") as mock_execute: + mock_execute.return_value = "Tool result" + + # For console.print, we want to verify it's not called with a thoughts panel + with patch("ra_aid.agent_backends.ciayn_agent.Panel") as mock_panel: + # Call stream method + next(agent.stream({"messages": []}, {})) + + # Verify panel was not created with '💭 Thoughts' title + thoughts_panel_call = None + for call in mock_panel.call_args_list: + args, kwargs = call + if kwargs.get("title") == "💭 Thoughts": + thoughts_panel_call = call + break + + assert thoughts_panel_call is None, "A panel with title '💭 Thoughts' was created but should not have been" + + # Check that the response content was not modified + assert "Actual response without tags" in mock_execute.call_args[0][0].content \ No newline at end of file