changed default max_output_length to 1000000 change docstring fix ci change max_output_length to maxlenpull/10427/head
parent
03e655f14a
commit
507cc4ab15
@ -0,0 +1,54 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cstdint>
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Range,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
DynamicRangeGpuKernel, float)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(Range,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
DynamicRangeGpuKernel, double)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(Range,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicRangeGpuKernel, int32_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(Range,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicRangeGpuKernel, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,121 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class DynamicRangeGpuKernel : public GpuKernel {
|
||||
public:
|
||||
DynamicRangeGpuKernel() { ResetResource(); }
|
||||
~DynamicRangeGpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *range_start = GetDeviceAddress<T>(inputs, 0);
|
||||
T *range_end = GetDeviceAddress<T>(inputs, 1);
|
||||
T *range_delta = GetDeviceAddress<T>(inputs, 2);
|
||||
T *output_device_address = GetDeviceAddress<T>(outputs, 0);
|
||||
int64_t *output_shape_device_address = GetDeviceAddress<int64_t>(workspace, 0);
|
||||
|
||||
stream_ptr_ = stream_ptr;
|
||||
|
||||
CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address,
|
||||
max_output_length_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
// use workspace[0] for actual output shape, we know it must be 1d
|
||||
CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_,
|
||||
cudaMemcpyAsync(&output_shape_, output_shape_device_address, sizeof(int64_t),
|
||||
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to copy gpu memory.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void PostExecute() override {
|
||||
// required synchronize for PostExecute
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaStreamSynchronize failed");
|
||||
|
||||
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(c_node_ptr_, 0)};
|
||||
std::vector<std::vector<size_t>> output_shape = {{(size_t)output_shape_}};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(output_type, output_shape, c_node_ptr_.get());
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
stream_ptr_ = nullptr;
|
||||
c_node_ptr_ = nullptr;
|
||||
output_shape_ = 0;
|
||||
max_output_length_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_count != 3) {
|
||||
MS_LOG(ERROR) << input_count << " inputs were provided, but DynamicRangeGpuKernel expects 3.";
|
||||
return false;
|
||||
}
|
||||
|
||||
max_output_length_ = GetAttr<int64_t>(kernel_node, "maxlen");
|
||||
c_node_ptr_ = kernel_node;
|
||||
InitSizeLists();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
output_size_list_.push_back(max_output_length_ * sizeof(T));
|
||||
|
||||
// this op outputs a 1d tensor, size of one int64_t is enough space to hold the shape.
|
||||
workspace_size_list_.push_back(sizeof(int64_t));
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
void *stream_ptr_;
|
||||
CNodePtr c_node_ptr_;
|
||||
int64_t output_shape_;
|
||||
int64_t max_output_length_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_
|
@ -0,0 +1,76 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dynamic_range_impl.cuh"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__device__ void CheckInputs(const T &start, const T &end, const T &delta) {
|
||||
if (delta == 0) {
|
||||
asm("trap;");
|
||||
}
|
||||
|
||||
if (start < end && delta < 0) {
|
||||
asm("trap;");
|
||||
}
|
||||
|
||||
if (start > end && delta > 0) {
|
||||
asm("trap;");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output,
|
||||
int64_t *output_shape, const int64_t max_output_size) {
|
||||
T start = range_start[0];
|
||||
T end = range_end[0];
|
||||
T delta = range_delta[0];
|
||||
|
||||
CheckInputs(start, end, delta);
|
||||
|
||||
int64_t real_output_shape = static_cast<int64_t>(ceil(static_cast<double>(end - start) / delta));
|
||||
if (real_output_shape > max_output_size) {
|
||||
asm("trap;");
|
||||
}
|
||||
*output_shape = real_output_shape;
|
||||
|
||||
size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (; gt_id < real_output_shape; gt_id += blockDim.x * gridDim.x) {
|
||||
output[gt_id] = gt_id * delta + start;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape,
|
||||
const int64_t max_output_size, cudaStream_t cuda_stream) {
|
||||
Range<<<GET_BLOCKS(max_output_size), GET_THREADS, 0, cuda_stream>>>(range_start, range_end, range_delta,
|
||||
output, output_shape, max_output_size);
|
||||
}
|
||||
|
||||
template void CalRange<int>(const int *range_start, const int *range_end, const int *range_delta, int *output,
|
||||
int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
|
||||
template void CalRange<int64_t>(const int64_t *range_start, const int64_t *range_end, const int64_t *range_delta,
|
||||
int64_t *output, int64_t *output_shape, const int64_t max_output_size,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void CalRange<float>(const float *range_start, const float *range_end, const float *range_delta, float *output,
|
||||
int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
template void CalRange<double>(const double *range_start, const double *range_end, const double *range_delta,
|
||||
double *output, int64_t *output_shape, const int64_t max_output_size,
|
||||
cudaStream_t cuda_stream);
|
@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
template <typename T>
|
||||
void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape,
|
||||
const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_
|
@ -0,0 +1,93 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class RangeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(RangeNet, self).__init__()
|
||||
self.range = P.Range()
|
||||
|
||||
def construct(self, s, e, d):
|
||||
return self.range(s, e, d)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_range_int():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy()
|
||||
np_expected = np.array([2, 3, 4])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(-24, mstype.int32), Tensor(1, mstype.int32), Tensor(4, mstype.int32)).asnumpy()
|
||||
np_expected = np.array([-24, -20, -16, -12, -8, -4, 0])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(8, mstype.int32), Tensor(1, mstype.int32), Tensor(-1, mstype.int32)).asnumpy()
|
||||
np_expected = np.array([8, 7, 6, 5, 4, 3, 2])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(3, mstype.int32), Tensor(-11, mstype.int32), Tensor(-5, mstype.int32)).asnumpy()
|
||||
np_expected = np.array([3, -2, -7])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_range_float():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(2.3, mstype.float32), Tensor(5.5, mstype.float32), Tensor(1.2, mstype.float32)).asnumpy()
|
||||
np_expected = np.array([2.3, 3.5, 4.7])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(-4, mstype.float32), Tensor(-1, mstype.float32), Tensor(1.5, mstype.float32)).asnumpy()
|
||||
np_expected = np.array([-4.0, -2.5])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(8.0, mstype.float32), Tensor(1.0, mstype.float32), Tensor(-1.0, mstype.float32)).asnumpy()
|
||||
np_expected = np.array([8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(1.5, mstype.float32), Tensor(-1, mstype.float32), Tensor(-18.9, mstype.float32)).asnumpy()
|
||||
np_expected = np.array([1.5])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_range_invalid_max_output_length():
|
||||
with pytest.raises(ValueError):
|
||||
_ = P.Range(0)
|
||||
_ = P.Range(-1)
|
||||
_ = P.Range(None)
|
||||
_ = P.Range('5')
|
Loading…
Reference in new issue