resolve encoder model parser problems

pull/10446/head
cjh9368 4 years ago
parent 6849173681
commit 94047f6360

@ -257,6 +257,39 @@ int ElementOptSub(const float *input0, const float *input1, float *output, const
return NNACL_OK;
}
int ElementOptSubInt(const int *input0, const int *input1, int *output, const int element_size,
const ArithmeticParameter *param) {
#ifdef ENABLE_NEON
int32x4_t vin0_opt = vdupq_n_s32(input0[0]);
int32x4_t vin1_opt = vdupq_n_s32(input1[0]);
#endif
int index = 0;
if (param->in_elements_num0_ == 1) {
#ifdef ENABLE_NEON
for (; index <= element_size - 4; index += C4NUM) {
int32x4_t vin1 = vld1q_s32(input1 + index);
int32x4_t vout = vsubq_s32(vin0_opt, vin1);
vst1q_s32(output + index, vout);
}
#endif
for (; index < element_size; index++) {
output[index] = input0[0] - input1[index];
}
} else {
#ifdef ENABLE_NEON
for (; index <= element_size - 4; index += C4NUM) {
int32x4_t vin0 = vld1q_s32(input0 + index);
int32x4_t vout = vsubq_s32(vin0, vin1_opt);
vst1q_s32(output + index, vout);
}
#endif
for (; index < element_size; index++) {
output[index] = input0[index] - input1[0];
}
}
return NNACL_OK;
}
int ElementOptSubRelu(const float *input0, const float *input1, float *output, const int element_size,
const ArithmeticParameter *param) {
#ifdef ENABLE_NEON
@ -1007,6 +1040,22 @@ int ElementMaximum(const float *input0, const float *input1, float *output, cons
return NNACL_OK;
}
int ElementMaximumInt(const int *input0, const int *input1, int *output, const int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 4; index += C4NUM) {
int32x4_t vin0 = vld1q_s32(input0 + index);
int32x4_t vin1 = vld1q_s32(input1 + index);
int32x4_t vout = vmaxq_s32(vin0, vin1);
vst1q_s32(output + index, vout);
}
#endif
for (; index < element_size; index++) {
output[index] = input0[index] > input1[index] ? input0[index] : input1[index];
}
return NNACL_OK;
}
int BroadcastMaximum(const float *input0, const float *input1, float *tile_input0, float *tile_input1, float *output,
int element_size, ArithmeticParameter *param) {
TileDimensions(input0, input1, tile_input0, tile_input1, param);

@ -36,6 +36,8 @@ int ElementOptAddRelu6(const float *input0, const float *input1, float *output,
const ArithmeticParameter *param);
int ElementOptSub(const float *input0, const float *input1, float *output, const int element_size,
const ArithmeticParameter *param);
int ElementOptSubInt(const int *input0, const int *input1, int *output, const int element_size,
const ArithmeticParameter *param);
int ElementOptSubRelu(const float *input0, const float *input1, float *output, const int element_size,
const ArithmeticParameter *param);
int ElementOptSubRelu6(const float *input0, const float *input1, float *output, const int element_size,
@ -102,6 +104,7 @@ int BroadcastLogicalOr(const float *input0, const float *input1, float *tile_inp
int element_size, ArithmeticParameter *param);
int ElementMaximum(const float *input0, const float *input1, float *output, const int element_size);
int ElementMaximumInt(const int *input0, const int *input1, int *output, const int element_size);
int BroadcastMaximum(const float *input0, const float *input1, float *tile_input0, float *tile_input1, float *output,
int element_size, ArithmeticParameter *param);

@ -102,6 +102,27 @@ int ReduceMax(int outer_size, int inner_size, int axis_size, const float *src_da
}
return NNACL_OK;
}
int IntReduceMax(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
int i, j, k;
for (j = tid; j < outer_size; j += thread_num) {
const int *outer_src = src_data + j * axis_size * inner_size;
int *outer_dst = dst_data + j * inner_size;
for (k = 0; k < inner_size; k++) {
const int *inner_src = outer_src + k;
int *inner_dst = outer_dst + k;
int tmp = -INT_MAX;
for (i = 0; i < axis_size; i++) {
tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size];
}
*inner_dst = tmp;
}
}
return NNACL_OK;
}
int ReduceMin(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {

@ -28,6 +28,8 @@ int ReduceSum(int outer_size, int inner_size, int axis_size, const float *src_da
int thread_num);
int ReduceMax(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
int thread_num);
int IntReduceMax(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid,
int thread_num);
int ReduceMin(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
int thread_num);
int IntReduceMin(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid,

@ -261,6 +261,7 @@ union PrimitiveType {
Reciprocal,
Merge,
Mod,
If,
GeLU,
Gru,
}

@ -1154,6 +1154,9 @@ table While {
bodySubgraphIndex : int;
}
table If {
}
table UnsortedSegmentSum {
numSegments : int;
}

@ -56,6 +56,26 @@ PrimitiveC *FillCreator(const schema::Primitive *primitive) { return PrimitiveC:
Registry FillRegistry(schema::PrimitiveType_Fill, FillCreator);
#endif
template <typename T>
void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<int> *out_shape, int shape_size) {
int input_count = inputs[0]->ElementsNum();
int index = 0;
int size = 1;
for (int i = 0; i < shape_size; i++) {
if (static_cast<int>(data[i]) == -1) {
index = i;
} else if (static_cast<int>(data[i]) == 0) {
size *= inputs[0]->shape().at(i);
} else {
size *= data[i];
}
out_shape->push_back(data[i]);
}
if (static_cast<int>(data[index]) == -1) {
(*out_shape).at(index) = input_count / size;
}
}
int Fill::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
@ -64,7 +84,7 @@ int Fill::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
MS_LOG(ERROR) << "Fill input or output is null!";
return RET_ERROR;
}
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
if ((inputs_.size() != kSingleNum && inputs_.size() != kDoubleNum) || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
return RET_INPUT_TENSOR_ERROR;
}
@ -74,11 +94,54 @@ int Fill::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
return RET_INFER_INVALID;
}
std::vector<int> output_shape;
for (size_t i = 0; i < GetDims().size(); i++) {
output_shape.push_back(GetDims().at(i));
std::vector<int> out_shape;
if (inputs_.size() == kDoubleNum) {
auto shape_tensor = inputs_.at(1);
if (shape_tensor->IsConst()) {
if (shape_tensor->data_c() == nullptr || (shape_tensor->shape().size() == 1 && shape_tensor->shape()[0] == 0)) {
MS_LOG(DEBUG) << "reshape to a scalar.";
output->set_shape(out_shape);
return RET_OK;
}
}
if (shape_tensor->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
size_t shape_size = shape_tensor->ElementsNum();
switch (shape_tensor->data_type()) {
case kNumberTypeInt8: {
auto data = reinterpret_cast<int8_t *>(shape_tensor->MutableData());
CalShape<int8_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeInt32: {
auto data = reinterpret_cast<int32_t *>(shape_tensor->MutableData());
CalShape<int32_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeInt64: {
auto data = reinterpret_cast<int64_t *>(shape_tensor->MutableData());
CalShape<int64_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeFloat: {
auto data = reinterpret_cast<float *>(shape_tensor->MutableData());
CalShape<float>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeUInt32: {
auto data = reinterpret_cast<uint32_t *>(shape_tensor->MutableData());
CalShape<uint32_t>(data, inputs_, &out_shape, shape_size);
} break;
default: {
MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type();
return RET_INFER_ERR;
}
}
} else {
for (size_t i = 0; i < GetDims().size(); i++) {
out_shape.push_back(GetDims().at(i));
}
}
output->set_shape(output_shape);
output->set_shape(out_shape);
return RET_OK;
}
} // namespace lite

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/primitive_c.h"
#ifndef LITE_MINDSPORE_LITE_C_OPS_IF_H_
#define LITE_MINDSPORE_LITE_C_OPS_IF_H_
namespace mindspore {
namespace lite {
class If : public PrimitiveC {
public:
MS_DECLARE_PARENT(If, PrimitiveC);
If() = default;
~If() = default;
explicit If(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_IF_H_

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateWhereParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *where_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (where_parameter == nullptr) {
MS_LOG(ERROR) << "malloc Where parameter failed.";
return nullptr;
}
memset(where_parameter, 0, sizeof(OpParameter));
where_parameter->type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(where_parameter);
}
Registry WhereParameterRegistry(schema::PrimitiveType_Where, PopulateWhereParameter);
} // namespace lite
} // namespace mindspore

@ -160,6 +160,8 @@
#include "src/ops/merge.h"
#include "src/ops/switch.h"
#include "src/ops/partial.h"
#include "src/ops/if.h"
#include "src/ops/select.h"
#include "src/ops/gelu.h"
#include "src/ops/gru.h"
@ -996,6 +998,10 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) AssertOP(primitive);
case schema::PrimitiveType_GeLU:
return new (std::nothrow) GeLU(primitive);
case schema::PrimitiveType_If:
return new (std::nothrow) If(primitive);
case schema::PrimitiveType_Select:
return new (std::nothrow) Select(primitive);
case schema::PrimitiveType_Gru:
return new (std::nothrow) Gru(primitive);
#ifdef SUPPORT_TRAIN

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/primitive_c.h"
#ifndef LITE_MINDSPORE_LITE_C_OPS_SELECT_H_
#define LITE_MINDSPORE_LITE_C_OPS_SELECT_H_
namespace mindspore {
namespace lite {
class Select : public PrimitiveC {
public:
MS_DECLARE_PARENT(Select, PrimitiveC);
Select() = default;
~Select() = default;
explicit Select(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_SELECT_H_

@ -184,7 +184,8 @@ bool StridedSlice::CheckInputs(std::vector<lite::Tensor *> inputs_) {
return false;
}
}
return true;
return ndim_ <= in_shape_.size();
}
void StridedSlice::ApplyNewAxisMask() {
@ -237,7 +238,7 @@ void StridedSlice::ApplyEllipsisMask() {
}
void StridedSlice::ApplyBeginMask() {
for (int i = 0; i < ndim_; i++) {
for (size_t i = 0; i < ndim_; i++) {
if (begins_mask_.at(i)) {
begins_.at(i) = 0;
}
@ -245,7 +246,7 @@ void StridedSlice::ApplyBeginMask() {
}
void StridedSlice::ApplyEndMask() {
for (int i = 0; i < ndim_; i++) {
for (size_t i = 0; i < ndim_; i++) {
if (ends_mask_.at(i)) {
ends_.at(i) = in_shape_.at(i);
}
@ -320,10 +321,10 @@ int StridedSlice::HandleAxesInputExist(const std::vector<lite::Tensor *> &inputs
ends_.assign(ndim_, 0);
strides_.assign(ndim_, 0);
auto input_shape = input_tensor->shape();
for (int i = 0; i < ndim_; ++i) {
for (size_t i = 0; i < ndim_; ++i) {
in_shape_.at(i) = input_shape.at(i);
}
for (int i = 0; i < ndim_; ++i) {
for (size_t i = 0; i < ndim_; ++i) {
auto axes_it = std::find(axes.begin(), axes.end(), i);
if (axes_it != axes.end()) {
auto axis = axes_it - axes.begin();
@ -369,7 +370,7 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
if (inputs.size() == kStridedSliceInputNum) {
ndim_ = static_cast<int>(GetBegin().size());
for (int i = 0; i < ndim_; i++) {
for (size_t i = 0; i < ndim_; i++) {
begins_.emplace_back((GetBegin()).at(i));
ends_.emplace_back((GetEnd()).at(i));
strides_.emplace_back((GetStride()).at(i));
@ -391,7 +392,7 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
return RET_INFER_ERR;
}
ndim_ = begin_tensor->ElementsNum();
for (int i = 0; i < ndim_; ++i) {
for (size_t i = 0; i < ndim_; ++i) {
begins_.emplace_back(begin_data[i]);
ends_.emplace_back(end_data[i]);
strides_.emplace_back(stride_data[i]);
@ -413,7 +414,7 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
shrink_axis_mask_.resize(ndim_);
// convert bit to vector
for (int i = 0; i < ndim_; i++) {
for (size_t i = 0; i < ndim_; i++) {
begins_mask_.at(i) = static_cast<uint32_t>(GetBeginMask()) & (1 << i);
ends_mask_.at(i) = static_cast<uint32_t>(GetEndMask()) & (1 << i);
ellipsis_mask_.at(i) = static_cast<uint32_t>(GetEllipsisMask()) & (1 << i);
@ -432,7 +433,7 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
std::vector<int> output_shape(in_shape_);
TransIndexToPositive();
for (int i = 0; i < ndim_; i++) {
for (size_t i = 0; i < ndim_; i++) {
if (strides_.at(i) == 0) {
MS_LOG(ERROR) << "strides should not be 0.";
return RET_INFER_ERR;

@ -70,7 +70,7 @@ class StridedSlice : public PrimitiveC {
std::vector<int> GetStrides() { return this->strides_; }
protected:
int ndim_ = 0;
size_t ndim_ = 0;
std::vector<int> in_shape_;
std::vector<int> begins_;
std::vector<int> ends_;

@ -137,7 +137,9 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
return RET_ERROR;
}
auto tensor_index = input0->GetTensor(index_);
MS_ASSERT(tensor_index != nullptr);
if (tensor_index == nullptr) {
return RET_INFER_INVALID;
}
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (tensor_index->data_type() != kTypeUnknown) {

@ -105,7 +105,7 @@ int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
}
if (input0->data_c() == nullptr) {
MS_LOG(ERROR) << "input0->data_c() is nullptr";
return RET_NULL_PTR;
return RET_INFER_INVALID;
}
auto ele_shape_ptr = reinterpret_cast<int *>(input0->data_c());
@ -122,7 +122,7 @@ int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
}
if (input1->data_c() == nullptr) {
MS_LOG(ERROR) << "input1->data_c() is nullptr";
return RET_NULL_PTR;
return RET_INFER_INVALID;
}
int num_elements = reinterpret_cast<int *>(input1->data_c())[0];
auto output = reinterpret_cast<TensorList *>(outputs_[0]);

@ -66,15 +66,31 @@ int Where::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
if (inputs_.size() == kSingleNum) {
auto input0 = inputs_.at(0);
if (input0->data_c() == nullptr) {
MS_LOG(ERROR) << "input0 is empty, tensor cannot be inferred yet";
return RET_INFER_INVALID;
}
int dim_size = input0->shape().size();
auto data_ptr = reinterpret_cast<bool *>(input0->data_c());
int true_num = 0;
for (int i = 0; i < input0->ElementsNum(); i++) {
if (*data_ptr) {
true_num++;
}
}
std::vector<int> output_shape = {true_num, dim_size};
outputs_.at(0)->set_shape(output_shape);
return RET_OK;
}
if (inputs_.size() < kMultiNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "where input or output number invalid, Input size:" << inputs_.size()
<< ", output size: " << outputs_.size();
return RET_INPUT_TENSOR_ERROR;
}
if (inputs_.size() < 3) {
MS_LOG(ERROR) << "Input shape tensors should b";
return RET_INPUT_TENSOR_ERROR;
}
auto input0 = inputs_.at(0);
auto input1 = inputs_.at(1);
auto input2 = inputs_.at(2);

@ -116,4 +116,5 @@ REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Less, LiteKernelCreator<Arithme
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, LiteKernelCreator<ArithmeticCompareCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, LiteKernelCreator<ArithmeticCompareCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, LiteKernelCreator<ArithmeticCompareCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_GreaterEqual, LiteKernelCreator<ArithmeticCompareCPUKernel>)
} // namespace mindspore::kernel

@ -164,6 +164,7 @@ void ArithmeticCPUKernel::InitRunFunction() {
break;
case PrimitiveType_Maximum:
arithmetic_run_ = ElementMaximum;
arithmetic_run_int_ = ElementMaximumInt;
break;
case PrimitiveType_Minimum:
arithmetic_run_ = ElementMinimum;
@ -252,6 +253,7 @@ void ArithmeticCPUKernel::InitOptRunFunction() {
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSub;
arithmetic_opt_run_int_ = ElementOptSubInt;
break;
}
break;
@ -509,6 +511,8 @@ REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_LogicalAnd, LiteKernelCreator<A
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalOr, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Maximum, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Minimum, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Maximum, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Minimum, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorDiv, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorMod, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_FloorDiv, LiteKernelCreator<ArithmeticCPUKernel>)

@ -62,6 +62,13 @@ int ExpandDimsCPUKernel::DoExpandDims(int task_id) {
MS_LOG(ERROR) << "ExpandDimsRun error task_id[" << task_id << "] error_code[" << ret << "]";
return ret;
}
} else if (this->in_tensors_.at(0)->data_type() == kNumberTypeInt32) {
int ret = ExpandDims(reinterpret_cast<int32_t *>(in_ptr_) + offset, reinterpret_cast<int32_t *>(out_ptr_) + offset,
size * sizeof(int32_t));
if (ret != RET_OK) {
MS_LOG(ERROR) << "ExpandDimsRun error task_id[" << task_id << "] error_code[" << ret << "]";
return ret;
}
}
return RET_OK;
}
@ -87,6 +94,7 @@ int ExpandDimsCPUKernel::Run() {
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ExpandDims, LiteKernelCreator<ExpandDimsCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ExpandDims, LiteKernelCreator<ExpandDimsCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_ExpandDims, LiteKernelCreator<ExpandDimsCPUKernel>)
} // namespace mindspore::kernel

@ -58,6 +58,7 @@ int ReduceCPUKernel::Init() {
}
case static_cast<int>(ReduceMode_ReduceMax): {
reducer_ = ReduceMax;
int_reducer_ = IntReduceMax;
break;
}
case static_cast<int>(ReduceMode_ReduceMin): {

@ -49,6 +49,7 @@ int ShapeCPUKernel::Run() {
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Shape, LiteKernelCreator<ShapeCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Shape, LiteKernelCreator<ShapeCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Shape, LiteKernelCreator<ShapeCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Shape, LiteKernelCreator<ShapeCPUKernel>)

@ -127,5 +127,6 @@ kernel::LiteKernel *CpuTensorListFromTensorFp32KernelCreator(const std::vector<l
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListFromTensor, CpuTensorListFromTensorFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListFromTensor, CpuTensorListFromTensorFp32KernelCreator)
} // namespace mindspore::kernel

@ -37,5 +37,6 @@ int TensorListReserveCPUKernel::Run() {
int TensorListReserveCPUKernel::ReSize() { return RET_OK; }
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListReserve, LiteKernelCreator<TensorListReserveCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListReserve, LiteKernelCreator<TensorListReserveCPUKernel>)
} // namespace mindspore::kernel

@ -90,5 +90,6 @@ int WhereCPUKernel::Run() {
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Where, LiteKernelCreator<WhereCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Where, LiteKernelCreator<WhereCPUKernel>)
} // namespace mindspore::kernel

@ -16,6 +16,7 @@
#include "src/tensorlist.h"
#include <utility>
#include <algorithm>
#include "include/ms_tensor.h"
#include "src/common/log_adapter.h"
#include "schema/model_generated.h"
@ -277,6 +278,21 @@ STATUS TensorList::Decode(const int *data) {
for (int j = 0; j < data[1]; ++j) {
element_shape_.push_back(data[2 + j]);
}
int tensors_num = data[2 + data[1]];
tensors_.resize(tensors_num);
int tensor_index = 2 + data[1] + 1;
for (int i = 0; i < tensors_num; i++) {
int tensor_dims_size = data[tensor_index++];
std::vector<int> shape(tensor_dims_size);
for (int j = 0; j < tensor_dims_size; j++) {
shape[j] = data[tensor_index++];
}
tensors_[i] = new (std::nothrow) Tensor(tensors_data_type_, shape);
if (tensors_[i] == nullptr) {
MS_LOG(ERROR) << "new Tensor failed";
return RET_NULL_PTR;
}
}
return RET_OK;
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save