!9321 Add Dropout Grad for CPU

From: @xutianming1985
Reviewed-by: 
Signed-off-by:
pull/9321/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 747bc87ab3

@ -0,0 +1,69 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/cpu/dropout_grad_kernel.h"
namespace mindspore {
namespace kernel {
void DropoutGradCpuBwdKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto input_mask_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
if (input_shape.size() != input_mask_shape.size()) {
MS_LOG(EXCEPTION) << "Input size " << input_shape.size() << " and mask size " << input_mask_shape.size()
<< " is not match";
}
num_count_ = 1;
for (size_t x : input_shape) {
num_count_ *= x;
}
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
keep_prob_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "keep_prob");
if (keep_prob_ == 0) {
MS_LOG(EXCEPTION) << "The keep_prob is zero.";
}
}
bool DropoutGradCpuBwdKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> & /*workspace*/,
const std::vector<AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
DropoutBackwardKernel<float16>(inputs, outputs, num_count_, keep_prob_);
} else if (dtype_ == kNumberTypeFloat32) {
DropoutBackwardKernel<float>(inputs, outputs, num_count_, keep_prob_);
}
return true;
}
template <typename T>
void DropoutGradCpuBwdKernel::DropoutBackwardKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs, size_t num_count,
float keep_prob) {
auto dx = reinterpret_cast<T *>(outputs[0]->addr);
auto dy = reinterpret_cast<T *>(inputs[0]->addr);
auto mask = reinterpret_cast<T *>(inputs[1]->addr);
float scale = 1.f / keep_prob;
for (size_t i = 0; i < num_count; i += 1) {
dx[i] = (T)(scale * static_cast<float>(dy[i] * mask[i]));
}
}
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,58 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_DROPOUT_GRAD_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_DROPOUT_GRAD_KERNEL_H_
#include <vector>
#include <memory>
#include <string>
#include <unordered_map>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class DropoutGradCpuBwdKernel : public CPUKernel {
public:
DropoutGradCpuBwdKernel() = default;
~DropoutGradCpuBwdKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
float keep_prob_{0.0};
size_t num_count_{1};
TypeId dtype_;
template <typename T>
void DropoutBackwardKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
size_t num_count, float keep_prob);
};
MS_REG_CPU_KERNEL(
DropoutGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
DropoutGradCpuBwdKernel);
MS_REG_CPU_KERNEL(
DropoutGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
DropoutGradCpuBwdKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DROPOUT_GRAD_KERNEL_H_

@ -0,0 +1,141 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_dropout """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
from mindspore.ops.operations import _grad_ops as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class Net(nn.Cell):
def __init__(self, keep_prob=0.5):
super(Net, self).__init__()
self.dropout_grad = P.DropoutGrad(keep_prob)
def construct(self, output, mask):
return self.dropout_grad(output, mask)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dropout_grad_001():
in_tensor = Tensor(np.array([[[3., 1., 2.]], \
[[4., 1., 4.]]]), mstype.float32)
in_mask = Tensor(np.array([[[1., 0, 0]], [[1., 1., 0]]]), mstype.float32)
dropout_grad = Net()
output = dropout_grad(in_tensor, in_mask)
print("output:\n", output)
expect = np.array([[[6., 0., 0.]], [[8., 2., 0.]]]).astype(np.float32)
error = np.ones(shape=[2, 3]) * 1.0e-6
diff = np.abs(output.asnumpy() - expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dropout_grad_002():
in_tensor = Tensor(np.array([[[3., 1., 2.]], [[4., 1., 4.]]]), mstype.float16)
in_mask = Tensor(np.array([[[1., 0, 0]], [[1., 1., 0]]]), mstype.float16)
dropout_grad = Net()
output = dropout_grad(in_tensor, in_mask)
print("output:\n", output)
expect = np.array([[[6., 0., 0.]], [[8., 2., 0.]]]).astype(np.float16)
error = np.ones(shape=[2, 3]) * 1.0e-6
diff = np.abs(output.asnumpy() - expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dropout_grad_003():
in_tensor = Tensor(np.array([[[3., 1., 2.], [3., 1., 2.]], \
[[4., 1., 4.], [4., 1., 4.]]]), mstype.float16)
in_mask = Tensor(np.array([[[1., 0, 0], [1., 0, 0]], \
[[1., 1., 0], [1., 1., 0]]]), mstype.float16)
dropout_grad = Net()
output = dropout_grad(in_tensor, in_mask)
print("output:\n", output)
expect = np.array([[[6., 0., 0.], [6., 0., 0.]], \
[[8., 2., 0.], [8., 2., 0.]]]).astype(np.float16)
error = np.ones(shape=[2, 2, 3]) * 1.0e-6
diff = np.abs(output.asnumpy() - expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dropout_grad_004():
in_tensor = Tensor(np.array([[6.]]), mstype.float32)
in_mask = Tensor(np.array([[1.]]), mstype.float32)
dropout_grad = Net(1.)
output = dropout_grad(in_tensor, in_mask)
print("output:\n", output)
expect = np.array([[6.]]).astype(np.float32)
error = np.ones(shape=[1]) * 1.0e-6
diff = np.abs(output.asnumpy() - expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dropout_grad_005():
in_tensor = Tensor(np.array([[]]), mstype.float32)
in_mask = Tensor(np.array([[]]), mstype.float32)
dropout_grad = Net(1.)
output = dropout_grad(in_tensor, in_mask)
print("output:\n", output)
expect = np.array([[]]).astype(np.float32)
error = np.ones(shape=[]) * 1.0e-6
diff = np.abs(output.asnumpy() - expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dropout_grad_006():
in_tensor = Tensor(np.array([[[3., 1., 2.]], [[4., 1., 4.]]]), mstype.float16)
in_mask = Tensor(np.array([[[1., 0, 0]], [[0., 0., 1.]]]), mstype.float16)
dropout_grad = Net(0.3333333333)
output = dropout_grad(in_tensor, in_mask)
print("output:\n", output)
expect = np.array([[[9., 0., 0.]], [[0., 0., 12.]]]).astype(np.float16)
error = np.ones(shape=[2, 3]) * 1.0e-6
diff = np.abs(output.asnumpy() - expect)
assert np.all(np.abs(diff) < error)
Loading…
Cancel
Save