|
|
|
@ -32,6 +32,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Pointer to graph argument should not be NULL."));
|
|
|
|
|
std::unordered_map<std::string, std::string> original_output_names;
|
|
|
|
|
std::unordered_set<std::string> inplaced_vars;
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(),
|
|
|
|
|
"mkldnn_inplace"};
|
|
|
|
@ -95,6 +96,22 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
|
|
|
|
|
"be an input to multiple operators";
|
|
|
|
|
return;
|
|
|
|
|
} else {
|
|
|
|
|
// We will prevent in-place when
|
|
|
|
|
// input is used in other part of graph, unless it was a result of
|
|
|
|
|
// inplacing
|
|
|
|
|
// Allow to next op out reuse inpuit var, as this is the same chaing
|
|
|
|
|
if (inplaced_vars.find(current_op_in->Name()) == inplaced_vars.end()) {
|
|
|
|
|
for (const Node* n : graph->Nodes()) {
|
|
|
|
|
if ((n->id() != current_op_in->id()) &&
|
|
|
|
|
(n->id() != next_op_out->id()) &&
|
|
|
|
|
(n->Name() == current_op_in->Name())) {
|
|
|
|
|
VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of "
|
|
|
|
|
"graph ";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If this op was alrady inplaced in previous pass placements
|
|
|
|
@ -132,6 +149,8 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
auto out_name = in_to_outs.begin()->second;
|
|
|
|
|
current_op->Op()->SetOutput(
|
|
|
|
|
out_name, std::vector<std::string>({current_op_out->Name()}));
|
|
|
|
|
// Record var name
|
|
|
|
|
inplaced_vars.insert(current_op_out->Name());
|
|
|
|
|
|
|
|
|
|
// If next op in a line is doing inplace
|
|
|
|
|
// then we need to update its output as well
|
|
|
|
|