diff --git a/mindspore/lite/src/ops/populate/select_populate.cc b/mindspore/lite/src/ops/populate/select_populate.cc index c93674a9ae..efee92d035 100644 --- a/mindspore/lite/src/ops/populate/select_populate.cc +++ b/mindspore/lite/src/ops/populate/select_populate.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/lite/src/ops/select.cc b/mindspore/lite/src/ops/select.cc index c5f4bb4dda..1bcb18dd67 100644 --- a/mindspore/lite/src/ops/select.cc +++ b/mindspore/lite/src/ops/select.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -90,9 +90,7 @@ int Select::InferShape(std::vector inputs_, std::vector outp output->set_shape(input->shape()); output->set_format(input->format()); auto data_type = input->data_type(); - if (data_type != kObjectTypeTensorType) { - continue; - } else { + if (data_type == kObjectTypeTensorType) { auto input_tensorlist = reinterpret_cast(input); auto output_tensorlist = reinterpret_cast(output); output_tensorlist->set_element_shape(input_tensorlist->element_shape()); diff --git a/mindspore/lite/src/ops/select.h b/mindspore/lite/src/ops/select.h index ec19825b1d..02f8ec452d 100644 --- a/mindspore/lite/src/ops/select.h +++ b/mindspore/lite/src/ops/select.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/lite/src/runtime/kernel/arm/base/select.cc b/mindspore/lite/src/runtime/kernel/arm/base/select.cc index 8d393954e2..00deeea53c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/select.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/select.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc index 17f8fbf304..5e890fc357 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc @@ -22,18 +22,41 @@ namespace mindspore { namespace lite { PrimitiveC *CaffeReduceParser::ParseLitePrimitive(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return nullptr; } - const caffe::PReLUParameter &pReluParam = proto.prelu_param(); - if (pReluParam.has_channel_shared()) { - attr->channelShared = pReluParam.channel_shared(); + attr->keepDims = false; + + const caffe::ReductionParameter &reduce_param = proto.reduction_param(); + if (reduce_param.has_operation()) { + if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_MEAN) { + attr->mode = schema::ReduceMode_ReduceMean; + } else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_SUM) { + attr->mode = schema::ReduceMode_ReduceSum; + } else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_SUMSQ) { + attr->mode = schema::ReduceMode_ReduceSumSquare; + } else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_ASUM) { + attr->mode = schema::ReduceMode_ReduceASum; + } else { + MS_LOG(ERROR) << "nsupported reduce mode: " << reduce_param.operation(); + return nullptr; + } + } else { + attr->mode = schema::ReduceMode_ReduceSum; + } + + std::vector axes; + if (reduce_param.has_axis()) { + axes.push_back(1); + axes.push_back(reduce_param.axis()); } else { - attr->channelShared = false; + axes.push_back(1); + axes.push_back(0); } + attr->axes = axes; auto primitive = std::make_unique(); primitive->value.type = schema::PrimitiveType_Reduce;