support think tag
This commit is contained in:
parent
26ecc05de8
commit
88538d92fc
|
|
@ -13,13 +13,16 @@ from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
from ra_aid.fallback_handler import FallbackHandler
|
from ra_aid.fallback_handler import FallbackHandler
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||||
from ra_aid.prompts.ciayn_prompts import CIAYN_AGENT_SYSTEM_PROMPT, CIAYN_AGENT_HUMAN_PROMPT, EXTRACT_TOOL_CALL_PROMPT, NO_TOOL_CALL_PROMPT
|
from ra_aid.prompts.ciayn_prompts import CIAYN_AGENT_SYSTEM_PROMPT, CIAYN_AGENT_HUMAN_PROMPT, EXTRACT_TOOL_CALL_PROMPT, NO_TOOL_CALL_PROMPT
|
||||||
from ra_aid.tools.expert import get_model
|
from ra_aid.tools.expert import get_model
|
||||||
from ra_aid.tools.reflection import get_function_info
|
from ra_aid.tools.reflection import get_function_info
|
||||||
from ra_aid.console.output import cpm
|
from ra_aid.console.output import cpm
|
||||||
from ra_aid.console.formatting import print_warning, print_error
|
from ra_aid.console.formatting import print_warning, print_error, console
|
||||||
from ra_aid.agent_context import should_exit
|
from ra_aid.agent_context import should_exit
|
||||||
|
from ra_aid.text import extract_think_tag
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -620,8 +623,20 @@ class CiaynAgent:
|
||||||
self.chat_history.append(HumanMessage(content=base_prompt))
|
self.chat_history.append(HumanMessage(content=base_prompt))
|
||||||
full_history = self._trim_chat_history(initial_messages, self.chat_history)
|
full_history = self._trim_chat_history(initial_messages, self.chat_history)
|
||||||
response = self.model.invoke([self.sys_message] + full_history)
|
response = self.model.invoke([self.sys_message] + full_history)
|
||||||
print("RESPONSE")
|
|
||||||
print(response.content)
|
# Check if model supports think tags
|
||||||
|
provider = self.config.get("provider", "")
|
||||||
|
model_name = self.config.get("model", "")
|
||||||
|
model_config = models_params.get(provider, {}).get(model_name, {})
|
||||||
|
supports_think_tag = model_config.get("supports_think_tag", False)
|
||||||
|
supports_thinking = model_config.get("supports_thinking", False)
|
||||||
|
|
||||||
|
# Extract think tags if supported
|
||||||
|
if supports_think_tag or supports_thinking:
|
||||||
|
think_content, remaining_text = extract_think_tag(response.content)
|
||||||
|
if think_content:
|
||||||
|
# console.print(Panel(Markdown(think_content), title="💭 Thoughts"))
|
||||||
|
response.content = remaining_text
|
||||||
|
|
||||||
# Check if the response is empty or doesn't contain a valid tool call
|
# Check if the response is empty or doesn't contain a valid tool call
|
||||||
if not response.content or not response.content.strip():
|
if not response.content or not response.content.strip():
|
||||||
|
|
|
||||||
|
|
@ -168,7 +168,7 @@ models_params = {
|
||||||
"openai-compatible": {
|
"openai-compatible": {
|
||||||
"qwen-qwq-32b": {
|
"qwen-qwq-32b": {
|
||||||
"token_limit": 130000,
|
"token_limit": 130000,
|
||||||
"think_tag": True,
|
"supports_think_tag": True,
|
||||||
"supports_temperature": True,
|
"supports_temperature": True,
|
||||||
"latency_coefficient": DEFAULT_BASE_LATENCY,
|
"latency_coefficient": DEFAULT_BASE_LATENCY,
|
||||||
"max_tokens": 64000,
|
"max_tokens": 64000,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,132 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_get_model():
|
||||||
|
"""Mock the get_model function to avoid database connection issues."""
|
||||||
|
with patch("ra_aid.agent_backends.ciayn_agent.get_model") as mock:
|
||||||
|
mock.return_value = MagicMock()
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_supports_think_tag(mock_get_model):
|
||||||
|
"""Test that CiaynAgent.stream extracts think tags when the model supports them."""
|
||||||
|
# Setup mock model
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_response = AIMessage(content="<think>These are my thoughts</think>Actual response")
|
||||||
|
mock_model.invoke.return_value = mock_response
|
||||||
|
|
||||||
|
# Setup agent with config that supports think tags
|
||||||
|
config = {
|
||||||
|
"provider": "openai-compatible",
|
||||||
|
"model": "qwen-qwq-32b"
|
||||||
|
}
|
||||||
|
agent = CiaynAgent(mock_model, [], config=config)
|
||||||
|
|
||||||
|
# Mock print_warning and print_error to avoid unwanted console output
|
||||||
|
with patch("ra_aid.agent_backends.ciayn_agent.print_warning"), \
|
||||||
|
patch("ra_aid.agent_backends.ciayn_agent.print_error"):
|
||||||
|
|
||||||
|
# We're not patching console.print to verify it's called with the panel
|
||||||
|
# Mock _execute_tool to avoid actually executing tools
|
||||||
|
with patch.object(agent, "_execute_tool") as mock_execute:
|
||||||
|
mock_execute.return_value = "Tool result"
|
||||||
|
|
||||||
|
# For console.print, we want to verify it's called, but not actually print anything
|
||||||
|
with patch("rich.console.Console.print") as mock_console_print:
|
||||||
|
# Call stream method
|
||||||
|
next(agent.stream({"messages": []}, {}))
|
||||||
|
|
||||||
|
# Verify console.print was called
|
||||||
|
mock_console_print.assert_called()
|
||||||
|
|
||||||
|
# Check if the response content was updated to remove the think tag
|
||||||
|
assert "Actual response" in mock_execute.call_args[0][0].content
|
||||||
|
assert "<think>" not in mock_execute.call_args[0][0].content
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_no_think_tag_support(mock_get_model):
|
||||||
|
"""Test that CiaynAgent.stream doesn't extract think tags when not supported."""
|
||||||
|
# Setup mock model
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_response = AIMessage(content="<think>These are my thoughts</think>Actual response")
|
||||||
|
mock_model.invoke.return_value = mock_response
|
||||||
|
|
||||||
|
# Setup agent with config that doesn't support think tags
|
||||||
|
config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"model": "gpt-4"
|
||||||
|
}
|
||||||
|
agent = CiaynAgent(mock_model, [], config=config)
|
||||||
|
|
||||||
|
# Mock print_warning and print_error to avoid unwanted console output
|
||||||
|
with patch("ra_aid.agent_backends.ciayn_agent.print_warning"), \
|
||||||
|
patch("ra_aid.agent_backends.ciayn_agent.print_error"):
|
||||||
|
|
||||||
|
# Mock _execute_tool to avoid actually executing tools
|
||||||
|
with patch.object(agent, "_execute_tool") as mock_execute:
|
||||||
|
mock_execute.return_value = "Tool result"
|
||||||
|
|
||||||
|
# For console.print, we want to patch it to verify Panel with title="💭 Thoughts" is not used
|
||||||
|
with patch("ra_aid.agent_backends.ciayn_agent.Panel") as mock_panel:
|
||||||
|
# Call stream method
|
||||||
|
next(agent.stream({"messages": []}, {}))
|
||||||
|
|
||||||
|
# Verify panel was not created with '💭 Thoughts' title
|
||||||
|
thoughts_panel_call = None
|
||||||
|
for call in mock_panel.call_args_list:
|
||||||
|
args, kwargs = call
|
||||||
|
if kwargs.get("title") == "💭 Thoughts":
|
||||||
|
thoughts_panel_call = call
|
||||||
|
break
|
||||||
|
|
||||||
|
assert thoughts_panel_call is None, "A panel with title '💭 Thoughts' was created but should not have been"
|
||||||
|
|
||||||
|
# Check that the response content was not modified
|
||||||
|
assert "<think>These are my thoughts</think>Actual response" in mock_execute.call_args[0][0].content
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_with_no_think_tags(mock_get_model):
|
||||||
|
"""Test that CiaynAgent.stream works properly when no think tags are present."""
|
||||||
|
# Setup mock model
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_response = AIMessage(content="Actual response without tags")
|
||||||
|
mock_model.invoke.return_value = mock_response
|
||||||
|
|
||||||
|
# Setup agent with config that supports think tags
|
||||||
|
config = {
|
||||||
|
"provider": "openai-compatible",
|
||||||
|
"model": "qwen-qwq-32b"
|
||||||
|
}
|
||||||
|
agent = CiaynAgent(mock_model, [], config=config)
|
||||||
|
|
||||||
|
# Mock print_warning and print_error to avoid unwanted console output
|
||||||
|
with patch("ra_aid.agent_backends.ciayn_agent.print_warning"), \
|
||||||
|
patch("ra_aid.agent_backends.ciayn_agent.print_error"):
|
||||||
|
|
||||||
|
# Mock _execute_tool to avoid actually executing tools
|
||||||
|
with patch.object(agent, "_execute_tool") as mock_execute:
|
||||||
|
mock_execute.return_value = "Tool result"
|
||||||
|
|
||||||
|
# For console.print, we want to verify it's not called with a thoughts panel
|
||||||
|
with patch("ra_aid.agent_backends.ciayn_agent.Panel") as mock_panel:
|
||||||
|
# Call stream method
|
||||||
|
next(agent.stream({"messages": []}, {}))
|
||||||
|
|
||||||
|
# Verify panel was not created with '💭 Thoughts' title
|
||||||
|
thoughts_panel_call = None
|
||||||
|
for call in mock_panel.call_args_list:
|
||||||
|
args, kwargs = call
|
||||||
|
if kwargs.get("title") == "💭 Thoughts":
|
||||||
|
thoughts_panel_call = call
|
||||||
|
break
|
||||||
|
|
||||||
|
assert thoughts_panel_call is None, "A panel with title '💭 Thoughts' was created but should not have been"
|
||||||
|
|
||||||
|
# Check that the response content was not modified
|
||||||
|
assert "Actual response without tags" in mock_execute.call_args[0][0].content
|
||||||
Loading…
Reference in New Issue