aboutsummaryrefslogtreecommitdiff
path: root/frontend
diff options
context:
space:
mode:
authorGravatar hunteraraujo <hunter_araujo@msn.com> 2023-10-08 22:28:57 -0700
committerGravatar hunteraraujo <hunter_araujo@msn.com> 2023-10-08 22:28:57 -0700
commit7e5bdac2a0ade793c7c99ef78569334dd2553c94 (patch)
tree3d76a89bb5d650dffc8afc2d8a74cda521ff5fa3 /frontend
parentbuddy entering the arena (#5638) (diff)
downloadAuto-GPT-7e5bdac2a0ade793c7c99ef78569334dd2553c94.tar.gz
Auto-GPT-7e5bdac2a0ade793c7c99ef78569334dd2553c94.tar.bz2
Auto-GPT-7e5bdac2a0ade793c7c99ef78569334dd2553c94.zip
Add TaskQueueViewModel for managing benchmark tasks and leaderboard submissions
Diffstat (limited to 'frontend')
-rw-r--r--frontend/lib/viewmodels/task_queue_viewmodel.dart273
1 files changed, 273 insertions, 0 deletions
diff --git a/frontend/lib/viewmodels/task_queue_viewmodel.dart b/frontend/lib/viewmodels/task_queue_viewmodel.dart
new file mode 100644
index 000000000..42f840cd0
--- /dev/null
+++ b/frontend/lib/viewmodels/task_queue_viewmodel.dart
@@ -0,0 +1,273 @@
+import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_run.dart';
+import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_step_request_body.dart';
+import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_request_body.dart';
+import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_status.dart';
+import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_edge.dart';
+import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_node.dart';
+import 'package:auto_gpt_flutter_client/models/step.dart';
+import 'package:auto_gpt_flutter_client/models/task.dart';
+import 'package:auto_gpt_flutter_client/models/test_option.dart';
+import 'package:auto_gpt_flutter_client/models/test_suite.dart';
+import 'package:auto_gpt_flutter_client/services/benchmark_service.dart';
+import 'package:auto_gpt_flutter_client/services/leaderboard_service.dart';
+import 'package:auto_gpt_flutter_client/viewmodels/chat_viewmodel.dart';
+import 'package:auto_gpt_flutter_client/viewmodels/task_viewmodel.dart';
+import 'package:collection/collection.dart';
+import 'package:flutter/foundation.dart';
+import 'package:uuid/uuid.dart';
+import 'package:auto_gpt_flutter_client/utils/stack.dart';
+
+class TaskQueueViewModel extends ChangeNotifier {
+ final BenchmarkService benchmarkService;
+ final LeaderboardService leaderboardService;
+ bool isBenchmarkRunning = false;
+ Map<SkillTreeNode, BenchmarkTaskStatus> benchmarkStatusMap = {};
+ List<BenchmarkRun> currentBenchmarkRuns = [];
+ List<SkillTreeNode>? _selectedNodeHierarchy;
+ TestOption _selectedOption = TestOption.runSingleTest;
+
+ TestOption get selectedOption => _selectedOption;
+ List<SkillTreeNode>? get selectedNodeHierarchy => _selectedNodeHierarchy;
+
+ TaskQueueViewModel(this.benchmarkService, this.leaderboardService);
+
+ void updateSelectedNodeHierarchyBasedOnOption(
+ TestOption selectedOption,
+ SkillTreeNode? selectedNode,
+ List<SkillTreeNode> nodes,
+ List<SkillTreeEdge> edges) {
+ _selectedOption = selectedOption;
+ switch (selectedOption) {
+ case TestOption.runSingleTest:
+ _selectedNodeHierarchy = selectedNode != null ? [selectedNode] : [];
+ break;
+
+ case TestOption.runTestSuiteIncludingSelectedNodeAndAncestors:
+ if (selectedNode != null) {
+ populateSelectedNodeHierarchy(selectedNode.id, nodes, edges);
+ }
+ break;
+
+ case TestOption.runAllTestsInCategory:
+ if (selectedNode != null) {
+ _getAllNodesInDepthFirstOrderEnsuringParents(nodes, edges);
+ }
+ break;
+ }
+ notifyListeners();
+ }
+
+ void _getAllNodesInDepthFirstOrderEnsuringParents(
+ List<SkillTreeNode> skillTreeNodes, List<SkillTreeEdge> skillTreeEdges) {
+ var nodes = <SkillTreeNode>[];
+ var stack = Stack<SkillTreeNode>();
+ var visited = <String>{};
+
+ // Identify the root node by its label
+ var root = skillTreeNodes.firstWhere((node) => node.label == "WriteFile");
+
+ stack.push(root);
+ visited.add(root.id);
+
+ while (stack.isNotEmpty) {
+ var node = stack.peek(); // Peek the top node, but do not remove it yet
+ var parents =
+ _getParentsOfNodeUsingEdges(node.id, skillTreeNodes, skillTreeEdges);
+
+ // Check if all parents are visited
+ if (parents.every((parent) => visited.contains(parent.id))) {
+ nodes.add(node);
+ stack.pop(); // Remove the node only when all its parents are visited
+
+ // Get the children of the current node using edges
+ var children = _getChildrenOfNodeUsingEdges(
+ node.id, skillTreeNodes, skillTreeEdges)
+ .where((child) => !visited.contains(child.id));
+
+ children.forEach((child) {
+ visited.add(child.id);
+ stack.push(child);
+ });
+ } else {
+ stack
+ .pop(); // Remove the node if not all parents are visited, it will be re-added when its parents are visited
+ }
+ }
+
+ _selectedNodeHierarchy = nodes;
+ }
+
+ List<SkillTreeNode> _getParentsOfNodeUsingEdges(
+ String nodeId, List<SkillTreeNode> nodes, List<SkillTreeEdge> edges) {
+ var parents = <SkillTreeNode>[];
+
+ for (var edge in edges) {
+ if (edge.to == nodeId) {
+ parents.add(nodes.firstWhere((node) => node.id == edge.from));
+ }
+ }
+
+ return parents;
+ }
+
+ List<SkillTreeNode> _getChildrenOfNodeUsingEdges(
+ String nodeId, List<SkillTreeNode> nodes, List<SkillTreeEdge> edges) {
+ var children = <SkillTreeNode>[];
+
+ for (var edge in edges) {
+ if (edge.from == nodeId) {
+ children.add(nodes.firstWhere((node) => node.id == edge.to));
+ }
+ }
+
+ return children;
+ }
+
+ // TODO: Do we want to continue testing other branches of tree if one branch side fails benchmarking?
+ void populateSelectedNodeHierarchy(String startNodeId,
+ List<SkillTreeNode> nodes, List<SkillTreeEdge> edges) {
+ _selectedNodeHierarchy = <SkillTreeNode>[];
+ final addedNodes = <String>{};
+ recursivePopulateHierarchy(startNodeId, addedNodes, nodes, edges);
+ notifyListeners();
+ }
+
+ void recursivePopulateHierarchy(String nodeId, Set<String> addedNodes,
+ List<SkillTreeNode> nodes, List<SkillTreeEdge> edges) {
+ // Find the current node in the skill tree nodes list.
+ final currentNode = nodes.firstWhereOrNull((node) => node.id == nodeId);
+
+ // If the node is found and it hasn't been added yet, proceed with the population.
+ if (currentNode != null && addedNodes.add(currentNode.id)) {
+ // Find all parent edges for the current node.
+ final parentEdges = edges.where((edge) => edge.to == currentNode.id);
+
+ // For each parent edge found, recurse to the parent node.
+ for (final parentEdge in parentEdges) {
+ // Recurse to the parent node identified by the 'from' field of the edge.
+ recursivePopulateHierarchy(parentEdge.from, addedNodes, nodes, edges);
+ }
+
+ // After processing all parent nodes, add the current node to the list.
+ _selectedNodeHierarchy!.add(currentNode);
+ }
+ }
+
+ Future<void> runBenchmark(
+ ChatViewModel chatViewModel, TaskViewModel taskViewModel) async {
+ // Clear the benchmarkStatusList
+ benchmarkStatusMap.clear();
+
+ // Reset the current benchmark runs list to be empty at the start of a new benchmark
+ currentBenchmarkRuns = [];
+
+ // Create a new TestSuite object with the current timestamp
+ final testSuite =
+ TestSuite(timestamp: DateTime.now().toIso8601String(), tests: []);
+
+ // Set the benchmark running flag to true
+ isBenchmarkRunning = true;
+ // Notify listeners
+ notifyListeners();
+
+ // Populate benchmarkStatusList with node hierarchy
+ for (var node in _selectedNodeHierarchy!) {
+ benchmarkStatusMap[node] = BenchmarkTaskStatus.notStarted;
+ }
+
+ try {
+ // Loop through the nodes in the hierarchy
+ for (var node in _selectedNodeHierarchy!) {
+ benchmarkStatusMap[node] = BenchmarkTaskStatus.inProgress;
+ notifyListeners();
+
+ // Create a BenchmarkTaskRequestBody
+ final benchmarkTaskRequestBody = BenchmarkTaskRequestBody(
+ input: node.data.task, evalId: node.data.evalId);
+
+ // Create a new benchmark task
+ final createdTask = await benchmarkService
+ .createBenchmarkTask(benchmarkTaskRequestBody);
+
+ // Create a new Task object
+ final task =
+ Task(id: createdTask['task_id'], title: createdTask['input']);
+
+ // Update the current task ID in ChatViewModel
+ chatViewModel.setCurrentTaskId(task.id);
+
+ // Execute the first step and initialize the Step object
+ Map<String, dynamic> stepResponse =
+ await benchmarkService.executeBenchmarkStep(
+ task.id, BenchmarkStepRequestBody(input: node.data.task));
+ Step step = Step.fromMap(stepResponse);
+ chatViewModel.fetchChatsForTask();
+
+ // Check if it's the last step
+ while (!step.isLast) {
+ // Execute next step and update the Step object
+ stepResponse = await benchmarkService.executeBenchmarkStep(
+ task.id, BenchmarkStepRequestBody(input: null));
+ step = Step.fromMap(stepResponse);
+
+ // Fetch chats for the task
+ chatViewModel.fetchChatsForTask();
+ }
+
+ // Trigger the evaluation
+ final evaluationResponse =
+ await benchmarkService.triggerEvaluation(task.id);
+
+ // Decode the evaluationResponse into a BenchmarkRun object
+ BenchmarkRun benchmarkRun = BenchmarkRun.fromJson(evaluationResponse);
+
+ // Add the benchmark run object to the list of current benchmark runs
+ currentBenchmarkRuns.add(benchmarkRun);
+
+ // Update the benchmarkStatusList based on the evaluation response
+ bool successStatus = benchmarkRun.metrics.success;
+ benchmarkStatusMap[node] = successStatus
+ ? BenchmarkTaskStatus.success
+ : BenchmarkTaskStatus.failure;
+ await Future.delayed(Duration(seconds: 1));
+ notifyListeners();
+
+ testSuite.tests.add(task);
+ // If successStatus is false, break out of the loop
+ if (!successStatus) {
+ print(
+ "Benchmark for node ${node.id} failed. Stopping all benchmarks.");
+ break;
+ }
+ }
+
+ // Add the TestSuite to the TaskViewModel
+ taskViewModel.addTestSuite(testSuite);
+ } catch (e) {
+ print("Error while running benchmark: $e");
+ }
+
+ // Reset the benchmark running flag
+ isBenchmarkRunning = false;
+ notifyListeners();
+ }
+
+ Future<void> submitToLeaderboard(
+ String teamName, String repoUrl, String agentGitCommitSha) async {
+ // Create a UUID.v4 for our unique run ID
+ String uuid = const Uuid().v4();
+
+ for (var run in currentBenchmarkRuns) {
+ run.repositoryInfo.teamName = teamName;
+ run.repositoryInfo.repoUrl = repoUrl;
+ run.repositoryInfo.agentGitCommitSha = agentGitCommitSha;
+ run.runDetails.runId = uuid;
+
+ await leaderboardService.submitReport(run);
+ print('Completed submission to leaderboard!');
+ }
+
+ // Clear the currentBenchmarkRuns list after submitting to the leaderboard
+ currentBenchmarkRuns.clear();
+ }
+}