This commit is contained in:
AI Christianson 2024-12-28 14:54:31 -05:00
parent 9074cae2f5
commit ac2bdfd69b
1 changed files with 29 additions and 3 deletions

View File

@ -20,10 +20,17 @@ class CiaynAgent:
\"\"\"""" \"\"\""""
return info return info
def __init__(self, model, tools: list): def __init__(self, model, tools: list, max_history_messages: int = 50):
"""Initialize the agent with a model and list of tools.""" """Initialize the agent with a model and list of tools.
Args:
model: The language model to use
tools: List of tools available to the agent
max_history_messages: Maximum number of messages to keep in chat history
"""
self.model = model self.model = model
self.tools = tools self.tools = tools
self.max_history_messages = max_history_messages
self.available_functions = [] self.available_functions = []
for t in tools: for t in tools:
self.available_functions.append(self._get_function_info(t.func)) self.available_functions.append(self._get_function_info(t.func))
@ -91,6 +98,25 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
} }
} }
def _trim_chat_history(self, initial_messages: List[Any], chat_history: List[Any]) -> List[Any]:
"""Trim chat history to maximum length while preserving initial messages.
Only trims the chat_history portion while preserving all initial messages.
Returns the concatenated list of initial_messages + trimmed chat_history.
Args:
initial_messages: List of initial messages to preserve
chat_history: List of chat messages that may be trimmed
Returns:
List[Any]: Concatenated initial_messages + trimmed chat_history
"""
if len(chat_history) <= self.max_history_messages:
return initial_messages + chat_history
# Keep last max_history_messages from chat_history
return initial_messages + chat_history[-self.max_history_messages:]
def stream(self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None) -> Generator[Dict[str, Any], None, None]: def stream(self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None) -> Generator[Dict[str, Any], None, None]:
"""Stream agent responses in a format compatible with print_agent_output.""" """Stream agent responses in a format compatible with print_agent_output."""
initial_messages = messages_dict.get("messages", []) initial_messages = messages_dict.get("messages", [])
@ -102,7 +128,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
base_prompt = self._build_prompt(None if first_iteration else last_result) base_prompt = self._build_prompt(None if first_iteration else last_result)
chat_history.append(HumanMessage(content=base_prompt)) chat_history.append(HumanMessage(content=base_prompt))
full_history = initial_messages + chat_history full_history = self._trim_chat_history(initial_messages, chat_history)
response = self.model.invoke(full_history) response = self.model.invoke(full_history)
try: try: