support think tag
This commit is contained in:
parent
26ecc05de8
commit
88538d92fc
|
|
@ -13,13 +13,16 @@ from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
|
|||
from ra_aid.exceptions import ToolExecutionError
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.prompts.ciayn_prompts import CIAYN_AGENT_SYSTEM_PROMPT, CIAYN_AGENT_HUMAN_PROMPT, EXTRACT_TOOL_CALL_PROMPT, NO_TOOL_CALL_PROMPT
|
||||
from ra_aid.tools.expert import get_model
|
||||
from ra_aid.tools.reflection import get_function_info
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.console.formatting import print_warning, print_error
|
||||
from ra_aid.console.formatting import print_warning, print_error, console
|
||||
from ra_aid.agent_context import should_exit
|
||||
from ra_aid.text import extract_think_tag
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -620,8 +623,20 @@ class CiaynAgent:
|
|||
self.chat_history.append(HumanMessage(content=base_prompt))
|
||||
full_history = self._trim_chat_history(initial_messages, self.chat_history)
|
||||
response = self.model.invoke([self.sys_message] + full_history)
|
||||
print("RESPONSE")
|
||||
print(response.content)
|
||||
|
||||
# Check if model supports think tags
|
||||
provider = self.config.get("provider", "")
|
||||
model_name = self.config.get("model", "")
|
||||
model_config = models_params.get(provider, {}).get(model_name, {})
|
||||
supports_think_tag = model_config.get("supports_think_tag", False)
|
||||
supports_thinking = model_config.get("supports_thinking", False)
|
||||
|
||||
# Extract think tags if supported
|
||||
if supports_think_tag or supports_thinking:
|
||||
think_content, remaining_text = extract_think_tag(response.content)
|
||||
if think_content:
|
||||
# console.print(Panel(Markdown(think_content), title="💭 Thoughts"))
|
||||
response.content = remaining_text
|
||||
|
||||
# Check if the response is empty or doesn't contain a valid tool call
|
||||
if not response.content or not response.content.strip():
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ models_params = {
|
|||
"openai-compatible": {
|
||||
"qwen-qwq-32b": {
|
||||
"token_limit": 130000,
|
||||
"think_tag": True,
|
||||
"supports_think_tag": True,
|
||||
"supports_temperature": True,
|
||||
"latency_coefficient": DEFAULT_BASE_LATENCY,
|
||||
"max_tokens": 64000,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,132 @@
|
|||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_model():
|
||||
"""Mock the get_model function to avoid database connection issues."""
|
||||
with patch("ra_aid.agent_backends.ciayn_agent.get_model") as mock:
|
||||
mock.return_value = MagicMock()
|
||||
yield mock
|
||||
|
||||
|
||||
def test_stream_supports_think_tag(mock_get_model):
|
||||
"""Test that CiaynAgent.stream extracts think tags when the model supports them."""
|
||||
# Setup mock model
|
||||
mock_model = MagicMock()
|
||||
mock_response = AIMessage(content="<think>These are my thoughts</think>Actual response")
|
||||
mock_model.invoke.return_value = mock_response
|
||||
|
||||
# Setup agent with config that supports think tags
|
||||
config = {
|
||||
"provider": "openai-compatible",
|
||||
"model": "qwen-qwq-32b"
|
||||
}
|
||||
agent = CiaynAgent(mock_model, [], config=config)
|
||||
|
||||
# Mock print_warning and print_error to avoid unwanted console output
|
||||
with patch("ra_aid.agent_backends.ciayn_agent.print_warning"), \
|
||||
patch("ra_aid.agent_backends.ciayn_agent.print_error"):
|
||||
|
||||
# We're not patching console.print to verify it's called with the panel
|
||||
# Mock _execute_tool to avoid actually executing tools
|
||||
with patch.object(agent, "_execute_tool") as mock_execute:
|
||||
mock_execute.return_value = "Tool result"
|
||||
|
||||
# For console.print, we want to verify it's called, but not actually print anything
|
||||
with patch("rich.console.Console.print") as mock_console_print:
|
||||
# Call stream method
|
||||
next(agent.stream({"messages": []}, {}))
|
||||
|
||||
# Verify console.print was called
|
||||
mock_console_print.assert_called()
|
||||
|
||||
# Check if the response content was updated to remove the think tag
|
||||
assert "Actual response" in mock_execute.call_args[0][0].content
|
||||
assert "<think>" not in mock_execute.call_args[0][0].content
|
||||
|
||||
|
||||
def test_stream_no_think_tag_support(mock_get_model):
|
||||
"""Test that CiaynAgent.stream doesn't extract think tags when not supported."""
|
||||
# Setup mock model
|
||||
mock_model = MagicMock()
|
||||
mock_response = AIMessage(content="<think>These are my thoughts</think>Actual response")
|
||||
mock_model.invoke.return_value = mock_response
|
||||
|
||||
# Setup agent with config that doesn't support think tags
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4"
|
||||
}
|
||||
agent = CiaynAgent(mock_model, [], config=config)
|
||||
|
||||
# Mock print_warning and print_error to avoid unwanted console output
|
||||
with patch("ra_aid.agent_backends.ciayn_agent.print_warning"), \
|
||||
patch("ra_aid.agent_backends.ciayn_agent.print_error"):
|
||||
|
||||
# Mock _execute_tool to avoid actually executing tools
|
||||
with patch.object(agent, "_execute_tool") as mock_execute:
|
||||
mock_execute.return_value = "Tool result"
|
||||
|
||||
# For console.print, we want to patch it to verify Panel with title="💭 Thoughts" is not used
|
||||
with patch("ra_aid.agent_backends.ciayn_agent.Panel") as mock_panel:
|
||||
# Call stream method
|
||||
next(agent.stream({"messages": []}, {}))
|
||||
|
||||
# Verify panel was not created with '💭 Thoughts' title
|
||||
thoughts_panel_call = None
|
||||
for call in mock_panel.call_args_list:
|
||||
args, kwargs = call
|
||||
if kwargs.get("title") == "💭 Thoughts":
|
||||
thoughts_panel_call = call
|
||||
break
|
||||
|
||||
assert thoughts_panel_call is None, "A panel with title '💭 Thoughts' was created but should not have been"
|
||||
|
||||
# Check that the response content was not modified
|
||||
assert "<think>These are my thoughts</think>Actual response" in mock_execute.call_args[0][0].content
|
||||
|
||||
|
||||
def test_stream_with_no_think_tags(mock_get_model):
|
||||
"""Test that CiaynAgent.stream works properly when no think tags are present."""
|
||||
# Setup mock model
|
||||
mock_model = MagicMock()
|
||||
mock_response = AIMessage(content="Actual response without tags")
|
||||
mock_model.invoke.return_value = mock_response
|
||||
|
||||
# Setup agent with config that supports think tags
|
||||
config = {
|
||||
"provider": "openai-compatible",
|
||||
"model": "qwen-qwq-32b"
|
||||
}
|
||||
agent = CiaynAgent(mock_model, [], config=config)
|
||||
|
||||
# Mock print_warning and print_error to avoid unwanted console output
|
||||
with patch("ra_aid.agent_backends.ciayn_agent.print_warning"), \
|
||||
patch("ra_aid.agent_backends.ciayn_agent.print_error"):
|
||||
|
||||
# Mock _execute_tool to avoid actually executing tools
|
||||
with patch.object(agent, "_execute_tool") as mock_execute:
|
||||
mock_execute.return_value = "Tool result"
|
||||
|
||||
# For console.print, we want to verify it's not called with a thoughts panel
|
||||
with patch("ra_aid.agent_backends.ciayn_agent.Panel") as mock_panel:
|
||||
# Call stream method
|
||||
next(agent.stream({"messages": []}, {}))
|
||||
|
||||
# Verify panel was not created with '💭 Thoughts' title
|
||||
thoughts_panel_call = None
|
||||
for call in mock_panel.call_args_list:
|
||||
args, kwargs = call
|
||||
if kwargs.get("title") == "💭 Thoughts":
|
||||
thoughts_panel_call = call
|
||||
break
|
||||
|
||||
assert thoughts_panel_call is None, "A panel with title '💭 Thoughts' was created but should not have been"
|
||||
|
||||
# Check that the response content was not modified
|
||||
assert "Actual response without tags" in mock_execute.call_args[0][0].content
|
||||
Loading…
Reference in New Issue