diff --git a/mindspore/lite/nnacl/fp32/one_hot_fp32.c b/mindspore/lite/nnacl/fp32/one_hot_fp32.c index 7d4825b674..ebe9f2a0c4 100644 --- a/mindspore/lite/nnacl/fp32/one_hot_fp32.c +++ b/mindspore/lite/nnacl/fp32/one_hot_fp32.c @@ -31,14 +31,15 @@ int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_par int i, j, k; for (i = tid; i < outer_size; i += thread_num) { float *output_ptr = output + i * depth * inner_size; - for (k = 0; k < inner_size; k++) { - int index = indices[i * inner_size + k]; - if (index >= depth) { - return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; - } - for (j = 0; j < depth; j++) { + for (k = 0; k < depth; k++) { // output layout: outer_size * depth * inner_size + const int *indices_ptr = indices + i * inner_size; + for (j = 0; j < inner_size; j++) { *output_ptr = off_value; - if (index == j) { + int index = *(indices_ptr++); + if (one_hot_param->support_neg_index_ && index < 0) { + index += depth; + } + if (index == k) { *output_ptr = on_value; } output_ptr++; diff --git a/mindspore/lite/nnacl/fp32/one_hot_fp32.h b/mindspore/lite/nnacl/fp32/one_hot_fp32.h index caf104a486..7b2039bb08 100644 --- a/mindspore/lite/nnacl/fp32/one_hot_fp32.h +++ b/mindspore/lite/nnacl/fp32/one_hot_fp32.h @@ -32,6 +32,7 @@ typedef struct OneHotParameter { float off_value_; int outer_size_; int inner_size_; + bool support_neg_index_; // if true, support neg index in indices tensor; if false, set off_value on neg index. } OneHotParameter; #ifdef __cplusplus diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot_fp32.cc index 1eeb8afdfa..490bdff93d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot_fp32.cc @@ -134,6 +134,8 @@ int OneHotCPUKernel::GetParams() { one_hot_param->depth_ = *depth; if (in_tensors_.size() == kInputNum) { + // 4 inputs: indices, depth, on_value, off_value + one_hot_param->support_neg_index_ = false; auto on_value_tensor = in_tensors_.at(2); if (on_value_tensor == nullptr) { MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr"; @@ -156,12 +158,14 @@ int OneHotCPUKernel::GetParams() { } one_hot_param->off_value_ = *off_value; } else { + // 3 inputs: indices, depth, off_on_value + one_hot_param->support_neg_index_ = true; auto off_on_tensor = in_tensors_.at(2); if (off_on_tensor == nullptr) { MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr"; return RET_NULL_PTR; } - const int64_t *off_on_values = static_cast(off_on_tensor->MutableData()); + const float *off_on_values = static_cast(off_on_tensor->MutableData()); // need to support int type if (off_on_values == nullptr) { MS_LOG(ERROR) << "OneHot input[2] data is nullptr"; return RET_NULL_PTR; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/one_hot_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/one_hot_fp32_test.cc new file mode 100644 index 0000000000..d7c7a240fa --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/one_hot_fp32_test.cc @@ -0,0 +1,137 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "mindspore/lite/src/lite_kernel.h" +#include "mindspore/lite/src/tensor.h" +#include "common/common_test.h" +#include "nnacl/fp32/one_hot_fp32.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "schema/ops_generated.h" + +namespace mindspore { + +class TestOneHotFp32 : public mindspore::CommonTest { + public: + TestOneHotFp32() = default; + void Prepare(const std::vector &indices_shape, int *indices_data, int *depth, float *off_on_value, + const int axis, const std::vector &output_shape, float *output_data, const int thread_num); + + void TearDown() override; + + public: + float err_tol = 1e-5; + lite::Tensor indices_tensor_; + lite::Tensor depth_tensor_; + lite::Tensor off_on_tensor_; + lite::Tensor out_tensor_; + OneHotParameter *param_; + std::vector inputs_{&indices_tensor_, &depth_tensor_, &off_on_tensor_}; + std::vector outputs_{&out_tensor_}; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt32, schema::PrimitiveType_OneHot}; + lite::InnerContext ctx_ = lite::InnerContext(); + kernel::KernelCreator creator_ = nullptr; + kernel::LiteKernel *kernel_ = nullptr; +}; + +void TestOneHotFp32::TearDown() { + indices_tensor_.set_data(nullptr); + depth_tensor_.set_data(nullptr); + off_on_tensor_.set_data(nullptr); + out_tensor_.set_data(nullptr); + delete (kernel_); +} + +void TestOneHotFp32::Prepare(const std::vector &indices_shape, int *indices_data, int *depth, float *off_on_value, + const int axis, const std::vector &output_shape, float *output_data, + const int thread_num) { + indices_tensor_.set_data_type(kNumberTypeInt32); + indices_tensor_.set_shape(indices_shape); + indices_tensor_.set_data(indices_data); + + depth_tensor_.set_data(depth); + off_on_tensor_.set_data_type(kNumberTypeFloat32); + off_on_tensor_.set_data(off_on_value); + + out_tensor_.set_shape(output_shape); + out_tensor_.set_data(output_data); + + param_ = reinterpret_cast(malloc(sizeof(OneHotParameter))); + param_->axis_ = axis; + ctx_ = lite::InnerContext(); + ctx_.thread_num_ = thread_num; + ctx_.Init(); + creator_ = lite::KernelRegistry::GetInstance()->GetCreator(desc); + kernel_ = creator_(inputs_, outputs_, reinterpret_cast(param_), &ctx_, desc, nullptr); +} + +// 3 3 axis -1 -> 3 3 4 +TEST_F(TestOneHotFp32, Test1) { + std::vector indices_shape{3, 3}; + int indices[9] = {0, 0, 1, 0, 0, 2, 0, 1, 2}; + int depth[1] = {4}; + float off_on[2] = {0, 1}; + std::vector output_shape{3, 3, 4}; + float out_data[36] = {0}; + + Prepare(indices_shape, indices, depth, off_on, -1, output_shape, out_data, 2); + + auto ret = kernel_->Run(); + EXPECT_EQ(0, ret); + std::vector expect{1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f}; + ASSERT_EQ(0, CompareOutputData(out_data, expect.data(), 36, err_tol)); +} + +// 3 3 axis 1 -> 3 4 3 +TEST_F(TestOneHotFp32, Test2) { + std::vector indices_shape{3, 3}; + int indices[9] = {0, 0, 1, 0, 0, 2, 0, 1, 2}; + int depth[1] = {4}; + float off_on[2] = {0, 1}; + std::vector output_shape{3, 4, 3}; + float out_data[36] = {0}; + + Prepare(indices_shape, indices, depth, off_on, 1, output_shape, out_data, 2); + + auto ret = kernel_->Run(); + EXPECT_EQ(0, ret); + std::vector expect{1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f}; + ASSERT_EQ(0, CompareOutputData(out_data, expect.data(), 36, err_tol)); +} + +// 3 3 axis 0 -> 4 3 3 +TEST_F(TestOneHotFp32, Test3) { + std::vector indices_shape{3, 3}; + int indices[9] = {0, 0, 1, 0, 0, 2, 0, 1, 2}; + int depth[1] = {4}; + float off_on[2] = {0, 1}; + std::vector output_shape{4, 3, 3}; + float out_data[36] = {0}; + + Prepare(indices_shape, indices, depth, off_on, 0, output_shape, out_data, 2); + + auto ret = kernel_->Run(); + EXPECT_EQ(0, ret); + std::vector expect{1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + ASSERT_EQ(0, CompareOutputData(out_data, expect.data(), 36, err_tol)); +} + +} // namespace mindspore