diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8874ca0..e2b52dc 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [Unreleased]
+
+- Add CiaynAgent to support models that do not have, or are not good at, agentic function calling.
+
## [0.10.3] - 2024-12-27
- Fix logging on interrupt.
diff --git a/experiment/llm_test.py b/experiment/llm_test.py
new file mode 100644
index 0000000..90115b0
--- /dev/null
+++ b/experiment/llm_test.py
@@ -0,0 +1,57 @@
+import os
+import uuid
+from dotenv import load_dotenv
+from ra_aid.agent_utils import run_agent_with_retry
+from langchain_openai import ChatOpenAI
+from langchain_core.tools import tool
+from ra_aid.tools.list_directory import list_directory_tree
+from ra_aid.tool_configs import get_read_only_tools
+from rich.panel import Panel
+from rich.markdown import Markdown
+from rich.console import Console
+from ra_aid.agents.ciayn_agent import CiaynAgent
+
+console = Console()
+
+# Load environment variables
+load_dotenv()
+
+@tool
+def check_weather(location: str) -> str:
+ """Gets the weather at the given location."""
+ return f"The weather in {location} is sunny!"
+
+@tool
+def output_message(message: str, prompt_user_input: bool = False) -> str:
+ """Outputs a message to the user, optionally prompting for input."""
+ console.print(Panel(Markdown(message.strip())))
+ if prompt_user_input:
+ user_input = input("\n> ").strip()
+ print()
+ return user_input
+ return ""
+
+if __name__ == "__main__":
+ # Initialize the chat model
+ chat = ChatOpenAI(
+ api_key=os.getenv("OPENROUTER_API_KEY"),
+ temperature=0.7,
+ base_url="https://openrouter.ai/api/v1",
+ model="qwen/qwen-2.5-coder-32b-instruct"
+ )
+
+ # Get tools
+ tools = get_read_only_tools(True, True)
+ tools.append(output_message)
+
+ # Initialize agent
+ agent = CiaynAgent(chat, tools)
+
+ # Test chat prompt
+ test_prompt = "Find the tests in this codebase."
+
+ # Run the agent using run_agent_with_retry
+ result = run_agent_with_retry(agent, test_prompt, {"configurable": {"thread_id": str(uuid.uuid4())}})
+
+ # Initial greeting
+ print("Welcome to the Chat Interface! (Type 'quit' to exit)")
diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py
index 46dc647..a958843 100644
--- a/ra_aid/__main__.py
+++ b/ra_aid/__main__.py
@@ -4,7 +4,6 @@ import uuid
from rich.panel import Panel
from rich.console import Console
from langgraph.checkpoint.memory import MemorySaver
-from langgraph.prebuilt import create_react_agent
from ra_aid.env import validate_environment
from ra_aid.tools.memory import _global_memory
from ra_aid.tools.human import ask_human
@@ -15,7 +14,8 @@ from ra_aid.agent_utils import (
AgentInterrupt,
run_agent_with_retry,
run_research_agent,
- run_planning_agent
+ run_planning_agent,
+ create_agent
)
from ra_aid.prompts import (
CHAT_PROMPT,
@@ -192,7 +192,7 @@ def main():
initial_request = ask_human.invoke({"question": "What would you like help with?"})
# Create chat agent with appropriate tools
- chat_agent = create_react_agent(
+ chat_agent = create_agent(
model,
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
checkpointer=MemorySaver()
diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py
index 33608dd..8d80769 100644
--- a/ra_aid/agent_utils.py
+++ b/ra_aid/agent_utils.py
@@ -11,7 +11,12 @@ import time
from typing import Optional
from langgraph.prebuilt import create_react_agent
+from ra_aid.agents.ciayn_agent import CiaynAgent
+from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.console.formatting import print_stage_header, print_error
+from langchain_core.language_models import BaseChatModel
+from langchain_core.tools import tool
+from typing import List, Any
from ra_aid.console.output import print_agent_output
from ra_aid.logging_config import get_logger
from ra_aid.exceptions import AgentInterrupt
@@ -41,7 +46,6 @@ from ra_aid.prompts import (
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage
-from langchain_core.messages import BaseMessage
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
from rich.console import Console
from rich.markdown import Markdown
@@ -65,6 +69,50 @@ console = Console()
logger = get_logger(__name__)
+@tool
+def output_markdown_message(message: str) -> str:
+ """Outputs a message to the user, optionally prompting for input."""
+ console.print(Panel(Markdown(message.strip()), title="🤖 Assistant"))
+ return "Message output."
+
+def create_agent(
+ model: BaseChatModel,
+ tools: List[Any],
+ *,
+ checkpointer: Any = None
+) -> Any:
+ """Create a react agent with the given configuration.
+
+ Args:
+ model: The LLM model to use
+ tools: List of tools to provide to the agent
+ checkpointer: Optional memory checkpointer
+
+ Returns:
+ The created agent instance
+ """
+ try:
+ # Extract model info from module path
+ module_path = model.__class__.__module__.split('.')
+ if len(module_path) > 1:
+ provider = module_path[1] # e.g. anthropic from langchain_anthropic
+ else:
+ provider = None
+
+ # Get model name if available
+ model_name = getattr(model, 'model_name', '').lower()
+
+ # Use REACT agent for Anthropic Claude models, otherwise use CIAYN
+ if provider == 'anthropic' and 'claude' in model_name:
+ return create_react_agent(model, tools, checkpointer=checkpointer)
+ else:
+ return CiaynAgent(model, tools)
+
+ except Exception as e:
+ # Default to REACT agent if provider/model detection fails
+ logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
+ return create_react_agent(model, tools, checkpointer=checkpointer)
+
def run_research_agent(
base_task_or_query: str,
model,
@@ -125,7 +173,7 @@ def run_research_agent(
)
# Create agent
- agent = create_react_agent(model, tools, checkpointer=memory)
+ agent = create_agent(model, tools, checkpointer=memory)
# Format prompt sections
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
@@ -238,7 +286,7 @@ def run_web_research_agent(
tools = get_web_research_tools(expert_enabled=expert_enabled)
# Create agent
- agent = create_react_agent(model, tools, checkpointer=memory)
+ agent = create_agent(model, tools, checkpointer=memory)
# Format prompt sections
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
@@ -351,7 +399,7 @@ def run_planning_agent(
tools = get_planning_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
# Create agent
- agent = create_react_agent(model, tools, checkpointer=memory)
+ agent = create_agent(model, tools, checkpointer=memory)
# Format prompt sections
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
@@ -438,7 +486,7 @@ def run_task_implementation_agent(
tools = get_implementation_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
# Create agent
- agent = create_react_agent(model, tools, checkpointer=memory)
+ agent = create_agent(model, tools, checkpointer=memory)
# Build prompt
prompt = IMPLEMENTATION_PROMPT.format(
@@ -467,6 +515,8 @@ def run_task_implementation_agent(
try:
logger.debug("Implementation agent completed successfully")
return run_agent_with_retry(agent, prompt, run_config)
+ except (KeyboardInterrupt, AgentInterrupt):
+ raise
except Exception as e:
logger.error("Implementation agent failed: %s", str(e), exc_info=True)
raise
@@ -523,7 +573,11 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
- logger.debug("Agent run completed successfully")
+ if _global_memory['task_completed']:
+ _global_memory['task_completed'] = False
+ _global_memory['completion_message'] = ''
+ break
+ logger.debug("Agent run completed successfully")
return "Agent run completed successfully"
except (KeyboardInterrupt, AgentInterrupt):
raise
diff --git a/ra_aid/agents/__init__.py b/ra_aid/agents/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py
new file mode 100644
index 0000000..5f021d7
--- /dev/null
+++ b/ra_aid/agents/ciayn_agent.py
@@ -0,0 +1,216 @@
+import inspect
+from dataclasses import dataclass
+from typing import Dict, Any, Generator, List, Optional, Union
+from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
+from ra_aid.exceptions import ToolExecutionError
+@dataclass
+class ChunkMessage:
+ content: str
+ status: str
+
+class CiaynAgent:
+ """Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction.
+
+ The CIAYN philosophy emphasizes direct code generation and execution over structured APIs:
+ - Language model generates executable Python code snippets
+ - Tools are invoked through natural Python code rather than fixed schemas
+ - Flexible and adaptable approach to tool usage through dynamic code
+ - Complex workflows emerge from composing code segments
+
+ Code Generation & Function Calling:
+ - Dynamic generation of Python code for tool invocation
+ - Handles complex nested function calls and argument structures
+ - Natural integration of tool outputs into Python data flow
+ - Runtime code composition for multi-step operations
+
+ ReAct Pattern Implementation:
+ - Observation: Captures tool execution results
+ - Reasoning: Analyzes outputs to determine next steps
+ - Action: Generates and executes appropriate code
+ - Reflection: Updates state and plans next iteration
+ - Maintains conversation context across iterations
+
+ Core Capabilities:
+ - Dynamic tool registration with automatic documentation
+ - Sandboxed code execution environment
+ - Token-aware chat history management
+ - Comprehensive error handling and recovery
+ - Streaming interface for real-time interaction
+ - Memory management with configurable limits
+ """
+
+ def _get_function_info(self, func):
+ """
+ Returns a well-formatted string containing the function signature and docstring,
+ designed to be easily readable by both humans and LLMs.
+ """
+ signature = inspect.signature(func)
+ docstring = inspect.getdoc(func)
+ if docstring is None:
+ docstring = "No docstring provided"
+ full_signature = f"{func.__name__}{signature}"
+ info = f"""{full_signature}
+\"\"\"
+{docstring}
+\"\"\""""
+ return info
+
+ def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = 100000):
+ """Initialize the agent with a model and list of tools.
+
+ Args:
+ model: The language model to use
+ tools: List of tools available to the agent
+ max_history_messages: Maximum number of messages to keep in chat history
+ max_tokens: Maximum number of tokens allowed in message history (None for no limit)
+ """
+ self.model = model
+ self.tools = tools
+ self.max_history_messages = max_history_messages
+ self.max_tokens = max_tokens
+ self.available_functions = []
+ for t in tools:
+ self.available_functions.append(self._get_function_info(t.func))
+
+ def _build_prompt(self, last_result: Optional[str] = None) -> str:
+ """Build the prompt for the agent including available tools and context."""
+ base_prompt = ""
+ if last_result is not None:
+ base_prompt += f"\n{last_result}"
+
+ base_prompt += f"""
+
+
+You are a ReAct agent. You run in a loop and use ONE of the available functions per iteration.
+If the current query does not require a function call, just use output_message to say what you would normally say.
+The result of that function call will be given to you in the next message.
+Call one function at a time. Function arguments can be complex objects, long strings, etc. if needed.
+The user cannot see the results of function calls, so you have to explicitly call output_message if you want them to see something.
+You must always respond with a single line of python that calls one of the available tools.
+Use as many steps as you need to in order to fully complete the task.
+Start by asking the user what they want.
+
+
+You must carefully review the conversation history, which functions were called so far, returned results, etc., and make sure the very next function call you make makes sense in order to achieve the original goal.
+
+You must ONLY use ONE of the following functions (these are the ONLY functions that exist):
+
+
+{"\n\n".join(self.available_functions)}
+
+
+Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
+ return base_prompt
+
+ def _execute_tool(self, code: str) -> str:
+ """Execute a tool call and return its result."""
+ globals_dict = {
+ tool.func.__name__: tool.func
+ for tool in self.tools
+ }
+
+ try:
+ result = eval(code.strip(), globals_dict)
+ return result
+ except Exception as e:
+ error_msg = f"Error executing code: {str(e)}"
+ raise ToolExecutionError(error_msg)
+
+ def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
+ """Create an agent chunk in the format expected by print_agent_output."""
+ return {
+ "agent": {
+ "messages": [AIMessage(content=content)]
+ }
+ }
+
+ def _create_error_chunk(self, content: str) -> Dict[str, Any]:
+ """Create an error chunk in the format expected by print_agent_output."""
+ message = ChunkMessage(content=content, status="error")
+ return {
+ "tools": {
+ "messages": [message]
+ }
+ }
+
+ @staticmethod
+ def _estimate_tokens(content: Optional[Union[str, BaseMessage]]) -> int:
+ """Estimate number of tokens in content using simple byte length heuristic.
+
+ Estimates 1 token per 4 bytes of content. For messages, uses the content field.
+
+ Args:
+ content: String content or Message object to estimate tokens for
+
+ Returns:
+ int: Estimated number of tokens, 0 if content is None/empty
+ """
+ if content is None:
+ return 0
+
+ if isinstance(content, BaseMessage):
+ text = content.content
+ else:
+ text = content
+
+ if not text:
+ return 0
+
+ return len(text.encode('utf-8')) // 4
+
+ def _trim_chat_history(self, initial_messages: List[Any], chat_history: List[Any]) -> List[Any]:
+ """Trim chat history based on message count and token limits while preserving initial messages.
+
+ Applies both message count and token limits (if configured) to chat_history,
+ while preserving all initial_messages. Returns concatenated result.
+
+ Args:
+ initial_messages: List of initial messages to preserve
+ chat_history: List of chat messages that may be trimmed
+
+ Returns:
+ List[Any]: Concatenated initial_messages + trimmed chat_history
+ """
+ # First apply message count limit
+ if len(chat_history) > self.max_history_messages:
+ chat_history = chat_history[-self.max_history_messages:]
+
+ # Skip token limiting if max_tokens is None
+ if self.max_tokens is None:
+ return initial_messages + chat_history
+
+ # Calculate initial messages token count
+ initial_tokens = sum(self._estimate_tokens(msg) for msg in initial_messages)
+
+ # Remove messages from start of chat_history until under token limit
+ while chat_history:
+ total_tokens = initial_tokens + sum(self._estimate_tokens(msg) for msg in chat_history)
+ if total_tokens <= self.max_tokens:
+ break
+ chat_history.pop(0)
+
+ return initial_messages + chat_history
+
+ def stream(self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None) -> Generator[Dict[str, Any], None, None]:
+ """Stream agent responses in a format compatible with print_agent_output."""
+ initial_messages = messages_dict.get("messages", [])
+ chat_history = []
+ last_result = None
+ first_iteration = True
+
+ while True:
+ base_prompt = self._build_prompt(None if first_iteration else last_result)
+ chat_history.append(HumanMessage(content=base_prompt))
+
+ full_history = self._trim_chat_history(initial_messages, chat_history)
+ response = self.model.invoke(full_history)
+
+ try:
+ last_result = self._execute_tool(response.content)
+ chat_history.append(response)
+ first_iteration = False
+ yield {}
+
+ except ToolExecutionError as e:
+ chat_history.append(HumanMessage(content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."))
+ yield self._create_error_chunk(str(e))
diff --git a/ra_aid/exceptions.py b/ra_aid/exceptions.py
index 2b9c0b7..9831a24 100644
--- a/ra_aid/exceptions.py
+++ b/ra_aid/exceptions.py
@@ -7,3 +7,12 @@ class AgentInterrupt(Exception):
separate from KeyboardInterrupt which is reserved for top-level handling.
"""
pass
+
+
+class ToolExecutionError(Exception):
+ """Exception raised when a tool execution fails.
+
+ This exception is used to distinguish tool execution failures
+ from other types of errors in the agent system.
+ """
+ pass
diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py
index e5262e9..f87ce0e 100644
--- a/ra_aid/tools/memory.py
+++ b/ra_aid/tools/memory.py
@@ -192,6 +192,8 @@ def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
"""Store multiple key source code snippets in global memory.
Automatically adds the filepaths of the snippets to related files.
+ This is for **existing**, or **just-written** files, not for things to be created in the future.
+
Args:
snippets: List of snippet information dictionaries containing:
- filepath: Path to the source file
diff --git a/tests/ra_aid/agents/test_ciayn_agent.py b/tests/ra_aid/agents/test_ciayn_agent.py
new file mode 100644
index 0000000..71bc1a8
--- /dev/null
+++ b/tests/ra_aid/agents/test_ciayn_agent.py
@@ -0,0 +1,156 @@
+import pytest
+from unittest.mock import Mock, patch
+from langchain_core.messages import HumanMessage, AIMessage
+from ra_aid.agents.ciayn_agent import CiaynAgent
+
+@pytest.fixture
+def mock_model():
+ """Create a mock language model."""
+ model = Mock()
+ model.invoke = Mock()
+ return model
+
+@pytest.fixture
+def agent(mock_model):
+ """Create a CiaynAgent instance with mock model."""
+ tools = [] # Empty tools list for testing trimming functionality
+ return CiaynAgent(mock_model, tools, max_history_messages=3)
+
+def test_trim_chat_history_preserves_initial_messages(agent):
+ """Test that initial messages are preserved during trimming."""
+ initial_messages = [
+ HumanMessage(content="Initial 1"),
+ AIMessage(content="Initial 2")
+ ]
+ chat_history = [
+ HumanMessage(content="Chat 1"),
+ AIMessage(content="Chat 2"),
+ HumanMessage(content="Chat 3"),
+ AIMessage(content="Chat 4")
+ ]
+
+ result = agent._trim_chat_history(initial_messages, chat_history)
+
+ # Verify initial messages are preserved
+ assert result[:2] == initial_messages
+ # Verify only last 3 chat messages are kept (due to max_history_messages=3)
+ assert len(result[2:]) == 3
+ assert result[2:] == chat_history[-3:]
+
+def test_trim_chat_history_under_limit(agent):
+ """Test trimming when chat history is under the maximum limit."""
+ initial_messages = [HumanMessage(content="Initial")]
+ chat_history = [
+ HumanMessage(content="Chat 1"),
+ AIMessage(content="Chat 2")
+ ]
+
+ result = agent._trim_chat_history(initial_messages, chat_history)
+
+ # Verify no trimming occurred
+ assert len(result) == 3
+ assert result == initial_messages + chat_history
+
+def test_trim_chat_history_over_limit(agent):
+ """Test trimming when chat history exceeds the maximum limit."""
+ initial_messages = [HumanMessage(content="Initial")]
+ chat_history = [
+ HumanMessage(content="Chat 1"),
+ AIMessage(content="Chat 2"),
+ HumanMessage(content="Chat 3"),
+ AIMessage(content="Chat 4"),
+ HumanMessage(content="Chat 5")
+ ]
+
+ result = agent._trim_chat_history(initial_messages, chat_history)
+
+ # Verify correct trimming
+ assert len(result) == 4 # initial + max_history_messages
+ assert result[0] == initial_messages[0] # Initial message preserved
+ assert result[1:] == chat_history[-3:] # Last 3 chat messages kept
+
+def test_trim_chat_history_empty_initial(agent):
+ """Test trimming with empty initial messages."""
+ initial_messages = []
+ chat_history = [
+ HumanMessage(content="Chat 1"),
+ AIMessage(content="Chat 2"),
+ HumanMessage(content="Chat 3"),
+ AIMessage(content="Chat 4")
+ ]
+
+ result = agent._trim_chat_history(initial_messages, chat_history)
+
+ # Verify only last 3 messages are kept
+ assert len(result) == 3
+ assert result == chat_history[-3:]
+
+def test_trim_chat_history_empty_chat(agent):
+ """Test trimming with empty chat history."""
+ initial_messages = [
+ HumanMessage(content="Initial 1"),
+ AIMessage(content="Initial 2")
+ ]
+ chat_history = []
+
+ result = agent._trim_chat_history(initial_messages, chat_history)
+
+ # Verify initial messages are preserved and no trimming occurred
+ assert result == initial_messages
+ assert len(result) == 2
+
+def test_trim_chat_history_token_limit():
+ """Test trimming based on token limit."""
+ agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=20)
+
+ initial_messages = [HumanMessage(content="Initial")] # ~2 tokens
+ chat_history = [
+ HumanMessage(content="A" * 40), # ~10 tokens
+ AIMessage(content="B" * 40), # ~10 tokens
+ HumanMessage(content="C" * 40) # ~10 tokens
+ ]
+
+ result = agent._trim_chat_history(initial_messages, chat_history)
+
+ # Should keep initial message (~2 tokens) and last message (~10 tokens)
+ assert len(result) == 2
+ assert result[0] == initial_messages[0]
+ assert result[1] == chat_history[-1]
+
+def test_trim_chat_history_no_token_limit():
+ """Test trimming with no token limit set."""
+ agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None)
+
+ initial_messages = [HumanMessage(content="Initial")]
+ chat_history = [
+ HumanMessage(content="A" * 1000),
+ AIMessage(content="B" * 1000),
+ HumanMessage(content="C" * 1000)
+ ]
+
+ result = agent._trim_chat_history(initial_messages, chat_history)
+
+ # Should keep initial message and last 2 messages (max_history_messages=2)
+ assert len(result) == 3
+ assert result[0] == initial_messages[0]
+ assert result[1:] == chat_history[-2:]
+
+def test_trim_chat_history_both_limits():
+ """Test trimming with both message count and token limits."""
+ agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=15)
+
+ initial_messages = [HumanMessage(content="Init")] # ~1 token
+ chat_history = [
+ HumanMessage(content="A" * 40), # ~10 tokens
+ AIMessage(content="B" * 40), # ~10 tokens
+ HumanMessage(content="C" * 40), # ~10 tokens
+ AIMessage(content="D" * 40) # ~10 tokens
+ ]
+
+ result = agent._trim_chat_history(initial_messages, chat_history)
+
+ # Should first apply message limit (keeping last 3)
+ # Then token limit should further reduce to fit under 15 tokens
+ assert len(result) == 2 # Initial message + 1 message under token limit
+ assert result[0] == initial_messages[0]
+ assert result[1] == chat_history[-1]
diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py
index 9519ef0..b6aa8bc 100644
--- a/tests/ra_aid/test_llm.py
+++ b/tests/ra_aid/test_llm.py
@@ -3,7 +3,9 @@ import pytest
from unittest.mock import patch, Mock
from langchain_openai.chat_models import ChatOpenAI
from langchain_anthropic.chat_models import ChatAnthropic
+from langchain_core.messages import HumanMessage
from dataclasses import dataclass
+from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.env import validate_environment
from ra_aid.llm import initialize_llm, initialize_expert_llm
@@ -87,6 +89,21 @@ def test_initialize_expert_unsupported_provider(clean_env):
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
initialize_expert_llm("unknown", "model")
+def test_estimate_tokens():
+ """Test token estimation functionality."""
+ # Test empty/None cases
+ assert CiaynAgent._estimate_tokens(None) == 0
+ assert CiaynAgent._estimate_tokens('') == 0
+
+ # Test string content
+ assert CiaynAgent._estimate_tokens('test') == 1 # 4 bytes
+ assert CiaynAgent._estimate_tokens('hello world') == 2 # 11 bytes
+ assert CiaynAgent._estimate_tokens('🚀') == 1 # 4 bytes
+
+ # Test message content
+ msg = HumanMessage(content='test message')
+ assert CiaynAgent._estimate_tokens(msg) == 3 # 11 bytes
+
def test_initialize_openai(clean_env, mock_openai):
"""Test OpenAI provider initialization"""
os.environ["OPENAI_API_KEY"] = "test-key"