aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/agents/agent.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/agents/agent.py')
-rw-r--r--autogpts/autogpt/autogpt/agents/agent.py135
1 files changed, 68 insertions, 67 deletions
diff --git a/autogpts/autogpt/autogpt/agents/agent.py b/autogpts/autogpt/autogpt/agents/agent.py
index 0ffd2db2a..3572cbed0 100644
--- a/autogpts/autogpt/autogpt/agents/agent.py
+++ b/autogpts/autogpt/autogpt/agents/agent.py
@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Optional
import sentry_sdk
from pydantic import Field
-from autogpt.agents.prompt_strategies.one_shot import OneShotAgentPromptStrategy
from autogpt.commands.execute_code import CodeExecutorComponent
from autogpt.commands.git_operations import GitOperationsComponent
from autogpt.commands.image_gen import ImageGeneratorComponent
@@ -19,9 +18,11 @@ from autogpt.commands.web_selenium import WebSeleniumComponent
from autogpt.components.event_history import EventHistoryComponent
from autogpt.core.configuration import Configurable
from autogpt.core.prompting import ChatPrompt
-from autogpt.core.resource.model_providers import ChatMessage, ChatModelProvider
-from autogpt.core.resource.model_providers.schema import (
+from autogpt.core.resource.model_providers import (
AssistantChatMessage,
+ AssistantFunctionCall,
+ ChatMessage,
+ ChatModelProvider,
ChatModelResponse,
)
from autogpt.core.runner.client_lib.logging.helpers import dump_prompt
@@ -33,12 +34,12 @@ from autogpt.logs.log_cycle import (
USER_INPUT_FILE_NAME,
LogCycleHandler,
)
-from autogpt.logs.utils import fmt_kwargs
from autogpt.models.action_history import (
ActionErrorResult,
ActionInterruptedByHuman,
ActionResult,
ActionSuccessResult,
+ EpisodicActionHistory,
)
from autogpt.models.command import Command, CommandOutput
from autogpt.utils.exceptions import (
@@ -49,15 +50,14 @@ from autogpt.utils.exceptions import (
UnknownCommandError,
)
-from .base import (
- BaseAgent,
- BaseAgentConfiguration,
- BaseAgentSettings,
- ThoughtProcessOutput,
-)
+from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
from .features.agent_file_manager import FileManagerComponent
from .features.context import ContextComponent
from .features.watchdog import WatchdogComponent
+from .prompt_strategies.one_shot import (
+ OneShotAgentActionProposal,
+ OneShotAgentPromptStrategy,
+)
from .protocols import (
AfterExecute,
AfterParse,
@@ -79,6 +79,11 @@ class AgentConfiguration(BaseAgentConfiguration):
class AgentSettings(BaseAgentSettings):
config: AgentConfiguration = Field(default_factory=AgentConfiguration)
+ history: EpisodicActionHistory[OneShotAgentActionProposal] = Field(
+ default_factory=EpisodicActionHistory[OneShotAgentActionProposal]
+ )
+ """(STATE) The action history of the agent."""
+
class Agent(BaseAgent, Configurable[AgentSettings]):
default_settings: AgentSettings = AgentSettings(
@@ -137,7 +142,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
self.event_history = settings.history
self.legacy_config = legacy_config
- async def propose_action(self) -> ThoughtProcessOutput:
+ async def propose_action(self) -> OneShotAgentActionProposal:
"""Proposes the next action to execute, based on the task and current state.
Returns:
@@ -188,12 +193,12 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
async def complete_and_parse(
self, prompt: ChatPrompt, exception: Optional[Exception] = None
- ) -> ThoughtProcessOutput:
+ ) -> OneShotAgentActionProposal:
if exception:
prompt.messages.append(ChatMessage.system(f"Error: {exception}"))
response: ChatModelResponse[
- ThoughtProcessOutput
+ OneShotAgentActionProposal
] = await self.llm_provider.create_chat_completion(
prompt.messages,
model_name=self.llm.name,
@@ -210,7 +215,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
self.state.ai_profile.ai_name,
self.created_at,
self.config.cycle_count,
- result.thoughts,
+ result.thoughts.dict(),
NEXT_ACTION_FILE_NAME,
)
@@ -220,13 +225,13 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
def parse_and_validate_response(
self, llm_response: AssistantChatMessage
- ) -> ThoughtProcessOutput:
+ ) -> OneShotAgentActionProposal:
parsed_response = self.prompt_strategy.parse_response_content(llm_response)
# Validate command arguments
- command_name = parsed_response.command_name
+ command_name = parsed_response.use_tool.name
command = self._get_command(command_name)
- if arg_errors := command.validate_args(parsed_response.command_args)[1]:
+ if arg_errors := command.validate_args(parsed_response.use_tool.arguments)[1]:
fmt_errors = [
f"{'.'.join(str(p) for p in f.path)}: {f.message}"
if f.path
@@ -242,49 +247,50 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
async def execute(
self,
- command_name: str,
- command_args: dict[str, str] = {},
- user_input: str = "",
+ proposal: OneShotAgentActionProposal,
+ user_feedback: str = "",
) -> ActionResult:
- result: ActionResult
-
- if command_name == "human_feedback":
- result = ActionInterruptedByHuman(feedback=user_input)
- self.log_cycle_handler.log_cycle(
- self.state.ai_profile.ai_name,
- self.created_at,
- self.config.cycle_count,
- user_input,
- USER_INPUT_FILE_NAME,
+ tool = proposal.use_tool
+
+ # Get commands
+ self.commands = await self.run_pipeline(CommandProvider.get_commands)
+ self._remove_disabled_commands()
+
+ try:
+ return_value = await self._execute_tool(tool)
+
+ result = ActionSuccessResult(outputs=return_value)
+ except AgentTerminated:
+ raise
+ except AgentException as e:
+ result = ActionErrorResult.from_exception(e)
+ logger.warning(f"{tool} raised an error: {e}")
+ sentry_sdk.capture_exception(e)
+
+ result_tlength = self.llm_provider.count_tokens(str(result), self.llm.name)
+ if result_tlength > self.send_token_limit // 3:
+ result = ActionErrorResult(
+ reason=f"Command {tool.name} returned too much output. "
+ "Do not execute this command again with the same arguments."
)
- else:
- # Get commands
- self.commands = await self.run_pipeline(CommandProvider.get_commands)
- self._remove_disabled_commands()
-
- try:
- return_value = await self._execute_command(
- command_name=command_name,
- arguments=command_args,
- )
-
- result = ActionSuccessResult(outputs=return_value)
- except AgentTerminated:
- raise
- except AgentException as e:
- result = ActionErrorResult.from_exception(e)
- logger.warning(
- f"{command_name}({fmt_kwargs(command_args)}) raised an error: {e}"
- )
- sentry_sdk.capture_exception(e)
-
- result_tlength = self.llm_provider.count_tokens(str(result), self.llm.name)
- if result_tlength > self.send_token_limit // 3:
- result = ActionErrorResult(
- reason=f"Command {command_name} returned too much output. "
- "Do not execute this command again with the same arguments."
- )
+ await self.run_pipeline(AfterExecute.after_execute, result)
+
+ logger.debug("\n".join(self.trace))
+
+ return result
+
+ async def do_not_execute(
+ self, denied_proposal: OneShotAgentActionProposal, user_feedback: str
+ ) -> ActionResult:
+ result = ActionInterruptedByHuman(feedback=user_feedback)
+ self.log_cycle_handler.log_cycle(
+ self.state.ai_profile.ai_name,
+ self.created_at,
+ self.config.cycle_count,
+ user_feedback,
+ USER_INPUT_FILE_NAME,
+ )
await self.run_pipeline(AfterExecute.after_execute, result)
@@ -292,24 +298,19 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
return result
- async def _execute_command(
- self,
- command_name: str,
- arguments: dict[str, str],
- ) -> CommandOutput:
+ async def _execute_tool(self, tool_call: AssistantFunctionCall) -> CommandOutput:
"""Execute the command and return the result
Args:
- command_name (str): The name of the command to execute
- arguments (dict): The arguments for the command
+ tool_call (AssistantFunctionCall): The tool call to execute
Returns:
- str: The result of the command
+ str: The execution result
"""
# Execute a native command with the same name or alias, if it exists
- command = self._get_command(command_name)
+ command = self._get_command(tool_call.name)
try:
- result = command(**arguments)
+ result = command(**tool_call.arguments)
if inspect.isawaitable(result):
return await result
return result