!9347 [lite] add reciprocal op and adjust tile、split

From: @xu_anyue
Reviewed-by: 
Signed-off-by:
pull/9347/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c04304337a

@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <assert.h>
#include <math.h> #include <math.h>
#include "nnacl/fp16/arithmetic_self_fp16.h" #include "nnacl/fp16/arithmetic_self_fp16.h"
@ -108,3 +109,11 @@ int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size) {
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; ++i) {
assert(input[i] != 0.0f);
output[i] = 1.f / input[i];
}
return NNACL_OK;
}

@ -48,6 +48,8 @@ int ElementFloorFp16(float16_t *input, float16_t *output, int element_size);
int ElementCeilFp16(float16_t *input, float16_t *output, int number); int ElementCeilFp16(float16_t *input, float16_t *output, int number);
int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size); int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size);
int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

@ -16,6 +16,7 @@
#include <string.h> #include <string.h>
#include <math.h> #include <math.h>
#include <assert.h>
#include "nnacl/fp32/arithmetic_self_fp32.h" #include "nnacl/fp32/arithmetic_self_fp32.h"
// abs: // abs:
@ -128,3 +129,11 @@ int ElementNegative(const float *input, float *output, const int element_size) {
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementReciprocal(const float *input, float *output, const int element_size) {
for (int i = 0; i < element_size; ++i) {
assert(input[i] != 0.0f);
output[i] = 1.f / input[i];
}
return NNACL_OK;
}

@ -51,6 +51,8 @@ int ElementFloor(const float *input, float *output, const int element_size);
int ElementCeil(const float *input, float *output, const int number); int ElementCeil(const float *input, float *output, const int number);
int ElementNegative(const float *input, float *output, const int element_size); int ElementNegative(const float *input, float *output, const int element_size);
int ElementReciprocal(const float *input, float *output, const int element_size);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

@ -15,6 +15,7 @@
*/ */
#include <math.h> #include <math.h>
#include <assert.h>
#include "nnacl/int8/arithmetic_self_int8.h" #include "nnacl/int8/arithmetic_self_int8.h"
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
#include <arm_neon.h> #include <arm_neon.h>
@ -278,3 +279,24 @@ int Int8ElementLogicalNot(int8_t *input, int8_t *output, int element_size, Arith
} }
return NNACL_OK; return NNACL_OK;
} }
int Int8ElementReciprocal(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) {
float in_scale = para.in_args_.scale_;
int32_t in_zp = para.in_args_.zp_;
float out_scale = para.out_args_.scale_;
int32_t out_zp = para.out_args_.zp_;
float bias = in_zp * in_scale;
for (int i = 0; i < element_size; i++) {
float input_f32 = input[i] * in_scale + bias;
assert(input_f32 != 0.0f);
int32_t output_tmp = round(1.f / (input_f32 * out_scale)) + out_zp;
if (output_tmp > para.output_activation_max_) {
output[i] = para.output_activation_max_;
} else if (output_tmp < para.output_activation_min_) {
output[i] = para.output_activation_min_;
} else {
output[i] = (int8_t)output_tmp;
}
}
return NNACL_OK;
}

