aboutsummaryrefslogtreecommitdiff
path: root/frontend/lib/viewmodels/skill_tree_viewmodel.dart
blob: 2a17f3adfea0b4a5e5ba56091f9996f34ca9e969 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import 'dart:convert';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_run.dart';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_step_request_body.dart';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_request_body.dart';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_status.dart';
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_category.dart';
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_edge.dart';
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_node.dart';
import 'package:auto_gpt_flutter_client/models/step.dart';
import 'package:auto_gpt_flutter_client/models/task.dart';
import 'package:auto_gpt_flutter_client/models/test_option.dart';
import 'package:auto_gpt_flutter_client/models/test_suite.dart';
import 'package:auto_gpt_flutter_client/services/benchmark_service.dart';
import 'package:auto_gpt_flutter_client/services/leaderboard_service.dart';
import 'package:auto_gpt_flutter_client/viewmodels/chat_viewmodel.dart';
import 'package:auto_gpt_flutter_client/viewmodels/task_viewmodel.dart';
import 'package:collection/collection.dart';
import 'package:flutter/foundation.dart';
import 'package:flutter/services.dart';
import 'package:graphview/GraphView.dart';
import 'package:uuid/uuid.dart';
import 'package:auto_gpt_flutter_client/utils/stack.dart';

class SkillTreeViewModel extends ChangeNotifier {
  // TODO: Potentially move to task queue view model when we create one
  final BenchmarkService benchmarkService;
  // TODO: Potentially move to task queue view model when we create one
  final LeaderboardService leaderboardService;
  // TODO: Potentially move to task queue view model when we create one
  bool isBenchmarkRunning = false;
  // TODO: Potentially move to task queue view model when we create one
  // TODO: clear when clicking a new node
  Map<SkillTreeNode, BenchmarkTaskStatus> benchmarkStatusMap = {};

  List<BenchmarkRun> currentBenchmarkRuns = [];

  List<SkillTreeNode> _skillTreeNodes = [];
  List<SkillTreeEdge> _skillTreeEdges = [];
  SkillTreeNode? _selectedNode;
  // TODO: Potentially move to task queue view model when we create one
  List<SkillTreeNode>? _selectedNodeHierarchy;

  TestOption _selectedOption = TestOption.runSingleTest;
  TestOption get selectedOption => _selectedOption;

  List<SkillTreeNode> get skillTreeNodes => _skillTreeNodes;
  List<SkillTreeEdge> get skillTreeEdges => _skillTreeEdges;
  SkillTreeNode? get selectedNode => _selectedNode;
  List<SkillTreeNode>? get selectedNodeHierarchy => _selectedNodeHierarchy;

  final Graph graph = Graph();
  SugiyamaConfiguration builder = SugiyamaConfiguration();

  SkillTreeCategory currentSkillTreeType = SkillTreeCategory.general;

  SkillTreeViewModel(this.benchmarkService, this.leaderboardService);

  Future<void> initializeSkillTree() async {
    try {
      resetState();

      String fileName = currentSkillTreeType.jsonFileName;

      // Read the JSON file from assets
      String jsonContent = await rootBundle.loadString('assets/$fileName');

      // Decode the JSON string
      Map<String, dynamic> decodedJson = jsonDecode(jsonContent);

      // Create SkillTreeNodes from the decoded JSON
      for (var nodeMap in decodedJson['nodes']) {
        SkillTreeNode node = SkillTreeNode.fromJson(nodeMap);
        _skillTreeNodes.add(node);
      }

      // Create SkillTreeEdges from the decoded JSON
      for (var edgeMap in decodedJson['edges']) {
        SkillTreeEdge edge = SkillTreeEdge.fromJson(edgeMap);
        _skillTreeEdges.add(edge);
      }

      builder.orientation = (SugiyamaConfiguration.ORIENTATION_LEFT_RIGHT);
      builder.bendPointShape = CurvedBendPointShape(curveLength: 20);

      notifyListeners();

      return Future.value(); // Explicitly return a completed Future
    } catch (e) {
      print(e);
    }
  }

  void resetState() {
    _skillTreeNodes = [];
    _skillTreeEdges = [];
    _selectedNode = null;
    _selectedNodeHierarchy = null;
  }

  void toggleNodeSelection(String nodeId) {
    if (isBenchmarkRunning) return;
    if (_selectedNode?.id == nodeId) {
      // Unselect the node if it's already selected
      _selectedNode = null;
      _selectedNodeHierarchy = null;
    } else {
      // Select the new node
      _selectedNode = _skillTreeNodes.firstWhere((node) => node.id == nodeId);
      updateSelectedNodeHierarchyBasedOnOption(_selectedOption);
    }
    notifyListeners();
  }

  void updateSelectedNodeHierarchyBasedOnOption(TestOption selectedOption) {
    _selectedOption = selectedOption;
    switch (selectedOption) {
      case TestOption.runSingleTest:
        _selectedNodeHierarchy = _selectedNode != null ? [_selectedNode!] : [];
        break;

      case TestOption.runTestSuiteIncludingSelectedNodeAndAncestors:
        if (_selectedNode != null) {
          populateSelectedNodeHierarchy(_selectedNode!.id);
        }
        break;

      case TestOption.runAllTestsInCategory:
        if (_selectedNode != null) {
          _getAllNodesInDepthFirstOrderEnsuringParents();
        }
        break;
    }
    notifyListeners();
  }

  void _getAllNodesInDepthFirstOrderEnsuringParents() {
    var nodes = <SkillTreeNode>[];
    var stack = Stack<SkillTreeNode>();
    var visited = <String>{};

    // Identify the root node by its label
    var root = _skillTreeNodes.firstWhere((node) => node.label == "WriteFile");

    stack.push(root);
    visited.add(root.id);

    while (stack.isNotEmpty) {
      var node = stack.peek(); // Peek the top node, but do not remove it yet
      var parents = _getParentsOfNodeUsingEdges(node.id);

      // Check if all parents are visited
      if (parents.every((parent) => visited.contains(parent.id))) {
        nodes.add(node);
        stack.pop(); // Remove the node only when all its parents are visited

        // Get the children of the current node using edges
        var children = _getChildrenOfNodeUsingEdges(node.id)
            .where((child) => !visited.contains(child.id));

        children.forEach((child) {
          visited.add(child.id);
          stack.push(child);
        });
      } else {
        stack
            .pop(); // Remove the node if not all parents are visited, it will be re-added when its parents are visited
      }
    }

    _selectedNodeHierarchy = nodes;
  }

  List<SkillTreeNode> _getParentsOfNodeUsingEdges(String nodeId) {
    var parents = <SkillTreeNode>[];

    for (var edge in _skillTreeEdges) {
      if (edge.to == nodeId) {
        parents.add(_skillTreeNodes.firstWhere((node) => node.id == edge.from));
      }
    }

    return parents;
  }

  List<SkillTreeNode> _getChildrenOfNodeUsingEdges(String nodeId) {
    var children = <SkillTreeNode>[];

    for (var edge in _skillTreeEdges) {
      if (edge.from == nodeId) {
        children.add(_skillTreeNodes.firstWhere((node) => node.id == edge.to));
      }
    }

    return children;
  }

  // TODO: Do we want to continue testing other branches of tree if one branch side fails benchmarking?
  void populateSelectedNodeHierarchy(String startNodeId) {
    // Initialize an empty list to hold the nodes in all hierarchies.
    _selectedNodeHierarchy = <SkillTreeNode>[];

    // Initialize a set to keep track of nodes that have been added.
    final addedNodes = <String>{};

    // Start the recursive population of the hierarchy from the startNodeId.
    recursivePopulateHierarchy(startNodeId, addedNodes);

    // Notify listeners about the change in the selectedNodeHierarchy state.
    notifyListeners();
  }

  void recursivePopulateHierarchy(String nodeId, Set<String> addedNodes) {
    // Find the current node in the skill tree nodes list.
    final currentNode =
        _skillTreeNodes.firstWhereOrNull((node) => node.id == nodeId);

    // If the node is found and it hasn't been added yet, proceed with the population.
    if (currentNode != null && addedNodes.add(currentNode.id)) {
      // Find all parent edges for the current node.
      final parentEdges =
          _skillTreeEdges.where((edge) => edge.to == currentNode.id);

      // For each parent edge found, recurse to the parent node.
      for (final parentEdge in parentEdges) {
        // Recurse to the parent node identified by the 'from' field of the edge.
        recursivePopulateHierarchy(parentEdge.from, addedNodes);
      }

      // After processing all parent nodes, add the current node to the list.
      _selectedNodeHierarchy!.add(currentNode);
    }
  }

