RA.Aid/ra_aid/agent_context.py

275 lines
8.3 KiB
Python

"""Context manager for tracking agent state and completion status."""
import threading # Keep for backward compatibility
import contextvars
from contextlib import contextmanager
from typing import Optional
# Create contextvar to hold the agent context
agent_context_var = contextvars.ContextVar("agent_context", default=None)
class AgentContext:
"""Context manager for agent state tracking."""
def __init__(self, parent_context=None):
"""Initialize a new agent context.
Args:
parent_context: Optional parent context to inherit state from
"""
# Store reference to parent context
self.parent = parent_context
# Initialize completion flags
self.task_completed = False
self.plan_completed = False
self.completion_message = ""
self.agent_should_exit = False
self.agent_has_crashed = False
self.agent_crashed_message = None
# Note: Completion flags (task_completed, plan_completed, completion_message,
# agent_should_exit) are no longer inherited from parent contexts
def mark_task_completed(self, message: str) -> None:
"""Mark the current task as completed.
Args:
message: Completion message explaining how/why the task is complete
"""
self.task_completed = True
self.completion_message = message
def mark_plan_completed(self, message: str) -> None:
"""Mark the current plan as completed.
Args:
message: Completion message explaining how the plan was completed
"""
self.task_completed = True
self.plan_completed = True
self.completion_message = message
def reset_completion_flags(self) -> None:
"""Reset all completion flags."""
self.task_completed = False
self.plan_completed = False
self.completion_message = ""
def mark_should_exit(self, propagation_depth: Optional[int] = 0) -> None:
"""Mark that the agent should exit execution.
Args:
propagation_depth: How far up the context hierarchy to propagate the flag.
None: Propagate to all parent contexts
0 (default): Only mark the current context
1: Mark the current context and its immediate parent
2+: Propagate up the specified number of levels
"""
self.agent_should_exit = True
# Propagate to parent context based on propagation_depth
if propagation_depth is None:
# Maintain current behavior of unlimited propagation
if self.parent:
self.parent.mark_should_exit(propagation_depth)
elif propagation_depth > 0:
# Propagate to parent with decremented depth
if self.parent:
self.parent.mark_should_exit(propagation_depth - 1)
# If propagation_depth is 0, don't propagate to parent
def mark_agent_crashed(self, message: str) -> None:
"""Mark the agent as crashed with the given message.
Unlike exit state, crash state does not propagate to parent contexts.
Args:
message: Error message explaining the crash
"""
self.agent_has_crashed = True
self.agent_crashed_message = message
def is_crashed(self) -> bool:
"""Check if the agent has crashed.
Returns:
True if the agent has crashed, False otherwise
"""
return self.agent_has_crashed
@property
def is_completed(self) -> bool:
"""Check if the current context is marked as completed."""
return self.task_completed or self.plan_completed
@property
def depth(self) -> int:
"""Calculate the depth of this context based on parent chain.
Returns:
int: 0 for a context with no parent, parent.depth + 1 otherwise
"""
if self.parent is None:
return 0
return self.parent.depth + 1
def get_current_context() -> Optional[AgentContext]:
"""Get the current agent context for this thread.
Returns:
The current AgentContext or None if no context is active
"""
return agent_context_var.get()
def get_depth() -> int:
"""Get the depth of the current agent context.
Returns:
int: Depth of the current context, or 0 if no context exists
"""
ctx = get_current_context()
if ctx is None:
return 0
return ctx.depth
@contextmanager
def agent_context(parent_context=None):
"""Context manager for agent execution.
Creates a new agent context and makes it the current context for the duration
of the with block. Restores the previous context when exiting the block.
Args:
parent_context: Optional parent context to inherit state from
Yields:
The newly created AgentContext
"""
# Save the previous context
previous_context = agent_context_var.get()
# Create a new context, inheriting from parent if provided
# If parent_context is None but previous_context exists, use previous_context as parent
if parent_context is None and previous_context is not None:
context = AgentContext(previous_context)
else:
context = AgentContext(parent_context)
# Set as current context and get token for resetting later
token = agent_context_var.set(context)
try:
yield context
finally:
# Restore previous context
agent_context_var.reset(token)
def mark_task_completed(message: str) -> None:
"""Mark the current task as completed.
Args:
message: Completion message explaining how/why the task is complete
"""
context = get_current_context()
if context:
context.mark_task_completed(message)
def mark_plan_completed(message: str) -> None:
"""Mark the current plan as completed.
Args:
message: Completion message explaining how the plan was completed
"""
context = get_current_context()
if context:
context.mark_plan_completed(message)
def reset_completion_flags() -> None:
"""Reset completion flags in the current context."""
context = get_current_context()
if context:
context.reset_completion_flags()
def is_completed() -> bool:
"""Check if the current context is marked as completed.
Returns:
True if the current context is marked as completed, False otherwise
"""
context = get_current_context()
return context.is_completed if context else False
def get_completion_message() -> str:
"""Get the completion message from the current context.
Returns:
The completion message or empty string if no context or no message
"""
context = get_current_context()
return context.completion_message if context else ""
def should_exit() -> bool:
"""Check if the agent should exit execution.
Returns:
True if the agent should exit, False otherwise
"""
context = get_current_context()
return context.agent_should_exit if context else False
def mark_should_exit(propagation_depth: Optional[int] = 0) -> None:
"""Mark that the agent should exit execution.
Args:
propagation_depth: How far up the context hierarchy to propagate the flag.
None: Propagate to all parent contexts
0 (default): Only mark the current context
1: Mark the current context and its immediate parent
2+: Propagate up the specified number of levels
"""
context = get_current_context()
if context:
context.mark_should_exit(propagation_depth)
def is_crashed() -> bool:
"""Check if the current agent has crashed.
Returns:
True if the current agent has crashed, False otherwise
"""
context = get_current_context()
return context.is_crashed() if context else False
def mark_agent_crashed(message: str) -> None:
"""Mark the current agent as crashed with the given message.
Args:
message: Error message explaining the crash
"""
context = get_current_context()
if context:
context.mark_agent_crashed(message)
def get_crash_message() -> Optional[str]:
"""Get the crash message from the current context.
Returns:
The crash message or None if the agent has not crashed
"""
context = get_current_context()
return context.agent_crashed_message if context and context.is_crashed() else None