strided_slice_with_axes

pull/8201/head
zhaozhenlong 4 years ago
parent bedc733e42
commit 2afed72140

@ -104,8 +104,12 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p
dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5; dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5;
if (param->data_type == kDataTypeFloat) { if (param->data_type == kDataTypeFloat) {
*((float *)out_data + out_offset) = *((float *)in_data + in_offset); *((float *)out_data + out_offset) = *((float *)in_data + in_offset);
} else { } else if (param->data_type == kDataTypeInt8) {
*((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset);
} else if (param->data_type == kDataTypeInt) {
*((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset);
} else {
return NNACL_ERR;
} }
out_offset++; out_offset++;
} }

@ -87,7 +87,7 @@ int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
} }
auto in_data = reinterpret_cast<int *>(in_tensor->data_c()); auto in_data = reinterpret_cast<int *>(in_tensor->data_c());
if (in_data == nullptr) { if (in_data == nullptr) {
MS_LOG(ERROR) << "Input data is nullptr"; MS_LOG(INFO) << "Input data is nullptr. Input tensor has not been calculated out yet.";
return RET_INFER_INVALID; return RET_INFER_INVALID;
} }
int size = in_tensor->ElementsNum(); int size = in_tensor->ElementsNum();

@ -50,6 +50,7 @@ class Registry {
} }
}; };
OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive); OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive);
OpParameter *PopulateStridedSliceParameter(const mindspore::lite::PrimitiveC *primitive);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif #endif

@ -41,6 +41,7 @@ OpParameter *PopulateStridedSliceParameter(const mindspore::lite::PrimitiveC *pr
memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int)); memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int));
auto in_shape = ((lite::StridedSlice *)primitive)->GetInShape(); auto in_shape = ((lite::StridedSlice *)primitive)->GetInShape();
memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int));
strided_slice_param->in_shape_length_ = static_cast<int>(in_shape.size());
return reinterpret_cast<OpParameter *>(strided_slice_param); return reinterpret_cast<OpParameter *>(strided_slice_param);
} }

