comment fix docString fix added asserts in test file atop np checks lint lint-2 lint3pull/8928/head
parent
adc8e3e707
commit
a17f76dd1d
@ -0,0 +1,32 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/linspace.cuh"
|
||||
#include <iostream>
|
||||
|
||||
template <typename T>
|
||||
__global__ void LinSpaceKernel(const T *start, const T *stop, const size_t value_count, T *output) {
|
||||
T add_value = ((*stop - *start) / (value_count - 1));
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < value_count; i += gridDim.x * blockDim.x) {
|
||||
output[i] = *start + (add_value * i);
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void calLinSpace(const T *start, const T *stop, const size_t value_count, T *output, cudaStream_t cuda_stream) {
|
||||
LinSpaceKernel<<<GET_BLOCKS(value_count), GET_THREADS, 0, cuda_stream>>>(start, stop, value_count, output);
|
||||
}
|
||||
template void calLinSpace<float>(const float *start, const float *stop, const size_t value_count, float *output,
|
||||
cudaStream_t cuda_stream);
|
@ -0,0 +1,23 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_LINSPACE_IMPL_CU_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_LINSPACE_IMPL_CU_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void calLinSpace(const T *start, const T *stop, const size_t value_count, T *output, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_LINSPACE_IMPL_CU_H_
|
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/linspace.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(LinSpace,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
LinSpaceGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,102 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LINSPACE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LINSPACE_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/linspace.cuh"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class LinSpaceGpuKernel : public GpuKernel {
|
||||
public:
|
||||
LinSpaceGpuKernel() { ResetResource(); }
|
||||
~LinSpaceGpuKernel() = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
T *start_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *stop_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
calLinSpace(start_addr, stop_addr, value_count_, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 3) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but DynamicLinSpace needs 3 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but DynamicLinSpace needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_1 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
auto input_2 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
|
||||
// error checking input data
|
||||
if ((input_1.size() != 0) || (input_2.size() != 0)) {
|
||||
MS_LOG(ERROR) << "For LinShape "
|
||||
<< "both start and end must be 0-D Tensors. Got " << input_1.size() << " and " << input_2.size()
|
||||
<< ".";
|
||||
return false;
|
||||
}
|
||||
auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
if (value_count.size() != 1) {
|
||||
MS_LOG(ERROR) << "For LinShape, output shape incorrect rank. Expect Rank: 1, got Rank: " << value_count.size()
|
||||
<< ".";
|
||||
}
|
||||
value_count_ = value_count[0];
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
value_count_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(sizeof(T)); // Scalar tensor
|
||||
input_size_list_.push_back(sizeof(T)); // Scalar tensor
|
||||
output_size_list_.push_back(value_count_ * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t value_count_;
|
||||
int num_input_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LINSPACE_GPU_KERNEL_H_
|
@ -0,0 +1,99 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.context as context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class LinSpaceNet(Cell):
|
||||
def __init__(self, num):
|
||||
super(LinSpaceNet, self).__init__()
|
||||
self.ls_op = P.LinSpace()
|
||||
self.num = num
|
||||
|
||||
def construct(self, start, stop):
|
||||
output = self.ls_op(start, stop, self.num)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lin_space_1():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
start_np = 5
|
||||
stop_np = 150
|
||||
num_np = 12
|
||||
start = Tensor(start_np, dtype=mstype.float32)
|
||||
stop = Tensor(stop_np, dtype=mstype.float32)
|
||||
num = num_np
|
||||
ls_op = P.LinSpace()
|
||||
result_ms = ls_op(start, stop, num).asnumpy()
|
||||
result_np = np.linspace(start_np, stop_np, num_np)
|
||||
assert np.allclose(result_ms, result_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lin_shape_2():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
start_np = -25
|
||||
stop_np = 147
|
||||
num_np = 10
|
||||
start = Tensor(start_np, dtype=mstype.float32)
|
||||
stop = Tensor(stop_np, dtype=mstype.float32)
|
||||
num = num_np
|
||||
ls_op = P.LinSpace()
|
||||
result_ms = ls_op(start, stop, num).asnumpy()
|
||||
result_np = np.linspace(start_np, stop_np, num_np)
|
||||
assert np.allclose(result_ms, result_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lin_shape_3():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
start_np = 25
|
||||
stop_np = -147
|
||||
num_np = 20
|
||||
start = Tensor(start_np, dtype=mstype.float32)
|
||||
stop = Tensor(stop_np, dtype=mstype.float32)
|
||||
net = LinSpaceNet(num_np)
|
||||
result_ms = net(start, stop).asnumpy()
|
||||
result_np = np.linspace(start_np, stop_np, num_np)
|
||||
assert np.allclose(result_ms, result_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lin_shape_4():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
start_np = -25.3
|
||||
stop_np = -147
|
||||
num_np = 36
|
||||
start = Tensor(start_np, dtype=mstype.float32)
|
||||
stop = Tensor(stop_np, dtype=mstype.float32)
|
||||
net = LinSpaceNet(num_np)
|
||||
result_ms = net(start, stop).asnumpy()
|
||||
result_np = np.linspace(start_np, stop_np, num_np)
|
||||
assert np.allclose(result_ms, result_np)
|
Loading…
Reference in new issue