fix agent context
This commit is contained in:
parent
3607803bf5
commit
60e4616313
|
|
@ -1,11 +1,12 @@
|
|||
"""Context manager for tracking agent state and completion status."""
|
||||
|
||||
import threading
|
||||
import threading # Keep for backward compatibility
|
||||
import contextvars
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
# Thread-local storage for context variables
|
||||
_thread_local = threading.local()
|
||||
# Create contextvar to hold the agent context
|
||||
agent_context_var = contextvars.ContextVar("agent_context", default=None)
|
||||
|
||||
|
||||
class AgentContext:
|
||||
|
|
@ -121,7 +122,7 @@ def get_current_context() -> Optional[AgentContext]:
|
|||
Returns:
|
||||
The current AgentContext or None if no context is active
|
||||
"""
|
||||
return getattr(_thread_local, "current_context", None)
|
||||
return agent_context_var.get()
|
||||
|
||||
|
||||
def get_depth() -> int:
|
||||
|
|
@ -150,7 +151,7 @@ def agent_context(parent_context=None):
|
|||
The newly created AgentContext
|
||||
"""
|
||||
# Save the previous context
|
||||
previous_context = getattr(_thread_local, "current_context", None)
|
||||
previous_context = agent_context_var.get()
|
||||
|
||||
# Create a new context, inheriting from parent if provided
|
||||
# If parent_context is None but previous_context exists, use previous_context as parent
|
||||
|
|
@ -159,14 +160,14 @@ def agent_context(parent_context=None):
|
|||
else:
|
||||
context = AgentContext(parent_context)
|
||||
|
||||
# Set as current context
|
||||
_thread_local.current_context = context
|
||||
# Set as current context and get token for resetting later
|
||||
token = agent_context_var.set(context)
|
||||
|
||||
try:
|
||||
yield context
|
||||
finally:
|
||||
# Restore previous context
|
||||
_thread_local.current_context = previous_context
|
||||
agent_context_var.reset(token)
|
||||
|
||||
|
||||
def mark_task_completed(message: str) -> None:
|
||||
|
|
@ -271,4 +272,4 @@ def get_crash_message() -> Optional[str]:
|
|||
The crash message or None if the agent has not crashed
|
||||
"""
|
||||
context = get_current_context()
|
||||
return context.agent_crashed_message if context and context.is_crashed() else None
|
||||
return context.agent_crashed_message if context and context.is_crashed() else None
|
||||
|
|
@ -1077,16 +1077,13 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict)
|
|||
human-in-the-loop interruptions using interrupt_after=["tools"].
|
||||
"""
|
||||
while True:
|
||||
print("HERE")
|
||||
# Process each chunk from the agent stream.
|
||||
for chunk in agent.stream({"messages": msg_list}, config):
|
||||
print("HERE IN FOR CHUNK")
|
||||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
agent_type = get_agent_type(agent)
|
||||
print_agent_output(chunk, agent_type)
|
||||
if is_completed() or should_exit():
|
||||
print("IS COMPLETED OR SHOULD EXIT TRIGGERED")
|
||||
reset_completion_flags()
|
||||
return True # Exit immediately when finished or signaled to exit.
|
||||
logger.debug("Stream iteration ended; checking agent state for continuation.")
|
||||
|
|
@ -1108,20 +1105,15 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict)
|
|||
# If the state indicates that further steps remain (i.e. state.next is non-empty),
|
||||
# then resume execution by invoking the agent with no new input.
|
||||
if state.next:
|
||||
print("HAS NEXT STATE")
|
||||
logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next)
|
||||
agent.invoke(None, config)
|
||||
continue
|
||||
else:
|
||||
print("NO NEXT STATE")
|
||||
logger.debug("No continuation indicated in state; exiting stream loop.")
|
||||
break
|
||||
|
||||
print("WHILE EXITED")
|
||||
return True
|
||||
|
||||
|
||||
|
||||
def run_agent_with_retry(
|
||||
agent: RAgents,
|
||||
prompt: str,
|
||||
|
|
|
|||
Loading…
Reference in New Issue