int64 support and typo fix

fixed typo

fix pylint
pull/9096/head
Peilin Wang 4 years ago
parent bd8522aff7
commit 1dd302ae93

@ -19,13 +19,28 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
DynamicShapeGpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
DynamicShapeGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
DynamicShapeGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
DynamicShapeGpuKernel, bool)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
DynamicShapeGpuKernel, int32_t, int32_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
DynamicShapeGpuKernel, half, int32_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
DynamicShapeGpuKernel, float, int32_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
DynamicShapeGpuKernel, bool, int32_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
DynamicShapeGpuKernel, int32_t, int64_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
DynamicShapeGpuKernel, half, int64_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
DynamicShapeGpuKernel, float, int64_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
DynamicShapeGpuKernel, bool, int64_t)
} // namespace kernel
} // namespace mindspore

@ -26,7 +26,7 @@
namespace mindspore {
namespace kernel {
template <typename T>
template <typename T, typename S>
class DynamicShapeGpuKernel : public GpuKernel {
public:
DynamicShapeGpuKernel() { ResetResource(); }
@ -38,8 +38,8 @@ class DynamicShapeGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
int *output_device_address = GetDeviceAddress<int>(outputs, 0);
size_t prev_node_output_shape_size = prev_node_output_shape_.size() * sizeof(int);
S *output_device_address = GetDeviceAddress<S>(outputs, 0);
size_t prev_node_output_shape_size = prev_node_output_shape_.size() * sizeof(S);
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(output_device_address, prev_node_output_shape_.data(), prev_node_output_shape_size,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
@ -58,9 +58,10 @@ class DynamicShapeGpuKernel : public GpuKernel {
input_size_ = 1;
for (const size_t &e : prev_node_output_shape_tmp) {
input_size_ *= e;
// shapes are Tensors with elements of type int32, but GetPrevNodeOutputInferShape returns vector of size_t,
// so we use an int* for allocated output memory and cast to an int here, otherwise the memcpy will fail with a
// silently.
// shapes are Tensors with elements of type S (int32, or int64) but
// GetPrevNodeOutputInferShape returns vector of size_t, so we use
// an S* for allocated output memory and cast to an integral type here,
// otherwise the memcpy will fail silently.
prev_node_output_shape_.push_back(e);
}
@ -83,13 +84,13 @@ class DynamicShapeGpuKernel : public GpuKernel {
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(output_size_ * sizeof(int));
output_size_list_.push_back(output_size_ * sizeof(S));
}
private:
size_t input_size_;
size_t output_size_;
std::vector<int> prev_node_output_shape_;
std::vector<S> prev_node_output_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

@ -0,0 +1,117 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
import mindspore.nn as nn
import mindspore.context as context
class DynamicShapeNet(nn.Cell):
def __init__(self):
super(DynamicShapeNet, self).__init__()
self.convert_to_dynamic_shape_op = inner.GpuConvertToDynamicShape()
self.dynamic_shape_op = P.DynamicShape()
def construct(self, x):
x_dynamic_shape = self.convert_to_dynamic_shape_op(x)
return self.dynamic_shape_op(x_dynamic_shape)
def dynamic_shape(np_type):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dynamic_shape_net = DynamicShapeNet()
shape = (1,)
x = Tensor(np.zeros(shape).astype(np_type))
ms_out = dynamic_shape_net(x).asnumpy()
expected = np.array(shape)
np.testing.assert_array_equal(ms_out, expected)
shape = (7,)
x = Tensor(np.zeros(shape).astype(np_type))
ms_out = dynamic_shape_net(x).asnumpy()
expected = np.array(shape)
np.testing.assert_array_equal(ms_out, expected)
shape = (1, 1)
x = Tensor(np.zeros(shape).astype(np_type))
ms_out = dynamic_shape_net(x).asnumpy()
expected = np.array(shape)
np.testing.assert_array_equal(ms_out, expected)
shape = (1, 7)
x = Tensor(np.zeros(shape).astype(np_type))
ms_out = dynamic_shape_net(x).asnumpy()
expected = np.array(shape)
np.testing.assert_array_equal(ms_out, expected)
shape = (3, 1)
x = Tensor(np.zeros(shape).astype(np_type))
ms_out = dynamic_shape_net(x).asnumpy()
expected = np.array(shape)
np.testing.assert_array_equal(ms_out, expected)
shape = (2, 4)
x = Tensor(np.zeros(shape).astype(np_type))
ms_out = dynamic_shape_net(x).asnumpy()
expected = np.array(shape)
np.testing.assert_array_equal(ms_out, expected)
shape = (1, 1, 1)
x = Tensor(np.zeros(shape).astype(np_type))
ms_out = dynamic_shape_net(x).asnumpy()
expected = np.array(shape)
np.testing.assert_array_equal(ms_out, expected)
shape = (1, 5, 3)
x = Tensor(np.zeros(shape).astype(np_type))
ms_out = dynamic_shape_net(x).asnumpy()
expected = np.array(shape)
np.testing.assert_array_equal(ms_out, expected)
shape = (2, 3, 1, 3, 1)
x = Tensor(np.zeros(shape).astype(np_type))
ms_out = dynamic_shape_net(x).asnumpy()
expected = np.array(shape)
np.testing.assert_array_equal(ms_out, expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dynamic_shape_int32():
dynamic_shape(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dynamic_shape_float16():
dynamic_shape(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dynamic_shape_float32():
dynamic_shape(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dynamic_shape_bool():
dynamic_shape(np.bool)
Loading…
Cancel
Save