Source code for golem.core.dag.graph

from abc import ABC, abstractmethod
from enum import Enum
from os import PathLike
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, TypeVar, Union

import networkx as nx

from golem.core.dag.graph_node import GraphNode
from golem.visualisation.graph_viz import GraphVisualizer, NodeColorType

NodeType = TypeVar('NodeType', bound=GraphNode, covariant=False, contravariant=False)


class ReconnectType(Enum):
    """Defines allowed kinds of removals in Graph. Used by mutations."""
    none = 'none'  # do not reconnect predecessors
    single = 'single'  # reconnect a predecessor only if it's single
    all = 'all'  # reconnect all predecessors to all successors


[docs]class Graph(ABC): """Defines abstract graph interface that's required by graph optimisation process. """
[docs] @abstractmethod def add_node(self, node: GraphNode): """Adds new node to the graph together with its parent nodes. Args: node: graph nodes """ raise NotImplementedError()
[docs] @abstractmethod def update_node(self, old_node: GraphNode, new_node: GraphNode): """Replaces ``old_node`` node with ``new_node`` Args: old_node: node to be replaced new_node: node to be placed instead """ raise NotImplementedError()
[docs] @abstractmethod def update_subtree(self, old_subtree: GraphNode, new_subtree: GraphNode): """Changes ``old_subtree`` subtree to ``new_subtree`` Args: old_subtree: node and its subtree to be removed new_subtree: node and its subtree to be placed instead """ raise NotImplementedError()
[docs] @abstractmethod def delete_node(self, node: GraphNode, reconnect: ReconnectType = ReconnectType.single): """Removes ``node`` from the graph. If ``node`` has only one child, then connects all of the ``node`` parents to it. Args: node: node of the graph to be deleted reconnect: defines how to treat left edges between parents and children """ raise NotImplementedError()
[docs] @abstractmethod def delete_subtree(self, subtree: GraphNode): """Deletes given node with all its parents. Deletes all edges from removed nodes to remaining graph nodes Args: subtree: node to be deleted with all of its parents and their connections amongst the remaining graph nodes """ raise NotImplementedError()
[docs] @abstractmethod def node_children(self, node: GraphNode) -> Sequence[Optional[GraphNode]]: """Returns all children of the ``node`` Args: node: for getting children from Returns: children of the ``node`` """ raise NotImplementedError()
[docs] @abstractmethod def connect_nodes(self, node_parent: GraphNode, node_child: GraphNode): """Adds edge between ``parent`` and ``child`` Args: node_parent: acts like parent in graph connection relations node_child: acts like child in graph connection relations """ raise NotImplementedError()
[docs] @abstractmethod def disconnect_nodes(self, node_parent: GraphNode, node_child: GraphNode, clean_up_leftovers: bool = False): """Removes an edge between two nodes Args: node_parent: where the removing edge comes out node_child: where the removing edge enters clean_up_leftovers: whether to remove the remaining invalid vertices with edges or not """ raise NotImplementedError()
[docs] @abstractmethod def get_edges(self) -> Sequence[Tuple[GraphNode, GraphNode]]: """Gets all available edges in this graph Returns: pairs of parent_node -> child_node """ raise NotImplementedError()
[docs] def get_nodes_by_name(self, name: str) -> List[GraphNode]: """Returns list of nodes with the required ``name`` Args: name: name to filter by Returns: list: relevant nodes (empty if there are no such nodes) """ appropriate_nodes = filter(lambda x: x.name == name, self.nodes) return list(appropriate_nodes)
[docs] def get_node_by_uid(self, uid: str) -> Optional[GraphNode]: """Returns node with the required ``uid`` Args: uid: uid of node to filter by Returns: Optional[Node]: relevant node (None if there is no such node) """ appropriate_nodes = list(filter(lambda x: x.uid == uid, self.nodes)) return appropriate_nodes[0] if appropriate_nodes else None
[docs] @abstractmethod def __eq__(self, other_graph: 'Graph') -> bool: """Compares this graph with the ``other_graph`` Args: other_graph: another graph Returns: is it equal to ``other_graph`` in terms of the graphs """ raise NotImplementedError()
def root_nodes(self) -> Sequence[GraphNode]: raise NotImplementedError() @property def root_node(self) -> Union[GraphNode, Sequence[GraphNode]]: """Gets the final layer node(s) of the graph Returns: the final layer node(s) """ roots = self.root_nodes() if len(roots) == 1: return roots[0] return roots @property @abstractmethod def nodes(self) -> List[GraphNode]: """Return list of all graph nodes Returns: graph nodes """ raise NotImplementedError() @nodes.setter @abstractmethod def nodes(self, new_nodes: List[GraphNode]): raise NotImplementedError() @property @abstractmethod def depth(self) -> int: """Gets this graph depth from its sink-node to its source-node Returns: length of a path from the root node to the farthest primary node """ raise NotImplementedError() @property def length(self) -> int: """Return size of the graph (number of nodes) Returns: graph size """ return len(self.nodes)
[docs] def show(self, save_path: Optional[Union[PathLike, str]] = None, engine: Optional[str] = None, node_color: Optional[NodeColorType] = None, dpi: Optional[int] = None, node_size_scale: Optional[float] = None, font_size_scale: Optional[float] = None, edge_curvature_scale: Optional[float] = None, title: Optional[str] = None, node_names_placement: Optional[Literal['auto', 'nodes', 'legend', 'none']] = None, nodes_labels: Dict[int, str] = None, edges_labels: Dict[int, str] = None, nodes_layout_function: Optional[Callable[[nx.DiGraph], Dict[Any, Tuple[float, float]]]] = None): """Visualizes graph or saves its picture to the specified ``path`` Args: save_path: optional, save location of the graph visualization image. engine: engine to visualize the graph. Possible values: 'matplotlib', 'pyvis', 'graphviz'. node_color: color of nodes to use. node_size_scale: use to make node size bigger or lesser. Supported only for the engine 'matplotlib'. font_size_scale: use to make font size bigger or lesser. Supported only for the engine 'matplotlib'. edge_curvature_scale: use to make edges more or less curved. Supported only for the engine 'matplotlib'. dpi: DPI of the output image. Not supported for the engine 'pyvis'. title: title for plot node_names_placement: variant of node names displaying. Defaults to ``auto``. Possible options: - ``auto`` -> empirical rule by node size - ``nodes`` -> place node names on top of the nodes - ``legend`` -> place node names at the legend - ``none`` -> do not show node names nodes_labels: labels to display near nodes edges_labels: labels to display near edges nodes_layout_function: any of `Networkx layout functions \ <https://networkx.org/documentation/stable/reference/drawing.html#module-networkx.drawing.layout>`_ . """ GraphVisualizer(graph=self) \ .visualise(save_path=save_path, engine=engine, node_color=node_color, dpi=dpi, node_size_scale=node_size_scale, font_size_scale=font_size_scale, edge_curvature_scale=edge_curvature_scale, node_names_placement=node_names_placement, title=title, nodes_layout_function=nodes_layout_function, nodes_labels=nodes_labels, edges_labels=edges_labels)
@property def graph_description(self) -> Dict: """Return summary characteristics of the graph Returns: dict: containing information about the graph """ return { 'depth': self.depth, 'length': self.length, 'nodes': self.nodes, } @property def descriptive_id(self) -> str: """Returns human-readable identifier of the graph. Returns: str: text description of the content in the node and its parameters """ if self.root_nodes: return self.root_node.descriptive_id else: return sorted(self.nodes, key=lambda x: x.uid)[0].descriptive_id def __str__(self): return str(self.graph_description) def __repr__(self): return self.__str__() def __len__(self): return self.length