fix agent context
This commit is contained in:
parent
3607803bf5
commit
60e4616313
|
|
@ -1,11 +1,12 @@
|
||||||
"""Context manager for tracking agent state and completion status."""
|
"""Context manager for tracking agent state and completion status."""
|
||||||
|
|
||||||
import threading
|
import threading # Keep for backward compatibility
|
||||||
|
import contextvars
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# Thread-local storage for context variables
|
# Create contextvar to hold the agent context
|
||||||
_thread_local = threading.local()
|
agent_context_var = contextvars.ContextVar("agent_context", default=None)
|
||||||
|
|
||||||
|
|
||||||
class AgentContext:
|
class AgentContext:
|
||||||
|
|
@ -121,7 +122,7 @@ def get_current_context() -> Optional[AgentContext]:
|
||||||
Returns:
|
Returns:
|
||||||
The current AgentContext or None if no context is active
|
The current AgentContext or None if no context is active
|
||||||
"""
|
"""
|
||||||
return getattr(_thread_local, "current_context", None)
|
return agent_context_var.get()
|
||||||
|
|
||||||
|
|
||||||
def get_depth() -> int:
|
def get_depth() -> int:
|
||||||
|
|
@ -150,7 +151,7 @@ def agent_context(parent_context=None):
|
||||||
The newly created AgentContext
|
The newly created AgentContext
|
||||||
"""
|
"""
|
||||||
# Save the previous context
|
# Save the previous context
|
||||||
previous_context = getattr(_thread_local, "current_context", None)
|
previous_context = agent_context_var.get()
|
||||||
|
|
||||||
# Create a new context, inheriting from parent if provided
|
# Create a new context, inheriting from parent if provided
|
||||||
# If parent_context is None but previous_context exists, use previous_context as parent
|
# If parent_context is None but previous_context exists, use previous_context as parent
|
||||||
|
|
@ -159,14 +160,14 @@ def agent_context(parent_context=None):
|
||||||
else:
|
else:
|
||||||
context = AgentContext(parent_context)
|
context = AgentContext(parent_context)
|
||||||
|
|
||||||
# Set as current context
|
# Set as current context and get token for resetting later
|
||||||
_thread_local.current_context = context
|
token = agent_context_var.set(context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield context
|
yield context
|
||||||
finally:
|
finally:
|
||||||
# Restore previous context
|
# Restore previous context
|
||||||
_thread_local.current_context = previous_context
|
agent_context_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
def mark_task_completed(message: str) -> None:
|
def mark_task_completed(message: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -1077,16 +1077,13 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict)
|
||||||
human-in-the-loop interruptions using interrupt_after=["tools"].
|
human-in-the-loop interruptions using interrupt_after=["tools"].
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
print("HERE")
|
|
||||||
# Process each chunk from the agent stream.
|
# Process each chunk from the agent stream.
|
||||||
for chunk in agent.stream({"messages": msg_list}, config):
|
for chunk in agent.stream({"messages": msg_list}, config):
|
||||||
print("HERE IN FOR CHUNK")
|
|
||||||
logger.debug("Agent output: %s", chunk)
|
logger.debug("Agent output: %s", chunk)
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
agent_type = get_agent_type(agent)
|
agent_type = get_agent_type(agent)
|
||||||
print_agent_output(chunk, agent_type)
|
print_agent_output(chunk, agent_type)
|
||||||
if is_completed() or should_exit():
|
if is_completed() or should_exit():
|
||||||
print("IS COMPLETED OR SHOULD EXIT TRIGGERED")
|
|
||||||
reset_completion_flags()
|
reset_completion_flags()
|
||||||
return True # Exit immediately when finished or signaled to exit.
|
return True # Exit immediately when finished or signaled to exit.
|
||||||
logger.debug("Stream iteration ended; checking agent state for continuation.")
|
logger.debug("Stream iteration ended; checking agent state for continuation.")
|
||||||
|
|
@ -1108,20 +1105,15 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict)
|
||||||
# If the state indicates that further steps remain (i.e. state.next is non-empty),
|
# If the state indicates that further steps remain (i.e. state.next is non-empty),
|
||||||
# then resume execution by invoking the agent with no new input.
|
# then resume execution by invoking the agent with no new input.
|
||||||
if state.next:
|
if state.next:
|
||||||
print("HAS NEXT STATE")
|
|
||||||
logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next)
|
logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next)
|
||||||
agent.invoke(None, config)
|
agent.invoke(None, config)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
print("NO NEXT STATE")
|
|
||||||
logger.debug("No continuation indicated in state; exiting stream loop.")
|
logger.debug("No continuation indicated in state; exiting stream loop.")
|
||||||
break
|
break
|
||||||
|
|
||||||
print("WHILE EXITED")
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_agent_with_retry(
|
def run_agent_with_retry(
|
||||||
agent: RAgents,
|
agent: RAgents,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue