837 lines
40 KiB
Python
837 lines
40 KiB
Python
import re
|
|
import ast
|
|
import string
|
|
import random
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Generator, List, Optional, Union, Tuple
|
|
|
|
from langchain_core.language_models import BaseChatModel
|
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
|
from langchain_core.tools import BaseTool
|
|
|
|
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
|
|
from ra_aid.exceptions import ToolExecutionError
|
|
from ra_aid.fallback_handler import FallbackHandler
|
|
from ra_aid.logging_config import get_logger
|
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
|
from ra_aid.prompts.ciayn_prompts import CIAYN_AGENT_SYSTEM_PROMPT, CIAYN_AGENT_HUMAN_PROMPT, EXTRACT_TOOL_CALL_PROMPT, NO_TOOL_CALL_PROMPT
|
|
from ra_aid.tools.expert import get_model
|
|
from ra_aid.tools.reflection import get_function_info
|
|
from ra_aid.console.output import cpm
|
|
from ra_aid.console.formatting import print_warning, print_error, console
|
|
from ra_aid.agent_context import should_exit
|
|
from ra_aid.text.processing import extract_think_tag, process_thinking_content
|
|
from rich.panel import Panel
|
|
from rich.markdown import Markdown
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ChunkMessage:
|
|
content: str
|
|
status: str
|
|
|
|
|
|
def validate_function_call_pattern(s: str) -> bool:
|
|
"""Check if a string matches the expected function call pattern.
|
|
|
|
Validates that the string represents a valid function call using AST parsing.
|
|
Valid function calls must be syntactically valid Python code.
|
|
|
|
Args:
|
|
s: String to validate
|
|
|
|
Returns:
|
|
bool: False if pattern matches (valid), True if invalid
|
|
"""
|
|
# Clean up the code before parsing
|
|
s = s.strip()
|
|
if s.startswith("```") and s.endswith("```"):
|
|
s = s[3:-3].strip()
|
|
elif s.startswith("```"):
|
|
s = s[3:].strip()
|
|
elif s.endswith("```"):
|
|
s = s[:-3].strip()
|
|
|
|
# Use AST parsing as the single validation method
|
|
try:
|
|
tree = ast.parse(s)
|
|
|
|
# Valid pattern is a single expression that's a function call
|
|
if (len(tree.body) == 1 and
|
|
isinstance(tree.body[0], ast.Expr) and
|
|
isinstance(tree.body[0].value, ast.Call)):
|
|
|
|
return False # Valid function call
|
|
|
|
return True # Invalid pattern
|
|
|
|
except Exception:
|
|
# Any exception during parsing means it's not valid
|
|
return True
|
|
|
|
|
|
class CiaynAgent:
|
|
"""Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction.
|
|
|
|
The CIAYN philosophy emphasizes direct code generation and execution over structured APIs:
|
|
- Language model generates executable Python code snippets
|
|
- Tools are invoked through natural Python code rather than fixed schemas
|
|
- Flexible and adaptable approach to tool usage through dynamic code
|
|
- Complex workflows emerge from composing code segments
|
|
|
|
Code Generation & Function Calling:
|
|
- Dynamic generation of Python code for tool invocation
|
|
- Handles complex nested function calls and argument structures
|
|
- Natural integration of tool outputs into Python data flow
|
|
- Runtime code composition for multi-step operations
|
|
|
|
ReAct Pattern Implementation:
|
|
- Observation: Captures tool execution results
|
|
- Reasoning: Analyzes outputs to determine next steps
|
|
- Action: Generates and executes appropriate code
|
|
- Reflection: Updates state and plans next iteration
|
|
- Maintains conversation context across iterations
|
|
|
|
Core Capabilities:
|
|
- Dynamic tool registration with automatic documentation
|
|
- Sandboxed code execution environment
|
|
- Token-aware chat history management
|
|
- Comprehensive error handling and recovery
|
|
- Streaming interface for real-time interaction
|
|
- Memory management with configurable limits
|
|
"""
|
|
|
|
# List of tools that can be bundled together in a single response
|
|
BUNDLEABLE_TOOLS = [
|
|
"emit_expert_context",
|
|
"ask_expert",
|
|
"emit_key_facts",
|
|
"emit_key_snippet",
|
|
"request_implementation",
|
|
"read_file_tool",
|
|
"emit_research_notes",
|
|
"ripgrep_search",
|
|
"plan_implementation_completed",
|
|
"request_research_and_implementation",
|
|
"run_shell_command",
|
|
]
|
|
|
|
# List of tools that should not be called repeatedly with the same parameters
|
|
# This prevents the agent from getting stuck in a loop calling the same tool
|
|
# with the same arguments multiple times
|
|
NO_REPEAT_TOOLS = [
|
|
"emit_expert_context",
|
|
"ask_expert",
|
|
"emit_key_facts",
|
|
"emit_key_snippet",
|
|
"request_implementation",
|
|
"read_file_tool",
|
|
"emit_research_notes",
|
|
"ripgrep_search",
|
|
"plan_implementation_completed",
|
|
"request_research_and_implementation",
|
|
"run_shell_command",
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
model: BaseChatModel,
|
|
tools: list[BaseTool],
|
|
max_history_messages: int = 50,
|
|
max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT,
|
|
config: Optional[dict] = None,
|
|
):
|
|
"""Initialize the agent with a model and list of tools.
|
|
|
|
Args:
|
|
model: The language model to use
|
|
tools: List of tools available to the agent
|
|
max_history_messages: Maximum number of messages to keep in chat history
|
|
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
|
|
config: Optional configuration dictionary
|
|
"""
|
|
if config is None:
|
|
config = {}
|
|
self.config = config
|
|
self.provider = config.get("provider", "openai")
|
|
|
|
self.model = model
|
|
self.tools = tools
|
|
self.max_history_messages = max_history_messages
|
|
self.max_tokens = max_tokens
|
|
self.chat_history = []
|
|
self.available_functions = []
|
|
for t in tools:
|
|
self.available_functions.append(get_function_info(t.func))
|
|
|
|
self.fallback_handler = FallbackHandler(config, tools)
|
|
|
|
# Include the functions list in the system prompt
|
|
functions_list = "\n\n".join(self.available_functions)
|
|
self.sys_message = SystemMessage(
|
|
CIAYN_AGENT_SYSTEM_PROMPT.format(functions_list=functions_list)
|
|
)
|
|
|
|
self.error_message_template = "Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."
|
|
self.fallback_fixed_msg = HumanMessage(
|
|
"Fallback tool handler has fixed the tool call see: <fallback tool call result> for the output."
|
|
)
|
|
|
|
# Track the most recent tool call and parameters to prevent repeats
|
|
# This is used to detect and prevent identical tool calls with the same parameters
|
|
# to avoid redundant operations and encourage the agent to try different approaches
|
|
self.last_tool_call = None
|
|
self.last_tool_params = None
|
|
|
|
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
|
"""Build the prompt for the agent including available tools and context."""
|
|
# Add last result section if provided
|
|
last_result_section = ""
|
|
if last_result is not None:
|
|
last_result_section = f"\n<last result>{last_result}</last result>"
|
|
|
|
# Build the human prompt without the function list
|
|
return CIAYN_AGENT_HUMAN_PROMPT.format(
|
|
last_result_section=last_result_section
|
|
)
|
|
|
|
def _detect_multiple_tool_calls(self, code: str) -> List[str]:
|
|
"""Detect if there are multiple tool calls in the code using AST parsing.
|
|
|
|
Args:
|
|
code: The code string to analyze
|
|
|
|
Returns:
|
|
List of individual tool call strings if bundleable, or just the original code as a single element
|
|
"""
|
|
try:
|
|
# Clean up the code for parsing
|
|
code = code.strip()
|
|
if code.startswith("```"):
|
|
code = code[3:].strip()
|
|
if code.endswith("```"):
|
|
code = code[:-3].strip()
|
|
|
|
# Try to parse the code as a sequence of expressions
|
|
parsed = ast.parse(code)
|
|
|
|
# Check if we have multiple expressions and they are all valid function calls
|
|
if isinstance(parsed.body, list) and len(parsed.body) > 1:
|
|
calls = []
|
|
for node in parsed.body:
|
|
# Only process expressions that are function calls
|
|
if (isinstance(node, ast.Expr) and
|
|
isinstance(node.value, ast.Call) and
|
|
isinstance(node.value.func, ast.Name)):
|
|
|
|
func_name = node.value.func.id
|
|
|
|
# Only consider this a bundleable call if the function is in our allowed list
|
|
if func_name in self.BUNDLEABLE_TOOLS:
|
|
# Extract the exact call text from the original code
|
|
call_str = ast.unparse(node)
|
|
calls.append(call_str)
|
|
else:
|
|
# If any function is not bundleable, return just the original code
|
|
logger.debug(f"Found multiple tool calls, but {func_name} is not bundleable.")
|
|
return [code]
|
|
|
|
if calls:
|
|
logger.debug(f"Detected {len(calls)} bundleable tool calls.")
|
|
return calls
|
|
|
|
# Default case: just return the original code as a single element
|
|
return [code]
|
|
|
|
except SyntaxError:
|
|
# If we can't parse the code with AST, just return the original
|
|
return [code]
|
|
|
|
def _execute_tool(self, msg: BaseMessage) -> str:
|
|
"""Execute a tool call and return its result."""
|
|
|
|
# Check for should_exit before executing tool calls
|
|
if should_exit():
|
|
logger.debug("Agent should exit flag detected in _execute_tool")
|
|
return "Tool execution aborted - agent should exit flag is set"
|
|
|
|
code = msg.content
|
|
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
|
|
|
try:
|
|
code = code.strip()
|
|
if code.startswith("```"):
|
|
code = code[3:].strip()
|
|
if code.endswith("```"):
|
|
code = code[:-3].strip()
|
|
|
|
# Check for multiple tool calls that can be bundled
|
|
tool_calls = self._detect_multiple_tool_calls(code)
|
|
|
|
# If we have multiple valid bundleable calls, execute them in sequence
|
|
if len(tool_calls) > 1:
|
|
# Check for should_exit before executing bundled tool calls
|
|
if should_exit():
|
|
logger.debug("Agent should exit flag detected before executing bundled tool calls")
|
|
return "Bundled tool execution aborted - agent should exit flag is set"
|
|
|
|
results = []
|
|
result_strings = []
|
|
|
|
for call in tool_calls:
|
|
# Check if agent should exit
|
|
if should_exit():
|
|
logger.debug("Agent should exit flag detected during bundled tool execution")
|
|
return "Tool execution interrupted: agent_should_exit flag is set."
|
|
|
|
# Validate and fix each call if needed
|
|
if validate_function_call_pattern(call):
|
|
functions_list = "\n\n".join(self.available_functions)
|
|
call = self._extract_tool_call(call, functions_list)
|
|
|
|
# Check for repeated tool calls with the same parameters
|
|
tool_name = self.extract_tool_name(call)
|
|
|
|
if tool_name in self.NO_REPEAT_TOOLS:
|
|
# Use AST to extract parameters
|
|
try:
|
|
tree = ast.parse(call)
|
|
if (isinstance(tree.body[0], ast.Expr) and
|
|
isinstance(tree.body[0].value, ast.Call)):
|
|
|
|
# Debug - print full AST structure
|
|
logger.debug(f"AST structure for bundled call: {ast.dump(tree.body[0].value)}")
|
|
|
|
# Extract and normalize parameter values
|
|
param_pairs = []
|
|
|
|
# Handle positional arguments
|
|
if tree.body[0].value.args:
|
|
logger.debug(f"Found positional args in bundled call: {[ast.unparse(arg) for arg in tree.body[0].value.args]}")
|
|
|
|
for i, arg in enumerate(tree.body[0].value.args):
|
|
arg_value = ast.unparse(arg)
|
|
|
|
# Normalize string literals by removing outer quotes
|
|
if ((arg_value.startswith("'") and arg_value.endswith("'")) or
|
|
(arg_value.startswith('"') and arg_value.endswith('"'))):
|
|
arg_value = arg_value[1:-1]
|
|
|
|
param_pairs.append((f"arg{i}", arg_value))
|
|
|
|
# Handle keyword arguments
|
|
for k in tree.body[0].value.keywords:
|
|
param_name = k.arg
|
|
param_value = ast.unparse(k.value)
|
|
|
|
# Debug - print each parameter
|
|
logger.debug(f"Processing parameter: {param_name} = {param_value}")
|
|
|
|
# Normalize string literals by removing outer quotes
|
|
if ((param_value.startswith("'") and param_value.endswith("'")) or
|
|
(param_value.startswith('"') and param_value.endswith('"'))):
|
|
param_value = param_value[1:-1]
|
|
|
|
param_pairs.append((param_name, param_value))
|
|
|
|
# Debug - print extracted parameters
|
|
logger.debug(f"Extracted parameters: {param_pairs}")
|
|
|
|
# Create a fingerprint of the call
|
|
current_call = (tool_name, str(sorted(param_pairs)))
|
|
|
|
# Debug information to help diagnose false positives
|
|
logger.debug(f"Tool call: {tool_name}\nCurrent call fingerprint: {current_call}\nLast call fingerprint: {self.last_tool_call}")
|
|
|
|
# If this fingerprint matches the last tool call, reject it
|
|
if current_call == self.last_tool_call:
|
|
logger.info(f"Detected repeat call of {tool_name} with the same parameters.")
|
|
result = f"Repeat calls of {tool_name} with the same parameters are not allowed. You must try something different!"
|
|
results.append(result)
|
|
|
|
# Generate a random ID for this result
|
|
result_id = self._generate_random_id()
|
|
result_strings.append(f"<result-{result_id}>\n{result}\n</result-{result_id}>")
|
|
continue
|
|
|
|
# Update last tool call fingerprint for next comparison
|
|
self.last_tool_call = current_call
|
|
except Exception as e:
|
|
# If we can't parse parameters, just continue
|
|
# This ensures robustness when dealing with complex or malformed tool calls
|
|
logger.debug(f"Failed to parse parameters for duplicate detection: {str(e)}")
|
|
pass
|
|
|
|
# Execute the call and collect the result
|
|
result = eval(call.strip(), globals_dict)
|
|
results.append(result)
|
|
|
|
# Generate a random ID for this result
|
|
result_id = self._generate_random_id()
|
|
result_strings.append(f"<result-{result_id}>\n{result}\n</result-{result_id}>")
|
|
|
|
# Return all results as one big string with tagged sections
|
|
return "\n\n".join(result_strings)
|
|
|
|
# Regular single tool call case
|
|
if validate_function_call_pattern(code):
|
|
logger.debug(f"Tool call validation failed. Attempting to extract function call using LLM.")
|
|
functions_list = "\n\n".join(self.available_functions)
|
|
code = self._extract_tool_call(code, functions_list)
|
|
|
|
# Check for repeated tool call with the same parameters (single tool case)
|
|
tool_name = self.extract_tool_name(code)
|
|
|
|
# If the tool is in the NO_REPEAT_TOOLS list, check for repeat calls
|
|
if tool_name in self.NO_REPEAT_TOOLS:
|
|
# Use AST to extract parameters
|
|
try:
|
|
tree = ast.parse(code)
|
|
if (isinstance(tree.body[0], ast.Expr) and
|
|
isinstance(tree.body[0].value, ast.Call)):
|
|
|
|
# Debug - print full AST structure
|
|
logger.debug(f"AST structure for single call: {ast.dump(tree.body[0].value)}")
|
|
|
|
# Extract and normalize parameter values
|
|
param_pairs = []
|
|
|
|
# Handle positional arguments
|
|
if tree.body[0].value.args:
|
|
logger.debug(f"Found positional args in single call: {[ast.unparse(arg) for arg in tree.body[0].value.args]}")
|
|
|
|
for i, arg in enumerate(tree.body[0].value.args):
|
|
arg_value = ast.unparse(arg)
|
|
|
|
# Normalize string literals by removing outer quotes
|
|
if ((arg_value.startswith("'") and arg_value.endswith("'")) or
|
|
(arg_value.startswith('"') and arg_value.endswith('"'))):
|
|
arg_value = arg_value[1:-1]
|
|
|
|
param_pairs.append((f"arg{i}", arg_value))
|
|
|
|
# Handle keyword arguments
|
|
for k in tree.body[0].value.keywords:
|
|
param_name = k.arg
|
|
param_value = ast.unparse(k.value)
|
|
|
|
# Debug - print each parameter
|
|
logger.debug(f"Processing parameter: {param_name} = {param_value}")
|
|
|
|
# Normalize string literals by removing outer quotes
|
|
if ((param_value.startswith("'") and param_value.endswith("'")) or
|
|
(param_value.startswith('"') and param_value.endswith('"'))):
|
|
param_value = param_value[1:-1]
|
|
|
|
param_pairs.append((param_name, param_value))
|
|
|
|
# Also check for positional arguments
|
|
if tree.body[0].value.args:
|
|
logger.debug(f"Found positional args: {[ast.unparse(arg) for arg in tree.body[0].value.args]}")
|
|
|
|
# Create a fingerprint of the call
|
|
current_call = (tool_name, str(sorted(param_pairs)))
|
|
|
|
# Debug information to help diagnose false positives
|
|
logger.debug(f"Tool call: {tool_name}\nCurrent call fingerprint: {current_call}\nLast call fingerprint: {self.last_tool_call}")
|
|
|
|
# If this fingerprint matches the last tool call, reject it
|
|
if current_call == self.last_tool_call:
|
|
logger.info(f"Detected repeat call of {tool_name} with the same parameters.")
|
|
return f"Repeat calls of {tool_name} with the same parameters are not allowed. You must try something different!"
|
|
|
|
# Update last tool call fingerprint for next comparison
|
|
self.last_tool_call = current_call
|
|
except Exception as e:
|
|
# If we can't parse parameters, just continue with the tool execution
|
|
# This ensures robustness when dealing with complex or malformed tool calls
|
|
logger.debug(f"Failed to parse parameters for duplicate detection: {str(e)}")
|
|
pass
|
|
|
|
# Before executing the call
|
|
if should_exit():
|
|
logger.debug("Agent should exit flag detected before tool execution")
|
|
return "Tool execution interrupted: agent_should_exit flag is set."
|
|
|
|
# Execute the tool
|
|
result = eval(code.strip(), globals_dict)
|
|
return result
|
|
except Exception as e:
|
|
error_msg = f"Error: {str(e)} \n Could not execute code: {code}"
|
|
tool_name = self.extract_tool_name(code)
|
|
logger.info(f"Tool execution failed for `{tool_name}`: {str(e)}")
|
|
|
|
# Record error in trajectory
|
|
try:
|
|
# Import here to avoid circular imports
|
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
|
from ra_aid.database.connection import get_db
|
|
|
|
# Create repositories directly
|
|
trajectory_repo = TrajectoryRepository(get_db())
|
|
human_input_repo = HumanInputRepository(get_db())
|
|
human_input_id = human_input_repo.get_most_recent_id()
|
|
|
|
trajectory_repo.create(
|
|
step_data={
|
|
"error_message": f"Tool execution failed for `{tool_name}`:\nError: {str(e)}",
|
|
"display_title": "Tool Error",
|
|
"code": code,
|
|
"tool_name": tool_name
|
|
},
|
|
record_type="tool_execution",
|
|
human_input_id=human_input_id,
|
|
is_error=True,
|
|
error_message=str(e),
|
|
error_type="ToolExecutionError",
|
|
tool_name=tool_name,
|
|
tool_parameters={"code": code}
|
|
)
|
|
except Exception as trajectory_error:
|
|
# Just log and continue if there's an error in trajectory recording
|
|
logger.error(f"Error recording trajectory for tool error display: {trajectory_error}")
|
|
|
|
print_warning(f"Tool execution failed for `{tool_name}`:\nError: {str(e)}\n\nCode:\n\n````\n{code}\n````", title="Tool Error")
|
|
raise ToolExecutionError(
|
|
error_msg, base_message=msg, tool_name=tool_name
|
|
) from e
|
|
|
|
def _generate_random_id(self, length: int = 6) -> str:
|
|
"""Generate a random ID string for result tagging.
|
|
|
|
Args:
|
|
length: Length of the random ID to generate
|
|
|
|
Returns:
|
|
String of random alphanumeric characters
|
|
"""
|
|
chars = string.ascii_lowercase + string.digits
|
|
return ''.join(random.choice(chars) for _ in range(length))
|
|
|
|
def extract_tool_name(self, code: str) -> str:
|
|
"""Extract the tool name from the code."""
|
|
match = re.match(r"\s*([\w_\-]+)\s*\(", code)
|
|
if match:
|
|
return match.group(1)
|
|
return ""
|
|
|
|
def handle_fallback_response(
|
|
self, fallback_response: list[Any], e: ToolExecutionError
|
|
) -> str:
|
|
"""Handle a fallback response from the fallback handler."""
|
|
err_msg = HumanMessage(content=self.error_message_template.format(e=e))
|
|
|
|
if not fallback_response:
|
|
self.chat_history.append(err_msg)
|
|
logger.info(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}")
|
|
|
|
# Record error in trajectory
|
|
try:
|
|
# Import here to avoid circular imports
|
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
|
from ra_aid.database.connection import get_db
|
|
|
|
# Create repositories directly
|
|
trajectory_repo = TrajectoryRepository(get_db())
|
|
human_input_repo = HumanInputRepository(get_db())
|
|
human_input_id = human_input_repo.get_most_recent_id()
|
|
|
|
trajectory_repo.create(
|
|
step_data={
|
|
"error_message": f"Tool fallback was attempted but did not succeed. Original error: {str(e)}",
|
|
"display_title": "Fallback Failed",
|
|
"tool_name": e.tool_name if hasattr(e, "tool_name") else "unknown_tool"
|
|
},
|
|
record_type="error",
|
|
human_input_id=human_input_id,
|
|
is_error=True,
|
|
error_message=str(e),
|
|
error_type="FallbackFailedError",
|
|
tool_name=e.tool_name if hasattr(e, "tool_name") else "unknown_tool"
|
|
)
|
|
except Exception as trajectory_error:
|
|
# Just log and continue if there's an error in trajectory recording
|
|
logger.error(f"Error recording trajectory for fallback failed warning: {trajectory_error}")
|
|
|
|
print_warning(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}", title="Fallback Failed")
|
|
return ""
|
|
|
|
self.chat_history.append(self.fallback_fixed_msg)
|
|
msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n"
|
|
# Passing the fallback raw invocation may confuse our llm, as invocation methods may differ.
|
|
# msg += f"<fallback llm raw invocation>{fallback_response[0]}</fallback llm raw invocation>\n"
|
|
msg += f"<fallback tool name>{e.tool_name}</fallback tool name>\n"
|
|
msg += f"<fallback tool call result>\n{fallback_response[1]}\n</fallback tool call result>\n"
|
|
|
|
logger.info(f"Fallback successful for tool `{e.tool_name}` after {DEFAULT_MAX_TOOL_FAILURES} consecutive failures.")
|
|
|
|
return msg
|
|
|
|
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
|
"""Create an agent chunk in the format expected by print_agent_output."""
|
|
return {"agent": {"messages": [AIMessage(content=content)]}}
|
|
|
|
def _create_error_chunk(self, error_message: str) -> Dict[str, Any]:
|
|
"""Create an error chunk for the agent output stream."""
|
|
return {
|
|
"type": "error",
|
|
"message": error_message,
|
|
"tool_call": {
|
|
"name": "report_error",
|
|
"args": {"error": error_message},
|
|
},
|
|
}
|
|
|
|
def _trim_chat_history(
|
|
self, initial_messages: List[Any], chat_history: List[Any]
|
|
) -> List[Any]:
|
|
"""Trim chat history based on message count and token limits while preserving initial messages.
|
|
|
|
Applies both message count and token limits (if configured) to chat_history,
|
|
while preserving all initial_messages. Returns concatenated result.
|
|
|
|
Args:
|
|
initial_messages: List of initial messages to preserve
|
|
chat_history: List of chat messages that may be trimmed
|
|
|
|
Returns:
|
|
List[Any]: Concatenated initial_messages + trimmed chat_history
|
|
"""
|
|
# First apply message count limit
|
|
if len(chat_history) > self.max_history_messages:
|
|
chat_history = chat_history[-self.max_history_messages :]
|
|
|
|
# Skip token limiting if max_tokens is None
|
|
if self.max_tokens is None:
|
|
return initial_messages + chat_history
|
|
|
|
# Calculate initial messages token count
|
|
initial_tokens = sum(self._estimate_tokens(msg) for msg in initial_messages)
|
|
|
|
# Remove messages from start of chat_history until under token limit
|
|
while chat_history:
|
|
total_tokens = initial_tokens + sum(
|
|
self._estimate_tokens(msg) for msg in chat_history
|
|
)
|
|
if total_tokens <= self.max_tokens:
|
|
break
|
|
chat_history.pop(0)
|
|
|
|
return initial_messages + chat_history
|
|
|
|
@staticmethod
|
|
def _estimate_tokens(content: Optional[Union[str, BaseMessage]]) -> int:
|
|
"""Estimate token count for a message or string."""
|
|
if content is None:
|
|
return 0
|
|
|
|
if isinstance(content, BaseMessage):
|
|
text = content.content
|
|
else:
|
|
text = content
|
|
|
|
# create-react-agent tool calls can be lists
|
|
if isinstance(text, List):
|
|
text = str(text)
|
|
|
|
if not text:
|
|
return 0
|
|
|
|
return len(text.encode("utf-8")) // 2.0
|
|
|
|
def _extract_tool_call(self, code: str, functions_list: str) -> str:
|
|
"""Extract a tool call from the code using a language model."""
|
|
model = get_model()
|
|
logger.debug(f"Attempting to fix malformed tool call using LLM. Original code:\n```\n{code}\n```")
|
|
prompt = EXTRACT_TOOL_CALL_PROMPT.format(
|
|
functions_list=functions_list, code=code
|
|
)
|
|
response = model.invoke(prompt)
|
|
response = response.content
|
|
|
|
pattern = r"([\w_\-]+)\((.*?)\)"
|
|
matches = re.findall(pattern, response, re.DOTALL)
|
|
if len(matches) == 0:
|
|
logger.info("Failed to extract a valid tool call from the model's response.")
|
|
|
|
# Record error in trajectory
|
|
try:
|
|
# Import here to avoid circular imports
|
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
|
from ra_aid.database.connection import get_db
|
|
|
|
# Create repositories directly
|
|
trajectory_repo = TrajectoryRepository(get_db())
|
|
human_input_repo = HumanInputRepository(get_db())
|
|
human_input_id = human_input_repo.get_most_recent_id()
|
|
|
|
trajectory_repo.create(
|
|
step_data={
|
|
"error_message": "Failed to extract a valid tool call from the model's response.",
|
|
"display_title": "Extraction Failed",
|
|
"code": code
|
|
},
|
|
record_type="error",
|
|
human_input_id=human_input_id,
|
|
is_error=True,
|
|
error_message="Failed to extract a valid tool call from the model's response.",
|
|
error_type="ExtractionError"
|
|
)
|
|
except Exception as trajectory_error:
|
|
# Just log and continue if there's an error in trajectory recording
|
|
logger.error(f"Error recording trajectory for extraction error display: {trajectory_error}")
|
|
|
|
print_warning("Failed to extract a valid tool call from the model's response.", title="Extraction Failed")
|
|
raise ToolExecutionError("Failed to extract tool call")
|
|
ma = matches[0][0].strip()
|
|
mb = matches[0][1].strip().replace("\n", " ")
|
|
fixed_code = f"{ma}({mb})"
|
|
logger.debug(f"Successfully extracted tool call: `{fixed_code}`")
|
|
return fixed_code
|
|
|
|
def stream(
|
|
self, messages_dict: Dict[str, List[Any]], _config: Dict[str, Any] = None
|
|
) -> Generator[Dict[str, Any], None, None]:
|
|
"""Stream agent responses in a format compatible with print_agent_output."""
|
|
initial_messages = messages_dict.get("messages", [])
|
|
self.chat_history = []
|
|
last_result = None
|
|
empty_response_count = 0
|
|
max_empty_responses = 3 # Maximum number of consecutive empty responses before giving up
|
|
|
|
while True:
|
|
# Check for should_exit
|
|
if should_exit():
|
|
logger.debug("Agent should exit flag detected in stream loop")
|
|
break
|
|
|
|
base_prompt = self._build_prompt(last_result)
|
|
self.chat_history.append(HumanMessage(content=base_prompt))
|
|
full_history = self._trim_chat_history(initial_messages, self.chat_history)
|
|
response = self.model.invoke([self.sys_message] + full_history)
|
|
|
|
# Check if model supports think tags
|
|
provider = self.config.get("provider", "")
|
|
model_name = self.config.get("model", "")
|
|
model_config = models_params.get(provider, {}).get(model_name, {})
|
|
supports_think_tag = model_config.get("supports_think_tag", False)
|
|
supports_thinking = model_config.get("supports_thinking", False)
|
|
|
|
# Process thinking content if supported
|
|
response.content, _ = process_thinking_content(
|
|
content=response.content,
|
|
supports_think_tag=supports_think_tag,
|
|
supports_thinking=supports_thinking,
|
|
panel_title="💭 Thoughts",
|
|
show_thoughts=self.config.get("show_thoughts", False)
|
|
)
|
|
|
|
# Check if the response is empty or doesn't contain a valid tool call
|
|
if not response.content or not response.content.strip():
|
|
empty_response_count += 1
|
|
logger.info(f"Model returned empty response (count: {empty_response_count})")
|
|
|
|
warning_message = f"The model returned an empty response (attempt {empty_response_count} of {max_empty_responses}). Requesting the model to make a valid tool call."
|
|
logger.info(warning_message)
|
|
|
|
# Record warning in trajectory
|
|
try:
|
|
# Import here to avoid circular imports
|
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
|
from ra_aid.database.connection import get_db_connection
|
|
|
|
# Create repositories directly
|
|
trajectory_repo = TrajectoryRepository(get_db_connection())
|
|
human_input_repo = HumanInputRepository(get_db_connection())
|
|
human_input_id = human_input_repo.get_most_recent_id()
|
|
|
|
trajectory_repo.create(
|
|
step_data={
|
|
"warning_message": warning_message,
|
|
"display_title": "Empty Response",
|
|
"attempt": empty_response_count,
|
|
"max_attempts": max_empty_responses
|
|
},
|
|
record_type="error",
|
|
human_input_id=human_input_id,
|
|
is_error=True,
|
|
error_message=warning_message,
|
|
error_type="EmptyResponseWarning"
|
|
)
|
|
except Exception as trajectory_error:
|
|
# Just log and continue if there's an error in trajectory recording
|
|
logger.error(f"Error recording trajectory for empty response warning: {trajectory_error}")
|
|
|
|
print_warning(warning_message, title="Empty Response")
|
|
|
|
if empty_response_count >= max_empty_responses:
|
|
# If we've had too many empty responses, raise an error to break the loop
|
|
from ra_aid.agent_context import mark_agent_crashed
|
|
crash_message = "Agent failed to make any tool calls after multiple attempts"
|
|
mark_agent_crashed(crash_message)
|
|
logger.error(crash_message)
|
|
|
|
error_message = "The agent has crashed after multiple failed attempts to generate a valid tool call."
|
|
logger.error(error_message)
|
|
|
|
# Record error in trajectory
|
|
try:
|
|
# Import here to avoid circular imports
|
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
|
from ra_aid.database.connection import get_db_connection
|
|
|
|
# Create repositories directly
|
|
trajectory_repo = TrajectoryRepository(get_db_connection())
|
|
human_input_repo = HumanInputRepository(get_db_connection())
|
|
human_input_id = human_input_repo.get_most_recent_id()
|
|
|
|
trajectory_repo.create(
|
|
step_data={
|
|
"error_message": error_message,
|
|
"display_title": "Agent Crashed",
|
|
"crash_reason": crash_message,
|
|
"attempts": empty_response_count
|
|
},
|
|
record_type="error",
|
|
human_input_id=human_input_id,
|
|
is_error=True,
|
|
error_message=error_message,
|
|
error_type="AgentCrashError"
|
|
)
|
|
except Exception as trajectory_error:
|
|
# Just log and continue if there's an error in trajectory recording
|
|
logger.error(f"Error recording trajectory for agent crash: {trajectory_error}")
|
|
|
|
print_error(error_message)
|
|
|
|
yield self._create_error_chunk(crash_message)
|
|
return
|
|
|
|
# Send a message to the model explicitly telling it to make a tool call
|
|
self.chat_history.append(AIMessage(content="")) # Add the empty response
|
|
self.chat_history.append(HumanMessage(content=NO_TOOL_CALL_PROMPT))
|
|
continue
|
|
|
|
# Reset empty response counter on successful response
|
|
empty_response_count = 0
|
|
|
|
try:
|
|
last_result = self._execute_tool(response)
|
|
self.chat_history.append(response)
|
|
self.fallback_handler.reset_fallback_handler()
|
|
yield {}
|
|
|
|
except ToolExecutionError as e:
|
|
logger.info(f"Tool execution error: {str(e)}. Attempting fallback...")
|
|
fallback_response = self.fallback_handler.handle_failure(
|
|
e, self, self.chat_history
|
|
)
|
|
last_result = self.handle_fallback_response(fallback_response, e)
|
|
yield {} |