  // Function to get a node by its ID
  SkillTreeNode? getNodeById(String nodeId) {
    try {
      // Find the node in the list where the ID matches
      return _skillTreeNodes.firstWhere((node) => node.id == nodeId);
    } catch (e) {
      print("Node with ID $nodeId not found: $e");
      return null;
    }
  }

  // TODO: Move to task queue view model
  Future<void> runBenchmark(
      ChatViewModel chatViewModel, TaskViewModel taskViewModel) async {
    // Clear the benchmarkStatusList
    benchmarkStatusMap.clear();

    // Reset the current benchmark runs list to be empty at the start of a new benchmark
    currentBenchmarkRuns = [];

    // Create a new TestSuite object with the current timestamp
    final testSuite =
        TestSuite(timestamp: DateTime.now().toIso8601String(), tests: []);

    // Set the benchmark running flag to true
    isBenchmarkRunning = true;
    // Notify listeners
    notifyListeners();

    // Populate benchmarkStatusList with node hierarchy
    for (var node in _selectedNodeHierarchy!) {
      benchmarkStatusMap[node] = BenchmarkTaskStatus.notStarted;
    }

    try {
      // Loop through the nodes in the hierarchy
      for (var node in _selectedNodeHierarchy!) {
        benchmarkStatusMap[node] = BenchmarkTaskStatus.inProgress;
        notifyListeners();

        // Create a BenchmarkTaskRequestBody
        final benchmarkTaskRequestBody = BenchmarkTaskRequestBody(
            input: node.data.task, evalId: node.data.evalId);

        // Create a new benchmark task
        final createdTask = await benchmarkService
            .createBenchmarkTask(benchmarkTaskRequestBody);

        // Create a new Task object
        final task =
            Task(id: createdTask['task_id'], title: createdTask['input']);

        // Update the current task ID in ChatViewModel
        chatViewModel.setCurrentTaskId(task.id);

        // Execute the first step and initialize the Step object
        Map<String, dynamic> stepResponse =
            await benchmarkService.executeBenchmarkStep(
                task.id, BenchmarkStepRequestBody(input: node.data.task));
        Step step = Step.fromMap(stepResponse);
        chatViewModel.fetchChatsForTask();

        // Check if it's the last step
        while (!step.isLast) {
          // Execute next step and update the Step object
          stepResponse = await benchmarkService.executeBenchmarkStep(
              task.id, BenchmarkStepRequestBody(input: null));
          step = Step.fromMap(stepResponse);

          // Fetch chats for the task
          chatViewModel.fetchChatsForTask();
        }

        // Trigger the evaluation
        final evaluationResponse =
            await benchmarkService.triggerEvaluation(task.id);

        // Decode the evaluationResponse into a BenchmarkRun object
        BenchmarkRun benchmarkRun = BenchmarkRun.fromJson(evaluationResponse);

        // Add the benchmark run object to the list of current benchmark runs
        currentBenchmarkRuns.add(benchmarkRun);

        // Update the benchmarkStatusList based on the evaluation response
        bool successStatus = benchmarkRun.metrics.success;
        benchmarkStatusMap[node] = successStatus
            ? BenchmarkTaskStatus.success
            : BenchmarkTaskStatus.failure;
        await Future.delayed(Duration(seconds: 1));
        notifyListeners();

        testSuite.tests.add(task);
        // If successStatus is false, break out of the loop
        if (!successStatus) {
          print(
              "Benchmark for node ${node.id} failed. Stopping all benchmarks.");
          break;
        }
      }

      // Add the TestSuite to the TaskViewModel
      taskViewModel.addTestSuite(testSuite);
    } catch (e) {
      print("Error while running benchmark: $e");
    }

    // Reset the benchmark running flag
    isBenchmarkRunning = false;
    notifyListeners();
  }

  // TODO: Move to task queue view model
  Future<void> submitToLeaderboard(
      String teamName, String repoUrl, String agentGitCommitSha) async {
    // Create a UUID.v4 for our unique run ID
    String uuid = const Uuid().v4();

    for (var run in currentBenchmarkRuns) {
      run.repositoryInfo.teamName = teamName;
      run.repositoryInfo.repoUrl = repoUrl;
      run.repositoryInfo.agentGitCommitSha = agentGitCommitSha;
      run.runDetails.runId = uuid;

      await leaderboardService.submitReport(run);
      print('Completed submission to leaderboard!');
    }

    // Clear the currentBenchmarkRuns list after submitting to the leaderboard
    currentBenchmarkRuns.clear();
  }
}