diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index d8d86b611b..dad863ece9 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -278,6 +278,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap MS_LOG(ERROR) << "Add fusion pass failed."; return nullptr; } + status = AddGraphPass(optimizer, config); if (status != RET_OK) { MS_LOG(ERROR) << "Add graph pass failed."; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_deconv_parser.cc new file mode 100644 index 0000000000..48c5734e12 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_deconv_parser.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_deconv_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/converter/parser/tf/tf_util.h" + +namespace mindspore { +namespace lite { +STATUS TFDeconvParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF DeConvParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + attr->group = 1; + attr->format = TensorFlowUtils::ParseNodeFormat(tf_op); + + std::vector dilations(2); + auto status = ParseDilations(tf_op, attr->format, &dilations); + if (status != RET_OK) { + return status; + } + attr->dilateH = dilations[0]; + attr->dilateW = dilations[1]; + + std::vector strides(2); + status = ParseStrides(tf_op, attr->format, &strides); + if (status != RET_OK) { + return status; + } + attr->strideH = strides[0]; + attr->strideW = strides[1]; + + auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1)); + if (weight_node != nullptr) { + std::vector kernels(4); + status = ParseKernels(*weight_node, attr->format, &kernels); + if (status != RET_OK) { + return status; + } + attr->kernelH = kernels[0]; + attr->kernelW = kernels[1]; + attr->channelIn = kernels[2]; + attr->channelOut = kernels[3]; + } else { + attr->kernelH = -1; + attr->kernelW = -1; + attr->channelIn = -1; + attr->channelOut = -1; + MS_LOG(WARNING) << "parsing of kernelH/W channelIn/Out is delayed"; + } + + status = ParsePadMode(tf_op, &attr->padMode); + if (status != RET_OK) { + return status; + } + + primitive->value.type = schema::PrimitiveType_DeConv2D; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + status = AddOpInput(tf_op, 2, inputs); + if (status != RET_OK) { + return status; + } + status = AddOpInput(tf_op, 1, inputs); // weights + return status; +} +TFNodeRegistrar g_tf_deconv_parser("Conv2DBackpropInput", new TFDeconvParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_deconv_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_deconv_parser.h new file mode 100644 index 0000000000..9d2fde610e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_deconv_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DECONV_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DECONV_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_conv_base_parser.h" +namespace mindspore { +namespace lite { +class TFDeconvParser : public TFConvBaseParser { + public: + TFDeconvParser() = default; + ~TFDeconvParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DECONV_PARSER_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc index 1b27adca31..164fb702ed 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -221,16 +221,22 @@ void ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const Pa delete[] tmp_weight_data; return; } - auto group = primc->GetGroup(); - auto cin_group = weight_tensor->tensor_shape()[0] / group; - int area_size = weight_tensor->tensor_shape()[2] * weight_tensor->tensor_shape()[3]; - int cout_size = kernel_num * area_size; - for (int k = 0; k < cin_group; ++k) { - for (int i = 0; i < kernel_num; ++i) { - auto row_addr = weight_data + k * cout_size + i * area_size; - auto new_row_addr = tmp_weight_data + k * cout_size + i * area_size; - for (int j = 0; j < area_size; j++) { - new_row_addr[j] = row_addr[j] * trans_scale[i]; + if (this->fmk_type_ == lite::converter::FmkType_TF) { + for (int i = 0; i < weight_shape_size; i++) { + tmp_weight_data[i] = weight_data[i] * trans_scale[i % kernel_num]; + } + } else { + auto group = primc->GetGroup(); + auto cin_group = weight_tensor->tensor_shape()[0] / group; + int area_size = weight_tensor->tensor_shape()[2] * weight_tensor->tensor_shape()[3]; + int cout_size = kernel_num * area_size; + for (int k = 0; k < cin_group; ++k) { + for (int i = 0; i < kernel_num; ++i) { + auto row_addr = weight_data + k * cout_size + i * area_size; + auto new_row_addr = tmp_weight_data + k * cout_size + i * area_size; + for (int j = 0; j < area_size; j++) { + new_row_addr[j] = row_addr[j] * trans_scale[i]; + } } } } diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc index 0f38cc9fac..af8d1ad173 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc @@ -193,6 +193,8 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTF(const AnfNodePtr &conv_node, param_value->set_format(schema::Format::Format_HWCK); } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { param_value->set_format(schema::Format::Format_HWKC); + } else if (op_type == schema::PrimitiveType_DeConv2D) { + param_value->set_format(schema::Format::Format_HWCK); } else { MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " << conv_node->fullname_with_scope(); @@ -248,7 +250,7 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { return false; } if (status != lite::RET_OK) { - MS_LOG(ERROR) << "schema::Format hardCode faild: " << status << ", node: " << node->fullname_with_scope(); + MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope(); return false; } }