use interrupt for tighter control of agent loop

This commit is contained in:
AI Christianson 2025-03-07 17:31:27 -05:00
parent 638776c8f8
commit 3607803bf5
1 changed files with 129 additions and 9 deletions

View File

@ -310,7 +310,7 @@ def create_agent(
if is_anthropic_claude(config):
logger.debug("Using create_react_agent to instantiate agent.")
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
return create_react_agent(model, tools, **agent_kwargs)
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs)
else:
logger.debug("Using CiaynAgent agent instance")
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
@ -321,7 +321,7 @@ def create_agent(
config = get_config_repository().get_all()
max_input_tokens = get_model_token_limit(config, agent_type)
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
return create_react_agent(model, tools, **agent_kwargs)
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs)
def run_research_agent(
@ -991,16 +991,136 @@ def _handle_fallback_response(
msg_list.extend(msg_list_response)
# def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
# for chunk in agent.stream({"messages": msg_list}, config):
# logger.debug("Agent output: %s", chunk)
# check_interrupt()
# agent_type = get_agent_type(agent)
# print_agent_output(chunk, agent_type)
# if is_completed() or should_exit():
# reset_completion_flags()
# break
# def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
# while True: ## WE NEED TO ONLY KEEP ITERATING IF IT IS AN INTERRUPT, NOT UNCONDITIONALLY
# stream = agent.stream({"messages": msg_list}, config)
# for chunk in stream:
# logger.debug("Agent output: %s", chunk)
# check_interrupt()
# agent_type = get_agent_type(agent)
# print_agent_output(chunk, agent_type)
# if is_completed() or should_exit():
# reset_completion_flags()
# return True
# print("HERE!")
# def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
# while True:
# for chunk in agent.stream({"messages": msg_list}, config):
# print("Chunk received:", chunk)
# check_interrupt()
# agent_type = get_agent_type(agent)
# print_agent_output(chunk, agent_type)
# if is_completed() or should_exit():
# reset_completion_flags()
# return True
# print("HERE!")
# print("Config passed to _run_agent_stream:", config)
# print("Config keys:", list(config.keys()))
# # Ensure the configuration for state retrieval contains a 'configurable' key.
# state_config = config.copy()
# if "configurable" not in state_config:
# print("Key 'configurable' not found in config. Adding it as an empty dict.")
# state_config["configurable"] = {}
# print("Using state_config for agent.get_state():", state_config)
# try:
# state = agent.get_state(state_config)
# print("Agent state retrieved:", state)
# print("State type:", type(state))
# print("State attributes:", dir(state))
# except Exception as e:
# print("Error retrieving agent state with state_config", state_config, ":", e)
# raise
# # Since state.current is not available, we rely solely on state.next.
# try:
# next_node = state.next
# print("State next value:", next_node)
# except Exception as e:
# print("Error accessing state.next:", e)
# next_node = None
# # Resume execution if state.next is truthy (indicating further steps remain).
# if next_node:
# print("Resuming execution because state.next is nonempty:", next_node)
# agent.invoke(None, config)
# continue
# else:
# print("No further steps indicated; breaking out of loop.")
# break
# return True
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
for chunk in agent.stream({"messages": msg_list}, config):
logger.debug("Agent output: %s", chunk)
check_interrupt()
agent_type = get_agent_type(agent)
print_agent_output(chunk, agent_type)
if is_completed() or should_exit():
reset_completion_flags()
"""
Streams agent output while handling completion and interruption.
For each chunk, it logs the output, calls check_interrupt(), prints agent output,
and then checks if is_completed() or should_exit() are true. If so, it resets completion
flags and returns. After finishing a stream iteration (i.e. the for-loop over chunks),
the function retrieves the agent's state. If the state indicates further steps (i.e. state.next is non-empty),
it resumes execution via agent.invoke(None, config); otherwise, it exits the loop.
This function adheres to the latest LangGraph best practices (as of March 2025) for handling
human-in-the-loop interruptions using interrupt_after=["tools"].
"""
while True:
print("HERE")
# Process each chunk from the agent stream.
for chunk in agent.stream({"messages": msg_list}, config):
print("HERE IN FOR CHUNK")
logger.debug("Agent output: %s", chunk)
check_interrupt()
agent_type = get_agent_type(agent)
print_agent_output(chunk, agent_type)
if is_completed() or should_exit():
print("IS COMPLETED OR SHOULD EXIT TRIGGERED")
reset_completion_flags()
return True # Exit immediately when finished or signaled to exit.
logger.debug("Stream iteration ended; checking agent state for continuation.")
# Prepare state configuration, ensuring 'configurable' is present.
state_config = config.copy()
if "configurable" not in state_config:
logger.debug("Key 'configurable' not found in config; adding it as an empty dict.")
state_config["configurable"] = {}
logger.debug("Using state_config for agent.get_state(): %s", state_config)
try:
state = agent.get_state(state_config)
logger.debug("Agent state retrieved: %s", state)
except Exception as e:
logger.error("Error retrieving agent state with state_config %s: %s", state_config, e)
raise
# If the state indicates that further steps remain (i.e. state.next is non-empty),
# then resume execution by invoking the agent with no new input.
if state.next:
print("HAS NEXT STATE")
logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next)
agent.invoke(None, config)
continue
else:
print("NO NEXT STATE")
logger.debug("No continuation indicated in state; exiting stream loop.")
break
print("WHILE EXITED")
return True
def run_agent_with_retry(
agent: RAgents,