added agent_should_exit context
This commit is contained in:
parent
6d2b0a148d
commit
6c85a39bbb
|
|
@ -21,12 +21,14 @@ class AgentContext:
|
|||
self.task_completed = False
|
||||
self.plan_completed = False
|
||||
self.completion_message = ""
|
||||
self.agent_should_exit = False
|
||||
|
||||
# Inherit state from parent if provided
|
||||
if parent_context:
|
||||
self.task_completed = parent_context.task_completed
|
||||
self.plan_completed = parent_context.plan_completed
|
||||
self.completion_message = parent_context.completion_message
|
||||
self.agent_should_exit = parent_context.agent_should_exit
|
||||
|
||||
def mark_task_completed(self, message: str) -> None:
|
||||
"""Mark the current task as completed.
|
||||
|
|
@ -53,6 +55,10 @@ class AgentContext:
|
|||
self.plan_completed = False
|
||||
self.completion_message = ""
|
||||
|
||||
def mark_should_exit(self) -> None:
|
||||
"""Mark that the agent should exit execution."""
|
||||
self.agent_should_exit = True
|
||||
|
||||
@property
|
||||
def is_completed(self) -> bool:
|
||||
"""Check if the current context is marked as completed."""
|
||||
|
|
@ -148,3 +154,20 @@ def get_completion_message() -> str:
|
|||
"""
|
||||
context = get_current_context()
|
||||
return context.completion_message if context else ""
|
||||
|
||||
|
||||
def should_exit() -> bool:
|
||||
"""Check if the agent should exit execution.
|
||||
|
||||
Returns:
|
||||
True if the agent should exit, False otherwise
|
||||
"""
|
||||
context = get_current_context()
|
||||
return context.agent_should_exit if context else False
|
||||
|
||||
|
||||
def mark_should_exit() -> None:
|
||||
"""Mark that the agent should exit execution."""
|
||||
context = get_current_context()
|
||||
if context:
|
||||
context.mark_should_exit()
|
||||
|
|
@ -75,6 +75,7 @@ from ra_aid.agent_context import (
|
|||
is_completed,
|
||||
reset_completion_flags,
|
||||
get_completion_message,
|
||||
should_exit,
|
||||
)
|
||||
from ra_aid.tools.memory import (
|
||||
_global_memory,
|
||||
|
|
@ -903,7 +904,7 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict)
|
|||
check_interrupt()
|
||||
agent_type = get_agent_type(agent)
|
||||
print_agent_output(chunk, agent_type)
|
||||
if is_completed():
|
||||
if is_completed() or should_exit():
|
||||
reset_completion_flags()
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from rich.markdown import Markdown
|
|||
from rich.panel import Panel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from ra_aid.agent_context import mark_task_completed, mark_plan_completed
|
||||
from ra_aid.agent_context import mark_task_completed, mark_plan_completed, mark_should_exit
|
||||
|
||||
|
||||
class WorkLogEntry(TypedDict):
|
||||
|
|
@ -339,6 +339,7 @@ def plan_implementation_completed(message: str) -> str:
|
|||
Args:
|
||||
message: Message explaining how the implementation plan was completed
|
||||
"""
|
||||
mark_should_exit()
|
||||
mark_plan_completed(message)
|
||||
_global_memory["tasks"].clear() # Clear task list when plan is completed
|
||||
_global_memory["task_id_counter"] = 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue