Merge llm-fn-call: resolve changelog conflicts
This commit is contained in:
commit
cf485bd96c
|
|
@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
- Add CiaynAgent to support models that do not have, or are not good at, agentic function calling.
|
||||
|
||||
## [0.10.3] - 2024-12-27
|
||||
|
||||
- Fix logging on interrupt.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
import os
|
||||
import uuid
|
||||
from dotenv import load_dotenv
|
||||
from ra_aid.agent_utils import run_agent_with_retry
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.tools import tool
|
||||
from ra_aid.tools.list_directory import list_directory_tree
|
||||
from ra_aid.tool_configs import get_read_only_tools
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.console import Console
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
|
||||
console = Console()
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
@tool
|
||||
def check_weather(location: str) -> str:
|
||||
"""Gets the weather at the given location."""
|
||||
return f"The weather in {location} is sunny!"
|
||||
|
||||
@tool
|
||||
def output_message(message: str, prompt_user_input: bool = False) -> str:
|
||||
"""Outputs a message to the user, optionally prompting for input."""
|
||||
console.print(Panel(Markdown(message.strip())))
|
||||
if prompt_user_input:
|
||||
user_input = input("\n> ").strip()
|
||||
print()
|
||||
return user_input
|
||||
return ""
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize the chat model
|
||||
chat = ChatOpenAI(
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
temperature=0.7,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="qwen/qwen-2.5-coder-32b-instruct"
|
||||
)
|
||||
|
||||
# Get tools
|
||||
tools = get_read_only_tools(True, True)
|
||||
tools.append(output_message)
|
||||
|
||||
# Initialize agent
|
||||
agent = CiaynAgent(chat, tools)
|
||||
|
||||
# Test chat prompt
|
||||
test_prompt = "Find the tests in this codebase."
|
||||
|
||||
# Run the agent using run_agent_with_retry
|
||||
result = run_agent_with_retry(agent, test_prompt, {"configurable": {"thread_id": str(uuid.uuid4())}})
|
||||
|
||||
# Initial greeting
|
||||
print("Welcome to the Chat Interface! (Type 'quit' to exit)")
|
||||
|
|
@ -4,7 +4,6 @@ import uuid
|
|||
from rich.panel import Panel
|
||||
from rich.console import Console
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from ra_aid.env import validate_environment
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
from ra_aid.tools.human import ask_human
|
||||
|
|
@ -15,7 +14,8 @@ from ra_aid.agent_utils import (
|
|||
AgentInterrupt,
|
||||
run_agent_with_retry,
|
||||
run_research_agent,
|
||||
run_planning_agent
|
||||
run_planning_agent,
|
||||
create_agent
|
||||
)
|
||||
from ra_aid.prompts import (
|
||||
CHAT_PROMPT,
|
||||
|
|
@ -192,7 +192,7 @@ def main():
|
|||
initial_request = ask_human.invoke({"question": "What would you like help with?"})
|
||||
|
||||
# Create chat agent with appropriate tools
|
||||
chat_agent = create_react_agent(
|
||||
chat_agent = create_agent(
|
||||
model,
|
||||
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
|
||||
checkpointer=MemorySaver()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,12 @@ import time
|
|||
from typing import Optional
|
||||
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
from ra_aid.console.formatting import print_stage_header, print_error
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import tool
|
||||
from typing import List, Any
|
||||
from ra_aid.console.output import print_agent_output
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
|
|
@ -41,7 +46,6 @@ from ra_aid.prompts import (
|
|||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
|
@ -65,6 +69,50 @@ console = Console()
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@tool
|
||||
def output_markdown_message(message: str) -> str:
|
||||
"""Outputs a message to the user, optionally prompting for input."""
|
||||
console.print(Panel(Markdown(message.strip()), title="🤖 Assistant"))
|
||||
return "Message output."
|
||||
|
||||
def create_agent(
|
||||
model: BaseChatModel,
|
||||
tools: List[Any],
|
||||
*,
|
||||
checkpointer: Any = None
|
||||
) -> Any:
|
||||
"""Create a react agent with the given configuration.
|
||||
|
||||
Args:
|
||||
model: The LLM model to use
|
||||
tools: List of tools to provide to the agent
|
||||
checkpointer: Optional memory checkpointer
|
||||
|
||||
Returns:
|
||||
The created agent instance
|
||||
"""
|
||||
try:
|
||||
# Extract model info from module path
|
||||
module_path = model.__class__.__module__.split('.')
|
||||
if len(module_path) > 1:
|
||||
provider = module_path[1] # e.g. anthropic from langchain_anthropic
|
||||
else:
|
||||
provider = None
|
||||
|
||||
# Get model name if available
|
||||
model_name = getattr(model, 'model_name', '').lower()
|
||||
|
||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||
if provider == 'anthropic' and 'claude' in model_name:
|
||||
return create_react_agent(model, tools, checkpointer=checkpointer)
|
||||
else:
|
||||
return CiaynAgent(model, tools)
|
||||
|
||||
except Exception as e:
|
||||
# Default to REACT agent if provider/model detection fails
|
||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||
return create_react_agent(model, tools, checkpointer=checkpointer)
|
||||
|
||||
def run_research_agent(
|
||||
base_task_or_query: str,
|
||||
model,
|
||||
|
|
@ -125,7 +173,7 @@ def run_research_agent(
|
|||
)
|
||||
|
||||
# Create agent
|
||||
agent = create_react_agent(model, tools, checkpointer=memory)
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
|
||||
# Format prompt sections
|
||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||
|
|
@ -238,7 +286,7 @@ def run_web_research_agent(
|
|||
tools = get_web_research_tools(expert_enabled=expert_enabled)
|
||||
|
||||
# Create agent
|
||||
agent = create_react_agent(model, tools, checkpointer=memory)
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
|
||||
# Format prompt sections
|
||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||
|
|
@ -351,7 +399,7 @@ def run_planning_agent(
|
|||
tools = get_planning_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
|
||||
|
||||
# Create agent
|
||||
agent = create_react_agent(model, tools, checkpointer=memory)
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
|
||||
# Format prompt sections
|
||||
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
||||
|
|
@ -438,7 +486,7 @@ def run_task_implementation_agent(
|
|||
tools = get_implementation_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
|
||||
|
||||
# Create agent
|
||||
agent = create_react_agent(model, tools, checkpointer=memory)
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
|
||||
# Build prompt
|
||||
prompt = IMPLEMENTATION_PROMPT.format(
|
||||
|
|
@ -467,6 +515,8 @@ def run_task_implementation_agent(
|
|||
try:
|
||||
logger.debug("Implementation agent completed successfully")
|
||||
return run_agent_with_retry(agent, prompt, run_config)
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Implementation agent failed: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
|
@ -523,7 +573,11 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
|||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
print_agent_output(chunk)
|
||||
logger.debug("Agent run completed successfully")
|
||||
if _global_memory['task_completed']:
|
||||
_global_memory['task_completed'] = False
|
||||
_global_memory['completion_message'] = ''
|
||||
break
|
||||
logger.debug("Agent run completed successfully")
|
||||
return "Agent run completed successfully"
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -0,0 +1,216 @@
|
|||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, Generator, List, Optional, Union
|
||||
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
@dataclass
|
||||
class ChunkMessage:
|
||||
content: str
|
||||
status: str
|
||||
|
||||
class CiaynAgent:
|
||||
"""Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction.
|
||||
|
||||
The CIAYN philosophy emphasizes direct code generation and execution over structured APIs:
|
||||
- Language model generates executable Python code snippets
|
||||
- Tools are invoked through natural Python code rather than fixed schemas
|
||||
- Flexible and adaptable approach to tool usage through dynamic code
|
||||
- Complex workflows emerge from composing code segments
|
||||
|
||||
Code Generation & Function Calling:
|
||||
- Dynamic generation of Python code for tool invocation
|
||||
- Handles complex nested function calls and argument structures
|
||||
- Natural integration of tool outputs into Python data flow
|
||||
- Runtime code composition for multi-step operations
|
||||
|
||||
ReAct Pattern Implementation:
|
||||
- Observation: Captures tool execution results
|
||||
- Reasoning: Analyzes outputs to determine next steps
|
||||
- Action: Generates and executes appropriate code
|
||||
- Reflection: Updates state and plans next iteration
|
||||
- Maintains conversation context across iterations
|
||||
|
||||
Core Capabilities:
|
||||
- Dynamic tool registration with automatic documentation
|
||||
- Sandboxed code execution environment
|
||||
- Token-aware chat history management
|
||||
- Comprehensive error handling and recovery
|
||||
- Streaming interface for real-time interaction
|
||||
- Memory management with configurable limits
|
||||
"""
|
||||
|
||||
def _get_function_info(self, func):
|
||||
"""
|
||||
Returns a well-formatted string containing the function signature and docstring,
|
||||
designed to be easily readable by both humans and LLMs.
|
||||
"""
|
||||
signature = inspect.signature(func)
|
||||
docstring = inspect.getdoc(func)
|
||||
if docstring is None:
|
||||
docstring = "No docstring provided"
|
||||
full_signature = f"{func.__name__}{signature}"
|
||||
info = f"""{full_signature}
|
||||
\"\"\"
|
||||
{docstring}
|
||||
\"\"\""""
|
||||
return info
|
||||
|
||||
def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = 100000):
|
||||
"""Initialize the agent with a model and list of tools.
|
||||
|
||||
Args:
|
||||
model: The language model to use
|
||||
tools: List of tools available to the agent
|
||||
max_history_messages: Maximum number of messages to keep in chat history
|
||||
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
|
||||
"""
|
||||
self.model = model
|
||||
self.tools = tools
|
||||
self.max_history_messages = max_history_messages
|
||||
self.max_tokens = max_tokens
|
||||
self.available_functions = []
|
||||
for t in tools:
|
||||
self.available_functions.append(self._get_function_info(t.func))
|
||||
|
||||
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
||||
"""Build the prompt for the agent including available tools and context."""
|
||||
base_prompt = ""
|
||||
if last_result is not None:
|
||||
base_prompt += f"\n<last result>{last_result}</last result>"
|
||||
|
||||
base_prompt += f"""
|
||||
|
||||
<agent instructions>
|
||||
You are a ReAct agent. You run in a loop and use ONE of the available functions per iteration.
|
||||
If the current query does not require a function call, just use output_message to say what you would normally say.
|
||||
The result of that function call will be given to you in the next message.
|
||||
Call one function at a time. Function arguments can be complex objects, long strings, etc. if needed.
|
||||
The user cannot see the results of function calls, so you have to explicitly call output_message if you want them to see something.
|
||||
You must always respond with a single line of python that calls one of the available tools.
|
||||
Use as many steps as you need to in order to fully complete the task.
|
||||
Start by asking the user what they want.
|
||||
</agent instructions>
|
||||
|
||||
You must carefully review the conversation history, which functions were called so far, returned results, etc., and make sure the very next function call you make makes sense in order to achieve the original goal.
|
||||
|
||||
You must ONLY use ONE of the following functions (these are the ONLY functions that exist):
|
||||
|
||||
<available functions>
|
||||
{"\n\n".join(self.available_functions)}
|
||||
</available functions>
|
||||
|
||||
Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||
return base_prompt
|
||||
|
||||
def _execute_tool(self, code: str) -> str:
|
||||
"""Execute a tool call and return its result."""
|
||||
globals_dict = {
|
||||
tool.func.__name__: tool.func
|
||||
for tool in self.tools
|
||||
}
|
||||
|
||||
try:
|
||||
result = eval(code.strip(), globals_dict)
|
||||
return result
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing code: {str(e)}"
|
||||
raise ToolExecutionError(error_msg)
|
||||
|
||||
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
||||
"""Create an agent chunk in the format expected by print_agent_output."""
|
||||
return {
|
||||
"agent": {
|
||||
"messages": [AIMessage(content=content)]
|
||||
}
|
||||
}
|
||||
|
||||
def _create_error_chunk(self, content: str) -> Dict[str, Any]:
|
||||
"""Create an error chunk in the format expected by print_agent_output."""
|
||||
message = ChunkMessage(content=content, status="error")
|
||||
return {
|
||||
"tools": {
|
||||
"messages": [message]
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _estimate_tokens(content: Optional[Union[str, BaseMessage]]) -> int:
|
||||
"""Estimate number of tokens in content using simple byte length heuristic.
|
||||
|
||||
Estimates 1 token per 4 bytes of content. For messages, uses the content field.
|
||||
|
||||
Args:
|
||||
content: String content or Message object to estimate tokens for
|
||||
|
||||
Returns:
|
||||
int: Estimated number of tokens, 0 if content is None/empty
|
||||
"""
|
||||
if content is None:
|
||||
return 0
|
||||
|
||||
if isinstance(content, BaseMessage):
|
||||
text = content.content
|
||||
else:
|
||||
text = content
|
||||
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
return len(text.encode('utf-8')) // 4
|
||||
|
||||
def _trim_chat_history(self, initial_messages: List[Any], chat_history: List[Any]) -> List[Any]:
|
||||
"""Trim chat history based on message count and token limits while preserving initial messages.
|
||||
|
||||
Applies both message count and token limits (if configured) to chat_history,
|
||||
while preserving all initial_messages. Returns concatenated result.
|
||||
|
||||
Args:
|
||||
initial_messages: List of initial messages to preserve
|
||||
chat_history: List of chat messages that may be trimmed
|
||||
|
||||
Returns:
|
||||
List[Any]: Concatenated initial_messages + trimmed chat_history
|
||||
"""
|
||||
# First apply message count limit
|
||||
if len(chat_history) > self.max_history_messages:
|
||||
chat_history = chat_history[-self.max_history_messages:]
|
||||
|
||||
# Skip token limiting if max_tokens is None
|
||||
if self.max_tokens is None:
|
||||
return initial_messages + chat_history
|
||||
|
||||
# Calculate initial messages token count
|
||||
initial_tokens = sum(self._estimate_tokens(msg) for msg in initial_messages)
|
||||
|
||||
# Remove messages from start of chat_history until under token limit
|
||||
while chat_history:
|
||||
total_tokens = initial_tokens + sum(self._estimate_tokens(msg) for msg in chat_history)
|
||||
if total_tokens <= self.max_tokens:
|
||||
break
|
||||
chat_history.pop(0)
|
||||
|
||||
return initial_messages + chat_history
|
||||
|
||||
def stream(self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None) -> Generator[Dict[str, Any], None, None]:
|
||||
"""Stream agent responses in a format compatible with print_agent_output."""
|
||||
initial_messages = messages_dict.get("messages", [])
|
||||
chat_history = []
|
||||
last_result = None
|
||||
first_iteration = True
|
||||
|
||||
while True:
|
||||
base_prompt = self._build_prompt(None if first_iteration else last_result)
|
||||
chat_history.append(HumanMessage(content=base_prompt))
|
||||
|
||||
full_history = self._trim_chat_history(initial_messages, chat_history)
|
||||
response = self.model.invoke(full_history)
|
||||
|
||||
try:
|
||||
last_result = self._execute_tool(response.content)
|
||||
chat_history.append(response)
|
||||
first_iteration = False
|
||||
yield {}
|
||||
|
||||
except ToolExecutionError as e:
|
||||
chat_history.append(HumanMessage(content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."))
|
||||
yield self._create_error_chunk(str(e))
|
||||
|
|
@ -7,3 +7,12 @@ class AgentInterrupt(Exception):
|
|||
separate from KeyboardInterrupt which is reserved for top-level handling.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ToolExecutionError(Exception):
|
||||
"""Exception raised when a tool execution fails.
|
||||
|
||||
This exception is used to distinguish tool execution failures
|
||||
from other types of errors in the agent system.
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -192,6 +192,8 @@ def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
|
|||
"""Store multiple key source code snippets in global memory.
|
||||
Automatically adds the filepaths of the snippets to related files.
|
||||
|
||||
This is for **existing**, or **just-written** files, not for things to be created in the future.
|
||||
|
||||
Args:
|
||||
snippets: List of snippet information dictionaries containing:
|
||||
- filepath: Path to the source file
|
||||
|
|
|
|||
|
|
@ -0,0 +1,156 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
"""Create a mock language model."""
|
||||
model = Mock()
|
||||
model.invoke = Mock()
|
||||
return model
|
||||
|
||||
@pytest.fixture
|
||||
def agent(mock_model):
|
||||
"""Create a CiaynAgent instance with mock model."""
|
||||
tools = [] # Empty tools list for testing trimming functionality
|
||||
return CiaynAgent(mock_model, tools, max_history_messages=3)
|
||||
|
||||
def test_trim_chat_history_preserves_initial_messages(agent):
|
||||
"""Test that initial messages are preserved during trimming."""
|
||||
initial_messages = [
|
||||
HumanMessage(content="Initial 1"),
|
||||
AIMessage(content="Initial 2")
|
||||
]
|
||||
chat_history = [
|
||||
HumanMessage(content="Chat 1"),
|
||||
AIMessage(content="Chat 2"),
|
||||
HumanMessage(content="Chat 3"),
|
||||
AIMessage(content="Chat 4")
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify initial messages are preserved
|
||||
assert result[:2] == initial_messages
|
||||
# Verify only last 3 chat messages are kept (due to max_history_messages=3)
|
||||
assert len(result[2:]) == 3
|
||||
assert result[2:] == chat_history[-3:]
|
||||
|
||||
def test_trim_chat_history_under_limit(agent):
|
||||
"""Test trimming when chat history is under the maximum limit."""
|
||||
initial_messages = [HumanMessage(content="Initial")]
|
||||
chat_history = [
|
||||
HumanMessage(content="Chat 1"),
|
||||
AIMessage(content="Chat 2")
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify no trimming occurred
|
||||
assert len(result) == 3
|
||||
assert result == initial_messages + chat_history
|
||||
|
||||
def test_trim_chat_history_over_limit(agent):
|
||||
"""Test trimming when chat history exceeds the maximum limit."""
|
||||
initial_messages = [HumanMessage(content="Initial")]
|
||||
chat_history = [
|
||||
HumanMessage(content="Chat 1"),
|
||||
AIMessage(content="Chat 2"),
|
||||
HumanMessage(content="Chat 3"),
|
||||
AIMessage(content="Chat 4"),
|
||||
HumanMessage(content="Chat 5")
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify correct trimming
|
||||
assert len(result) == 4 # initial + max_history_messages
|
||||
assert result[0] == initial_messages[0] # Initial message preserved
|
||||
assert result[1:] == chat_history[-3:] # Last 3 chat messages kept
|
||||
|
||||
def test_trim_chat_history_empty_initial(agent):
|
||||
"""Test trimming with empty initial messages."""
|
||||
initial_messages = []
|
||||
chat_history = [
|
||||
HumanMessage(content="Chat 1"),
|
||||
AIMessage(content="Chat 2"),
|
||||
HumanMessage(content="Chat 3"),
|
||||
AIMessage(content="Chat 4")
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify only last 3 messages are kept
|
||||
assert len(result) == 3
|
||||
assert result == chat_history[-3:]
|
||||
|
||||
def test_trim_chat_history_empty_chat(agent):
|
||||
"""Test trimming with empty chat history."""
|
||||
initial_messages = [
|
||||
HumanMessage(content="Initial 1"),
|
||||
AIMessage(content="Initial 2")
|
||||
]
|
||||
chat_history = []
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify initial messages are preserved and no trimming occurred
|
||||
assert result == initial_messages
|
||||
assert len(result) == 2
|
||||
|
||||
def test_trim_chat_history_token_limit():
|
||||
"""Test trimming based on token limit."""
|
||||
agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=20)
|
||||
|
||||
initial_messages = [HumanMessage(content="Initial")] # ~2 tokens
|
||||
chat_history = [
|
||||
HumanMessage(content="A" * 40), # ~10 tokens
|
||||
AIMessage(content="B" * 40), # ~10 tokens
|
||||
HumanMessage(content="C" * 40) # ~10 tokens
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Should keep initial message (~2 tokens) and last message (~10 tokens)
|
||||
assert len(result) == 2
|
||||
assert result[0] == initial_messages[0]
|
||||
assert result[1] == chat_history[-1]
|
||||
|
||||
def test_trim_chat_history_no_token_limit():
|
||||
"""Test trimming with no token limit set."""
|
||||
agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None)
|
||||
|
||||
initial_messages = [HumanMessage(content="Initial")]
|
||||
chat_history = [
|
||||
HumanMessage(content="A" * 1000),
|
||||
AIMessage(content="B" * 1000),
|
||||
HumanMessage(content="C" * 1000)
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Should keep initial message and last 2 messages (max_history_messages=2)
|
||||
assert len(result) == 3
|
||||
assert result[0] == initial_messages[0]
|
||||
assert result[1:] == chat_history[-2:]
|
||||
|
||||
def test_trim_chat_history_both_limits():
|
||||
"""Test trimming with both message count and token limits."""
|
||||
agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=15)
|
||||
|
||||
initial_messages = [HumanMessage(content="Init")] # ~1 token
|
||||
chat_history = [
|
||||
HumanMessage(content="A" * 40), # ~10 tokens
|
||||
AIMessage(content="B" * 40), # ~10 tokens
|
||||
HumanMessage(content="C" * 40), # ~10 tokens
|
||||
AIMessage(content="D" * 40) # ~10 tokens
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Should first apply message limit (keeping last 3)
|
||||
# Then token limit should further reduce to fit under 15 tokens
|
||||
assert len(result) == 2 # Initial message + 1 message under token limit
|
||||
assert result[0] == initial_messages[0]
|
||||
assert result[1] == chat_history[-1]
|
||||
|
|
@ -3,7 +3,9 @@ import pytest
|
|||
from unittest.mock import patch, Mock
|
||||
from langchain_openai.chat_models import ChatOpenAI
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
from langchain_core.messages import HumanMessage
|
||||
from dataclasses import dataclass
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
|
||||
from ra_aid.env import validate_environment
|
||||
from ra_aid.llm import initialize_llm, initialize_expert_llm
|
||||
|
|
@ -87,6 +89,21 @@ def test_initialize_expert_unsupported_provider(clean_env):
|
|||
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
|
||||
initialize_expert_llm("unknown", "model")
|
||||
|
||||
def test_estimate_tokens():
|
||||
"""Test token estimation functionality."""
|
||||
# Test empty/None cases
|
||||
assert CiaynAgent._estimate_tokens(None) == 0
|
||||
assert CiaynAgent._estimate_tokens('') == 0
|
||||
|
||||
# Test string content
|
||||
assert CiaynAgent._estimate_tokens('test') == 1 # 4 bytes
|
||||
assert CiaynAgent._estimate_tokens('hello world') == 2 # 11 bytes
|
||||
assert CiaynAgent._estimate_tokens('🚀') == 1 # 4 bytes
|
||||
|
||||
# Test message content
|
||||
msg = HumanMessage(content='test message')
|
||||
assert CiaynAgent._estimate_tokens(msg) == 3 # 11 bytes
|
||||
|
||||
def test_initialize_openai(clean_env, mock_openai):
|
||||
"""Test OpenAI provider initialization"""
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
|
|
|
|||
Loading…
Reference in New Issue