diff options
Diffstat (limited to 'benchmark/agbenchmark/utils/dependencies/graphs.py')
-rw-r--r-- | benchmark/agbenchmark/utils/dependencies/graphs.py | 445 |
1 files changed, 445 insertions, 0 deletions
diff --git a/benchmark/agbenchmark/utils/dependencies/graphs.py b/benchmark/agbenchmark/utils/dependencies/graphs.py new file mode 100644 index 000000000..47d3d5c09 --- /dev/null +++ b/benchmark/agbenchmark/utils/dependencies/graphs.py @@ -0,0 +1,445 @@ +import json +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np +from pyvis.network import Network + +from agbenchmark.generate_test import DATA_CATEGORY +from agbenchmark.utils.utils import write_pretty_json + +logger = logging.getLogger(__name__) + + +def bezier_curve( + src: np.ndarray, ctrl: List[float], dst: np.ndarray +) -> List[np.ndarray]: + """ + Generate Bézier curve points. + + Args: + - src (np.ndarray): The source point. + - ctrl (List[float]): The control point. + - dst (np.ndarray): The destination point. + + Returns: + - List[np.ndarray]: The Bézier curve points. + """ + curve = [] + for t in np.linspace(0, 1, num=100): + curve_point = ( + np.outer((1 - t) ** 2, src) + + 2 * np.outer((1 - t) * t, ctrl) + + np.outer(t**2, dst) + ) + curve.append(curve_point[0]) + return curve + + +def curved_edges( + G: nx.Graph, pos: Dict[Any, Tuple[float, float]], dist: float = 0.2 +) -> None: + """ + Draw curved edges for nodes on the same level. + + Args: + - G (Any): The graph object. + - pos (Dict[Any, Tuple[float, float]]): Dictionary with node positions. + - dist (float, optional): Distance for curvature. Defaults to 0.2. + + Returns: + - None + """ + ax = plt.gca() + for u, v, data in G.edges(data=True): + src = np.array(pos[u]) + dst = np.array(pos[v]) + + same_level = abs(src[1] - dst[1]) < 0.01 + + if same_level: + control = [(src[0] + dst[0]) / 2, src[1] + dist] + curve = bezier_curve(src, control, dst) + arrow = patches.FancyArrowPatch( + posA=curve[0], # type: ignore + posB=curve[-1], # type: ignore + connectionstyle=f"arc3,rad=0.2", + color="gray", + arrowstyle="-|>", + mutation_scale=15.0, + lw=1, + shrinkA=10, + shrinkB=10, + ) + ax.add_patch(arrow) + else: + ax.annotate( + "", + xy=dst, + xytext=src, + arrowprops=dict( + arrowstyle="-|>", color="gray", lw=1, shrinkA=10, shrinkB=10 + ), + ) + + +def tree_layout(graph: nx.DiGraph, root_node: Any) -> Dict[Any, Tuple[float, float]]: + """Compute positions as a tree layout centered on the root with alternating vertical shifts.""" + bfs_tree = nx.bfs_tree(graph, source=root_node) + levels = { + node: depth + for node, depth in nx.single_source_shortest_path_length( + bfs_tree, root_node + ).items() + } + + pos = {} + max_depth = max(levels.values()) + level_positions = {i: 0 for i in range(max_depth + 1)} # type: ignore + + # Count the number of nodes per level to compute the width + level_count: Any = {} + for node, level in levels.items(): + level_count[level] = level_count.get(level, 0) + 1 + + vertical_offset = ( + 0.07 # The amount of vertical shift per node within the same level + ) + + # Assign positions + for node, level in sorted(levels.items(), key=lambda x: x[1]): + total_nodes_in_level = level_count[level] + horizontal_spacing = 1.0 / (total_nodes_in_level + 1) + pos_x = ( + 0.5 + - (total_nodes_in_level - 1) * horizontal_spacing / 2 + + level_positions[level] * horizontal_spacing + ) + + # Alternately shift nodes up and down within the same level + pos_y = ( + -level + + (level_positions[level] % 2) * vertical_offset + - ((level_positions[level] + 1) % 2) * vertical_offset + ) + pos[node] = (pos_x, pos_y) + + level_positions[level] += 1 + + return pos + + +def graph_spring_layout( + dag: nx.DiGraph, labels: Dict[Any, str], tree: bool = True +) -> None: + num_nodes = len(dag.nodes()) + # Setting up the figure and axis + fig, ax = plt.subplots() + ax.axis("off") # Turn off the axis + + base = 3.0 + + if num_nodes > 10: + base /= 1 + math.log(num_nodes) + font_size = base * 10 + + font_size = max(10, base * 10) + node_size = max(300, base * 1000) + + if tree: + root_node = [node for node, degree in dag.in_degree() if degree == 0][0] + pos = tree_layout(dag, root_node) + else: + # Adjust k for the spring layout based on node count + k_value = 3 / math.sqrt(num_nodes) + + pos = nx.spring_layout(dag, k=k_value, iterations=50) + + # Draw nodes and labels + nx.draw_networkx_nodes(dag, pos, node_color="skyblue", node_size=int(node_size)) + nx.draw_networkx_labels(dag, pos, labels=labels, font_size=int(font_size)) + + # Draw curved edges + curved_edges(dag, pos) # type: ignore + + plt.tight_layout() + plt.show() + + +def rgb_to_hex(rgb: Tuple[float, float, float]) -> str: + return "#{:02x}{:02x}{:02x}".format( + int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255) + ) + + +def get_category_colors(categories: Dict[Any, str]) -> Dict[str, str]: + unique_categories = set(categories.values()) + colormap = plt.cm.get_cmap("tab10", len(unique_categories)) # type: ignore + return { + category: rgb_to_hex(colormap(i)[:3]) + for i, category in enumerate(unique_categories) + } + + +def graph_interactive_network( + dag: nx.DiGraph, + labels: Dict[Any, Dict[str, Any]], + html_graph_path: str = "", +) -> None: + nt = Network(notebook=True, width="100%", height="800px", directed=True) + + category_colors = get_category_colors(DATA_CATEGORY) + + # Add nodes and edges to the pyvis network + for node, json_data in labels.items(): + label = json_data.get("name", "") + # remove the first 4 letters of label + label_without_test = label[4:] + node_id_str = node.nodeid + + # Get the category for this label + category = DATA_CATEGORY.get( + label, "unknown" + ) # Default to 'unknown' if label not found + + # Get the color for this category + color = category_colors.get(category, "grey") + + nt.add_node( + node_id_str, + label=label_without_test, + color=color, + data=json_data, + ) + + # Add edges to the pyvis network + for edge in dag.edges(): + source_id_str = edge[0].nodeid + target_id_str = edge[1].nodeid + edge_id_str = ( + f"{source_id_str}_to_{target_id_str}" # Construct a unique edge id + ) + if not (source_id_str in nt.get_nodes() and target_id_str in nt.get_nodes()): + logger.warning( + f"Skipping edge {source_id_str} -> {target_id_str} due to missing nodes" + ) + continue + nt.add_edge(source_id_str, target_id_str, id=edge_id_str) + + # Configure physics for hierarchical layout + hierarchical_options = { + "enabled": True, + "levelSeparation": 200, # Increased vertical spacing between levels + "nodeSpacing": 250, # Increased spacing between nodes on the same level + "treeSpacing": 250, # Increased spacing between different trees (for forest) + "blockShifting": True, + "edgeMinimization": True, + "parentCentralization": True, + "direction": "UD", + "sortMethod": "directed", + } + + physics_options = { + "stabilization": { + "enabled": True, + "iterations": 1000, # Default is often around 100 + }, + "hierarchicalRepulsion": { + "centralGravity": 0.0, + "springLength": 200, # Increased edge length + "springConstant": 0.01, + "nodeDistance": 250, # Increased minimum distance between nodes + "damping": 0.09, + }, + "solver": "hierarchicalRepulsion", + "timestep": 0.5, + } + + nt.options = { + "nodes": { + "font": { + "size": 20, # Increased font size for labels + "color": "black", # Set a readable font color + }, + "shapeProperties": {"useBorderWithImage": True}, + }, + "edges": { + "length": 250, # Increased edge length + }, + "physics": physics_options, + "layout": {"hierarchical": hierarchical_options}, + } + + # Serialize the graph to JSON and save in appropriate locations + graph_data = {"nodes": nt.nodes, "edges": nt.edges} + logger.debug(f"Generated graph data:\n{json.dumps(graph_data, indent=4)}") + + # FIXME: use more reliable method to find the right location for these files. + # This will fail in all cases except if run from the root of our repo. + home_path = Path.cwd() + write_pretty_json(graph_data, home_path / "frontend" / "public" / "graph.json") + + flutter_app_path = home_path.parent / "frontend" / "assets" + + # Optionally, save to a file + # Sync with the flutter UI + # this literally only works in the AutoGPT repo, but this part of the code is not reached if BUILD_SKILL_TREE is false + write_pretty_json(graph_data, flutter_app_path / "tree_structure.json") + validate_skill_tree(graph_data, "") + + # Extract node IDs with category "coding" + + coding_tree = extract_subgraph_based_on_category(graph_data.copy(), "coding") + validate_skill_tree(coding_tree, "coding") + write_pretty_json( + coding_tree, + flutter_app_path / "coding_tree_structure.json", + ) + + data_tree = extract_subgraph_based_on_category(graph_data.copy(), "data") + # validate_skill_tree(data_tree, "data") + write_pretty_json( + data_tree, + flutter_app_path / "data_tree_structure.json", + ) + + general_tree = extract_subgraph_based_on_category(graph_data.copy(), "general") + validate_skill_tree(general_tree, "general") + write_pretty_json( + general_tree, + flutter_app_path / "general_tree_structure.json", + ) + + scrape_synthesize_tree = extract_subgraph_based_on_category( + graph_data.copy(), "scrape_synthesize" + ) + validate_skill_tree(scrape_synthesize_tree, "scrape_synthesize") + write_pretty_json( + scrape_synthesize_tree, + flutter_app_path / "scrape_synthesize_tree_structure.json", + ) + + if html_graph_path: + file_path = str(Path(html_graph_path).resolve()) + + nt.write_html(file_path) + + +def extract_subgraph_based_on_category(graph, category): + """ + Extracts a subgraph that includes all nodes and edges required to reach all nodes with a specified category. + + :param graph: The original graph. + :param category: The target category. + :return: Subgraph with nodes and edges required to reach the nodes with the given category. + """ + + subgraph = {"nodes": [], "edges": []} + visited = set() + + def reverse_dfs(node_id): + if node_id in visited: + return + visited.add(node_id) + + node_data = next(node for node in graph["nodes"] if node["id"] == node_id) + + # Add the node to the subgraph if it's not already present. + if node_data not in subgraph["nodes"]: + subgraph["nodes"].append(node_data) + + for edge in graph["edges"]: + if edge["to"] == node_id: + if edge not in subgraph["edges"]: + subgraph["edges"].append(edge) + reverse_dfs(edge["from"]) + + # Identify nodes with the target category and initiate reverse DFS from them. + nodes_with_target_category = [ + node["id"] for node in graph["nodes"] if category in node["data"]["category"] + ] + + for node_id in nodes_with_target_category: + reverse_dfs(node_id) + + return subgraph + + +def is_circular(graph): + def dfs(node, visited, stack, parent_map): + visited.add(node) + stack.add(node) + for edge in graph["edges"]: + if edge["from"] == node: + if edge["to"] in stack: + # Detected a cycle + cycle_path = [] + current = node + while current != edge["to"]: + cycle_path.append(current) + current = parent_map.get(current) + cycle_path.append(edge["to"]) + cycle_path.append(node) + return cycle_path[::-1] + elif edge["to"] not in visited: + parent_map[edge["to"]] = node + cycle_path = dfs(edge["to"], visited, stack, parent_map) + if cycle_path: + return cycle_path + stack.remove(node) + return None + + visited = set() + stack = set() + parent_map = {} + for node in graph["nodes"]: + node_id = node["id"] + if node_id not in visited: + cycle_path = dfs(node_id, visited, stack, parent_map) + if cycle_path: + return cycle_path + return None + + +def get_roots(graph): + """ + Return the roots of a graph. Roots are nodes with no incoming edges. + """ + # Create a set of all node IDs + all_nodes = {node["id"] for node in graph["nodes"]} + + # Create a set of nodes with incoming edges + nodes_with_incoming_edges = {edge["to"] for edge in graph["edges"]} + + # Roots are nodes that have no incoming edges + roots = all_nodes - nodes_with_incoming_edges + + return list(roots) + + +def validate_skill_tree(graph, skill_tree_name): + """ + Validate if a given graph represents a valid skill tree and raise appropriate exceptions if not. + + :param graph: A dictionary representing the graph with 'nodes' and 'edges'. + :raises: ValueError with a description of the invalidity. + """ + # Check for circularity + cycle_path = is_circular(graph) + if cycle_path: + cycle_str = " -> ".join(cycle_path) + raise ValueError( + f"{skill_tree_name} skill tree is circular! Circular path detected: {cycle_str}." + ) + + # Check for multiple roots + roots = get_roots(graph) + if len(roots) > 1: + raise ValueError(f"{skill_tree_name} skill tree has multiple roots: {roots}.") + elif not roots: + raise ValueError(f"{skill_tree_name} skill tree has no roots.") |