tf_parser-4

pull/11172/head
yefeng 4 years ago
parent b8a3a539bc
commit 257aa5fb97

@ -45,6 +45,7 @@ STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodePr
}
if (iter->ints().data() == nullptr) {
MS_LOG(ERROR) << "origin ints data in onnx is nullptr";
delete[] param_data;
return RET_NULL_PTR;
}
if (memcpy_s(param_data, data_size, iter->ints().data(), data_size) != EOK) {

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_argmax_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFArgMaxParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF ArgMaxParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::ArgMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
auto axis_node = tf_node_map.at(tf_op.input(tf_op.input_size() - 1));
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The attr value should be specified.";
return RET_ERROR;
}
auto &axis_tensor = attr_value.tensor();
attr->axis = axis_tensor.int_val(0);
attr->outMaxValue = false;
primitive->value.type = schema::PrimitiveType_ArgMax;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
return RET_OK;
}
TFNodeRegistrar g_tfArgMaxParser("ArgMax", new TFArgMaxParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARGMAX_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARGMAX_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFArgMaxParser : public TFNodeParser {
public:
TFArgMaxParser() = default;
~TFArgMaxParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARGMAX_PARSER_H_

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_argmin_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFArgMinParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF ArgMinParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::ArgMinT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
auto axis_node = tf_node_map.at(tf_op.input(tf_op.input_size() - 1));
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The attr value should be specified.";
return RET_ERROR;
}
auto &axis_tensor = attr_value.tensor();
attr->axis = axis_tensor.int_val(0);
attr->outMaxValue = false;
primitive->value.type = schema::PrimitiveType_ArgMin;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
return RET_OK;
}
TFNodeRegistrar g_tfArgMinParser("ArgMin", new TFArgMinParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARGMIN_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARGMIN_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFArgMinParser : public TFNodeParser {
public:
TFArgMinParser() = default;
~TFArgMinParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARGMIN_PARSER_H_

@ -0,0 +1,110 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_conv_depthwise_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/parser/tf/tf_util.h"
namespace mindspore {
namespace lite {
STATUS TFConvDepthwiseParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF ConvDepthwiseParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
attr->format = TensorFlowUtils::ParseNodeFormat(tf_op);
if (attr->format == schema::Format_NCHW) {
MS_LOG(ERROR) << "TF Conv2D with data_format=NCHW is not supported now";
return RET_ERROR;
}
std::vector<int64_t> dilations(2);
auto status = ParseDilations(tf_op, attr->format, &dilations);
if (status != RET_OK) {
return status;
}
attr->dilateH = dilations[0];
attr->dilateW = dilations[1];
std::vector<int64_t> strides(2);
status = ParseStrides(tf_op, attr->format, &strides);
if (status != RET_OK) {
return status;
}
attr->strideH = strides[0];
attr->strideW = strides[1];
auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1));
if (weight_node != nullptr) {
std::vector<int64_t> kernels(4);
status = ParseKernels(*weight_node, attr->format, &kernels);
if (status != RET_OK) {
return status;
}
attr->kernelH = kernels[0];
attr->kernelW = kernels[1];
attr->channelIn = kernels[2];
attr->channelMultiplier = kernels[3];
} else {
attr->kernelH = -1;
attr->kernelW = -1;
attr->channelIn = -1;
attr->channelMultiplier = -1;
MS_LOG(WARNING) << "parsing of kernelH/W channelIn/Out is delayed";
}
status = ParsePadMode(tf_op, &attr->padMode);
if (status != RET_OK) {
return status;
}
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
status = AddOpInput(tf_op, 1, inputs); // weights
return status;
}
TFNodeRegistrar g_tfConvDepthwiseParser("DepthwiseConv2dNative", new TFConvDepthwiseParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_DEPTHWISE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_DEPTHWISE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_conv_base_parser.h"
namespace mindspore {
namespace lite {
class TFConvDepthwiseParser : public TFConvBaseParser {
public:
TFConvDepthwiseParser() = default;
~TFConvDepthwiseParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_DEPTHWISE_PARSER_H_

@ -653,7 +653,7 @@ STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type,
*filterW = oriDims.at(lite::HWCK_W);
*filterC = oriDims.at(lite::HWCK_C);
*filterK = oriDims.at(lite::HWCK_K);
} else if (type == kHWKC2KCHW || type == kHWKC2CKHW) {
} else if (type == kHWKC2KCHW || type == kHWKC2CKHW || type == kHWKC2KHWC) {
*filterH = oriDims.at(lite::HWKC_H);
*filterW = oriDims.at(lite::HWKC_W);
*filterK = oriDims.at(lite::HWKC_K);
@ -693,7 +693,8 @@ STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int3
tensor->set_tensor_shape({filterC, filterK, filterH, filterW});
} else if (type == kKHWC2CHWK) {
tensor->set_tensor_shape({filterC, filterH, filterW, filterK});
} else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC || type == kHWCK2KHWC) {
} else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC || type == kHWCK2KHWC ||
type == kHWKC2KHWC) {
tensor->set_tensor_shape({filterK, filterH, filterW, filterC});
} else {
MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
@ -836,6 +837,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
}
} break;
case kHWKC2KCHW:
case kHWKC2KHWC:
case kHWKC2CKHW: {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
@ -1009,6 +1011,20 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
return RET_ERROR;
}
break;
case schema::Format::Format_HWKC:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWKC2KHWC);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWKC2KHWC);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWKC2KHWC);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kHWKC2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_KHWC:
return RET_OK;
default:

@ -103,7 +103,8 @@ enum kTransFilterType {
kCKHW2KCHW,
kCHWK2KCHW,
kKCHW2CKHW, // 20
kHWCK2KHWC
kHWCK2KHWC,
kHWKC2KHWC
};
STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,

@ -191,6 +191,8 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTF(const AnfNodePtr &conv_node,
if (op_type == schema::PrimitiveType_Conv2D) {
param_value->set_format(schema::Format::Format_HWCK);
} else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
param_value->set_format(schema::Format::Format_HWKC);
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
<< ", node: " << conv_node->fullname_with_scope();

Loading…
Cancel
Save