@ -15,6 +15,7 @@
*/ */
#include "src/ops/strided_slice.h" #include "src/ops/strided_slice.h"
#include <algorithm>
#ifndef PRIMITIVE_WRITEABLE #ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h" #include "src/ops/ops_register.h"
@ -172,7 +173,8 @@ Registry StridedSliceRegistry(schema::PrimitiveType_StridedSlice, StridedSliceCr
namespace { namespace {
constexpr size_t kStridedSliceOutputNum = 1; constexpr size_t kStridedSliceOutputNum = 1;
constexpr size_t kStridedSliceInputNum = 1; constexpr size_t kStridedSliceInputNum = 1;
constexpr size_t kStridedSliceMultiInputNum = 4; constexpr size_t kStridedSliceMultiInputNumMin = 3;
constexpr size_t kStridedSliceMultiInputNumMax = 5;
} // namespace } // namespace
void StridedSlice::ApplyNewAxisMask() { void StridedSlice::ApplyNewAxisMask() {
@ -251,13 +253,91 @@ void StridedSlice::TransIndexToPositive() {
} }
} }
int StridedSlice::HandleAxesInputExist(const std::vector<lite::Tensor *> &inputs) {
// when axes input exist:
// input order: data, begin, end, axes(opt), stride(opt)
auto input_tensor = inputs.at(0);
MS_ASSERT(input_tensor != nullptr);
auto begin_tensor = inputs.at(1);
MS_ASSERT(begin_tensor != nullptr);
int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData());
auto end_tensor = inputs.at(2);
MS_ASSERT(end_tensor != nullptr);
int *end_data = reinterpret_cast<int *>(end_tensor->MutableData());
if (begin_data == nullptr || end_data == nullptr) {
return RET_INFER_ERR;
}
// when input contains axes, begins, ends, strides will be expand to the same length as input rank
ndim_ = static_cast<int>(input_tensor->shape().size());
int begin_ndim = begin_tensor->ElementsNum();
int *axes_data = nullptr;
auto axes_tensor = inputs.at(3);
if (axes_tensor->ElementsNum() != 0) {
MS_ASSERT(axes_tensor->ElementsNum() == begin_ndim);
axes_data = reinterpret_cast<int *>(axes_tensor->MutableData());
if (axes_data == nullptr) {
return RET_INFER_ERR;
}
}
int *stride_data = nullptr;
auto stride_tensor = inputs.at(4);
if (stride_tensor->ElementsNum() != 0) {
MS_ASSERT(stride_tensor->ElementsNum() == begin_ndim);
stride_data = reinterpret_cast<int *>(stride_tensor->MutableData());
if (stride_data == nullptr) {
return RET_INFER_ERR;
}
}
std::vector<int> axes;
if (axes_data == nullptr) {
for (int i = 0; i < begin_ndim; ++i) {
axes[i] = i;
}
} else {
axes.assign(axes_data, axes_data + begin_ndim);
for (int i = 0; i < begin_ndim; ++i) {
if (axes[i] < 0) {
axes[i] += ndim_;
}
}
}
in_shape_.assign(ndim_, 0);
begins_.assign(ndim_, 0);
ends_.assign(ndim_, 0);
strides_.assign(ndim_, 0);
auto input_shape = input_tensor->shape();
for (int i = 0; i < ndim_; ++i) {
in_shape_[i] = input_shape.at(i);
}
for (int i = 0; i < ndim_; ++i) {
auto axes_it = std::find(axes.begin(), axes.end(), i);
if (axes_it != axes.end()) {
auto axis = axes_it - axes.begin();
// begins or ends exceed limit will be set to limit
begins_[i] = std::max(std::min(begin_data[axis], input_shape[i] - 1), -input_shape[i]);
ends_[i] = std::max(std::min(end_data[axis], input_shape[i]), -input_shape[i] - 1);
strides_[i] = stride_data[axis];
} else {
begins_[i] = 0;
ends_[i] = input_shape[i];
strides_[i] = 1;
}
}
return RET_OK;
}
int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
MS_ASSERT(this->primitive_ != nullptr); MS_ASSERT(this->primitive_ != nullptr);
if (outputs.size() != kStridedSliceOutputNum) { if (outputs.size() != kStridedSliceOutputNum) {
MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
if (inputs.size() != kStridedSliceInputNum && inputs.size() != kStridedSliceMultiInputNum) { if (inputs.size() != kStridedSliceInputNum &&
!(inputs.size() <= kStridedSliceMultiInputNumMax && inputs.size() >= kStridedSliceMultiInputNumMin)) {
MS_LOG(ERROR) << "Invalid input size " << inputs.size(); MS_LOG(ERROR) << "Invalid input size " << inputs.size();
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
@ -268,6 +348,10 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
auto input_shape = input->shape(); auto input_shape = input->shape();
auto inferflag = GetInferFlag(); auto inferflag = GetInferFlag();
in_shape_.clear();
begins_.clear();
ends_.clear();
strides_.clear();
if (inputs.size() == kStridedSliceInputNum) { if (inputs.size() == kStridedSliceInputNum) {
ndim_ = static_cast<int>(GetBegin().size()); ndim_ = static_cast<int>(GetBegin().size());
@ -279,7 +363,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
ends_.emplace_back((GetEnd())[i]); ends_.emplace_back((GetEnd())[i]);
strides_.emplace_back((GetStride())[i]); strides_.emplace_back((GetStride())[i]);
} }
} else { }
if (inputs.size() == 4) {
// input order: input, begins, ends, strides.
auto begin_tensor = inputs.at(1); auto begin_tensor = inputs.at(1);
int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData()); int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData());
auto end_tensor = inputs.at(2); auto end_tensor = inputs.at(2);
@ -299,6 +385,13 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
strides_.emplace_back(stride_data[i]); strides_.emplace_back(stride_data[i]);
} }
} }
if (inputs.size() == 5) {
// input order: input, begins, end, axes, strides
auto ret = HandleAxesInputExist(inputs);
if (ret != RET_OK) {
return ret;
}
}
// set all mask to original input shape // set all mask to original input shape
begins_mask_.resize(ndim_); begins_mask_.resize(ndim_);
@ -333,7 +426,12 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
if (i < ndim_ && new_axis_mask_.at(i)) { if (i < ndim_ && new_axis_mask_.at(i)) {
output_shape.at(i) = 1; output_shape.at(i) = 1;
} else { } else {
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i); if (strides_.at(i) == 0) {
MS_LOG(ERROR) << "strides should not be 0.";
return RET_INFER_ERR;
}
output_shape.at(i) =
(ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i);
} }
} }

