aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Reinier van der Leer <pwuts@agpt.co> 2024-04-18 21:48:22 +0200
committerGravatar Reinier van der Leer <pwuts@agpt.co> 2024-04-22 17:47:15 +0200
commit7bb7c30842a84b42ea9cf724864470a5bf4ab982 (patch)
treeac5f70a19f735a8bfbb37a6fe0f05248a7397d41
parentrefactor(agent): Add `ChatModelProvider.get_available_models()` and remove `A... (diff)
downloadAuto-GPT-7bb7c30842a84b42ea9cf724864470a5bf4ab982.tar.gz
Auto-GPT-7bb7c30842a84b42ea9cf724864470a5bf4ab982.tar.bz2
Auto-GPT-7bb7c30842a84b42ea9cf724864470a5bf4ab982.zip
feat(agent/core): Add `max_output_tokens` parameter to `create_chat_completion` interface
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/openai.py7
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/schema.py1
-rw-r--r--autogpts/autogpt/autogpt/processing/text.py2
3 files changed, 8 insertions, 2 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
index 91fe09d9d..2ebb56638 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
@@ -416,12 +416,17 @@ class OpenAIProvider(
model_name: OpenAIModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
+ max_output_tokens: Optional[int] = None,
**kwargs,
) -> ChatModelResponse[_T]:
"""Create a completion using the OpenAI API."""
openai_messages, completion_kwargs = self._get_chat_completion_args(
- model_prompt, model_name, functions, **kwargs
+ model_prompt=model_prompt,
+ model_name=model_name,
+ functions=functions,
+ max_tokens=max_output_tokens,
+ **kwargs,
)
tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs)
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
index e9c81a4d3..327718c11 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
@@ -357,6 +357,7 @@ class ChatModelProvider(ModelProvider):
model_name: str,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
+ max_output_tokens: Optional[int] = None,
**kwargs,
) -> ChatModelResponse[_T]:
...
diff --git a/autogpts/autogpt/autogpt/processing/text.py b/autogpts/autogpt/autogpt/processing/text.py
index 8e5c0794b..4cebbabd6 100644
--- a/autogpts/autogpt/autogpt/processing/text.py
+++ b/autogpts/autogpt/autogpt/processing/text.py
@@ -160,7 +160,7 @@ async def _process_text(
model_prompt=prompt.messages,
model_name=model,
temperature=0.5,
- max_tokens=max_result_tokens,
+ max_output_tokens=max_result_tokens,
completion_parser=lambda s: (
extract_list_from_json(s.content) if output_type is not str else None
),