Source code for golem.core.dag.graph

from abc import ABC, abstractmethod
from enum import Enum
from os import PathLike
from typing import Dict, List, Optional, Sequence, Union, Tuple, TypeVar

from golem.core.dag.graph_node import GraphNode
from golem.visualisation.graph_viz import GraphVisualizer, NodeColorType

NodeType = TypeVar('NodeType', bound=GraphNode, covariant=False, contravariant=False)


class ReconnectType(Enum):
    """Defines allowed kinds of removals in Graph. Used by mutations."""
    none = 'none'  # do not reconnect predecessors
    single = 'single'  # reconnect a predecessor only if it's single
    all = 'all'  # reconnect all predecessors to all successors


[docs]class Graph(ABC): """Defines abstract graph interface that's required by graph optimisation process. """
[docs] @abstractmethod def add_node(self, node: GraphNode): """Adds new node to the graph together with its parent nodes. Args: node: graph nodes """ raise NotImplementedError()
[docs] @abstractmethod def update_node(self, old_node: GraphNode, new_node: GraphNode): """Replaces ``old_node`` node with ``new_node`` :param old_node: node to be replaced :param new_node: node to be placed instead """ raise NotImplementedError()
[docs] @abstractmethod def update_subtree(self, old_subtree: GraphNode, new_subtree: GraphNode): """Changes ``old_subtree`` subtree to ``new_subtree`` Args: old_subtree: node and its subtree to be removed new_subtree: node and its subtree to be placed instead """ raise NotImplementedError()
[docs] @abstractmethod def delete_node(self, node: GraphNode, reconnect: ReconnectType = ReconnectType.single): """Removes ``node`` from the graph. If ``node`` has only one child, then connects all of the ``node`` parents to it. Args: node: node of the graph to be deleted reconnect: defines how to treat left edges between parents and children """ raise NotImplementedError()
[docs] @abstractmethod def delete_subtree(self, subtree: GraphNode): """Deletes given node with all its parents. Deletes all edges from removed nodes to remaining graph nodes Args: subtree: node to be deleted with all of its parents and their connections amongst the remaining graph nodes """ raise NotImplementedError()
[docs] @abstractmethod def node_children(self, node: GraphNode) -> Sequence[Optional[GraphNode]]: """Returns all children of the ``node`` Args: node: for getting children from Returns: children of the ``node`` """ raise NotImplementedError()
[docs] @abstractmethod def connect_nodes(self, node_parent: GraphNode, node_child: GraphNode): """Adds edge between ``parent`` and ``child`` Args: node_parent: acts like parent in graph connection relations node_child: acts like child in graph connection relations """ raise NotImplementedError()
[docs] @abstractmethod def disconnect_nodes(self, node_parent: GraphNode, node_child: GraphNode, clean_up_leftovers: bool = False): """Removes an edge between two nodes Args: node_parent: where the removing edge comes out node_child: where the removing edge enters clean_up_leftovers: whether to remove the remaining invalid vertices with edges or not """ raise NotImplementedError()
[docs] @abstractmethod def get_edges(self) -> Sequence[Tuple[GraphNode, GraphNode]]: """Gets all available edges in this graph Returns: pairs of parent_node -> child_node """ raise NotImplementedError()
[docs] def get_nodes_by_name(self, name: str) -> List[GraphNode]: """Returns list of nodes with the required ``name`` Args: name: name to filter by Returns: list: relevant nodes (empty if there are no such nodes) """ appropriate_nodes = filter(lambda x: x.name == name, self.nodes) return list(appropriate_nodes)
[docs] def get_node_by_uid(self, uid: str) -> Optional[GraphNode]: """Returns node with the required ``uid`` Args: uid: uid of node to filter by Returns: Optional[Node]: relevant node (None if there is no such node) """ appropriate_nodes = list(filter(lambda x: x.uid == uid, self.nodes)) return appropriate_nodes[0] if appropriate_nodes else None
[docs] @abstractmethod def __eq__(self, other_graph: 'Graph') -> bool: """Compares this graph with the ``other_graph`` Args: other_graph: another graph Returns: is it equal to ``other_graph`` in terms of the graphs """ raise NotImplementedError()
def root_nodes(self) -> Sequence[GraphNode]: raise NotImplementedError() @property def root_node(self) -> Union[GraphNode, Sequence[GraphNode]]: """Gets the final layer node(s) of the graph Returns: the final layer node(s) """ roots = self.root_nodes() if len(roots) == 1: return roots[0] return roots @property @abstractmethod def nodes(self) -> List[GraphNode]: """Return list of all graph nodes Returns: graph nodes """ raise NotImplementedError() @nodes.setter @abstractmethod def nodes(self, new_nodes: List[GraphNode]): raise NotImplementedError() @property @abstractmethod def depth(self) -> int: """Gets this graph depth from its sink-node to its source-node Returns: length of a path from the root node to the farthest primary node """ raise NotImplementedError() @property def length(self) -> int: """Return size of the graph (number of nodes) Returns: graph size """ return len(self.nodes)
[docs] def show(self, save_path: Optional[Union[PathLike, str]] = None, engine: Optional[str] = None, node_color: Optional[NodeColorType] = None, dpi: Optional[int] = None, node_size_scale: Optional[float] = None, font_size_scale: Optional[float] = None, edge_curvature_scale: Optional[float] = None, nodes_labels: Dict[int, str] = None, edges_labels: Dict[int, str] = None): """Visualizes graph or saves its picture to the specified ``path`` Args: save_path: optional, save location of the graph visualization image. engine: engine to visualize the graph. Possible values: 'matplotlib', 'pyvis', 'graphviz'. node_color: color of nodes to use. node_size_scale: use to make node size bigger or lesser. Supported only for the engine 'matplotlib'. font_size_scale: use to make font size bigger or lesser. Supported only for the engine 'matplotlib'. edge_curvature_scale: use to make edges more or less curved. Supported only for the engine 'matplotlib'. dpi: DPI of the output image. Not supported for the engine 'pyvis'. nodes_labels: labels to display near nodes edges_labels: labels to display near edges """ GraphVisualizer(graph=self)\ .visualise(save_path=save_path, engine=engine, node_color=node_color, dpi=dpi, node_size_scale=node_size_scale, font_size_scale=font_size_scale, edge_curvature_scale=edge_curvature_scale, nodes_labels=nodes_labels, edges_labels=edges_labels)
@property def graph_description(self) -> Dict: """Return summary characteristics of the graph Returns: dict: containing information about the graph """ return { 'depth': self.depth, 'length': self.length, 'nodes': self.nodes, } @property def descriptive_id(self) -> str: """Returns human-readable identifier of the graph. Returns: str: text description of the content in the node and its parameters """ if self.root_nodes: return self.root_node.descriptive_id else: return sorted(self.nodes, key=lambda x: x.uid)[0].descriptive_id def __str__(self): return str(self.graph_description) def __repr__(self): return self.__str__() def __len__(self): return self.length