ciayn
This commit is contained in:
parent
4cb98370c2
commit
6f10db811e
|
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
from typing import Dict, Any, Generator, List, Optional
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from typing import Dict, Any, Generator, List, Optional, Union
|
||||
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
|
||||
class CiaynAgent:
|
||||
|
|
@ -20,17 +20,19 @@ class CiaynAgent:
|
|||
\"\"\""""
|
||||
return info
|
||||
|
||||
def __init__(self, model, tools: list, max_history_messages: int = 50):
|
||||
def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = None):
|
||||
"""Initialize the agent with a model and list of tools.
|
||||
|
||||
Args:
|
||||
model: The language model to use
|
||||
tools: List of tools available to the agent
|
||||
max_history_messages: Maximum number of messages to keep in chat history
|
||||
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
|
||||
"""
|
||||
self.model = model
|
||||
self.tools = tools
|
||||
self.max_history_messages = max_history_messages
|
||||
self.max_tokens = max_tokens
|
||||
self.available_functions = []
|
||||
for t in tools:
|
||||
self.available_functions.append(self._get_function_info(t.func))
|
||||
|
|
@ -98,11 +100,36 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
}
|
||||
}
|
||||
|
||||
def _trim_chat_history(self, initial_messages: List[Any], chat_history: List[Any]) -> List[Any]:
|
||||
"""Trim chat history to maximum length while preserving initial messages.
|
||||
@staticmethod
|
||||
def _estimate_tokens(content: Optional[Union[str, BaseMessage]]) -> int:
|
||||
"""Estimate number of tokens in content using simple byte length heuristic.
|
||||
|
||||
Only trims the chat_history portion while preserving all initial messages.
|
||||
Returns the concatenated list of initial_messages + trimmed chat_history.
|
||||
Estimates 1 token per 4 bytes of content. For messages, uses the content field.
|
||||
|
||||
Args:
|
||||
content: String content or Message object to estimate tokens for
|
||||
|
||||
Returns:
|
||||
int: Estimated number of tokens, 0 if content is None/empty
|
||||
"""
|
||||
if content is None:
|
||||
return 0
|
||||
|
||||
if isinstance(content, BaseMessage):
|
||||
text = content.content
|
||||
else:
|
||||
text = content
|
||||
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
return len(text.encode('utf-8')) // 4
|
||||
|
||||
def _trim_chat_history(self, initial_messages: List[Any], chat_history: List[Any]) -> List[Any]:
|
||||
"""Trim chat history based on message count and token limits while preserving initial messages.
|
||||
|
||||
Applies both message count and token limits (if configured) to chat_history,
|
||||
while preserving all initial_messages. Returns concatenated result.
|
||||
|
||||
Args:
|
||||
initial_messages: List of initial messages to preserve
|
||||
|
|
@ -111,11 +138,25 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
Returns:
|
||||
List[Any]: Concatenated initial_messages + trimmed chat_history
|
||||
"""
|
||||
if len(chat_history) <= self.max_history_messages:
|
||||
# First apply message count limit
|
||||
if len(chat_history) > self.max_history_messages:
|
||||
chat_history = chat_history[-self.max_history_messages:]
|
||||
|
||||
# Skip token limiting if max_tokens is None
|
||||
if self.max_tokens is None:
|
||||
return initial_messages + chat_history
|
||||
|
||||
# Keep last max_history_messages from chat_history
|
||||
return initial_messages + chat_history[-self.max_history_messages:]
|
||||
# Calculate initial messages token count
|
||||
initial_tokens = sum(self._estimate_tokens(msg) for msg in initial_messages)
|
||||
|
||||
# Remove messages from start of chat_history until under token limit
|
||||
while chat_history:
|
||||
total_tokens = initial_tokens + sum(self._estimate_tokens(msg) for msg in chat_history)
|
||||
if total_tokens <= self.max_tokens:
|
||||
break
|
||||
chat_history.pop(0)
|
||||
|
||||
return initial_messages + chat_history
|
||||
|
||||
def stream(self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None) -> Generator[Dict[str, Any], None, None]:
|
||||
"""Stream agent responses in a format compatible with print_agent_output."""
|
||||
|
|
|
|||
|
|
@ -98,3 +98,59 @@ def test_trim_chat_history_empty_chat(agent):
|
|||
# Verify initial messages are preserved and no trimming occurred
|
||||
assert result == initial_messages
|
||||
assert len(result) == 2
|
||||
|
||||
def test_trim_chat_history_token_limit():
|
||||
"""Test trimming based on token limit."""
|
||||
agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=20)
|
||||
|
||||
initial_messages = [HumanMessage(content="Initial")] # ~2 tokens
|
||||
chat_history = [
|
||||
HumanMessage(content="A" * 40), # ~10 tokens
|
||||
AIMessage(content="B" * 40), # ~10 tokens
|
||||
HumanMessage(content="C" * 40) # ~10 tokens
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Should keep initial message (~2 tokens) and last message (~10 tokens)
|
||||
assert len(result) == 2
|
||||
assert result[0] == initial_messages[0]
|
||||
assert result[1] == chat_history[-1]
|
||||
|
||||
def test_trim_chat_history_no_token_limit():
|
||||
"""Test trimming with no token limit set."""
|
||||
agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None)
|
||||
|
||||
initial_messages = [HumanMessage(content="Initial")]
|
||||
chat_history = [
|
||||
HumanMessage(content="A" * 1000),
|
||||
AIMessage(content="B" * 1000),
|
||||
HumanMessage(content="C" * 1000)
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Should keep initial message and last 2 messages (max_history_messages=2)
|
||||
assert len(result) == 3
|
||||
assert result[0] == initial_messages[0]
|
||||
assert result[1:] == chat_history[-2:]
|
||||
|
||||
def test_trim_chat_history_both_limits():
|
||||
"""Test trimming with both message count and token limits."""
|
||||
agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=15)
|
||||
|
||||
initial_messages = [HumanMessage(content="Init")] # ~1 token
|
||||
chat_history = [
|
||||
HumanMessage(content="A" * 40), # ~10 tokens
|
||||
AIMessage(content="B" * 40), # ~10 tokens
|
||||
HumanMessage(content="C" * 40), # ~10 tokens
|
||||
AIMessage(content="D" * 40) # ~10 tokens
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Should first apply message limit (keeping last 3)
|
||||
# Then token limit should further reduce to fit under 15 tokens
|
||||
assert len(result) == 2 # Initial message + 1 message under token limit
|
||||
assert result[0] == initial_messages[0]
|
||||
assert result[1] == chat_history[-1]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ import pytest
|
|||
from unittest.mock import patch, Mock
|
||||
from langchain_openai.chat_models import ChatOpenAI
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
from langchain_core.messages import HumanMessage
|
||||
from dataclasses import dataclass
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
|
||||
from ra_aid.env import validate_environment
|
||||
from ra_aid.llm import initialize_llm, initialize_expert_llm
|
||||
|
|
@ -87,6 +89,21 @@ def test_initialize_expert_unsupported_provider(clean_env):
|
|||
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
|
||||
initialize_expert_llm("unknown", "model")
|
||||
|
||||
def test_estimate_tokens():
|
||||
"""Test token estimation functionality."""
|
||||
# Test empty/None cases
|
||||
assert CiaynAgent._estimate_tokens(None) == 0
|
||||
assert CiaynAgent._estimate_tokens('') == 0
|
||||
|
||||
# Test string content
|
||||
assert CiaynAgent._estimate_tokens('test') == 1 # 4 bytes
|
||||
assert CiaynAgent._estimate_tokens('hello world') == 2 # 11 bytes
|
||||
assert CiaynAgent._estimate_tokens('🚀') == 1 # 4 bytes
|
||||
|
||||
# Test message content
|
||||
msg = HumanMessage(content='test message')
|
||||
assert CiaynAgent._estimate_tokens(msg) == 3 # 11 bytes
|
||||
|
||||
def test_initialize_openai(clean_env, mock_openai):
|
||||
"""Test OpenAI provider initialization"""
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
|
|
|
|||
Loading…
Reference in New Issue