@ -81,6 +81,7 @@ class StridedSlice : public PrimitiveC {
std::vector<bool> new_axis_mask_; std::vector<bool> new_axis_mask_;
std::vector<bool> shrink_axis_mask_; std::vector<bool> shrink_axis_mask_;
void TransIndexToPositive(); void TransIndexToPositive();
int HandleAxesInputExist(const std::vector<lite::Tensor *> &inputs);
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -16,11 +16,11 @@
#include "src/runtime/kernel/arm/base/strided_slice.h" #include "src/runtime/kernel/arm/base/strided_slice.h"
#include <vector> #include <vector>
#include "nnacl/strided_slice.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/ops/populate/populate_register.h"
using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@ -44,16 +44,16 @@ int StridedSliceCPUKernel::Init() {
} }
int StridedSliceCPUKernel::ReSize() { int StridedSliceCPUKernel::ReSize() {
auto input = in_tensors_.at(0); if (op_parameter_ != nullptr) {
auto parameter = reinterpret_cast<StridedSliceParameter *>(op_parameter_); free(op_parameter_);
MS_ASSERT(input); op_parameter_ = nullptr;
MS_ASSERT(parameter); }
parameter->data_type = input->data_type() == kNumberTypeInt8 ? kDataTypeInt8 : kDataTypeFloat; op_parameter_ = PopulateStridedSliceParameter(primitive_);
auto input_shape = input->shape(); if (op_parameter_ == nullptr) {
for (size_t i = 0; i < input_shape.size(); ++i) { MS_LOG(ERROR) << "Malloc parameter failed";
parameter->in_shape_[i] = input_shape[i]; return RET_ERROR;
} }
parameter->in_shape_length_ = static_cast<int>(input_shape.size()); param_ = reinterpret_cast<StridedSliceParameter *>(op_parameter_);
return RET_OK; return RET_OK;
} }
@ -62,8 +62,7 @@ int StridedSliceCPUKernel::HandleMultiInputs() {
MS_LOG(ERROR) << "Inputs size should be " << kMultiInputsSize << ", got " << in_tensors_.size(); MS_LOG(ERROR) << "Inputs size should be " << kMultiInputsSize << ", got " << in_tensors_.size();
return RET_ERROR; return RET_ERROR;
} }
auto param = reinterpret_cast<StridedSliceParameter *>(op_parameter_); if (param_ == nullptr) {
if (param == nullptr) {
MS_LOG(ERROR) << "StridedSliceParamater cast nullptr"; MS_LOG(ERROR) << "StridedSliceParamater cast nullptr";
return RET_ERROR; return RET_ERROR;
} }
@ -74,35 +73,49 @@ int StridedSliceCPUKernel::HandleMultiInputs() {
MS_LOG(ERROR) << "StridedSlice supports max dimension " << DIMENSION_6D << ", input begins dim is " << axis_num; MS_LOG(ERROR) << "StridedSlice supports max dimension " << DIMENSION_6D << ", input begins dim is " << axis_num;
return RET_ERROR; return RET_ERROR;
} }
memcpy(param->begins_, begins->MutableData(), axis_num * sizeof(int)); memcpy(param_->begins_, begins->MutableData(), axis_num * sizeof(int));
auto ends = in_tensors_.at(kEndsIndex); auto ends = in_tensors_.at(kEndsIndex);
MS_ASSERT(ends != nullptr); MS_ASSERT(ends != nullptr);
MS_ASSERT(axis_num == ends->ElementsNum()); MS_ASSERT(axis_num == ends->ElementsNum());
memcpy(param->ends_, ends->MutableData(), axis_num * sizeof(int)); memcpy(param_->ends_, ends->MutableData(), axis_num * sizeof(int));
auto strides = in_tensors_.at(kStridesInex); auto strides = in_tensors_.at(kStridesInex);
MS_ASSERT(strides != nullptr); MS_ASSERT(strides != nullptr);
MS_ASSERT(axis_num == strides->ElementsNum()); MS_ASSERT(axis_num == strides->ElementsNum());
memcpy(param->strides_, strides->MutableData(), axis_num * sizeof(int)); memcpy(param_->strides_, strides->MutableData(), axis_num * sizeof(int));
param->num_axes_ = axis_num; param_->num_axes_ = axis_num;
return RET_OK; return RET_OK;
} }
int StridedSliceCPUKernel::Run() { int StridedSliceCPUKernel::Run() {
auto input = in_tensors_.at(0); auto input = in_tensors_.at(0);
auto output = out_tensors_.at(0);
MS_ASSERT(input); MS_ASSERT(input);
switch (input->data_type()) {
case kNumberTypeInt8:
param_->data_type = kDataTypeInt8;
break;
case kNumberTypeFloat32:
param_->data_type = kDataTypeFloat;
break;
case kNumberTypeInt32:
param_->data_type = kDataTypeInt;
break;
default:
MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
return RET_ERROR;
}
auto output = out_tensors_.at(0);
MS_ASSERT(output); MS_ASSERT(output);
// inputs order: input, begin, end, stride
if (in_tensors().size() == kMultiInputsSize) { if (in_tensors().size() == kMultiInputsSize) {
auto ret = HandleMultiInputs(); auto ret = HandleMultiInputs();
if (ret != RET_OK) { if (ret != RET_OK) {
return ret; return ret;
} }
} }
auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), param_);
reinterpret_cast<StridedSliceParameter *>(op_parameter_));
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]"; MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]";
return RET_ERROR; return RET_ERROR;

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_BACKEND_ARM_BASE_STRIDED_SLICE_H_ #define MINDSPORE_LITE_SRC_BACKEND_ARM_BASE_STRIDED_SLICE_H_
#include <vector> #include <vector>
#include "nnacl/strided_slice.h"
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
namespace mindspore::kernel { namespace mindspore::kernel {
@ -27,7 +27,9 @@ class StridedSliceCPUKernel : public LiteKernel {
StridedSliceCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, StridedSliceCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive) const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {} : LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<StridedSliceParameter *>(parameter);
}
~StridedSliceCPUKernel() override = default; ~StridedSliceCPUKernel() override = default;
int Init() override; int Init() override;
@ -36,6 +38,9 @@ class StridedSliceCPUKernel : public LiteKernel {
private: private:
int HandleMultiInputs(); int HandleMultiInputs();
private:
StridedSliceParameter *param_;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel

@ -25,9 +25,3 @@ psenet_lite_mbv2.onnx;1,32,32,3
super-resolution-10.onnx;1,224,224,1 super-resolution-10.onnx;1,224,224,1
tinyyolov2-8.onnx;1,416,416,3 tinyyolov2-8.onnx;1,416,416,3
ml_2012_ocr_cn.onnx ml_2012_ocr_cn.onnx
ml_2012_ocr_cn_noLSTM.onnx
candy-9.onnx
mosaic-9.onnx
pointilism-9.onnx
rain-princess-9.onnx
udnie-9.onnx

@ -62,7 +62,7 @@ function Run_Converter() {
if [ $? = 0 ]; then if [ $? = 0 ]; then
converter_result='converter onnx '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} converter_result='converter onnx '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file}
else else
converter_result='converter onnx '${model_name}' failed';echo ${converter_result} >> ${run_converter_result_file} converter_result='converter onnx '${model_name}' failed';echo ${converter_result} >> ${run_converter_result_file};return 1
fi fi
done < ${models_onnx_config} done < ${models_onnx_config}

@ -64,6 +64,9 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() {
param_value->set_tensor_addr(tensor_data); param_value->set_tensor_addr(tensor_data);
param_value->set_tensor_size(size); param_value->set_tensor_size(size);
parameter->set_default_param(param_value); parameter->set_default_param(param_value);
} else if (std::find(meta_graph_->inputIndex.begin(), meta_graph_->inputIndex.end(), i) ==
meta_graph_->inputIndex.end()) {
parameter->set_default_param(param_value);
} }
AddNode(i, parameter); AddNode(i, parameter);
} }

