aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Reinier van der Leer <pwuts@agpt.co> 2024-05-04 20:33:25 +0200
committerGravatar GitHub <noreply@github.com> 2024-05-04 20:33:25 +0200
commit39c46ef6be4e7772bc2a39e8b5b72066391ec689 (patch)
tree7e6c13e560b918207f9818dc868ca957258a7abd
parentCreate .pr_agent.toml (diff)
downloadAuto-GPT-39c46ef6be4e7772bc2a39e8b5b72066391ec689.tar.gz
Auto-GPT-39c46ef6be4e7772bc2a39e8b5b72066391ec689.tar.bz2
Auto-GPT-39c46ef6be4e7772bc2a39e8b5b72066391ec689.zip
feat(agent/core): Add Anthropic Claude 3 support (#7085)
- feat(agent/core): Add `AnthropicProvider` - Add `ANTHROPIC_API_KEY` to .env.template and docs Notable differences in logic compared to `OpenAIProvider`: - Merges subsequent user messages in `AnthropicProvider._get_chat_completion_args` - Merges and extracts all system messages into `system` parameter in `AnthropicProvider._get_chat_completion_args` - Supports prefill; merges prefill content (if any) into generated response - Prompt changes to improve compatibility with `AnthropicProvider` Anthropic has a slightly different API compared to OpenAI, and has much stricter input validation. E.g. Anthropic only supports a single `system` prompt, where OpenAI allows multiple `system` messages. Anthropic also forbids sequences of multiple `user` or `assistant` messages and requires that messages alternate between roles. - Move response format instruction from separate message into main system prompt - Fix clock message format - Add pre-fill to `OneShot` generated prompt - refactor(agent/core): Tweak `model_providers.schema` - Simplify `ModelProviderUsage` - Remove attribute `total_tokens` as it is always equal to `prompt_tokens + completion_tokens` - Modify signature of `update_usage(..)`; no longer requires a full `ModelResponse` object as input - Improve `ModelProviderBudget` - Change type of attribute `usage` to `defaultdict[str, ModelProviderUsage]` -> allow per-model usage tracking - Modify signature of `update_usage_and_cost(..)`; no longer requires a full `ModelResponse` object as input - Allow `ModelProviderBudget` zero-argument instantiation - Fix type of `AssistantChatMessage.role` to match `ChatMessage.role` (str -> `ChatMessage.Role`) - Add shared attributes and constructor to `ModelProvider` base class - Add `max_output_tokens` parameter to `create_chat_completion` interface - Add pre-filling as a global feature - Add `prefill_response` field to `ChatPrompt` model - Add `prefill_response` parameter to `create_chat_completion` interface - Add `ChatModelProvider.get_available_models()` and remove `ApiManager` - Remove unused `OpenAIChatParser` typedef in openai.py - Remove redundant `budget` attribute definition on `OpenAISettings` - Remove unnecessary `usage` in `OpenAIProvider` > `default_settings` > `budget` - feat(agent): Allow use of any available LLM provider through `MultiProvider` - Add `MultiProvider` (`model_providers.multi`) - Replace all references to / uses of `OpenAIProvider` with `MultiProvider` - Change type of `Config.smart_llm` and `Config.fast_llm` from `str` to `ModelName` - feat(agent/core): Validate function call arguments in `create_chat_completion` - Add `validate_call` method to `CompletionModelFunction` in `model_providers.schema` - Add `validate_tool_calls` utility function in `model_providers.utils` - Add tool call validation step to `create_chat_completion` in `OpenAIProvider` and `AnthropicProvider` - Remove (now redundant) command argument validation logic in agent.py and models/command.py - refactor(agent): Rename `get_openai_command_specs` to `function_specs_from_commands`
-rw-r--r--autogpts/autogpt/.env.template7
-rw-r--r--autogpts/autogpt/agbenchmark_config/benchmarks.py10
-rw-r--r--autogpts/autogpt/autogpt/agents/agent.py43
-rw-r--r--autogpts/autogpt/autogpt/agents/base.py15
-rw-r--r--autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py42
-rw-r--r--autogpts/autogpt/autogpt/app/agent_protocol_server.py25
-rw-r--r--autogpts/autogpt/autogpt/app/configurator.py10
-rw-r--r--autogpts/autogpt/autogpt/app/main.py30
-rw-r--r--autogpts/autogpt/autogpt/commands/system.py4
-rw-r--r--autogpts/autogpt/autogpt/config/config.py10
-rw-r--r--autogpts/autogpt/autogpt/core/prompting/schema.py1
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/__init__.py4
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/anthropic.py495
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/multi.py162
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/openai.py16
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/schema.py58
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/utils.py71
-rw-r--r--autogpts/autogpt/autogpt/llm/providers/openai.py2
-rw-r--r--autogpts/autogpt/autogpt/models/command.py16
-rw-r--r--autogpts/autogpt/poetry.lock26
-rw-r--r--autogpts/autogpt/pyproject.toml1
-rw-r--r--autogpts/autogpt/tests/conftest.py8
-rw-r--r--autogpts/autogpt/tests/unit/test_config.py15
-rw-r--r--docs/content/AutoGPT/configuration/options.md1
24 files changed, 923 insertions, 149 deletions
diff --git a/autogpts/autogpt/.env.template b/autogpts/autogpt/.env.template
index ba514c56b..14c7bcaa5 100644
--- a/autogpts/autogpt/.env.template
+++ b/autogpts/autogpt/.env.template
@@ -2,8 +2,11 @@
### AutoGPT - GENERAL SETTINGS
################################################################################
-## OPENAI_API_KEY - OpenAI API Key (Example: my-openai-api-key)
-OPENAI_API_KEY=your-openai-api-key
+## OPENAI_API_KEY - OpenAI API Key (Example: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
+# OPENAI_API_KEY=
+
+## ANTHROPIC_API_KEY - Anthropic API Key (Example: sk-ant-api03-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
+# ANTHROPIC_API_KEY=
## TELEMETRY_OPT_IN - Share telemetry on errors and other issues with the AutoGPT team, e.g. through Sentry.
## This helps us to spot and solve problems earlier & faster. (Default: DISABLED)
diff --git a/autogpts/autogpt/agbenchmark_config/benchmarks.py b/autogpts/autogpt/agbenchmark_config/benchmarks.py
index 1281f9fff..c574dc303 100644
--- a/autogpts/autogpt/agbenchmark_config/benchmarks.py
+++ b/autogpts/autogpt/agbenchmark_config/benchmarks.py
@@ -5,8 +5,7 @@ from pathlib import Path
from autogpt.agent_manager.agent_manager import AgentManager
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
-from autogpt.agents.prompt_strategies.one_shot import OneShotAgentPromptStrategy
-from autogpt.app.main import _configure_openai_provider, run_interaction_loop
+from autogpt.app.main import _configure_llm_provider, run_interaction_loop
from autogpt.config import AIProfile, ConfigBuilder
from autogpt.file_storage import FileStorageBackendName, get_storage
from autogpt.logs.config import configure_logging
@@ -38,10 +37,6 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
ai_goals=[task],
)
- agent_prompt_config = OneShotAgentPromptStrategy.default_configuration.copy(
- deep=True
- )
- agent_prompt_config.use_functions_api = config.openai_functions
agent_settings = AgentSettings(
name=Agent.default_settings.name,
agent_id=AgentManager.generate_id("AutoGPT-benchmark"),
@@ -53,7 +48,6 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
allow_fs_access=not config.restrict_to_workspace,
use_functions_api=config.openai_functions,
),
- prompt_config=agent_prompt_config,
history=Agent.default_settings.history.copy(deep=True),
)
@@ -66,7 +60,7 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
agent = Agent(
settings=agent_settings,
- llm_provider=_configure_openai_provider(config),
+ llm_provider=_configure_llm_provider(config),
file_storage=file_storage,
legacy_config=config,
)
diff --git a/autogpts/autogpt/autogpt/agents/agent.py b/autogpts/autogpt/autogpt/agents/agent.py
index 3572cbed0..4a66a7ca4 100644
--- a/autogpts/autogpt/autogpt/agents/agent.py
+++ b/autogpts/autogpt/autogpt/agents/agent.py
@@ -19,7 +19,6 @@ from autogpt.components.event_history import EventHistoryComponent
from autogpt.core.configuration import Configurable
from autogpt.core.prompting import ChatPrompt
from autogpt.core.resource.model_providers import (
- AssistantChatMessage,
AssistantFunctionCall,
ChatMessage,
ChatModelProvider,
@@ -27,7 +26,7 @@ from autogpt.core.resource.model_providers import (
)
from autogpt.core.runner.client_lib.logging.helpers import dump_prompt
from autogpt.file_storage.base import FileStorage
-from autogpt.llm.providers.openai import get_openai_command_specs
+from autogpt.llm.providers.openai import function_specs_from_commands
from autogpt.logs.log_cycle import (
CURRENT_CONTEXT_FILE_NAME,
NEXT_ACTION_FILE_NAME,
@@ -46,7 +45,6 @@ from autogpt.utils.exceptions import (
AgentException,
AgentTerminated,
CommandExecutionError,
- InvalidArgumentError,
UnknownCommandError,
)
@@ -104,7 +102,11 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
self.ai_profile = settings.ai_profile
self.directives = settings.directives
prompt_config = OneShotAgentPromptStrategy.default_configuration.copy(deep=True)
- prompt_config.use_functions_api = settings.config.use_functions_api
+ prompt_config.use_functions_api = (
+ settings.config.use_functions_api
+ # Anthropic currently doesn't support tools + prefilling :(
+ and self.llm.provider_name != "anthropic"
+ )
self.prompt_strategy = OneShotAgentPromptStrategy(prompt_config, logger)
self.commands: list[Command] = []
@@ -172,7 +174,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
task=self.state.task,
ai_profile=self.state.ai_profile,
ai_directives=directives,
- commands=get_openai_command_specs(self.commands),
+ commands=function_specs_from_commands(self.commands),
include_os_info=self.legacy_config.execute_local_commands,
)
@@ -202,12 +204,9 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
] = await self.llm_provider.create_chat_completion(
prompt.messages,
model_name=self.llm.name,
- completion_parser=self.parse_and_validate_response,
- functions=(
- get_openai_command_specs(self.commands)
- if self.config.use_functions_api
- else []
- ),
+ completion_parser=self.prompt_strategy.parse_response_content,
+ functions=prompt.functions,
+ prefill_response=prompt.prefill_response,
)
result = response.parsed_result
@@ -223,28 +222,6 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
return result
- def parse_and_validate_response(
- self, llm_response: AssistantChatMessage
- ) -> OneShotAgentActionProposal:
- parsed_response = self.prompt_strategy.parse_response_content(llm_response)
-
- # Validate command arguments
- command_name = parsed_response.use_tool.name
- command = self._get_command(command_name)
- if arg_errors := command.validate_args(parsed_response.use_tool.arguments)[1]:
- fmt_errors = [
- f"{'.'.join(str(p) for p in f.path)}: {f.message}"
- if f.path
- else f.message
- for f in arg_errors
- ]
- raise InvalidArgumentError(
- f"The set of arguments supplied for {command_name} is invalid:\n"
- + "\n".join(fmt_errors)
- )
-
- return parsed_response
-
async def execute(
self,
proposal: OneShotAgentActionProposal,
diff --git a/autogpts/autogpt/autogpt/agents/base.py b/autogpts/autogpt/autogpt/agents/base.py
index cf8e3cac8..515515701 100644
--- a/autogpts/autogpt/autogpt/agents/base.py
+++ b/autogpts/autogpt/autogpt/agents/base.py
@@ -39,11 +39,12 @@ from autogpt.core.configuration import (
SystemSettings,
UserConfigurable,
)
-from autogpt.core.resource.model_providers import AssistantFunctionCall
-from autogpt.core.resource.model_providers.openai import (
- OPEN_AI_CHAT_MODELS,
- OpenAIModelName,
+from autogpt.core.resource.model_providers import (
+ CHAT_MODELS,
+ AssistantFunctionCall,
+ ModelName,
)
+from autogpt.core.resource.model_providers.openai import OpenAIModelName
from autogpt.models.utils import ModelWithSummary
from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT
@@ -56,8 +57,8 @@ P = ParamSpec("P")
class BaseAgentConfiguration(SystemConfiguration):
allow_fs_access: bool = UserConfigurable(default=False)
- fast_llm: OpenAIModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
- smart_llm: OpenAIModelName = UserConfigurable(default=OpenAIModelName.GPT4)
+ fast_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
+ smart_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT4)
use_functions_api: bool = UserConfigurable(default=False)
default_cycle_instruction: str = DEFAULT_TRIGGERING_PROMPT
@@ -174,7 +175,7 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
llm_name = (
self.config.smart_llm if self.config.big_brain else self.config.fast_llm
)
- return OPEN_AI_CHAT_MODELS[llm_name]
+ return CHAT_MODELS[llm_name]
@property
def send_token_limit(self) -> int:
diff --git a/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py b/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py
index 53fadaa7c..ff08f4669 100644
--- a/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py
+++ b/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py
@@ -122,7 +122,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
1. System prompt
3. `cycle_instruction`
"""
- system_prompt = self.build_system_prompt(
+ system_prompt, response_prefill = self.build_system_prompt(
ai_profile=ai_profile,
ai_directives=ai_directives,
commands=commands,
@@ -131,24 +131,34 @@ class OneShotAgentPromptStrategy(PromptStrategy):
final_instruction_msg = ChatMessage.user(self.config.choose_action_instruction)
- prompt = ChatPrompt(
+ return ChatPrompt(
messages=[
ChatMessage.system(system_prompt),
ChatMessage.user(f'"""{task}"""'),
*messages,
final_instruction_msg,
],
+ prefill_response=response_prefill,
+ functions=commands if self.config.use_functions_api else [],
)
- return prompt
-
def build_system_prompt(
self,
ai_profile: AIProfile,
ai_directives: AIDirectives,
commands: list[CompletionModelFunction],
include_os_info: bool,
- ) -> str:
+ ) -> tuple[str, str]:
+ """
+ Builds the system prompt.
+
+ Returns:
+ str: The system prompt body
+ str: The desired start for the LLM's response; used to steer the output
+ """
+ response_fmt_instruction, response_prefill = self.response_format_instruction(
+ self.config.use_functions_api
+ )
system_prompt_parts = (
self._generate_intro_prompt(ai_profile)
+ (self._generate_os_info() if include_os_info else [])
@@ -169,16 +179,16 @@ class OneShotAgentPromptStrategy(PromptStrategy):
" in the next message. Your job is to complete the task while following"
" your directives as given above, and terminate when your task is done."
]
- + [
- "## RESPONSE FORMAT\n"
- + self.response_format_instruction(self.config.use_functions_api)
- ]
+ + ["## RESPONSE FORMAT\n" + response_fmt_instruction]
)
# Join non-empty parts together into paragraph format
- return "\n\n".join(filter(None, system_prompt_parts)).strip("\n")
+ return (
+ "\n\n".join(filter(None, system_prompt_parts)).strip("\n"),
+ response_prefill,
+ )
- def response_format_instruction(self, use_functions_api: bool) -> str:
+ def response_format_instruction(self, use_functions_api: bool) -> tuple[str, str]:
response_schema = self.response_schema.copy(deep=True)
if (
use_functions_api
@@ -193,11 +203,15 @@ class OneShotAgentPromptStrategy(PromptStrategy):
"\n",
response_schema.to_typescript_object_interface(_RESPONSE_INTERFACE_NAME),
)
+ response_prefill = f'{{\n "{list(response_schema.properties.keys())[0]}":'
return (
- f"YOU MUST ALWAYS RESPOND WITH A JSON OBJECT OF THE FOLLOWING TYPE:\n"
- f"{response_format}"
- + ("\n\nYOU MUST ALSO INVOKE A TOOL!" if use_functions_api else "")
+ (
+ f"YOU MUST ALWAYS RESPOND WITH A JSON OBJECT OF THE FOLLOWING TYPE:\n"
+ f"{response_format}"
+ + ("\n\nYOU MUST ALSO INVOKE A TOOL!" if use_functions_api else "")
+ ),
+ response_prefill,
)
def _generate_intro_prompt(self, ai_profile: AIProfile) -> list[str]:
diff --git a/autogpts/autogpt/autogpt/app/agent_protocol_server.py b/autogpts/autogpt/autogpt/app/agent_protocol_server.py
index cdaf1f460..2eb09706e 100644
--- a/autogpts/autogpt/autogpt/app/agent_protocol_server.py
+++ b/autogpts/autogpt/autogpt/app/agent_protocol_server.py
@@ -34,7 +34,6 @@ from autogpt.agent_manager import AgentManager
from autogpt.app.utils import is_port_free
from autogpt.config import Config
from autogpt.core.resource.model_providers import ChatModelProvider, ModelProviderBudget
-from autogpt.core.resource.model_providers.openai import OpenAIProvider
from autogpt.file_storage import FileStorage
from autogpt.models.action_history import ActionErrorResult, ActionSuccessResult
from autogpt.utils.exceptions import AgentFinished
@@ -464,20 +463,18 @@ class AgentProtocolServer:
if task.additional_input and (user_id := task.additional_input.get("user_id")):
_extra_request_headers["AutoGPT-UserID"] = user_id
- task_llm_provider = None
- if isinstance(self.llm_provider, OpenAIProvider):
- settings = self.llm_provider._settings.copy()
- settings.budget = task_llm_budget
- settings.configuration = task_llm_provider_config # type: ignore
- task_llm_provider = OpenAIProvider(
- settings=settings,
- logger=logger.getChild(f"Task-{task.task_id}_OpenAIProvider"),
- )
-
- if task_llm_provider and task_llm_provider._budget:
- self._task_budgets[task.task_id] = task_llm_provider._budget
+ settings = self.llm_provider._settings.copy()
+ settings.budget = task_llm_budget
+ settings.configuration = task_llm_provider_config
+ task_llm_provider = self.llm_provider.__class__(
+ settings=settings,
+ logger=logger.getChild(
+ f"Task-{task.task_id}_{self.llm_provider.__class__.__name__}"
+ ),
+ )
+ self._task_budgets[task.task_id] = task_llm_provider._budget # type: ignore
- return task_llm_provider or self.llm_provider
+ return task_llm_provider
def task_agent_id(task_id: str | int) -> str:
diff --git a/autogpts/autogpt/autogpt/app/configurator.py b/autogpts/autogpt/autogpt/app/configurator.py
index 085d5dbf1..2463b6fcf 100644
--- a/autogpts/autogpt/autogpt/app/configurator.py
+++ b/autogpts/autogpt/autogpt/app/configurator.py
@@ -10,7 +10,7 @@ from colorama import Back, Fore, Style
from autogpt.config import Config
from autogpt.config.config import GPT_3_MODEL, GPT_4_MODEL
-from autogpt.core.resource.model_providers.openai import OpenAIModelName, OpenAIProvider
+from autogpt.core.resource.model_providers import ModelName, MultiProvider
from autogpt.logs.helpers import request_user_double_check
from autogpt.memory.vector import get_supported_memory_backends
from autogpt.utils import utils
@@ -150,11 +150,11 @@ async def apply_overrides_to_config(
async def check_model(
- model_name: OpenAIModelName, model_type: Literal["smart_llm", "fast_llm"]
-) -> OpenAIModelName:
+ model_name: ModelName, model_type: Literal["smart_llm", "fast_llm"]
+) -> ModelName:
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
- openai = OpenAIProvider()
- models = await openai.get_available_models()
+ multi_provider = MultiProvider()
+ models = await multi_provider.get_available_models()
if any(model_name == m.name for m in models):
return model_name
diff --git a/autogpts/autogpt/autogpt/app/main.py b/autogpts/autogpt/autogpt/app/main.py
index aaab5fe48..04354fb10 100644
--- a/autogpts/autogpt/autogpt/app/main.py
+++ b/autogpts/autogpt/autogpt/app/main.py
@@ -35,7 +35,7 @@ from autogpt.config import (
ConfigBuilder,
assert_config_has_openai_api_key,
)
-from autogpt.core.resource.model_providers.openai import OpenAIProvider
+from autogpt.core.resource.model_providers import MultiProvider
from autogpt.core.runner.client_lib.utils import coroutine
from autogpt.file_storage import FileStorageBackendName, get_storage
from autogpt.logs.config import configure_logging
@@ -123,7 +123,7 @@ async def run_auto_gpt(
skip_news=skip_news,
)
- llm_provider = _configure_openai_provider(config)
+ llm_provider = _configure_llm_provider(config)
logger = logging.getLogger(__name__)
@@ -399,7 +399,7 @@ async def run_auto_gpt_server(
allow_downloads=allow_downloads,
)
- llm_provider = _configure_openai_provider(config)
+ llm_provider = _configure_llm_provider(config)
# Set up & start server
database = AgentDB(
@@ -421,24 +421,12 @@ async def run_auto_gpt_server(
)
-def _configure_openai_provider(config: Config) -> OpenAIProvider:
- """Create a configured OpenAIProvider object.
-
- Args:
- config: The program's configuration.
-
- Returns:
- A configured OpenAIProvider object.
- """
- if config.openai_credentials is None:
- raise RuntimeError("OpenAI key is not configured")
-
- openai_settings = OpenAIProvider.default_settings.copy(deep=True)
- openai_settings.credentials = config.openai_credentials
- return OpenAIProvider(
- settings=openai_settings,
- logger=logging.getLogger("OpenAIProvider"),
- )
+def _configure_llm_provider(config: Config) -> MultiProvider:
+ multi_provider = MultiProvider()
+ for model in [config.smart_llm, config.fast_llm]:
+ # Ensure model providers for configured LLMs are available
+ multi_provider.get_model_provider(model)
+ return multi_provider
def _get_cycle_budget(continuous_mode: bool, continuous_limit: int) -> int | float:
diff --git a/autogpts/autogpt/autogpt/commands/system.py b/autogpts/autogpt/autogpt/commands/system.py
index 85ab5c100..ce2640529 100644
--- a/autogpts/autogpt/autogpt/commands/system.py
+++ b/autogpts/autogpt/autogpt/commands/system.py
@@ -31,7 +31,9 @@ class SystemComponent(DirectiveProvider, MessageProvider, CommandProvider):
def get_messages(self) -> Iterator[ChatMessage]:
# Clock
- yield ChatMessage.system(f"The current time and date is {time.strftime('%c')}")
+ yield ChatMessage.system(
+ f"## Clock\nThe current time and date is {time.strftime('%c')}"
+ )
def get_commands(self) -> Iterator[Command]:
yield self.finish
diff --git a/autogpts/autogpt/autogpt/config/config.py b/autogpts/autogpt/autogpt/config/config.py
index 8dd648188..11d35b673 100644
--- a/autogpts/autogpt/autogpt/config/config.py
+++ b/autogpts/autogpt/autogpt/config/config.py
@@ -17,8 +17,8 @@ from autogpt.core.configuration.schema import (
SystemSettings,
UserConfigurable,
)
+from autogpt.core.resource.model_providers import CHAT_MODELS, ModelName
from autogpt.core.resource.model_providers.openai import (
- OPEN_AI_CHAT_MODELS,
OpenAICredentials,
OpenAIModelName,
)
@@ -74,11 +74,11 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
)
# Model configuration
- fast_llm: OpenAIModelName = UserConfigurable(
+ fast_llm: ModelName = UserConfigurable(
default=OpenAIModelName.GPT3,
from_env="FAST_LLM",
)
- smart_llm: OpenAIModelName = UserConfigurable(
+ smart_llm: ModelName = UserConfigurable(
default=OpenAIModelName.GPT4_TURBO,
from_env="SMART_LLM",
)
@@ -206,8 +206,8 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
if v:
smart_llm = values["smart_llm"]
- assert OPEN_AI_CHAT_MODELS[smart_llm].has_function_call_api, (
- f"Model {smart_llm} does not support OpenAI Functions. "
+ assert CHAT_MODELS[smart_llm].has_function_call_api, (
+ f"Model {smart_llm} does not support tool calling. "
"Please disable OPENAI_FUNCTIONS or choose a suitable model."
)
return v
diff --git a/autogpts/autogpt/autogpt/core/prompting/schema.py b/autogpts/autogpt/autogpt/core/prompting/schema.py
index 45efc40fe..fcc7c6b61 100644
--- a/autogpts/autogpt/autogpt/core/prompting/schema.py
+++ b/autogpts/autogpt/autogpt/core/prompting/schema.py
@@ -24,6 +24,7 @@ class LanguageModelClassification(str, enum.Enum):
class ChatPrompt(BaseModel):
messages: list[ChatMessage]
functions: list[CompletionModelFunction] = Field(default_factory=list)
+ prefill_response: str = ""
def raw(self) -> list[ChatMessageDict]:
return [m.dict() for m in self.messages]
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/__init__.py b/autogpts/autogpt/autogpt/core/resource/model_providers/__init__.py
index b896760d2..7fb98170e 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/__init__.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/__init__.py
@@ -1,3 +1,4 @@
+from .multi import CHAT_MODELS, ModelName, MultiProvider
from .openai import (
OPEN_AI_CHAT_MODELS,
OPEN_AI_EMBEDDING_MODELS,
@@ -42,11 +43,13 @@ __all__ = [
"ChatModelProvider",
"ChatModelResponse",
"CompletionModelFunction",
+ "CHAT_MODELS",
"Embedding",
"EmbeddingModelInfo",
"EmbeddingModelProvider",
"EmbeddingModelResponse",
"ModelInfo",
+ "ModelName",
"ModelProvider",
"ModelProviderBudget",
"ModelProviderCredentials",
@@ -56,6 +59,7 @@ __all__ = [
"ModelProviderUsage",
"ModelResponse",
"ModelTokenizer",
+ "MultiProvider",
"OPEN_AI_MODELS",
"OPEN_AI_CHAT_MODELS",
"OPEN_AI_EMBEDDING_MODELS",
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/anthropic.py b/autogpts/autogpt/autogpt/core/resource/model_providers/anthropic.py
new file mode 100644
index 000000000..3d5967f1c
--- /dev/null
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/anthropic.py
@@ -0,0 +1,495 @@
+from __future__ import annotations
+
+import enum
+import logging
+from typing import TYPE_CHECKING, Callable, Optional, ParamSpec, TypeVar
+
+import sentry_sdk
+import tenacity
+import tiktoken
+from anthropic import APIConnectionError, APIStatusError
+from pydantic import SecretStr
+
+from autogpt.core.configuration import Configurable, UserConfigurable
+from autogpt.core.resource.model_providers.schema import (
+ AssistantChatMessage,
+ AssistantFunctionCall,
+ AssistantToolCall,
+ ChatMessage,
+ ChatModelInfo,
+ ChatModelProvider,
+ ChatModelResponse,
+ CompletionModelFunction,
+ ModelProviderBudget,
+ ModelProviderConfiguration,
+ ModelProviderCredentials,
+ ModelProviderName,
+ ModelProviderSettings,
+ ModelTokenizer,
+ ToolResultMessage,
+)
+
+from .utils import validate_tool_calls
+
+if TYPE_CHECKING:
+ from anthropic.types.beta.tools import MessageCreateParams
+ from anthropic.types.beta.tools import ToolsBetaMessage as Message
+ from anthropic.types.beta.tools import ToolsBetaMessageParam as MessageParam
+
+_T = TypeVar("_T")
+_P = ParamSpec("_P")
+
+
+class AnthropicModelName(str, enum.Enum):
+ CLAUDE3_OPUS_v1 = "claude-3-opus-20240229"
+ CLAUDE3_SONNET_v1 = "claude-3-sonnet-20240229"
+ CLAUDE3_HAIKU_v1 = "claude-3-haiku-20240307"
+
+
+ANTHROPIC_CHAT_MODELS = {
+ info.name: info
+ for info in [
+ ChatModelInfo(
+ name=AnthropicModelName.CLAUDE3_OPUS_v1,
+ provider_name=ModelProviderName.ANTHROPIC,
+ prompt_token_cost=15 / 1e6,
+ completion_token_cost=75 / 1e6,
+ max_tokens=200000,
+ has_function_call_api=True,
+ ),
+ ChatModelInfo(
+ name=AnthropicModelName.CLAUDE3_SONNET_v1,
+ provider_name=ModelProviderName.ANTHROPIC,
+ prompt_token_cost=3 / 1e6,
+ completion_token_cost=15 / 1e6,
+ max_tokens=200000,
+ has_function_call_api=True,
+ ),
+ ChatModelInfo(
+ name=AnthropicModelName.CLAUDE3_HAIKU_v1,
+ provider_name=ModelProviderName.ANTHROPIC,
+ prompt_token_cost=0.25 / 1e6,
+ completion_token_cost=1.25 / 1e6,
+ max_tokens=200000,
+ has_function_call_api=True,
+ ),
+ ]
+}
+
+
+class AnthropicConfiguration(ModelProviderConfiguration):
+ fix_failed_parse_tries: int = UserConfigurable(3)
+
+
+class AnthropicCredentials(ModelProviderCredentials):
+ """Credentials for Anthropic."""
+
+ api_key: SecretStr = UserConfigurable(from_env="ANTHROPIC_API_KEY")
+ api_base: Optional[SecretStr] = UserConfigurable(
+ default=None, from_env="ANTHROPIC_API_BASE_URL"
+ )
+
+ def get_api_access_kwargs(self) -> dict[str, str]:
+ return {
+ k: (v.get_secret_value() if type(v) is SecretStr else v)
+ for k, v in {
+ "api_key": self.api_key,
+ "base_url": self.api_base,
+ }.items()
+ if v is not None
+ }
+
+
+class AnthropicSettings(ModelProviderSettings):
+ configuration: AnthropicConfiguration
+ credentials: Optional[AnthropicCredentials]
+ budget: ModelProviderBudget
+
+
+class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
+ default_settings = AnthropicSettings(
+ name="anthropic_provider",
+ description="Provides access to Anthropic's API.",
+ configuration=AnthropicConfiguration(
+ retries_per_request=7,
+ ),
+ credentials=None,
+ budget=ModelProviderBudget(),
+ )
+
+ _settings: AnthropicSettings
+ _configuration: AnthropicConfiguration
+ _credentials: AnthropicCredentials
+ _budget: ModelProviderBudget
+
+ def __init__(
+ self,
+ settings: Optional[AnthropicSettings] = None,
+ logger: Optional[logging.Logger] = None,
+ ):
+ if not settings:
+ settings = self.default_settings.copy(deep=True)
+ if not settings.credentials:
+ settings.credentials = AnthropicCredentials.from_env()
+
+ super(AnthropicProvider, self).__init__(settings=settings, logger=logger)
+
+ from anthropic import AsyncAnthropic
+
+ self._client = AsyncAnthropic(**self._credentials.get_api_access_kwargs())
+
+ async def get_available_models(self) -> list[ChatModelInfo]:
+ return list(ANTHROPIC_CHAT_MODELS.values())
+
+ def get_token_limit(self, model_name: str) -> int:
+ """Get the token limit for a given model."""
+ return ANTHROPIC_CHAT_MODELS[model_name].max_tokens
+
+ @classmethod
+ def get_tokenizer(cls, model_name: AnthropicModelName) -> ModelTokenizer:
+ # HACK: No official tokenizer is available for Claude 3
+ return tiktoken.encoding_for_model(model_name)
+
+ @classmethod
+ def count_tokens(cls, text: str, model_name: AnthropicModelName) -> int:
+ return 0 # HACK: No official tokenizer is available for Claude 3
+
+ @classmethod
+ def count_message_tokens(
+ cls,
+ messages: ChatMessage | list[ChatMessage],
+ model_name: AnthropicModelName,
+ ) -> int:
+ return 0 # HACK: No official tokenizer is available for Claude 3
+
+ async def create_chat_completion(
+ self,
+ model_prompt: list[ChatMessage],
+ model_name: AnthropicModelName,
+ completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
+ functions: Optional[list[CompletionModelFunction]] = None,
+ max_output_tokens: Optional[int] = None,
+ prefill_response: str = "",
+ **kwargs,
+ ) -> ChatModelResponse[_T]:
+ """Create a completion using the Anthropic API."""
+ anthropic_messages, completion_kwargs = self._get_chat_completion_args(
+ prompt_messages=model_prompt,
+ model=model_name,
+ functions=functions,
+ max_output_tokens=max_output_tokens,
+ **kwargs,
+ )
+
+ total_cost = 0.0
+ attempts = 0
+ while True:
+ completion_kwargs["messages"] = anthropic_messages.copy()
+ if prefill_response:
+ completion_kwargs["messages"].append(
+ {"role": "assistant", "content": prefill_response}
+ )
+
+ (
+ _assistant_msg,
+ cost,
+ t_input,
+ t_output,
+ ) = await self._create_chat_completion(completion_kwargs)
+ total_cost += cost
+ self._logger.debug(
+ f"Completion usage: {t_input} input, {t_output} output "
+ f"- ${round(cost, 5)}"
+ )
+
+ # Merge prefill into generated response
+ if prefill_response:
+ first_text_block = next(
+ b for b in _assistant_msg.content if b.type == "text"
+ )
+ first_text_block.text = prefill_response + first_text_block.text
+
+ assistant_msg = AssistantChatMessage(
+ content="\n\n".join(
+ b.text for b in _assistant_msg.content if b.type == "text"
+ ),
+ tool_calls=self._parse_assistant_tool_calls(_assistant_msg),
+ )
+
+ # If parsing the response fails, append the error to the prompt, and let the
+ # LLM fix its mistake(s).
+ attempts += 1
+ tool_call_errors = []
+ try:
+ # Validate tool calls
+ if assistant_msg.tool_calls and functions:
+ tool_call_errors = validate_tool_calls(
+ assistant_msg.tool_calls, functions
+ )
+ if tool_call_errors:
+ raise ValueError(
+ "Invalid tool use(s):\n"
+ + "\n".join(str(e) for e in tool_call_errors)
+ )
+
+ parsed_result = completion_parser(assistant_msg)
+ break
+ except Exception as e:
+ self._logger.debug(
+ f"Parsing failed on response: '''{_assistant_msg}'''"
+ )
+ self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
+ sentry_sdk.capture_exception(
+ error=e,
+ extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
+ )
+ if attempts < self._configuration.fix_failed_parse_tries:
+ anthropic_messages.append(
+ _assistant_msg.dict(include={"role", "content"})
+ )
+ anthropic_messages.append(
+ {
+ "role": "user",
+ "content": [
+ *(
+ # tool_result is required if last assistant message
+ # had tool_use block(s)
+ {
+ "type": "tool_result",
+ "tool_use_id": tc.id,
+ "is_error": True,
+ "content": [
+ {
+ "type": "text",
+ "text": "Not executed because parsing "
+ "of your last message failed"
+ if not tool_call_errors
+ else str(e)
+ if (
+ e := next(
+ (
+ tce
+ for tce in tool_call_errors
+ if tce.name
+ == tc.function.name
+ ),
+ None,
+ )
+ )
+ else "Not executed because validation "
+ "of tool input failed",
+ }
+ ],
+ }
+ for tc in assistant_msg.tool_calls or []
+ ),
+ {
+ "type": "text",
+ "text": (
+ "ERROR PARSING YOUR RESPONSE:\n\n"
+ f"{e.__class__.__name__}: {e}"
+ ),
+ },
+ ],
+ }
+ )
+ else:
+ raise
+
+ if attempts > 1:
+ self._logger.debug(
+ f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
+ )
+
+ return ChatModelResponse(
+ response=assistant_msg,
+ parsed_result=parsed_result,
+ model_info=ANTHROPIC_CHAT_MODELS[model_name],
+ prompt_tokens_used=t_input,
+ completion_tokens_used=t_output,
+ )
+
+ def _get_chat_completion_args(
+ self,
+ prompt_messages: list[ChatMessage],
+ model: AnthropicModelName,
+ functions: Optional[list[CompletionModelFunction]] = None,
+ max_output_tokens: Optional[int] = None,
+ **kwargs,
+ ) -> tuple[list[MessageParam], MessageCreateParams]:
+ """Prepare arguments for message completion API call.
+
+ Args:
+ prompt_messages: List of ChatMessages.
+ model: The model to use.
+ functions: Optional list of functions available to the LLM.
+ kwargs: Additional keyword arguments.
+
+ Returns:
+ list[MessageParam]: Prompt messages for the Anthropic call
+ dict[str, Any]: Any other kwargs for the Anthropic call
+ """
+ kwargs["model"] = model
+
+ if functions:
+ kwargs["tools"] = [
+ {
+ "name": f.name,
+ "description": f.description,
+ "input_schema": {
+ "type": "object",
+ "properties": {
+ name: param.to_dict()
+ for name, param in f.parameters.items()
+ },
+ "required": [
+ name
+ for name, param in f.parameters.items()
+ if param.required
+ ],
+ },
+ }
+ for f in functions
+ ]
+
+ kwargs["max_tokens"] = max_output_tokens or 4096
+
+ if extra_headers := self._configuration.extra_request_headers:
+ kwargs["extra_headers"] = kwargs.get("extra_headers", {})
+ kwargs["extra_headers"].update(extra_headers.copy())
+
+ system_messages = [
+ m for m in prompt_messages if m.role == ChatMessage.Role.SYSTEM
+ ]
+ if (_n := len(system_messages)) > 1:
+ self._logger.warning(
+ f"Prompt has {_n} system messages; Anthropic supports only 1. "
+ "They will be merged, and removed from the rest of the prompt."
+ )
+ kwargs["system"] = "\n\n".join(sm.content for sm in system_messages)
+
+ messages: list[MessageParam] = []
+ for message in prompt_messages:
+ if message.role == ChatMessage.Role.SYSTEM:
+ continue
+ elif message.role == ChatMessage.Role.USER:
+ # Merge subsequent user messages
+ if messages and (prev_msg := messages[-1])["role"] == "user":
+ if isinstance(prev_msg["content"], str):
+ prev_msg["content"] += f"\n\n{message.content}"
+ else:
+ assert isinstance(prev_msg["content"], list)
+ prev_msg["content"].append(
+ {"type": "text", "text": message.content}
+ )
+ else:
+ messages.append({"role": "user", "content": message.content})
+ # TODO: add support for image blocks
+ elif message.role == ChatMessage.Role.ASSISTANT:
+ if isinstance(message, AssistantChatMessage) and message.tool_calls:
+ messages.append(
+ {
+ "role": "assistant",
+ "content": [
+ *(
+ [{"type": "text", "text": message.content}]
+ if message.content
+ else []
+ ),
+ *(
+ {
+ "type": "tool_use",
+ "id": tc.id,
+ "name": tc.function.name,
+ "input": tc.function.arguments,
+ }
+ for tc in message.tool_calls
+ ),
+ ],
+ }
+ )
+ elif message.content:
+ messages.append(
+ {
+ "role": "assistant",
+ "content": message.content,
+ }
+ )
+ elif isinstance(message, ToolResultMessage):
+ messages.append(
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "tool_result",
+ "tool_use_id": message.tool_call_id,
+ "content": [{"type": "text", "text": message.content}],
+ "is_error": message.is_error,
+ }
+ ],
+ }
+ )
+
+ return messages, kwargs # type: ignore
+
+ async def _create_chat_completion(
+ self, completion_kwargs: MessageCreateParams
+ ) -> tuple[Message, float, int, int]:
+ """
+ Create a chat completion using the Anthropic API with retry handling.
+
+ Params:
+ completion_kwargs: Keyword arguments for an Anthropic Messages API call
+
+ Returns:
+ Message: The message completion object
+ float: The cost ($) of this completion
+ int: Number of input tokens used
+ int: Number of output tokens used
+ """
+
+ @self._retry_api_request
+ async def _create_chat_completion_with_retry(
+ completion_kwargs: MessageCreateParams,
+ ) -> Message:
+ return await self._client.beta.tools.messages.create(
+ **completion_kwargs # type: ignore
+ )
+
+ response = await _create_chat_completion_with_retry(completion_kwargs)
+
+ cost = self._budget.update_usage_and_cost(
+ model_info=ANTHROPIC_CHAT_MODELS[completion_kwargs["model"]],
+ input_tokens_used=response.usage.input_tokens,
+ output_tokens_used=response.usage.output_tokens,
+ )
+ return response, cost, response.usage.input_tokens, response.usage.output_tokens
+
+ def _parse_assistant_tool_calls(
+ self, assistant_message: Message
+ ) -> list[AssistantToolCall]:
+ return [
+ AssistantToolCall(
+ id=c.id,
+ type="function",
+ function=AssistantFunctionCall(name=c.name, arguments=c.input),
+ )
+ for c in assistant_message.content
+ if c.type == "tool_use"
+ ]
+
+ def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
+ return tenacity.retry(
+ retry=(
+ tenacity.retry_if_exception_type(APIConnectionError)
+ | tenacity.retry_if_exception(
+ lambda e: isinstance(e, APIStatusError) and e.status_code >= 500
+ )
+ ),
+ wait=tenacity.wait_exponential(),
+ stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
+ after=tenacity.after_log(self._logger, logging.DEBUG),
+ )(func)
+
+ def __repr__(self):
+ return "AnthropicProvider()"
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/multi.py b/autogpts/autogpt/autogpt/core/resource/model_providers/multi.py
new file mode 100644
index 000000000..f194e0256
--- /dev/null
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/multi.py
@@ -0,0 +1,162 @@
+from __future__ import annotations
+
+import logging
+from typing import Callable, Iterator, Optional, TypeVar
+
+from pydantic import ValidationError
+
+from autogpt.core.configuration import Configurable
+
+from .anthropic import ANTHROPIC_CHAT_MODELS, AnthropicModelName, AnthropicProvider
+from .openai import OPEN_AI_CHAT_MODELS, OpenAIModelName, OpenAIProvider
+from .schema import (
+ AssistantChatMessage,
+ ChatMessage,
+ ChatModelInfo,
+ ChatModelProvider,
+ ChatModelResponse,
+ CompletionModelFunction,
+ ModelProviderBudget,
+ ModelProviderConfiguration,
+ ModelProviderName,
+ ModelProviderSettings,
+ ModelTokenizer,
+)
+
+_T = TypeVar("_T")
+
+ModelName = AnthropicModelName | OpenAIModelName
+
+CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
+
+
+class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
+ default_settings = ModelProviderSettings(
+ name="multi_provider",
+ description=(
+ "Provides access to all of the available models, regardless of provider."
+ ),
+ configuration=ModelProviderConfiguration(
+ retries_per_request=7,
+ ),
+ budget=ModelProviderBudget(),
+ )
+
+ _budget: ModelProviderBudget
+
+ _provider_instances: dict[ModelProviderName, ChatModelProvider]
+
+ def __init__(
+ self,
+ settings: Optional[ModelProviderSettings] = None,
+ logger: Optional[logging.Logger] = None,
+ ):
+ super(MultiProvider, self).__init__(settings=settings, logger=logger)
+ self._budget = self._settings.budget or ModelProviderBudget()
+
+ self._provider_instances = {}
+
+ async def get_available_models(self) -> list[ChatModelInfo]:
+ models = []
+ for provider in self.get_available_providers():
+ models.extend(await provider.get_available_models())
+ return models
+
+ def get_token_limit(self, model_name: ModelName) -> int:
+ """Get the token limit for a given model."""
+ return self.get_model_provider(model_name).get_token_limit(model_name)
+
+ @classmethod
+ def get_tokenizer(cls, model_name: ModelName) -> ModelTokenizer:
+ return cls._get_model_provider_class(model_name).get_tokenizer(model_name)
+
+ @classmethod
+ def count_tokens(cls, text: str, model_name: ModelName) -> int:
+ return cls._get_model_provider_class(model_name).count_tokens(
+ text=text, model_name=model_name
+ )
+
+ @classmethod
+ def count_message_tokens(
+ cls, messages: ChatMessage | list[ChatMessage], model_name: ModelName
+ ) -> int:
+ return cls._get_model_provider_class(model_name).count_message_tokens(
+ messages=messages, model_name=model_name
+ )
+
+ async def create_chat_completion(
+ self,
+ model_prompt: list[ChatMessage],
+ model_name: ModelName,
+ completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
+ functions: Optional[list[CompletionModelFunction]] = None,
+ max_output_tokens: Optional[int] = None,
+ prefill_response: str = "",
+ **kwargs,
+ ) -> ChatModelResponse[_T]:
+ """Create a completion using the Anthropic API."""
+ return await self.get_model_provider(model_name).create_chat_completion(
+ model_prompt=model_prompt,
+ model_name=model_name,
+ completion_parser=completion_parser,
+ functions=functions,
+ max_output_tokens=max_output_tokens,
+ prefill_response=prefill_response,
+ **kwargs,
+ )
+
+ def get_model_provider(self, model: ModelName) -> ChatModelProvider:
+ model_info = CHAT_MODELS[model]
+ return self._get_provider(model_info.provider_name)
+
+ def get_available_providers(self) -> Iterator[ChatModelProvider]:
+ for provider_name in ModelProviderName:
+ try:
+ yield self._get_provider(provider_name)
+ except Exception:
+ pass
+
+ def _get_provider(self, provider_name: ModelProviderName) -> ChatModelProvider:
+ _provider = self._provider_instances.get(provider_name)
+ if not _provider:
+ Provider = self._get_provider_class(provider_name)
+ settings = Provider.default_settings.copy(deep=True)
+ settings.budget = self._budget
+ settings.configuration.extra_request_headers.update(
+ self._settings.configuration.extra_request_headers
+ )
+ if settings.credentials is None:
+ try:
+ Credentials = settings.__fields__["credentials"].type_
+ settings.credentials = Credentials.from_env()
+ except ValidationError as e:
+ raise ValueError(
+ f"{provider_name} is unavailable: can't load credentials"
+ ) from e
+
+ self._provider_instances[provider_name] = _provider = Provider(
+ settings=settings, logger=self._logger
+ )
+ _provider._budget = self._budget # Object binding not preserved by Pydantic
+ return _provider
+
+ @classmethod
+ def _get_model_provider_class(
+ cls, model_name: ModelName
+ ) -> type[AnthropicProvider | OpenAIProvider]:
+ return cls._get_provider_class(CHAT_MODELS[model_name].provider_name)
+
+ @classmethod
+ def _get_provider_class(
+ cls, provider_name: ModelProviderName
+ ) -> type[AnthropicProvider | OpenAIProvider]:
+ try:
+ return {
+ ModelProviderName.ANTHROPIC: AnthropicProvider,
+ ModelProviderName.OPENAI: OpenAIProvider,
+ }[provider_name]
+ except KeyError:
+ raise ValueError(f"{provider_name} is not a known provider") from None
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}()"
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
index 0e50f1fa3..c3c5aabb4 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
@@ -42,6 +42,8 @@ from autogpt.core.resource.model_providers.schema import (
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
+from .utils import validate_tool_calls
+
_T = TypeVar("_T")
_P = ParamSpec("_P")
@@ -298,6 +300,7 @@ class OpenAIProvider(
budget=ModelProviderBudget(),
)
+ _settings: OpenAISettings
_configuration: OpenAIConfiguration
_credentials: OpenAICredentials
_budget: ModelProviderBudget
@@ -312,11 +315,7 @@ class OpenAIProvider(
if not settings.credentials:
settings.credentials = OpenAICredentials.from_env()
- self._settings = settings
-
- self._configuration = settings.configuration
- self._credentials = settings.credentials
- self._budget = settings.budget
+ super(OpenAIProvider, self).__init__(settings=settings, logger=logger)
if self._credentials.api_type == "azure":
from openai import AsyncAzureOpenAI
@@ -329,8 +328,6 @@ class OpenAIProvider(
self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs())
- self._logger = logger or logging.getLogger(__name__)
-
async def get_available_models(self) -> list[ChatModelInfo]:
_models = (await self._client.models.list()).data
return [OPEN_AI_MODELS[m.id] for m in _models if m.id in OPEN_AI_MODELS]
@@ -398,6 +395,7 @@ class OpenAIProvider(
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
+ prefill_response: str = "", # not supported by OpenAI
**kwargs,
) -> ChatModelResponse[_T]:
"""Create a completion using the OpenAI API and parse it."""
@@ -432,6 +430,10 @@ class OpenAIProvider(
)
parse_errors += _errors
+ # Validate tool calls
+ if not parse_errors and tool_calls and functions:
+ parse_errors += validate_tool_calls(tool_calls, functions)
+
assistant_msg = AssistantChatMessage(
content=_assistant_msg.content,
tool_calls=tool_calls or None,
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
index 60df855f2..bb2e29490 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
@@ -1,8 +1,10 @@
import abc
import enum
+import logging
import math
from collections import defaultdict
from typing import (
+ TYPE_CHECKING,
Any,
Callable,
ClassVar,
@@ -28,6 +30,9 @@ from autogpt.core.resource.schema import (
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.logs.utils import fmt_kwargs
+if TYPE_CHECKING:
+ from jsonschema import ValidationError
+
class ModelProviderService(str, enum.Enum):
"""A ModelService describes what kind of service the model provides."""
@@ -39,6 +44,7 @@ class ModelProviderService(str, enum.Enum):
class ModelProviderName(str, enum.Enum):
OPENAI = "openai"
+ ANTHROPIC = "anthropic"
class ChatMessage(BaseModel):
@@ -100,6 +106,12 @@ class AssistantChatMessage(ChatMessage):
tool_calls: Optional[list[AssistantToolCall]] = None
+class ToolResultMessage(ChatMessage):
+ role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL
+ is_error: bool = False
+ tool_call_id: str
+
+
class AssistantChatMessageDict(TypedDict, total=False):
role: str
content: str
@@ -146,6 +158,30 @@ class CompletionModelFunction(BaseModel):
)
return f"{self.name}: {self.description}. Params: ({params})"
+ def validate_call(
+ self, function_call: AssistantFunctionCall
+ ) -> tuple[bool, list["ValidationError"]]:
+ """
+ Validates the given function call against the function's parameter specs
+
+ Returns:
+ bool: Whether the given set of arguments is valid for this command
+ list[ValidationError]: Issues with the set of arguments (if any)
+
+ Raises:
+ ValueError: If the function_call doesn't call this function
+ """
+ if function_call.name != self.name:
+ raise ValueError(
+ f"Can't validate {function_call.name} call using {self.name} spec"
+ )
+
+ params_schema = JSONSchema(
+ type=JSONSchema.Type.OBJECT,
+ properties={name: spec for name, spec in self.parameters.items()},
+ )
+ return params_schema.validate_object(function_call.arguments)
+
class ModelInfo(BaseModel):
"""Struct for model information.
@@ -229,7 +265,7 @@ class ModelProviderBudget(ProviderBudget):
class ModelProviderSettings(ProviderSettings):
resource_type: ResourceType = ResourceType.MODEL
configuration: ModelProviderConfiguration
- credentials: ModelProviderCredentials
+ credentials: Optional[ModelProviderCredentials] = None
budget: Optional[ModelProviderBudget] = None
@@ -238,9 +274,28 @@ class ModelProvider(abc.ABC):
default_settings: ClassVar[ModelProviderSettings]
+ _settings: ModelProviderSettings
_configuration: ModelProviderConfiguration
+ _credentials: Optional[ModelProviderCredentials] = None
_budget: Optional[ModelProviderBudget] = None
+ _logger: logging.Logger
+
+ def __init__(
+ self,
+ settings: Optional[ModelProviderSettings] = None,
+ logger: Optional[logging.Logger] = None,
+ ):
+ if not settings:
+ settings = self.default_settings.copy(deep=True)
+
+ self._settings = settings
+ self._configuration = settings.configuration
+ self._credentials = settings.credentials
+ self._budget = settings.budget
+
+ self._logger = logger or logging.getLogger(self.__module__)
+
@abc.abstractmethod
def count_tokens(self, text: str, model_name: str) -> int:
...
@@ -358,6 +413,7 @@ class ChatModelProvider(ModelProvider):
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
+ prefill_response: str = "",
**kwargs,
) -> ChatModelResponse[_T]:
...
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/utils.py b/autogpts/autogpt/autogpt/core/resource/model_providers/utils.py
new file mode 100644
index 000000000..5b83b047b
--- /dev/null
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/utils.py
@@ -0,0 +1,71 @@
+from typing import Any
+
+from .schema import AssistantToolCall, CompletionModelFunction
+
+
+class InvalidFunctionCallError(Exception):
+ def __init__(self, name: str, arguments: dict[str, Any], message: str):
+ self.message = message
+ self.name = name
+ self.arguments = arguments
+ super().__init__(message)
+
+ def __str__(self) -> str:
+ return f"Invalid function call for {self.name}: {self.message}"
+
+
+def validate_tool_calls(
+ tool_calls: list[AssistantToolCall], functions: list[CompletionModelFunction]
+) -> list[InvalidFunctionCallError]:
+ """
+ Validates a list of tool calls against a list of functions.
+
+ 1. Tries to find a function matching each tool call
+ 2. If a matching function is found, validates the tool call's arguments,
+ reporting any resulting errors
+ 2. If no matching function is found, an error "Unknown function X" is reported
+ 3. A list of all errors encountered during validation is returned
+
+ Params:
+ tool_calls: A list of tool calls to validate.
+ functions: A list of functions to validate against.
+
+ Returns:
+ list[InvalidFunctionCallError]: All errors encountered during validation.
+ """
+ errors: list[InvalidFunctionCallError] = []
+ for tool_call in tool_calls:
+ function_call = tool_call.function
+
+ if function := next(
+ (f for f in functions if f.name == function_call.name),
+ None,
+ ):
+ is_valid, validation_errors = function.validate_call(function_call)
+ if not is_valid:
+ fmt_errors = [
+ f"{'.'.join(str(p) for p in f.path)}: {f.message}"
+ if f.path
+ else f.message
+ for f in validation_errors
+ ]
+ errors.append(
+ InvalidFunctionCallError(
+ name=function_call.name,
+ arguments=function_call.arguments,
+ message=(
+ "The set of arguments supplied is invalid:\n"
+ + "\n".join(fmt_errors)
+ ),
+ )
+ )
+ else:
+ errors.append(
+ InvalidFunctionCallError(
+ name=function_call.name,
+ arguments=function_call.arguments,
+ message=f"Unknown function {function_call.name}",
+ )
+ )
+
+ return errors
diff --git a/autogpts/autogpt/autogpt/llm/providers/openai.py b/autogpts/autogpt/autogpt/llm/providers/openai.py
index 18a9d2b07..e6423827c 100644
--- a/autogpts/autogpt/autogpt/llm/providers/openai.py
+++ b/autogpts/autogpt/autogpt/llm/providers/openai.py
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Callable)
-def get_openai_command_specs(
+def function_specs_from_commands(
commands: Iterable[Command],
) -> list[CompletionModelFunction]:
"""Get OpenAI-consumable function specs for the agent's available commands.
diff --git a/autogpts/autogpt/autogpt/models/command.py b/autogpts/autogpt/autogpt/models/command.py
index c629e5126..29bed5864 100644
--- a/autogpts/autogpt/autogpt/models/command.py
+++ b/autogpts/autogpt/autogpt/models/command.py
@@ -3,8 +3,6 @@ from __future__ import annotations
import inspect
from typing import Any, Callable
-from autogpt.core.utils.json_schema import JSONSchema
-
from .command_parameter import CommandParameter
from .context_item import ContextItem
@@ -42,20 +40,6 @@ class Command:
def is_async(self) -> bool:
return inspect.iscoroutinefunction(self.method)
- def validate_args(self, args: dict[str, Any]):
- """
- Validates the given arguments against the command's parameter specifications
-
- Returns:
- bool: Whether the given set of arguments is valid for this command
- list[ValidationError]: Issues with the set of arguments (if any)
- """
- params_schema = JSONSchema(
- type=JSONSchema.Type.OBJECT,
- properties={p.name: p.spec for p in self.parameters},
- )
- return params_schema.validate_object(args)
-
def _parameters_match(
self, func: Callable, parameters: list[CommandParameter]
) -> bool:
diff --git a/autogpts/autogpt/poetry.lock b/autogpts/autogpt/poetry.lock
index 77b64e79c..251cbb3bd 100644
--- a/autogpts/autogpt/poetry.lock
+++ b/autogpts/autogpt/poetry.lock
@@ -168,6 +168,30 @@ files = [
frozenlist = ">=1.1.0"
[[package]]
+name = "anthropic"
+version = "0.25.1"
+description = "The official Python library for the anthropic API"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "anthropic-0.25.1-py3-none-any.whl", hash = "sha256:95d0cedc2a4b5beae3a78f9030aea4001caea5f46c6d263cce377c891c594e71"},
+ {file = "anthropic-0.25.1.tar.gz", hash = "sha256:0c01b30b77d041a8d07c532737bae69da58086031217150008e4541f52a64bd9"},
+]
+
+[package.dependencies]
+anyio = ">=3.5.0,<5"
+distro = ">=1.7.0,<2"
+httpx = ">=0.23.0,<1"
+pydantic = ">=1.9.0,<3"
+sniffio = "*"
+tokenizers = ">=0.13.0"
+typing-extensions = ">=4.7,<5"
+
+[package.extras]
+bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"]
+vertex = ["google-auth (>=2,<3)"]
+
+[[package]]
name = "anyio"
version = "4.2.0"
description = "High level compatibility layer for multiple asynchronous event loop implementations"
@@ -7234,4 +7258,4 @@ benchmark = ["agbenchmark"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "e6eab5c079d53f075ce701e86a2007e7ebeb635ac067d25f555bfea363bcc630"
+content-hash = "ad1e3c4706465733d04ddab975af630975bd528efce152c1da01eded53069eca"
diff --git a/autogpts/autogpt/pyproject.toml b/autogpts/autogpt/pyproject.toml
index e1b0c32f6..99f58774b 100644
--- a/autogpts/autogpt/pyproject.toml
+++ b/autogpts/autogpt/pyproject.toml
@@ -22,6 +22,7 @@ serve = "autogpt.app.cli:serve"
[tool.poetry.dependencies]
python = "^3.10"
+anthropic = "^0.25.1"
# autogpt-forge = { path = "../forge" }
autogpt-forge = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", subdirectory = "autogpts/forge"}
beautifulsoup4 = "^4.12.2"
diff --git a/autogpts/autogpt/tests/conftest.py b/autogpts/autogpt/tests/conftest.py
index 29479f6f0..64376446d 100644
--- a/autogpts/autogpt/tests/conftest.py
+++ b/autogpts/autogpt/tests/conftest.py
@@ -8,9 +8,9 @@ import pytest
from pytest_mock import MockerFixture
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
-from autogpt.app.main import _configure_openai_provider
+from autogpt.app.main import _configure_llm_provider
from autogpt.config import AIProfile, Config, ConfigBuilder
-from autogpt.core.resource.model_providers import ChatModelProvider, OpenAIProvider
+from autogpt.core.resource.model_providers import ChatModelProvider
from autogpt.file_storage.local import (
FileStorage,
FileStorageConfiguration,
@@ -73,8 +73,8 @@ def setup_logger(config: Config):
@pytest.fixture
-def llm_provider(config: Config) -> OpenAIProvider:
- return _configure_openai_provider(config)
+def llm_provider(config: Config) -> ChatModelProvider:
+ return _configure_llm_provider(config)
@pytest.fixture
diff --git a/autogpts/autogpt/tests/unit/test_config.py b/autogpts/autogpt/tests/unit/test_config.py
index f52efcd8c..d6120dec6 100644
--- a/autogpts/autogpt/tests/unit/test_config.py
+++ b/autogpts/autogpt/tests/unit/test_config.py
@@ -14,7 +14,6 @@ from pydantic import SecretStr
from autogpt.app.configurator import GPT_3_MODEL, GPT_4_MODEL, apply_overrides_to_config
from autogpt.config import Config, ConfigBuilder
-from autogpt.core.resource.model_providers.openai import OpenAIModelName
from autogpt.core.resource.model_providers.schema import (
ChatModelInfo,
ModelProviderName,
@@ -39,8 +38,8 @@ async def test_fallback_to_gpt3_if_gpt4_not_available(
"""
Test if models update to gpt-3.5-turbo if gpt-4 is not available.
"""
- config.fast_llm = OpenAIModelName.GPT4_TURBO
- config.smart_llm = OpenAIModelName.GPT4_TURBO
+ config.fast_llm = GPT_4_MODEL
+ config.smart_llm = GPT_4_MODEL
mock_list_models.return_value = asyncio.Future()
mock_list_models.return_value.set_result(
@@ -56,8 +55,8 @@ async def test_fallback_to_gpt3_if_gpt4_not_available(
gpt4only=False,
)
- assert config.fast_llm == "gpt-3.5-turbo"
- assert config.smart_llm == "gpt-3.5-turbo"
+ assert config.fast_llm == GPT_3_MODEL
+ assert config.smart_llm == GPT_3_MODEL
def test_missing_azure_config(config: Config) -> None:
@@ -148,8 +147,7 @@ def test_azure_config(config_with_azure: Config) -> None:
@pytest.mark.asyncio
async def test_create_config_gpt4only(config: Config) -> None:
with mock.patch(
- "autogpt.core.resource.model_providers.openai."
- "OpenAIProvider.get_available_models"
+ "autogpt.core.resource.model_providers.multi.MultiProvider.get_available_models"
) as mock_get_models:
mock_get_models.return_value = [
ChatModelInfo(
@@ -169,8 +167,7 @@ async def test_create_config_gpt4only(config: Config) -> None:
@pytest.mark.asyncio
async def test_create_config_gpt3only(config: Config) -> None:
with mock.patch(
- "autogpt.core.resource.model_providers.openai."
- "OpenAIProvider.get_available_models"
+ "autogpt.core.resource.model_providers.multi.MultiProvider.get_available_models"
) as mock_get_models:
mock_get_models.return_value = [
ChatModelInfo(
diff --git a/docs/content/AutoGPT/configuration/options.md b/docs/content/AutoGPT/configuration/options.md
index 9003c7378..17602102b 100644
--- a/docs/content/AutoGPT/configuration/options.md
+++ b/docs/content/AutoGPT/configuration/options.md
@@ -7,6 +7,7 @@ Configuration is controlled through the `Config` object. You can set configurati
- `AI_SETTINGS_FILE`: Location of the AI Settings file relative to the AutoGPT root directory. Default: ai_settings.yaml
- `AUDIO_TO_TEXT_PROVIDER`: Audio To Text Provider. Only option currently is `huggingface`. Default: huggingface
- `AUTHORISE_COMMAND_KEY`: Key response accepted when authorising commands. Default: y
+- `ANTHROPIC_API_KEY`: Set this if you want to use Anthropic models with AutoGPT
- `AZURE_CONFIG_FILE`: Location of the Azure Config file relative to the AutoGPT root directory. Default: azure.yaml
- `BROWSE_CHUNK_MAX_LENGTH`: When browsing website, define the length of chunks to summarize. Default: 3000
- `BROWSE_SPACY_LANGUAGE_MODEL`: [spaCy language model](https://spacy.io/usage/models) to use when creating chunks. Default: en_core_web_sm