From 01267949b1006f005c2802899c58f7f18ce4da6f Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Wed, 25 Nov 2020 16:49:25 +0800 Subject: [PATCH] Fix onehot input 3 convert --- mindspore/lite/src/ops/one_hot.cc | 9 ++++++--- .../src/runtime/kernel/arm/fp32/space_to_batch_fp32.cc | 2 +- .../lite/src/runtime/kernel/arm/fp32/squeeze_fp32.cc | 7 +++++-- .../lite/src/runtime/kernel/arm/int8/squeeze_int8.cc | 2 +- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc index b3bd9c9672..8591234db7 100644 --- a/mindspore/lite/src/ops/one_hot.cc +++ b/mindspore/lite/src/ops/one_hot.cc @@ -82,7 +82,8 @@ Registry OneHotRegistry(schema::PrimitiveType_OneHot, OneHotCreator); namespace { constexpr size_t kOneHotInputNum = 4; -} +constexpr size_t kOneHotInputNumOpt = 3; +} // namespace int OneHot::InferShape(std::vector inputs, std::vector outputs) { if (this->primitive_ == nullptr) { return RET_NULL_PTR; @@ -90,8 +91,10 @@ int OneHot::InferShape(std::vector inputs, std::vector outpu int axis = GetAxis(); // indices, depth, on_value, off_value - if (inputs.size() != kOneHotInputNum) { - MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum; + // indices, depth, on_off_value(contain 2 values); + if (inputs.size() != kOneHotInputNum && inputs.size() != kOneHotInputNumOpt) { + MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum << " or " + << kOneHotInputNumOpt; return RET_ERROR; } auto depth_tensor = inputs.at(1); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.cc index 3ad85db272..d48013ac62 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.cc @@ -43,7 +43,7 @@ int SpaceToBatchCPUKernel::ReSize() { MS_ASSERT(input_tensor); auto output_tensor = out_tensors_.at(0); MS_ASSERT(output_tensor); - MS_ASSERT(param); + MS_ASSERT(param_); for (size_t i = 0; i < DIMENSION_4D; i++) { param_->input_shape_[i] = input_tensor->shape().at(i); param_->output_shape_[i] = output_tensor->shape().at(i); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze_fp32.cc index 25321322ab..658b0eab03 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze_fp32.cc @@ -34,15 +34,18 @@ int SqueezeCPUKernel::ReSize() { return RET_OK; } int SqueezeCPUKernel::Run() { mindspore::lite::STATUS ret = RET_ERROR; size_t data_size = in_tensors_.front()->Size(); - MS_ASSERT(input_ptr); - MS_ASSERT(output_ptr); + if (in_tensors_.front()->data_type() == kNumberTypeInt32) { auto input_ptr = reinterpret_cast(in_tensors_.front()->MutableData()); auto output_ptr = reinterpret_cast(out_tensors_.front()->MutableData()); + MS_ASSERT(input_ptr); + MS_ASSERT(output_ptr); ret = DoSqueezeInt32(input_ptr, output_ptr, data_size); } else { auto input_ptr = reinterpret_cast(in_tensors_.front()->MutableData()); auto output_ptr = reinterpret_cast(out_tensors_.front()->MutableData()); + MS_ASSERT(input_ptr); + MS_ASSERT(output_ptr); ret = DoSqueeze(input_ptr, output_ptr, data_size); } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.cc index 7cbbcda3ae..866a539899 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.cc @@ -61,7 +61,7 @@ int SqueezeInt8CPUKernel::Init() { return RET_ERROR; } auto in_quant_args = in_tensors_.front()->quant_params(); - MS_ASSERT(quant_args.size() > 0); + MS_ASSERT(in_quant_args.size() > 0); quant_squeeze_param_->in_quant_args_->scale_ = in_quant_args.front().scale; quant_squeeze_param_->in_quant_args_->zp_ = in_quant_args.front().zeroPoint;