|
|
@ -14,6 +14,7 @@
|
|
|
|
* limitations under the License.
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
|
|
|
|
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
|
|
|
|
#include "runtime/device/gpu/cuda_common.h"
|
|
|
|
#include "runtime/device/gpu/cuda_common.h"
|
|
|
|
|
|
|
|
|
|
|
@ -107,69 +108,97 @@ __device__ __forceinline__ int Index(const int &index, const int &dim) { return
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, typename S, typename Func>
|
|
|
|
template <typename T, typename S, typename Func>
|
|
|
|
__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3,
|
|
|
|
__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3,
|
|
|
|
const int &r0, const int &r1, const int &r2, const int &r3,
|
|
|
|
const int &l4, const int &l5, const int &l6, const int &r0,
|
|
|
|
const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
const int &r1, const int &r2, const int &r3, const int &r4,
|
|
|
|
const T *input0, const T *input1, S *output) {
|
|
|
|
const int &r5, const int &r6, const int &d0, const int &d1,
|
|
|
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) {
|
|
|
|
const int &d2, const int &d3, const int &d4, const int &d5,
|
|
|
|
int i = pos / (d1 * d2 * d3) % d0;
|
|
|
|
const int &d6, const T *input0, const T *input1, S *output) {
|
|
|
|
int j = pos / (d2 * d3) % d1;
|
|
|
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
|
|
|
|
int k = pos / d3 % d2;
|
|
|
|
pos += blockDim.x * gridDim.x) {
|
|
|
|
int l = pos % d3;
|
|
|
|
int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
|
|
|
|
|
|
|
|
int j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
|
|
|
|
|
|
|
|
int k = pos / (d3 * d4 * d5 * d6) % d2;
|
|
|
|
|
|
|
|
int l = pos / (d4 * d5 * d6) % d3;
|
|
|
|
|
|
|
|
int m = pos / (d5 * d6) % d4;
|
|
|
|
|
|
|
|
int n = pos / d6 % d5;
|
|
|
|
|
|
|
|
int o = pos % d6;
|
|
|
|
|
|
|
|
|
|
|
|
int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3);
|
|
|
|
int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
|
|
|
|
int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3);
|
|
|
|
l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
|
|
|
|
|
|
|
|
l_index += Index(k, l2) * l3 * l4 * l5 * l6;
|
|
|
|
|
|
|
|
l_index += Index(l, l3) * l4 * l5 * l6;
|
|
|
|
|
|
|
|
l_index += Index(m, l4) * l5 * l6;
|
|
|
|
|
|
|
|
l_index += Index(n, l5) * l6;
|
|
|
|
|
|
|
|
l_index += Index(o, l6);
|
|
|
|
|
|
|
|
int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
|
|
|
|
|
|
|
|
r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
|
|
|
|
|
|
|
|
r_index += Index(k, r2) * r3 * r4 * r5 * r6;
|
|
|
|
|
|
|
|
r_index += Index(l, r3) * r4 * r5 * r6;
|
|
|
|
|
|
|
|
r_index += Index(m, r4) * r5 * r6;
|
|
|
|
|
|
|
|
r_index += Index(n, r5) * r6;
|
|
|
|
|
|
|
|
r_index += Index(o, r6);
|
|
|
|
output[pos] = Func()(input0[l_index], input1[r_index]);
|
|
|
|
output[pos] = Func()(input0[l_index], input1[r_index]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, typename S>
|
|
|
|
template <typename T, typename S>
|
|
|
|
__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1,
|
|
|
|
__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5,
|
|
|
|
const int r2, const int r3, const int d0, const int d1, const int d2, const int d3,
|
|
|
|
const int l6, const int r0, const int r1, const int r2, const int r3, const int r4,
|
|
|
|
enum BroadcastOpType op, const T *input0, const T *input1, S *output) {
|
|
|
|
const int r5, const int r6, const int d0, const int d1, const int d2, const int d3,
|
|
|
|
|
|
|
|
const int d4, const int d5, const int d6, enum BroadcastOpType op, const T *input0,
|
|
|
|
|
|
|
|
const T *input1, S *output) {
|
|
|
|
switch (op) {
|
|
|
|
switch (op) {
|
|
|
|
case BROADCAST_TYPE_GREATER:
|
|
|
|
case BROADCAST_TYPE_GREATER:
|
|
|
|
return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
|
|
|
output);
|
|
|
|
d2, d3, d4, d5, d6, input0, input1, output);
|
|
|
|
case BROADCAST_TYPE_LESS:
|
|
|
|
case BROADCAST_TYPE_LESS:
|
|
|
|
return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
|
|
|
output);
|
|
|
|
d3, d4, d5, d6, input0, input1, output);
|
|
|
|
case BROADCAST_TYPE_MINIMUM:
|
|
|
|
case BROADCAST_TYPE_MINIMUM:
|
|
|
|
return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
|
|
|
output);
|
|
|
|
d2, d3, d4, d5, d6, input0, input1, output);
|
|
|
|
case BROADCAST_TYPE_MAXIMUM:
|
|
|
|
case BROADCAST_TYPE_MAXIMUM:
|
|
|
|
return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
|
|
|
output);
|
|
|
|
d2, d3, d4, d5, d6, input0, input1, output);
|
|
|
|
case BROADCAST_TYPE_POWER:
|
|
|
|
case BROADCAST_TYPE_POWER:
|
|
|
|
return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
|
|
|
output);
|
|
|
|
d2, d3, d4, d5, d6, input0, input1, output);
|
|
|
|
case BROADCAST_TYPE_REALDIV:
|
|
|
|
case BROADCAST_TYPE_REALDIV:
|
|
|
|
return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
|
|
|
output);
|
|
|
|
d2, d3, d4, d5, d6, input0, input1, output);
|
|
|
|
case BROADCAST_TYPE_MUL:
|
|
|
|
case BROADCAST_TYPE_MUL:
|
|
|
|
return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
|
|
|
output);
|
|
|
|
d3, d4, d5, d6, input0, input1, output);
|
|
|
|
case BROADCAST_TYPE_SUB:
|
|
|
|
case BROADCAST_TYPE_SUB:
|
|
|
|
return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
|
|
|
output);
|
|
|
|
d3, d4, d5, d6, input0, input1, output);
|
|
|
|
case BROADCAST_TYPE_ADD:
|
|
|
|
case BROADCAST_TYPE_ADD:
|
|
|
|
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
|
|
|
output);
|
|
|
|
d3, d4, d5, d6, input0, input1, output);
|
|
|
|
case BROADCAST_TYPE_FLOORDIV:
|
|
|
|
case BROADCAST_TYPE_FLOORDIV:
|
|
|
|
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
|
|
|
output);
|
|
|
|
d2, d3, d4, d5, d6, input0, input1, output);
|
|
|
|
case BROADCAST_TYPE_ABSGRAD:
|
|
|
|
case BROADCAST_TYPE_ABSGRAD:
|
|
|
|
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
|
|
|
output);
|
|
|
|
d2, d3, d4, d5, d6, input0, input1, output);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, typename S>
|
|
|
|
template <typename T, typename S>
|
|
|
|
void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2,
|
|
|
|
void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
|
|
|
const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op,
|
|
|
|
const std::vector<int> &output_shape, enum BroadcastOpType op, const T *input0, const T *input1,
|
|
|
|
const T *input0, const T *input1, S *output, cudaStream_t stream) {
|
|
|
|
S *output, cudaStream_t stream) {
|
|
|
|
int size = d0 * d1 * d2 * d3;
|
|
|
|
int size = 1;
|
|
|
|
BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op,
|
|
|
|
for (auto d : output_shape) {
|
|
|
|
input0, input1, output);
|
|
|
|
size *= d;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(lhs_shape[0], lhs_shape[1], lhs_shape[2], lhs_shape[3],
|
|
|
|
|
|
|
|
lhs_shape[4], lhs_shape[5], lhs_shape[6], rhs_shape[0],
|
|
|
|
|
|
|
|
rhs_shape[1], rhs_shape[2], rhs_shape[3], rhs_shape[4],
|
|
|
|
|
|
|
|
rhs_shape[5], rhs_shape[6], output_shape[0],
|
|
|
|
|
|
|
|
output_shape[1], output_shape[2], output_shape[3],
|
|
|
|
|
|
|
|
output_shape[4], output_shape[5], output_shape[6],
|
|
|
|
|
|
|
|
op, input0, input1, output);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, typename S, typename Func>
|
|
|
|
template <typename T, typename S, typename Func>
|
|
|
@ -236,30 +265,24 @@ void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, con
|
|
|
|
output_addr);
|
|
|
|
output_addr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
const std::vector<int> &output_shape, enum BroadcastOpType op, const float *input0,
|
|
|
|
enum BroadcastOpType op, const float *input0, const float *input1, bool *output,
|
|
|
|
const float *input1, bool *output, cudaStream_t stream);
|
|
|
|
cudaStream_t stream);
|
|
|
|
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
|
|
|
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
const std::vector<int> &output_shape, enum BroadcastOpType op, const float *input0,
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
const float *input1, float *output, cudaStream_t stream);
|
|
|
|
enum BroadcastOpType op, const float *input0, const float *input1, float *output,
|
|
|
|
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
|
|
|
cudaStream_t stream);
|
|
|
|
const std::vector<int> &output_shape, enum BroadcastOpType op, const half *input0,
|
|
|
|
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
const half *input1, bool *output, cudaStream_t stream);
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
|
|
|
enum BroadcastOpType op, const half *input0, const half *input1, bool *output,
|
|
|
|
const std::vector<int> &output_shape, enum BroadcastOpType op, const half *input0,
|
|
|
|
cudaStream_t stream);
|
|
|
|
const half *input1, half *output, cudaStream_t stream);
|
|
|
|
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
const std::vector<int> &output_shape, enum BroadcastOpType op, const int *input0,
|
|
|
|
enum BroadcastOpType op, const half *input0, const half *input1, half *output,
|
|
|
|
const int *input1, int *output, cudaStream_t stream);
|
|
|
|
cudaStream_t stream);
|
|
|
|
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
|
|
|
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
const std::vector<int> &output_shape, enum BroadcastOpType op, const int *input0,
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
const int *input1, bool *output, cudaStream_t stream);
|
|
|
|
enum BroadcastOpType op, const int *input0, const int *input1, int *output,
|
|
|
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
|
|
|
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
|
|
|
|
enum BroadcastOpType op, const int *input0, const int *input1, bool *output,
|
|
|
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
|
|
|
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
|
|
|
|
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
|
|
|
|
bool *output, cudaStream_t stream);
|
|
|
|
bool *output, cudaStream_t stream);
|
|
|
|
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
|
|
|
|
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
|
|
|
|