diff --git a/backend/app/agent/agent_model.py b/backend/app/agent/agent_model.py index 154181b8..eb5c27ba 100644 --- a/backend/app/agent/agent_model.py +++ b/backend/app/agent/agent_model.py @@ -136,11 +136,11 @@ def agent_model( ) model_platform_enum = None - if ( - effective_config["model_platform"].lower() == "anthropic" - and model_config.get("cache_control") is None - ): - model_config["cache_control"] = "5m" + if effective_config["model_platform"].lower() == "anthropic": + if model_config.get("cache_control") is None: + model_config["cache_control"] = "5m" + if model_config.get("max_tokens") is None: + model_config["max_tokens"] = 64000 model = ModelFactory.create( model_platform=effective_config["model_platform"], diff --git a/backend/app/component/model_validation.py b/backend/app/component/model_validation.py index c8da48be..bad0ad48 100644 --- a/backend/app/component/model_validation.py +++ b/backend/app/component/model_validation.py @@ -227,6 +227,10 @@ def create_agent( raise ValueError(f"Invalid model_type: {model_type}") if platform is None: raise ValueError(f"Invalid model_platform: {model_platform}") + if str(platform).lower() == "anthropic": + model_config_dict = dict(model_config_dict or {}) + if model_config_dict.get("max_tokens") is None: + model_config_dict["max_tokens"] = 4096 model = ModelFactory.create( model_platform=platform, model_type=mtype, @@ -326,6 +330,10 @@ def validate_model_with_details( "Creating model", extra={"platform": model_platform, "model_type": model_type}, ) + if str(model_platform).lower() == "anthropic": + model_config_dict = dict(model_config_dict or {}) + if model_config_dict.get("max_tokens") is None: + model_config_dict["max_tokens"] = 4096 model = ModelFactory.create( model_platform=model_platform, model_type=model_type,