This commit is contained in:
AI Christianson 2024-12-28 15:09:32 -05:00
parent 4cb98370c2
commit 6f10db811e
3 changed files with 124 additions and 10 deletions

View File

@ -1,6 +1,6 @@
import inspect
from typing import Dict, Any, Generator, List, Optional
from langchain_core.messages import AIMessage, HumanMessage
from typing import Dict, Any, Generator, List, Optional, Union
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
from ra_aid.exceptions import ToolExecutionError
class CiaynAgent:
@ -20,17 +20,19 @@ class CiaynAgent:
\"\"\""""
return info
def __init__(self, model, tools: list, max_history_messages: int = 50):
def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = None):
"""Initialize the agent with a model and list of tools.
Args:
model: The language model to use
tools: List of tools available to the agent
max_history_messages: Maximum number of messages to keep in chat history
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
"""
self.model = model
self.tools = tools
self.max_history_messages = max_history_messages
self.max_tokens = max_tokens
self.available_functions = []
for t in tools:
self.available_functions.append(self._get_function_info(t.func))
@ -98,11 +100,36 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
}
}
def _trim_chat_history(self, initial_messages: List[Any], chat_history: List[Any]) -> List[Any]:
"""Trim chat history to maximum length while preserving initial messages.
@staticmethod
def _estimate_tokens(content: Optional[Union[str, BaseMessage]]) -> int:
"""Estimate number of tokens in content using simple byte length heuristic.
Only trims the chat_history portion while preserving all initial messages.
Returns the concatenated list of initial_messages + trimmed chat_history.
Estimates 1 token per 4 bytes of content. For messages, uses the content field.
Args:
content: String content or Message object to estimate tokens for
Returns:
int: Estimated number of tokens, 0 if content is None/empty
"""
if content is None:
return 0
if isinstance(content, BaseMessage):
text = content.content
else:
text = content
if not text:
return 0
return len(text.encode('utf-8')) // 4
def _trim_chat_history(self, initial_messages: List[Any], chat_history: List[Any]) -> List[Any]:
"""Trim chat history based on message count and token limits while preserving initial messages.
Applies both message count and token limits (if configured) to chat_history,
while preserving all initial_messages. Returns concatenated result.
Args:
initial_messages: List of initial messages to preserve
@ -111,11 +138,25 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
Returns:
List[Any]: Concatenated initial_messages + trimmed chat_history
"""
if len(chat_history) <= self.max_history_messages:
# First apply message count limit
if len(chat_history) > self.max_history_messages:
chat_history = chat_history[-self.max_history_messages:]
# Skip token limiting if max_tokens is None
if self.max_tokens is None:
return initial_messages + chat_history
# Keep last max_history_messages from chat_history
return initial_messages + chat_history[-self.max_history_messages:]
# Calculate initial messages token count
initial_tokens = sum(self._estimate_tokens(msg) for msg in initial_messages)
# Remove messages from start of chat_history until under token limit
while chat_history:
total_tokens = initial_tokens + sum(self._estimate_tokens(msg) for msg in chat_history)
if total_tokens <= self.max_tokens:
break
chat_history.pop(0)
return initial_messages + chat_history
def stream(self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None) -> Generator[Dict[str, Any], None, None]:
"""Stream agent responses in a format compatible with print_agent_output."""

View File

@ -98,3 +98,59 @@ def test_trim_chat_history_empty_chat(agent):
# Verify initial messages are preserved and no trimming occurred
assert result == initial_messages
assert len(result) == 2
def test_trim_chat_history_token_limit():
"""Test trimming based on token limit."""
agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=20)
initial_messages = [HumanMessage(content="Initial")] # ~2 tokens
chat_history = [
HumanMessage(content="A" * 40), # ~10 tokens
AIMessage(content="B" * 40), # ~10 tokens
HumanMessage(content="C" * 40) # ~10 tokens
]
result = agent._trim_chat_history(initial_messages, chat_history)
# Should keep initial message (~2 tokens) and last message (~10 tokens)
assert len(result) == 2
assert result[0] == initial_messages[0]
assert result[1] == chat_history[-1]
def test_trim_chat_history_no_token_limit():
"""Test trimming with no token limit set."""
agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None)
initial_messages = [HumanMessage(content="Initial")]
chat_history = [
HumanMessage(content="A" * 1000),
AIMessage(content="B" * 1000),
HumanMessage(content="C" * 1000)
]
result = agent._trim_chat_history(initial_messages, chat_history)
# Should keep initial message and last 2 messages (max_history_messages=2)
assert len(result) == 3
assert result[0] == initial_messages[0]
assert result[1:] == chat_history[-2:]
def test_trim_chat_history_both_limits():
"""Test trimming with both message count and token limits."""
agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=15)
initial_messages = [HumanMessage(content="Init")] # ~1 token
chat_history = [
HumanMessage(content="A" * 40), # ~10 tokens
AIMessage(content="B" * 40), # ~10 tokens
HumanMessage(content="C" * 40), # ~10 tokens
AIMessage(content="D" * 40) # ~10 tokens
]
result = agent._trim_chat_history(initial_messages, chat_history)
# Should first apply message limit (keeping last 3)
# Then token limit should further reduce to fit under 15 tokens
assert len(result) == 2 # Initial message + 1 message under token limit
assert result[0] == initial_messages[0]
assert result[1] == chat_history[-1]

View File

@ -3,7 +3,9 @@ import pytest
from unittest.mock import patch, Mock
from langchain_openai.chat_models import ChatOpenAI
from langchain_anthropic.chat_models import ChatAnthropic
from langchain_core.messages import HumanMessage
from dataclasses import dataclass
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.env import validate_environment
from ra_aid.llm import initialize_llm, initialize_expert_llm
@ -87,6 +89,21 @@ def test_initialize_expert_unsupported_provider(clean_env):
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
initialize_expert_llm("unknown", "model")
def test_estimate_tokens():
"""Test token estimation functionality."""
# Test empty/None cases
assert CiaynAgent._estimate_tokens(None) == 0
assert CiaynAgent._estimate_tokens('') == 0
# Test string content
assert CiaynAgent._estimate_tokens('test') == 1 # 4 bytes
assert CiaynAgent._estimate_tokens('hello world') == 2 # 11 bytes
assert CiaynAgent._estimate_tokens('🚀') == 1 # 4 bytes
# Test message content
msg = HumanMessage(content='test message')
assert CiaynAgent._estimate_tokens(msg) == 3 # 11 bytes
def test_initialize_openai(clean_env, mock_openai):
"""Test OpenAI provider initialization"""
os.environ["OPENAI_API_KEY"] = "test-key"