From 7bb7c30842a84b42ea9cf724864470a5bf4ab982 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Thu, 18 Apr 2024 21:48:22 +0200 Subject: feat(agent/core): Add `max_output_tokens` parameter to `create_chat_completion` interface --- autogpts/autogpt/autogpt/core/resource/model_providers/openai.py | 7 ++++++- autogpts/autogpt/autogpt/core/resource/model_providers/schema.py | 1 + autogpts/autogpt/autogpt/processing/text.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py index 91fe09d9d..2ebb56638 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py @@ -416,12 +416,17 @@ class OpenAIProvider( model_name: OpenAIModelName, completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, functions: Optional[list[CompletionModelFunction]] = None, + max_output_tokens: Optional[int] = None, **kwargs, ) -> ChatModelResponse[_T]: """Create a completion using the OpenAI API.""" openai_messages, completion_kwargs = self._get_chat_completion_args( - model_prompt, model_name, functions, **kwargs + model_prompt=model_prompt, + model_name=model_name, + functions=functions, + max_tokens=max_output_tokens, + **kwargs, ) tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs) diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py index e9c81a4d3..327718c11 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py @@ -357,6 +357,7 @@ class ChatModelProvider(ModelProvider): model_name: str, completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, functions: Optional[list[CompletionModelFunction]] = None, + max_output_tokens: Optional[int] = None, **kwargs, ) -> ChatModelResponse[_T]: ... diff --git a/autogpts/autogpt/autogpt/processing/text.py b/autogpts/autogpt/autogpt/processing/text.py index 8e5c0794b..4cebbabd6 100644 --- a/autogpts/autogpt/autogpt/processing/text.py +++ b/autogpts/autogpt/autogpt/processing/text.py @@ -160,7 +160,7 @@ async def _process_text( model_prompt=prompt.messages, model_name=model, temperature=0.5, - max_tokens=max_result_tokens, + max_output_tokens=max_result_tokens, completion_parser=lambda s: ( extract_list_from_json(s.content) if output_type is not str else None ), -- cgit v1.2.3