|
|
@ -110,11 +110,13 @@ class PartialProgramLayer(layers.Layer):
|
|
|
|
self._inputs = NestSequence(inputs)
|
|
|
|
self._inputs = NestSequence(inputs)
|
|
|
|
self._outputs = NestSequence(outputs, need_check=True)
|
|
|
|
self._outputs = NestSequence(outputs, need_check=True)
|
|
|
|
self._params = parameters if parameters is not None else []
|
|
|
|
self._params = parameters if parameters is not None else []
|
|
|
|
|
|
|
|
|
|
|
|
# Check all params from main program can be found in self._params:
|
|
|
|
# Check all params from main program can be found in self._params:
|
|
|
|
# 1. parameter in self._params should be type `framework.ParamBase` which are created in dygraph.
|
|
|
|
# 1. parameter in self._params should be type `framework.ParamBase` which are created in dygraph.
|
|
|
|
# 2. parameter from transformed program shall be found in self._params.
|
|
|
|
# 2. parameter from transformed program shall be found in self._params.
|
|
|
|
# Because they share same data with ParamBase of original dygraph.
|
|
|
|
# Because they share same data with ParamBase of original dygraph.
|
|
|
|
self._check_params_all_inited(main_program)
|
|
|
|
self._check_params_all_inited(main_program)
|
|
|
|
|
|
|
|
self._prune_unused_params(main_program)
|
|
|
|
|
|
|
|
|
|
|
|
self._infer_program = main_program
|
|
|
|
self._infer_program = main_program
|
|
|
|
self._train_program = self._append_backward_desc()
|
|
|
|
self._train_program = self._append_backward_desc()
|
|
|
@ -138,6 +140,23 @@ class PartialProgramLayer(layers.Layer):
|
|
|
|
|
|
|
|
|
|
|
|
return program
|
|
|
|
return program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _prune_unused_params(self, program):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Prune the parameters not used anywhere in the program.
|
|
|
|
|
|
|
|
The `@declarative` may only decorated a sub function which
|
|
|
|
|
|
|
|
contains some unused parameters created in `__init__`.
|
|
|
|
|
|
|
|
So prune these parameters to avoid unnecessary operations in
|
|
|
|
|
|
|
|
`run_program_op`.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
required_params = []
|
|
|
|
|
|
|
|
for param in self._params:
|
|
|
|
|
|
|
|
for block in program.blocks:
|
|
|
|
|
|
|
|
if param.name in block.vars:
|
|
|
|
|
|
|
|
required_params.append(param)
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._params = required_params
|
|
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
def train(self):
|
|
|
|
# self.training is inherited from layers.Layer
|
|
|
|
# self.training is inherited from layers.Layer
|
|
|
|
self.training = True
|
|
|
|
self.training = True
|
|
|
|