275 lines
8.3 KiB
Python
275 lines
8.3 KiB
Python
"""Context manager for tracking agent state and completion status."""
|
|
|
|
import threading # Keep for backward compatibility
|
|
import contextvars
|
|
from contextlib import contextmanager
|
|
from typing import Optional
|
|
|
|
# Create contextvar to hold the agent context
|
|
agent_context_var = contextvars.ContextVar("agent_context", default=None)
|
|
|
|
|
|
class AgentContext:
|
|
"""Context manager for agent state tracking."""
|
|
|
|
def __init__(self, parent_context=None):
|
|
"""Initialize a new agent context.
|
|
|
|
Args:
|
|
parent_context: Optional parent context to inherit state from
|
|
"""
|
|
# Store reference to parent context
|
|
self.parent = parent_context
|
|
|
|
# Initialize completion flags
|
|
self.task_completed = False
|
|
self.plan_completed = False
|
|
self.completion_message = ""
|
|
self.agent_should_exit = False
|
|
self.agent_has_crashed = False
|
|
self.agent_crashed_message = None
|
|
|
|
# Note: Completion flags (task_completed, plan_completed, completion_message,
|
|
# agent_should_exit) are no longer inherited from parent contexts
|
|
|
|
def mark_task_completed(self, message: str) -> None:
|
|
"""Mark the current task as completed.
|
|
|
|
Args:
|
|
message: Completion message explaining how/why the task is complete
|
|
"""
|
|
self.task_completed = True
|
|
self.completion_message = message
|
|
|
|
def mark_plan_completed(self, message: str) -> None:
|
|
"""Mark the current plan as completed.
|
|
|
|
Args:
|
|
message: Completion message explaining how the plan was completed
|
|
"""
|
|
self.task_completed = True
|
|
self.plan_completed = True
|
|
self.completion_message = message
|
|
|
|
def reset_completion_flags(self) -> None:
|
|
"""Reset all completion flags."""
|
|
self.task_completed = False
|
|
self.plan_completed = False
|
|
self.completion_message = ""
|
|
|
|
def mark_should_exit(self, propagation_depth: Optional[int] = 0) -> None:
|
|
"""Mark that the agent should exit execution.
|
|
|
|
Args:
|
|
propagation_depth: How far up the context hierarchy to propagate the flag.
|
|
None: Propagate to all parent contexts
|
|
0 (default): Only mark the current context
|
|
1: Mark the current context and its immediate parent
|
|
2+: Propagate up the specified number of levels
|
|
"""
|
|
self.agent_should_exit = True
|
|
|
|
# Propagate to parent context based on propagation_depth
|
|
if propagation_depth is None:
|
|
# Maintain current behavior of unlimited propagation
|
|
if self.parent:
|
|
self.parent.mark_should_exit(propagation_depth)
|
|
elif propagation_depth > 0:
|
|
# Propagate to parent with decremented depth
|
|
if self.parent:
|
|
self.parent.mark_should_exit(propagation_depth - 1)
|
|
# If propagation_depth is 0, don't propagate to parent
|
|
|
|
def mark_agent_crashed(self, message: str) -> None:
|
|
"""Mark the agent as crashed with the given message.
|
|
|
|
Unlike exit state, crash state does not propagate to parent contexts.
|
|
|
|
Args:
|
|
message: Error message explaining the crash
|
|
"""
|
|
self.agent_has_crashed = True
|
|
self.agent_crashed_message = message
|
|
|
|
def is_crashed(self) -> bool:
|
|
"""Check if the agent has crashed.
|
|
|
|
Returns:
|
|
True if the agent has crashed, False otherwise
|
|
"""
|
|
return self.agent_has_crashed
|
|
|
|
@property
|
|
def is_completed(self) -> bool:
|
|
"""Check if the current context is marked as completed."""
|
|
return self.task_completed or self.plan_completed
|
|
|
|
@property
|
|
def depth(self) -> int:
|
|
"""Calculate the depth of this context based on parent chain.
|
|
|
|
Returns:
|
|
int: 0 for a context with no parent, parent.depth + 1 otherwise
|
|
"""
|
|
if self.parent is None:
|
|
return 0
|
|
return self.parent.depth + 1
|
|
|
|
|
|
def get_current_context() -> Optional[AgentContext]:
|
|
"""Get the current agent context for this thread.
|
|
|
|
Returns:
|
|
The current AgentContext or None if no context is active
|
|
"""
|
|
return agent_context_var.get()
|
|
|
|
|
|
def get_depth() -> int:
|
|
"""Get the depth of the current agent context.
|
|
|
|
Returns:
|
|
int: Depth of the current context, or 0 if no context exists
|
|
"""
|
|
ctx = get_current_context()
|
|
if ctx is None:
|
|
return 0
|
|
return ctx.depth
|
|
|
|
|
|
@contextmanager
|
|
def agent_context(parent_context=None):
|
|
"""Context manager for agent execution.
|
|
|
|
Creates a new agent context and makes it the current context for the duration
|
|
of the with block. Restores the previous context when exiting the block.
|
|
|
|
Args:
|
|
parent_context: Optional parent context to inherit state from
|
|
|
|
Yields:
|
|
The newly created AgentContext
|
|
"""
|
|
# Save the previous context
|
|
previous_context = agent_context_var.get()
|
|
|
|
# Create a new context, inheriting from parent if provided
|
|
# If parent_context is None but previous_context exists, use previous_context as parent
|
|
if parent_context is None and previous_context is not None:
|
|
context = AgentContext(previous_context)
|
|
else:
|
|
context = AgentContext(parent_context)
|
|
|
|
# Set as current context and get token for resetting later
|
|
token = agent_context_var.set(context)
|
|
|
|
try:
|
|
yield context
|
|
finally:
|
|
# Restore previous context
|
|
agent_context_var.reset(token)
|
|
|
|
|
|
def mark_task_completed(message: str) -> None:
|
|
"""Mark the current task as completed.
|
|
|
|
Args:
|
|
message: Completion message explaining how/why the task is complete
|
|
"""
|
|
context = get_current_context()
|
|
if context:
|
|
context.mark_task_completed(message)
|
|
|
|
|
|
def mark_plan_completed(message: str) -> None:
|
|
"""Mark the current plan as completed.
|
|
|
|
Args:
|
|
message: Completion message explaining how the plan was completed
|
|
"""
|
|
context = get_current_context()
|
|
if context:
|
|
context.mark_plan_completed(message)
|
|
|
|
|
|
def reset_completion_flags() -> None:
|
|
"""Reset completion flags in the current context."""
|
|
context = get_current_context()
|
|
if context:
|
|
context.reset_completion_flags()
|
|
|
|
|
|
def is_completed() -> bool:
|
|
"""Check if the current context is marked as completed.
|
|
|
|
Returns:
|
|
True if the current context is marked as completed, False otherwise
|
|
"""
|
|
context = get_current_context()
|
|
return context.is_completed if context else False
|
|
|
|
|
|
def get_completion_message() -> str:
|
|
"""Get the completion message from the current context.
|
|
|
|
Returns:
|
|
The completion message or empty string if no context or no message
|
|
"""
|
|
context = get_current_context()
|
|
return context.completion_message if context else ""
|
|
|
|
|
|
def should_exit() -> bool:
|
|
"""Check if the agent should exit execution.
|
|
|
|
Returns:
|
|
True if the agent should exit, False otherwise
|
|
"""
|
|
context = get_current_context()
|
|
return context.agent_should_exit if context else False
|
|
|
|
|
|
def mark_should_exit(propagation_depth: Optional[int] = 0) -> None:
|
|
"""Mark that the agent should exit execution.
|
|
|
|
Args:
|
|
propagation_depth: How far up the context hierarchy to propagate the flag.
|
|
None: Propagate to all parent contexts
|
|
0 (default): Only mark the current context
|
|
1: Mark the current context and its immediate parent
|
|
2+: Propagate up the specified number of levels
|
|
"""
|
|
context = get_current_context()
|
|
if context:
|
|
context.mark_should_exit(propagation_depth)
|
|
|
|
|
|
def is_crashed() -> bool:
|
|
"""Check if the current agent has crashed.
|
|
|
|
Returns:
|
|
True if the current agent has crashed, False otherwise
|
|
"""
|
|
context = get_current_context()
|
|
return context.is_crashed() if context else False
|
|
|
|
|
|
def mark_agent_crashed(message: str) -> None:
|
|
"""Mark the current agent as crashed with the given message.
|
|
|
|
Args:
|
|
message: Error message explaining the crash
|
|
"""
|
|
context = get_current_context()
|
|
if context:
|
|
context.mark_agent_crashed(message)
|
|
|
|
|
|
def get_crash_message() -> Optional[str]:
|
|
"""Get the crash message from the current context.
|
|
|
|
Returns:
|
|
The crash message or None if the agent has not crashed
|
|
"""
|
|
context = get_current_context()
|
|
return context.agent_crashed_message if context and context.is_crashed() else None |