@ -30,7 +30,7 @@
#include "securec/include/securec.h" #include "securec/include/securec.h"
#include "tools/converter/model_parser.h" #include "tools/converter/model_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/common/tensor_util.h" #include "tools/converter/parser/onnx/onnx_tensor_parser.h"
#include "proto/onnx.pb.h" #include "proto/onnx.pb.h"
namespace mindspore { namespace mindspore {
@ -53,42 +53,38 @@ class OnnxModelParser : public ModelParser {
private: private:
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph);
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, TensorCache *tensor_cache); STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph);
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, TensorCache *tensor_cache); STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph);
STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, int *index);
TensorCache *tensor_cache, int *index);
STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, int *index);
TensorCache *tensor_cache, int *index);
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, TensorCache *tensor_cache, const QuantType &quantType, schema::CNodeT *dst_op, const QuantType &quantType, schema::MetaGraphT *dst_graph);
schema::MetaGraphT *dst_graph);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache, schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, const QuantType &quant_type);
const QuantType &quant_type);
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node);
STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
const string &onnx_op_type, schema::CNodeT *dst_op); const string &onnx_op_type, schema::CNodeT *dst_op);
void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op,
schema::TensorT *dst_tensor, TensorCache *tensor_cache); schema::TensorT *dst_tensor);
STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op, STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); const onnx::NodeProto &onnx_node);
STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache); STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op);
STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor); STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor);
STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef); STATUS SetAllTensors(schema::MetaGraphT *graphDef);
void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);

