MKLDNN conv + elementwise_add fusion: UT for missing bias added. UTs refactored. Some minor changes in the pass

ce
Tomasz Patejko 7 years ago
parent 4e72ab411e
commit ce2464fd98

@ -68,8 +68,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output->AsIntermediate();
auto conv_op_has_bias = [](const Node& conv_op,
const Scope& scope) -> std::pair<bool, Node*> {
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
auto bias_input_names = conv_op.Op()->Inputs();
auto bias_it = bias_input_names.find("Bias");
@ -116,7 +115,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
bool has_bias;
Node* conv_bias;
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op, *param_scope());
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
if (has_bias) {
op_desc.SetInput("Bias", {conv_bias->Name()});

@ -1014,7 +1014,7 @@ PDNode *patterns::Conv::operator()() {
->AsOutput()
->assert_is_op_output("conv2d", "Output");
conv_op->LinksFrom({input_var, /*bias_var,*/ filter_var});
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var});
return output_var;

@ -617,7 +617,6 @@ struct Conv : public PatternBase {
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_input);
PATTERN_DECL_NODE(conv_bias);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(conv_residual_data);
PATTERN_DECL_NODE(conv_output);

Loading…
Cancel
Save