From 6f10db811e128b3213d38d1f015a7ddb995ef974 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Sat, 28 Dec 2024 15:09:32 -0500 Subject: [PATCH] ciayn --- ra_aid/agents/ciayn_agent.py | 61 +++++++++++++++++++++---- tests/ra_aid/agents/test_ciayn_agent.py | 56 +++++++++++++++++++++++ tests/ra_aid/test_llm.py | 17 +++++++ 3 files changed, 124 insertions(+), 10 deletions(-) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index a9cde81..8f21f35 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -1,6 +1,6 @@ import inspect -from typing import Dict, Any, Generator, List, Optional -from langchain_core.messages import AIMessage, HumanMessage +from typing import Dict, Any, Generator, List, Optional, Union +from langchain_core.messages import AIMessage, HumanMessage, BaseMessage from ra_aid.exceptions import ToolExecutionError class CiaynAgent: @@ -20,17 +20,19 @@ class CiaynAgent: \"\"\"""" return info - def __init__(self, model, tools: list, max_history_messages: int = 50): + def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = None): """Initialize the agent with a model and list of tools. Args: model: The language model to use tools: List of tools available to the agent max_history_messages: Maximum number of messages to keep in chat history + max_tokens: Maximum number of tokens allowed in message history (None for no limit) """ self.model = model self.tools = tools self.max_history_messages = max_history_messages + self.max_tokens = max_tokens self.available_functions = [] for t in tools: self.available_functions.append(self._get_function_info(t.func)) @@ -98,11 +100,36 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" } } - def _trim_chat_history(self, initial_messages: List[Any], chat_history: List[Any]) -> List[Any]: - """Trim chat history to maximum length while preserving initial messages. + @staticmethod + def _estimate_tokens(content: Optional[Union[str, BaseMessage]]) -> int: + """Estimate number of tokens in content using simple byte length heuristic. - Only trims the chat_history portion while preserving all initial messages. - Returns the concatenated list of initial_messages + trimmed chat_history. + Estimates 1 token per 4 bytes of content. For messages, uses the content field. + + Args: + content: String content or Message object to estimate tokens for + + Returns: + int: Estimated number of tokens, 0 if content is None/empty + """ + if content is None: + return 0 + + if isinstance(content, BaseMessage): + text = content.content + else: + text = content + + if not text: + return 0 + + return len(text.encode('utf-8')) // 4 + + def _trim_chat_history(self, initial_messages: List[Any], chat_history: List[Any]) -> List[Any]: + """Trim chat history based on message count and token limits while preserving initial messages. + + Applies both message count and token limits (if configured) to chat_history, + while preserving all initial_messages. Returns concatenated result. Args: initial_messages: List of initial messages to preserve @@ -111,11 +138,25 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" Returns: List[Any]: Concatenated initial_messages + trimmed chat_history """ - if len(chat_history) <= self.max_history_messages: + # First apply message count limit + if len(chat_history) > self.max_history_messages: + chat_history = chat_history[-self.max_history_messages:] + + # Skip token limiting if max_tokens is None + if self.max_tokens is None: return initial_messages + chat_history - # Keep last max_history_messages from chat_history - return initial_messages + chat_history[-self.max_history_messages:] + # Calculate initial messages token count + initial_tokens = sum(self._estimate_tokens(msg) for msg in initial_messages) + + # Remove messages from start of chat_history until under token limit + while chat_history: + total_tokens = initial_tokens + sum(self._estimate_tokens(msg) for msg in chat_history) + if total_tokens <= self.max_tokens: + break + chat_history.pop(0) + + return initial_messages + chat_history def stream(self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None) -> Generator[Dict[str, Any], None, None]: """Stream agent responses in a format compatible with print_agent_output.""" diff --git a/tests/ra_aid/agents/test_ciayn_agent.py b/tests/ra_aid/agents/test_ciayn_agent.py index 2a3bae1..71bc1a8 100644 --- a/tests/ra_aid/agents/test_ciayn_agent.py +++ b/tests/ra_aid/agents/test_ciayn_agent.py @@ -98,3 +98,59 @@ def test_trim_chat_history_empty_chat(agent): # Verify initial messages are preserved and no trimming occurred assert result == initial_messages assert len(result) == 2 + +def test_trim_chat_history_token_limit(): + """Test trimming based on token limit.""" + agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=20) + + initial_messages = [HumanMessage(content="Initial")] # ~2 tokens + chat_history = [ + HumanMessage(content="A" * 40), # ~10 tokens + AIMessage(content="B" * 40), # ~10 tokens + HumanMessage(content="C" * 40) # ~10 tokens + ] + + result = agent._trim_chat_history(initial_messages, chat_history) + + # Should keep initial message (~2 tokens) and last message (~10 tokens) + assert len(result) == 2 + assert result[0] == initial_messages[0] + assert result[1] == chat_history[-1] + +def test_trim_chat_history_no_token_limit(): + """Test trimming with no token limit set.""" + agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None) + + initial_messages = [HumanMessage(content="Initial")] + chat_history = [ + HumanMessage(content="A" * 1000), + AIMessage(content="B" * 1000), + HumanMessage(content="C" * 1000) + ] + + result = agent._trim_chat_history(initial_messages, chat_history) + + # Should keep initial message and last 2 messages (max_history_messages=2) + assert len(result) == 3 + assert result[0] == initial_messages[0] + assert result[1:] == chat_history[-2:] + +def test_trim_chat_history_both_limits(): + """Test trimming with both message count and token limits.""" + agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=15) + + initial_messages = [HumanMessage(content="Init")] # ~1 token + chat_history = [ + HumanMessage(content="A" * 40), # ~10 tokens + AIMessage(content="B" * 40), # ~10 tokens + HumanMessage(content="C" * 40), # ~10 tokens + AIMessage(content="D" * 40) # ~10 tokens + ] + + result = agent._trim_chat_history(initial_messages, chat_history) + + # Should first apply message limit (keeping last 3) + # Then token limit should further reduce to fit under 15 tokens + assert len(result) == 2 # Initial message + 1 message under token limit + assert result[0] == initial_messages[0] + assert result[1] == chat_history[-1] diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index ee5197d..9ff011f 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -3,7 +3,9 @@ import pytest from unittest.mock import patch, Mock from langchain_openai.chat_models import ChatOpenAI from langchain_anthropic.chat_models import ChatAnthropic +from langchain_core.messages import HumanMessage from dataclasses import dataclass +from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.env import validate_environment from ra_aid.llm import initialize_llm, initialize_expert_llm @@ -87,6 +89,21 @@ def test_initialize_expert_unsupported_provider(clean_env): with pytest.raises(ValueError, match=r"Unsupported provider: unknown"): initialize_expert_llm("unknown", "model") +def test_estimate_tokens(): + """Test token estimation functionality.""" + # Test empty/None cases + assert CiaynAgent._estimate_tokens(None) == 0 + assert CiaynAgent._estimate_tokens('') == 0 + + # Test string content + assert CiaynAgent._estimate_tokens('test') == 1 # 4 bytes + assert CiaynAgent._estimate_tokens('hello world') == 2 # 11 bytes + assert CiaynAgent._estimate_tokens('🚀') == 1 # 4 bytes + + # Test message content + msg = HumanMessage(content='test message') + assert CiaynAgent._estimate_tokens(msg) == 3 # 11 bytes + def test_initialize_openai(clean_env, mock_openai): """Test OpenAI provider initialization""" os.environ["OPENAI_API_KEY"] = "test-key"