fix bug of encoder weight quant

fix some bug of where parameter populate
pull/13134/head
mengyuanli 4 years ago
parent e149d64695
commit 3b359c4ad9

@ -23,7 +23,6 @@ typedef struct GatherParameter {
// Primitive parameter
OpParameter op_parameter_;
int axis_;
int quant_type_;
} GatherParameter;
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_

@ -25,8 +25,7 @@ int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
const TensorC *indices = inputs[1];
TensorC *output = outputs[0];
output->data_type_ = input->data_type_;
GatherParameter *param = (GatherParameter *)parameter;
if (param->quant_type_ == QuantType_WeightQuant) {
if (parameter->quant_type_ == QuantType_WeightQuant) {
output->data_type_ = kNumberTypeFloat32;
}
output->format_ = input->format_;

@ -80,6 +80,7 @@ typedef struct OpParameter {
bool infer_flag_;
int type_;
int thread_num_;
int quant_type_;
} OpParameter;
typedef struct QuantArg {

@ -49,7 +49,7 @@ int MindrtExecutor::Prepare(const std::vector<kernel::LiteKernel *> &kernels) {
for (size_t j = 0; j < outTensorSize; j++) {
auto data =
std::make_shared<OpData<Tensor>>(opActors_[i]->GetAID(), kernels[i]->in_tensors()[j], static_cast<int>(j));
std::make_shared<OpData<Tensor>>(opActors_[i]->GetAID(), kernels[i]->out_tensors()[j], static_cast<int>(j));
outputData_.emplace_back(data);
}
}

@ -14,19 +14,20 @@
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
#include "nnacl/where_parameter.h"
namespace mindspore {
namespace lite {
namespace {
OpParameter *PopulateWhereParameter(const void *prim) {
OpParameter *where_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
WhereParameter *where_parameter = reinterpret_cast<WhereParameter *>(malloc(sizeof(WhereParameter)));
if (where_parameter == nullptr) {
MS_LOG(ERROR) << "malloc Where parameter failed.";
return nullptr;
}
memset(where_parameter, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
where_parameter->type_ = primitive->value_type();
where_parameter->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(where_parameter);
}
} // namespace

@ -135,6 +135,8 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node, bool *infer_shape_i
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << PrimitiveTypeName(GetPrimitiveType(primitive));
return RET_ERROR;
}
parameter->quant_type_ = node->quant_type_;
op_parameters_[node->output_indices_.at(0)] = parameter;
parameter->infer_flag_ = !(*infer_shape_interrupt);
auto ret = KernelInferShape(inputs, &outputs, parameter);

@ -81,7 +81,7 @@ class QuantParamHolder : public Value {
}
void set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) {
if (index > this->input_quant_param_.size()) {
if (index >= this->input_quant_param_.size()) {
std::vector<schema::QuantParamT> place_quant(1);
this->input_quant_param_.insert(this->input_quant_param_.end(), index + 1 - input_quant_param_.size(),
place_quant);
@ -94,7 +94,7 @@ class QuantParamHolder : public Value {
}
void set_output_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) {
if (index > this->output_quant_param_.size()) {
if (index >= this->output_quant_param_.size()) {
std::vector<schema::QuantParamT> place_quant(1);
this->output_quant_param_.insert(this->output_quant_param_.end(), index + 1 - output_quant_param_.size(),
place_quant);

Loading…
Cancel
Save