fix loss.gradVar

fix_gru_py
chengduoZH 7 years ago
parent 8c6dae7776
commit a3ca4c99eb

@ -145,14 +145,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) {
CreateComputationalOps(&result, *op, places_.size());
// user can customize loss@grad if not use_default_grad_scale_
if (use_default_grad_scale_) {
CreateScaleLossGradOp(&result);
}
is_forwarding = false;
} else {
if (IsScaleLossGradOp(*op)) continue;
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
@ -401,12 +399,6 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
}
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this
return op.OutputArgumentNames().size() == 1 &&
(op.OutputArgumentNames()[0]) == loss_var_name_;
}
bool MultiDevSSAGraphBuilder::IsScaleLossGradOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this
return op.OutputArgumentNames().size() == 1 &&
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);

@ -67,8 +67,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const;
bool IsScaleLossGradOp(const OpDesc &op) const;
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
/**

@ -480,6 +480,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
program.current_block_idx = current_block_idx
program.sync_with_cpp()
# FIXME(zcd): prevent loss.grad optimized by mem_opt.
loss.block.var(_append_grad_suffix_(loss.name)).persistable = True
if parameter_list is not None:
parameters = parameter_list

Loading…
Cancel
Save