added agent_should_exit context

This commit is contained in:
AI Christianson 2025-02-27 10:14:17 -05:00
parent 6d2b0a148d
commit 6c85a39bbb
3 changed files with 27 additions and 2 deletions

View File

@ -21,12 +21,14 @@ class AgentContext:
self.task_completed = False
self.plan_completed = False
self.completion_message = ""
self.agent_should_exit = False
# Inherit state from parent if provided
if parent_context:
self.task_completed = parent_context.task_completed
self.plan_completed = parent_context.plan_completed
self.completion_message = parent_context.completion_message
self.agent_should_exit = parent_context.agent_should_exit
def mark_task_completed(self, message: str) -> None:
"""Mark the current task as completed.
@ -53,6 +55,10 @@ class AgentContext:
self.plan_completed = False
self.completion_message = ""
def mark_should_exit(self) -> None:
"""Mark that the agent should exit execution."""
self.agent_should_exit = True
@property
def is_completed(self) -> bool:
"""Check if the current context is marked as completed."""
@ -148,3 +154,20 @@ def get_completion_message() -> str:
"""
context = get_current_context()
return context.completion_message if context else ""
def should_exit() -> bool:
"""Check if the agent should exit execution.
Returns:
True if the agent should exit, False otherwise
"""
context = get_current_context()
return context.agent_should_exit if context else False
def mark_should_exit() -> None:
"""Mark that the agent should exit execution."""
context = get_current_context()
if context:
context.mark_should_exit()

View File

@ -75,6 +75,7 @@ from ra_aid.agent_context import (
is_completed,
reset_completion_flags,
get_completion_message,
should_exit,
)
from ra_aid.tools.memory import (
_global_memory,
@ -903,7 +904,7 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict)
check_interrupt()
agent_type = get_agent_type(agent)
print_agent_output(chunk, agent_type)
if is_completed():
if is_completed() or should_exit():
reset_completion_flags()
break

View File

@ -12,7 +12,7 @@ from rich.markdown import Markdown
from rich.panel import Panel
from typing_extensions import TypedDict
from ra_aid.agent_context import mark_task_completed, mark_plan_completed
from ra_aid.agent_context import mark_task_completed, mark_plan_completed, mark_should_exit
class WorkLogEntry(TypedDict):
@ -339,6 +339,7 @@ def plan_implementation_completed(message: str) -> str:
Args:
message: Message explaining how the implementation plan was completed
"""
mark_should_exit()
mark_plan_completed(message)
_global_memory["tasks"].clear() # Clear task list when plan is completed
_global_memory["task_id_counter"] = 1