From 824418c8f684aea1353f1bcd999841b6603303c0 Mon Sep 17 00:00:00 2001 From: lishixing3 Date: Thu, 3 Dec 2020 14:21:54 +0800 Subject: [PATCH] add dropout --- .../kernel_compiler/cpu/dropout_cpu_kernel.cc | 79 ++++++++++++++++ .../kernel_compiler/cpu/dropout_cpu_kernel.h | 60 ++++++++++++ mindspore/nn/layer/basic.py | 6 +- tests/st/ops/cpu/test_dropout_op.py | 93 +++++++++++++++++++ 4 files changed, 235 insertions(+), 3 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/dropout_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/dropout_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_dropout_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/dropout_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/dropout_cpu_kernel.cc new file mode 100644 index 0000000000..cab48c08a3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/dropout_cpu_kernel.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "runtime/device/cpu/cpu_device_address.h" +#include "backend/kernel_compiler/cpu/dropout_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +void DropoutCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + mask_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 1); + keep_prob_ = AnfAlgo::GetNodeAttr(kernel_node, "keep_prob"); + if (keep_prob_ <= 0.0) { + MS_LOG(EXCEPTION) << "Keep_prob is smaller or equal to zero but DropoutCPUKernel needs greater than 0"; + } + if (keep_prob_ > 1.0) { + MS_LOG(EXCEPTION) << "Keep_prob greater than one but DropoutCPUKernel needs smaller or equal to one"; + } + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + for (const uint64_t &d : input_shape_) { + tensor_size_ *= d; + } +} + +bool DropoutCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } + return true; +} + +template +void DropoutCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto mask_addr = reinterpret_cast(outputs[1]->addr); + std::random_device rd; + std::mt19937 gen(rd()); + std::bernoulli_distribution dis(keep_prob_); + T scale = (T)(1.f / keep_prob_); + for (uint64_t i = 0; i < tensor_size_; ++i) { + mask_addr[i] = (T)dis(gen); + output_addr[i] = mask_addr[i] * input_addr[i] * scale; + } +} + +void DropoutCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but DropoutCPUKernel needs 1 input."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but DropoutCPUKernel needs 1 output."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/dropout_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/dropout_cpu_kernel.h new file mode 100644 index 0000000000..b4f2eb4677 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/dropout_cpu_kernel.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class DropoutCPUKernel : public CPUKernel { + public: + DropoutCPUKernel() = default; + ~DropoutCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + void CheckParam(const CNodePtr &kernel_node); + std::vector input_shape_; + std::vector output_shape_; + std::vector mask_shape_; + TypeId dtype_{kTypeUnknown}; + float keep_prob_ = 0.0; + uint64_t tensor_size_ = 1; +}; + +MS_REG_CPU_KERNEL( + Dropout, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + DropoutCPUKernel); +MS_REG_CPU_KERNEL( + Dropout, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + DropoutCPUKernel); + +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_ diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index ce7722628d..911d2e13df 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -73,7 +73,7 @@ class Dropout(Cell): Tensor, output tensor with the same shape as the input. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32) @@ -102,14 +102,14 @@ class Dropout(Cell): self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) self.dropout_do_mask = P.DropoutDoMask() self.cast = P.Cast() - self.is_gpu = context.get_context('device_target') in ["GPU"] + self.is_ascend = context.get_context('device_target') in ["Ascend"] self.dropout = P.Dropout(keep_prob) def construct(self, x): if not self.training: return x - if self.is_gpu: + if not self.is_ascend: out, _ = self.dropout(x) return out diff --git a/tests/st/ops/cpu/test_dropout_op.py b/tests/st/ops/cpu/test_dropout_op.py new file mode 100644 index 0000000000..4fc1be596f --- /dev/null +++ b/tests/st/ops/cpu/test_dropout_op.py @@ -0,0 +1,93 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.dropout = P.Dropout() + + def construct(self, x): + return self.dropout(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_net(): + x = np.random.randn(3, 3, 4).astype(np.float32) + dropout = Net() + output, mask = dropout(Tensor(x)) + print(x) + print(output) + print(mask) + + +class Net1(nn.Cell): + def __init__(self): + super(Net1, self).__init__() + self.dropout = P.Dropout(keep_prob=0.1) + + def construct(self, x): + return self.dropout(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_net1(): + x = np.arange(0, 16).reshape(2, 2, 4).astype(np.float32) + dropout = Net1() + output, mask = dropout(Tensor(x)) + print(x) + print(output) + print(mask) + + +class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.dropout = P.Dropout(keep_prob=1.0) + + def construct(self, x): + return self.dropout(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_net2(): + x = np.arange(0, 12).reshape(3, 4).astype(np.float16) + dropout = Net2() + output, mask = dropout(Tensor(x)) + print(x) + print(output) + print(mask) + + +if __name__ == '__main__': + test_net() + test_net1() + test_net2()