From 8761d2c874957016bb6b19d41a8d407cf008fb64 Mon Sep 17 00:00:00 2001 From: liujiahan Date: Fri, 5 Feb 2021 17:49:47 +0800 Subject: [PATCH] add relu, sigmoid and log grad computing units. modified file format modified file format change the header file in activation_fp16_grad.h change the year of Copyright info and added LiteKernelCreator. --- .../lite/nnacl/fp16_grad/activation_grad.c | 72 +++++++ .../lite/nnacl/fp16_grad/activation_grad.h | 43 ++++ mindspore/lite/nnacl/optimize/CMakeLists.txt | 5 + mindspore/lite/schema/ops.fbs | 3 +- .../src/runtime/kernel/arm/CMakeLists.txt | 4 + .../arm/fp16_grad/activation_fp16_grad.cc | 93 ++++++++ .../arm/fp16_grad/activation_fp16_grad.h | 46 ++++ mindspore/lite/test/CMakeLists.txt | 6 + .../fp16_grad/activation_grad_fp16_test.cc | 199 ++++++++++++++++++ .../test_data/activationGrad/log_out_50.bin | 1 + .../arm/test_data/activationGrad/log_x_50.bin | Bin 0 -> 200 bytes .../test_data/activationGrad/log_yt_50.bin | Bin 0 -> 200 bytes 12 files changed, 471 insertions(+), 1 deletion(-) create mode 100644 mindspore/lite/nnacl/fp16_grad/activation_grad.c create mode 100644 mindspore/lite/nnacl/fp16_grad/activation_grad.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16_grad/activation_fp16_grad.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16_grad/activation_fp16_grad.h create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/fp16_grad/activation_grad_fp16_test.cc create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/log_out_50.bin create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/log_x_50.bin create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/log_yt_50.bin diff --git a/mindspore/lite/nnacl/fp16_grad/activation_grad.c b/mindspore/lite/nnacl/fp16_grad/activation_grad.c new file mode 100644 index 0000000000..a9406e0262 --- /dev/null +++ b/mindspore/lite/nnacl/fp16_grad/activation_grad.c @@ -0,0 +1,72 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl/op_base.h" +#include "nnacl/fp16_grad/activation_grad.h" +#include "nnacl/errorcode.h" + +int Fp16ReluGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t zero_4 = vdupq_n_f16(0); + for (; i < length - 4; i += 4) { + float16x8_t src0_4 = vld1q_f16(src0 + i); + float16x8_t src1_4 = vld1q_f16(src1 + i); + uint16x8_t mask_4 = vcgtq_f16(src1_4, zero_4); + float16x8_t dst_4 = vbslq_f16(mask_4, src0_4, zero_4); + vst1q_f16(dst + i, dst_4); + } +#endif + for (; i < length; i++) { + dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int Fp16SigmoidGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t one_4 = vdupq_n_f16(1); + for (; i < length - 4; i += 4) { + float16x8_t src0_4 = vld1q_f16(src0 + i); + float16x8_t src1_4 = vld1q_f16(src1 + i); + float16x8_t dst_4 = vmulq_f16(src0_4, vmulq_f16(src1_4, vsubq_f16(one_4, src1_4))); + vst1q_f16(dst + i, dst_4); + } +#endif + for (; i < length; i++) { + dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); + } + return NNACL_OK; +} + +int Fp16LogGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t log_10 = vdupq_n_f16(log(10)); + for (; i < length - 4; i += 4) { + float16x8_t src0_4 = vld1q_f16(src0 + i); + float16x8_t src1_4 = vld1q_f16(src1 + i); + float16x8_t dst_4 = vmulq_f16(src0_4, vrecpeq_f16(vmulq_f16(src1_4, log_10))); + vst1q_f16(dst + i, dst_4); + } +#endif + for (; i < length; i++) { + dst[i] = src0[i] * 1.0f / (src1[i] * log(10)); + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp16_grad/activation_grad.h b/mindspore/lite/nnacl/fp16_grad/activation_grad.h new file mode 100644 index 0000000000..985708bb15 --- /dev/null +++ b/mindspore/lite/nnacl/fp16_grad/activation_grad.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP16_GRAD_ACTIVATION_GRAD_H_ +#define MINDSPORE_LITE_NNACL_FP16_GRAD_ACTIVATION_GRAD_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl/op_base.h" +#include "mindspore/lite/nnacl/int8/fixed_point.h" + +typedef struct ActivationGradParameterFp16 { + OpParameter op_parameter; + int type_; + float alpha_; +} ActivationGradParameterFp16; +#ifdef __cplusplus +extern "C" { +#endif + +int Fp16ReluGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst); +int Fp16SigmoidGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst); +int Fp16LogGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP16_GRAD_ACTIVATION_GRAD_H_ diff --git a/mindspore/lite/nnacl/optimize/CMakeLists.txt b/mindspore/lite/nnacl/optimize/CMakeLists.txt index ea7a14499e..77dc3f7452 100644 --- a/mindspore/lite/nnacl/optimize/CMakeLists.txt +++ b/mindspore/lite/nnacl/optimize/CMakeLists.txt @@ -17,6 +17,11 @@ list(APPEND SDOT_FILES ${SDOT_SRC}) list(APPEND FP16_FILES ${FP16_C_SRC}) list(APPEND FP16_FILES ${FP16_NEON_SRC}) +if(SUPPORT_TRAIN) + file(GLOB FP16_TRAIN_SRC ${NNACL_DIR}/fp16_grad/*.c) + list(APPEND FP16_FILES ${FP16_TRAIN_SRC}) +endif() + string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 60dd7d6603..8024c8a757 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -99,7 +99,8 @@ enum ActivationGradType : byte { HSIGMOID = 13, THRESHOLDRELU = 14, LINEAR = 15, - UNKNOWN = 16 + UNKNOWN = 16, + LOG = 17 } enum ReduceType : byte { REDUCE_MAX = 0, diff --git a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt index 99c4e1f295..324ce653ab 100644 --- a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt +++ b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt @@ -9,6 +9,7 @@ file(GLOB KERNEL_SRC list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/opt_op_handler.cc) if(SUPPORT_TRAIN) + file(GLOB TRAIN_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc) file(GLOB TRAIN_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp32_grad/*.cc) set(KERNEL_SRC ${KERNEL_SRC} ${TRAIN_KERNEL_SRC}) endif() @@ -19,6 +20,9 @@ add_dependencies(cpu_kernel_mid fbs_src) if(PLATFORM_ARM64) if(ENABLE_FP16) file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc) + if(SUPPORT_TRAIN) + file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc) + endif() add_library(cpu_fp16_kernel_mid OBJECT ${FP16_KERNEL_SRC}) add_dependencies(cpu_fp16_kernel_mid fbs_src) endif() diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/activation_fp16_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/activation_fp16_grad.cc new file mode 100644 index 0000000000..87f8fe4887 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/activation_fp16_grad.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16_grad/activation_fp16_grad.h" +#include "nnacl/fp16_grad/activation_grad.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType_RELU; +using mindspore::schema::ActivationType_SIGMOID; +using mindspore::schema::PrimitiveType_ActivationGrad; + +namespace mindspore::kernel { +int ActivationGradCPUKernelFp16::Init() { + if (in_tensors_.size() != 2) { + MS_LOG(ERROR) << "ActivationGrad should have 2 input tensors"; + return RET_ERROR; + } + return RET_OK; +} + +int ActivationGradCPUKernelFp16::ReSize() { return RET_OK; } + +int ActivationGradCPUKernelFp16::DoActivation(int task_id) { + auto yt_addr = reinterpret_cast(in_tensors_.at(0)->MutableData()); + auto input_addr = reinterpret_cast(in_tensors_.at(1)->MutableData()); + auto output_addr = reinterpret_cast(out_tensors_.at(0)->MutableData()); + int length = in_tensors_.at(0)->ElementsNum(); + + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); + int start = stride * task_id; + + auto error_code = RET_OK; + + if (param_act_grad_->type_ == schema::ActivationGradType_RELU) { + error_code = Fp16ReluGrad(yt_addr + start, input_addr + start, count, output_addr + start); + } else if (param_act_grad_->type_ == schema::ActivationGradType_SIGMOID) { + // Sigmoid gets the input tensors in reverse order! + error_code = Fp16SigmoidGrad(input_addr + start, yt_addr + start, count, output_addr + start); + } else if (param_act_grad_->type_ == schema::ActivationGradType_LOG) { + error_code = Fp16LogGrad(yt_addr + start, input_addr + start, count, output_addr + start); + } else { + MS_LOG(ERROR) << "Activation type error"; + return RET_ERROR; + } + if (error_code != RET_OK) { + return RET_ERROR; + } + return RET_OK; +} + +int ActivationGradRunFp16(void *cdata, int task_id) { + MS_ASSERT(cdata != nullptr); + auto activationGrad_kernel = reinterpret_cast(cdata); + auto error_code = activationGrad_kernel->DoActivation(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ActivationGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ActivationGradCPUKernelFp16::Run() { + int error_code = ParallelLaunch(this->context_->thread_pool_, ActivationGradRunFp16, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Activation Grad function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ActivationGrad, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/activation_fp16_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/activation_fp16_grad.h new file mode 100644 index 0000000000..f92faa31f4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/activation_fp16_grad.h @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ACTIVATION_FP16_GRAD_H +#define MINDSPORE_ACTIVATION_FP16_GRAD_H + +#include +#include "src/lite_kernel.h" +#include "nnacl/fp16_grad/activation_grad.h" + +namespace mindspore::kernel { +class ActivationGradCPUKernelFp16 : public LiteKernel { + public: + explicit ActivationGradCPUKernelFp16(OpParameter *param, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { + param_act_grad_ = reinterpret_cast(param); + } + ~ActivationGradCPUKernelFp16() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoActivation(int task_id); + + private: + ActivationGradParameterFp16 *param_act_grad_; + int thread_count_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_ACTIVATION_FP16_GRAD_H diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 413179be4a..ed60150f40 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -323,6 +323,12 @@ if(ENABLE_FP16) ) endif() +if(SUPPORT_TRAIN) + file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD + ${TEST_DIR}/ut/src/runtime/kernel/arm/fp6_grad/*.cc) + list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD}) +endif() + add_executable(lite-test ${TEST_SRC}) add_dependencies(lite-test fbs_src) target_link_libraries(lite-test dl mindspore::gtest) diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp16_grad/activation_grad_fp16_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp16_grad/activation_grad_fp16_test.cc new file mode 100644 index 0000000000..73ff167f68 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp16_grad/activation_grad_fp16_test.cc @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#ifdef ENABLE_NEON +#include +#endif +#include "src/common/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "nnacl/fp16_grad/activation_grad.h" + +namespace mindspore { +class TestActGradFp16 : public mindspore::CommonTest { + public: + TestActGradFp16() {} + float error_bound = 1e-3; +}; + +TEST_F(TestActGradFp16, ReluGradFp16) { + size_t output_data_size = 50; + size_t input_size; + std::string input_path = "./test_data/activationGrad/relu_y_50.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + ASSERT_NE(input_data, nullptr); + EXPECT_EQ(input_size, output_data_size * sizeof(float)); + + std::string yt_path = "./test_data/activationGrad/relu_yt_50.bin"; + auto yt_data = reinterpret_cast(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); + ASSERT_NE(yt_data, nullptr); + EXPECT_EQ(input_size, output_data_size * sizeof(float)); + + std::string output_path = "./test_data/activationGrad/relu_out_50.bin"; + auto ref_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &input_size)); + ASSERT_NE(ref_data, nullptr); + EXPECT_EQ(input_size, output_data_size * sizeof(float)); + + auto yt_buf = new float16_t[output_data_size]; + auto input_buf = new float16_t[output_data_size]; + auto output_buf = new float16_t[output_data_size]; + + std::cout << "======yt_buf======" << std::endl; + for (int i = 0; i < output_data_size; i++) { + yt_buf[i] = (float16_t)yt_data[i]; + input_buf[i] = (float16_t)input_data[i]; + } + + Fp16ReluGrad(yt_buf, input_buf, 50, output_buf); + + int res = 0; + float error = 0; + std::cout << "======Compare with reference data======" << std::endl; + for (int i = 0; i < output_data_size; i++) { + float diff = std::fabs(static_cast(output_buf[i]) - ref_data[i]); + if (diff > 0.00001) { + error += diff; + } + } + error /= static_cast(output_data_size); + if (error > error_bound) { + printf("error=%f while error_bound=%f\n", error, error_bound); + res = 1; + } + + EXPECT_EQ(res, 0); + + delete[] output_buf; + delete[] yt_buf; + delete[] input_buf; + delete[] ref_data; + delete[] yt_data; + delete[] input_data; + + MS_LOG(INFO) << "ReluGradFp16 passed"; +} + +TEST_F(TestActGradFp16, SigmoidGradFp16) { + size_t output_data_size = 50; + size_t input_size; + std::string input_path = "./test_data/activationGrad/sigmoid_y_50.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + ASSERT_NE(input_data, nullptr); + + std::string yt_path = "./test_data/activationGrad/sigmoid_yt_50.bin"; + auto yt_data = reinterpret_cast(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); + ASSERT_NE(yt_data, nullptr); + + std::string output_path = "./test_data/activationGrad/sigmoid_out_50.bin"; + auto ref_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &input_size)); + ASSERT_NE(ref_data, nullptr); + EXPECT_EQ(input_size, output_data_size * sizeof(float)); + + auto yt_buf = new float16_t[output_data_size]; + auto input_buf = new float16_t[output_data_size]; + auto output_buf = new float16_t[output_data_size]; + + std::cout << "======yt_buf======" << std::endl; + for (int i = 0; i < output_data_size; i++) { + yt_buf[i] = (float16_t)yt_data[i]; + input_buf[i] = (float16_t)input_data[i]; + } + + Fp16SigmoidGrad(yt_buf, input_buf, 50, output_buf); + + int res = 0; + float error = 0; + std::cout << "======Compare with reference data======" << std::endl; + for (int i = 0; i < output_data_size; i++) { + float diff = std::fabs(static_cast(output_buf[i]) - ref_data[i]); + if (diff > 0.00001) { + error += diff; + } + } + error /= static_cast(output_data_size); + if (error > error_bound) { + printf("error=%f while error_bound=%f\n", error, error_bound); + res = 1; + } + + EXPECT_EQ(res, 0); + + delete[] output_buf; + delete[] yt_buf; + delete[] input_buf; + delete[] ref_data; + delete[] yt_data; + delete[] input_data; + + MS_LOG(INFO) << "SigmoidGradFp16 passed"; +} + +TEST_F(TestActGradFp16, LogGradFp16) { + size_t output_data_size = 50; + size_t input_size; + std::string input_path = "./test_data/activationGrad/log_x_50.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + ASSERT_NE(input_data, nullptr); + + std::string yt_path = "./test_data/activationGrad/log_yt_50.bin"; + auto yt_data = reinterpret_cast(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); + ASSERT_NE(yt_data, nullptr); + + std::string output_path = "./test_data/activationGrad/log_out_50.bin"; + auto ref_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &input_size)); + ASSERT_NE(ref_data, nullptr); + EXPECT_EQ(input_size, output_data_size * sizeof(float)); + + auto yt_buf = new float16_t[output_data_size]; + auto input_buf = new float16_t[output_data_size]; + auto output_buf = new float16_t[output_data_size]; + + for (int i = 0; i < output_data_size; i++) { + yt_buf[i] = (float16_t)yt_data[i]; + input_buf[i] = (float16_t)input_data[i]; + } + + Fp16LogGrad(yt_buf, input_buf, 50, output_buf); + + int res = 0; + float error = 0; + std::cout << "======Compare with reference data======" << std::endl; + for (int i = 0; i < output_data_size; i++) { + float diff = std::fabs(static_cast(output_buf[i]) - ref_data[i]); + if (diff > 0.00001) { + error += diff; + } + } + error /= static_cast(output_data_size); + if (error > error_bound) { + printf("error%f while error_bound=%f\n", error, error_bound); + res = 1; + } + + EXPECT_EQ(res, 0); + + delete[] output_buf; + delete[] yt_buf; + delete[] input_buf; + delete[] ref_data; + delete[] yt_data; + delete[] input_data; + + MS_LOG(INFO) << "LogGradFp16 passed"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/log_out_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/log_out_50.bin new file mode 100644 index 0000000000..683161be13 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/log_out_50.bin @@ -0,0 +1 @@ +Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ>Ø[Þ> \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/log_x_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/log_x_50.bin new file mode 100644 index 0000000000000000000000000000000000000000..f34607f1acbb57e713fe477143341338ee408ba4 GIT binary patch literal 200 OcmZQzXs~A(1{eT`#3fGv literal 0 HcmV?d00001 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/log_yt_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/log_yt_50.bin new file mode 100644 index 0000000000000000000000000000000000000000..f34607f1acbb57e713fe477143341338ee408ba4 GIT binary patch literal 200 OcmZQzXs~A(1{eT`#3fGv literal 0 HcmV?d00001