@ -50,6 +50,8 @@ int Int8ElementSquare(int8_t *input, int8_t *output, int element_size, ArithSelf
int Int8ElementLogicalNot(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); int Int8ElementLogicalNot(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para);
int Int8ElementReciprocal(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

@ -253,7 +253,8 @@ union PrimitiveType {
All, All,
Assert, Assert,
Adder, Adder,
SparseSoftmaxCrossEntropy SparseSoftmaxCrossEntropy,
Reciprocal,
} }
enum QuantType: int { enum QuantType: int {

@ -1203,3 +1203,6 @@ table All {
table Assert { table Assert {
summarize : int; summarize : int;
} }
table Reciprocal {
}

@ -375,11 +375,11 @@ void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output
int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (inputs_.size() != 2 && inputs_.size() != 3) { if (inputs_.size() != 2 && inputs_.size() != 3) {
MS_LOG(ERROR) << "Add should has two or three inputs"; MS_LOG(ERROR) << "Conv2d should has two or three inputs";
return RET_ERROR; return RET_ERROR;
} }
if (outputs_.size() != 1) { if (outputs_.size() != 1) {
MS_LOG(ERROR) << "Add should has one outputs"; MS_LOG(ERROR) << "Conv2d should has one outputs";
return RET_ERROR; return RET_ERROR;
} }
auto *input_tensor = inputs_.front(); auto *input_tensor = inputs_.front();

@ -47,6 +47,7 @@ Registry LogicalNotParameterRegistry(schema::PrimitiveType_LogicalNot, PopulateA
Registry FloorParameterRegistry(schema::PrimitiveType_Floor, PopulateArithmeticSelf); Registry FloorParameterRegistry(schema::PrimitiveType_Floor, PopulateArithmeticSelf);
Registry CeilParameterRegistry(schema::PrimitiveType_Ceil, PopulateArithmeticSelf); Registry CeilParameterRegistry(schema::PrimitiveType_Ceil, PopulateArithmeticSelf);
Registry RoundParameterRegistry(schema::PrimitiveType_Round, PopulateArithmeticSelf); Registry RoundParameterRegistry(schema::PrimitiveType_Round, PopulateArithmeticSelf);
Registry ReciprocalParameterRegistry(schema::PrimitiveType_Reciprocal, PopulateArithmeticSelf);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -31,7 +31,7 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive
memset(split_param, 0, sizeof(SplitParameter)); memset(split_param, 0, sizeof(SplitParameter));
auto param = reinterpret_cast<mindspore::lite::Split *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); auto param = reinterpret_cast<mindspore::lite::Split *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
split_param->op_parameter_.type_ = primitive->Type(); split_param->op_parameter_.type_ = primitive->Type();
split_param->num_split_ = param->GetNumberSplit(); split_param->num_split_ = param->num_split();
if (split_param->num_split_ > std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) { if (split_param->num_split_ > std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
MS_LOG(ERROR) << "The value of split_param->num_split_ is too big"; MS_LOG(ERROR) << "The value of split_param->num_split_ is too big";
return nullptr; return nullptr;
@ -44,7 +44,7 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive
} }
memset(split_sizes, 0, split_param->num_split_ * sizeof(int)); memset(split_sizes, 0, split_param->num_split_ * sizeof(int));
split_param->split_sizes_ = split_sizes; split_param->split_sizes_ = split_sizes;
auto split_sizes_vector_ = param->GetSizeSplits(); auto split_sizes_vector_ = param->size_splits();
int i = 0; int i = 0;
for (int &iter : split_sizes_vector_) { for (int &iter : split_sizes_vector_) {
split_param->split_sizes_[i++] = iter; split_param->split_sizes_[i++] = iter;

@ -43,8 +43,10 @@ OpParameter *PopulateTileParameter(const mindspore::lite::PrimitiveC *primitive)
for (size_t i = 0; i < kDimension_4d; ++i) { for (size_t i = 0; i < kDimension_4d; ++i) {
tile_param->multiples_[i] = 1; tile_param->multiples_[i] = 1;
} }
if (!dims.empty() && !multiples.empty()) {
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < dims.size(); ++i) {
tile_param->multiples_[dims.at(i)] = multiples.at(i); tile_param->multiples_[dims[i]] = multiples[i];
}
} }
#endif #endif
return reinterpret_cast<OpParameter *>(tile_param); return reinterpret_cast<OpParameter *>(tile_param);

@ -148,6 +148,7 @@
#include "src/ops/while.h" #include "src/ops/while.h"
#include "src/ops/oneslike.h" #include "src/ops/oneslike.h"
#include "src/ops/unsorted_segment_sum.h" #include "src/ops/unsorted_segment_sum.h"
#include "src/ops/reciprocal.h"
#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h" #include "src/ops/neg_grad.h"
@ -888,6 +889,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) Quant(primitive); return new (std::nothrow) Quant(primitive);
case schema::PrimitiveType_OnnxInt8Dequantize: case schema::PrimitiveType_OnnxInt8Dequantize:
return new (std::nothrow) Dequant(primitive); return new (std::nothrow) Dequant(primitive);
case schema::PrimitiveType_Reciprocal:
return new (std::nothrow) Reciprocal(primitive);
#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad: case schema::PrimitiveType_ActivationGrad:

@ -0,0 +1,33 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/reciprocal.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
PrimitiveC *ReciprocalCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<Reciprocal>(primitive);
}
Registry ReciprocalRegistry(schema::PrimitiveType_Reciprocal, ReciprocalCreator);
#endif
} // namespace lite
} // namespace mindspore

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_
#define LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_
#include "src/ops/arithmetic_self.h"
namespace mindspore {
namespace lite {
class Reciprocal : public ArithmeticSelf {
public:
Reciprocal() = default;
~Reciprocal() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Reciprocal, ArithmeticSelf);
explicit Reciprocal(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateReciprocal(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Reciprocal, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_

@ -24,7 +24,7 @@ namespace mindspore {
namespace lite { namespace lite {
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
int Split::GetNumberSplit() const { return this->primitive_->value.AsSplit()->numberSplit; } int Split::GetNumberSplit() const { return this->primitive_->value.AsSplit()->numberSplit; }
std::vector<int> Split::GetSizeSplits() const { return this->primitive_->value.AsSplit()->sizeSplits; } std::vector<int> Split::GetSizeSplit() const { return this->primitive_->value.AsSplit()->sizeSplits; }
int Split::GetSplitDim() const { return this->primitive_->value.AsSplit()->splitDim; } int Split::GetSplitDim() const { return this->primitive_->value.AsSplit()->splitDim; }
void Split::SetNumberSplit(int number_split) { this->primitive_->value.AsSplit()->numberSplit = number_split; } void Split::SetNumberSplit(int number_split) { this->primitive_->value.AsSplit()->numberSplit = number_split; }
@ -67,7 +67,7 @@ int Split::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
#else #else
int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); } int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); }
std::vector<int> Split::GetSizeSplits() const { std::vector<int> Split::GetSizeSplit() const {
auto fb_vector = this->primitive_->value_as_Split()->sizeSplits(); auto fb_vector = this->primitive_->value_as_Split()->sizeSplits();
return std::vector<int>(fb_vector->begin(), fb_vector->end()); return std::vector<int>(fb_vector->begin(), fb_vector->end());
} }
@ -108,42 +108,50 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
MS_LOG(ERROR) << "inputs number is less to " << kSplitInputNum; MS_LOG(ERROR) << "inputs number is less to " << kSplitInputNum;
return RET_ERROR; return RET_ERROR;
} }
auto output = outputs_.front(); if (outputs_.empty()) {
if (output == nullptr) { MS_LOG(ERROR) << "split has no output.";
MS_LOG(ERROR) << "output null pointer dereferencing.";
return RET_ERROR; return RET_ERROR;
} }
int number_split = GetNumberSplit(); for (auto &output : outputs_) {
if (static_cast<int>(outputs_.size()) != number_split) { output->set_data_type(input->data_type());
MS_LOG(ERROR) << "outputs number is not equal to " << number_split; output->set_format(input->format());
return RET_ERROR;
}
for (int i = 0; i < number_split; ++i) {
outputs_.at(i)->set_data_type(input->data_type());
outputs_.at(i)->set_format(input->format());
} }
size_splits_ = GetSizeSplit();
num_split_ = GetNumberSplit() == 0 ? static_cast<int>(outputs_.size()) : GetNumberSplit();
if (!infer_flag()) { if (!infer_flag()) {
return RET_INFER_INVALID; return RET_INFER_INVALID;
} }
size_t split_dim = GetSplitDim() == -1 ? input->shape().size() - 1 : GetSplitDim(); size_t split_dim = GetSplitDim() < 0 ? input->shape().size() + GetSplitDim() : GetSplitDim();
std::vector<int> input_shape = input->shape(); std::vector<int> input_shape = input->shape();
std::vector<int> size_split; if (split_dim > input_shape.size()) {
for (size_t i = 0; i < GetSizeSplits().size(); ++i) { MS_LOG(ERROR) << "split dim is out of range, which is " << input_shape.size();
size_split.push_back(GetSizeSplits().at(i)); return RET_INPUT_PARAM_INVALID;
}
if (static_cast<int>(outputs_.size()) != num_split_) {
MS_LOG(ERROR) << "outputs number is not equal to " << num_split_;
return RET_ERROR;
}
if (size_splits_.empty()) {
if (input_shape[split_dim] % num_split_ != 0) {
MS_LOG(ERROR) << "cannot split to equal size, which dim is " << input_shape[split_dim] << ", num split is "
<< num_split_;
return RET_INPUT_PARAM_INVALID;
}
for (int i = 0; i < num_split_; ++i) {
size_splits_.push_back(input_shape[split_dim] / num_split_);
}
} }
for (int i = 0; i < number_split; ++i) { for (int i = 0; i < num_split_; ++i) {
std::vector<int> output_shape; std::vector<int> output_shape;
output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end()); output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end());
int split_dim_i = input_shape.at(split_dim); int split_dim_i = input_shape.at(split_dim);
// support split size is -1 in the end. // support split size is -1 in the end.
if (size_split.empty()) { if (i == num_split_ - 1 && size_splits_[i] == -1) {
split_dim_i = input_shape.at(split_dim) / number_split; for (size_t j = 0; j < size_splits_.size() - 1; ++j) {
} else if (i == number_split - 1 && size_split.at(i) == -1) { split_dim_i -= size_splits_[j];
for (size_t j = 0; j < size_split.size() - 1; ++j) {
split_dim_i -= size_split.at(j);
} }
} else { } else {
split_dim_i = size_split.at(i); split_dim_i = size_splits_[i];
} }
output_shape.at(split_dim) = split_dim_i; output_shape.at(split_dim) = split_dim_i;
outputs_.at(i)->set_shape(output_shape); outputs_.at(i)->set_shape(output_shape);

@ -42,8 +42,14 @@ class Split : public PrimitiveC {
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetNumberSplit() const; int GetNumberSplit() const;
std::vector<int> GetSizeSplits() const; std::vector<int> GetSizeSplit() const;
int GetSplitDim() const; int GetSplitDim() const;
int num_split() const { return num_split_; }
std::vector<int> size_splits() const { return size_splits_; }
protected:
int num_split_ = 0;
std::vector<int> size_splits_;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -139,8 +139,22 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
} }
std::vector<int> out_shape; std::vector<int> out_shape;
std::vector<int> multiples = GetMultiples(); std::vector<int> multiples;
if (inputs_.size() == 2) {
if (inputs_[1]->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
int data_num = inputs_[1]->ElementsNum();
if (data_num > static_cast<int>(input->shape().size())) {
MS_LOG(ERROR) << "multiples data num cannot be larger than input shape size.";
return RET_INPUT_TENSOR_ERROR;
}
multiples.resize(data_num);
memcpy(multiples.data(), inputs_[1]->data_c(), inputs_[1]->Size());
} else {
multiples = GetMultiples();
}
#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
const size_t in_dims = input->shape().size(); const size_t in_dims = input->shape().size();
const size_t delta_dims = in_dims - multiples.size(); const size_t delta_dims = in_dims - multiples.size();
@ -156,6 +170,11 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
} }
#else #else
std::vector<int> dims = GetDims(); std::vector<int> dims = GetDims();
if (inputs_.size() == 2 && dims.empty()) {
for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) {
dims.push_back(dim);
}
}
const size_t in_dims = input->shape().size(); const size_t in_dims = input->shape().size();
MS_ASSERT(multiples.size() == dims.size()); MS_ASSERT(multiples.size() == dims.size());

