add GPU support to RandomChoiceWithMask

pull/3608/head
TFbunny 5 years ago
parent ade60ad3d3
commit ad8a786b07

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
#define BLOCKSIZE 256
#define MAX_DIMENSION 5
template <typename T, typename S>
void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2,
const int &d3, const int &d4, const int &d5, const int &seedc, const int &count,
const T *input, S *output_index, T *output_mask, S *index_buff, S *mask_buff, S *rank_buff,
S *Tnum_buff, S *tmp_buff, curandState *globalState, cudaStream_t stream);
int RcwmRoundUpPower2(int v);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
RandomChoiceWithMask,
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
RandomChoiceWithMaskGpuKernel, bool, int)
}
} // namespace mindspore

@ -0,0 +1,129 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class RandomChoiceWithMaskGpuKernel : public GpuKernel {
public:
RandomChoiceWithMaskGpuKernel() : input_shape_size_(0), seedc_(0), input_size_(1), count_(0), ceil_power2_(0) {}
~RandomChoiceWithMaskGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
S *output_index = GetDeviceAddress<S>(outputs, 0);
T *output_mask = GetDeviceAddress<T>(outputs, 1);
S *index_buff = GetDeviceAddress<S>(workspaces, 0);
S *mask_buff = GetDeviceAddress<S>(workspaces, 1);
S *rank_buff = GetDeviceAddress<S>(workspaces, 2);
S *Tnum_buff = GetDeviceAddress<S>(workspaces, 3);
S *tmp_buff = GetDeviceAddress<S>(workspaces, 4);
void *States = GetDeviceAddress<void *>(workspaces, 5);
curandState *devStates = reinterpret_cast<curandState *>(States);
CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], input_shape_5D_[2],
input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, output_index, output_mask,
index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, devStates,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but RandomChoiceWithMask has 2 outputs.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_shape_size_ = input_shape.size();
if (input_shape_size_ < 1 || input_shape_size_ > MAX_DIMENSION) {
MS_LOG(ERROR) << "Input is " << input_shape_size_
<< "-D, but RandomChoiceWithMask supports only 1-D to 5-D inputs.";
return false;
}
// convert size_t to int
for (auto i = 0; i < input_shape_size_; i++) {
input_shape_5D_.push_back(input_shape[i]);
}
// convert shape to 5D
while (input_shape_5D_.size() != MAX_DIMENSION) {
input_shape_5D_.insert(input_shape_5D_.begin(), 1);
}
// init seedc_
int seed = GetAttr<int>(kernel_node, "seed");
int seed2 = GetAttr<int>(kernel_node, "seed2");
if (seed2 != 0)
seedc_ = seed2;
else if (seed != 0)
seedc_ = seed;
else
seedc_ = time(NULL);
// init memory
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
count_ = GetAttr<int>(kernel_node, "count");
// upper ceiling for input for ceil_power2
ceil_power2_ = RcwmRoundUpPower2(input_size_);
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(count_ * input_shape_size_ * sizeof(S));
output_size_list_.push_back(count_ * sizeof(T));
workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S));
workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
int blocknum = std::ceil(static_cast<float>(ceil_power2_) / BLOCKSIZE);
workspace_size_list_.push_back(blocknum * sizeof(S));
workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState));
}
private:
int input_shape_size_;
int seedc_;
int input_size_;
int count_;
int ceil_power2_;
std::vector<int> input_shape_5D_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_

@ -348,13 +348,13 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
seed2 (int): Random seed2. Default: 0. seed2 (int): Random seed2. Default: 0.
Inputs: Inputs:
- **input_x** (Tensor[bool]) - The input tensor. - **input_x** (Tensor[bool]) - The input tensor. The input tensor rank should be >= 1 and <= 5.
Outputs: Outputs:
Two tensors, the first one is the index tensor and the other one is the mask tensor. Two tensors, the first one is the index tensor and the other one is the mask tensor.
- **index** (Tensor) - The output has shape between 2-D and 5-D. - **index** (Tensor) - The output shape is 2-D.
- **mask** (Tensor) - The output has shape 1-D. - **mask** (Tensor) - The output shape is 1-D.
Examples: Examples:
>>> rnd_choice_mask = P.RandomChoiceWithMask() >>> rnd_choice_mask = P.RandomChoiceWithMask()
@ -372,6 +372,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name) validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name)
return ([self.count, len(x_shape)], [self.count]) return ([self.count, len(x_shape)], [self.count])
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):

@ -0,0 +1,86 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class RCWM_count_in(nn.Cell):
def __init__(self):
super(RCWM_count_in, self).__init__()
self.RCWM_count_in = P.RandomChoiceWithMask(count=4, seed=1)
def construct(self, x):
return self.RCWM_count_in(x)
class RCWM_count_out(nn.Cell):
def __init__(self):
super(RCWM_count_out, self).__init__()
self.RCWM_count_out = P.RandomChoiceWithMask(count=10, seed=1)
def construct(self, x):
return self.RCWM_count_out(x)
class RCWM_3D(nn.Cell):
def __init__(self):
super(RCWM_3D, self).__init__()
self.RCWM_3D = P.RandomChoiceWithMask(count=10, seed=1)
def construct(self, x):
return self.RCWM_3D(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_RCWM_3D():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
input_tensor = Tensor(np.ones([3, 4, 5]).astype(np.bool))
expect1 = [[0, 1, 1], [0, 2, 1], [0, 2, 2], [1, 0, 1], [0, 1, 3], [0, 3, 0], [1, 3, 2], \
[0, 0, 0], [1, 1, 2], [1, 3, 4]]
expect2 = [True, True, True, True, True, True, True, True, True, True]
rcwm = RCWM_3D()
output1, output2 = rcwm(input_tensor)
assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1)
assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_RCWM_count_out():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool))
expect1 = [[0, 2], [2, 2], [2, 1], [2, 0], [0, 0], [3, 3], [2, 3], [1, 3], [0, 0], [0, 0]]
expect2 = [True, True, True, True, True, True, True, True, False, False]
rcwm = RCWM_count_out()
output1, output2 = rcwm(input_tensor)
assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1)
assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_RCWM_count_in():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool))
expect1 = [[0, 2], [2, 2], [2, 1], [2, 0]]
expect2 = [True, True, True, True]
rcwm = RCWM_count_in()
output1, output2 = rcwm(input_tensor)
assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1)
assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2)
Loading…
Cancel
Save