增加测试用例 修改测试用例 修改测试用例 修改算子实现,alpha 写死为 1.0 修改算子实现,alpha 写死为 1 修改算子注册 修改算子实现,dy 和 y 搞反了 修改测试代码 修改算子实现 修改算子实现,float16 单独走一个分支 修改算子实现,前端的 float 转换为double 修改算子实现,改回来了 修改算子实现,float16 转换为 Eigen 库中的 half 修改算子实现,float16 转换为 Eigen 库中的 half 修改算子实现,float16 转换为 Eigen 库中的 half 修改算子实现 在单独的文件中实现 EluGrad 修改算子实现 在单独的文件中实现 EluGrad 修改算子实现 在单独的文件中实现 EluGrad 修改算子实现 在单独的文件中实现 EluGrad 格式化代码 修改 pylint 错误 更新注释信息 删除未使用的私有变量 删除未使用的私有变量pull/11615/head
parent
fcc4a4eaea
commit
d86a8ea367
@ -0,0 +1,85 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include "backend/kernel_compiler/cpu/elu_grad_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void EluGradCPUKernel::EluGrad(const T *input0, const T *input1, T *out, size_t start, size_t end) {
|
||||
const T alpha = static_cast<T>(1);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = (input1[i] < static_cast<T>(0)) ? input0[i] * (input1[i] + alpha) : input0[i];
|
||||
}
|
||||
}
|
||||
|
||||
void EluGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) {
|
||||
MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type";
|
||||
}
|
||||
}
|
||||
|
||||
bool EluGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float16>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void EluGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
T *input0 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *input1 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
T *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
|
||||
auto max_thread_num = std::thread::hardware_concurrency();
|
||||
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
|
||||
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
|
||||
std::vector<std::thread> threads;
|
||||
if (thread_num < 1) {
|
||||
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
|
||||
return;
|
||||
}
|
||||
threads.reserve(thread_num);
|
||||
size_t start = 0;
|
||||
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
|
||||
if (once_compute_size < 1) {
|
||||
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
|
||||
return;
|
||||
}
|
||||
while (start < lens) {
|
||||
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
|
||||
threads.emplace_back(std::thread(&EluGradCPUKernel::EluGrad<T>, this, input0, input1, output, start, end));
|
||||
start += once_compute_size;
|
||||
}
|
||||
for (size_t i = 0; i < threads.size(); ++i) {
|
||||
threads[i].join();
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,56 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELU_GRAD_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELU_GRAD_CPU_KERNEL_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class EluGradCPUKernel : public CPUKernel {
|
||||
public:
|
||||
EluGradCPUKernel() = default;
|
||||
~EluGradCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void EluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
EluGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
EluGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
EluGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
EluGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
EluGrad, KernelAttr().AddInputAttr(kNumberTypeFloat).AddInputAttr(kNumberTypeFloat).AddOutputAttr(kNumberTypeFloat),
|
||||
EluGradCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELU_GRAD_CPU_KERNEL_H_
|
@ -0,0 +1,75 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class NetEluGrad(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetEluGrad, self).__init__()
|
||||
self.elu_grad = G.EluGrad()
|
||||
|
||||
def construct(self, dy, y):
|
||||
return self.elu_grad(dy, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_elu_grad_fp32():
|
||||
y = Tensor(np.array([[[[-0.3, 1, 2],
|
||||
[1, -0.6, 1],
|
||||
[2, 1, -2]]]]).astype(np.float32))
|
||||
dy = Tensor(np.array([[[[-11, 2, 4],
|
||||
[-1, 1, -1],
|
||||
[-4, 4, -4]]]]).astype(np.float32))
|
||||
|
||||
expect = np.array([[[[-7.7, 2, 4],
|
||||
[-1, 0.4, -1],
|
||||
[-4, 4, 4]]]]).astype(np.float32)
|
||||
|
||||
error = np.ones(shape=[1, 1, 3, 3]) * 1.0e-6
|
||||
|
||||
elu_grad = NetEluGrad()
|
||||
output = elu_grad(dy, y)
|
||||
print(output)
|
||||
diff = np.abs(output.asnumpy() - expect)
|
||||
double_check = diff / expect
|
||||
assert np.all(double_check < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_elu_grad_fp16():
|
||||
y = Tensor(np.array([[0.5, 2, 5.5], [4.5, -2, 0]]).astype(np.float16))
|
||||
dy = Tensor(np.array([[2, 1, 1.5], [-0.5, -1, -3]]).astype(np.float16))
|
||||
expect = np.array([[2, 1, 1.5], [-0.5, 1, -3]]).astype(np.float16)
|
||||
error = np.ones(shape=[2, 3]) * 1.0e-3
|
||||
|
||||
elu_grad = NetEluGrad()
|
||||
output = elu_grad(dy, y)
|
||||
print(output)
|
||||
diff = np.abs(output.asnumpy() - expect)
|
||||
double_check = diff / expect
|
||||
assert np.all(double_check < error)
|
Loading…
Reference in new issue