|
|
|
@ -112,14 +112,7 @@ class PartialProgramLayer(layers.Layer):
|
|
|
|
|
self._outputs = NestSequence(outputs, need_check=True)
|
|
|
|
|
self._params = parameters if parameters is not None else []
|
|
|
|
|
|
|
|
|
|
# Check all params from main program can be found in self._params:
|
|
|
|
|
# 1. parameter in self._params should be type `framework.ParamBase` which are created in dygraph.
|
|
|
|
|
# 2. parameter from transformed program shall be found in self._params.
|
|
|
|
|
# Because they share same data with ParamBase of original dygraph.
|
|
|
|
|
self._check_params_all_inited(main_program)
|
|
|
|
|
self._prune_unused_params(main_program)
|
|
|
|
|
|
|
|
|
|
self._infer_program = main_program
|
|
|
|
|
self._infer_program = self._verify_program(main_program)
|
|
|
|
|
self._train_program = self._append_backward_desc()
|
|
|
|
|
# Switch infer or train by train() and eval()
|
|
|
|
|
self._trace_program = None
|
|
|
|
@ -128,6 +121,20 @@ class PartialProgramLayer(layers.Layer):
|
|
|
|
|
# Set default mode to train
|
|
|
|
|
self.train()
|
|
|
|
|
|
|
|
|
|
def _verify_program(self, main_program):
|
|
|
|
|
"""
|
|
|
|
|
Verify that the program parameter is initialized, prune some unused params,
|
|
|
|
|
and remove redundant op callstack.
|
|
|
|
|
"""
|
|
|
|
|
# 1. Check all params from main program can be found in self._params
|
|
|
|
|
self._check_params_all_inited(main_program)
|
|
|
|
|
# 2. Prune the parameters not used anywhere in the program.
|
|
|
|
|
self._prune_unused_params(main_program)
|
|
|
|
|
# 3. Remove op's python call stack with redundant low-level error messages.
|
|
|
|
|
main_program = self._remove_op_call_stack(main_program)
|
|
|
|
|
|
|
|
|
|
return main_program
|
|
|
|
|
|
|
|
|
|
@switch_to_static_graph
|
|
|
|
|
def _append_backward_desc(self):
|
|
|
|
|
program = self._infer_program.clone()
|
|
|
|
@ -295,6 +302,19 @@ class PartialProgramLayer(layers.Layer):
|
|
|
|
|
continue
|
|
|
|
|
param._set_grad_type(grad_var.type())
|
|
|
|
|
|
|
|
|
|
def _remove_op_call_stack(self, main_program):
|
|
|
|
|
"""
|
|
|
|
|
Remove op's python call stack with redundant low-level error messages related to
|
|
|
|
|
transforamtions to avoid confusing users.
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(main_program, framework.Program)
|
|
|
|
|
for block in main_program.blocks:
|
|
|
|
|
for op in block.ops:
|
|
|
|
|
if op.has_attr("op_callstack"):
|
|
|
|
|
op._remove_attr("op_callstack")
|
|
|
|
|
|
|
|
|
|
return main_program
|
|
|
|
|
|
|
|
|
|
def _check_params_all_inited(self, main_program):
|
|
|
|
|
"""
|
|
|
|
|
Check all params from main program are already initialized, see details as follows:
|
|
|
|
|