|
|
|
@ -19,10 +19,11 @@
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct MinimumGradFunc {
|
|
|
|
|
__device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) {
|
|
|
|
|
if (x1 < x2) {
|
|
|
|
|
__device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2,
|
|
|
|
|
const T &dy, T *dx1, T *dx2) {
|
|
|
|
|
if (grad_x1 && x1 < x2) {
|
|
|
|
|
atomicAdd(dx1, dy);
|
|
|
|
|
} else {
|
|
|
|
|
} else if (grad_x2 && x1 >= x2) {
|
|
|
|
|
atomicAdd(dx2, dy);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -30,10 +31,11 @@ struct MinimumGradFunc {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct MaximumGradFunc {
|
|
|
|
|
__device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) {
|
|
|
|
|
if (x1 > x2) {
|
|
|
|
|
__device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2,
|
|
|
|
|
const T &dy, T *dx1, T *dx2) {
|
|
|
|
|
if (grad_x1 && x1 > x2) {
|
|
|
|
|
atomicAdd(dx1, dy);
|
|
|
|
|
} else {
|
|
|
|
|
} else if (grad_x2 && x1 <= x2) {
|
|
|
|
|
atomicAdd(dx2, dy);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -45,7 +47,8 @@ template <typename T, typename Func>
|
|
|
|
|
__device__ __forceinline__ void BroadcastGradOperator(const int &l0, const int &l1, const int &l2, const int &l3,
|
|
|
|
|
const int &r0, const int &r1, const int &r2, const int &r3,
|
|
|
|
|
const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
|
const T *x1, const T *x2, const T *dy, T *dx1, T *dx2) {
|
|
|
|
|
const bool &grad_x1, const bool &grad_x2, const T *x1,
|
|
|
|
|
const T *x2, const T *dy, T *dx1, T *dx2) {
|
|
|
|
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) {
|
|
|
|
|
int i = pos / (d1 * d2 * d3) % d0;
|
|
|
|
|
int j = pos / (d2 * d3) % d1;
|
|
|
|
@ -54,69 +57,71 @@ __device__ __forceinline__ void BroadcastGradOperator(const int &l0, const int &
|
|
|
|
|
|
|
|
|
|
int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3);
|
|
|
|
|
int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3);
|
|
|
|
|
Func()(x1[l_index], x2[r_index], dy[pos], dx1 + l_index, dx2 + r_index);
|
|
|
|
|
Func()(x1[l_index], x2[r_index], grad_x1, grad_x2, dy[pos], dx1 + l_index, dx2 + r_index);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void BroadcastGradKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1,
|
|
|
|
|
const int r2, const int r3, const int d0, const int d1, const int d2, const int d3,
|
|
|
|
|
enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1,
|
|
|
|
|
T *dx2) {
|
|
|
|
|
const bool grad_x1, const bool grad_x2, enum BroadcastGradOpType op, const T *x1,
|
|
|
|
|
const T *x2, const T *dy, T *dx1, T *dx2) {
|
|
|
|
|
switch (op) {
|
|
|
|
|
case BROADCAST_GRAD_TYPE_MINIMUM:
|
|
|
|
|
return BroadcastGradOperator<T, MinimumGradFunc<T>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy,
|
|
|
|
|
dx1, dx2);
|
|
|
|
|
return BroadcastGradOperator<T, MinimumGradFunc<T>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, grad_x1,
|
|
|
|
|
grad_x2, x1, x2, dy, dx1, dx2);
|
|
|
|
|
case BROADCAST_GRAD_TYPE_MAXIMUM:
|
|
|
|
|
return BroadcastGradOperator<T, MaximumGradFunc<T>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy,
|
|
|
|
|
dx1, dx2);
|
|
|
|
|
return BroadcastGradOperator<T, MaximumGradFunc<T>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, grad_x1,
|
|
|
|
|
grad_x2, x1, x2, dy, dx1, dx2);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
|
enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2,
|
|
|
|
|
cudaStream_t stream) {
|
|
|
|
|
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const T *x1, const T *x2,
|
|
|
|
|
const T *dy, T *dx1, T *dx2, cudaStream_t stream) {
|
|
|
|
|
int size = d0 * d1 * d2 * d3;
|
|
|
|
|
BroadcastGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op,
|
|
|
|
|
x1, x2, dy, dx1, dx2);
|
|
|
|
|
BroadcastGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3,
|
|
|
|
|
grad_x1, grad_x2, op, x1, x2, dy, dx1, dx2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename Func>
|
|
|
|
|
__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *x1, const T *x2, const T *dy, T *dx1,
|
|
|
|
|
T *dx2) {
|
|
|
|
|
__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const bool &grad_x1, const bool &grad_x2,
|
|
|
|
|
const T *x1, const T *x2, const T *dy, T *dx1, T *dx2) {
|
|
|
|
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) {
|
|
|
|
|
Func()(x1[pos], x2[pos], dy[pos], dx1 + pos, dx2 + pos);
|
|
|
|
|
Func()(x1[pos], x2[pos], grad_x1, grad_x2, dy[pos], dx1 + pos, dx2 + pos);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void NoBroadcastGradKernel(const int nums, enum BroadcastGradOpType op, const T *x1, const T *x2,
|
|
|
|
|
const T *dy, T *dx1, T *dx2) {
|
|
|
|
|
__global__ void NoBroadcastGradKernel(const int nums, const bool grad_x1, const bool grad_x2,
|
|
|
|
|
enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1,
|
|
|
|
|
T *dx2) {
|
|
|
|
|
switch (op) {
|
|
|
|
|
case BROADCAST_GRAD_TYPE_MINIMUM:
|
|
|
|
|
return NoBroadcastOperator<T, MinimumGradFunc<T>>(nums, x1, x2, dy, dx1, dx2);
|
|
|
|
|
return NoBroadcastOperator<T, MinimumGradFunc<T>>(nums, grad_x1, grad_x2, x1, x2, dy, dx1, dx2);
|
|
|
|
|
case BROADCAST_GRAD_TYPE_MAXIMUM:
|
|
|
|
|
return NoBroadcastOperator<T, MaximumGradFunc<T>>(nums, x1, x2, dy, dx1, dx2);
|
|
|
|
|
return NoBroadcastOperator<T, MaximumGradFunc<T>>(nums, grad_x1, grad_x2, x1, x2, dy, dx1, dx2);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1,
|
|
|
|
|
T *dx2, cudaStream_t stream) {
|
|
|
|
|
NoBroadcastGradKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(nums, op, x1, x2, dy, dx1, dx2);
|
|
|
|
|
void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
|
|
|
|
|
const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, cudaStream_t stream) {
|
|
|
|
|
NoBroadcastGradKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(nums, grad_x1, grad_x2, op, x1, x2, dy, dx1, dx2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const float *x1, const float *x2,
|
|
|
|
|
const float *dy, float *dx1, float *dx2, cudaStream_t stream);
|
|
|
|
|
template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const int *x1, const int *x2,
|
|
|
|
|
const int *dy, int *dx1, int *dx2, cudaStream_t stream);
|
|
|
|
|
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
|
|
|
|
|
const float *x1, const float *x2, const float *dy, float *dx1, float *dx2,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
|
|
|
|
|
const int *x1, const int *x2, const int *dy, int *dx1, int *dx2, cudaStream_t stream);
|
|
|
|
|
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
|
enum BroadcastGradOpType op, const float *x1, const float *x2, const float *dy, float *dx1,
|
|
|
|
|
float *dx2, cudaStream_t stream);
|
|
|
|
|
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1,
|
|
|
|
|
const float *x2, const float *dy, float *dx1, float *dx2, cudaStream_t stream);
|
|
|
|
|
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
|
enum BroadcastGradOpType op, const int *x1, const int *x2, const int *dy, int *dx1,
|
|
|
|
|
int *dx2, cudaStream_t stream);
|
|
|
|
|
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const int *x1,
|
|
|
|
|
const int *x2, const int *dy, int *dx1, int *dx2, cudaStream_t stream);
|
|
|
|
|