You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/operators/elementwise/elementwise_op_function.h

1815 lines
64 KiB

/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <algorithm>
#include <iterator>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/transform.h"
#ifdef __NVCC__
#include <cuda.h>
#include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
#endif
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
/*
* Out = X ⊙ Y
* If Y's shape does not match X' shape, they will be reshaped.
* For example:
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
* pre=2, n=3*4, post=5
* x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
* pre=2*3, n=4*5, post=1
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
*
* New parameter: *mid_flag* is added to solve m*n*k & m*1*k
* broadcast cases.
* 3. shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1, 4, 5)
* mid_flag should not be NULL.
* x.shape(2, 3, 20) * y.shape(2, 1, 20).broadcast(2, 3, 20)
*/
inline void get_mid_dims(const framework::DDim &x_dims,
const framework::DDim &y_dims, const int axis,
int *pre, int *n, int *post, int *mid_flag = NULL) {
*pre = 1;
*n = 1;
*post = 1;
if (mid_flag != NULL) {
*mid_flag = 0;
int mid = 0;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
if (x_dims[i + axis] != y_dims[i]) {
// only support single y_dims[i] = 1 now.
PADDLE_ENFORCE_EQ(*mid_flag, 0,
"Broadcast support y_dims with single 1.");
PADDLE_ENFORCE_EQ(y_dims[i], 1, "Broadcast dimension mismatch.");
// m*n*k m*1*k
for (int j = 0; j < i; ++j) {
(*pre) *= y_dims[j];
}
*n = std::max(x_dims[i + axis], y_dims[i]);
*mid_flag = 1;
mid = i;
break;
}
(*n) *= y_dims[i];
}
if (*mid_flag) {
for (int i = mid + 1; i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
} else {
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
}
} else { // for fused_elementwise_activation_op. keep the old version.
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
"Broadcast dimension mismatch.");
(*n) *= y_dims[i];
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
}
}
inline framework::DDim trim_trailing_singular_dims(
const framework::DDim &dims) {
// Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) {
if (dims[actual_dims_size - 1] != 1) break;
}
std::vector<int> trim_dims;
trim_dims.resize(actual_dims_size);
for (int i = 0; i < actual_dims_size; ++i) {
trim_dims[i] = dims[i];
}
if (trim_dims.size() == 0) {
return framework::DDim(framework::make_dim());
}
framework::DDim actual_dims = framework::make_ddim(trim_dims);
return actual_dims;
}
template <typename T, typename DeviceContext>
class RowwiseTransformIterator;
template <typename T, typename DeviceContext>
class MidWiseTransformIterator;
// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17
template <typename T>
class RowwiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
T *, T &> {
public:
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
return *this;
}
RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator+(int n) {
while (n-- > 0) {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
return *this;
}
bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) != &(*rhs);
}
const T &operator*() { return ptr_[i_]; }
private:
const T *ptr_;
int i_;
int64_t n_;
};
template <typename T>
class MidWiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
T *, T &> {
public:
MidWiseTransformIterator(const T *ptr, int n, int post)
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
MidWiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
++j_;
if (UNLIKELY(j_ == post_)) {
++i_;
j_ = 0;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
return *this;
}
MidWiseTransformIterator<T, platform::CPUDeviceContext> &operator+(int n) {
while (n-- > 0) {
++j_;
if (UNLIKELY(j_ == post_)) {
++i_;
j_ = 0;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
}
return *this;
}
bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) != &(*rhs);
}
const T &operator*() { return ptr_[i_]; }
private:
const T *ptr_;
int64_t i_;
int64_t j_;
int64_t n_;
int64_t post_;
};
#ifdef __NVCC__
template <typename T>
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
: public thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T *> {
public:
typedef thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T *>
super_t;
HOSTDEVICE RowwiseTransformIterator(const T *x, int n)
: super_t(x), begin_(x), n_(n) {}
friend class thrust::iterator_core_access;
private:
unsigned int n_;
const T *begin_;
HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (this->base() - begin_) % n_);
}
};
template <typename T>
class MidWiseTransformIterator<T, platform::CUDADeviceContext>
: public thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T *> {
public:
typedef thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T *>
super_t;
HOSTDEVICE MidWiseTransformIterator(const T *x, int n, int post)
: super_t(x), begin_(x), n_(n), post_(post) {}
friend class thrust::iterator_core_access;
private:
unsigned int post_;
unsigned int n_;
const T *begin_;
HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (((this->base() - begin_) / post_) % n_));
}
};
#endif
template <typename Functor, typename T, typename DeviceContext,
typename OutType = T>
class TransformFunctor {
public:
TransformFunctor(const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z, const DeviceContext &ctx, Functor func)
: x_(x->data<T>()),
y_(y->data<T>()),
z_(z->mutable_data<OutType>(ctx.GetPlace())),
nx_(x->numel()),
ctx_(ctx),
func_(func) {}
inline void Run() const {
platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_, y_, z_, func_);
}
inline void RunRowWise(int n, int pre) const {
platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
z_, func_);
}
inline void RunMidWise(int n, int pre, int post) const {
platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_,
MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
}
inline void RunMidRowWise(int n, int pre, int post) const {
platform::Transform<DeviceContext> trans;
for (int i = 0; i < pre; i++) {
trans(ctx_, x_ + i * n * post, x_ + (i + 1) * n * post,
RowwiseTransformIterator<T, DeviceContext>(y_ + i * post, post),
z_ + i * n * post, func_);
}
}
private:
const T *x_;
const T *y_;
OutType *z_;
int64_t nx_;
const DeviceContext &ctx_;
Functor func_;
};
template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast {
const T *x_;
const T *y_;
const T *out_;
const T *dout_;
HOSTDEVICE void operator()(size_t i) {
if (dx_ != nullptr) {
dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
}
if (dy_ != nullptr) {
dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
}
}
DX_OP dx_op_;
DY_OP dy_op_;
T *dx_;
T *dy_;
};
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
const T *dout, int h, int w, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int x_offset = i * w + j;
if (dx != nullptr) {
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (i == 0) {
dy[j] = tmp;
} else {
dy[j] += tmp;
}
}
}
}
}
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int j = blockIdx.x;
int i = threadIdx.x;
int tid = threadIdx.x;
T val(0);
do {
int x_offset = i * w + j;
if (dx) {
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy) {
val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
}
#define BLOCK_X 32
#define BLOCK_Y 32
// suppose use 2D block is fast because more parallel
// and memory coalesced
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
__shared__ T sdata[BLOCK_Y][BLOCK_X + 1];
T val(0);
size_t width_stride = gridDim.x * blockDim.x;
size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
size_t full_width =
(w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0);
size_t full_height =
(h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0);
for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
int x_offset = n * w + m;
if (dx && m < w && n < h) {
dx[x_offset] = dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
}
if (dy) {
if (m < w && n < h) {
T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
sdata[threadIdx.y][threadIdx.x] += val;
}
__syncthreads();
}
}
if (dy) {
T my_val = sdata[threadIdx.x][threadIdx.y];
for (int i = warpSize >> 1; i > 0; i >>= 1)
my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
__syncthreads();
if ((threadIdx.x == 0)) {
sdata[0][threadIdx.y] = my_val;
}
__syncthreads();
if (threadIdx.y == 0 && m < w) {
dy[m] = sdata[0][threadIdx.x];
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x,
const T *y, const T *out, const T *dout,
int h, int w, DX_OP dx_op, DY_OP dy_op,
T *dx, T *dy) {
// For small case use 1D block
constexpr int half_walf = 16;
if (w < half_walf || h < half_walf) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
} else {
// suppose perfoemance improves with h increased.
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
FastElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
}
}
#endif
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out,
const T *dout, int pre, int n, int post,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int x_offset = i * n * post + j * post + k;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (i == 0 && k == 0) {
dy[j] = tmp;
} else {
dy[j] += tmp;
}
}
}
}
}
}
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int tid = threadIdx.x;
int j = blockIdx.x;
T val(0);
int ttid = tid;
while (true) {
int i = ttid / post;
int k = ttid % post;
if (i >= pre) break;
int x_offset = i * n * post + j * post + k;
if (dx != nullptr) {
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
ttid += ELEMWISE_MAX_BLOCK_DIM;
}
if (dy) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
}
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T *x,
const T *y, const T *out, const T *dout,
int pre, int n, int post, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
}
#endif
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcastMid2CPU(const T *x, const T *y, const T *out,
const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op,
T *dx, T *dy) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int x_offset = i * n * post + j * post + k;
int y_offset = i * post + k;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp =
dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
if (j == 0) {
dy[y_offset] = tmp;
} else {
dy[y_offset] += tmp;
}
}
}
}
}
}
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcastMid2CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int j = threadIdx.x;
int tid = blockIdx.x;
T val(0);
int ttid = tid;
while (true) {
int i = ttid / post;
int k = ttid % post;
if (i >= pre) break;
int x_offset = i * n * post + j * post + k;
int y_offset = i * post + k;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
val += dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
}
ttid += ELEMWISE_MAX_BLOCK_DIM;
}
if (dy) {
int h = n;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, j, h);
if (threadIdx.x == 0) {
dy[tid] = val;
}
}
}
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcastMid2CUDA(cudaStream_t stream, const T *x,
const T *y, const T *out,
const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op,
T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, n);
int gird_size = pre * post;
ElemwiseGradBroadcastMid2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
}
#endif
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeNoBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim, const framework::Tensor &x,
const framework::Tensor &y, const framework::Tensor &out,
const framework::Tensor &dout, int axis, framework::Tensor *dx,
framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
size_t N = static_cast<size_t>(framework::product(x_dim));
#if !defined(_WIN32)
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N);
#else
platform::ForRange<DeviceContext> for_range(
ctx.device_context<DeviceContext>(), N);
#endif // !_WIN32
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeWithBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim_untrimed, const framework::Tensor &x,
const framework::Tensor &y, const framework::Tensor &out,
const framework::Tensor &dout, int axis, framework::Tensor *dx,
framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post, mid_flag = 0;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &mid_flag);
if (mid_flag) {
PADDLE_ENFORCE_EQ(mid_flag, 1, "mid_flag should be no more than 1.");
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
ElemwiseGradBroadcastMid2CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
ElemwiseGradBroadcastMid2CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
} else if (post == 1) {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
ElemwiseGradBroadcast1CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
ElemwiseGradBroadcast1CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
ElemwiseGradBroadcast2CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
ElemwiseGradBroadcast2CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
}
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor &x, const framework::Tensor &y,
const framework::Tensor &out,
const framework::Tensor &dout, int axis,
framework::Tensor *dx, framework::Tensor *dy,
DX_OP dx_op, DY_OP dy_op) {
const framework::DDim &x_dim = x.dims();
const framework::DDim &y_dim = y.dims();
if (x.dims() == y.dims()) {
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else { // Y is a scalar
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
}
}
// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub.
// explicit gradient can cut off X, Y, Out from gradient op
// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
// elementwise code.
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor &x,
const framework::Tensor &y,
const framework::Tensor &out,
const framework::Tensor &dout, int axis,
framework::Tensor *dx, framework::Tensor *dy,
DX_OP dx_op, DY_OP dy_op) {
if (dy == nullptr) {
const framework::DDim &dx_dims = dout.dims();
auto dy_dims = dx_dims;
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else {
if (dout.dims() == dy->dims()) {
const framework::DDim &dx_dims = dout.dims();
const framework::DDim &dy_dims = dy->dims();
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else { // Y is a scalar
auto dx_dims = dout.dims();
const framework::DDim &dy_dims = dy->dims();
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
}
}
}
// Deprecated
template <typename DeviceContext, typename T, typename functor,
typename broadcastfunctor, typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout, int axis,
framework::Tensor *dx, framework::Tensor *dy) {
auto &place = *ctx.template device_context<DeviceContext>().eigen_device();
auto x_dims = x->dims();
auto y_dims = y->dims();
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
}
if (x_dims == y_dims) {
functor f;
f(place, x, y, out, dx, dy, dout);
return;
}
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
trim_trailing_singular_dims(y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
if (post == 1) {
broadcastfunctor f;
f(place, x, y, out, dx, dy, dout, pre, n);
return;
} else {
broadcast2functor f;
f(place, x, y, out, dx, dy, dout, pre, n, post);
return;
}
}
template <typename Functor, typename DeviceContext, typename T,
typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, int axis, Functor func,
framework::Tensor *z) {
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), func);
auto x_dims = x->dims();
auto y_dims_untrimed = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
"Rank of first input must >= rank of second input.");
if (x_dims == y_dims_untrimed) {
functor.Run();
return;
}
axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post, mid_flag = 0;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post, &mid_flag);
if (mid_flag) {
functor.RunMidRowWise(n, pre, post);
return;
}
if (post == 1) {
functor.RunRowWise(n, pre);
return;
} else {
functor.RunMidWise(n, pre, post);
return;
}
}
// FusedElemwiseAndAct
// --- forward
template <typename T, typename CompoundFunctor, bool KeepIntermediateOut>
struct FusedElemwiseAndActNoBroadcast {
HOSTDEVICE void operator()(size_t i) {
T y_val = y_[i];
T x_val = x_[i];
if (KeepIntermediateOut) {
T intermeidiate_out = compound_functor_.GetIntermediateOut(x_val, y_val);
intermediate_out_[i] = intermeidiate_out;
out_[i] =
compound_functor_.GetOutUseIntermediateOut(x_val, intermeidiate_out);
} else {
out_[i] = compound_functor_.GetOut(x_val, y_val);
}
}
const T *x_;
const T *y_;
CompoundFunctor compound_functor_;
T *out_;
T *intermediate_out_;
};
// FusedElemwiseAndActBroadcast1:
// In this case, X and Y can be reshaped to a matrix.
// For example shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) and axis = -1 or 2,
// X can be reshaped to (6, 20) and Y can be reshaped to (1, 20)
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast1CPU(const T *x, const T *y,
CompoundFunctor compound_functor,
int h, int w, T *out,
T *intermediate_out) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int offset = i * w + j;
T y_val = BcastY ? y[j] : y[offset];
T x_val = BcastY ? x[offset] : x[j];
int64_t intermediate_out_offset;
if (KeepIntermediateOut) {
T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);
if (SameShapeOfIntermediateOutAndOut) {
// for the case of f1(f2(x, y))
intermediate_out_offset = offset;
} else if (BcastY) {
intermediate_out_offset = j;
} else {
intermediate_out_offset = offset;
}
intermediate_out[intermediate_out_offset] = intermeidiate_out;
out[offset] =
compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
} else {
out[offset] = compound_functor.GetOut(x_val, y_val);
}
}
}
}
// FusedElemwiseAndActBroadcast2
// In this case, X and Y can be reshaped to a matrix.
// For example shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4) and axis = 1,
// X can be reshaped to (2, 12, 5) and Y can be reshaped to (1, 12, 1)
// pre = 2, n = 12, post = 5
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast2CPU(const T *x, const T *y, int pre,
int n, int post,
CompoundFunctor compound_functor,
T *out, T *intermediate_out) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int offset = i * n * post + j * post + k;
T y_val = BcastY ? y[j] : y[offset];
T x_val = BcastY ? x[offset] : x[j];
int64_t intermediate_out_offset;
if (KeepIntermediateOut) {
T intermeidiate_out =
compound_functor.GetIntermediateOut(x_val, y_val);
if (SameShapeOfIntermediateOutAndOut) {
// for the case of f1(f2(x, y))
intermediate_out_offset = offset;
} else if (BcastY) {
intermediate_out_offset = j;
} else {
intermediate_out_offset = offset;
}
intermediate_out[intermediate_out_offset] = intermeidiate_out;
out[offset] = compound_functor.GetOutUseIntermediateOut(
x_val, intermeidiate_out);
} else {
out[offset] = compound_functor.GetOut(x_val, y_val);
}
}
}
}
}
#ifdef __NVCC__
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
const T *x, const T *y, int h, int w, CompoundFunctor compound_functor,
T *out, T *intermediate_out) {
int j = blockIdx.x;
int i = threadIdx.x;
while (i < h) {
int offset = i * w + j;
T y_val = BcastY ? y[j] : y[offset];
T x_val = BcastY ? x[offset] : x[j];
int64_t intermediate_out_offset;
if (KeepIntermediateOut) {
T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);
if (SameShapeOfIntermediateOutAndOut) {
// for the case of f1(f2(x, y))
intermediate_out_offset = offset;
} else if (BcastY) {
intermediate_out_offset = j;
} else {
intermediate_out_offset = offset;
}
intermediate_out[intermediate_out_offset] = intermeidiate_out;
out[offset] =
compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
} else {
out[offset] = compound_functor.GetOut(x_val, y_val);
}
i += ELEMWISE_MAX_BLOCK_DIM;
}
}
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast1CUDA(cudaStream_t stream, const T *x,
const T *y,
CompoundFunctor compound_functor,
int h, int w, T *out,
T *intermediate_out) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
FusedElemwiseAndActBroadcast1CUDAKernel<
T, CompoundFunctor, BcastY, KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, h, w, compound_functor, out, intermediate_out);
}
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActBroadcast2CUDAKernel(
const T *x, const T *y, CompoundFunctor compound_functor, int pre, int n,
int post, T *out, T *intermediate_out) {
int tid = threadIdx.x;
int j = blockIdx.x;
while (true) {
int i = tid / post;
int k = tid % post;
if (i >= pre) break;
int offset = i * n * post + j * post + k;
T y_val = BcastY ? y[j] : y[offset];
T x_val = BcastY ? x[offset] : x[j];
int64_t intermediate_out_offset;
if (KeepIntermediateOut) {
T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);
if (SameShapeOfIntermediateOutAndOut) {
// for the case of f1(f2(x, y))
intermediate_out_offset = offset;
} else if (BcastY) {
intermediate_out_offset = j;
} else {
intermediate_out_offset = offset;
}
intermediate_out[intermediate_out_offset] = intermeidiate_out;
out[offset] =
compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
} else {
out[offset] = compound_functor.GetOut(x_val, y_val);
}
tid += ELEMWISE_MAX_BLOCK_DIM;
}
}
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast2CUDA(cudaStream_t stream, const T *x,
const T *y, int pre, int n,
int post,
CompoundFunctor compound_functor,
T *out, T *intermediate_out) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
FusedElemwiseAndActBroadcast2CUDAKernel<
T, CompoundFunctor, BcastY, KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, compound_functor, pre, n, post, out, intermediate_out);
}
#endif
template <typename DeviceContext, typename T, typename CompoundFunctor,
bool KeepIntermediateOut>
void FusedElemwiseAndActComputeNoBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::Tensor &x, const framework::Tensor &y,
CompoundFunctor compound_functor, framework::Tensor *out,
framework::Tensor *intermediate_out) {
size_t N = static_cast<size_t>(framework::product(x_dim));
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N);
for_range(
FusedElemwiseAndActNoBroadcast<T, CompoundFunctor, KeepIntermediateOut>{
x.data<T>(), y.data<T>(), compound_functor,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace())});
}
template <typename DeviceContext, typename T, typename CompoundFunctor,
bool BcastY, bool KeepIntermediateOut,
bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActComputeWithBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim_untrimed, const framework::Tensor &x,
const framework::Tensor &y, CompoundFunctor compound_functor, int axis,
framework::Tensor *out, framework::Tensor *intermediate_out) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
if (post == 1) {
int h = pre;
int w = n;
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActBroadcast1CUDA<T, CompoundFunctor, BcastY,
KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), compound_functor, h, w,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActBroadcast1CPU<T, CompoundFunctor, BcastY,
KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut>(
x.data<T>(), y.data<T>(), compound_functor, h, w,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace()));
}
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActBroadcast2CUDA<T, CompoundFunctor, BcastY,
KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), pre, n, post, compound_functor,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActBroadcast2CPU<T, CompoundFunctor, BcastY,
KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut>(
x.data<T>(), y.data<T>(), pre, n, post, compound_functor,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace()));
}
}
}
// --- backward
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut>
struct FusedElemwiseAndActGradNoBroadcast {
HOSTDEVICE void operator()(size_t i) {
T x_val = x_[i];
T y_val = y_[i];
T out_val = out_[i];
T dout_val = dout_[i];
T intermediate_out_val = UseIntermediateOut
? intermediate_out_[i]
: dx_op_.GetIntermediateOut(x_val, y_val);
if (dx_ != nullptr) {
dx_[i] = dx_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
out_val, dout_val);
}
if (dy_ != nullptr) {
dy_[i] = dy_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
out_val, dout_val);
}
if (dintermediate_ != nullptr) {
dintermediate_[i] = dintermediate_op_.UseIntermediateOut(
x_val, intermediate_out_val, out_val, dout_val);
}
}
const T *x_;
const T *y_;
const T *intermediate_out_;
const T *out_;
const T *dout_;
DX_OP dx_op_;
DY_OP dy_op_;
DIntermediate_OP dintermediate_op_;
T *dx_;
T *dy_;
T *dintermediate_;
};
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
typename DIntermediate_OP, bool UseIntermediateOut>
void FusedElemwiseAndActGradComputeNoBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *intermediate_out,
const framework::Tensor *out, const framework::Tensor *dout, int axis,
framework::Tensor *dx, framework::Tensor *dy,
framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op) {
size_t N = static_cast<size_t>(framework::product(x_dim));
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N);
for_range(
FusedElemwiseAndActGradNoBroadcast<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut>{
x->data<T>(), y->data<T>(),
intermediate_out ? intermediate_out->data<T>() : nullptr,
out->data<T>(), dout->data<T>(), dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
ctx.GetPlace())});
}
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CPU(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int64_t tmp_out_idx, x_idx, y_idx;
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int offset = i * w + j;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
if (i == 0) {
dx[x_idx] = tmp;
} else {
dx[x_idx] += tmp;
}
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
if (BcastY) {
if (i == 0) {
dy[y_idx] = tmp;
} else {
dy[y_idx] += tmp;
}
} else {
dy[y_idx] = tmp;
}
}
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
x[x_idx], intermediate_out[tmp_out_idx], out[offset],
dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx],
out[offset], dout[i]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
if (i == 0) {
d_intermediate[tmp_out_idx] = tmp;
} else {
d_intermediate[tmp_out_idx] += tmp;
}
}
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CPU(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int64_t tmp_out_idx, x_idx, y_idx;
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int offset = i * n * post + j * post + k;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
if (i == 0 && k == 0) {
dx[x_idx] = tmp;
} else {
dx[x_idx] += tmp;
}
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
if (BcastY) {
if (i == 0 && k == 0) {
dy[y_idx] = tmp;
} else {
dy[y_idx] += tmp;
}
} else {
dy[y_idx] = tmp;
}
}
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
x[x_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx],
out[offset], dout[i]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
if (i == 0) {
d_intermediate[tmp_out_idx] = tmp;
} else {
d_intermediate[tmp_out_idx] += tmp;
}
}
}
}
}
}
}
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int j = blockIdx.x;
int i = threadIdx.x;
int tid = threadIdx.x;
T val(0), inter_val(0);
int64_t tmp_out_idx, x_idx, y_idx;
do {
int offset = i * w + j;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp =
UseIntermediateOut
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
val += tmp;
}
}
if (dy != nullptr) {
T tmp =
UseIntermediateOut
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
val += tmp;
} else {
dy[y_idx] = tmp;
}
}
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx], out[offset],
dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
inter_val += tmp;
}
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
if (BcastY) {
if (dy) {
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
} else {
if (dx) {
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
if (!SameShapeOfIntermediateOutAndOut) {
if (d_intermediate) {
inter_val = paddle::platform::reduceSum(inter_val, tid, h);
if (threadIdx.x == 0) {
d_intermediate[j] = inter_val;
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CUDA(
cudaStream_t stream, const T *x, const T *y, const T *intermediate_out,
const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
FusedElemwiseAndActGradBroadcast1CUDAKernel<
T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dintermediate_op,
dx, dy, d_intermediate);
}
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int tid = threadIdx.x;
int j = blockIdx.x;
T val(0), inter_val(0);
int ttid = tid;
int64_t tmp_out_idx, x_idx, y_idx;
while (true) {
int i = ttid / post;
int k = ttid % post;
if (i >= pre) break;
int offset = i * n * post + j * post + k;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp =
UseIntermediateOut
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
val += tmp;
}
}
if (dy != nullptr) {
T tmp =
UseIntermediateOut
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
val += tmp;
} else {
dy[y_idx] = tmp;
}
}
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx], out[offset],
dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
inter_val += tmp;
}
}
ttid += ELEMWISE_MAX_BLOCK_DIM;
}
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
if (BcastY) {
if (dy) {
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
} else {
if (dx) {
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
if (!SameShapeOfIntermediateOutAndOut) {
if (d_intermediate) {
inter_val = paddle::platform::reduceSum(inter_val, tid, h);
if (threadIdx.x == 0) {
d_intermediate[j] = inter_val;
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CUDA(
cudaStream_t stream, const T *x, const T *y, const T *intermediate_out,
const T *out, const T *dout, int pre, int n, int post, DX_OP dx_op,
DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy,
T *dintermediate) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
FusedElemwiseAndActGradBroadcast2CUDAKernel<
T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, intermediate_out, out, dout, pre, n, post, dx_op, dy_op,
dintermediate_op, dx, dy, dintermediate);
}
#endif
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
typename DIntermediate_OP, bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActGradComputeWithBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim_untrimed, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *intermediate_out,
const framework::Tensor *out, const framework::Tensor *dout, int axis,
framework::Tensor *dx, framework::Tensor *dy,
framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
if (post == 1) {
int h = pre;
int w = n;
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActGradBroadcast1CPU<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
x->data<T>(), y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
ctx.GetPlace()));
}
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActGradBroadcast2CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActGradBroadcast2CPU<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
x->data<T>(), y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
ctx.GetPlace()));
}
}
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
typename DIntermediate_OP, bool UseIntermediateOut,
bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActGradComputeEx(
const framework::ExecutionContext &ctx, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *out,
const framework::Tensor *intermediate_out, const framework::Tensor *dout,
int axis, framework::Tensor *dx, framework::Tensor *dy,
framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op) {
const framework::DDim &x_dim = x->dims();
const framework::DDim &y_dim = y->dims();
if (UseIntermediateOut) {
PADDLE_ENFORCE(intermediate_out, "intermediate_out should not be nullptr");
}
if (x_dim == y_dim) {
FusedElemwiseAndActGradComputeNoBroadcast<
DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut>(
ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
dintermediate, dx_op, dy_op, dintermediate_op);
} else { // Y is a scalar
bool bcast_y = x_dim.size() >= y_dim.size();
if (x_dim.size() == y_dim.size()) {
for (int i = 0; i < x_dim.size(); ++i) {
if (x_dim[i] < y_dim[i]) {
bcast_y = false;
break;
}
}
}
// z = f1(x, f2(y))
// z = f1(f2(x, y))
if (bcast_y) { // Y should be broadcast.
FusedElemwiseAndActGradComputeWithBroadcast<
DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut,
true /*BcastY*/, SameShapeOfIntermediateOutAndOut>(
ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
dintermediate, dx_op, dy_op, dintermediate_op);
} else {
FusedElemwiseAndActGradComputeWithBroadcast<
DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut,
false /*BcastY*/, SameShapeOfIntermediateOutAndOut>(
ctx, y_dim, x_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
dintermediate, dx_op, dy_op, dintermediate_op);
}
}
}
template <typename DeviceContext, typename T, typename CompoundFunctor,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor &x,
const framework::Tensor &y, int axis,
CompoundFunctor compound_functor,
framework::Tensor *out,
framework::Tensor *intermediate_out) {
if (KeepIntermediateOut) {
PADDLE_ENFORCE(intermediate_out,
"The save_intermediate_out is opened, "
"intermediate_out should not be nullptr.");
}
const framework::DDim &x_dim = x.dims();
const framework::DDim &y_dim = y.dims();
if (x.dims() == y.dims()) {
FusedElemwiseAndActComputeNoBroadcast<DeviceContext, T, CompoundFunctor,
KeepIntermediateOut>(
ctx, x_dim, x, y, compound_functor, out, intermediate_out);
} else {
// Whether the shape of Y is a continuous subsequence of X,
// For more information please refer to the op's introduction.
bool bcast_y = x.dims().size() >= y.dims().size();
if (x.dims().size() == y.dims().size()) {
for (int i = 0; i < x.dims().size(); ++i) {
if (x.dims()[i] < y.dims()[i]) {
bcast_y = false;
break;
}
}
}
// z = f1(x, f2(y))
// z = f1(f2(x, y))
if (bcast_y) { // Y should be broadcast.
// In this case,
// for 'f2(y)', the shape of intermediate_out should be equal to the
// shape
// of Y.
// for 'f2(x, y)', the shape of intermediate_out should be equal to the
// shape of Out.
// the shape of Out should be equal to the shape of X.
FusedElemwiseAndActComputeWithBroadcast<
DeviceContext, T, CompoundFunctor, true /*BcastY*/,
KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>(
ctx, x_dim /*OutShape*/, y_dim, x, y, compound_functor, axis, out,
intermediate_out);
} else {
// In this case,
// for 'f2(y)', the shape of intermediate_out should be equal to the
// shape
// of Out.
// for 'f2(x, y)', the shape of intermediate_out should be equal to the
// shape of Out.
// the shape of Out should be equal to the shape of Y.
FusedElemwiseAndActComputeWithBroadcast<
DeviceContext, T, CompoundFunctor, false /*BcastY*/,
KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>(
ctx, y_dim /*OutShape*/, x_dim, x, y, compound_functor, axis, out,
intermediate_out);
}
}
}
template <typename DeviceContext, typename T>
static inline void GetDoubleGradSafeTensor(
const framework::ExecutionContext &ctx, const framework::Tensor *x,
const framework::Tensor *ddx, framework::Tensor *ddx_safe) {
if (ddx) {
*ddx_safe = *ddx;
} else {
auto &dev_ctx = ctx.template device_context<DeviceContext>();
*ddx_safe = ctx.AllocateTmpTensor<T, DeviceContext>(x->dims(), dev_ctx);
math::SetConstant<DeviceContext, T> set_zero;
set_zero(ctx.template device_context<DeviceContext>(), ddx_safe,
static_cast<T>(0));
}
}
} // namespace operators
} // namespace paddle