@ -17,9 +17,35 @@
#include "tools/converter/parser/onnx/onnx_slice_parser.h" #include "tools/converter/parser/onnx/onnx_slice_parser.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <string>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxSliceParser::InsertTensor(const std::vector<int> &onnx_val, const std::string &name,
onnx::NodeProto *onnx_node) {
std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>();
if (tensor == nullptr) {
MS_LOG(ERROR) << "new tensor failed";
return RET_ERROR;
}
tensor->dataType = mindspore::kNumberTypeInt32;
tensor->dims.push_back(onnx_val.size());
tensor->format = schema::Format::Format_NCHW;
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
int data_size = sizeof(int32_t) * onnx_val.size();
tensor->data.resize(data_size);
if (data_size != 0 &&
memcpy_s(static_cast<void *>(tensor->data.data()), data_size, onnx_val.data(), data_size) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
int tensor_num = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().size();
std::string tensor_name = name + std::to_string(tensor_num);
OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(tensor_name, tensor.release(), GRAPH_INPUT);
onnx_node->add_input(tensor_name);
return RET_OK;
}
STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SliceParser"; MS_LOG(DEBUG) << "onnx SliceParser";
@ -33,15 +59,15 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::unique_ptr<schema::SliceT> attr = std::make_unique<schema::SliceT>(); std::unique_ptr<schema::StridedSliceT> attr = std::make_unique<schema::StridedSliceT>();
if (attr == nullptr) { if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed"; MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::vector<int> axes;
std::vector<int> starts; std::vector<int> starts;
std::vector<int> ends; std::vector<int> ends;
std::vector<int> axes;
std::vector<int> steps; std::vector<int> steps;
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
@ -71,64 +97,49 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
} }
} }
} }
if (axes.empty()) {
if (onnx_node.input_size() > 1) { for (size_t i = 0; i < starts.size(); ++i) {
const auto &starts_name = onnx_node.input(1); axes.push_back(i);
for (const auto &it : onnx_graph.initializer()) {
if (it.name() == starts_name) {
starts.clear();
for (int i = 0; i < it.int32_data_size(); ++i) {
starts.push_back(it.int32_data(i));
}
}
} }
} }
if (steps.empty()) {
if (onnx_node.input_size() > 2) { steps.assign(starts.size(), 1);
const auto &ends_name = onnx_node.input(2);
for (const auto &it : onnx_graph.initializer()) {
if (it.name() == ends_name) {
ends.clear();
for (int i = 0; i < it.int32_data_size(); ++i) {
ends.push_back(it.int32_data(i));
}
}
}
} }
onnx::NodeProto *slice_node = nullptr;
if (onnx_node.input_size() > 3) { for (auto &node : onnx_graph.node()) {
const auto &axes_name = onnx_node.input(3); if (&node == &onnx_node) {
for (const auto &it : onnx_graph.initializer()) { slice_node = const_cast<onnx::NodeProto *>(&node);
if (it.name() == axes_name) {
axes.clear();
for (int i = 0; i < it.int32_data_size(); ++i) {
axes.push_back(it.int32_data(i));
}
}
} }
} }
int insert_num = 5 - onnx_node.input_size();
if (onnx_node.input_size() > 4) { int status = RET_OK;
const auto &steps_name = onnx_node.input(4); switch (insert_num) {
for (const auto &it : onnx_graph.initializer()) { case 4: {
if (it.name() == steps_name) { std::string name = "slice/starts/";
steps.clear(); status = InsertTensor(starts, name, slice_node);
for (int i = 0; i < it.int32_data_size(); ++i) {
steps.push_back(it.int32_data(i));
}
}
} }
case 3:
if (status == RET_OK) {
std::string name = "slice/ends/";
status = InsertTensor(ends, name, slice_node);
}
case 2:
if (status == RET_OK) {
std::string name = "slice/axes/";
status = InsertTensor(axes, name, slice_node);
}
case 1:
if (status == RET_OK) {
std::string name = "slice/steps/";
status = InsertTensor(steps, name, slice_node);
}
default:
if (status != RET_OK) {
MS_LOG(ERROR) << "onnx slice insert tensor failed";
return RET_ERROR;
}
} }
op->primitive->value.type = schema::PrimitiveType_StridedSlice;
std::vector<int> sizes(starts.size(), -1);
for (size_t i = 0; i < starts.size(); ++i) {
sizes[i] = (ends[i] < 0 ? ends[i] : ends[i] - starts[i]);
}
attr->axes = axes;
attr->begin = starts;
attr->size = sizes;
attr->step = steps;
op->primitive->value.type = schema::PrimitiveType_Slice;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }

@ -17,8 +17,11 @@
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H
#include <vector>
#include <string>
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_tensor_parser.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -27,6 +30,7 @@ class OnnxSliceParser : public OnnxNodeParser {
OnnxSliceParser() : OnnxNodeParser("Slice") {} OnnxSliceParser() : OnnxNodeParser("Slice") {}
~OnnxSliceParser() override = default; ~OnnxSliceParser() override = default;
STATUS InsertTensor(const std::vector<int> &onnx_val, const std::string &name, onnx::NodeProto *onnx_node);
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H
#include "tools/common/tensor_util.h"
namespace mindspore {
namespace lite {
class OnnxTensorParser {
public:
~OnnxTensorParser() = default;
static OnnxTensorParser *GetInstance() {
static OnnxTensorParser onnxTensorParser;
return &onnxTensorParser;
}
TensorCache *GetTensorCache() { return &tensor_cache_; }
private:
OnnxTensorParser() = default;
TensorCache tensor_cache_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TESNOR_PARSER_H
Loading…
Cancel
Save