aboutsummaryrefslogtreecommitdiff
path: root/frontend
diff options
context:
space:
mode:
authorGravatar hunteraraujo <hunter_araujo@msn.com> 2023-10-08 22:33:21 -0700
committerGravatar hunteraraujo <hunter_araujo@msn.com> 2023-10-08 22:33:21 -0700
commita7e27d1a645a2ada0148b307ee317f24f808a74a (patch)
tree0ce771a56dbba12b697665e0e62f05a9361e3b10 /frontend
parentAdd TaskQueueViewModel for managing benchmark tasks and leaderboard submissions (diff)
downloadAuto-GPT-a7e27d1a645a2ada0148b307ee317f24f808a74a.tar.gz
Auto-GPT-a7e27d1a645a2ada0148b307ee317f24f808a74a.tar.bz2
Auto-GPT-a7e27d1a645a2ada0148b307ee317f24f808a74a.zip
Remove duplicate functionality from SkillTreeViewModel
Diffstat (limited to 'frontend')
-rw-r--r--frontend/lib/viewmodels/skill_tree_viewmodel.dart293
1 files changed, 8 insertions, 285 deletions
diff --git a/frontend/lib/viewmodels/skill_tree_viewmodel.dart b/frontend/lib/viewmodels/skill_tree_viewmodel.dart
index 2a17f3adf..5383d127c 100644
--- a/frontend/lib/viewmodels/skill_tree_viewmodel.dart
+++ b/frontend/lib/viewmodels/skill_tree_viewmodel.dart
@@ -1,60 +1,27 @@
import 'dart:convert';
-import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_run.dart';
-import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_step_request_body.dart';
-import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_request_body.dart';
-import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_status.dart';
-import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_category.dart';
-import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_edge.dart';
-import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_node.dart';
-import 'package:auto_gpt_flutter_client/models/step.dart';
-import 'package:auto_gpt_flutter_client/models/task.dart';
-import 'package:auto_gpt_flutter_client/models/test_option.dart';
-import 'package:auto_gpt_flutter_client/models/test_suite.dart';
-import 'package:auto_gpt_flutter_client/services/benchmark_service.dart';
-import 'package:auto_gpt_flutter_client/services/leaderboard_service.dart';
-import 'package:auto_gpt_flutter_client/viewmodels/chat_viewmodel.dart';
-import 'package:auto_gpt_flutter_client/viewmodels/task_viewmodel.dart';
-import 'package:collection/collection.dart';
import 'package:flutter/foundation.dart';
import 'package:flutter/services.dart';
import 'package:graphview/GraphView.dart';
-import 'package:uuid/uuid.dart';
-import 'package:auto_gpt_flutter_client/utils/stack.dart';
-class SkillTreeViewModel extends ChangeNotifier {
- // TODO: Potentially move to task queue view model when we create one
- final BenchmarkService benchmarkService;
- // TODO: Potentially move to task queue view model when we create one
- final LeaderboardService leaderboardService;
- // TODO: Potentially move to task queue view model when we create one
- bool isBenchmarkRunning = false;
- // TODO: Potentially move to task queue view model when we create one
- // TODO: clear when clicking a new node
- Map<SkillTreeNode, BenchmarkTaskStatus> benchmarkStatusMap = {};
-
- List<BenchmarkRun> currentBenchmarkRuns = [];
+import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_category.dart';
+import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_edge.dart';
+import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_node.dart';
+class SkillTreeViewModel extends ChangeNotifier {
List<SkillTreeNode> _skillTreeNodes = [];
- List<SkillTreeEdge> _skillTreeEdges = [];
- SkillTreeNode? _selectedNode;
- // TODO: Potentially move to task queue view model when we create one
- List<SkillTreeNode>? _selectedNodeHierarchy;
-
- TestOption _selectedOption = TestOption.runSingleTest;
- TestOption get selectedOption => _selectedOption;
-
List<SkillTreeNode> get skillTreeNodes => _skillTreeNodes;
+
+ List<SkillTreeEdge> _skillTreeEdges = [];
List<SkillTreeEdge> get skillTreeEdges => _skillTreeEdges;
+
+ SkillTreeNode? _selectedNode;
SkillTreeNode? get selectedNode => _selectedNode;
- List<SkillTreeNode>? get selectedNodeHierarchy => _selectedNodeHierarchy;
final Graph graph = Graph();
SugiyamaConfiguration builder = SugiyamaConfiguration();
SkillTreeCategory currentSkillTreeType = SkillTreeCategory.general;
- SkillTreeViewModel(this.benchmarkService, this.leaderboardService);
-
Future<void> initializeSkillTree() async {
try {
resetState();
@@ -94,143 +61,19 @@ class SkillTreeViewModel extends ChangeNotifier {
_skillTreeNodes = [];
_skillTreeEdges = [];
_selectedNode = null;
- _selectedNodeHierarchy = null;
}
void toggleNodeSelection(String nodeId) {
- if (isBenchmarkRunning) return;
if (_selectedNode?.id == nodeId) {
// Unselect the node if it's already selected
_selectedNode = null;
- _selectedNodeHierarchy = null;
} else {
// Select the new node
_selectedNode = _skillTreeNodes.firstWhere((node) => node.id == nodeId);
- updateSelectedNodeHierarchyBasedOnOption(_selectedOption);
- }
- notifyListeners();
- }
-
- void updateSelectedNodeHierarchyBasedOnOption(TestOption selectedOption) {
- _selectedOption = selectedOption;
- switch (selectedOption) {
- case TestOption.runSingleTest:
- _selectedNodeHierarchy = _selectedNode != null ? [_selectedNode!] : [];
- break;
-
- case TestOption.runTestSuiteIncludingSelectedNodeAndAncestors:
- if (_selectedNode != null) {
- populateSelectedNodeHierarchy(_selectedNode!.id);
- }
- break;
-
- case TestOption.runAllTestsInCategory:
- if (_selectedNode != null) {
- _getAllNodesInDepthFirstOrderEnsuringParents();
- }
- break;
- }
- notifyListeners();
- }
-
- void _getAllNodesInDepthFirstOrderEnsuringParents() {
- var nodes = <SkillTreeNode>[];
- var stack = Stack<SkillTreeNode>();
- var visited = <String>{};
-
- // Identify the root node by its label
- var root = _skillTreeNodes.firstWhere((node) => node.label == "WriteFile");
-
- stack.push(root);
- visited.add(root.id);
-
- while (stack.isNotEmpty) {
- var node = stack.peek(); // Peek the top node, but do not remove it yet
- var parents = _getParentsOfNodeUsingEdges(node.id);
-
- // Check if all parents are visited
- if (parents.every((parent) => visited.contains(parent.id))) {
- nodes.add(node);
- stack.pop(); // Remove the node only when all its parents are visited
-
- // Get the children of the current node using edges
- var children = _getChildrenOfNodeUsingEdges(node.id)
- .where((child) => !visited.contains(child.id));
-
- children.forEach((child) {
- visited.add(child.id);
- stack.push(child);
- });
- } else {
- stack
- .pop(); // Remove the node if not all parents are visited, it will be re-added when its parents are visited
- }
- }
-
- _selectedNodeHierarchy = nodes;
- }
-
- List<SkillTreeNode> _getParentsOfNodeUsingEdges(String nodeId) {
- var parents = <SkillTreeNode>[];
-
- for (var edge in _skillTreeEdges) {
- if (edge.to == nodeId) {
- parents.add(_skillTreeNodes.firstWhere((node) => node.id == edge.from));
- }
- }
-
- return parents;
- }
-
- List<SkillTreeNode> _getChildrenOfNodeUsingEdges(String nodeId) {
- var children = <SkillTreeNode>[];
-
- for (var edge in _skillTreeEdges) {
- if (edge.from == nodeId) {
- children.add(_skillTreeNodes.firstWhere((node) => node.id == edge.to));
- }
}
-
- return children;
- }
-
- // TODO: Do we want to continue testing other branches of tree if one branch side fails benchmarking?
- void populateSelectedNodeHierarchy(String startNodeId) {
- // Initialize an empty list to hold the nodes in all hierarchies.
- _selectedNodeHierarchy = <SkillTreeNode>[];
-
- // Initialize a set to keep track of nodes that have been added.
- final addedNodes = <String>{};
-
- // Start the recursive population of the hierarchy from the startNodeId.
- recursivePopulateHierarchy(startNodeId, addedNodes);
-
- // Notify listeners about the change in the selectedNodeHierarchy state.
notifyListeners();
}
- void recursivePopulateHierarchy(String nodeId, Set<String> addedNodes) {
- // Find the current node in the skill tree nodes list.
- final currentNode =
- _skillTreeNodes.firstWhereOrNull((node) => node.id == nodeId);
-
- // If the node is found and it hasn't been added yet, proceed with the population.
- if (currentNode != null && addedNodes.add(currentNode.id)) {
- // Find all parent edges for the current node.
- final parentEdges =
- _skillTreeEdges.where((edge) => edge.to == currentNode.id);
-
- // For each parent edge found, recurse to the parent node.
- for (final parentEdge in parentEdges) {
- // Recurse to the parent node identified by the 'from' field of the edge.
- recursivePopulateHierarchy(parentEdge.from, addedNodes);
- }
-
- // After processing all parent nodes, add the current node to the list.
- _selectedNodeHierarchy!.add(currentNode);
- }
- }
-
// Function to get a node by its ID
SkillTreeNode? getNodeById(String nodeId) {
try {
@@ -241,124 +84,4 @@ class SkillTreeViewModel extends ChangeNotifier {
return null;
}
}
-
- // TODO: Move to task queue view model
- Future<void> runBenchmark(
- ChatViewModel chatViewModel, TaskViewModel taskViewModel) async {
- // Clear the benchmarkStatusList
- benchmarkStatusMap.clear();
-
- // Reset the current benchmark runs list to be empty at the start of a new benchmark
- currentBenchmarkRuns = [];
-
- // Create a new TestSuite object with the current timestamp
- final testSuite =
- TestSuite(timestamp: DateTime.now().toIso8601String(), tests: []);
-
- // Set the benchmark running flag to true
- isBenchmarkRunning = true;
- // Notify listeners
- notifyListeners();
-
- // Populate benchmarkStatusList with node hierarchy
- for (var node in _selectedNodeHierarchy!) {
- benchmarkStatusMap[node] = BenchmarkTaskStatus.notStarted;
- }
-
- try {
- // Loop through the nodes in the hierarchy
- for (var node in _selectedNodeHierarchy!) {
- benchmarkStatusMap[node] = BenchmarkTaskStatus.inProgress;
- notifyListeners();
-
- // Create a BenchmarkTaskRequestBody
- final benchmarkTaskRequestBody = BenchmarkTaskRequestBody(
- input: node.data.task, evalId: node.data.evalId);
-
- // Create a new benchmark task
- final createdTask = await benchmarkService
- .createBenchmarkTask(benchmarkTaskRequestBody);
-
- // Create a new Task object
- final task =
- Task(id: createdTask['task_id'], title: createdTask['input']);
-
- // Update the current task ID in ChatViewModel
- chatViewModel.setCurrentTaskId(task.id);
-
- // Execute the first step and initialize the Step object
- Map<String, dynamic> stepResponse =
- await benchmarkService.executeBenchmarkStep(
- task.id, BenchmarkStepRequestBody(input: node.data.task));
- Step step = Step.fromMap(stepResponse);
- chatViewModel.fetchChatsForTask();
-
- // Check if it's the last step
- while (!step.isLast) {
- // Execute next step and update the Step object
- stepResponse = await benchmarkService.executeBenchmarkStep(
- task.id, BenchmarkStepRequestBody(input: null));
- step = Step.fromMap(stepResponse);
-
- // Fetch chats for the task
- chatViewModel.fetchChatsForTask();
- }
-
- // Trigger the evaluation
- final evaluationResponse =
- await benchmarkService.triggerEvaluation(task.id);
-
- // Decode the evaluationResponse into a BenchmarkRun object
- BenchmarkRun benchmarkRun = BenchmarkRun.fromJson(evaluationResponse);
-
- // Add the benchmark run object to the list of current benchmark runs
- currentBenchmarkRuns.add(benchmarkRun);
-
- // Update the benchmarkStatusList based on the evaluation response
- bool successStatus = benchmarkRun.metrics.success;
- benchmarkStatusMap[node] = successStatus
- ? BenchmarkTaskStatus.success
- : BenchmarkTaskStatus.failure;
- await Future.delayed(Duration(seconds: 1));
- notifyListeners();
-
- testSuite.tests.add(task);
- // If successStatus is false, break out of the loop
- if (!successStatus) {
- print(
- "Benchmark for node ${node.id} failed. Stopping all benchmarks.");
- break;
- }
- }
-
- // Add the TestSuite to the TaskViewModel
- taskViewModel.addTestSuite(testSuite);
- } catch (e) {
- print("Error while running benchmark: $e");
- }
-
- // Reset the benchmark running flag
- isBenchmarkRunning = false;
- notifyListeners();
- }
-
- // TODO: Move to task queue view model
- Future<void> submitToLeaderboard(
- String teamName, String repoUrl, String agentGitCommitSha) async {
- // Create a UUID.v4 for our unique run ID
- String uuid = const Uuid().v4();
-
- for (var run in currentBenchmarkRuns) {
- run.repositoryInfo.teamName = teamName;
- run.repositoryInfo.repoUrl = repoUrl;
- run.repositoryInfo.agentGitCommitSha = agentGitCommitSha;
- run.runDetails.runId = uuid;
-
- await leaderboardService.submitReport(run);
- print('Completed submission to leaderboard!');
- }
-
- // Clear the currentBenchmarkRuns list after submitting to the leaderboard
- currentBenchmarkRuns.clear();
- }
}