broadcast, slice, scatter_nd ops optimizer.

pull/4048/head
linqingke 5 years ago
parent 645f11fa59
commit fb405ee6f4

@ -182,30 +182,59 @@ class ArrayReduceGpuKernel : public GpuKernel {
void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) { void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) {
std::vector<int> inputA; std::vector<int> inputA;
std::vector<size_t> outputC_shape = output_shape; std::vector<size_t> outputC_shape = output_shape;
const int split_dim = 4;
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &inputA); ShapeNdTo4d(input_shape, &inputA);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, inputA[0], CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
inputA[1], inputA[2], inputA[3]), inputA[0], inputA[1], inputA[2], inputA[3]),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
} else {
CudnnSetTensorNdDescriptor(input_shape, inputA_descriptor_, data_type_);
for (auto dim : input_shape) {
inputA.emplace_back(SizeToInt(dim));
}
}
if (axis_[0] == -1) { if (axis_[0] == -1) {
outputC_shape.resize(input_shape.size(), 1);
if (outputC_shape.size() <= split_dim) {
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1), cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
if (inputA[0] == 1 && inputA[1] == 1 && inputA[2] == 1 && inputA[3] == 1) { } else {
all_match_ = true; CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_);
}
for (auto dim : inputA) {
if (dim != 1) {
return;
}
} }
all_match_ = true;
return; return;
} }
std::vector<int> outputC;
if (!keep_dims_) { if (!keep_dims_) {
for (auto i : axis_) { for (auto i : axis_) {
(void)(outputC_shape.insert(outputC_shape.begin() + i, 1)); (void)(outputC_shape.insert(outputC_shape.begin() + i, 1));
} }
} }
std::vector<int> outputC;
if (outputC_shape.size() <= split_dim) {
ShapeNdTo4d(outputC_shape, &outputC); ShapeNdTo4d(outputC_shape, &outputC);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
outputC[0], outputC[1], outputC[2], outputC[3]), outputC[0], outputC[1], outputC[2], outputC[3]),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
} else {
CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_);
for (auto dim : outputC_shape) {
outputC.emplace_back(SizeToInt(dim));
}
}
if (inputA == outputC) { if (inputA == outputC) {
all_match_ = true; all_match_ = true;
} }

@ -69,6 +69,10 @@ class ScatterNdGpuFwdKernel : public GpuKernel {
memcpy_flag_ = true; memcpy_flag_ = true;
} }
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemsetAsync(output, static_cast<T>(0.0), output_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemSet failed in ScatterNdGpuFwdKernel::Launch.");
const size_t input_size = input_size_ / sizeof(T); const size_t input_size = input_size_ / sizeof(T);
const size_t output_size = output_size_ / sizeof(T); const size_t output_size = output_size_ / sizeof(T);

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include <vector>
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
@ -107,69 +108,97 @@ __device__ __forceinline__ int Index(const int &index, const int &dim) { return
template <typename T, typename S, typename Func> template <typename T, typename S, typename Func>
__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, __device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3,
const int &r0, const int &r1, const int &r2, const int &r3, const int &l4, const int &l5, const int &l6, const int &r0,
const int &d0, const int &d1, const int &d2, const int &d3, const int &r1, const int &r2, const int &r3, const int &r4,
const T *input0, const T *input1, S *output) { const int &r5, const int &r6, const int &d0, const int &d1,
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { const int &d2, const int &d3, const int &d4, const int &d5,
int i = pos / (d1 * d2 * d3) % d0; const int &d6, const T *input0, const T *input1, S *output) {
int j = pos / (d2 * d3) % d1; for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
int k = pos / d3 % d2; pos += blockDim.x * gridDim.x) {
int l = pos % d3; int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
int j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
int k = pos / (d3 * d4 * d5 * d6) % d2;
int l = pos / (d4 * d5 * d6) % d3;
int m = pos / (d5 * d6) % d4;
int n = pos / d6 % d5;
int o = pos % d6;
int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
l_index += Index(k, l2) * l3 * l4 * l5 * l6;
l_index += Index(l, l3) * l4 * l5 * l6;
l_index += Index(m, l4) * l5 * l6;
l_index += Index(n, l5) * l6;
l_index += Index(o, l6);
int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
r_index += Index(k, r2) * r3 * r4 * r5 * r6;
r_index += Index(l, r3) * r4 * r5 * r6;
r_index += Index(m, r4) * r5 * r6;
r_index += Index(n, r5) * r6;
r_index += Index(o, r6);
output[pos] = Func()(input0[l_index], input1[r_index]); output[pos] = Func()(input0[l_index], input1[r_index]);
} }
} }
template <typename T, typename S> template <typename T, typename S>
__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5,
const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, const int l6, const int r0, const int r1, const int r2, const int r3, const int r4,
enum BroadcastOpType op, const T *input0, const T *input1, S *output) { const int r5, const int r6, const int d0, const int d1, const int d2, const int d3,
const int d4, const int d5, const int d6, enum BroadcastOpType op, const T *input0,
const T *input1, S *output) {
switch (op) { switch (op) {
case BROADCAST_TYPE_GREATER: case BROADCAST_TYPE_GREATER:
return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
output); d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_LESS: case BROADCAST_TYPE_LESS:
return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
output); d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_MINIMUM: case BROADCAST_TYPE_MINIMUM:
return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
output); d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_MAXIMUM: case BROADCAST_TYPE_MAXIMUM:
return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
output); d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_POWER: case BROADCAST_TYPE_POWER:
return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
output); d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_REALDIV: case BROADCAST_TYPE_REALDIV:
return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
output); d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_MUL: case BROADCAST_TYPE_MUL:
return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
output); d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_SUB: case BROADCAST_TYPE_SUB:
return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
output); d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_ADD: case BROADCAST_TYPE_ADD:
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
output); d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_FLOORDIV: case BROADCAST_TYPE_FLOORDIV:
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
output); d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_ABSGRAD: case BROADCAST_TYPE_ABSGRAD:
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
output); d2, d3, d4, d5, d6, input0, input1, output);
} }
} }
template <typename T, typename S> template <typename T, typename S>
void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, const std::vector<int> &output_shape, enum BroadcastOpType op, const T *input0, const T *input1,
const T *input0, const T *input1, S *output, cudaStream_t stream) { S *output, cudaStream_t stream) {
int size = d0 * d1 * d2 * d3; int size = 1;
BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, for (auto d : output_shape) {
input0, input1, output); size *= d;
}
BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(lhs_shape[0], lhs_shape[1], lhs_shape[2], lhs_shape[3],
lhs_shape[4], lhs_shape[5], lhs_shape[6], rhs_shape[0],
rhs_shape[1], rhs_shape[2], rhs_shape[3], rhs_shape[4],
rhs_shape[5], rhs_shape[6], output_shape[0],
output_shape[1], output_shape[2], output_shape[3],
output_shape[4], output_shape[5], output_shape[6],
op, input0, input1, output);
} }
template <typename T, typename S, typename Func> template <typename T, typename S, typename Func>
@ -236,30 +265,24 @@ void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, con
output_addr); output_addr);
} }
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const std::vector<int> &output_shape, enum BroadcastOpType op, const float *input0,
enum BroadcastOpType op, const float *input0, const float *input1, bool *output, const float *input1, bool *output, cudaStream_t stream);
cudaStream_t stream); template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const std::vector<int> &output_shape, enum BroadcastOpType op, const float *input0,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const float *input1, float *output, cudaStream_t stream);
enum BroadcastOpType op, const float *input0, const float *input1, float *output, template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
cudaStream_t stream); const std::vector<int> &output_shape, enum BroadcastOpType op, const half *input0,
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const half *input1, bool *output, cudaStream_t stream);
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
enum BroadcastOpType op, const half *input0, const half *input1, bool *output, const std::vector<int> &output_shape, enum BroadcastOpType op, const half *input0,
cudaStream_t stream); const half *input1, half *output, cudaStream_t stream);
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const std::vector<int> &output_shape, enum BroadcastOpType op, const int *input0,
enum BroadcastOpType op, const half *input0, const half *input1, half *output, const int *input1, int *output, cudaStream_t stream);
cudaStream_t stream); template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const std::vector<int> &output_shape, enum BroadcastOpType op, const int *input0,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const int *input1, bool *output, cudaStream_t stream);
enum BroadcastOpType op, const int *input0, const int *input1, int *output,
cudaStream_t stream);
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
enum BroadcastOpType op, const int *input0, const int *input1, bool *output,
cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
bool *output, cudaStream_t stream); bool *output, cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
#include <vector>
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
enum BroadcastOpType { enum BroadcastOpType {
@ -35,9 +36,9 @@ enum BroadcastOpType {
}; };
template <typename T, typename S> template <typename T, typename S>
void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, const std::vector<int> &output_shape, enum BroadcastOpType op, const T *input0, const T *input1,
const T *input0, const T *input1, S *output, cudaStream_t stream); S *output, cudaStream_t stream);
template <typename T, typename S> template <typename T, typename S>
void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output, void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output,

@ -25,10 +25,10 @@ __global__ void CheckValidKernel(const size_t size, const T *box, const T *img_m
const size_t right_y = i * 4 + 3; const size_t right_y = i * 4 + 3;
S valid_flag = false; S valid_flag = false;
valid_flag |= !(box[left_x] >= 0.f); valid_flag |= !(box[left_x] >= static_cast<T>(0.0));
valid_flag |= !(box[left_y] >= 0.f); valid_flag |= !(box[left_y] >= static_cast<T>(0.0));
valid_flag |= !(img_metas[0] * img_metas[2] - 1.f >= box[right_x]); valid_flag |= !(img_metas[1] * img_metas[2] - static_cast<T>(1.0) >= box[right_x]);
valid_flag |= !(img_metas[1] * img_metas[2] - 1.f >= box[right_y]); valid_flag |= !(img_metas[0] * img_metas[2] - static_cast<T>(1.0) >= box[right_y]);
valid[i] = !valid_flag; valid[i] = !valid_flag;
} }
@ -43,3 +43,5 @@ void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid,
template void CheckValid(const size_t &size, const float *box, const float *img_metas, bool *valid, template void CheckValid(const size_t &size, const float *box, const float *img_metas, bool *valid,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void CheckValid(const size_t &size, const half *box, const half *img_metas, bool *valid,
cudaStream_t cuda_stream);

@ -16,27 +16,26 @@
#include "backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh"
template <typename T> __device__ float CoordinateMax(const float a, const float b) {
__device__ T CoordinateMax(const T a, const T b) {
return (a > b ? a : b); return (a > b ? a : b);
} }
template <typename T> __device__ float CoordinateMin(const float a, const float b) {
__device__ T CoordinateMin(const T a, const T b) {
return (a < b ? a : b); return (a < b ? a : b);
} }
template <typename T> template <typename T>
__global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *iou_results, const size_t mode, __global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *iou_results, const size_t mode,
const size_t input_len_0) { const size_t input_len_0) {
T location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION]; float location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION];
T overlaps_coordinate[IOU_DIMENSION]; float overlaps_coordinate[IOU_DIMENSION];
const T epsilon = 1e-10; const float epsilon = 1e-10;
const float offset = 1.0;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
for (size_t j = 0; j < IOU_DIMENSION; j++) { for (size_t j = 0; j < IOU_DIMENSION; j++) {
location_coordinate[0][j] = box1[(i % input_len_0) * IOU_DIMENSION + j]; location_coordinate[0][j] = static_cast<float>(box1[(i % input_len_0) * IOU_DIMENSION + j]);
location_coordinate[1][j] = box2[(i / input_len_0) * IOU_DIMENSION + j]; location_coordinate[1][j] = static_cast<float>(box2[(i / input_len_0) * IOU_DIMENSION + j]);
} }
overlaps_coordinate[0] = CoordinateMax(location_coordinate[0][0], location_coordinate[1][0]); overlaps_coordinate[0] = CoordinateMax(location_coordinate[0][0], location_coordinate[1][0]);
@ -44,18 +43,18 @@ __global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *io
overlaps_coordinate[2] = CoordinateMin(location_coordinate[0][2], location_coordinate[1][2]); overlaps_coordinate[2] = CoordinateMin(location_coordinate[0][2], location_coordinate[1][2]);
overlaps_coordinate[3] = CoordinateMin(location_coordinate[0][3], location_coordinate[1][3]); overlaps_coordinate[3] = CoordinateMin(location_coordinate[0][3], location_coordinate[1][3]);
T overlaps_w = CoordinateMax(0.f, overlaps_coordinate[2] - overlaps_coordinate[0] + 1); float overlaps_w = CoordinateMax(0.0, overlaps_coordinate[2] - overlaps_coordinate[0] + offset);
T overlaps_h = CoordinateMax(0.f, overlaps_coordinate[3] - overlaps_coordinate[1] + 1); float overlaps_h = CoordinateMax(0.0, overlaps_coordinate[3] - overlaps_coordinate[1] + offset);
T overlaps = overlaps_w * overlaps_h; float overlaps = overlaps_w * overlaps_h;
T area1 = (location_coordinate[0][2] - location_coordinate[0][0] + 1) * (location_coordinate[0][3] - float area1 = (location_coordinate[0][2] - location_coordinate[0][0] + offset) * (location_coordinate[0][3] -
location_coordinate[0][1] + 1); location_coordinate[0][1] + offset);
T area2 = (location_coordinate[1][2] - location_coordinate[1][0] + 1) * (location_coordinate[1][3] - float area2 = (location_coordinate[1][2] - location_coordinate[1][0] + offset) * (location_coordinate[1][3] -
location_coordinate[1][1] + 1); location_coordinate[1][1] + offset);
if (mode == 0) { if (mode == 0) {
iou_results[i] = overlaps / (area1 + area2 - overlaps + epsilon); iou_results[i] = static_cast<T>(overlaps / (area1 + area2 - overlaps + epsilon));
} else { } else {
iou_results[i] = overlaps / (area2 + epsilon); iou_results[i] = static_cast<T>(overlaps / (area2 + epsilon));
} }
} }
@ -70,3 +69,5 @@ void IOU(const size_t &size, const T *box1, const T *box2, T *iou_results, const
template void IOU(const size_t &size, const float *box1, const float *box2, float *iou_results, const size_t &mode, template void IOU(const size_t &size, const float *box1, const float *box2, float *iou_results, const size_t &mode,
const size_t &input_len_0, cudaStream_t cuda_stream); const size_t &input_len_0, cudaStream_t cuda_stream);
template void IOU(const size_t &size, const half *box1, const half *box2, half *iou_results, const size_t &mode,
const size_t &input_len_0, cudaStream_t cuda_stream);

@ -84,6 +84,40 @@ class GpuKernel : public KernelMod {
} }
} }
// set the tensor descriptor for cudnn/cublas
void CudnnSetTensorNdDescriptor(const std::vector<size_t> &shape, cudnnTensorDescriptor_t descriptor,
cudnnDataType_t data_type) {
if (shape.size() < 3) {
MS_EXCEPTION(ValueError) << "cudnnSetTensorNdDescriptor don't support" << shape.size() << "D.";
}
const int nbDims = shape.size();
int *dim = new (std::nothrow) int[nbDims];
if (dim == nullptr) {
MS_LOG(EXCEPTION) << "malloc dim failed.";
}
int *stride = new (std::nothrow) int[nbDims];
if (stride == nullptr) {
MS_LOG(EXCEPTION) << "malloc stride failed.";
}
for (int i = 0; i < nbDims; i++) {
dim[i] = SizeToInt(shape[i]);
stride[i] = 1;
}
for (int i = nbDims - 2; i >= 0; i--) {
stride[i] = stride[i + 1] * SizeToInt(shape[i + 1]);
}
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(descriptor, data_type, nbDims, dim, stride),
"cudnnSetTensorNdDescriptor failed");
delete[] dim;
dim = nullptr;
delete[] stride;
stride = nullptr;
}
// choose the suitable datatype for cudnn/cublas // choose the suitable datatype for cudnn/cublas
inline cudnnDataType_t GetCudnnDataType(const std::string &Type) { inline cudnnDataType_t GetCudnnDataType(const std::string &Type) {
auto type = kCudnnDtypeMap.find(Type); auto type = kCudnnDtypeMap.find(Type);

@ -27,6 +27,7 @@
#include "backend/kernel_compiler/gpu/kernel_constants.h" #include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
constexpr int MAX_DIMS = 7;
template <typename T, typename S> template <typename T, typename S>
class BroadcastOpGpuKernel : public GpuKernel { class BroadcastOpGpuKernel : public GpuKernel {
public: public:
@ -45,9 +46,8 @@ class BroadcastOpGpuKernel : public GpuKernel {
S *output = GetDeviceAddress<S>(outputs, 0); S *output = GetDeviceAddress<S>(outputs, 0);
if (need_broadcast_) { if (need_broadcast_) {
Broadcast(lhs_shape_[0], lhs_shape_[1], lhs_shape_[2], lhs_shape_[3], rhs_shape_[0], rhs_shape_[1], rhs_shape_[2], Broadcast(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
rhs_shape_[3], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], op_type_, lhs, reinterpret_cast<cudaStream_t>(stream_ptr));
rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
} else { } else {
NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr)); NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
} }
@ -60,10 +60,13 @@ class BroadcastOpGpuKernel : public GpuKernel {
auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0); auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0);
need_broadcast_ = IsBroadcast(shape1, shape2); need_broadcast_ = IsBroadcast(shape1, shape2);
if (need_broadcast_ && shape1.size() > 4) { if (need_broadcast_ && shape1.size() > 7) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";
} }
lhs_shape_.resize(MAX_DIMS, 1);
rhs_shape_.resize(MAX_DIMS, 1);
output_shape_.resize(MAX_DIMS, 1);
for (size_t i = 0; i < shape3.size(); i++) { for (size_t i = 0; i < shape3.size(); i++) {
output_shape_[i] = shape3[i]; output_shape_[i] = shape3[i];
output_num_ *= shape3[i]; output_num_ *= shape3[i];
@ -127,9 +130,9 @@ class BroadcastOpGpuKernel : public GpuKernel {
int input1_num_; int input1_num_;
int input2_num_; int input2_num_;
int output_num_; int output_num_;
int lhs_shape_[4] = {1, 1, 1, 1}; std::vector<int> lhs_shape_;
int rhs_shape_[4] = {1, 1, 1, 1}; std::vector<int> rhs_shape_;
int output_shape_[4] = {1, 1, 1, 1}; std::vector<int> output_shape_;
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;

@ -83,12 +83,19 @@ class ActivationGpuFwdKernel : public GpuKernel {
return true; return true;
} }
std::vector<int> shape; std::vector<int> shape;
ShapeNdTo4d(input_shape, &shape);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0),
"cudnnSetActivationDescriptor failed"); "cudnnSetActivationDescriptor failed");
const int split_dim = 4;
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &shape);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]), shape[0], shape[1], shape[2], shape[3]),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
} else {
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
}
InitSizeLists(); InitSizeLists();
return true; return true;
} }

@ -90,12 +90,18 @@ class ActivationGradGpuKernel : public GpuKernel {
return true; return true;
} }
std::vector<int> shape; std::vector<int> shape;
ShapeNdTo4d(input_shape, &shape);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0),
"SetActivationDescriptor failed"); "SetActivationDescriptor failed");
const int split_dim = 4;
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &shape);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]), shape[0], shape[1], shape[2], shape[3]),
"SetTensor4dDescriptor failed"); "SetTensor4dDescriptor failed");
} else {
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
}
InitSizeLists(); InitSizeLists();
return true; return true;

@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_TWO(
CheckValid, CheckValid,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
CheckValidGpuKernel, float, bool) CheckValidGpuKernel, float, bool)
MS_REG_GPU_KERNEL_TWO(
CheckValid,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
CheckValidGpuKernel, half, bool)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -21,5 +21,8 @@ namespace kernel {
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
IOU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), IOU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
IOUGpuKernel, float) IOUGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
IOU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
IOUGpuKernel, half)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -37,8 +37,8 @@ def test_floor_div():
y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32) y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32)
x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(np.float32) x2_np = np.random.randint(1, 5, (2, 1, 1, 4, 9)).astype(np.float32)
y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) y2_np = np.random.randint(1, 5, (2, 3, 4, 4, 9)).astype(np.float32)
x3_np = np.random.randint(1, 5, 1).astype(np.float32) x3_np = np.random.randint(1, 5, 1).astype(np.float32)
y3_np = np.random.randint(1, 5, 1).astype(np.float32) y3_np = np.random.randint(1, 5, 1).astype(np.float32)
x4_np = np.array(768).astype(np.float32) x4_np = np.array(768).astype(np.float32)

@ -70,7 +70,7 @@ x11 = np.random.rand(1, 1, 1, 1).astype(np.float32)
axis11 = (0, 1, 2, 3) axis11 = (0, 1, 2, 3)
keep_dims11 = False keep_dims11 = False
x12 = np.random.rand(2, 3, 4, 4).astype(np.float32) x12 = np.random.rand(2, 3, 4, 4, 5, 6).astype(np.float32)
axis12 = -2 axis12 = -2
keep_dims12 = False keep_dims12 = False

Loading…
Cancel
Save