@ -38,7 +38,7 @@ int SplitBaseCPUKernel::ReSize() {
auto input_shape = in_tensor->shape(); auto input_shape = in_tensor->shape();
MS_ASSERT(param); MS_ASSERT(param);
MS_ASSERT(input_shape.size() >= 2 && input_shape.size() <= SPLIT_STRIDES_SIZE); MS_ASSERT(input_shape.size() >= 1 && input_shape.size() <= SPLIT_STRIDES_SIZE);
param->strides_[input_shape.size() - 1] = 1; param->strides_[input_shape.size() - 1] = 1;
for (int i = input_shape.size() - 2; i >= 0; i--) { for (int i = input_shape.size() - 2; i >= 0; i--) {
param->strides_[i] = param->strides_[i + 1] * input_shape.at(i + 1); param->strides_[i] = param->strides_[i + 1] * input_shape.at(i + 1);
@ -50,8 +50,8 @@ int SplitBaseCPUKernel::ReSize() {
param->n_dims_ = input_shape.size(); param->n_dims_ = input_shape.size();
if (param->split_sizes_[0] == 0) { if (param->split_sizes_[0] == 0) {
MS_ASSERT(param->num_split_ > 0 && static_cast<int>(param->num_split_) < input_shape.size()); MS_ASSERT(param->num_split_ > 0 && static_cast<int>(param->num_split_) <= input_shape[param->split_dim_]);
if (input_shape.at(param->split_dim_) % param->num_split_ != 0) { if (input_shape[param->split_dim_] % param->num_split_ != 0) {
MS_LOG(ERROR) << "Default split size is not usable."; MS_LOG(ERROR) << "Default split size is not usable.";
return RET_ERROR; return RET_ERROR;
} }

@ -43,7 +43,8 @@ ArithmeticSelfFp16Func ArithmeticSelfFp16CPUKernel::GetArithmeticSelfFp16Fun(int
{mindspore::schema::PrimitiveType_Floor, ElementFloorFp16}, {mindspore::schema::PrimitiveType_Floor, ElementFloorFp16},
{mindspore::schema::PrimitiveType_Ceil, ElementCeilFp16}, {mindspore::schema::PrimitiveType_Ceil, ElementCeilFp16},
{mindspore::schema::PrimitiveType_Round, ElementRoundFp16}, {mindspore::schema::PrimitiveType_Round, ElementRoundFp16},
{mindspore::schema::PrimitiveType_Neg, ElementNegativeFp16}}; {mindspore::schema::PrimitiveType_Neg, ElementNegativeFp16},
{mindspore::schema::PrimitiveType_Reciprocal, ElementReciprocalFp16}};
for (size_t i = 0; i < sizeof(type_func_table) / sizeof(TYPE_FUNC_INFO); i++) { for (size_t i = 0; i < sizeof(type_func_table) / sizeof(TYPE_FUNC_INFO); i++) {
if (type_func_table[i].primitive_type_ == primitive_type) { if (type_func_table[i].primitive_type_ == primitive_type) {
return type_func_table[i].func_; return type_func_table[i].func_;
@ -139,4 +140,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Floor, CpuArithmeticSelfFp16K
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Ceil, CpuArithmeticSelfFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Ceil, CpuArithmeticSelfFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Round, CpuArithmeticSelfFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Round, CpuArithmeticSelfFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Neg, CpuArithmeticSelfFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Neg, CpuArithmeticSelfFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Reciprocal, CpuArithmeticSelfFp16KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

@ -41,7 +41,8 @@ ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_t
{mindspore::schema::PrimitiveType_Floor, ElementFloor}, {mindspore::schema::PrimitiveType_Floor, ElementFloor},
{mindspore::schema::PrimitiveType_Ceil, ElementCeil}, {mindspore::schema::PrimitiveType_Ceil, ElementCeil},
{mindspore::schema::PrimitiveType_Round, ElementRound}, {mindspore::schema::PrimitiveType_Round, ElementRound},
{mindspore::schema::PrimitiveType_Neg, ElementNegative}}; {mindspore::schema::PrimitiveType_Neg, ElementNegative},
{mindspore::schema::PrimitiveType_Reciprocal, ElementReciprocal}};
for (size_t i = 0; i < sizeof(type_func_table) / sizeof(TYPE_FUNC_INFO); i++) { for (size_t i = 0; i < sizeof(type_func_table) / sizeof(TYPE_FUNC_INFO); i++) {
if (type_func_table[i].primitive_type_ == primitive_type) { if (type_func_table[i].primitive_type_ == primitive_type) {
return type_func_table[i].func_; return type_func_table[i].func_;
@ -152,4 +153,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32K
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Round, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Round, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Neg, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Neg, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reciprocal, CpuArithmeticSelfFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

@ -26,6 +26,7 @@ using mindspore::schema::PrimitiveType_Floor;
using mindspore::schema::PrimitiveType_Log; using mindspore::schema::PrimitiveType_Log;
using mindspore::schema::PrimitiveType_LogicalNot; using mindspore::schema::PrimitiveType_LogicalNot;
using mindspore::schema::PrimitiveType_Neg; using mindspore::schema::PrimitiveType_Neg;
using mindspore::schema::PrimitiveType_Reciprocal;
using mindspore::schema::PrimitiveType_Round; using mindspore::schema::PrimitiveType_Round;
using mindspore::schema::PrimitiveType_Rsqrt; using mindspore::schema::PrimitiveType_Rsqrt;
using mindspore::schema::PrimitiveType_Sin; using mindspore::schema::PrimitiveType_Sin;

@ -24,6 +24,9 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Tile; using mindspore::schema::PrimitiveType_Tile;
namespace mindspore::kernel { namespace mindspore::kernel {
namespace {
constexpr size_t kDoubleInputsSize = 2;
}
int TileCPUKernel::Init() { int TileCPUKernel::Init() {
if (!InferShapeDone()) { if (!InferShapeDone()) {
return RET_OK; return RET_OK;
@ -42,6 +45,17 @@ void TileCPUKernel::ComputeStrides(const int *shape, int *strides, int ndim) {
int TileCPUKernel::ReSize() { int TileCPUKernel::ReSize() {
auto tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_); auto tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_);
MS_ASSERT(tile_parameter_); MS_ASSERT(tile_parameter_);
if (in_tensors_.size() == kDoubleInputsSize) {
if (in_tensors_[1]->ElementsNum() > static_cast<int>(in_tensors_[0]->shape().size())) {
MS_LOG(ERROR) << "tile's input1 data_num cannot be larger than input0's shape_size.";
return false;
}
auto input1_addr = reinterpret_cast<int *>(in_tensors_[1]->data_c());
for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) {
tile_parameter_->dims_[i] = i;
tile_parameter_->multiples_[i] = input1_addr[i];
}
}
tile_parameter_->in_dim_ = in_tensors_.at(0)->shape().size(); tile_parameter_->in_dim_ = in_tensors_.at(0)->shape().size();
for (int i = 0; i < tile_parameter_->in_dim_; ++i) { for (int i = 0; i < tile_parameter_->in_dim_; ++i) {
tile_parameter_->in_shape_[i] = in_tensors_.at(0)->shape().at(i); tile_parameter_->in_shape_[i] = in_tensors_.at(0)->shape().at(i);
@ -93,4 +107,5 @@ kernel::LiteKernel *CpuTileFp32KernelCreator(const std::vector<lite::Tensor *> &
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Tile, CpuTileFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Tile, CpuTileFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Tile, CpuTileFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

@ -146,4 +146,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Reciprocal, CpuArithmeticSelfInt8KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

@ -31,6 +31,7 @@ using mindspore::schema::PrimitiveType_Cos;
using mindspore::schema::PrimitiveType_Floor; using mindspore::schema::PrimitiveType_Floor;
using mindspore::schema::PrimitiveType_Log; using mindspore::schema::PrimitiveType_Log;
using mindspore::schema::PrimitiveType_LogicalNot; using mindspore::schema::PrimitiveType_LogicalNot;
using mindspore::schema::PrimitiveType_Reciprocal;
using mindspore::schema::PrimitiveType_Round; using mindspore::schema::PrimitiveType_Round;
using mindspore::schema::PrimitiveType_Rsqrt; using mindspore::schema::PrimitiveType_Rsqrt;
using mindspore::schema::PrimitiveType_Sin; using mindspore::schema::PrimitiveType_Sin;
@ -80,6 +81,8 @@ class ArithmeticSelfInt8CPUKernel : public LiteKernel {
case PrimitiveType_LogicalNot: case PrimitiveType_LogicalNot:
arithmeticSelf_run_ = Int8ElementLogicalNot; arithmeticSelf_run_ = Int8ElementLogicalNot;
break; break;
case PrimitiveType_Reciprocal:
arithmeticSelf_run_ = Int8ElementReciprocal;
default: default:
break; break;
} }

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save