Optimize fused_elewise_activation_grad op. (#18041)

test=develop
revert-18229-add_multi_gpu_install_check
Yiqun Liu 6 years ago committed by GitHub
parent 466254151a
commit 660c1a65f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1005,24 +1005,24 @@ template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut> bool UseIntermediateOut>
struct FusedElemwiseAndActGradNoBroadcast { struct FusedElemwiseAndActGradNoBroadcast {
HOSTDEVICE void operator()(size_t i) { HOSTDEVICE void operator()(size_t i) {
T x_val = x_[i];
T y_val = y_[i];
T out_val = out_[i];
T dout_val = dout_[i];
T intermediate_out_val = UseIntermediateOut
? intermediate_out_[i]
: dx_op_.GetIntermediateOut(x_val, y_val);
if (dx_ != nullptr) { if (dx_ != nullptr) {
dx_[i] = UseIntermediateOut dx_[i] = dx_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
? dx_op_.UseIntermediateOut( out_val, dout_val);
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
: dx_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
} }
if (dy_ != nullptr) { if (dy_ != nullptr) {
dy_[i] = UseIntermediateOut dy_[i] = dy_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
? dy_op_.UseIntermediateOut( out_val, dout_val);
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
: dy_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
} }
if (dintermediate_ != nullptr) { if (dintermediate_ != nullptr) {
dintermediate_[i] = dintermediate_[i] = dintermediate_op_.UseIntermediateOut(
UseIntermediateOut x_val, intermediate_out_val, out_val, dout_val);
? dintermediate_op_.UseIntermediateOut(
x_[i], intermediate_out_[i], out_[i], dout_[i])
: dintermediate_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
} }
} }

@ -74,6 +74,8 @@ struct BinaryCompoundGradDxFunctor {
return dout * d_binary_fun_.Dx(x, intermediate_out); return dout * d_binary_fun_.Dx(x, intermediate_out);
} }
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
private: private:
DBinaryFun d_binary_fun_; DBinaryFun d_binary_fun_;
UnaryFun unary_fun_; UnaryFun unary_fun_;
@ -105,6 +107,8 @@ struct BinaryCompoundGradDyFunctor {
} }
} }
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
private: private:
DBinaryFun d_binary_fun_; DBinaryFun d_binary_fun_;
UnaryFun unary_fun_; UnaryFun unary_fun_;
@ -143,6 +147,8 @@ struct UnaryCompoundGradDxFunctor {
return base * d_binary_fun_.Dx(x, y); return base * d_binary_fun_.Dx(x, y);
} }
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
private: private:
DUnaryFun d_unary_fun_; DUnaryFun d_unary_fun_;
BinaryFun binary_fun_; BinaryFun binary_fun_;
@ -181,6 +187,8 @@ struct UnaryCompoundGradDyFunctor {
return base * d_binary_fun_.Dy(x, y); return base * d_binary_fun_.Dy(x, y);
} }
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
private: private:
DUnaryFun d_unary_fun_; DUnaryFun d_unary_fun_;
BinaryFun binary_fun_; BinaryFun binary_fun_;
@ -203,6 +211,8 @@ struct BinaryCompoundGradDIntermedaiteOutFunctor {
return dout * d_binary_fun_.Dy(x, intermediate_out); return dout * d_binary_fun_.Dy(x, intermediate_out);
} }
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
private: private:
DBinaryFun d_binary_fun_; DBinaryFun d_binary_fun_;
UnaryFun unary_fun_; UnaryFun unary_fun_;
@ -232,6 +242,8 @@ struct UnaryCompoundGradDIntermediateFunctor {
} }
} }
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
private: private:
DUnaryFun d_unary_fun_; DUnaryFun d_unary_fun_;
BinaryFun binary_fun_; BinaryFun binary_fun_;

Loading…
Cancel
Save