|
|
|
@ -89,6 +89,7 @@ class Cell(Cell_):
|
|
|
|
|
self._scope = None
|
|
|
|
|
self._phase = 'train'
|
|
|
|
|
self._parameter_layout_dict = {}
|
|
|
|
|
self._parallel_parameter_name_list = ()
|
|
|
|
|
self._create_time = int(time.time() * 1e9)
|
|
|
|
|
self.phase_prefix = ""
|
|
|
|
|
self.parameter_broadcast_done = False
|
|
|
|
@ -213,6 +214,16 @@ class Cell(Cell_):
|
|
|
|
|
raise TypeError("'parameter_layout_dict' must be dict type.")
|
|
|
|
|
self._parameter_layout_dict = value
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def parallel_parameter_name_list(self):
|
|
|
|
|
return self._parallel_parameter_name_list
|
|
|
|
|
|
|
|
|
|
@parallel_parameter_name_list.setter
|
|
|
|
|
def parallel_parameter_name_list(self, value):
|
|
|
|
|
if not isinstance(value, list):
|
|
|
|
|
raise TypeError("'parallel_parameter_name_list' must be list type.")
|
|
|
|
|
self._parallel_parameter_name_list = value
|
|
|
|
|
|
|
|
|
|
def get_func_graph_proto(self):
|
|
|
|
|
"""Return graph binary proto."""
|
|
|
|
|
return _executor._get_func_graph_proto(self, self.phase + "." + str(self.create_time), "anf_ir", True)
|
|
|
|
@ -656,6 +667,28 @@ class Cell(Cell_):
|
|
|
|
|
"""
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def remove_redundant_parameters(self):
|
|
|
|
|
"""Remove the redundant parameters"""
|
|
|
|
|
cells = self.cells_and_names()
|
|
|
|
|
for _, cell in cells:
|
|
|
|
|
params = cell._params.items()
|
|
|
|
|
for param_name, param in list(params):
|
|
|
|
|
if param.name not in self.parallel_parameter_name_list:
|
|
|
|
|
cell._params.pop(param_name)
|
|
|
|
|
logger.info("remove the redundant parameter: %s", param.name)
|
|
|
|
|
continue
|
|
|
|
|
cell_dict = cell.__dict__
|
|
|
|
|
for key in cell_dict:
|
|
|
|
|
if isinstance(cell_dict[key], ParameterTuple):
|
|
|
|
|
param_tuple = cell_dict[key]
|
|
|
|
|
new_param_tuple = []
|
|
|
|
|
for param in param_tuple:
|
|
|
|
|
if param.name not in self.parallel_parameter_name_list:
|
|
|
|
|
logger.info("remove the redundant parameter: %s in ParameterTuple", param.name)
|
|
|
|
|
continue
|
|
|
|
|
new_param_tuple.append(param)
|
|
|
|
|
cell.__dict__[key] = ParameterTuple(new_param_tuple)
|
|
|
|
|
|
|
|
|
|
def init_parameters_data(self, auto_parallel_mode=False):
|
|
|
|
|
"""
|
|
|
|
|
Initialize all parameters and replace the original saved parameters in cell.
|
|
|
|
@ -750,7 +783,7 @@ class Cell(Cell_):
|
|
|
|
|
"""
|
|
|
|
|
Returns all trainable parameters.
|
|
|
|
|
|
|
|
|
|
Returns a list of all trainable parmeters.
|
|
|
|
|
Returns a list of all trainable parameters.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
recurse (bool): Whether contains the trainable parameters of subcells. Default: True.
|
|
|
|
@ -1031,7 +1064,7 @@ class Cell(Cell_):
|
|
|
|
|
Note:
|
|
|
|
|
fn must be defined as the following code. `cell_name` is the name of registered cell.
|
|
|
|
|
`grad_input` is gradient passed to the cell. `grad_output` is the gradient computed and passed to the
|
|
|
|
|
next cell or primitve, which may be modified and returned.
|
|
|
|
|
next cell or primitive, which may be modified and returned.
|
|
|
|
|
hook_fn(cell_name, grad_input, grad_output) -> Tensor or None.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|