Optimize fused_elewise_activation_grad op. (#18041)

test=develop
revert-18229-add_multi_gpu_install_check
Yiqun Liu 6 years ago committed by GitHub
parent 466254151a
commit 660c1a65f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1005,24 +1005,24 @@ template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut>
struct FusedElemwiseAndActGradNoBroadcast {
HOSTDEVICE void operator()(size_t i) {
T x_val = x_[i];
T y_val = y_[i];
T out_val = out_[i];
T dout_val = dout_[i];
T intermediate_out_val = UseIntermediateOut
? intermediate_out_[i]
: dx_op_.GetIntermediateOut(x_val, y_val);
if (dx_ != nullptr) {
dx_[i] = UseIntermediateOut
? dx_op_.UseIntermediateOut(
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
: dx_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
dx_[i] = dx_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
out_val, dout_val);
}
if (dy_ != nullptr) {
dy_[i] = UseIntermediateOut
? dy_op_.UseIntermediateOut(
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
: dy_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
dy_[i] = dy_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
out_val, dout_val);
}
if (dintermediate_ != nullptr) {
dintermediate_[i] =
UseIntermediateOut
? dintermediate_op_.UseIntermediateOut(
x_[i], intermediate_out_[i], out_[i], dout_[i])
: dintermediate_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
dintermediate_[i] = dintermediate_op_.UseIntermediateOut(
x_val, intermediate_out_val, out_val, dout_val);
}
}

@ -74,6 +74,8 @@ struct BinaryCompoundGradDxFunctor {
return dout * d_binary_fun_.Dx(x, intermediate_out);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
@ -105,6 +107,8 @@ struct BinaryCompoundGradDyFunctor {
}
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
@ -143,6 +147,8 @@ struct UnaryCompoundGradDxFunctor {
return base * d_binary_fun_.Dx(x, y);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
@ -181,6 +187,8 @@ struct UnaryCompoundGradDyFunctor {
return base * d_binary_fun_.Dy(x, y);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
@ -203,6 +211,8 @@ struct BinaryCompoundGradDIntermedaiteOutFunctor {
return dout * d_binary_fun_.Dy(x, intermediate_out);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
@ -232,6 +242,8 @@ struct UnaryCompoundGradDIntermediateFunctor {
}
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;

Loading…
Cancel
Save