fix reduce parser bug

pull/11402/head
cjh9368 4 years ago
parent 4f4fa5260c
commit 396aaacce4

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -90,9 +90,7 @@ int Select::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
output->set_shape(input->shape()); output->set_shape(input->shape());
output->set_format(input->format()); output->set_format(input->format());
auto data_type = input->data_type(); auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) { if (data_type == kObjectTypeTensorType) {
continue;
} else {
auto input_tensorlist = reinterpret_cast<TensorList *>(input); auto input_tensorlist = reinterpret_cast<TensorList *>(input);
auto output_tensorlist = reinterpret_cast<TensorList *>(output); auto output_tensorlist = reinterpret_cast<TensorList *>(output);
output_tensorlist->set_element_shape(input_tensorlist->element_shape()); output_tensorlist->set_element_shape(input_tensorlist->element_shape());

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

@ -22,18 +22,41 @@ namespace mindspore {
namespace lite { namespace lite {
PrimitiveC *CaffeReduceParser::ParseLitePrimitive(const caffe::LayerParameter &proto, PrimitiveC *CaffeReduceParser::ParseLitePrimitive(const caffe::LayerParameter &proto,
const caffe::LayerParameter &weight) { const caffe::LayerParameter &weight) {
std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>(); auto attr = std::make_unique<schema::ReduceT>();
if (attr == nullptr) { if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed"; MS_LOG(ERROR) << "new op failed";
return nullptr; return nullptr;
} }
const caffe::PReLUParameter &pReluParam = proto.prelu_param(); attr->keepDims = false;
if (pReluParam.has_channel_shared()) {
attr->channelShared = pReluParam.channel_shared(); const caffe::ReductionParameter &reduce_param = proto.reduction_param();
if (reduce_param.has_operation()) {
if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_MEAN) {
attr->mode = schema::ReduceMode_ReduceMean;
} else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_SUM) {
attr->mode = schema::ReduceMode_ReduceSum;
} else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_SUMSQ) {
attr->mode = schema::ReduceMode_ReduceSumSquare;
} else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_ASUM) {
attr->mode = schema::ReduceMode_ReduceASum;
} else {
MS_LOG(ERROR) << "nsupported reduce mode: " << reduce_param.operation();
return nullptr;
}
} else {
attr->mode = schema::ReduceMode_ReduceSum;
}
std::vector<int32_t> axes;
if (reduce_param.has_axis()) {
axes.push_back(1);
axes.push_back(reduce_param.axis());
} else { } else {
attr->channelShared = false; axes.push_back(1);
axes.push_back(0);
} }
attr->axes = axes;
auto primitive = std::make_unique<schema::PrimitiveT>(); auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_Reduce; primitive->value.type = schema::PrimitiveType_Reduce;

Loading…
Cancel
Save