aboutsummaryrefslogtreecommitdiff
path: root/benchmark/agbenchmark/utils/dependencies/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'benchmark/agbenchmark/utils/dependencies/main.py')
-rw-r--r--benchmark/agbenchmark/utils/dependencies/main.py253
1 files changed, 253 insertions, 0 deletions
diff --git a/benchmark/agbenchmark/utils/dependencies/main.py b/benchmark/agbenchmark/utils/dependencies/main.py
new file mode 100644
index 000000000..7dab3b51b
--- /dev/null
+++ b/benchmark/agbenchmark/utils/dependencies/main.py
@@ -0,0 +1,253 @@
+"""
+A module to manage dependencies between pytest tests.
+
+This module provides the methods implementing the main logic. These are used in the pytest hooks that are in
+__init__.py.
+"""
+
+import collections
+import json
+import os
+from typing import Any, Generator
+
+import colorama
+import networkx
+from _pytest.nodes import Item
+
+from .constants import MARKER_KWARG_DEPENDENCIES, MARKER_NAME
+from .graphs import graph_interactive_network
+from .util import clean_nodeid, get_absolute_nodeid, get_markers, get_name
+
+
+class TestResult(object):
+ """Keeps track of the results of a single test."""
+
+ STEPS = ["setup", "call", "teardown"]
+ GOOD_OUTCOMES = ["passed"]
+
+ def __init__(self, nodeid: str) -> None:
+ """Create a new instance for a test with a given node id."""
+ self.nodeid = nodeid
+ self.results: dict[str, Any] = {}
+
+ def register_result(self, result: Any) -> None:
+ """Register a result of this test."""
+ if result.when not in self.STEPS:
+ raise ValueError(
+ f"Received result for unknown step {result.when} of test {self.nodeid}"
+ )
+ if result.when in self.results:
+ raise AttributeError(
+ f"Received multiple results for step {result.when} of test {self.nodeid}"
+ )
+ self.results[result.when] = result.outcome
+
+ @property
+ def success(self) -> bool:
+ """Whether the entire test was successful."""
+ return all(
+ self.results.get(step, None) in self.GOOD_OUTCOMES for step in self.STEPS
+ )
+
+
+class TestDependencies(object):
+ """Information about the resolved dependencies of a single test."""
+
+ def __init__(self, item: Item, manager: "DependencyManager") -> None:
+ """Create a new instance for a given test."""
+ self.nodeid = clean_nodeid(item.nodeid)
+ self.dependencies = set()
+ self.unresolved = set()
+
+ markers = get_markers(item, MARKER_NAME)
+ dependencies = [
+ dep
+ for marker in markers
+ for dep in marker.kwargs.get(MARKER_KWARG_DEPENDENCIES, [])
+ ]
+ for dependency in dependencies:
+ # If the name is not known, try to make it absolute (ie file::[class::]method)
+ if dependency not in manager.name_to_nodeids:
+ absolute_dependency = get_absolute_nodeid(dependency, self.nodeid)
+ if absolute_dependency in manager.name_to_nodeids:
+ dependency = absolute_dependency
+
+ # Add all items matching the name
+ if dependency in manager.name_to_nodeids:
+ for nodeid in manager.name_to_nodeids[dependency]:
+ self.dependencies.add(nodeid)
+ else:
+ self.unresolved.add(dependency)
+
+
+class DependencyManager(object):
+ """Keep track of tests, their names and their dependencies."""
+
+ def __init__(self) -> None:
+ """Create a new DependencyManager."""
+ self.options: dict[str, Any] = {}
+ self._items: list[Item] | None = None
+ self._name_to_nodeids: Any = None
+ self._nodeid_to_item: Any = None
+ self._results: Any = None
+
+ @property
+ def items(self) -> list[Item]:
+ """The collected tests that are managed by this instance."""
+ if self._items is None:
+ raise AttributeError("The items attribute has not been set yet")
+ return self._items
+
+ @items.setter
+ def items(self, items: list[Item]) -> None:
+ if self._items is not None:
+ raise AttributeError("The items attribute has already been set")
+ self._items = items
+
+ self._name_to_nodeids = collections.defaultdict(list)
+ self._nodeid_to_item = {}
+ self._results = {}
+ self._dependencies = {}
+
+ for item in items:
+ nodeid = clean_nodeid(item.nodeid)
+ # Add the mapping from nodeid to the test item
+ self._nodeid_to_item[nodeid] = item
+ # Add the mappings from all names to the node id
+ name = get_name(item)
+ self._name_to_nodeids[name].append(nodeid)
+ # Create the object that will contain the results of this test
+ self._results[nodeid] = TestResult(clean_nodeid(item.nodeid))
+
+ # Don't allow using unknown keys on the name_to_nodeids mapping
+ self._name_to_nodeids.default_factory = None
+
+ for item in items:
+ nodeid = clean_nodeid(item.nodeid)
+ # Process the dependencies of this test
+ # This uses the mappings created in the previous loop, and can thus not be merged into that loop
+ self._dependencies[nodeid] = TestDependencies(item, self)
+
+ @property
+ def name_to_nodeids(self) -> dict[str, list[str]]:
+ """A mapping from names to matching node id(s)."""
+ assert self.items is not None
+ return self._name_to_nodeids
+
+ @property
+ def nodeid_to_item(self) -> dict[str, Item]:
+ """A mapping from node ids to test items."""
+ assert self.items is not None
+ return self._nodeid_to_item
+
+ @property
+ def results(self) -> dict[str, TestResult]:
+ """The results of the tests."""
+ assert self.items is not None
+ return self._results
+
+ @property
+ def dependencies(self) -> dict[str, TestDependencies]:
+ """The dependencies of the tests."""
+ assert self.items is not None
+ return self._dependencies
+
+ def print_name_map(self, verbose: bool = False) -> None:
+ """Print a human-readable version of the name -> test mapping."""
+ print("Available dependency names:")
+ for name, nodeids in sorted(self.name_to_nodeids.items(), key=lambda x: x[0]):
+ if len(nodeids) == 1:
+ if name == nodeids[0]:
+ # This is just the base name, only print this when verbose
+ if verbose:
+ print(f" {name}")
+ else:
+ # Name refers to a single node id, so use the short format
+ print(f" {name} -> {nodeids[0]}")
+ else:
+ # Name refers to multiple node ids, so use the long format
+ print(f" {name} ->")
+ for nodeid in sorted(nodeids):
+ print(f" {nodeid}")
+
+ def print_processed_dependencies(self, colors: bool = False) -> None:
+ """Print a human-readable list of the processed dependencies."""
+ missing = "MISSING"
+ if colors:
+ missing = f"{colorama.Fore.RED}{missing}{colorama.Fore.RESET}"
+ colorama.init()
+ try:
+ print("Dependencies:")
+ for nodeid, info in sorted(self.dependencies.items(), key=lambda x: x[0]):
+ descriptions = []
+ for dependency in info.dependencies:
+ descriptions.append(dependency)
+ for dependency in info.unresolved:
+ descriptions.append(f"{dependency} ({missing})")
+ if descriptions:
+ print(f" {nodeid} depends on")
+ for description in sorted(descriptions):
+ print(f" {description}")
+ finally:
+ if colors:
+ colorama.deinit()
+
+ @property
+ def sorted_items(self) -> Generator:
+ """Get a sorted list of tests where all tests are sorted after their dependencies."""
+ # Build a directed graph for sorting
+ build_skill_tree = os.getenv("BUILD_SKILL_TREE")
+ BUILD_SKILL_TREE = (
+ build_skill_tree.lower() == "true" if build_skill_tree else False
+ )
+ dag = networkx.DiGraph()
+
+ # Insert all items as nodes, to prevent items that have no dependencies and are not dependencies themselves from
+ # being lost
+ dag.add_nodes_from(self.items)
+
+ # Insert edges for all the dependencies
+ for item in self.items:
+ nodeid = clean_nodeid(item.nodeid)
+ for dependency in self.dependencies[nodeid].dependencies:
+ dag.add_edge(self.nodeid_to_item[dependency], item)
+
+ labels = {}
+ for item in self.items:
+ try:
+ with open(item.cls.CHALLENGE_LOCATION) as f:
+ data = json.load(f)
+ except:
+ data = {}
+
+ node_name = get_name(item)
+ data["name"] = node_name
+ labels[item] = data
+
+ # only build the tree if it's specified in the env and is a whole run
+ if BUILD_SKILL_TREE:
+ # graph_spring_layout(dag, labels)
+ graph_interactive_network(dag, labels, html_graph_path="")
+
+ # Sort based on the dependencies
+ return networkx.topological_sort(dag)
+
+ def register_result(self, item: Item, result: Any) -> None:
+ """Register a result of a test."""
+ nodeid = clean_nodeid(item.nodeid)
+ self.results[nodeid].register_result(result)
+
+ def get_failed(self, item: Item) -> Any:
+ """Get a list of unfulfilled dependencies for a test."""
+ nodeid = clean_nodeid(item.nodeid)
+ failed = []
+ for dependency in self.dependencies[nodeid].dependencies:
+ result = self.results[dependency]
+ if not result.success:
+ failed.append(dependency)
+ return failed
+
+ def get_missing(self, item: Item) -> Any:
+ """Get a list of missing dependencies for a test."""
+ nodeid = clean_nodeid(item.nodeid)
+ return self.dependencies[nodeid].unresolved