diff --git a/mindspore/lite/nnacl/fp32/sparse_to_dense.c b/mindspore/lite/nnacl/fp32/sparse_to_dense.c new file mode 100644 index 0000000000..c22f1c4677 --- /dev/null +++ b/mindspore/lite/nnacl/fp32/sparse_to_dense.c @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/sparse_to_dense.h" + +void SparseToDense(int **sparse_indices, int *output_shape, + float *sparse_values, float default_value, float *output, + bool isScalar, int index_start, int index_end, int out_width) { + for (int i = index_start; i < index_end; i++) { + for (int j = 0; j < out_width; j++) { + output[i * out_width + j] = default_value; + } + } + + int d1 = output_shape[1] * output_shape[2] * output_shape[3]; + int d2 = output_shape[2] * output_shape[3]; + int d3 = output_shape[3]; + + int index; + if (isScalar == true) { + for (int i = index_start; i < index_end; i++) { + index = d1 * sparse_indices[i][0] + d2 * sparse_indices[i][1] + + d3 * sparse_indices[i][2] + sparse_indices[i][3]; + output[index] = sparse_values[0]; + } + } else { + for (int i = index_start; i < index_end; i++) { + index = d1 * sparse_indices[i][0] + d2 * sparse_indices[i][1] + + d3 * sparse_indices[i][2] + sparse_indices[i][3]; + output[index] = sparse_values[i]; + } + } +} diff --git a/mindspore/lite/nnacl/fp32/sparse_to_dense.h b/mindspore/lite/nnacl/fp32/sparse_to_dense.h new file mode 100644 index 0000000000..15c2345a76 --- /dev/null +++ b/mindspore/lite/nnacl/fp32/sparse_to_dense.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP32_SPARSETODENSE_H_ +#define MINDSPORE_LITE_NNACL_FP32_SPARSETODENSE_H_ + +#include "nnacl/sparse_to_dense_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SparseToDense(int **sparse_indices_vect, int *output_shape, + float *sparse_values, float default_value, float *output, + bool isScalar, int index_start, int index_end, int out_width); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP32_SPARSETODENSE_H_ diff --git a/mindspore/lite/nnacl/sparse_to_dense.c b/mindspore/lite/nnacl/sparse_to_dense.c deleted file mode 100644 index 95d5c4d239..0000000000 --- a/mindspore/lite/nnacl/sparse_to_dense.c +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * -// * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "nnacl/sparse_to_dense.h" - -void SparseToDense(int *input, int *output_shape_, float *snum, float *dnum, int sp_num, float *output, - SparseToDenseParameter *s2d_param_, int task_id) { - int m; - for (int i = task_id; i < output_shape_[0]; i += s2d_param_->op_parameter_.thread_num_) { - for (int j = 0; j < output_shape_[1]; j++) { - m = i * output_shape_[1] + j; - output[m] = dnum[0]; - } - } - - for (int j = 0; j < sp_num; j++) { - int temp = j * 2; - int temp1 = j * 2 + 1; - int tempout1 = input[temp] * output_shape_[1] + input[temp1]; - output[tempout1] = snum[j]; - } -} diff --git a/mindspore/lite/nnacl/sparse_to_dense.h b/mindspore/lite/nnacl/sparse_to_dense_parameter.h similarity index 66% rename from mindspore/lite/nnacl/sparse_to_dense.h rename to mindspore/lite/nnacl/sparse_to_dense_parameter.h index aa2205c8e1..5dfcff139c 100644 --- a/mindspore/lite/nnacl/sparse_to_dense.h +++ b/mindspore/lite/nnacl/sparse_to_dense_parameter.h @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_NNACL_SPARSETODENSE_H_ -#define MINDSPORE_LITE_NNACL_SPARSETODENSE_H_ + +#ifndef MINDSPORE_LITE_NNACL_SPARSE_TO_DENSE_PARAMETER_H_ +#define MINDSPORE_LITE_NNACL_SPARSE_TO_DENSE_PARAMETER_H_ #include "nnacl/op_base.h" @@ -22,16 +23,6 @@ typedef struct SparseToDenseParameter { OpParameter op_parameter_; bool validate_indices_; int thread_num_; - int count_; } SparseToDenseParameter; -#ifdef __cplusplus -extern "C" { -#endif -void SparseToDense(int *input, int *output_shape_, float *snum, float *dnum, int sp_num, float *output, - SparseToDenseParameter *s2d_param_, int task_id); -#ifdef __cplusplus -} -#endif - -#endif // MINDSPORE_LITE_NNACL_SPARSETODENCE_H_ +#endif // MINDSPORE_LITE_NNACL_SPARSE_TO_DENSE_PARAMETER_H_ diff --git a/mindspore/lite/src/ops/sparse_to_dense.cc b/mindspore/lite/src/ops/sparse_to_dense.cc index 130b148aea..04a745843d 100644 --- a/mindspore/lite/src/ops/sparse_to_dense.cc +++ b/mindspore/lite/src/ops/sparse_to_dense.cc @@ -42,5 +42,34 @@ int SparseToDense::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb return RET_OK; } #endif + +int SparseToDense::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive_ != nullptr); + MS_ASSERT(output_shape != nullptr); + auto output = outputs_.front(); + if (output == nullptr) { + MS_LOG(ERROR) << "output null pointer dereferencing."; + return RET_ERROR; + } + auto input2 = inputs_.at(2); + outputs_[0]->set_data_type(input2->data_type()); + outputs_[0]->SetFormat(input2->GetFormat()); + + if (!GetInferFlag()) { + return RET_OK; + } + if (this->primitive_ == nullptr) { + return RET_NULL_PTR; + } + + auto input1 = inputs_.at(1); + int *input1_data = reinterpret_cast(input1->MutableData()); + std::vector output_shape; + for (int i = 0; i < input1->ElementsNum(); i++) { + output_shape.push_back(input1_data[i]); + } + outputs_[0]->set_shape(output_shape); + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/sparse_to_dense.h b/mindspore/lite/src/ops/sparse_to_dense.h index d98a843975..9e08335085 100644 --- a/mindspore/lite/src/ops/sparse_to_dense.h +++ b/mindspore/lite/src/ops/sparse_to_dense.h @@ -45,6 +45,7 @@ class SparseToDense : public PrimitiveC { std::vector GetSparseValue() const; std::vector GetDefaultValue() const; bool GetValidateIndices() const; + int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 57efab75b3..2e8b0b516a 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -170,7 +170,7 @@ #include "nnacl/fp32/embedding_lookup.h" #include "nnacl/fp32/elu.h" #include "nnacl/leaky_relu_parameter.h" -#include "nnacl/sparse_to_dense.h" +#include "mindspore/lite/nnacl/fp32/sparse_to_dense.h" #include "nnacl/l2_norm_parameter.h" #include "nnacl/detection_post_process_parameter.h" #include "nnacl/fp32/exp.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc index 1384d18ef2..fdcecccbd3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc @@ -14,12 +14,14 @@ * limitations under the License. */ #include "src/runtime/kernel/arm/fp32/sparse_to_dense.h" + #include + +#include "include/errorcode.h" +#include "mindspore/lite/nnacl/fp32/sparse_to_dense.h" #include "schema/model_generated.h" #include "schema/ops_generated.h" -#include "nnacl/sparse_to_dense.h" #include "src/kernel_registry.h" -#include "include/errorcode.h" #include "src/runtime/runtime_api.h" using mindspore::kernel::KERNEL_ARCH::kCPU; @@ -30,12 +32,45 @@ using mindspore::schema::PrimitiveType_SparseToDense; namespace mindspore::kernel { int SparseToDenseCPUKernel::Init() { - s2d_param_->op_parameter_.thread_num_ = thread_count_; + auto input2 = in_tensors_.at(2); + auto input3 = in_tensors_.at(3); + sparse_values = reinterpret_cast(input2->MutableData()); + default_value = reinterpret_cast(input3->MutableData())[0]; + if (input2->ElementsNum() == 1) { + isScalar = true; + } + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int SparseToDenseCPUKernel::ReSize() { + auto output0 = out_tensors_.at(0); + std::vector out_shape_tensor = output0->shape(); + auto output_shape_tmp = reinterpret_cast(out_shape_tensor.data()); + int output_dim = output0->shape().size(); + for (int i = 0; i < DIMENSION_4D - output_dim; i++) { + output_shape[i] = 1; + } + for (int i = 0; i < output_dim; i++) { + output_shape[i + DIMENSION_4D - output_dim] = output_shape_tmp[i]; + } + output_num = output0->ElementsNum(); return RET_OK; } int SparseToDenseCPUKernel::DoExcute(int task_id) { - SparseToDense(input_data_, output_shape_, snum_, dnum_, sp_num_, output_data, s2d_param_, task_id); + int real_dst_count = MSMIN(index_num - task_id * count_unit_, count_unit_); + if (real_dst_count <= 0) { + return RET_OK; + } + int index_start = task_id * count_unit_; + int index_end = index_start + real_dst_count; + int out_width = output_num / index_num; + SparseToDense(sparse_indices_vect, output_shape, sparse_values, + default_value, output_data, isScalar, + index_start, index_end, out_width); return RET_OK; } @@ -43,38 +78,117 @@ int SparseToDenseRun(void *cdata, int task_id) { auto s2ddata = reinterpret_cast(cdata); auto ret = s2ddata->DoExcute(task_id); if (ret != RET_OK) { - MS_LOG(ERROR) << "SparseToDenseRun error task_id[" << task_id << "] error_code[" << ret << "]"; + MS_LOG(ERROR) << "SparseToDenseRun error task_id[" << task_id + << "] error_code[" << ret << "]"; return RET_ERROR; } return RET_OK; } + +int SparseToDenseCPUKernel::GenerateIndices() { + auto input0 = in_tensors_.at(0); + index_dim = input0->shape().size(); + index_num = input0->shape()[0]; + int *sparse_indices = reinterpret_cast(input0->MutableData()); + sparse_indices_vect = reinterpret_cast(ctx_->allocator->Malloc(sizeof(int *) * index_num)); + if (sparse_indices_vect == nullptr) { + MS_LOG(ERROR) << "Null pointer reference: sparse_indices_vect."; + return RET_ERROR; + } + switch (index_dim) { + case 0: + case 1: { + for (int i = 0; i < index_num; i++) { + sparse_indices_vect[i] = new int[DIMENSION_4D]; + if (sparse_indices_vect[i] == nullptr) { + MS_LOG(ERROR) << "Null pointer reference: sparse_indices_vect[" << i << "]."; + return RET_ERROR; + } + for (int j = 0; j < DIMENSION_4D - 1; j++) { + sparse_indices_vect[i][j] = 0; + } + sparse_indices_vect[i][DIMENSION_4D - 1] = sparse_indices[i]; + } + break; + } + case 2: { + int true_dims = input0->shape()[1]; + MS_ASSERT(true_dims <= DIMENSION_4D); + for (int i = 0; i < index_num; i++) { + sparse_indices_vect[i] = new int[DIMENSION_4D]; + if (sparse_indices_vect[i] == nullptr) { + MS_LOG(ERROR) << "Null pointer reference: sparse_indices_vect[" << i << "]."; + return RET_ERROR; + } + for (int j = 0; j < DIMENSION_4D - true_dims; j++) { + sparse_indices_vect[i][j] = 0; + } + for (int j = 0; j < true_dims; j++) { + sparse_indices_vect[i][j + DIMENSION_4D - true_dims] = sparse_indices[i * true_dims + j]; + } + } + break; + } + default: { + MS_LOG(ERROR) << "Indices dimensions is " << index_dim << ", which must be 0, 1 or 2"; + return RET_ERROR; + } + } + return RET_OK; +} + +int SparseToDenseCPUKernel::IndicesValidCheck() { + int d1 = output_shape[1] * output_shape[2] * output_shape[3]; + int d2 = output_shape[2] * output_shape[3]; + int d3 = output_shape[3]; + int index_before = -1; + for (int i = 0; i < index_num; i++) { + int index = d1 * sparse_indices_vect[i][0] + d2 * sparse_indices_vect[i][1] + + d3 * sparse_indices_vect[i][2] + sparse_indices_vect[i][3]; + if (index <= index_before) { + return RET_ERROR; + } + index_before = index; + } + return RET_OK; +} + int SparseToDenseCPUKernel::Run() { auto ret = Prepare(); if (ret != RET_OK) { MS_LOG(ERROR) << "Prepare failed."; return RET_ERROR; } - auto input = in_tensors_.at(0); - auto input1 = in_tensors_.at(1); - auto input2 = in_tensors_.at(2); - auto input3 = in_tensors_.at(3); - auto output0 = out_tensors_.at(0); - - input_data_ = reinterpret_cast(input->MutableData()); - total_number_ = reinterpret_cast(input1->MutableData()); - snum_ = reinterpret_cast(input2->MutableData()); - dnum_ = reinterpret_cast(input3->MutableData()); - sp_num_ = static_cast(input->ElementsNum() / 2); - + auto ret1 = GenerateIndices(); + if (ret1 != RET_OK) { + MS_LOG(ERROR) << "Generate Indices failed."; + return RET_ERROR; + } + if (s2d_param->validate_indices_ == true) { + auto ret2 = IndicesValidCheck(); + if (ret2 != RET_OK) { + MS_LOG(ERROR) << "The sparse indices is not valid."; + return RET_ERROR; + } + } output_data = reinterpret_cast(out_tensors_.at(0)->MutableData()); - std::vector temp_shape = output0->shape(); - output_shape_ = reinterpret_cast(temp_shape.data()); - - ret = ParallelLaunch(THREAD_POOL_DEFAULT, SparseToDenseRun, this, s2d_param_->thread_num_); + count_unit_ = thread_count_ > 1 ? UP_DIV(index_num, thread_count_) : index_num; + ret = ParallelLaunch(THREAD_POOL_DEFAULT, SparseToDenseRun, this, + s2d_param->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "SparseToDenseRun error: error_code[" << ret << "]"; return RET_ERROR; } + + for (int i = 0; i < index_num; i++) { + if (sparse_indices_vect[i] != nullptr) { + delete sparse_indices_vect[i]; + } + } + if (sparse_indices_vect != nullptr) { + ctx_->allocator->Free(sparse_indices_vect); + sparse_indices_vect = nullptr; + } return RET_OK; } @@ -88,20 +202,25 @@ kernel::LiteKernel *CpuSparseToDenseFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ + << ", type: " + << schema::EnumNamePrimitiveType( + static_cast( + opParameter->type_)); delete kernel; return nullptr; } return kernel; } -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SparseToDense, CpuSparseToDenseFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SparseToDense, + CpuSparseToDenseFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.h b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.h index 6f429760ac..34c2d56061 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.h @@ -20,7 +20,7 @@ #include "src/lite_kernel.h" #include "include/context.h" -#include "nnacl/sparse_to_dense.h" +#include "mindspore/lite/nnacl/fp32/sparse_to_dense.h" #include "src/runtime/kernel/arm/base/layout_transform.h" using mindspore::lite::Context; @@ -32,28 +32,34 @@ class SparseToDenseCPUKernel : public LiteKernel { const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { - s2d_param_ = (reinterpret_cast(op_parameter_)); + s2d_param = (reinterpret_cast(op_parameter_)); + s2d_param->thread_num_ = thread_count_; } ~SparseToDenseCPUKernel() = default; int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override; int DoExcute(int task_id); + int GenerateIndices(); + int IndicesValidCheck(); protected: const Context *ctx_; int thread_count_; - SparseToDenseParameter *s2d_param_; + SparseToDenseParameter *s2d_param; private: - int *input_data_; - int *total_number_; - int sp_num_; - float *snum_; - float *dnum_; - float *output_data; - int *output_shape_; + int **sparse_indices_vect = nullptr; + float *sparse_values = nullptr; + float default_value; + bool isScalar = false; + int index_num; + int index_dim; + float *output_data = nullptr; + int output_shape[4]; + int output_num; + int64_t count_unit_; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSETODENSE_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/sparse_to_dense_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/sparse_to_dense_fp32_tests.cc new file mode 100644 index 0000000000..4104b296d7 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/sparse_to_dense_fp32_tests.cc @@ -0,0 +1,452 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "schema/inner/model_generated.h" +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/nnacl/fp32/sparse_to_dense.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" +#include "mindspore/lite/src/tensor.h" + +namespace mindspore { + +class TestSparseToDenseFp32 : public mindspore::CommonTest { + public: + TestSparseToDenseFp32() {} +}; + +TEST_F(TestSparseToDenseFp32, SparseToDense_test1) { + std::vector input1 = {0, 0, 1, 2, 2, 3, 3, 6, 4, 7, 5, 9}; + std::vector shape1 = {6, 2}; + std::vector input2 = {6, 10}; + std::vector shape2 = {2}; + std::vector input3 = {1}; + std::vector shape3 = {1}; + std::vector input4 = {0}; + std::vector shape4 = {1}; + + TypeId tid = kNumberTypeFloat32; + lite::Tensor *input_tensor1 = new lite::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->set_data_type(tid); + + lite::Tensor *input_tensor2 = new lite::Tensor; + input_tensor2->SetData(input2.data()); + input_tensor2->set_shape(shape2); + input_tensor2->set_data_type(tid); + + lite::Tensor *input_tensor3 = new lite::Tensor; + input_tensor3->SetData(input3.data()); + input_tensor3->set_shape(shape3); + input_tensor3->set_data_type(tid); + + lite::Tensor *input_tensor4 = new lite::Tensor; + input_tensor4->SetData(input4.data()); + input_tensor4->set_shape(shape4); + input_tensor4->set_data_type(tid); + + std::vector inputs_tensor(4); + inputs_tensor[0] = input_tensor1; + inputs_tensor[1] = input_tensor2; + inputs_tensor[2] = input_tensor3; + inputs_tensor[3] = input_tensor4; + + const int output_size = 60; + float output[60]; + std::vector output_shape = {6, 10}; + + lite::Tensor *output0_tensor = new lite::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->set_data_type(tid); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + SparseToDenseParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 3; + op_param.validate_indices_ = false; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, tid, schema::PrimitiveType_SparseToDense}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + input_tensor2->SetData(nullptr); + input_tensor3->SetData(nullptr); + input_tensor4->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete input_tensor2; + delete input_tensor3; + delete input_tensor4; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestSparseToDenseFp32, SparseToDense_test2) { + std::vector input1 = {0, 0, 1, 2, 2, 3, 3, 6, 4, 7, 5, 9}; + std::vector shape1 = {6, 2}; + std::vector input2 = {6, 10}; + std::vector shape2 = {2}; + std::vector input3 = {1, 2, 3, 4, 5, 6}; + std::vector shape3 = {6}; + std::vector input4 = {0}; + std::vector shape4 = {1}; + + TypeId tid = kNumberTypeFloat32; + lite::Tensor *input_tensor1 = new lite::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->set_data_type(tid); + + lite::Tensor *input_tensor2 = new lite::Tensor; + input_tensor2->SetData(input2.data()); + input_tensor2->set_shape(shape2); + input_tensor2->set_data_type(tid); + + lite::Tensor *input_tensor3 = new lite::Tensor; + input_tensor3->SetData(input3.data()); + input_tensor3->set_shape(shape3); + input_tensor3->set_data_type(tid); + + lite::Tensor *input_tensor4 = new lite::Tensor; + input_tensor4->SetData(input4.data()); + input_tensor4->set_shape(shape4); + input_tensor4->set_data_type(tid); + + std::vector inputs_tensor(4); + inputs_tensor[0] = input_tensor1; + inputs_tensor[1] = input_tensor2; + inputs_tensor[2] = input_tensor3; + inputs_tensor[3] = input_tensor4; + + const int output_size = 60; + float output[60]; + std::vector output_shape = {6, 10}; + + lite::Tensor *output0_tensor = new lite::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->set_data_type(tid); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + SparseToDenseParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + op_param.validate_indices_ = false; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, tid, schema::PrimitiveType_SparseToDense}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 6}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + input_tensor2->SetData(nullptr); + input_tensor3->SetData(nullptr); + input_tensor4->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete input_tensor2; + delete input_tensor3; + delete input_tensor4; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestSparseToDenseFp32, SparseToDense_test3) { + std::vector input1 = {1, 3, 4}; + std::vector shape1 = {3}; + std::vector input2 = {1, 10}; + std::vector shape2 = {2}; + std::vector input3 = {1}; + std::vector shape3 = {1}; + std::vector input4 = {0}; + std::vector shape4 = {1}; + + TypeId tid = kNumberTypeFloat32; + lite::Tensor *input_tensor1 = new lite::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->set_data_type(tid); + + lite::Tensor *input_tensor2 = new lite::Tensor; + input_tensor2->SetData(input2.data()); + input_tensor2->set_shape(shape2); + input_tensor2->set_data_type(tid); + + lite::Tensor *input_tensor3 = new lite::Tensor; + input_tensor3->SetData(input3.data()); + input_tensor3->set_shape(shape3); + input_tensor3->set_data_type(tid); + + lite::Tensor *input_tensor4 = new lite::Tensor; + input_tensor4->SetData(input4.data()); + input_tensor4->set_shape(shape4); + input_tensor4->set_data_type(tid); + + std::vector inputs_tensor(4); + inputs_tensor[0] = input_tensor1; + inputs_tensor[1] = input_tensor2; + inputs_tensor[2] = input_tensor3; + inputs_tensor[3] = input_tensor4; + + const int output_size = 10; + float output[10]; + std::vector output_shape = {1, 10}; + + lite::Tensor *output0_tensor = new lite::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->set_data_type(tid); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + SparseToDenseParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + op_param.validate_indices_ = true; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, tid, schema::PrimitiveType_SparseToDense}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {0, 1, 0, 1, 1, 0, 0, 0, 0, 0}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + input_tensor2->SetData(nullptr); + input_tensor3->SetData(nullptr); + input_tensor4->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete input_tensor2; + delete input_tensor3; + delete input_tensor4; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestSparseToDenseFp32, SparseToDense_test4) { + std::vector input1 = {5}; + std::vector shape1 = {1}; + std::vector input2 = {10}; + std::vector shape2 = {1}; + std::vector input3 = {1}; + std::vector shape3 = {1}; + std::vector input4 = {0}; + std::vector shape4 = {1}; + + TypeId tid = kNumberTypeFloat32; + lite::Tensor *input_tensor1 = new lite::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->set_data_type(tid); + + lite::Tensor *input_tensor2 = new lite::Tensor; + input_tensor2->SetData(input2.data()); + input_tensor2->set_shape(shape2); + input_tensor2->set_data_type(tid); + + lite::Tensor *input_tensor3 = new lite::Tensor; + input_tensor3->SetData(input3.data()); + input_tensor3->set_shape(shape3); + input_tensor3->set_data_type(tid); + + lite::Tensor *input_tensor4 = new lite::Tensor; + input_tensor4->SetData(input4.data()); + input_tensor4->set_shape(shape4); + input_tensor4->set_data_type(tid); + + std::vector inputs_tensor(4); + inputs_tensor[0] = input_tensor1; + inputs_tensor[1] = input_tensor2; + inputs_tensor[2] = input_tensor3; + inputs_tensor[3] = input_tensor4; + + const int output_size = 10; + float output[10]; + std::vector output_shape = {1, 10}; + + lite::Tensor *output0_tensor = new lite::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->set_data_type(tid); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + SparseToDenseParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + op_param.validate_indices_ = true; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, tid, schema::PrimitiveType_SparseToDense}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {0, 0, 0, 0, 0, 1, 0, 0, 0, 0}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + input_tensor2->SetData(nullptr); + input_tensor3->SetData(nullptr); + input_tensor4->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete input_tensor2; + delete input_tensor3; + delete input_tensor4; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestSparseToDenseFp32, SparseToDense_test5) { + std::vector input1 = {0, 0, 1, 2, 2, 3, 2, 3, 4, 7, 5, 9}; + std::vector shape1 = {6, 2}; + std::vector input2 = {6, 10}; + std::vector shape2 = {2}; + std::vector input3 = {1, 2, 3, 4, 5, 6}; + std::vector shape3 = {6}; + std::vector input4 = {0}; + std::vector shape4 = {1}; + + TypeId tid = kNumberTypeFloat32; + lite::Tensor *input_tensor1 = new lite::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->set_data_type(tid); + + lite::Tensor *input_tensor2 = new lite::Tensor; + input_tensor2->SetData(input2.data()); + input_tensor2->set_shape(shape2); + input_tensor2->set_data_type(tid); + + lite::Tensor *input_tensor3 = new lite::Tensor; + input_tensor3->SetData(input3.data()); + input_tensor3->set_shape(shape3); + input_tensor3->set_data_type(tid); + + lite::Tensor *input_tensor4 = new lite::Tensor; + input_tensor4->SetData(input4.data()); + input_tensor4->set_shape(shape4); + input_tensor4->set_data_type(tid); + + std::vector inputs_tensor(4); + inputs_tensor[0] = input_tensor1; + inputs_tensor[1] = input_tensor2; + inputs_tensor[2] = input_tensor3; + inputs_tensor[3] = input_tensor4; + + const int output_size = 60; + float output[60]; + std::vector output_shape = {6, 10}; + + lite::Tensor *output0_tensor = new lite::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->set_data_type(tid); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + SparseToDenseParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + op_param.validate_indices_ = true; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, tid, schema::PrimitiveType_SparseToDense}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 6}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + input_tensor2->SetData(nullptr); + input_tensor3->SetData(nullptr); + input_tensor4->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete input_tensor2; + delete input_tensor3; + delete input_tensor4; + delete output0_tensor; + delete ctx; +} +} // namespace mindspore