diff --git a/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.c b/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.c index 3296b5bdd5..7e8ea2b320 100644 --- a/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.c +++ b/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.c @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include "nnacl/fp16/arithmetic_self_fp16.h" @@ -108,3 +109,11 @@ int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size) { } return NNACL_OK; } + +int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + assert(input[i] != 0.0f); + output[i] = 1.f / input[i]; + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.h b/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.h index 21590a6b24..5979c51ad0 100644 --- a/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.h +++ b/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.h @@ -48,6 +48,8 @@ int ElementFloorFp16(float16_t *input, float16_t *output, int element_size); int ElementCeilFp16(float16_t *input, float16_t *output, int number); int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size); + +int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c b/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c index caa048bf8b..55a05e568f 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c +++ b/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c @@ -16,6 +16,7 @@ #include #include +#include #include "nnacl/fp32/arithmetic_self_fp32.h" // abs: @@ -128,3 +129,11 @@ int ElementNegative(const float *input, float *output, const int element_size) { } return NNACL_OK; } + +int ElementReciprocal(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; ++i) { + assert(input[i] != 0.0f); + output[i] = 1.f / input[i]; + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.h b/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.h index d29aaa5a16..a6d53f65c7 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.h +++ b/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.h @@ -51,6 +51,8 @@ int ElementFloor(const float *input, float *output, const int element_size); int ElementCeil(const float *input, float *output, const int number); int ElementNegative(const float *input, float *output, const int element_size); + +int ElementReciprocal(const float *input, float *output, const int element_size); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/int8/arithmetic_self_int8.c b/mindspore/lite/nnacl/int8/arithmetic_self_int8.c index c3576843f2..16a4c198ca 100644 --- a/mindspore/lite/nnacl/int8/arithmetic_self_int8.c +++ b/mindspore/lite/nnacl/int8/arithmetic_self_int8.c @@ -15,6 +15,7 @@ */ #include +#include #include "nnacl/int8/arithmetic_self_int8.h" #ifdef ENABLE_NEON #include @@ -278,3 +279,24 @@ int Int8ElementLogicalNot(int8_t *input, int8_t *output, int element_size, Arith } return NNACL_OK; } + +int Int8ElementReciprocal(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + float input_f32 = input[i] * in_scale + bias; + assert(input_f32 != 0.0f); + int32_t output_tmp = round(1.f / (input_f32 * out_scale)) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/int8/arithmetic_self_int8.h b/mindspore/lite/nnacl/int8/arithmetic_self_int8.h index 3eb7b57d54..e792443d43 100644 --- a/mindspore/lite/nnacl/int8/arithmetic_self_int8.h +++ b/mindspore/lite/nnacl/int8/arithmetic_self_int8.h @@ -50,6 +50,8 @@ int Int8ElementSquare(int8_t *input, int8_t *output, int element_size, ArithSelf int Int8ElementLogicalNot(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); +int Int8ElementReciprocal(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + #ifdef __cplusplus } #endif diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 327950f487..c1fa7f080d 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -253,7 +253,8 @@ union PrimitiveType { All, Assert, Adder, - SparseSoftmaxCrossEntropy + SparseSoftmaxCrossEntropy, + Reciprocal, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index d97a183a09..1a622e0891 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1203,3 +1203,6 @@ table All { table Assert { summarize : int; } + +table Reciprocal { +} \ No newline at end of file diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index e2637e8ca4..c24ea68e5a 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -375,11 +375,11 @@ void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output int Conv2D::InferShape(std::vector inputs_, std::vector outputs_) { if (inputs_.size() != 2 && inputs_.size() != 3) { - MS_LOG(ERROR) << "Add should has two or three inputs"; + MS_LOG(ERROR) << "Conv2d should has two or three inputs"; return RET_ERROR; } if (outputs_.size() != 1) { - MS_LOG(ERROR) << "Add should has one outputs"; + MS_LOG(ERROR) << "Conv2d should has one outputs"; return RET_ERROR; } auto *input_tensor = inputs_.front(); diff --git a/mindspore/lite/src/ops/populate/arithmetic_self_populate.cc b/mindspore/lite/src/ops/populate/arithmetic_self_populate.cc index 580cb6673d..7f651587c3 100644 --- a/mindspore/lite/src/ops/populate/arithmetic_self_populate.cc +++ b/mindspore/lite/src/ops/populate/arithmetic_self_populate.cc @@ -47,6 +47,7 @@ Registry LogicalNotParameterRegistry(schema::PrimitiveType_LogicalNot, PopulateA Registry FloorParameterRegistry(schema::PrimitiveType_Floor, PopulateArithmeticSelf); Registry CeilParameterRegistry(schema::PrimitiveType_Ceil, PopulateArithmeticSelf); Registry RoundParameterRegistry(schema::PrimitiveType_Round, PopulateArithmeticSelf); +Registry ReciprocalParameterRegistry(schema::PrimitiveType_Reciprocal, PopulateArithmeticSelf); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/populate/split_populate.cc b/mindspore/lite/src/ops/populate/split_populate.cc index 275f607282..880a5ef854 100644 --- a/mindspore/lite/src/ops/populate/split_populate.cc +++ b/mindspore/lite/src/ops/populate/split_populate.cc @@ -31,7 +31,7 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive memset(split_param, 0, sizeof(SplitParameter)); auto param = reinterpret_cast(const_cast(primitive)); split_param->op_parameter_.type_ = primitive->Type(); - split_param->num_split_ = param->GetNumberSplit(); + split_param->num_split_ = param->num_split(); if (split_param->num_split_ > std::numeric_limits::max() / static_cast(sizeof(int))) { MS_LOG(ERROR) << "The value of split_param->num_split_ is too big"; return nullptr; @@ -44,7 +44,7 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive } memset(split_sizes, 0, split_param->num_split_ * sizeof(int)); split_param->split_sizes_ = split_sizes; - auto split_sizes_vector_ = param->GetSizeSplits(); + auto split_sizes_vector_ = param->size_splits(); int i = 0; for (int &iter : split_sizes_vector_) { split_param->split_sizes_[i++] = iter; diff --git a/mindspore/lite/src/ops/populate/tile_populate.cc b/mindspore/lite/src/ops/populate/tile_populate.cc index 4fd09fe0f9..6dd170ed0e 100644 --- a/mindspore/lite/src/ops/populate/tile_populate.cc +++ b/mindspore/lite/src/ops/populate/tile_populate.cc @@ -43,8 +43,10 @@ OpParameter *PopulateTileParameter(const mindspore::lite::PrimitiveC *primitive) for (size_t i = 0; i < kDimension_4d; ++i) { tile_param->multiples_[i] = 1; } - for (size_t i = 0; i < dims.size(); ++i) { - tile_param->multiples_[dims.at(i)] = multiples.at(i); + if (!dims.empty() && !multiples.empty()) { + for (size_t i = 0; i < dims.size(); ++i) { + tile_param->multiples_[dims[i]] = multiples[i]; + } } #endif return reinterpret_cast(tile_param); diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 6c263899a7..c3fc2841b4 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -148,6 +148,7 @@ #include "src/ops/while.h" #include "src/ops/oneslike.h" #include "src/ops/unsorted_segment_sum.h" +#include "src/ops/reciprocal.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -888,6 +889,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) Quant(primitive); case schema::PrimitiveType_OnnxInt8Dequantize: return new (std::nothrow) Dequant(primitive); + case schema::PrimitiveType_Reciprocal: + return new (std::nothrow) Reciprocal(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: diff --git a/mindspore/lite/src/ops/reciprocal.cc b/mindspore/lite/src/ops/reciprocal.cc new file mode 100644 index 0000000000..86966a584c --- /dev/null +++ b/mindspore/lite/src/ops/reciprocal.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/reciprocal.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifndef PRIMITIVE_WRITEABLE +PrimitiveC *ReciprocalCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry ReciprocalRegistry(schema::PrimitiveType_Reciprocal, ReciprocalCreator); +#endif + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/reciprocal.h b/mindspore/lite/src/ops/reciprocal.h new file mode 100644 index 0000000000..2af5b5d230 --- /dev/null +++ b/mindspore/lite/src/ops/reciprocal.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_ + +#include "src/ops/arithmetic_self.h" + +namespace mindspore { +namespace lite { +class Reciprocal : public ArithmeticSelf { + public: + Reciprocal() = default; + ~Reciprocal() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(Reciprocal, ArithmeticSelf); + explicit Reciprocal(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#else + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateReciprocal(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Reciprocal, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; + } +#endif +}; + +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_ diff --git a/mindspore/lite/src/ops/split.cc b/mindspore/lite/src/ops/split.cc index 45914ae0d4..ed5502c55a 100644 --- a/mindspore/lite/src/ops/split.cc +++ b/mindspore/lite/src/ops/split.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE int Split::GetNumberSplit() const { return this->primitive_->value.AsSplit()->numberSplit; } -std::vector Split::GetSizeSplits() const { return this->primitive_->value.AsSplit()->sizeSplits; } +std::vector Split::GetSizeSplit() const { return this->primitive_->value.AsSplit()->sizeSplits; } int Split::GetSplitDim() const { return this->primitive_->value.AsSplit()->splitDim; } void Split::SetNumberSplit(int number_split) { this->primitive_->value.AsSplit()->numberSplit = number_split; } @@ -67,7 +67,7 @@ int Split::UnPackAttr(const Primitive &prim, const std::vector &inpu #else int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); } -std::vector Split::GetSizeSplits() const { +std::vector Split::GetSizeSplit() const { auto fb_vector = this->primitive_->value_as_Split()->sizeSplits(); return std::vector(fb_vector->begin(), fb_vector->end()); } @@ -108,42 +108,50 @@ int Split::InferShape(std::vector inputs_, std::vector outpu MS_LOG(ERROR) << "inputs number is less to " << kSplitInputNum; return RET_ERROR; } - auto output = outputs_.front(); - if (output == nullptr) { - MS_LOG(ERROR) << "output null pointer dereferencing."; + if (outputs_.empty()) { + MS_LOG(ERROR) << "split has no output."; return RET_ERROR; } - int number_split = GetNumberSplit(); - if (static_cast(outputs_.size()) != number_split) { - MS_LOG(ERROR) << "outputs number is not equal to " << number_split; - return RET_ERROR; - } - for (int i = 0; i < number_split; ++i) { - outputs_.at(i)->set_data_type(input->data_type()); - outputs_.at(i)->set_format(input->format()); + for (auto &output : outputs_) { + output->set_data_type(input->data_type()); + output->set_format(input->format()); } + size_splits_ = GetSizeSplit(); + num_split_ = GetNumberSplit() == 0 ? static_cast(outputs_.size()) : GetNumberSplit(); if (!infer_flag()) { return RET_INFER_INVALID; } - size_t split_dim = GetSplitDim() == -1 ? input->shape().size() - 1 : GetSplitDim(); + size_t split_dim = GetSplitDim() < 0 ? input->shape().size() + GetSplitDim() : GetSplitDim(); std::vector input_shape = input->shape(); - std::vector size_split; - for (size_t i = 0; i < GetSizeSplits().size(); ++i) { - size_split.push_back(GetSizeSplits().at(i)); + if (split_dim > input_shape.size()) { + MS_LOG(ERROR) << "split dim is out of range, which is " << input_shape.size(); + return RET_INPUT_PARAM_INVALID; + } + if (static_cast(outputs_.size()) != num_split_) { + MS_LOG(ERROR) << "outputs number is not equal to " << num_split_; + return RET_ERROR; + } + if (size_splits_.empty()) { + if (input_shape[split_dim] % num_split_ != 0) { + MS_LOG(ERROR) << "cannot split to equal size, which dim is " << input_shape[split_dim] << ", num split is " + << num_split_; + return RET_INPUT_PARAM_INVALID; + } + for (int i = 0; i < num_split_; ++i) { + size_splits_.push_back(input_shape[split_dim] / num_split_); + } } - for (int i = 0; i < number_split; ++i) { + for (int i = 0; i < num_split_; ++i) { std::vector output_shape; output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end()); int split_dim_i = input_shape.at(split_dim); // support split size is -1 in the end. - if (size_split.empty()) { - split_dim_i = input_shape.at(split_dim) / number_split; - } else if (i == number_split - 1 && size_split.at(i) == -1) { - for (size_t j = 0; j < size_split.size() - 1; ++j) { - split_dim_i -= size_split.at(j); + if (i == num_split_ - 1 && size_splits_[i] == -1) { + for (size_t j = 0; j < size_splits_.size() - 1; ++j) { + split_dim_i -= size_splits_[j]; } } else { - split_dim_i = size_split.at(i); + split_dim_i = size_splits_[i]; } output_shape.at(split_dim) = split_dim_i; outputs_.at(i)->set_shape(output_shape); diff --git a/mindspore/lite/src/ops/split.h b/mindspore/lite/src/ops/split.h index 959bb9a86a..bbdf7515d3 100644 --- a/mindspore/lite/src/ops/split.h +++ b/mindspore/lite/src/ops/split.h @@ -42,8 +42,14 @@ class Split : public PrimitiveC { #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetNumberSplit() const; - std::vector GetSizeSplits() const; + std::vector GetSizeSplit() const; int GetSplitDim() const; + int num_split() const { return num_split_; } + std::vector size_splits() const { return size_splits_; } + + protected: + int num_split_ = 0; + std::vector size_splits_; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index 3eb3f064e2..319f0489be 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -139,8 +139,22 @@ int Tile::InferShape(std::vector inputs_, std::vector output } std::vector out_shape; - std::vector multiples = GetMultiples(); - + std::vector multiples; + if (inputs_.size() == 2) { + if (inputs_[1]->data_c() == nullptr) { + MS_LOG(INFO) << "Do infer shape in runtime."; + return RET_INFER_INVALID; + } + int data_num = inputs_[1]->ElementsNum(); + if (data_num > static_cast(input->shape().size())) { + MS_LOG(ERROR) << "multiples data num cannot be larger than input shape size."; + return RET_INPUT_TENSOR_ERROR; + } + multiples.resize(data_num); + memcpy(multiples.data(), inputs_[1]->data_c(), inputs_[1]->Size()); + } else { + multiples = GetMultiples(); + } #ifdef SUPPORT_TRAIN const size_t in_dims = input->shape().size(); const size_t delta_dims = in_dims - multiples.size(); @@ -156,6 +170,11 @@ int Tile::InferShape(std::vector inputs_, std::vector output } #else std::vector dims = GetDims(); + if (inputs_.size() == 2 && dims.empty()) { + for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) { + dims.push_back(dim); + } + } const size_t in_dims = input->shape().size(); MS_ASSERT(multiples.size() == dims.size()); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc index 6e62eb4295..aed5a8d140 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc @@ -38,7 +38,7 @@ int SplitBaseCPUKernel::ReSize() { auto input_shape = in_tensor->shape(); MS_ASSERT(param); - MS_ASSERT(input_shape.size() >= 2 && input_shape.size() <= SPLIT_STRIDES_SIZE); + MS_ASSERT(input_shape.size() >= 1 && input_shape.size() <= SPLIT_STRIDES_SIZE); param->strides_[input_shape.size() - 1] = 1; for (int i = input_shape.size() - 2; i >= 0; i--) { param->strides_[i] = param->strides_[i + 1] * input_shape.at(i + 1); @@ -50,8 +50,8 @@ int SplitBaseCPUKernel::ReSize() { param->n_dims_ = input_shape.size(); if (param->split_sizes_[0] == 0) { - MS_ASSERT(param->num_split_ > 0 && static_cast(param->num_split_) < input_shape.size()); - if (input_shape.at(param->split_dim_) % param->num_split_ != 0) { + MS_ASSERT(param->num_split_ > 0 && static_cast(param->num_split_) <= input_shape[param->split_dim_]); + if (input_shape[param->split_dim_] % param->num_split_ != 0) { MS_LOG(ERROR) << "Default split size is not usable."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_self_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_self_fp16.cc index 567f863fd5..b9dcead122 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_self_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_self_fp16.cc @@ -43,7 +43,8 @@ ArithmeticSelfFp16Func ArithmeticSelfFp16CPUKernel::GetArithmeticSelfFp16Fun(int {mindspore::schema::PrimitiveType_Floor, ElementFloorFp16}, {mindspore::schema::PrimitiveType_Ceil, ElementCeilFp16}, {mindspore::schema::PrimitiveType_Round, ElementRoundFp16}, - {mindspore::schema::PrimitiveType_Neg, ElementNegativeFp16}}; + {mindspore::schema::PrimitiveType_Neg, ElementNegativeFp16}, + {mindspore::schema::PrimitiveType_Reciprocal, ElementReciprocalFp16}}; for (size_t i = 0; i < sizeof(type_func_table) / sizeof(TYPE_FUNC_INFO); i++) { if (type_func_table[i].primitive_type_ == primitive_type) { return type_func_table[i].func_; @@ -139,4 +140,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Floor, CpuArithmeticSelfFp16K REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Ceil, CpuArithmeticSelfFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Round, CpuArithmeticSelfFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Neg, CpuArithmeticSelfFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Reciprocal, CpuArithmeticSelfFp16KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.cc index 5d77e788be..2da568d754 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.cc @@ -41,7 +41,8 @@ ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_t {mindspore::schema::PrimitiveType_Floor, ElementFloor}, {mindspore::schema::PrimitiveType_Ceil, ElementCeil}, {mindspore::schema::PrimitiveType_Round, ElementRound}, - {mindspore::schema::PrimitiveType_Neg, ElementNegative}}; + {mindspore::schema::PrimitiveType_Neg, ElementNegative}, + {mindspore::schema::PrimitiveType_Reciprocal, ElementReciprocal}}; for (size_t i = 0; i < sizeof(type_func_table) / sizeof(TYPE_FUNC_INFO); i++) { if (type_func_table[i].primitive_type_ == primitive_type) { return type_func_table[i].func_; @@ -152,4 +153,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32K REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Round, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Neg, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reciprocal, CpuArithmeticSelfFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.h index 5d6a775653..e88ecec4db 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.h @@ -26,6 +26,7 @@ using mindspore::schema::PrimitiveType_Floor; using mindspore::schema::PrimitiveType_Log; using mindspore::schema::PrimitiveType_LogicalNot; using mindspore::schema::PrimitiveType_Neg; +using mindspore::schema::PrimitiveType_Reciprocal; using mindspore::schema::PrimitiveType_Round; using mindspore::schema::PrimitiveType_Rsqrt; using mindspore::schema::PrimitiveType_Sin; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tile_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tile_fp32.cc index da68b6405a..16da523978 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tile_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tile_fp32.cc @@ -24,6 +24,9 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Tile; namespace mindspore::kernel { +namespace { +constexpr size_t kDoubleInputsSize = 2; +} int TileCPUKernel::Init() { if (!InferShapeDone()) { return RET_OK; @@ -42,6 +45,17 @@ void TileCPUKernel::ComputeStrides(const int *shape, int *strides, int ndim) { int TileCPUKernel::ReSize() { auto tile_parameter_ = reinterpret_cast(op_parameter_); MS_ASSERT(tile_parameter_); + if (in_tensors_.size() == kDoubleInputsSize) { + if (in_tensors_[1]->ElementsNum() > static_cast(in_tensors_[0]->shape().size())) { + MS_LOG(ERROR) << "tile's input1 data_num cannot be larger than input0's shape_size."; + return false; + } + auto input1_addr = reinterpret_cast(in_tensors_[1]->data_c()); + for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) { + tile_parameter_->dims_[i] = i; + tile_parameter_->multiples_[i] = input1_addr[i]; + } + } tile_parameter_->in_dim_ = in_tensors_.at(0)->shape().size(); for (int i = 0; i < tile_parameter_->in_dim_; ++i) { tile_parameter_->in_shape_[i] = in_tensors_.at(0)->shape().at(i); @@ -93,4 +107,5 @@ kernel::LiteKernel *CpuTileFp32KernelCreator(const std::vector & } REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Tile, CpuTileFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Tile, CpuTileFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc index bdfc349ec4..b27557181b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc @@ -146,4 +146,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Reciprocal, CpuArithmeticSelfInt8KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.h index 5fc64fa3f5..49e3f8274b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.h @@ -31,6 +31,7 @@ using mindspore::schema::PrimitiveType_Cos; using mindspore::schema::PrimitiveType_Floor; using mindspore::schema::PrimitiveType_Log; using mindspore::schema::PrimitiveType_LogicalNot; +using mindspore::schema::PrimitiveType_Reciprocal; using mindspore::schema::PrimitiveType_Round; using mindspore::schema::PrimitiveType_Rsqrt; using mindspore::schema::PrimitiveType_Sin; @@ -80,6 +81,8 @@ class ArithmeticSelfInt8CPUKernel : public LiteKernel { case PrimitiveType_LogicalNot: arithmeticSelf_run_ = Int8ElementLogicalNot; break; + case PrimitiveType_Reciprocal: + arithmeticSelf_run_ = Int8ElementReciprocal; default: break; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc index e72ffa3626..86e02cb2c3 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc @@ -690,6 +690,29 @@ STATUS OnnxRoundParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No op->primitive->value.value = attr.release(); return RET_OK; } + +STATUS OnnxReciprocalParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx ReciprocalParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + op->primitive->value.type = schema::PrimitiveType_Reciprocal; + op->primitive->value.value = attr.release(); + return RET_OK; +} OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser()); OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser()); OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser()); @@ -720,5 +743,6 @@ OnnxNodeRegistrar g_onnxAndParser("And", new OnnxAndParser()); OnnxNodeRegistrar g_onnxOrParser("Or", new OnnxOrParser()); OnnxNodeRegistrar g_onnxNotParser("Not", new OnnxNotParser()); OnnxNodeRegistrar g_onnxRoundParser("Round", new OnnxRoundParser()); +OnnxNodeRegistrar g_onnxReciprocalParser("Reciprocal", new OnnxReciprocalParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h index b6533f5188..a6c635c088 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h @@ -217,6 +217,13 @@ class OnnxRoundParser : public OnnxNodeParser { ~OnnxRoundParser() override = default; STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; }; + +class OnnxReciprocalParser : public OnnxNodeParser { + public: + OnnxReciprocalParser() : OnnxNodeParser("Reciprocal") {} + ~OnnxReciprocalParser() override = default; + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc index d42813c6f1..d57b6f1719 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc @@ -41,8 +41,11 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "to") { - attr->dstT = static_cast( - OnnxModelParser::GetDataTypeFromOnnx(static_cast(onnx_node_attr.i()))); + auto dst_type = OnnxModelParser::GetDataTypeFromOnnx(static_cast(onnx_node_attr.i())); + if (dst_type == kNumberTypeInt64) { + dst_type = kNumberTypeInt32; + } + attr->dstT = static_cast(dst_type); } } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index a82db65116..cd52428771 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -105,7 +105,7 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const st MS_LOG(ERROR) << "new tensor failed"; return RET_ERROR; } - tensor->dataType = data_type; + tensor->dataType = data_type == kNumberTypeInt64 ? kNumberTypeInt32 : data_type; tensor->dims = GetDimsFromOnnxValue(proto); tensor->format = schema::Format::Format_NCHW; tensor->nodeType = schema::NodeType::NodeType_ValueNode; @@ -370,7 +370,6 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const MS_LOG(ERROR) << "new QuantParamT failed, node: " << dst_op->name; return; } - quant_param->inited = true; int argNum = 0; for (const auto &onnx_node_attr : node.attribute()) { if (onnx_node_attr.name() == "Y_scale") { @@ -382,11 +381,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const } } if (argNum != 2) { - quant_param->scale = FLT_MAX; - quant_param->zeroPoint = 0; - quant_param->min = FLT_MAX; - quant_param->max = FLT_MAX; - quant_param->inited = false; + continue; } dst_tensor->quantParams.emplace_back(std::move(quant_param)); if (argNum == 2) { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc index 5672819be9..5fa0d40d63 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc @@ -15,7 +15,9 @@ */ #include "tools/converter/parser/onnx/onnx_slice_parser.h" +#include #include +#include #include #include @@ -46,6 +48,35 @@ STATUS OnnxSliceParser::InsertTensor(const std::vector &onnx_val, const std return RET_OK; } +STATUS OnnxSliceParser::GetInputTensor(std::vector *onnx_val, const std::string &name) { + if (onnx_val == nullptr) { + MS_LOG(ERROR) << "input vector is nullptr."; + return RET_ERROR; + } + if (OnnxTensorParser::GetInstance() == nullptr || OnnxTensorParser::GetInstance()->GetTensorCache() == nullptr) { + MS_LOG(ERROR) << "cannot get tensorcache."; + return RET_ERROR; + } + int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(name); + if (index == -1) { + MS_LOG(ERROR) << "can not find node: " << name; + return RET_ERROR; + } + auto input_tensor = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index]; + if (input_tensor->data.empty()) { + MS_LOG(DEBUG) << "data is empty."; + return RET_NO_CHANGE; + } + int data_num = std::accumulate(input_tensor->dims.begin(), input_tensor->dims.end(), 1, std::multiplies()); + onnx_val->resize(data_num); + if (memcpy_s(onnx_val->data(), data_num * sizeof(int32_t), input_tensor->data.data(), data_num * sizeof(int32_t)) != + EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + return RET_ERROR; + } + return RET_OK; +} + STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx SliceParser"; @@ -97,6 +128,36 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No } } } + int status = RET_OK; + switch (onnx_node.input_size()) { + case 5: { + if (steps.empty()) { + status = GetInputTensor(&steps, onnx_node.input(4)); + } + } + case 4: { + if (status != RET_ERROR && axes.empty()) { + status = GetInputTensor(&axes, onnx_node.input(3)); + } + } + case 3: { + if (status != RET_ERROR && ends.empty()) { + status = GetInputTensor(&ends, onnx_node.input(2)); + } + } + case 2: { + if (status != RET_ERROR && starts.empty()) { + status = GetInputTensor(&starts, onnx_node.input(1)); + } + } + default: { + if (status == RET_ERROR) { + MS_LOG(ERROR) << "onnx slice inputs are invalid."; + return RET_INPUT_TENSOR_ERROR; + } + } + } + if (axes.empty()) { for (size_t i = 0; i < starts.size(); ++i) { axes.push_back(i); @@ -112,7 +173,6 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No } } int insert_num = 5 - onnx_node.input_size(); - int status = RET_OK; switch (insert_num) { case 4: { std::string name = "slice/starts/"; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h index 83fde3ea95..7bf60dfcc2 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h @@ -32,6 +32,7 @@ class OnnxSliceParser : public OnnxNodeParser { STATUS InsertTensor(const std::vector &onnx_val, const std::string &name, onnx::NodeProto *onnx_node); STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + STATUS GetInputTensor(std::vector *onnx_val, const std::string &name); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc index c9655773a2..8ee86e99c5 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc @@ -38,6 +38,7 @@ STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No return RET_NULL_PTR; } + attr->splitDim = 0; for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "axis") { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc index eaa93acba2..bdde1605f8 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc @@ -16,9 +16,7 @@ #include "tools/converter/parser/onnx/onnx_tile_parser.h" #include -#include #include -#include "tools/converter/parser/onnx/onnx_tensor_parser.h" namespace mindspore { namespace lite { @@ -39,26 +37,6 @@ STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } - const auto &onnx_tile_multiple = onnx_node.input(1); - int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_tile_multiple); - if (index == -1) { - MS_LOG(ERROR) << "can not find node: " << onnx_tile_multiple; - return RET_ERROR; - } - auto tile_attr = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index]; - if (tile_attr->data.data() == nullptr) { - MS_LOG(ERROR) << "power's attr pow can't be obtained."; - return RET_INVALID_OP_ATTR; - } - int element_size = std::accumulate(tile_attr->dims.begin(), tile_attr->dims.end(), 1, std::multiplies()); - std::vector multiples; - std::vector dims; - for (int i = 0; i < element_size; ++i) { - multiples.push_back(reinterpret_cast(tile_attr->data.data())[i]); - dims.push_back(i); - } - attr->multiples = multiples; - attr->dims = dims; op->primitive->value.type = schema::PrimitiveType_Tile; op->primitive->value.value = attr.release(); return RET_OK; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc index 70b8f2d1b3..3a5dc26ace 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc @@ -23,7 +23,6 @@ namespace mindspore { namespace lite { PrimitiveC *TfliteTileParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model) { - auto &tflite_subgraph = tflite_model->subgraphs.front(); auto primitive = std::make_unique(); if (primitive == nullptr) { MS_LOG(ERROR) << "primitive is null"; @@ -35,16 +34,6 @@ PrimitiveC *TfliteTileParser::ParseLitePrimitive(const std::unique_ptrinputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->multiples)) { - MS_LOG(ERROR) << "get tile -> multiples failed"; - return nullptr; - } - std::vector dims(attr->multiples.size(), 0); - for (size_t i = 0; i < dims.size(); ++i) { - dims[i] = i; - } - attr->dims = dims; primitive->value.type = schema::PrimitiveType_Tile; primitive->value.value = attr.release(); return PrimitiveC::Create(primitive.release()); diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 25ff498afe..cf5c843e04 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -133,6 +133,10 @@ lite::STATUS ReplaceCNode(const FuncGraphPtr &func_graph, const CNodePtr &any_no if (output_tensors.size() != 1) { for (size_t k = 0; k < output_tensors.size(); k++) { auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, input_node, k); + if (used_node_list->empty()) { + MS_LOG(DEBUG) << "this output don't be used by other node."; + continue; + } if (used_node_list->size() != 1) { MS_LOG(ERROR) << " output must tuple_getitem"; return lite::RET_ERROR;