diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 0b18d34..d0f26cb 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -164,7 +164,6 @@ def create_agent( max_input_tokens = ( get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT ) - print(f"max_input_tokens={max_input_tokens}") # Use REACT agent for Anthropic Claude models, otherwise use CIAYN if is_anthropic_claude(config): diff --git a/ra_aid/anthropic_message_utils.py b/ra_aid/anthropic_message_utils.py new file mode 100644 index 0000000..91c285f --- /dev/null +++ b/ra_aid/anthropic_message_utils.py @@ -0,0 +1,393 @@ +"""Utilities for handling Anthropic-specific message formats and trimming.""" + +from typing import Callable, List, Literal, Optional, Sequence, Union, cast + +from langchain_core.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) + + +def _is_message_type( + message: BaseMessage, message_types: Union[str, type, List[Union[str, type]]] +) -> bool: + """Check if a message is of a specific type or types. + + Args: + message: The message to check + message_types: Type(s) to check against (string name or class) + + Returns: + bool: True if message matches any of the specified types + """ + if not isinstance(message_types, list): + message_types = [message_types] + + types_str = [t for t in message_types if isinstance(t, str)] + types_classes = tuple(t for t in message_types if isinstance(t, type)) + + return message.type in types_str or isinstance(message, types_classes) + + +def has_tool_use(message: BaseMessage) -> bool: + """Check if a message contains tool use. + + Args: + message: The message to check + + Returns: + bool: True if the message contains tool use + """ + if not isinstance(message, AIMessage): + return False + + # Check content for tool_use + if isinstance(message.content, str) and "tool_use" in message.content: + return True + + # Check content list for tool_use blocks + if isinstance(message.content, list): + for item in message.content: + if isinstance(item, dict) and item.get("type") == "tool_use": + return True + + # Check additional_kwargs for tool_calls + if hasattr(message, "additional_kwargs") and message.additional_kwargs.get("tool_calls"): + return True + + return False + + +def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool: + """Check if two messages form a tool use/result pair. + + Args: + message1: First message + message2: Second message + + Returns: + bool: True if the messages form a tool use/result pair + """ + return ( + isinstance(message1, AIMessage) and + isinstance(message2, ToolMessage) and + has_tool_use(message1) + ) + + +def fix_anthropic_message_content(message: BaseMessage) -> BaseMessage: + """Fix message content format for Anthropic API compatibility.""" + if not isinstance(message, AIMessage) or not isinstance(message.content, list): + return message + + fixed_message = message.model_copy(deep=True) + + # Ensure first block is valid thinking type + if fixed_message.content and isinstance(fixed_message.content[0], dict): + first_block_type = fixed_message.content[0].get("type") + if first_block_type not in ("thinking", "redacted_thinking"): + # Prepend redacted_thinking block instead of thinking + fixed_message.content.insert( + 0, + { + "type": "redacted_thinking", + "data": "ENCRYPTED_REASONING", # Required field for redacted_thinking + }, + ) + + # Ensure all thinking blocks have valid structure + for i, block in enumerate(fixed_message.content): + if block.get("type") == "thinking": + # Convert thinking blocks to redacted_thinking to avoid signature validation + fixed_message.content[i] = { + "type": "redacted_thinking", + "data": "ENCRYPTED_REASONING", + } + elif block.get("type") == "redacted_thinking": + # Ensure required data field exists + if "data" not in block: + fixed_message.content[i]["data"] = "ENCRYPTED_REASONING" + + return fixed_message + + +def anthropic_trim_messages( + messages: Sequence[BaseMessage], + *, + max_tokens: int, + token_counter: Callable[[List[BaseMessage]], int], + strategy: Literal["first", "last"] = "last", + num_messages_to_keep: int = 2, + allow_partial: bool = False, + include_system: bool = True, + start_on: Optional[Union[str, type, List[Union[str, type]]]] = None, +) -> List[BaseMessage]: + """Trim messages to fit within a token limit, with Anthropic-specific handling. + + This function is similar to langchain_core's trim_messages but with special + handling for Anthropic message formats to avoid API errors. + + It always keeps the first num_messages_to_keep messages. + + Args: + messages: Sequence of messages to trim + max_tokens: Maximum number of tokens allowed + token_counter: Function to count tokens in messages + strategy: Whether to keep the "first" or "last" messages + allow_partial: Whether to allow partial messages + include_system: Whether to always include the system message + start_on: Message type to start on (only for "last" strategy) + + Returns: + List[BaseMessage]: Trimmed messages that fit within token limit + """ + if not messages: + return [] + + messages = list(messages) + + # Always keep the first num_messages_to_keep messages + kept_messages = messages[:num_messages_to_keep] + remaining_msgs = messages[num_messages_to_keep:] + + # Debug: Print message types for all messages + print("\nDEBUG - All messages:") + for i, msg in enumerate(messages): + msg_type = type(msg).__name__ + tool_use = ( + "tool_use" + if isinstance(msg, AIMessage) + and hasattr(msg, "additional_kwargs") + and msg.additional_kwargs.get("tool_calls") + else "" + ) + tool_result = ( + f"tool_call_id: {msg.tool_call_id}" + if isinstance(msg, ToolMessage) and hasattr(msg, "tool_call_id") + else "" + ) + print(f" [{i}] {msg_type} {tool_use} {tool_result}") + + # For Anthropic, we need to maintain the conversation structure where: + # 1. Every AIMessage with tool_use must be followed by a ToolMessage + # 2. Every AIMessage that follows a ToolMessage must start with a tool_result + + # First, check if we have any tool_use in the messages + has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages) + print(f"DEBUG - Has tool_use anywhere in messages: {has_tool_use_anywhere}") + + # Print debug info for AIMessages + for i, msg in enumerate(messages): + if isinstance(msg, AIMessage): + print(f"DEBUG - AIMessage[{i}] details:") + print(f" has_tool_use: {has_tool_use(msg)}") + if hasattr(msg, "additional_kwargs"): + print(f" additional_kwargs keys: {list(msg.additional_kwargs.keys())}") + + # If we have tool_use anywhere, we need to be very careful about trimming + if has_tool_use_anywhere: + # For safety, just keep all messages if we're under the token limit + if token_counter(messages) <= max_tokens: + print("DEBUG - All messages fit within token limit, keeping all") + return messages + + # We need to identify all tool_use/tool_result relationships + # First, find all AIMessage+ToolMessage pairs + pairs = [] + i = 0 + while i < len(messages) - 1: + if is_tool_pair(messages[i], messages[i+1]): + pairs.append((i, i+1)) + print(f"DEBUG - Found tool_use pair: ({i}, {i+1})") + i += 2 + else: + i += 1 + + print(f"DEBUG - Found {len(pairs)} AIMessage+ToolMessage pairs") + + # For Anthropic, we need to ensure that: + # 1. If we include an AIMessage with tool_use, we must include the following ToolMessage + # 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use + + # The safest approach is to always keep complete AIMessage+ToolMessage pairs together + # First, identify all complete pairs + complete_pairs = [] + for start, end in pairs: + complete_pairs.append((start, end)) + + print(f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs") + + # Now we'll build our result, starting with the kept_messages + # But we need to be careful about the first message if it has tool_use + result = [] + + # Check if the last message in kept_messages has tool_use + if kept_messages and isinstance(kept_messages[-1], AIMessage) and has_tool_use(kept_messages[-1]): + # We need to find the corresponding ToolMessage + for i, (ai_idx, tool_idx) in enumerate(pairs): + if messages[ai_idx] is kept_messages[-1]: + # Found the pair, add all kept_messages except the last one + result.extend(kept_messages[:-1]) + # Add the AIMessage and ToolMessage as a pair + result.extend([messages[ai_idx], messages[tool_idx]]) + # Remove this pair from the list of pairs to process later + pairs = pairs[:i] + pairs[i+1:] + break + else: + # If we didn't find a matching pair, just add all kept_messages + result.extend(kept_messages) + else: + # No tool_use in the last kept message, just add all kept_messages + result.extend(kept_messages) + + # If we're using the "last" strategy, we'll try to include pairs from the end + if strategy == "last": + # First collect all pairs we can include within the token limit + pairs_to_include = [] + + # Process pairs from the end (newest first) + for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)): + # Try adding this pair + test_msgs = result.copy() + + # Add all previously selected pairs + for prev_ai_idx, prev_tool_idx in pairs_to_include: + test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]]) + + # Add this pair + test_msgs.extend([messages[ai_idx], messages[tool_idx]]) + + if token_counter(test_msgs) <= max_tokens: + # This pair fits, add it to our list + pairs_to_include.append((ai_idx, tool_idx)) + print(f"DEBUG - Added complete pair ({ai_idx}, {tool_idx})") + else: + # This pair would exceed the token limit + print(f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping") + break + + # Now add the pairs in the correct order + # Sort by index to maintain the original conversation flow + pairs_to_include.sort(key=lambda x: x[0]) + for ai_idx, tool_idx in pairs_to_include: + result.extend([messages[ai_idx], messages[tool_idx]]) + + # No need to sort - we've already added messages in the correct order + + print(f"DEBUG - Final result has {len(result)} messages") + return result + + # If no tool_use, proceed with normal segmentation + segments = [] + i = 0 + + # Group messages into segments + while i < len(remaining_msgs): + segments.append([remaining_msgs[i]]) + print(f"DEBUG - Added message as segment: [{i}]") + i += 1 + + print(f"\nDEBUG - Created {len(segments)} segments") + for i, segment in enumerate(segments): + segment_types = [type(msg).__name__ for msg in segment] + print(f" Segment {i}: {segment_types}") + + # Now we have segments that maintain the required structure + # We'll add segments from the end (for "last" strategy) or beginning (for "first") + # until we hit the token limit + + if strategy == "last": + # If we have no segments, just return kept_messages + if not segments: + return kept_messages + + result = [] + + # Process segments from the end + for i, segment in enumerate(reversed(segments)): + # Try adding this segment + test_msgs = segment + result + + if token_counter(kept_messages + test_msgs) <= max_tokens: + result = segment + result + print(f"DEBUG - Added segment {len(segments)-i-1} to result") + else: + # This segment would exceed the token limit + print(f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping") + break + + final_result = kept_messages + result + + # For Anthropic, we need to ensure the conversation follows a valid structure + # We'll do a final check of the entire conversation + print("\nDEBUG - Final result before validation:") + for i, msg in enumerate(final_result): + msg_type = type(msg).__name__ + print(f" [{i}] {msg_type}") + + # Validate the conversation structure + valid_result = [] + i = 0 + + # Process messages in order + while i < len(final_result): + current_msg = final_result[i] + + # If this is an AIMessage with tool_use, it must be followed by a ToolMessage + if i < len(final_result) - 1 and isinstance(current_msg, AIMessage) and has_tool_use(current_msg): + if isinstance(final_result[i+1], ToolMessage): + # This is a valid tool_use + tool_result pair + valid_result.append(current_msg) + valid_result.append(final_result[i+1]) + print(f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}") + i += 2 + else: + # Invalid: AIMessage with tool_use not followed by ToolMessage + print(f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage") + # Skip this message to maintain valid structure + i += 1 + else: + # Regular message, just add it + valid_result.append(current_msg) + print(f"DEBUG - Added regular message at position {i}") + i += 1 + + # Final check: don't end with an AIMessage that has tool_use + if valid_result and isinstance(valid_result[-1], AIMessage) and has_tool_use(valid_result[-1]): + print("WARNING: Last message is AIMessage with tool_use but no following ToolMessage") + valid_result.pop() # Remove the last message + + print("\nDEBUG - Final validated result:") + for i, msg in enumerate(valid_result): + msg_type = type(msg).__name__ + print(f" [{i}] {msg_type}") + + return valid_result + + elif strategy == "first": + result = [] + + # Process segments from the beginning + for i, segment in enumerate(segments): + # Try adding this segment + test_msgs = result + segment + if token_counter(kept_messages + test_msgs) <= max_tokens: + result = result + segment + print(f"DEBUG - Added segment {i} to result") + else: + # This segment would exceed the token limit + print(f"DEBUG - Segment {i} would exceed token limit, stopping") + break + + final_result = kept_messages + result + print("\nDEBUG - Final result:") + for i, msg in enumerate(final_result): + msg_type = type(msg).__name__ + print(f" [{i}] {msg_type}") + + return final_result diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index fc6d3ac..c46cbc9 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -1,11 +1,17 @@ """Utilities for handling token limits with Anthropic models.""" from functools import partial -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence from langchain_anthropic import ChatAnthropic -from langchain_core.messages import BaseMessage, trim_messages +from langchain_core.messages import AIMessage, BaseMessage, ToolMessage, trim_messages from langchain_core.messages.base import message_to_dict + +from ra_aid.anthropic_message_utils import ( + fix_anthropic_message_content, + anthropic_trim_messages, + has_tool_use, +) from langgraph.prebuilt.chat_agent_executor import AgentState from litellm import token_counter @@ -13,7 +19,7 @@ from ra_aid.agent_backends.ciayn_agent import CiaynAgent from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params -from ra_aid.console.output import print_messages_compact +from ra_aid.console.output import cpm, print_messages_compact logger = get_logger(__name__) @@ -85,68 +91,58 @@ def create_token_counter_wrapper(model: str): def state_modifier( state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT ) -> list[BaseMessage]: - """Given the agent state and max_tokens, return a trimmed list of messages but always keep the first message. + """Given the agent state and max_tokens, return a trimmed list of messages. + + This uses anthropic_trim_messages which always keeps the first 2 messages. Args: state: The current agent state containing messages model: The language model to use for token counting - max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) + max_input_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) Returns: list[BaseMessage]: Trimmed list of messages that fits within token limit """ messages = state["messages"] - if not messages: return [] - first_message = messages[0] - remaining_messages = messages[1:] - wrapped_token_counter = create_token_counter_wrapper(model.model) - print(f"max_input_tokens={max_input_tokens}") - max_input_tokens = 17000 - first_tokens = wrapped_token_counter([first_message]) - print(f"first_tokens={first_tokens}") - new_max_tokens = max_input_tokens - first_tokens + # Keep max_input_tokens at 21000 as requested + max_input_tokens = 21000 - # Calculate total tokens before trimming - total_tokens_before = wrapped_token_counter(messages) - print( - f"Current token total: {total_tokens_before} (should be at least {first_tokens})" - ) + print("\nDEBUG - Starting token trimming with max_tokens:", max_input_tokens) + print(f"Current token total: {wrapped_token_counter(messages)}") - # Verify the token count is correct - if total_tokens_before < first_tokens: - print(f"WARNING: Token count inconsistency detected! Recounting...") - # Count message by message to debug - for i, msg in enumerate(messages): - msg_tokens = wrapped_token_counter([msg]) - print(f" Message {i}: {msg_tokens} tokens") - # Try alternative counting method - alt_count = sum(wrapped_token_counter([msg]) for msg in messages) - print(f" Alternative count method: {alt_count} tokens") + # Print more details about the messages to help debug + for i, msg in enumerate(messages): + if isinstance(msg, AIMessage): + print(f"DEBUG - AIMessage[{i}] content type: {type(msg.content)}") + print(f"DEBUG - AIMessage[{i}] has_tool_use: {has_tool_use(msg)}") + if has_tool_use(msg) and i < len(messages) - 1: + print( + f"DEBUG - Next message is ToolMessage: {isinstance(messages[i+1], ToolMessage)}" + ) - print_messages_compact(messages) - - trimmed_remaining = trim_messages( - remaining_messages, + result = anthropic_trim_messages( + messages, token_counter=wrapped_token_counter, - max_tokens=new_max_tokens, + max_tokens=max_input_tokens, strategy="last", allow_partial=False, + include_system=True, + num_messages_to_keep=2, ) - result = [first_message] + trimmed_remaining - - # Only show message if some messages were trimmed if len(result) < len(messages): print(f"TRIMMED: {len(messages)} messages → {len(result)} messages") - # Calculate total tokens after trimming total_tokens_after = wrapped_token_counter(result) print(f"New token total: {total_tokens_after}") - + print("BEFORE TRIMMING") + print_messages_compact(messages) + print("AFTER TRIMMING") + print_messages_compact(result) return result @@ -176,12 +172,14 @@ def sonnet_35_state_modifier( total_tokens_before = estimate_messages_tokens(messages) print(f"Current token total: {total_tokens_before}") - trimmed_remaining = trim_messages( + # Trim remaining messages + trimmed_remaining = anthropic_trim_messages( remaining_messages, token_counter=estimate_messages_tokens, max_tokens=new_max_tokens, strategy="last", allow_partial=False, + include_system=True, ) result = [first_message] + trimmed_remaining @@ -193,6 +191,8 @@ def sonnet_35_state_modifier( total_tokens_after = estimate_messages_tokens(result) print(f"New token total: {total_tokens_after}") + # No need to fix message content as anthropic_trim_messages already handles this + return result