Add --temperature CLI parameter.
This commit is contained in:
parent
ace34633de
commit
a28ed59bca
|
|
@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
- Add CiaynAgent to support models that do not have, or are not good at, agentic function calling.
|
||||
- Improve env var validation.
|
||||
- Add --temperature CLI parameter.
|
||||
|
||||
## [0.10.3] - 2024-12-27
|
||||
|
||||
|
|
|
|||
|
|
@ -104,6 +104,12 @@ Examples:
|
|||
action='store_true',
|
||||
help='Enable verbose logging output'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--temperature',
|
||||
type=float,
|
||||
help='LLM temperature (0.0-2.0). Controls randomness in responses',
|
||||
default=None
|
||||
)
|
||||
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
|
|
@ -129,6 +135,10 @@ Examples:
|
|||
if parsed_args.expert_provider != 'openai' and not parsed_args.expert_model and not parsed_args.research_only:
|
||||
parser.error(f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'")
|
||||
|
||||
# Validate temperature range if provided
|
||||
if parsed_args.temperature is not None and not (0.0 <= parsed_args.temperature <= 2.0):
|
||||
parser.error('Temperature must be between 0.0 and 2.0')
|
||||
|
||||
return parsed_args
|
||||
|
||||
# Create console instance
|
||||
|
|
@ -179,7 +189,7 @@ def main():
|
|||
))
|
||||
|
||||
# Create the base model after validation
|
||||
model = initialize_llm(args.provider, args.model)
|
||||
model = initialize_llm(args.provider, args.model, temperature=args.temperature)
|
||||
|
||||
# Handle chat mode
|
||||
if args.chat:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from langchain_openai import ChatOpenAI
|
|||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
|
||||
def initialize_llm(provider: str, model_name: str, temperature: float | None = None) -> BaseChatModel:
|
||||
"""Initialize a language model client based on the specified provider and model.
|
||||
|
||||
Note: Environment variables must be validated before calling this function.
|
||||
|
|
@ -12,6 +12,8 @@ def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
|
|||
Args:
|
||||
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible')
|
||||
model_name: Name of the model to use
|
||||
temperature: Optional temperature setting for controlling randomness (0.0-2.0).
|
||||
If not specified, provider-specific defaults are used.
|
||||
|
||||
Returns:
|
||||
BaseChatModel: Configured language model client
|
||||
|
|
@ -23,23 +25,26 @@ def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
|
|||
return ChatOpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model=model_name,
|
||||
**({"temperature": temperature} if temperature is not None else {})
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
return ChatAnthropic(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
model_name=model_name,
|
||||
**({"temperature": temperature} if temperature is not None else {})
|
||||
)
|
||||
elif provider == "openrouter":
|
||||
return ChatOpenAI(
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model=model_name,
|
||||
**({"temperature": temperature} if temperature is not None else {})
|
||||
)
|
||||
elif provider == "openai-compatible":
|
||||
return ChatOpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
base_url=os.getenv("OPENAI_API_BASE"),
|
||||
temperature=0.3,
|
||||
temperature=temperature if temperature is not None else 0.3,
|
||||
model=model_name,
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -154,6 +154,83 @@ def test_initialize_unsupported_provider(clean_env):
|
|||
initialize_llm("unsupported", "model")
|
||||
assert str(exc_info.value) == "Unsupported provider: unsupported"
|
||||
|
||||
def test_temperature_defaults(clean_env, mock_openai, mock_anthropic):
|
||||
"""Test default temperature behavior for different providers."""
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
os.environ["ANTHROPIC_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_API_BASE"] = "http://test-url"
|
||||
|
||||
# Test openai-compatible default temperature
|
||||
initialize_llm("openai-compatible", "test-model")
|
||||
mock_openai.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url="http://test-url",
|
||||
model="test-model",
|
||||
temperature=0.3
|
||||
)
|
||||
|
||||
# Test other providers don't set temperature by default
|
||||
initialize_llm("openai", "test-model")
|
||||
mock_openai.assert_called_with(
|
||||
api_key="test-key",
|
||||
model="test-model"
|
||||
)
|
||||
|
||||
initialize_llm("anthropic", "test-model")
|
||||
mock_anthropic.assert_called_with(
|
||||
api_key="test-key",
|
||||
model_name="test-model"
|
||||
)
|
||||
|
||||
def test_explicit_temperature(clean_env, mock_openai, mock_anthropic):
|
||||
"""Test explicit temperature setting for each provider."""
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
os.environ["ANTHROPIC_API_KEY"] = "test-key"
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-key"
|
||||
|
||||
test_temp = 0.7
|
||||
|
||||
# Test OpenAI
|
||||
initialize_llm("openai", "test-model", temperature=test_temp)
|
||||
mock_openai.assert_called_with(
|
||||
api_key="test-key",
|
||||
model="test-model",
|
||||
temperature=test_temp
|
||||
)
|
||||
|
||||
# Test Anthropic
|
||||
initialize_llm("anthropic", "test-model", temperature=test_temp)
|
||||
mock_anthropic.assert_called_with(
|
||||
api_key="test-key",
|
||||
model_name="test-model",
|
||||
temperature=test_temp
|
||||
)
|
||||
|
||||
# Test OpenRouter
|
||||
initialize_llm("openrouter", "test-model", temperature=test_temp)
|
||||
mock_openai.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test-model",
|
||||
temperature=test_temp
|
||||
)
|
||||
|
||||
def test_temperature_validation(clean_env, mock_openai):
|
||||
"""Test temperature validation in command line arguments."""
|
||||
from ra_aid.__main__ import parse_arguments
|
||||
|
||||
# Test temperature below minimum
|
||||
with pytest.raises(SystemExit):
|
||||
parse_arguments(['--message', 'test', '--temperature', '-0.1'])
|
||||
|
||||
# Test temperature above maximum
|
||||
with pytest.raises(SystemExit):
|
||||
parse_arguments(['--message', 'test', '--temperature', '2.1'])
|
||||
|
||||
# Test valid temperature
|
||||
args = parse_arguments(['--message', 'test', '--temperature', '0.7'])
|
||||
assert args.temperature == 0.7
|
||||
|
||||
def test_provider_name_validation():
|
||||
"""Test provider name validation and normalization."""
|
||||
# Test all supported providers
|
||||
|
|
|
|||
Loading…
Reference in New Issue