|
|
|
@ -47,25 +47,65 @@ namespace operators {
|
|
|
|
|
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
|
|
|
|
|
* pre=2*3, n=4*5, post=1
|
|
|
|
|
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
|
|
|
|
|
*
|
|
|
|
|
* New parameter: *mid_flag* is added to solve m*n*k & m*1*k
|
|
|
|
|
* broadcast cases.
|
|
|
|
|
* 3. shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1, 4, 5)
|
|
|
|
|
* mid_flag should not be NULL.
|
|
|
|
|
* x.shape(2, 3, 20) * y.shape(2, 1, 20).broadcast(2, 3, 20)
|
|
|
|
|
*/
|
|
|
|
|
inline void get_mid_dims(const framework::DDim &x_dims,
|
|
|
|
|
const framework::DDim &y_dims, const int axis,
|
|
|
|
|
int *pre, int *n, int *post) {
|
|
|
|
|
int *pre, int *n, int *post, int *mid_flag = NULL) {
|
|
|
|
|
*pre = 1;
|
|
|
|
|
*n = 1;
|
|
|
|
|
*post = 1;
|
|
|
|
|
for (int i = 0; i < axis; ++i) {
|
|
|
|
|
(*pre) *= x_dims[i];
|
|
|
|
|
}
|
|
|
|
|
if (mid_flag != NULL) {
|
|
|
|
|
*mid_flag = 0;
|
|
|
|
|
int mid = 0;
|
|
|
|
|
for (int i = 0; i < axis; ++i) {
|
|
|
|
|
(*pre) *= x_dims[i];
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < y_dims.size(); ++i) {
|
|
|
|
|
if (x_dims[i + axis] != y_dims[i]) {
|
|
|
|
|
// only support single y_dims[i] = 1 now.
|
|
|
|
|
PADDLE_ENFORCE_EQ(*mid_flag, 0,
|
|
|
|
|
"Broadcast support y_dims with single 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims[i], 1, "Broadcast dimension mismatch.");
|
|
|
|
|
// m*n*k m*1*k
|
|
|
|
|
for (int j = 0; j < i; ++j) {
|
|
|
|
|
(*pre) *= y_dims[j];
|
|
|
|
|
}
|
|
|
|
|
*n = std::max(x_dims[i + axis], y_dims[i]);
|
|
|
|
|
*mid_flag = 1;
|
|
|
|
|
mid = i;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
(*n) *= y_dims[i];
|
|
|
|
|
}
|
|
|
|
|
if (*mid_flag) {
|
|
|
|
|
for (int i = mid + 1; i < x_dims.size(); ++i) {
|
|
|
|
|
(*post) *= x_dims[i];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
|
|
|
|
|
(*post) *= x_dims[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else { // for fused_elementwise_activation_op. keep the old version.
|
|
|
|
|
for (int i = 0; i < axis; ++i) {
|
|
|
|
|
(*pre) *= x_dims[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < y_dims.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
|
|
|
|
|
"Broadcast dimension mismatch.");
|
|
|
|
|
(*n) *= y_dims[i];
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < y_dims.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
|
|
|
|
|
"Broadcast dimension mismatch.");
|
|
|
|
|
(*n) *= y_dims[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
|
|
|
|
|
(*post) *= x_dims[i];
|
|
|
|
|
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
|
|
|
|
|
(*post) *= x_dims[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -171,7 +211,6 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext>
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -268,6 +307,15 @@ class TransformFunctor {
|
|
|
|
|
MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void RunMidRowWise(int n, int pre, int post) const {
|
|
|
|
|
platform::Transform<DeviceContext> trans;
|
|
|
|
|
for (int i = 0; i < pre; i++) {
|
|
|
|
|
trans(ctx_, x_ + i * n * post, x_ + (i + 1) * n * post,
|
|
|
|
|
RowwiseTransformIterator<T, DeviceContext>(y_ + i * post, post),
|
|
|
|
|
z_ + i * n * post, func_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const T *x_;
|
|
|
|
|
const T *y_;
|
|
|
|
@ -501,6 +549,88 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T *x,
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <typename T, typename DX_OP, typename DY_OP>
|
|
|
|
|
static void ElemwiseGradBroadcastMid2CPU(const T *x, const T *y, const T *out,
|
|
|
|
|
const T *dout, int pre, int n,
|
|
|
|
|
int post, DX_OP dx_op, DY_OP dy_op,
|
|
|
|
|
T *dx, T *dy) {
|
|
|
|
|
for (int i = 0; i < pre; ++i) {
|
|
|
|
|
for (int j = 0; j < n; ++j) {
|
|
|
|
|
for (int k = 0; k < post; ++k) {
|
|
|
|
|
int x_offset = i * n * post + j * post + k;
|
|
|
|
|
int y_offset = i * post + k;
|
|
|
|
|
if (dx != nullptr) {
|
|
|
|
|
dx[x_offset] =
|
|
|
|
|
dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
|
|
|
|
|
}
|
|
|
|
|
if (dy != nullptr) {
|
|
|
|
|
T tmp =
|
|
|
|
|
dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
|
|
|
|
|
if (j == 0) {
|
|
|
|
|
dy[y_offset] = tmp;
|
|
|
|
|
} else {
|
|
|
|
|
dy[y_offset] += tmp;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
template <typename T, typename DX_OP, typename DY_OP>
|
|
|
|
|
static __global__ void ElemwiseGradBroadcastMid2CUDAKernel(
|
|
|
|
|
const T *x, const T *y, const T *out, const T *dout, int pre, int n,
|
|
|
|
|
int post, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
|
|
|
|
|
int j = threadIdx.x;
|
|
|
|
|
int tid = blockIdx.x;
|
|
|
|
|
|
|
|
|
|
T val(0);
|
|
|
|
|
int ttid = tid;
|
|
|
|
|
|
|
|
|
|
while (true) {
|
|
|
|
|
int i = ttid / post;
|
|
|
|
|
int k = ttid % post;
|
|
|
|
|
if (i >= pre) break;
|
|
|
|
|
|
|
|
|
|
int x_offset = i * n * post + j * post + k;
|
|
|
|
|
int y_offset = i * post + k;
|
|
|
|
|
if (dx != nullptr) {
|
|
|
|
|
dx[x_offset] =
|
|
|
|
|
dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dy != nullptr) {
|
|
|
|
|
val += dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ttid += ELEMWISE_MAX_BLOCK_DIM;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dy) {
|
|
|
|
|
int h = n;
|
|
|
|
|
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
|
|
|
|
|
val = paddle::platform::reduceSum(val, j, h);
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
dy[tid] = val;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename DX_OP, typename DY_OP>
|
|
|
|
|
static void ElemwiseGradBroadcastMid2CUDA(cudaStream_t stream, const T *x,
|
|
|
|
|
const T *y, const T *out,
|
|
|
|
|
const T *dout, int pre, int n,
|
|
|
|
|
int post, DX_OP dx_op, DY_OP dy_op,
|
|
|
|
|
T *dx, T *dy) {
|
|
|
|
|
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, n);
|
|
|
|
|
int gird_size = pre * post;
|
|
|
|
|
ElemwiseGradBroadcastMid2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
|
|
|
|
|
x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
|
|
|
|
|
void ElemwiseGradComputeNoBroadcast(
|
|
|
|
|
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
|
|
|
|
@ -533,23 +663,39 @@ void ElemwiseGradComputeWithBroadcast(
|
|
|
|
|
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
|
|
|
|
|
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
|
|
|
|
|
|
|
|
|
|
int pre, n, post;
|
|
|
|
|
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
|
|
|
|
|
if (post == 1) {
|
|
|
|
|
int h = pre;
|
|
|
|
|
int w = n;
|
|
|
|
|
int pre, n, post, mid_flag = 0;
|
|
|
|
|
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &mid_flag);
|
|
|
|
|
if (mid_flag) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(mid_flag, 1, "mid_flag should be no more than 1.");
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
ElemwiseGradBroadcastMid2CUDA(
|
|
|
|
|
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
|
|
|
|
|
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
|
|
|
|
|
dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
ElemwiseGradBroadcastMid2CPU(
|
|
|
|
|
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
|
|
|
|
|
dx_op, dy_op,
|
|
|
|
|
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
}
|
|
|
|
|
} else if (post == 1) {
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
ElemwiseGradBroadcast1CUDA(
|
|
|
|
|
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
|
|
|
|
|
y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
|
|
|
|
|
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, dx_op, dy_op,
|
|
|
|
|
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
ElemwiseGradBroadcast1CPU(
|
|
|
|
|
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op,
|
|
|
|
|
dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
|
|
|
|
|
dx_op, dy_op,
|
|
|
|
|
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
@ -689,9 +835,12 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
|
|
|
|
|
"Axis should be in range [0, x_dims)");
|
|
|
|
|
auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
|
|
|
|
|
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
|
|
|
|
|
|
|
|
|
|
int pre, n, post;
|
|
|
|
|
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
|
|
|
|
|
int pre, n, post, mid_flag = 0;
|
|
|
|
|
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post, &mid_flag);
|
|
|
|
|
if (mid_flag) {
|
|
|
|
|
functor.RunMidRowWise(n, pre, post);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (post == 1) {
|
|
|
|
|
functor.RunRowWise(n, pre);
|
|
|
|
|
return;
|
|
|
|
|