ciayn
This commit is contained in:
parent
4cb98370c2
commit
6f10db811e
|
|
@ -1,6 +1,6 @@
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Dict, Any, Generator, List, Optional
|
from typing import Dict, Any, Generator, List, Optional, Union
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
|
|
||||||
class CiaynAgent:
|
class CiaynAgent:
|
||||||
|
|
@ -20,17 +20,19 @@ class CiaynAgent:
|
||||||
\"\"\""""
|
\"\"\""""
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def __init__(self, model, tools: list, max_history_messages: int = 50):
|
def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = None):
|
||||||
"""Initialize the agent with a model and list of tools.
|
"""Initialize the agent with a model and list of tools.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The language model to use
|
model: The language model to use
|
||||||
tools: List of tools available to the agent
|
tools: List of tools available to the agent
|
||||||
max_history_messages: Maximum number of messages to keep in chat history
|
max_history_messages: Maximum number of messages to keep in chat history
|
||||||
|
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
self.max_history_messages = max_history_messages
|
self.max_history_messages = max_history_messages
|
||||||
|
self.max_tokens = max_tokens
|
||||||
self.available_functions = []
|
self.available_functions = []
|
||||||
for t in tools:
|
for t in tools:
|
||||||
self.available_functions.append(self._get_function_info(t.func))
|
self.available_functions.append(self._get_function_info(t.func))
|
||||||
|
|
@ -98,11 +100,36 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _trim_chat_history(self, initial_messages: List[Any], chat_history: List[Any]) -> List[Any]:
|
@staticmethod
|
||||||
"""Trim chat history to maximum length while preserving initial messages.
|
def _estimate_tokens(content: Optional[Union[str, BaseMessage]]) -> int:
|
||||||
|
"""Estimate number of tokens in content using simple byte length heuristic.
|
||||||
|
|
||||||
Only trims the chat_history portion while preserving all initial messages.
|
Estimates 1 token per 4 bytes of content. For messages, uses the content field.
|
||||||
Returns the concatenated list of initial_messages + trimmed chat_history.
|
|
||||||
|
Args:
|
||||||
|
content: String content or Message object to estimate tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Estimated number of tokens, 0 if content is None/empty
|
||||||
|
"""
|
||||||
|
if content is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if isinstance(content, BaseMessage):
|
||||||
|
text = content.content
|
||||||
|
else:
|
||||||
|
text = content
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return len(text.encode('utf-8')) // 4
|
||||||
|
|
||||||
|
def _trim_chat_history(self, initial_messages: List[Any], chat_history: List[Any]) -> List[Any]:
|
||||||
|
"""Trim chat history based on message count and token limits while preserving initial messages.
|
||||||
|
|
||||||
|
Applies both message count and token limits (if configured) to chat_history,
|
||||||
|
while preserving all initial_messages. Returns concatenated result.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
initial_messages: List of initial messages to preserve
|
initial_messages: List of initial messages to preserve
|
||||||
|
|
@ -111,11 +138,25 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
Returns:
|
Returns:
|
||||||
List[Any]: Concatenated initial_messages + trimmed chat_history
|
List[Any]: Concatenated initial_messages + trimmed chat_history
|
||||||
"""
|
"""
|
||||||
if len(chat_history) <= self.max_history_messages:
|
# First apply message count limit
|
||||||
|
if len(chat_history) > self.max_history_messages:
|
||||||
|
chat_history = chat_history[-self.max_history_messages:]
|
||||||
|
|
||||||
|
# Skip token limiting if max_tokens is None
|
||||||
|
if self.max_tokens is None:
|
||||||
return initial_messages + chat_history
|
return initial_messages + chat_history
|
||||||
|
|
||||||
# Keep last max_history_messages from chat_history
|
# Calculate initial messages token count
|
||||||
return initial_messages + chat_history[-self.max_history_messages:]
|
initial_tokens = sum(self._estimate_tokens(msg) for msg in initial_messages)
|
||||||
|
|
||||||
|
# Remove messages from start of chat_history until under token limit
|
||||||
|
while chat_history:
|
||||||
|
total_tokens = initial_tokens + sum(self._estimate_tokens(msg) for msg in chat_history)
|
||||||
|
if total_tokens <= self.max_tokens:
|
||||||
|
break
|
||||||
|
chat_history.pop(0)
|
||||||
|
|
||||||
|
return initial_messages + chat_history
|
||||||
|
|
||||||
def stream(self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None) -> Generator[Dict[str, Any], None, None]:
|
def stream(self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None) -> Generator[Dict[str, Any], None, None]:
|
||||||
"""Stream agent responses in a format compatible with print_agent_output."""
|
"""Stream agent responses in a format compatible with print_agent_output."""
|
||||||
|
|
|
||||||
|
|
@ -98,3 +98,59 @@ def test_trim_chat_history_empty_chat(agent):
|
||||||
# Verify initial messages are preserved and no trimming occurred
|
# Verify initial messages are preserved and no trimming occurred
|
||||||
assert result == initial_messages
|
assert result == initial_messages
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
|
|
||||||
|
def test_trim_chat_history_token_limit():
|
||||||
|
"""Test trimming based on token limit."""
|
||||||
|
agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=20)
|
||||||
|
|
||||||
|
initial_messages = [HumanMessage(content="Initial")] # ~2 tokens
|
||||||
|
chat_history = [
|
||||||
|
HumanMessage(content="A" * 40), # ~10 tokens
|
||||||
|
AIMessage(content="B" * 40), # ~10 tokens
|
||||||
|
HumanMessage(content="C" * 40) # ~10 tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||||
|
|
||||||
|
# Should keep initial message (~2 tokens) and last message (~10 tokens)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0] == initial_messages[0]
|
||||||
|
assert result[1] == chat_history[-1]
|
||||||
|
|
||||||
|
def test_trim_chat_history_no_token_limit():
|
||||||
|
"""Test trimming with no token limit set."""
|
||||||
|
agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None)
|
||||||
|
|
||||||
|
initial_messages = [HumanMessage(content="Initial")]
|
||||||
|
chat_history = [
|
||||||
|
HumanMessage(content="A" * 1000),
|
||||||
|
AIMessage(content="B" * 1000),
|
||||||
|
HumanMessage(content="C" * 1000)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||||
|
|
||||||
|
# Should keep initial message and last 2 messages (max_history_messages=2)
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0] == initial_messages[0]
|
||||||
|
assert result[1:] == chat_history[-2:]
|
||||||
|
|
||||||
|
def test_trim_chat_history_both_limits():
|
||||||
|
"""Test trimming with both message count and token limits."""
|
||||||
|
agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=15)
|
||||||
|
|
||||||
|
initial_messages = [HumanMessage(content="Init")] # ~1 token
|
||||||
|
chat_history = [
|
||||||
|
HumanMessage(content="A" * 40), # ~10 tokens
|
||||||
|
AIMessage(content="B" * 40), # ~10 tokens
|
||||||
|
HumanMessage(content="C" * 40), # ~10 tokens
|
||||||
|
AIMessage(content="D" * 40) # ~10 tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||||
|
|
||||||
|
# Should first apply message limit (keeping last 3)
|
||||||
|
# Then token limit should further reduce to fit under 15 tokens
|
||||||
|
assert len(result) == 2 # Initial message + 1 message under token limit
|
||||||
|
assert result[0] == initial_messages[0]
|
||||||
|
assert result[1] == chat_history[-1]
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,9 @@ import pytest
|
||||||
from unittest.mock import patch, Mock
|
from unittest.mock import patch, Mock
|
||||||
from langchain_openai.chat_models import ChatOpenAI
|
from langchain_openai.chat_models import ChatOpenAI
|
||||||
from langchain_anthropic.chat_models import ChatAnthropic
|
from langchain_anthropic.chat_models import ChatAnthropic
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
|
|
||||||
from ra_aid.env import validate_environment
|
from ra_aid.env import validate_environment
|
||||||
from ra_aid.llm import initialize_llm, initialize_expert_llm
|
from ra_aid.llm import initialize_llm, initialize_expert_llm
|
||||||
|
|
@ -87,6 +89,21 @@ def test_initialize_expert_unsupported_provider(clean_env):
|
||||||
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
|
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
|
||||||
initialize_expert_llm("unknown", "model")
|
initialize_expert_llm("unknown", "model")
|
||||||
|
|
||||||
|
def test_estimate_tokens():
|
||||||
|
"""Test token estimation functionality."""
|
||||||
|
# Test empty/None cases
|
||||||
|
assert CiaynAgent._estimate_tokens(None) == 0
|
||||||
|
assert CiaynAgent._estimate_tokens('') == 0
|
||||||
|
|
||||||
|
# Test string content
|
||||||
|
assert CiaynAgent._estimate_tokens('test') == 1 # 4 bytes
|
||||||
|
assert CiaynAgent._estimate_tokens('hello world') == 2 # 11 bytes
|
||||||
|
assert CiaynAgent._estimate_tokens('🚀') == 1 # 4 bytes
|
||||||
|
|
||||||
|
# Test message content
|
||||||
|
msg = HumanMessage(content='test message')
|
||||||
|
assert CiaynAgent._estimate_tokens(msg) == 3 # 11 bytes
|
||||||
|
|
||||||
def test_initialize_openai(clean_env, mock_openai):
|
def test_initialize_openai(clean_env, mock_openai):
|
||||||
"""Test OpenAI provider initialization"""
|
"""Test OpenAI provider initialization"""
|
||||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue