fix_gru_py
chengduoZH 7 years ago
parent 8231960fb1
commit 8c6dae7776

@ -145,12 +145,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) {
CreateComputationalOps(&result, *op, places_.size());
// user can customize loss@grad if not use_default_grad_scale_
if (use_default_grad_scale_) {
CreateScaleLossGradOp(&result);
}
is_forwarding = false;
} else {
if (IsScaleLossGradOp(*op)) continue;
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
@ -399,6 +401,12 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
}
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this
return op.OutputArgumentNames().size() == 1 &&
(op.OutputArgumentNames()[0]) == loss_var_name_;
}
bool MultiDevSSAGraphBuilder::IsScaleLossGradOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this
return op.OutputArgumentNames().size() == 1 &&
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);

@ -67,6 +67,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const;
bool IsScaleLossGradOp(const OpDesc &op) const;
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
/**

Loading…
Cancel
Save