restore the huber_loss_op

test=develop
revert-15207-remove_op_handle_lock_and_fix_var
peizhilin 7 years ago
parent 01c00b07dd
commit e49276e731

@ -104,19 +104,15 @@ class HuberLossGradKernel : public framework::OpKernel<T> {
if (out0) {
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
// MSVC not treat it well when partial template arguments were specified
x_grad.device(place) =
out_grad *
residual.unaryExpr(HuberLossBackward<T>(delta, static_cast<T>(-1.0)));
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, -1.0));
}
if (out1) {
out1->mutable_data<T>(context.GetPlace());
auto y_grad = EigenVector<T>::Flatten(*out1);
// MSVC not treat it well when partial template arguments were specified
y_grad.device(place) =
out_grad *
residual.unaryExpr(HuberLossBackward<T>(delta, static_cast<T>(1.0)));
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, 1.0));
}
}
};

Loading…
Cancel
Save