feat(agent_utils.py): add get_model_name_from_chat_model function to improve model handling
refactor(build_agent_kwargs): simplify state modifier logic by using model name instead of model attribute
This commit is contained in:
parent
f1274b3164
commit
6c159d39d4
|
|
@ -52,6 +52,7 @@ from ra_aid.database.repositories.human_input_repository import (
|
||||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.anthropic_token_limiter import (
|
from ra_aid.anthropic_token_limiter import (
|
||||||
|
get_model_name_from_chat_model,
|
||||||
sonnet_35_state_modifier,
|
sonnet_35_state_modifier,
|
||||||
state_modifier,
|
state_modifier,
|
||||||
get_model_token_limit,
|
get_model_token_limit,
|
||||||
|
|
@ -102,11 +103,10 @@ def build_agent_kwargs(
|
||||||
):
|
):
|
||||||
|
|
||||||
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
||||||
if not hasattr(model, 'model'):
|
model_name = get_model_name_from_chat_model(model)
|
||||||
return state_modifier(state, model, max_input_tokens=max_input_tokens)
|
|
||||||
|
|
||||||
if any(
|
if any(
|
||||||
pattern in model.model
|
pattern in model_name
|
||||||
for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]
|
for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]
|
||||||
):
|
):
|
||||||
return sonnet_35_state_modifier(
|
return sonnet_35_state_modifier(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue