type support for faster rcnn gpu kernels

addressed code review comments

fix cpplint and pylint

trying to fix python ut

fix smoke test
pull/4246/head
Peilin Wang 5 years ago
parent f37a2fa402
commit 3cb3a5c7d8

@ -27,5 +27,14 @@ MS_REG_GPU_KERNEL_ONE(Concat,
MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ConcatV2GpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(Concat,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
ConcatV2GpuFwdKernel, short) // NOLINT
MS_REG_GPU_KERNEL_ONE(Concat,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ConcatV2GpuFwdKernel, char)
MS_REG_GPU_KERNEL_ONE(Concat,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ConcatV2GpuFwdKernel, bool)
} // namespace kernel
} // namespace mindspore

@ -1,33 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
GatherNd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherNdGpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
GatherNd,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherNdGpuFwdKernel, half, int)
MS_REG_GPU_KERNEL_TWO(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherNdGpuFwdKernel, int, int)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
GatherNd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherNdGpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
GatherNd,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherNdGpuFwdKernel, half, int)
MS_REG_GPU_KERNEL_TWO(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherNdGpuFwdKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
GatherNdGpuFwdKernel, short, int) // NOLINT
MS_REG_GPU_KERNEL_TWO(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
GatherNdGpuFwdKernel, char, int)
MS_REG_GPU_KERNEL_TWO(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
GatherNdGpuFwdKernel, bool, int)
} // namespace kernel
} // namespace mindspore

@ -24,5 +24,11 @@ MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16
StridedSliceGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
StridedSliceGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
StridedSliceGpuKernel, short) // NOLINT
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
StridedSliceGpuKernel, char)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
StridedSliceGpuKernel, bool)
} // namespace kernel
} // namespace mindspore

@ -24,5 +24,11 @@ MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFlo
StridedSliceGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
StridedSliceGradGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
StridedSliceGradGpuKernel, short) // NOLINT
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
StridedSliceGradGpuKernel, char)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
StridedSliceGradGpuKernel, bool)
} // namespace kernel
} // namespace mindspore

@ -36,6 +36,26 @@ __global__ void CheckValidKernel(const size_t size, const T *box, const T *img_m
return;
}
template <typename S>
__global__ void CheckValidKernel(const size_t size, const char *box, const char *img_metas, S *valid) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
const size_t left_x = i * 4;
const size_t left_y = i * 4 + 1;
const size_t right_x = i * 4 + 2;
const size_t right_y = i * 4 + 3;
S valid_flag = false;
valid_flag |= !((unsigned int)box[left_x] >= 0);
valid_flag |= !((unsigned int)box[left_y] >= 0);
valid_flag |= !((unsigned int)img_metas[0] * (unsigned int)img_metas[2] - 1 >= (unsigned int)box[right_x]);
valid_flag |= !((unsigned int)img_metas[1] * (unsigned int)img_metas[2] - 1 >= (unsigned int)box[right_y]);
valid[i] = !valid_flag;
}
return;
}
template <typename T, typename S>
void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid, cudaStream_t cuda_stream) {
CheckValidKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, box, img_metas, valid);
@ -45,3 +65,7 @@ template void CheckValid(const size_t &size, const float *box, const float *img_
cudaStream_t cuda_stream);
template void CheckValid(const size_t &size, const half *box, const half *img_metas, bool *valid,
cudaStream_t cuda_stream);
template void CheckValid(const size_t &size, const short *box, const short *img_metas, bool *valid, // NOLINT
cudaStream_t cuda_stream);
template void CheckValid(const size_t &size, const char *box, const char *img_metas, bool *valid,
cudaStream_t cuda_stream);

@ -67,3 +67,15 @@ template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, half** inputs, half* output,
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, short** inputs, short* output, // NOLINT
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, char** inputs, char* output,
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, bool** inputs, bool* output,
cudaStream_t cuda_stream);

@ -63,3 +63,12 @@ template void GatherNd<half, int>(half *input, int *indices, half *output, const
template void GatherNd<int, int>(int *input, int *indices, int *output, const size_t &output_dim0,
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
int *batch_strides, cudaStream_t stream);
template void GatherNd<short, int>(short *input, int *indices, short *output, const size_t &output_dim0, // NOLINT
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
int *batch_strides, cudaStream_t stream);
template void GatherNd<char, int>(char *input, int *indices, char *output, const size_t &output_dim0,
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
int *batch_strides, cudaStream_t stream);
template void GatherNd<bool, int>(bool *input, int *indices, bool *output, const size_t &output_dim0,
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
int *batch_strides, cudaStream_t stream);

@ -163,6 +163,7 @@ template void Slice4DKernel(const int s1, const int s2, const int s3, const int
template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, float *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<half>(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2,
const int l3, const int l4, const int d1, const int d2, const int d3, const int d4,
@ -170,6 +171,7 @@ template void Slice4DKernel(const int s1, const int s2, const int s3, const int
template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, half *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<int>(const size_t input_size, int *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2,
const int l3, const int l4, const int d1, const int d2, const int d3, const int d4,
@ -178,6 +180,31 @@ template void CalSliceGrad<int>(const size_t input_size, const int *dy, const st
const std::vector<int> begin, const std::vector<int> size, int *output,
cudaStream_t cuda_stream);
// NOLINTNEXTLINE
template void FillDeviceArray<short>(const size_t input_size, short *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2,
const int l3, const int l4, const int d1, const int d2, const int d3, const int d4,
const short *input, short *output, cudaStream_t stream); // NOLINT
template void CalSliceGrad<short>(const size_t input_size, const short *dy, const std::vector<int> in_shape, // NOLINT
const std::vector<int> begin, const std::vector<int> size, short *output, // NOLINT
cudaStream_t cuda_stream);
template void FillDeviceArray<char>(const size_t input_size, char *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2,
const int l3, const int l4, const int d1, const int d2, const int d3, const int d4,
const char *input, char *output, cudaStream_t stream);
template void CalSliceGrad<char>(const size_t input_size, const char *dy, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, char *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<bool>(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2,
const int l3, const int l4, const int d1, const int d2, const int d3, const int d4,
const bool *input, bool *output, cudaStream_t stream);
template void CalSliceGrad<bool>(const size_t input_size, const bool *dy, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, bool *output,
cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &output_shape, const float *input,
float *output, cudaStream_t cuda_stream);
@ -187,6 +214,16 @@ template void StridedSlice(const std::vector<size_t> &input_shape, const std::ve
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &output_shape, const int *input,
int *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
// NOLINTNEXTLINE
const std::vector<int> &strides, const std::vector<int> &output_shape, const short *input,
short *output, cudaStream_t cuda_stream); // NOLINT
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &output_shape, const char *input,
char *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &output_shape, const bool *input,
bool *output, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &dx_shape, const float *dy,
@ -197,3 +234,13 @@ template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vect
template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &dx_shape, const int *dy,
int *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin,
// NOLINTNEXTLINE
const std::vector<int> &strides, const std::vector<int> &dx_shape, const short *dy,
short *dx, cudaStream_t cuda_stream); // NOLINT
template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &dx_shape, const char *dy,
char *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &dx_shape, const bool *dy,
bool *dx, cudaStream_t cuda_stream);

@ -1,30 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
CheckValid,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
CheckValidGpuKernel, float, bool)
MS_REG_GPU_KERNEL_TWO(
CheckValid,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
CheckValidGpuKernel, half, bool)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
CheckValid,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
CheckValidGpuKernel, float, bool)
MS_REG_GPU_KERNEL_TWO(
CheckValid,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
CheckValidGpuKernel, half, bool)
MS_REG_GPU_KERNEL_TWO(
CheckValid, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
CheckValidGpuKernel, short, bool) // NOLINT
MS_REG_GPU_KERNEL_TWO(
CheckValid, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
CheckValidGpuKernel, char, bool)
} // namespace kernel
} // namespace mindspore

@ -1304,7 +1304,7 @@ class StridedSliceGrad(PrimitiveWithInfer):
def __infer__(self, dy, shapex, begin, end, strides):
args = {"dy": dy['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
for idx, item in enumerate(shapex['value']):
validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)

@ -879,7 +879,7 @@ class Fill(PrimitiveWithInfer):
validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
for idx, item in enumerate(dims['value']):
validator.check_integer("dims[%d]" % idx, item, 0, Rel.GT, self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int32, mstype.int64,
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_type_same({"value": dtype['value']}, valid_types, self.name)

@ -221,7 +221,7 @@ class CheckValid(PrimitiveWithInfer):
return bboxes_shape[:-1]
def infer_dtype(self, bboxes_type, metas_type):
valid_type = [mstype.float32, mstype.float16]
valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8]
validator.check_tensor_type_same({"bboxes_type": bboxes_type}, valid_type, self.name)
validator.check_tensor_type_same({"metas_type": metas_type}, valid_type, self.name)
return mstype.bool_

@ -16,7 +16,6 @@
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
@ -31,24 +30,57 @@ class NetCheckValid(nn.Cell):
def construct(self, anchor, image_metas):
return self.valid(anchor, image_metas)
def check_valid(nptype):
anchor = np.array([[50, 0, 100, 700], [-2, 2, 8, 100], [10, 20, 300, 2000]], nptype)
image_metas = np.array([768, 1280, 1], nptype)
anchor_box = Tensor(anchor)
image_metas_box = Tensor(image_metas)
expect = np.array([True, False, False], np.bool)
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
boundingbox_decode = NetCheckValid()
output = boundingbox_decode(anchor_box, image_metas_box)
assert np.array_equal(output.asnumpy(), expect)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
boundingbox_decode = NetCheckValid()
output = boundingbox_decode(anchor_box, image_metas_box)
assert np.array_equal(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_check_valid_float32():
check_valid(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_check_valid_float16():
check_valid(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_check_valid_int16():
check_valid(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_boundingbox_decode():
anchor = np.array([[50, 0, 100, 700], [-2, 2, 8, 100], [10, 20, 300, 2000]], np.float32)
image_metas = np.array([768, 1280, 1], np.float32)
anchor_box = Tensor(anchor, mindspore.float32)
image_metas_box = Tensor(image_metas, mindspore.float32)
expect = np.array([True, False, False], np.bool_)
def test_check_valid_uint8():
anchor = np.array([[5, 0, 10, 70], [2, 2, 8, 10], [1, 2, 30, 200]], np.uint8)
image_metas = np.array([76, 128, 1], np.uint8)
anchor_box = Tensor(anchor)
image_metas_box = Tensor(image_metas)
expect = np.array([True, True, False], np.bool)
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
boundingbox_decode = NetCheckValid()
output = boundingbox_decode(anchor_box, image_metas_box)
diff = (output.asnumpy() == expect)
assert (diff == 1).all()
assert np.array_equal(output.asnumpy(), expect)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
boundingbox_decode = NetCheckValid()
output = boundingbox_decode(anchor_box, image_metas_box)
diff = (output.asnumpy() == expect)
assert (diff == 1).all()
assert np.array_equal(output.asnumpy(), expect)

File diff suppressed because it is too large Load Diff

@ -28,28 +28,50 @@ class GatherNdNet(nn.Cell):
def construct(self, x, indices):
return self.gathernd(x, indices)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd0():
x = Tensor(np.arange(3 * 2, dtype=np.float32).reshape(3, 2))
def gathernd0(nptype):
x = Tensor(np.arange(3 * 2, dtype=nptype).reshape(3, 2))
indices = Tensor(np.array([[1, 1], [0, 1]]).astype(np.int32))
expect = np.array([3., 1.])
expect = np.array([3, 1]).astype(nptype)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
gathernd = GatherNdNet()
output = gathernd(x, indices)
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
assert np.array_equal(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd0_float32():
gathernd0(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd0_float16():
gathernd0(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_traning
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd1():
x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
def test_gathernd0_int32():
gathernd0(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd0_int16():
gathernd0(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd0_uint8():
gathernd0(np.uint8)
def gathernd1(nptype):
x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=nptype).reshape(2, 3, 4, 5))
indices = Tensor(np.array([[[[[l, k, j, i] for i in [1, 3, 4]] for j in range(4)]
for k in range(3)] for l in range(2)], dtype='i4'))
expect = np.array([[[[1., 3., 4.],
@ -80,21 +102,45 @@ def test_gathernd1():
[[101., 103., 104.],
[106., 108., 109.],
[111., 113., 114.],
[116., 118., 119.]]]])
[116., 118., 119.]]]]).astype(nptype)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
gather = GatherNdNet()
output = gather(x, indices)
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
assert np.array_equal(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_traning
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd1_float32():
gathernd1(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd1_float16():
gathernd1(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd2():
def test_gathernd1_int32():
gathernd1(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd1_int16():
gathernd1(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd1_uint8():
gathernd1(np.uint8)
def gathernd2(nptype):
x = Tensor(np.array([[4., 5., 4., 1., 5.],
[4., 9., 5., 6., 4.],
[9., 8., 4., 3., 6.],
@ -115,37 +161,48 @@ def test_gathernd2():
gathernd = GatherNdNet()
output = gathernd(x, indices)
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
assert np.array_equal(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_traning
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd3():
x = Tensor(np.array([[4, 5, 4, 1, 5],
[4, 9, 5, 6, 4],
[9, 8, 4, 3, 6],
[0, 4, 2, 2, 8],
[1, 8, 6, 2, 8],
[8, 1, 9, 7, 3],
[7, 9, 2, 5, 7],
[9, 8, 6, 8, 5],
[3, 7, 2, 7, 4],
[4, 2, 8, 2, 9]]
).astype(np.int32))
def test_gathernd2_float32():
gathernd2(np.float32)
indices = Tensor(np.array([[4000], [1], [300000]]).astype(np.int32))
expect = np.array([[0, 0, 0, 0, 0],
[4, 9, 5, 6, 4],
[0, 0, 0, 0, 0]])
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd2_float16():
gathernd2(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd2_int32():
gathernd2(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd2_int16():
gathernd2(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd2_uint8():
gathernd2(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gathernd_bool():
x = Tensor(np.array([[True, False], [False, False]]).astype(np.bool))
indices = Tensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]).astype(np.int32))
expect = np.array([True, False, False, False]).astype(np.bool)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
gathernd = GatherNdNet()
output = gathernd(x, indices)
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
assert np.array_equal(output.asnumpy(), expect)

@ -22,9 +22,6 @@ from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
class StridedSliceNet(nn.Cell):
def __init__(self, begin, end, stride, begin_mask=0, end_mask=0, ellipsis_mask=0):
super(StridedSliceNet, self).__init__()
@ -45,11 +42,11 @@ class GradData(nn.Cell):
def construct(self, x):
return self.grad(self.network)(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad():
x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(np.float32))
def strided_slice_grad(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(nptype))
net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
dx = GradData(net)(x)
expect = np.array([[[[0., 0., 0., 0., 0.],
@ -81,7 +78,7 @@ def test_strided_slice_grad():
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]]])
[0., 0., 0., 0., 0.]]]]).astype(nptype)
assert np.allclose(dx[0].asnumpy(), expect)
net = StridedSliceNet((1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2))
@ -115,7 +112,7 @@ def test_strided_slice_grad():
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]]])
[0., 0., 0., 0., 0.]]]]).astype(nptype)
assert np.allclose(dx[0].asnumpy(), expect)
@ -150,7 +147,7 @@ def test_strided_slice_grad():
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]]])
[0., 0., 0., 0., 0.]]]]).astype(nptype)
assert np.allclose(dx[0].asnumpy(), expect)
# ME infer fault
@ -253,7 +250,7 @@ def test_strided_slice_grad():
[[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.]]]])
[1., 1., 1., 1., 0.]]]]).astype(nptype)
assert np.allclose(dx[0].asnumpy(), expect)
x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(np.float32))
@ -272,10 +269,10 @@ def test_strided_slice_grad():
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]])
[0., 0., 0., 0., 0.]]]).astype(nptype)
assert np.allclose(dx[0].asnumpy(), expect)
x = Tensor(np.arange(0, 1 * 1 * 1 * 2 * 3 * 4 * 5).reshape(1, 1, 1, 2, 3, 4, 5).astype(np.float32))
x = Tensor(np.arange(0, 1 * 1 * 1 * 2 * 3 * 4 * 5).reshape(1, 1, 1, 2, 3, 4, 5).astype(nptype))
net = StridedSliceNet((0, 0, 0, 1, 1, 2, 2), (1, 1, 1, 2, 3, 3, 4), (1, 1, 1, 1, 1, 1, 1))
dx = GradData(net)(x)
expect = np.array([[[[[[[0., 0., 0., 0., 0.],
@ -306,5 +303,29 @@ def test_strided_slice_grad():
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 1., 1., 0.],
[0., 0., 0., 0., 0.]]]]]]])
[0., 0., 0., 0., 0.]]]]]]]).astype(nptype)
assert np.allclose(dx[0].asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_float32():
strided_slice_grad(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_int16():
strided_slice_grad(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_uint8():
strided_slice_grad(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_bool():
strided_slice_grad(np.bool)

@ -20,33 +20,29 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
def strided_slice(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_stridedslice():
x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(np.float32))
x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(nptype))
y = P.StridedSlice()(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
expect = np.array([[[[62, 63],
[67, 68]],
[[82, 83],
[87, 88]]]])
[87, 88]]]]).astype(nptype)
assert np.allclose(y.asnumpy(), expect)
y = P.StridedSlice()(x, (1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2))
expect = np.array([[[[64, 62],
[69, 67]],
[[84, 82],
[89, 87]]]])
[89, 87]]]]).astype(nptype)
assert np.allclose(y.asnumpy(), expect)
y = P.StridedSlice()(x, (1, 0, 0, -1), (2, 2, 2, 1), (1, 1, 1, -1))
expect = np.array([[[[64, 63, 62],
[69, 68, 67]],
[[84, 83, 82],
[89, 88, 87]]]])
[89, 88, 87]]]]).astype(nptype)
assert np.allclose(y.asnumpy(), expect)
# ME infer fault
@ -81,20 +77,20 @@ def test_stridedslice():
[[100, 101, 102, 103],
[105, 106, 107, 108],
[110, 111, 112, 113],
[115, 116, 117, 118]]]])
[115, 116, 117, 118]]]]).astype(nptype)
assert np.allclose(y.asnumpy(), expect)
x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(np.float32))
x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(nptype))
y = P.StridedSlice()(x, (1, 0, 0), (2, -3, 3), (1, 1, 3))
expect = np.array([[[20]]])
expect = np.array([[[20]]]).astype(nptype)
assert np.allclose(y.asnumpy(), expect)
x_np = np.arange(0, 4*5).reshape(4, 5).astype(np.float32)
x_np = np.arange(0, 4*5).reshape(4, 5).astype(nptype)
y = Tensor(x_np)[:, ::-1]
expect = x_np[:, ::-1]
assert np.allclose(y.asnumpy(), expect)
x = Tensor(np.arange(0, 2 * 3 * 4 * 5 * 4 * 3 * 2).reshape(2, 3, 4, 5, 4, 3, 2).astype(np.float32))
x = Tensor(np.arange(0, 2 * 3 * 4 * 5 * 4 * 3 * 2).reshape(2, 3, 4, 5, 4, 3, 2).astype(nptype))
y = P.StridedSlice()(x, (1, 0, 0, 2, 1, 2, 0), (2, 2, 2, 4, 2, 3, 2), (1, 1, 1, 1, 1, 1, 2))
expect = np.array([[[[[[[1498.]]],
[[[1522.]]]],
@ -103,5 +99,29 @@ def test_stridedslice():
[[[[[1978.]]],
[[[2002.]]]],
[[[[2098.]]],
[[[2122.]]]]]]])
[[[2122.]]]]]]]).astype(nptype)
assert np.allclose(y.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_float32():
strided_slice(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_int16():
strided_slice(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_uint8():
strided_slice(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_bool():
strided_slice(np.bool)

Loading…
Cancel
Save