diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/agents/agent.py')
-rw-r--r-- | autogpts/autogpt/autogpt/agents/agent.py | 135 |
1 files changed, 68 insertions, 67 deletions
diff --git a/autogpts/autogpt/autogpt/agents/agent.py b/autogpts/autogpt/autogpt/agents/agent.py index 0ffd2db2a..3572cbed0 100644 --- a/autogpts/autogpt/autogpt/agents/agent.py +++ b/autogpts/autogpt/autogpt/agents/agent.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Optional import sentry_sdk from pydantic import Field -from autogpt.agents.prompt_strategies.one_shot import OneShotAgentPromptStrategy from autogpt.commands.execute_code import CodeExecutorComponent from autogpt.commands.git_operations import GitOperationsComponent from autogpt.commands.image_gen import ImageGeneratorComponent @@ -19,9 +18,11 @@ from autogpt.commands.web_selenium import WebSeleniumComponent from autogpt.components.event_history import EventHistoryComponent from autogpt.core.configuration import Configurable from autogpt.core.prompting import ChatPrompt -from autogpt.core.resource.model_providers import ChatMessage, ChatModelProvider -from autogpt.core.resource.model_providers.schema import ( +from autogpt.core.resource.model_providers import ( AssistantChatMessage, + AssistantFunctionCall, + ChatMessage, + ChatModelProvider, ChatModelResponse, ) from autogpt.core.runner.client_lib.logging.helpers import dump_prompt @@ -33,12 +34,12 @@ from autogpt.logs.log_cycle import ( USER_INPUT_FILE_NAME, LogCycleHandler, ) -from autogpt.logs.utils import fmt_kwargs from autogpt.models.action_history import ( ActionErrorResult, ActionInterruptedByHuman, ActionResult, ActionSuccessResult, + EpisodicActionHistory, ) from autogpt.models.command import Command, CommandOutput from autogpt.utils.exceptions import ( @@ -49,15 +50,14 @@ from autogpt.utils.exceptions import ( UnknownCommandError, ) -from .base import ( - BaseAgent, - BaseAgentConfiguration, - BaseAgentSettings, - ThoughtProcessOutput, -) +from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings from .features.agent_file_manager import FileManagerComponent from .features.context import ContextComponent from .features.watchdog import WatchdogComponent +from .prompt_strategies.one_shot import ( + OneShotAgentActionProposal, + OneShotAgentPromptStrategy, +) from .protocols import ( AfterExecute, AfterParse, @@ -79,6 +79,11 @@ class AgentConfiguration(BaseAgentConfiguration): class AgentSettings(BaseAgentSettings): config: AgentConfiguration = Field(default_factory=AgentConfiguration) + history: EpisodicActionHistory[OneShotAgentActionProposal] = Field( + default_factory=EpisodicActionHistory[OneShotAgentActionProposal] + ) + """(STATE) The action history of the agent.""" + class Agent(BaseAgent, Configurable[AgentSettings]): default_settings: AgentSettings = AgentSettings( @@ -137,7 +142,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]): self.event_history = settings.history self.legacy_config = legacy_config - async def propose_action(self) -> ThoughtProcessOutput: + async def propose_action(self) -> OneShotAgentActionProposal: """Proposes the next action to execute, based on the task and current state. Returns: @@ -188,12 +193,12 @@ class Agent(BaseAgent, Configurable[AgentSettings]): async def complete_and_parse( self, prompt: ChatPrompt, exception: Optional[Exception] = None - ) -> ThoughtProcessOutput: + ) -> OneShotAgentActionProposal: if exception: prompt.messages.append(ChatMessage.system(f"Error: {exception}")) response: ChatModelResponse[ - ThoughtProcessOutput + OneShotAgentActionProposal ] = await self.llm_provider.create_chat_completion( prompt.messages, model_name=self.llm.name, @@ -210,7 +215,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]): self.state.ai_profile.ai_name, self.created_at, self.config.cycle_count, - result.thoughts, + result.thoughts.dict(), NEXT_ACTION_FILE_NAME, ) @@ -220,13 +225,13 @@ class Agent(BaseAgent, Configurable[AgentSettings]): def parse_and_validate_response( self, llm_response: AssistantChatMessage - ) -> ThoughtProcessOutput: + ) -> OneShotAgentActionProposal: parsed_response = self.prompt_strategy.parse_response_content(llm_response) # Validate command arguments - command_name = parsed_response.command_name + command_name = parsed_response.use_tool.name command = self._get_command(command_name) - if arg_errors := command.validate_args(parsed_response.command_args)[1]: + if arg_errors := command.validate_args(parsed_response.use_tool.arguments)[1]: fmt_errors = [ f"{'.'.join(str(p) for p in f.path)}: {f.message}" if f.path @@ -242,49 +247,50 @@ class Agent(BaseAgent, Configurable[AgentSettings]): async def execute( self, - command_name: str, - command_args: dict[str, str] = {}, - user_input: str = "", + proposal: OneShotAgentActionProposal, + user_feedback: str = "", ) -> ActionResult: - result: ActionResult - - if command_name == "human_feedback": - result = ActionInterruptedByHuman(feedback=user_input) - self.log_cycle_handler.log_cycle( - self.state.ai_profile.ai_name, - self.created_at, - self.config.cycle_count, - user_input, - USER_INPUT_FILE_NAME, + tool = proposal.use_tool + + # Get commands + self.commands = await self.run_pipeline(CommandProvider.get_commands) + self._remove_disabled_commands() + + try: + return_value = await self._execute_tool(tool) + + result = ActionSuccessResult(outputs=return_value) + except AgentTerminated: + raise + except AgentException as e: + result = ActionErrorResult.from_exception(e) + logger.warning(f"{tool} raised an error: {e}") + sentry_sdk.capture_exception(e) + + result_tlength = self.llm_provider.count_tokens(str(result), self.llm.name) + if result_tlength > self.send_token_limit // 3: + result = ActionErrorResult( + reason=f"Command {tool.name} returned too much output. " + "Do not execute this command again with the same arguments." ) - else: - # Get commands - self.commands = await self.run_pipeline(CommandProvider.get_commands) - self._remove_disabled_commands() - - try: - return_value = await self._execute_command( - command_name=command_name, - arguments=command_args, - ) - - result = ActionSuccessResult(outputs=return_value) - except AgentTerminated: - raise - except AgentException as e: - result = ActionErrorResult.from_exception(e) - logger.warning( - f"{command_name}({fmt_kwargs(command_args)}) raised an error: {e}" - ) - sentry_sdk.capture_exception(e) - - result_tlength = self.llm_provider.count_tokens(str(result), self.llm.name) - if result_tlength > self.send_token_limit // 3: - result = ActionErrorResult( - reason=f"Command {command_name} returned too much output. " - "Do not execute this command again with the same arguments." - ) + await self.run_pipeline(AfterExecute.after_execute, result) + + logger.debug("\n".join(self.trace)) + + return result + + async def do_not_execute( + self, denied_proposal: OneShotAgentActionProposal, user_feedback: str + ) -> ActionResult: + result = ActionInterruptedByHuman(feedback=user_feedback) + self.log_cycle_handler.log_cycle( + self.state.ai_profile.ai_name, + self.created_at, + self.config.cycle_count, + user_feedback, + USER_INPUT_FILE_NAME, + ) await self.run_pipeline(AfterExecute.after_execute, result) @@ -292,24 +298,19 @@ class Agent(BaseAgent, Configurable[AgentSettings]): return result - async def _execute_command( - self, - command_name: str, - arguments: dict[str, str], - ) -> CommandOutput: + async def _execute_tool(self, tool_call: AssistantFunctionCall) -> CommandOutput: """Execute the command and return the result Args: - command_name (str): The name of the command to execute - arguments (dict): The arguments for the command + tool_call (AssistantFunctionCall): The tool call to execute Returns: - str: The result of the command + str: The execution result """ # Execute a native command with the same name or alias, if it exists - command = self._get_command(command_name) + command = self._get_command(tool_call.name) try: - result = command(**arguments) + result = command(**tool_call.arguments) if inspect.isawaitable(result): return await result return result |