diff options
3 files changed, 8 insertions, 2 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py index 91fe09d9d..2ebb56638 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py @@ -416,12 +416,17 @@ class OpenAIProvider( model_name: OpenAIModelName, completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, functions: Optional[list[CompletionModelFunction]] = None, + max_output_tokens: Optional[int] = None, **kwargs, ) -> ChatModelResponse[_T]: """Create a completion using the OpenAI API.""" openai_messages, completion_kwargs = self._get_chat_completion_args( - model_prompt, model_name, functions, **kwargs + model_prompt=model_prompt, + model_name=model_name, + functions=functions, + max_tokens=max_output_tokens, + **kwargs, ) tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs) diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py index e9c81a4d3..327718c11 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py @@ -357,6 +357,7 @@ class ChatModelProvider(ModelProvider): model_name: str, completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, functions: Optional[list[CompletionModelFunction]] = None, + max_output_tokens: Optional[int] = None, **kwargs, ) -> ChatModelResponse[_T]: ... diff --git a/autogpts/autogpt/autogpt/processing/text.py b/autogpts/autogpt/autogpt/processing/text.py index 8e5c0794b..4cebbabd6 100644 --- a/autogpts/autogpt/autogpt/processing/text.py +++ b/autogpts/autogpt/autogpt/processing/text.py @@ -160,7 +160,7 @@ async def _process_text( model_prompt=prompt.messages, model_name=model, temperature=0.5, - max_tokens=max_result_tokens, + max_output_tokens=max_result_tokens, completion_parser=lambda s: ( extract_list_from_json(s.content) if output_type is not str else None ), |