From e788f00976fe91d2344df6ef0bb944a4442016b4 Mon Sep 17 00:00:00 2001 From: ling Date: Mon, 28 Dec 2020 20:44:26 +0800 Subject: [PATCH] constant of shape support int64 --- mindspore/lite/nnacl/constant_of_shape.c | 31 ++++++++++ ...nt_of_shape_fp32.h => constant_of_shape.h} | 12 ++-- .../lite/nnacl/fp32/constant_of_shape_fp32.c | 39 ------------ mindspore/lite/src/ops/constant_of_shape.cc | 35 ++++++++--- .../populate/constant_of_shape_populate.cc | 15 ++++- .../constant_of_shape.cc} | 62 ++++++++----------- .../constant_of_shape.h} | 18 +++--- .../arm/fp32/constant_of_shape_fp32_test.cc | 5 +- 8 files changed, 114 insertions(+), 103 deletions(-) create mode 100644 mindspore/lite/nnacl/constant_of_shape.c rename mindspore/lite/nnacl/{fp32/constant_of_shape_fp32.h => constant_of_shape.h} (80%) delete mode 100644 mindspore/lite/nnacl/fp32/constant_of_shape_fp32.c rename mindspore/lite/src/runtime/kernel/arm/{fp32/constant_of_shape_fp32.cc => base/constant_of_shape.cc} (60%) rename mindspore/lite/src/runtime/kernel/arm/{fp32/constant_of_shape_fp32.h => base/constant_of_shape.h} (76%) diff --git a/mindspore/lite/nnacl/constant_of_shape.c b/mindspore/lite/nnacl/constant_of_shape.c new file mode 100644 index 0000000000..ce37a11484 --- /dev/null +++ b/mindspore/lite/nnacl/constant_of_shape.c @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/constant_of_shape.h" + +int ConstantOfShapeInt32(int32_t *output, int start, int end, int32_t value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} + +int ConstantOfShapeFp32(float *output, int start, int end, float value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp32/constant_of_shape_fp32.h b/mindspore/lite/nnacl/constant_of_shape.h similarity index 80% rename from mindspore/lite/nnacl/fp32/constant_of_shape_fp32.h rename to mindspore/lite/nnacl/constant_of_shape.h index 353172aeeb..12bf757c9d 100644 --- a/mindspore/lite/nnacl/fp32/constant_of_shape_fp32.h +++ b/mindspore/lite/nnacl/constant_of_shape.h @@ -24,17 +24,19 @@ typedef struct ConstantOfShapeParameter { OpParameter op_parameter_; - float value_; + union value_ { + float f32_value_; + int32_t int32_value_; + } value_; int data_type_; - int unit_; - int element_sz_; + int element_size_; } ConstantOfShapeParameter; #ifdef __cplusplus extern "C" { #endif -int ConstantOfShape(float *output, int tid, const ConstantOfShapeParameter *param); -int ConstantOfShapeInt(int32_t *output, int tid, const ConstantOfShapeParameter *param); +int ConstantOfShapeFp32(float *output, int start, int end, float value); +int ConstantOfShapeInt32(int32_t *output, int start, int end, int32_t value); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32/constant_of_shape_fp32.c b/mindspore/lite/nnacl/fp32/constant_of_shape_fp32.c deleted file mode 100644 index 917912809d..0000000000 --- a/mindspore/lite/nnacl/fp32/constant_of_shape_fp32.c +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "nnacl/fp32/constant_of_shape_fp32.h" - -int ConstantOfShape(float *output, int tid, const ConstantOfShapeParameter *param) { - int size = param->unit_; - float data = param->value_; - int ind_st = MSMIN(tid * size, param->element_sz_); - int ind_end = MSMIN(param->element_sz_, (tid + 1) * size); - for (int i = ind_st; i < ind_end; ++i) { - output[i] = data; - } - return NNACL_OK; -} - -int ConstantOfShapeInt(int32_t *output, int tid, const ConstantOfShapeParameter *param) { - int size = param->unit_; - float data = param->value_; - int ind_st = MSMIN(tid * size, param->element_sz_); - int ind_end = MSMIN(param->element_sz_, (tid + 1) * size); - for (int i = ind_st; i < ind_end; ++i) { - output[i] = data; - } - return NNACL_OK; -} diff --git a/mindspore/lite/src/ops/constant_of_shape.cc b/mindspore/lite/src/ops/constant_of_shape.cc index 0e73effed6..5e5a78bce7 100644 --- a/mindspore/lite/src/ops/constant_of_shape.cc +++ b/mindspore/lite/src/ops/constant_of_shape.cc @@ -78,25 +78,42 @@ int ConstantOfShape::InferShape(std::vector inputs_, std::vectorset_data_type(static_cast(GetDataType())); out_tensor->set_format(in_tensor->format()); - if (!infer_flag()) { - return RET_INFER_INVALID; - } - auto in_data = reinterpret_cast(in_tensor->data_c()); - if (in_data == nullptr) { - MS_LOG(INFO) << "Input data is nullptr. Input tensor has not been calculated out yet."; + + if (!infer_flag() || in_tensor->data_c() == nullptr) { return RET_INFER_INVALID; } + int size = in_tensor->ElementsNum(); std::vector out_shape(size); - for (int i = 0; i < size; ++i) { - out_shape[i] = in_data[i]; + + switch (in_tensor->data_type()) { + case kNumberTypeInt32: { + int32_t *in_data = reinterpret_cast(in_tensor->data_c()); + for (int i = 0; i < size; ++i) { + out_shape[i] = in_data[i]; + MS_ASSERT(out_shape[i] > 0); + } + break; + } + case kNumberTypeInt64: { + int64_t *in_data = reinterpret_cast(in_tensor->data_c()); + for (int i = 0; i < size; ++i) { + out_shape[i] = in_data[i]; + MS_ASSERT(out_shape[i] > 0); + } + break; + } + default: + MS_LOG(INFO) << "Invalid input data type!"; + return RET_INFER_INVALID; } - out_tensor->set_shape(out_shape); + out_tensor->set_shape(out_shape); return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/populate/constant_of_shape_populate.cc b/mindspore/lite/src/ops/populate/constant_of_shape_populate.cc index 2ead683888..4a04ef0812 100644 --- a/mindspore/lite/src/ops/populate/constant_of_shape_populate.cc +++ b/mindspore/lite/src/ops/populate/constant_of_shape_populate.cc @@ -19,7 +19,7 @@ #include "src/tensor.h" #include "src/ops/primitive_c.h" #include "src/ops/populate/populate_register.h" -#include "nnacl/fp32/constant_of_shape_fp32.h" +#include "nnacl/constant_of_shape.h" namespace mindspore::lite { namespace { @@ -34,13 +34,22 @@ OpParameter *PopulateConstantOfShapeParameter(const mindspore::lite::PrimitiveC } memset(param, 0, sizeof(ConstantOfShapeParameter)); param->op_parameter_.type_ = primitive->Type(); + param->data_type_ = attr->GetDataType(); auto value = attr->GetValue(); if (value.empty() || value.size() > 1) { MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1."; } else { - param->value_ = attr->GetValue().at(0); + switch (param->data_type_) { + case kNumberTypeFloat32: + param->value_.f32_value_ = attr->GetValue().at(0); + break; + case kNumberTypeInt32: + param->value_.int32_value_ = attr->GetValue().at(0); + break; + default: + MS_LOG(ERROR) << "The value of constant of shape is invalid"; + } } - param->data_type_ = attr->GetDataType(); return reinterpret_cast(param); } Registry ConstantOfShapeParameterRegistry(schema::PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.cc similarity index 60% rename from mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape_fp32.cc rename to mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.cc index 6231736428..7ed1d51a9d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.cc @@ -14,11 +14,9 @@ * limitations under the License. */ -#include "src/runtime/kernel/arm/fp32/constant_of_shape_fp32.h" -#include +#include "src/runtime/kernel/arm/base/constant_of_shape.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" -#include "include/errorcode.h" #include "src/runtime/runtime_api.h" using mindspore::kernel::KERNEL_ARCH::kCPU; @@ -28,30 +26,6 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_ConstantOfShape; namespace mindspore::kernel { -int ConstantOfShapeCPUKernel::Init() { return RET_OK; } - -int ConstantOfShapeCPUKernel::ReSize() { return RET_OK; } - -int ConstantOfShapeCPUKernel::DoExecute(int task_id) { - int ret = RET_ERROR; - switch (param_->data_type_) { - case kNumberTypeFloat32: - ret = ConstantOfShape(reinterpret_cast(out_ptr_), task_id, param_); - break; - case kNumberTypeInt32: - ret = ConstantOfShapeInt(reinterpret_cast(out_ptr_), task_id, param_); - break; - default: - MS_LOG(ERROR) << "Constant of shape does not support the output data type."; - return RET_ERROR; - } - if (ret != RET_OK) { - MS_LOG(ERROR) << "ConstantOfShapeRun error task_id[" << task_id << "] error_code[" << ret << "]"; - return ret; - } - return RET_OK; -} - int ConstantOfShapeRun(void *cdata, int task_id) { auto g_kernel = reinterpret_cast(cdata); auto ret = g_kernel->DoExecute(task_id); @@ -62,23 +36,38 @@ int ConstantOfShapeRun(void *cdata, int task_id) { return RET_OK; } -int ConstantOfShapeCPUKernel::Run() { - param_->element_sz_ = out_tensors_.front()->ElementsNum(); - int thread_num = MSMIN(param_->op_parameter_.thread_num_, param_->element_sz_); - param_->unit_ = UP_DIV(param_->element_sz_, thread_num); - param_->op_parameter_.thread_num_ = thread_num; +int ConstantOfShapeCPUKernel::DoExecute(int task_id) { + int start = task_id * thread_stride_; + int current_stride = MSMIN(thread_stride_, param_->element_size_ - start); + if (current_stride < 0) { + return RET_OK; + } + switch (param_->data_type_) { case kNumberTypeFloat32: - out_ptr_ = reinterpret_cast(out_tensors_.front()->MutableData()); + ConstantOfShapeFp32(reinterpret_cast(output_ptr_), start, start + current_stride, + param_->value_.f32_value_); break; case kNumberTypeInt32: - out_ptr_ = reinterpret_cast(out_tensors_.front()->MutableData()); + ConstantOfShapeInt32(reinterpret_cast(output_ptr_), start, start + current_stride, + param_->value_.int32_value_); break; default: - MS_LOG(ERROR) << "Constant of shape does not support the output data type."; + MS_LOG(ERROR) << "Invalid datatype in ConstantOfShapeRun"; return RET_ERROR; } - auto ret = ParallelLaunch(this->context_->thread_pool_, ConstantOfShapeRun, this, thread_num); + return RET_OK; +} + +int ConstantOfShapeCPUKernel::Run() { + auto output = out_tensors_.front(); + param_->data_type_ = output->data_type(); + param_->element_size_ = output->ElementsNum(); + output_ptr_ = output->data_c(); + int thread_count = MSMIN(op_parameter_->thread_num_, param_->element_size_); + thread_stride_ = UP_DIV(param_->element_size_, thread_count); + + auto ret = ParallelLaunch(this->context_->thread_pool_, ConstantOfShapeRun, this, thread_count); if (ret != RET_OK) { MS_LOG(ERROR) << "ConstantOfShapeRun error error_code[" << ret << "]"; return ret; @@ -88,4 +77,5 @@ int ConstantOfShapeCPUKernel::Run() { REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_ConstantOfShape, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape_fp32.h b/mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.h similarity index 76% rename from mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape_fp32.h rename to mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.h index 682ae1b9d7..e03c9f762f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.h @@ -13,15 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_CONSTANT_OF_SHAPE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_CONSTANT_OF_SHAPE_H_ #include +#include "include/errorcode.h" #include "src/lite_kernel.h" #include "include/context.h" -#include "nnacl/fp32/constant_of_shape_fp32.h" - -using mindspore::lite::InnerContext; +#include "nnacl/constant_of_shape.h" namespace mindspore::kernel { class ConstantOfShapeCPUKernel : public LiteKernel { @@ -34,15 +33,16 @@ class ConstantOfShapeCPUKernel : public LiteKernel { } ~ConstantOfShapeCPUKernel() override = default; - int Init() override; - int ReSize() override; + int Init() override { return lite::RET_OK; } + int ReSize() override { return lite::RET_OK; } int Run() override; int DoExecute(int task_id); private: ConstantOfShapeParameter *param_ = nullptr; - void *out_ptr_ = nullptr; + void *output_ptr_ = nullptr; + int thread_stride_; }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_CONSTANT_OF_SHAPE_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/constant_of_shape_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/constant_of_shape_fp32_test.cc index 407b97895e..f84a7e5ede 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/constant_of_shape_fp32_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/constant_of_shape_fp32_test.cc @@ -15,7 +15,7 @@ */ #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape_fp32.h" +#include "mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.h" #include "src/kernel_registry.h" #include "src/lite_kernel.h" @@ -47,7 +47,8 @@ TEST_F(TestConstantOfShapeFp32, Simple) { std::vector inputs_; std::vector outputs_; auto param = new ConstantOfShapeParameter(); - param->value_ = 1; + param->value_.f32_value_ = 1; + param->data_type_ = kNumberTypeFloat32; float a[] = {1, 2, 3, 4}; std::vector a_shape = {4, 1, 1, 1}; // std::vector c_shape = {2, 2, 2, 1};