support think tag

This commit is contained in:
AI Christianson 2025-03-07 20:29:37 -05:00 committed by Will
parent 26ecc05de8
commit 88538d92fc
3 changed files with 152 additions and 5 deletions

View File

@ -13,13 +13,16 @@ from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
from ra_aid.exceptions import ToolExecutionError
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.prompts.ciayn_prompts import CIAYN_AGENT_SYSTEM_PROMPT, CIAYN_AGENT_HUMAN_PROMPT, EXTRACT_TOOL_CALL_PROMPT, NO_TOOL_CALL_PROMPT
from ra_aid.tools.expert import get_model
from ra_aid.tools.reflection import get_function_info
from ra_aid.console.output import cpm
from ra_aid.console.formatting import print_warning, print_error
from ra_aid.console.formatting import print_warning, print_error, console
from ra_aid.agent_context import should_exit
from ra_aid.text import extract_think_tag
from rich.panel import Panel
from rich.markdown import Markdown
logger = get_logger(__name__)
@ -620,8 +623,20 @@ class CiaynAgent:
self.chat_history.append(HumanMessage(content=base_prompt))
full_history = self._trim_chat_history(initial_messages, self.chat_history)
response = self.model.invoke([self.sys_message] + full_history)
print("RESPONSE")
print(response.content)
# Check if model supports think tags
provider = self.config.get("provider", "")
model_name = self.config.get("model", "")
model_config = models_params.get(provider, {}).get(model_name, {})
supports_think_tag = model_config.get("supports_think_tag", False)
supports_thinking = model_config.get("supports_thinking", False)
# Extract think tags if supported
if supports_think_tag or supports_thinking:
think_content, remaining_text = extract_think_tag(response.content)
if think_content:
# console.print(Panel(Markdown(think_content), title="💭 Thoughts"))
response.content = remaining_text
# Check if the response is empty or doesn't contain a valid tool call
if not response.content or not response.content.strip():

View File

@ -168,7 +168,7 @@ models_params = {
"openai-compatible": {
"qwen-qwq-32b": {
"token_limit": 130000,
"think_tag": True,
"supports_think_tag": True,
"supports_temperature": True,
"latency_coefficient": DEFAULT_BASE_LATENCY,
"max_tokens": 64000,

View File

@ -0,0 +1,132 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
@pytest.fixture
def mock_get_model():
"""Mock the get_model function to avoid database connection issues."""
with patch("ra_aid.agent_backends.ciayn_agent.get_model") as mock:
mock.return_value = MagicMock()
yield mock
def test_stream_supports_think_tag(mock_get_model):
"""Test that CiaynAgent.stream extracts think tags when the model supports them."""
# Setup mock model
mock_model = MagicMock()
mock_response = AIMessage(content="<think>These are my thoughts</think>Actual response")
mock_model.invoke.return_value = mock_response
# Setup agent with config that supports think tags
config = {
"provider": "openai-compatible",
"model": "qwen-qwq-32b"
}
agent = CiaynAgent(mock_model, [], config=config)
# Mock print_warning and print_error to avoid unwanted console output
with patch("ra_aid.agent_backends.ciayn_agent.print_warning"), \
patch("ra_aid.agent_backends.ciayn_agent.print_error"):
# We're not patching console.print to verify it's called with the panel
# Mock _execute_tool to avoid actually executing tools
with patch.object(agent, "_execute_tool") as mock_execute:
mock_execute.return_value = "Tool result"
# For console.print, we want to verify it's called, but not actually print anything
with patch("rich.console.Console.print") as mock_console_print:
# Call stream method
next(agent.stream({"messages": []}, {}))
# Verify console.print was called
mock_console_print.assert_called()
# Check if the response content was updated to remove the think tag
assert "Actual response" in mock_execute.call_args[0][0].content
assert "<think>" not in mock_execute.call_args[0][0].content
def test_stream_no_think_tag_support(mock_get_model):
"""Test that CiaynAgent.stream doesn't extract think tags when not supported."""
# Setup mock model
mock_model = MagicMock()
mock_response = AIMessage(content="<think>These are my thoughts</think>Actual response")
mock_model.invoke.return_value = mock_response
# Setup agent with config that doesn't support think tags
config = {
"provider": "openai",
"model": "gpt-4"
}
agent = CiaynAgent(mock_model, [], config=config)
# Mock print_warning and print_error to avoid unwanted console output
with patch("ra_aid.agent_backends.ciayn_agent.print_warning"), \
patch("ra_aid.agent_backends.ciayn_agent.print_error"):
# Mock _execute_tool to avoid actually executing tools
with patch.object(agent, "_execute_tool") as mock_execute:
mock_execute.return_value = "Tool result"
# For console.print, we want to patch it to verify Panel with title="💭 Thoughts" is not used
with patch("ra_aid.agent_backends.ciayn_agent.Panel") as mock_panel:
# Call stream method
next(agent.stream({"messages": []}, {}))
# Verify panel was not created with '💭 Thoughts' title
thoughts_panel_call = None
for call in mock_panel.call_args_list:
args, kwargs = call
if kwargs.get("title") == "💭 Thoughts":
thoughts_panel_call = call
break
assert thoughts_panel_call is None, "A panel with title '💭 Thoughts' was created but should not have been"
# Check that the response content was not modified
assert "<think>These are my thoughts</think>Actual response" in mock_execute.call_args[0][0].content
def test_stream_with_no_think_tags(mock_get_model):
"""Test that CiaynAgent.stream works properly when no think tags are present."""
# Setup mock model
mock_model = MagicMock()
mock_response = AIMessage(content="Actual response without tags")
mock_model.invoke.return_value = mock_response
# Setup agent with config that supports think tags
config = {
"provider": "openai-compatible",
"model": "qwen-qwq-32b"
}
agent = CiaynAgent(mock_model, [], config=config)
# Mock print_warning and print_error to avoid unwanted console output
with patch("ra_aid.agent_backends.ciayn_agent.print_warning"), \
patch("ra_aid.agent_backends.ciayn_agent.print_error"):
# Mock _execute_tool to avoid actually executing tools
with patch.object(agent, "_execute_tool") as mock_execute:
mock_execute.return_value = "Tool result"
# For console.print, we want to verify it's not called with a thoughts panel
with patch("ra_aid.agent_backends.ciayn_agent.Panel") as mock_panel:
# Call stream method
next(agent.stream({"messages": []}, {}))
# Verify panel was not created with '💭 Thoughts' title
thoughts_panel_call = None
for call in mock_panel.call_args_list:
args, kwargs = call
if kwargs.get("title") == "💭 Thoughts":
thoughts_panel_call = call
break
assert thoughts_panel_call is None, "A panel with title '💭 Thoughts' was created but should not have been"
# Check that the response content was not modified
assert "Actual response without tags" in mock_execute.call_args[0][0].content