|
|
|
@ -1005,24 +1005,24 @@ template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
|
|
|
|
|
bool UseIntermediateOut>
|
|
|
|
|
struct FusedElemwiseAndActGradNoBroadcast {
|
|
|
|
|
HOSTDEVICE void operator()(size_t i) {
|
|
|
|
|
T x_val = x_[i];
|
|
|
|
|
T y_val = y_[i];
|
|
|
|
|
T out_val = out_[i];
|
|
|
|
|
T dout_val = dout_[i];
|
|
|
|
|
T intermediate_out_val = UseIntermediateOut
|
|
|
|
|
? intermediate_out_[i]
|
|
|
|
|
: dx_op_.GetIntermediateOut(x_val, y_val);
|
|
|
|
|
if (dx_ != nullptr) {
|
|
|
|
|
dx_[i] = UseIntermediateOut
|
|
|
|
|
? dx_op_.UseIntermediateOut(
|
|
|
|
|
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
|
|
|
|
|
: dx_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
|
|
|
|
|
dx_[i] = dx_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
|
|
|
|
|
out_val, dout_val);
|
|
|
|
|
}
|
|
|
|
|
if (dy_ != nullptr) {
|
|
|
|
|
dy_[i] = UseIntermediateOut
|
|
|
|
|
? dy_op_.UseIntermediateOut(
|
|
|
|
|
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
|
|
|
|
|
: dy_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
|
|
|
|
|
dy_[i] = dy_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
|
|
|
|
|
out_val, dout_val);
|
|
|
|
|
}
|
|
|
|
|
if (dintermediate_ != nullptr) {
|
|
|
|
|
dintermediate_[i] =
|
|
|
|
|
UseIntermediateOut
|
|
|
|
|
? dintermediate_op_.UseIntermediateOut(
|
|
|
|
|
x_[i], intermediate_out_[i], out_[i], dout_[i])
|
|
|
|
|
: dintermediate_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
|
|
|
|
|
dintermediate_[i] = dintermediate_op_.UseIntermediateOut(
|
|
|
|
|
x_val, intermediate_out_val, out_val, dout_val);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|