from __future__ import annotations
from abc import abstractmethod
from copy import deepcopy
from typing import TYPE_CHECKING, TypeVar, Generic, Type, Optional, Dict, Any, Callable, Tuple, Sequence, Union
from golem.core.dag.graph import Graph
from golem.core.log import default_log
from golem.core.optimisers.graph import OptGraph, OptNode
from golem.core.adapter.adapt_registry import AdaptRegistry
from golem.core.optimisers.opt_history_objects.individual import Individual
if TYPE_CHECKING:
from golem.core.optimisers.genetic.operators.operator import PopulationT
DomainStructureType = TypeVar('DomainStructureType')
[docs]class BaseOptimizationAdapter(Generic[DomainStructureType]):
def __init__(self, base_graph_class: Type[DomainStructureType] = Graph):
self._log = default_log(self)
self.domain_graph_class = base_graph_class
self.opt_graph_class = OptGraph
[docs] def restore_func(self, fun: Callable) -> Callable:
"""Wraps native function so that it could accept domain graphs as arguments.
Behavior: ``restore( f(Graph)->Graph ) => f'(DomainGraph)->DomainGraph``
Implementation details.
The method wraps callable into a function that transforms its args & return value.
Arguments are transformed by ``adapt`` (that maps domain graphs to internal graphs).
Return value is transformed by ``restore`` (that maps internal graphs to domain graphs).
Args:
fun: native function that accepts native args (i.e. optimization graph)
Returns:
Callable: domain function that can accept domain graphs
"""
return _transform(fun, f_args=self.adapt, f_ret=self.restore)
[docs] def adapt_func(self, fun: Callable) -> Callable:
"""Wraps domain function so that it could accept native optimization graphs
as arguments. If the function was registered as native, it is returned as-is.
``AdaptRegistry`` is responsible for function registration.
Behavior: ``adapt( f(DomainGraph)->DomainGraph ) => f'(Graph)->Graph``
Implementation details.
The method wraps callable into a function that transforms its args & return value.
Arguments are transformed by ``restore`` (that maps internal graphs to domain graphs).
Return value is transformed by ``adapt`` (that maps domain graphs to internal graphs).
Args:
fun: domain function that accepts domain graphs
Returns:
Callable: native function that can accept opt graphs
and be used inside Optimizer
"""
if AdaptRegistry.is_native(fun):
return fun
return _transform(fun, f_args=self.restore, f_ret=self.adapt)
[docs] def adapt(self, item: Union[DomainStructureType, Sequence[DomainStructureType]]) \
-> Union[Graph, Sequence[Graph]]:
"""Maps domain graphs to internal graph representation used by optimizer.
Performs mapping only if argument has a type of domain graph.
Args:
item: a domain graph or sequence of them
Returns:
Graph | Sequence: mapped internal graph or sequence of them
"""
if type(item) is self.domain_graph_class:
return self._adapt(item)
elif isinstance(item, Sequence) and type(item[0]) is self.domain_graph_class:
return [self._adapt(graph) for graph in item]
else:
return item
[docs] def restore(self, item: Union[Graph, Individual, PopulationT]) \
-> Union[DomainStructureType, Sequence[DomainStructureType]]:
"""Maps graphs from internal representation to domain graphs.
Performs mapping only if argument has a type of internal representation.
Args:
item: an internal graph representation or sequence of them
Returns:
Graph | Sequence: mapped domain graph or sequence of them
"""
if type(item) is self.opt_graph_class:
return self._restore(item)
elif isinstance(item, Individual):
return self._restore(item.graph, item.metadata)
elif isinstance(item, Sequence) and isinstance(item[0], Individual):
return [self._restore(ind.graph, ind.metadata) for ind in item]
else:
return item
[docs] @abstractmethod
def _adapt(self, adaptee: DomainStructureType) -> Graph:
"""Implementation of ``adapt`` for single graph."""
raise NotImplementedError()
[docs] @abstractmethod
def _restore(self, opt_graph: Graph, metadata: Optional[Dict[str, Any]] = None) -> DomainStructureType:
"""Implementation of ``restore`` for single graph."""
raise NotImplementedError()
[docs]class IdentityAdapter(BaseOptimizationAdapter[DomainStructureType]):
"""Identity adapter that performs no transformation, returning same graphs."""
def _adapt(self, adaptee: DomainStructureType) -> Graph:
return adaptee
def _restore(self, opt_graph: Graph, metadata: Optional[Dict[str, Any]] = None) -> DomainStructureType:
return opt_graph
[docs]class DirectAdapter(BaseOptimizationAdapter[DomainStructureType]):
"""Naive optimization adapter for arbitrary class that just overwrites __class__."""
def __init__(self,
base_graph_class: Type[DomainStructureType] = OptGraph,
base_node_class: Type = OptNode):
super().__init__(base_graph_class)
self.domain_node_class = base_node_class
def _adapt(self, adaptee: DomainStructureType) -> Graph:
opt_graph = deepcopy(adaptee)
opt_graph.__class__ = self.opt_graph_class
for node in opt_graph.nodes:
node.__class__ = OptNode
return opt_graph
def _restore(self, opt_graph: Graph, metadata: Optional[Dict[str, Any]] = None) -> DomainStructureType:
obj = deepcopy(opt_graph)
obj.__class__ = self.domain_graph_class
for node in obj.nodes:
node.__class__ = self.domain_node_class
return obj