|
|
|
@ -112,14 +112,14 @@ class PartialProgramLayer(layers.Layer):
|
|
|
|
|
self._outputs = NestSequence(outputs, need_check=True)
|
|
|
|
|
self._params = parameters if parameters is not None else []
|
|
|
|
|
|
|
|
|
|
self._infer_program = self._verify_program(main_program)
|
|
|
|
|
self._train_program = self._append_backward_desc()
|
|
|
|
|
# Switch infer or train by train() and eval()
|
|
|
|
|
self._trace_program = None
|
|
|
|
|
main_program = self._verify_program(main_program)
|
|
|
|
|
self._infer_program = self._clone_for_test(main_program)
|
|
|
|
|
self._train_program = self._append_backward_desc(main_program)
|
|
|
|
|
|
|
|
|
|
self._set_grad_type(self._params)
|
|
|
|
|
self._inner_scope = core.Scope()
|
|
|
|
|
# Set default mode to train
|
|
|
|
|
self.train()
|
|
|
|
|
self.training = True
|
|
|
|
|
|
|
|
|
|
def _verify_program(self, main_program):
|
|
|
|
|
"""
|
|
|
|
@ -136,8 +136,8 @@ class PartialProgramLayer(layers.Layer):
|
|
|
|
|
return main_program
|
|
|
|
|
|
|
|
|
|
@switch_to_static_graph
|
|
|
|
|
def _append_backward_desc(self):
|
|
|
|
|
program = self._infer_program.clone()
|
|
|
|
|
def _append_backward_desc(self, main_program):
|
|
|
|
|
program = main_program.clone()
|
|
|
|
|
targets = []
|
|
|
|
|
for out in self._outputs.tolist():
|
|
|
|
|
if isinstance(out, framework.Variable):
|
|
|
|
@ -165,15 +165,6 @@ class PartialProgramLayer(layers.Layer):
|
|
|
|
|
|
|
|
|
|
self._params = required_params
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
|
# self.training is inherited from layers.Layer
|
|
|
|
|
self.training = True
|
|
|
|
|
self._trace_program = self._train_program
|
|
|
|
|
|
|
|
|
|
def eval(self):
|
|
|
|
|
self.training = False
|
|
|
|
|
self._trace_program = self._infer_program
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs):
|
|
|
|
|
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs)
|
|
|
|
|
|
|
|
|
@ -186,7 +177,7 @@ class PartialProgramLayer(layers.Layer):
|
|
|
|
|
outputs={'Out': valid_vars(out_vars),
|
|
|
|
|
'OutScope': tmp_scope_vec},
|
|
|
|
|
attrs={
|
|
|
|
|
'global_block': self._trace_program.desc.block(0),
|
|
|
|
|
'global_block': self.program.desc.block(0),
|
|
|
|
|
'start_op_index': 0,
|
|
|
|
|
'end_op_index': self._infer_program.desc.block(0).op_size(),
|
|
|
|
|
'is_test': not self.training
|
|
|
|
@ -195,6 +186,10 @@ class PartialProgramLayer(layers.Layer):
|
|
|
|
|
restored_nest_out = self._restore_out(out_vars)
|
|
|
|
|
return self._remove_no_value(restored_nest_out)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def program(self):
|
|
|
|
|
return self._train_program if self.training else self._infer_program
|
|
|
|
|
|
|
|
|
|
def _prepare(self, inputs):
|
|
|
|
|
"""
|
|
|
|
|
Prepare inputs, outputs, attrs.
|
|
|
|
@ -253,6 +248,10 @@ class PartialProgramLayer(layers.Layer):
|
|
|
|
|
|
|
|
|
|
return outs
|
|
|
|
|
|
|
|
|
|
@switch_to_static_graph
|
|
|
|
|
def _clone_for_test(self, main_program):
|
|
|
|
|
return main_program.clone(for_test=True)
|
|
|
|
|
|
|
|
|
|
def _is_no_value(self, var):
|
|
|
|
|
if isinstance(var, core.VarBase):
|
|
|
|
|
if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
|
|
|
|
|