|
|
|
@ -690,6 +690,7 @@ class Cell(Cell_):
|
|
|
|
|
Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
|
|
|
|
|
"""
|
|
|
|
|
replace = dict()
|
|
|
|
|
|
|
|
|
|
def _updata(param):
|
|
|
|
|
if param in replace:
|
|
|
|
|
return replace[param]
|
|
|
|
@ -1078,6 +1079,10 @@ class GraphKernel(Cell):
|
|
|
|
|
A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when
|
|
|
|
|
enable_graph_kernel in context is set to True.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
auto_prefix (bool): Recursively generate namespaces. Default: True.
|
|
|
|
|
flags (dict) : Set graph flags. Default: None.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> class Relu(nn.GraphKernel):
|
|
|
|
|
... def __init__(self):
|
|
|
|
@ -1088,8 +1093,8 @@ class GraphKernel(Cell):
|
|
|
|
|
... return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, auto_prefix=True, pips=None):
|
|
|
|
|
super(GraphKernel, self).__init__(auto_prefix, pips)
|
|
|
|
|
def __init__(self, auto_prefix=True, flags=None):
|
|
|
|
|
super(GraphKernel, self).__init__(auto_prefix, flags)
|
|
|
|
|
class_name = self.__class__.__name__
|
|
|
|
|
self.add_flags(graph_kernel=class_name)
|
|
|
|
|
|
|
|
|
|