aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/models/action_history.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/models/action_history.py')
-rw-r--r--autogpts/autogpt/autogpt/models/action_history.py177
1 files changed, 177 insertions, 0 deletions
diff --git a/autogpts/autogpt/autogpt/models/action_history.py b/autogpts/autogpt/autogpt/models/action_history.py
new file mode 100644
index 000000000..b36d7e540
--- /dev/null
+++ b/autogpts/autogpt/autogpt/models/action_history.py
@@ -0,0 +1,177 @@
+from __future__ import annotations
+
+from typing import Any, Iterator, Literal, Optional
+
+from pydantic import BaseModel, Field
+
+from autogpt.prompts.utils import format_numbered_list, indent
+
+
+class Action(BaseModel):
+ name: str
+ args: dict[str, Any]
+ reasoning: str
+
+ def format_call(self) -> str:
+ return (
+ f"{self.name}"
+ f"({', '.join([f'{a}={repr(v)}' for a, v in self.args.items()])})"
+ )
+
+
+class ActionSuccessResult(BaseModel):
+ outputs: Any
+ status: Literal["success"] = "success"
+
+ def __str__(self) -> str:
+ outputs = str(self.outputs).replace("```", r"\```")
+ multiline = "\n" in outputs
+ return f"```\n{self.outputs}\n```" if multiline else str(self.outputs)
+
+
+class ErrorInfo(BaseModel):
+ args: tuple
+ message: str
+ exception_type: str
+ repr: str
+
+ @staticmethod
+ def from_exception(exception: Exception) -> ErrorInfo:
+ return ErrorInfo(
+ args=exception.args,
+ message=getattr(exception, "message", exception.args[0]),
+ exception_type=exception.__class__.__name__,
+ repr=repr(exception),
+ )
+
+ def __str__(self):
+ return repr(self)
+
+ def __repr__(self):
+ return self.repr
+
+
+class ActionErrorResult(BaseModel):
+ reason: str
+ error: Optional[ErrorInfo] = None
+ status: Literal["error"] = "error"
+
+ @staticmethod
+ def from_exception(exception: Exception) -> ActionErrorResult:
+ return ActionErrorResult(
+ reason=getattr(exception, "message", exception.args[0]),
+ error=ErrorInfo.from_exception(exception),
+ )
+
+ def __str__(self) -> str:
+ return f"Action failed: '{self.reason}'"
+
+
+class ActionInterruptedByHuman(BaseModel):
+ feedback: str
+ status: Literal["interrupted_by_human"] = "interrupted_by_human"
+
+ def __str__(self) -> str:
+ return (
+ 'The user interrupted the action with the following feedback: "%s"'
+ % self.feedback
+ )
+
+
+ActionResult = ActionSuccessResult | ActionErrorResult | ActionInterruptedByHuman
+
+
+class Episode(BaseModel):
+ action: Action
+ result: ActionResult | None
+
+ def __str__(self) -> str:
+ executed_action = f"Executed `{self.action.format_call()}`"
+ action_result = f": {self.result}" if self.result else "."
+ return executed_action + action_result
+
+
+class EpisodicActionHistory(BaseModel):
+ """Utility container for an action history"""
+
+ episodes: list[Episode] = Field(default_factory=list)
+ cursor: int = 0
+
+ @property
+ def current_episode(self) -> Episode | None:
+ if self.cursor == len(self):
+ return None
+ return self[self.cursor]
+
+ def __getitem__(self, key: int) -> Episode:
+ return self.episodes[key]
+
+ def __iter__(self) -> Iterator[Episode]:
+ return iter(self.episodes)
+
+ def __len__(self) -> int:
+ return len(self.episodes)
+
+ def __bool__(self) -> bool:
+ return len(self.episodes) > 0
+
+ def register_action(self, action: Action) -> None:
+ if not self.current_episode:
+ self.episodes.append(Episode(action=action, result=None))
+ assert self.current_episode
+ elif self.current_episode.action:
+ raise ValueError("Action for current cycle already set")
+
+ def register_result(self, result: ActionResult) -> None:
+ if not self.current_episode:
+ raise RuntimeError("Cannot register result for cycle without action")
+ elif self.current_episode.result:
+ raise ValueError("Result for current cycle already set")
+
+ self.current_episode.result = result
+ self.cursor = len(self.episodes)
+
+ def rewind(self, number_of_episodes: int = 0) -> None:
+ """Resets the history to an earlier state.
+
+ Params:
+ number_of_cycles (int): The number of cycles to rewind. Default is 0.
+ When set to 0, it will only reset the current cycle.
+ """
+ # Remove partial record of current cycle
+ if self.current_episode:
+ if self.current_episode.action and not self.current_episode.result:
+ self.episodes.pop(self.cursor)
+
+ # Rewind the specified number of cycles
+ if number_of_episodes > 0:
+ self.episodes = self.episodes[:-number_of_episodes]
+ self.cursor = len(self.episodes)
+
+ def fmt_list(self) -> str:
+ return format_numbered_list(self.episodes)
+
+ def fmt_paragraph(self) -> str:
+ steps: list[str] = []
+
+ for i, c in enumerate(self.episodes, 1):
+ step = f"### Step {i}: Executed `{c.action.format_call()}`\n"
+ step += f'- **Reasoning:** "{c.action.reasoning}"\n'
+ step += (
+ f"- **Status:** `{c.result.status if c.result else 'did_not_finish'}`\n"
+ )
+ if c.result:
+ if c.result.status == "success":
+ result = str(c.result)
+ result = "\n" + indent(result) if "\n" in result else result
+ step += f"- **Output:** {result}"
+ elif c.result.status == "error":
+ step += f"- **Reason:** {c.result.reason}\n"
+ if c.result.error:
+ step += f"- **Error:** {c.result.error}\n"
+ elif c.result.status == "interrupted_by_human":
+ step += f"- **Feedback:** {c.result.feedback}\n"
+
+ steps.append(step)
+
+ return "\n\n".join(steps)