fix code review

pull/9444/head
sunsuodong 4 years ago
parent 3da8cc98c5
commit 3d7e9b0c79

@ -660,7 +660,7 @@ table NetOutput {
}
table MatMul {
broadcast : bool = false;
broadcast : bool = false; // DEPRECATED
transposeA : bool = false;
transposeB : bool = false;
}

@ -189,7 +189,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
attr->channelMultiplier = channel_mutiplier;
MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo);
auto input_node = inputs[kAnfPopulaterInputNumOne];
auto input_node = inputs.at(kAnfPopulaterInputNumOne);
MS_ASSERT(input_node != nullptr);
if (input_node->isa<Parameter>()) {
auto param_node = input_node->cast<ParameterPtr>();
@ -201,7 +201,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
MS_ASSERT(abstractTensor != nullptr);
if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
attr->channelIn = dims[kAnfPopulaterInputNumOne];
attr->channelIn = dims.at(kAnfPopulaterInputNumOne);
}
}
} else if (input_node->isa<CNode>()) {

@ -128,7 +128,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
attr->channelMultiplier = channel_multiplier;
MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo);
auto inputNode = inputs[kAnfPopulaterInputNumOne];
auto inputNode = inputs.at(kAnfPopulaterInputNumOne);
MS_ASSERT(inputNode != nullptr);
if (inputNode->isa<Parameter>()) {
auto paramNode = inputNode->cast<ParameterPtr>();
@ -139,7 +139,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
MS_ASSERT(abstractTensor != nullptr);
if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
attr->channelIn = dims[kAnfPopulaterInputNumOne];
attr->channelIn = dims.at(kAnfPopulaterInputNumOne);
}
}
}

@ -42,9 +42,6 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
float beta = 1.0f;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "broadcast") {
attr->broadcast = static_cast<bool>(onnx_node_attr.i());
}
if (attribute_name == "transA") {
attr->transposeA = static_cast<bool>(onnx_node_attr.i());
} else if (attribute_name == "transB") {

@ -199,7 +199,6 @@ STATUS TfliteCustomParser::BatchMatMul(const std::vector<uint8_t> &custom_attr,
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->broadcast = false;
attr->transposeA = false;
attr->transposeB = false;
op->primitive->value.type = schema::PrimitiveType_MatMul;

@ -36,7 +36,6 @@ PrimitiveC *TfliteMatMulParser::ParseLitePrimitive(const std::unique_ptr<tflite:
const auto &tflite_attr = tflite_op->builtin_options.AsBatchMatMulOptions();
attr->transposeA = tflite_attr->adj_x;
attr->transposeB = tflite_attr->adj_y;
attr->broadcast = false;
primitive->value.type = schema::PrimitiveType_MatMul;
primitive->value.value = attr.release();

Loading…
Cancel
Save