diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 86bcaabed4..cb1bad7031 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -53,7 +53,7 @@ union PrimitiveType { Activation, Conv2D, FusedBatchNorm, - CaffeBatchNorm, + BatchNorm, BiasAdd, Pooling, DepthwiseConv2D, diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 2fc1024c5e..78278b0c3a 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -212,8 +212,8 @@ table Conv2DGradInput { spatial: int = 1; } -table CaffeBatchNorm { - epsilon: float; // eg. epsilon=0.001 +table BatchNorm { + epsilon: float = 0.00001; // eg. epsilon=0.001 } table BiasGrad { diff --git a/mindspore/lite/src/gllo/fusion/conv_bn_fusion.cc b/mindspore/lite/src/gllo/fusion/conv_bn_fusion.cc index 5f110a419d..d53f4ad266 100644 --- a/mindspore/lite/src/gllo/fusion/conv_bn_fusion.cc +++ b/mindspore/lite/src/gllo/fusion/conv_bn_fusion.cc @@ -37,7 +37,7 @@ constexpr const float POW_NUM = 0.5; bool IsBatchNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { auto type = opt::GetCNodeType(n); - return type == schema::PrimitiveType_CaffeBatchNorm || type == schema::PrimitiveType_FusedBatchNorm; + return type == schema::PrimitiveType_BatchNorm || type == schema::PrimitiveType_FusedBatchNorm; } return false; } @@ -115,12 +115,12 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern AnfNodePtr bn_bias_node = nullptr; float eps = 0; auto primitiveT_value = GetValueNode>(bn_node->input(0)); - if (GetCNodeType(bn_node) == schema::PrimitiveType_CaffeBatchNorm) { + if (GetCNodeType(bn_node) == schema::PrimitiveType_BatchNorm) { bn_mean_node = bn_node->input(kCaffeBNMeanIndex); bn_variance_node = bn_node->input(kCaffeBNVarIndex); CheckIfNodeIsParam(bn_mean_node); CheckIfNodeIsParam(bn_variance_node); - eps = primitiveT_value->GetPrimitiveT()->value.AsCaffeBatchNorm()->epsilon; + eps = primitiveT_value->GetPrimitiveT()->value.AsBatchNorm()->epsilon; } else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) { bn_scale_node = bn_node->input(kTFBNScaleIndex); bn_bias_node = bn_node->input(kTFBNBiasIndex); diff --git a/mindspore/lite/src/model_impl.cc b/mindspore/lite/src/model_impl.cc index 078bae0f41..eef3682043 100644 --- a/mindspore/lite/src/model_impl.cc +++ b/mindspore/lite/src/model_impl.cc @@ -90,8 +90,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { return new lite::DepthwiseConv2D(const_cast(srcPrim)); case schema::PrimitiveType_FusedBatchNorm: return new lite::FusedBatchNorm(const_cast(srcPrim)); - case schema::PrimitiveType_CaffeBatchNorm: - return new lite::CaffeBatchNorm(const_cast(srcPrim)); + case schema::PrimitiveType_BatchNorm: + return new lite::BatchNorm(const_cast(srcPrim)); case schema::PrimitiveType_FullConnection: return new lite::FullConnection(const_cast(srcPrim)); case schema::PrimitiveType_Power: diff --git a/mindspore/lite/src/ops/ops.cc b/mindspore/lite/src/ops/ops.cc index 85c20267ee..06da5561f5 100644 --- a/mindspore/lite/src/ops/ops.cc +++ b/mindspore/lite/src/ops/ops.cc @@ -39,8 +39,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) { return new lite::DepthwiseConv2D(const_cast(primitive)); case schema::PrimitiveType_FusedBatchNorm: return new lite::FusedBatchNorm(const_cast(primitive)); - case schema::PrimitiveType_CaffeBatchNorm: - return new lite::CaffeBatchNorm(const_cast(primitive)); + case schema::PrimitiveType_BatchNorm: + return new lite::BatchNorm(const_cast(primitive)); case schema::PrimitiveType_FullConnection: return new lite::FullConnection(const_cast(primitive)); case schema::PrimitiveType_Power: diff --git a/mindspore/lite/src/ops/ops.h b/mindspore/lite/src/ops/ops.h index 3942dd854f..f199b93c5b 100644 --- a/mindspore/lite/src/ops/ops.h +++ b/mindspore/lite/src/ops/ops.h @@ -90,10 +90,10 @@ class Pooling : public Primitive { int pad_r_ = 0; }; -class CaffeBatchNorm : public Primitive { +class BatchNorm : public Primitive { public: - explicit CaffeBatchNorm(schema::Primitive *primitive) : Primitive(primitive) {} - const schema::CaffeBatchNorm *GetAttribute() const { return this->primitive->value_as_CaffeBatchNorm(); } + explicit BatchNorm(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::BatchNorm *GetAttribute() const { return this->primitive->value_as_BatchNorm(); } }; class FusedBatchNorm : public Primitive { diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 521ae28867..a85682d94e 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -39,6 +39,7 @@ #include "src/runtime/kernel/arm/opclib/fp32/activation.h" #include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" #include "src/runtime/kernel/arm/opclib/fused_batchnorm.h" +#include "src/runtime/kernel/arm/opclib/fp32/batchnorm.h" #include "src/runtime/kernel/arm/opclib/power.h" #include "src/runtime/kernel/arm/opclib/fp32/range.h" #include "src/runtime/kernel/arm/opclib/fp32/local_response_norm.h" @@ -70,6 +71,18 @@ #include "src/runtime/kernel/arm/opclib/fp32/lstm.h" namespace mindspore::kernel { +OpParameter *PopulateBatchNorm(const lite::Primitive *primitive) { + BatchNormParameter *batch_norm_param = new (std::nothrow) BatchNormParameter(); + if (batch_norm_param == nullptr) { + MS_LOG(ERROR) << "new BatchNormParameter failed."; + return nullptr; + } + batch_norm_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_BatchNorm(); + batch_norm_param->epsilon_ = param->epsilon(); + return reinterpret_cast(batch_norm_param); +} + OpParameter *PopulateFillParameter(const lite::Primitive *primitive) { auto param = primitive->Value()->value_as_Fill(); FillParameter *fill_param = new (std::nothrow) FillParameter(); @@ -1190,6 +1203,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_DeDepthwiseConv2D] = PopulateDeconvDwParameter; populate_parameter_funcs_[schema::PrimitiveType_DeConv2D] = PopulateDeconvParameter; populate_parameter_funcs_[schema::PrimitiveType_FusedBatchNorm] = PopulateFusedBatchNorm; + populate_parameter_funcs_[schema::PrimitiveType_BatchNorm] = PopulateBatchNorm; populate_parameter_funcs_[schema::PrimitiveType_FullConnection] = PopulateFullconnectionParameter; populate_parameter_funcs_[schema::PrimitiveType_Power] = PopulatePowerParameter; populate_parameter_funcs_[schema::PrimitiveType_LocalResponseNormalization] = PopulateLocalResponseNormParameter; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc new file mode 100644 index 0000000000..063fd4f6aa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/batchnorm.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_BatchNorm; + +namespace mindspore::kernel { +int BatchnormCPUKernel::Init() { return RET_OK; } + +int BatchnormCPUKernel::ReSize() { return RET_OK; } + +int BatchnormCPUKernel::DoExecute(int tid) { + int count = MSMIN(thread_unit_, units_ - tid * thread_unit_); + if (count <= 0) { + return RET_OK; + } + int offset = tid * thread_unit_ * channel_; + BatchNorm(in_addr_ + offset, mean_addr_, var_addr_, count, channel_, batchnorm_param_->epsilon_, out_addr_ + offset); + return RET_OK; +} + +int BatchNormRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoExecute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "BatchnormRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int BatchnormCPUKernel::Run() { + in_addr_ = reinterpret_cast(inputs_.at(0)->Data()); + mean_addr_ = reinterpret_cast(inputs_.at(1)->Data()); + var_addr_ = reinterpret_cast(inputs_.at(2)->Data()); + out_addr_ = reinterpret_cast(outputs_.at(0)->Data()); + auto input_shapes = inputs_[0]->shape(); + channel_ = input_shapes[3]; + units_ = 1; + for (int i = 0; i < 3; i++) { + units_ *= input_shapes[i]; + } + thread_count_ = MSMIN(thread_count_, units_); + thread_unit_ = UP_DIV(units_, thread_count_); + int ret = LiteBackendParallelLaunch(BatchNormRun, this, thread_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "BatchnormRun error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuBatchnormKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_BatchNorm); + auto *kernel = new (std::nothrow) BatchnormCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new BatchNormCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchNorm, CpuBatchnormKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.h new file mode 100644 index 0000000000..2aadad2c09 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCHNORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCHNORM_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/opclib/fp32/batchnorm.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class BatchnormCPUKernel : public LiteKernel { + public: + BatchnormCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + batchnorm_param_ = reinterpret_cast(parameter); + } + ~BatchnormCPUKernel() override { delete batchnorm_param_; } + + int Init() override; + int ReSize() override; + int Run() override; + int DoExecute(int tid); + + private: + int thread_count_; + int thread_unit_; + int units_; + int channel_; + float *in_addr_; + float *mean_addr_; + float *var_addr_; + float *out_addr_; + const Context *ctx_; + BatchNormParameter *batchnorm_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCHNORM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batchnorm.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batchnorm.cc new file mode 100644 index 0000000000..57430f7c3e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batchnorm.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/batchnorm.h" + +void BatchNorm(const float *input_ptr, const float *mean_ptr, const float *variance_ptr, int units, int channel, + float epsilon, float *output_ptr) { + for (int u = 0; u < units; u++) { + for (int c = 0; c < channel; c++) { + auto variance_sqrt = sqrt(variance_ptr[c] + epsilon); + output_ptr[u * channel + c] = (input_ptr[u * channel + c] - mean_ptr[c]) / variance_sqrt; + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batchnorm.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batchnorm.h new file mode 100644 index 0000000000..ae3feb174c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batchnorm.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_BATCHNORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_BATCHNORM_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct BatchNormParameter { + OpParameter op_parameter_; + float epsilon_; +}; + +void BatchNorm(const float *input_ptr, const float *mean_ptr, const float *variance_ptr, int count, int channel, + float epsilon, float *output_ptr); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FUSED_BATCHNORM_H_ diff --git a/mindspore/lite/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc b/mindspore/lite/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc index 1906e71b14..e2ce0d9e90 100644 --- a/mindspore/lite/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc +++ b/mindspore/lite/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc @@ -96,8 +96,8 @@ MetaGraphTptr BuildCaffeGraph(schema::PrimitiveType conv_type) { bn_node->inputIndex = {2, 3, 4}; bn_node->outputIndex = {5}; bn_node->primitive = std::make_unique(); - bn_node->primitive->value.type = schema::PrimitiveType_CaffeBatchNorm; - auto prim2 = new schema::CaffeBatchNormT; + bn_node->primitive->value.type = schema::PrimitiveType_BatchNorm; + auto prim2 = new schema::BatchNormT; bn_node->primitive->value.value = prim2; bn_node->name = "bn"; meta_graph->nodes.emplace_back(std::move(bn_node)); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc new file mode 100644 index 0000000000..8f48b472fe --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batchnorm.h" +#include "mindspore/lite/src/runtime/kernel/arm/opclib/fused_batchnorm.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" +#include "mindspore/lite/src/common/file_utils.h" + +namespace mindspore { + +class TestBatchnormFp32 : public mindspore::Common { + public: + TestBatchnormFp32() {} +}; + +TEST_F(TestBatchnormFp32, BNTest) { + std::vector in_data = {0.0669681, 0.959215, 0.252686, 0.613594, 0.811776, 0.139469, 0.322848, 0.118354, + 0.082978, 0.399467, 0.961267, 0.0247456, 0.0714259, 0.0791484, 0.0648625, 0.561612, + 0.412069, 0.311492, 0.46109, 0.377125, 0.369283, 0.0332446, 0.696142, 0.715973, + 0.525524, 0.477265, 0.0336351, 0.751577, 0.377548, 0.964603, 0.0196834, 0.174865}; + std::vector in_data1 = {0.855446, 0.821765, 0.281008, 0.0798653, 0.22294, 0.793782, 0.963222, 0.17851, + 0.667549, 0.274381, 0.592842, 0.216552, 0.190274, 0.237873, 0.610063, 0.307559, + 0.830007, 0.760957, 0.583265, 0.763793, 0.456372, 0.391378, 0.547915, 0.862198, + 0.510794, 0.826776, 0.515894, 0.30071, 0.404987, 0.184773}; + std::vector in_data2 = {0.712438, 0.4927, 0.078419, 0.310429, 0.546871, 0.0667141, 0.874321, 0.0265647, + 0.685165, 0.732586, 0.952889, 0.506402, 0.540784, 0.131119, 0.357713, 0.678992, + 0.960839, 0.340706, 0.697678, 0.398146, 0.313321, 0.6485, 0.739153, 0.00190134, + 0.536842, 0.996873, 0.445276, 0.371212, 0.420397, 0.0930115}; + std::vector in_data3(32, 1); + std::vector in_data4(32, 0); + std::vector inputs_tensor; + std::vector outputs_tensor; + + BatchNormParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_BatchNorm; + op_param.epsilon_ = 0.001f; + + std::vector in_shape = {1, 2, 4, 4}; + + lite::tensor::Tensor input0_tensor; + lite::tensor::Tensor input1_tensor; + lite::tensor::Tensor input2_tensor; + inputs_tensor.push_back(&input0_tensor); + inputs_tensor.push_back(&input1_tensor); + inputs_tensor.push_back(&input2_tensor); + input0_tensor.SetData(in_data.data()); + input1_tensor.SetData(in_data1.data()); + input2_tensor.SetData(in_data2.data()); + input0_tensor.set_shape(in_shape); + + std::vector output(32); + std::vector corr_out(32); + std::vector output_shape = {1, 2, 4, 4}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_BatchNorm}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + lite::Context ctx; + ctx.thread_num_ = 7; + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), &ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + FusedBatchNorm(in_data.data(), in_data3.data(), in_data4.data(), in_data1.data(), in_data2.data(), in_shape.data(), + 0.001f, corr_out.data()); + + printf("==================output data=================\n"); + for (int i = 0; i < 1 * 28; i++) { + std::cout << output[i] << " ,"; + } + std::cout << std::endl; + CompareOutputData(output.data(), corr_out.data(), 32, 0.00001); + + input0_tensor.SetData(nullptr); + input1_tensor.SetData(nullptr); + input2_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_bn_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/conv_bn_fusion_pass.cc index f8e20ef1e1..3ab24737b4 100644 --- a/mindspore/lite/tools/converter/optimizer/fusion/conv_bn_fusion_pass.cc +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_bn_fusion_pass.cc @@ -50,7 +50,7 @@ STATUS ConvBNFusionPass::DefinePattern() { convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; auto bnOp = std::make_shared(); bnOp->id = DST_NAME; - bnOp->types = {schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_CaffeBatchNorm}; + bnOp->types = {schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_BatchNorm}; bnOp->left = convOp; std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvBatchNormFusion")); @@ -208,8 +208,8 @@ STATUS ConvBNFusionPass::GetBnEpsilon(schema::MetaGraphT *graph, std::shared_ptr MS_ASSERT(bnNode != nullptr); if (bnNode->primitive->value.type == schema::PrimitiveType_FusedBatchNorm) { eps = bnNode->primitive->value.AsFusedBatchNorm()->epsilon; - } else if (bnNode->primitive->value.type == schema::PrimitiveType_CaffeBatchNorm) { - eps = bnNode->primitive->value.AsCaffeBatchNorm()->epsilon; + } else if (bnNode->primitive->value.type == schema::PrimitiveType_BatchNorm) { + eps = bnNode->primitive->value.AsBatchNorm()->epsilon; } else { MS_LOG(ERROR) << "match pattern has error, " << bnNode->name.c_str() << " not BatchNorm node"; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc index 2e8e78fa37..7b22fd2d95 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc @@ -28,13 +28,11 @@ static const int CAFFE_BATCHNORMAL_TOP_SIZE = 1; namespace mindspore { namespace lite { using STATUS = int; -STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, - const caffe::LayerParameter &weight, - schema::CNodeT *op, - std::vector *weightVec) { +STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { op->name = proto.name(); // caffe batch norm attr - std::unique_ptr attr(new FusedBatchNormT()); + std::unique_ptr attr(new schema::BatchNormT()); const caffe::BatchNormParameter batchNormParam = proto.batch_norm_param(); // check bottom size @@ -98,7 +96,7 @@ STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, weightVec->push_back(beta); op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; + op->primitive->value.type = schema::PrimitiveType_BatchNorm; op->primitive->value.value = attr.release(); return RET_OK; @@ -107,5 +105,3 @@ STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, CaffeNodeRegistrar g_caffeBatchNormParser("BatchNorm", new CaffeBatchNormParser()); } // namespace lite } // namespace mindspore - - diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc old mode 100755 new mode 100644 index 8992d2da1f..7378bb1f99 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -61,7 +61,7 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const caffe::NetParameter weight; if (ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight) != RET_OK) { - MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weightFile; + MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weightFile; return nullptr; } @@ -88,14 +88,13 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const SetAllTensors(tensorCache, subGraphDef.get()); graph = move(subGraphDef); - ConvertCaffeBatchNorm(graph.get()); + // ConvertCaffeBatchNorm(graph.get()); return graph.release(); -// return Fb2Anf(graph.release()); + // return Fb2Anf(graph.release()); } -STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, - schema::CNodeT *op, +STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache) { for (int i = 0; i < layer.bottom_size(); i++) { int index = tensorCache->FindTensor(layer.bottom(i)); @@ -109,8 +108,7 @@ STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, return RET_OK; } -STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, - schema::CNodeT *op, +STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache) { for (int i = 0; i < layer.top_size(); i++) { std::unique_ptr msTensor(new schema::TensorT()); @@ -183,7 +181,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff } msTensor->nodeType = schema::NodeType_ValueNode; msTensor->refCount = 1; - msTensor->dataType = kNumberTypeFloat32; + msTensor->dataType = kNumberTypeFloat32; tensorCache->AddTensor(layer.top(0), msTensor.release(), GRAPH_INPUT); } else { if (skipedLayerType.find(layer.type()) != skipedLayerType.end()) { @@ -240,7 +238,7 @@ STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, TensorC msTensor->dims.push_back(proto.input_dim(j)); } msTensor->refCount = schema::NodeType_ValueNode; - msTensor->dataType = kNumberTypeFloat32; + msTensor->dataType = kNumberTypeFloat32; tensorCache->AddTensor(proto.input(i), msTensor.release(), GRAPH_INPUT); } @@ -251,7 +249,7 @@ STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, TensorC msTensor->dims.push_back(shape.dim(j)); } msTensor->refCount = schema::NodeType_ValueNode; - msTensor->dataType = kNumberTypeFloat32; + msTensor->dataType = kNumberTypeFloat32; tensorCache->AddTensor(proto.input(i), msTensor.release(), GRAPH_INPUT); } return RET_OK; @@ -279,7 +277,7 @@ void CaffeModelParser::ConvertCaffeBatchNorm(schema::MetaGraphT *meta_graph) { scaleTensor->dataType = TypeId::kNumberTypeFloat32; scaleTensor->data.resize(shapeSize * sizeof(float)); auto scaleData = reinterpret_cast(scaleTensor->data.data()); - for (size_t i = 0 ; i < shapeSize; i++) { + for (size_t i = 0; i < shapeSize; i++) { scaleData[i] = 1; } @@ -291,7 +289,7 @@ void CaffeModelParser::ConvertCaffeBatchNorm(schema::MetaGraphT *meta_graph) { biasTensor->dataType = TypeId::kNumberTypeInt32; biasTensor->data.resize(shapeSize * sizeof(int32_t)); auto biasData = reinterpret_cast(biasTensor->data.data()); - for (size_t i = 0 ; i < shapeSize; i++) { + for (size_t i = 0; i < shapeSize; i++) { biasData[i] = 0; } @@ -304,4 +302,3 @@ void CaffeModelParser::ConvertCaffeBatchNorm(schema::MetaGraphT *meta_graph) { } } // namespace lite } // namespace mindspore -