From 3b359c4ad919aba276bcce0e2b2aa9f129c50403 Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Thu, 11 Mar 2021 09:33:06 +0800 Subject: [PATCH] fix bug of encoder weight quant fix some bug of where parameter populate --- mindspore/lite/nnacl/gather_parameter.h | 1 - mindspore/lite/nnacl/infer/gather_infer.c | 3 +-- mindspore/lite/nnacl/op_base.h | 1 + mindspore/lite/src/mindrt_executor.cc | 2 +- mindspore/lite/src/ops/populate/where_populate.cc | 5 +++-- mindspore/lite/src/scheduler.cc | 2 ++ mindspore/lite/tools/converter/quant_param_holder.h | 4 ++-- 7 files changed, 10 insertions(+), 8 deletions(-) diff --git a/mindspore/lite/nnacl/gather_parameter.h b/mindspore/lite/nnacl/gather_parameter.h index 0a2d907b6d..6ac16dcaab 100644 --- a/mindspore/lite/nnacl/gather_parameter.h +++ b/mindspore/lite/nnacl/gather_parameter.h @@ -23,7 +23,6 @@ typedef struct GatherParameter { // Primitive parameter OpParameter op_parameter_; int axis_; - int quant_type_; } GatherParameter; #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_ diff --git a/mindspore/lite/nnacl/infer/gather_infer.c b/mindspore/lite/nnacl/infer/gather_infer.c index 4bd9fb5a9f..229faf81d9 100644 --- a/mindspore/lite/nnacl/infer/gather_infer.c +++ b/mindspore/lite/nnacl/infer/gather_infer.c @@ -25,8 +25,7 @@ int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * const TensorC *indices = inputs[1]; TensorC *output = outputs[0]; output->data_type_ = input->data_type_; - GatherParameter *param = (GatherParameter *)parameter; - if (param->quant_type_ == QuantType_WeightQuant) { + if (parameter->quant_type_ == QuantType_WeightQuant) { output->data_type_ = kNumberTypeFloat32; } output->format_ = input->format_; diff --git a/mindspore/lite/nnacl/op_base.h b/mindspore/lite/nnacl/op_base.h index a6dadef202..d7551c3f30 100644 --- a/mindspore/lite/nnacl/op_base.h +++ b/mindspore/lite/nnacl/op_base.h @@ -80,6 +80,7 @@ typedef struct OpParameter { bool infer_flag_; int type_; int thread_num_; + int quant_type_; } OpParameter; typedef struct QuantArg { diff --git a/mindspore/lite/src/mindrt_executor.cc b/mindspore/lite/src/mindrt_executor.cc index b2830e3ca5..732d3f1a39 100644 --- a/mindspore/lite/src/mindrt_executor.cc +++ b/mindspore/lite/src/mindrt_executor.cc @@ -49,7 +49,7 @@ int MindrtExecutor::Prepare(const std::vector &kernels) { for (size_t j = 0; j < outTensorSize; j++) { auto data = - std::make_shared>(opActors_[i]->GetAID(), kernels[i]->in_tensors()[j], static_cast(j)); + std::make_shared>(opActors_[i]->GetAID(), kernels[i]->out_tensors()[j], static_cast(j)); outputData_.emplace_back(data); } } diff --git a/mindspore/lite/src/ops/populate/where_populate.cc b/mindspore/lite/src/ops/populate/where_populate.cc index 5f1790f839..75f624efb6 100644 --- a/mindspore/lite/src/ops/populate/where_populate.cc +++ b/mindspore/lite/src/ops/populate/where_populate.cc @@ -14,19 +14,20 @@ * limitations under the License. */ #include "src/ops/populate/populate_register.h" +#include "nnacl/where_parameter.h" namespace mindspore { namespace lite { namespace { OpParameter *PopulateWhereParameter(const void *prim) { - OpParameter *where_parameter = reinterpret_cast(malloc(sizeof(OpParameter))); + WhereParameter *where_parameter = reinterpret_cast(malloc(sizeof(WhereParameter))); if (where_parameter == nullptr) { MS_LOG(ERROR) << "malloc Where parameter failed."; return nullptr; } memset(where_parameter, 0, sizeof(OpParameter)); auto primitive = static_cast(prim); - where_parameter->type_ = primitive->value_type(); + where_parameter->op_parameter_.type_ = primitive->value_type(); return reinterpret_cast(where_parameter); } } // namespace diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index ca676f101e..1a4a672431 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -135,6 +135,8 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node, bool *infer_shape_i MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << PrimitiveTypeName(GetPrimitiveType(primitive)); return RET_ERROR; } + parameter->quant_type_ = node->quant_type_; + op_parameters_[node->output_indices_.at(0)] = parameter; parameter->infer_flag_ = !(*infer_shape_interrupt); auto ret = KernelInferShape(inputs, &outputs, parameter); diff --git a/mindspore/lite/tools/converter/quant_param_holder.h b/mindspore/lite/tools/converter/quant_param_holder.h index 830e3dc543..c368d9eed3 100644 --- a/mindspore/lite/tools/converter/quant_param_holder.h +++ b/mindspore/lite/tools/converter/quant_param_holder.h @@ -81,7 +81,7 @@ class QuantParamHolder : public Value { } void set_input_quant_param(const size_t &index, const std::vector &input_quant_param) { - if (index > this->input_quant_param_.size()) { + if (index >= this->input_quant_param_.size()) { std::vector place_quant(1); this->input_quant_param_.insert(this->input_quant_param_.end(), index + 1 - input_quant_param_.size(), place_quant); @@ -94,7 +94,7 @@ class QuantParamHolder : public Value { } void set_output_quant_param(const size_t &index, const std::vector &output_quant_param) { - if (index > this->output_quant_param_.size()) { + if (index >= this->output_quant_param_.size()) { std::vector place_quant(1); this->output_quant_param_.insert(this->output_quant_param_.end(), index + 1 - output_quant_param_.size(), place_quant);