|
|
@ -66,17 +66,17 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "DNNL Inplace op(" << current_op->id() << ") "
|
|
|
|
VLOG(3) << "oneDNN Inplace op(" << current_op->id() << ") "
|
|
|
|
<< "Curr Node In: " << current_op_in->Name()
|
|
|
|
<< "Curr Node In: " << current_op_in->Name()
|
|
|
|
<< " Curr Node out: " << current_op_out->Name();
|
|
|
|
<< " Curr Node out: " << current_op_out->Name();
|
|
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "DNNL Inplace next op(" << next_op->id() << ") "
|
|
|
|
VLOG(3) << "oneDNN Inplace next op(" << next_op->id() << ") "
|
|
|
|
<< " next Node out: " << next_op_out->Name();
|
|
|
|
<< " next Node out: " << next_op_out->Name();
|
|
|
|
|
|
|
|
|
|
|
|
auto inputs = current_op->Op()->Inputs();
|
|
|
|
auto inputs = current_op->Op()->Inputs();
|
|
|
|
auto outputs = current_op->Op()->Outputs();
|
|
|
|
auto outputs = current_op->Op()->Outputs();
|
|
|
|
auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN
|
|
|
|
auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN
|
|
|
|
VLOG(3) << "DNNL InplaceInferer op(" << current_op->id() << ") "
|
|
|
|
VLOG(3) << "oneDNN InplaceInferer op(" << current_op->id() << ") "
|
|
|
|
<< in_to_outs.begin()->first << ": "
|
|
|
|
<< in_to_outs.begin()->first << ": "
|
|
|
|
<< inputs[in_to_outs.begin()->first][0] << " "
|
|
|
|
<< inputs[in_to_outs.begin()->first][0] << " "
|
|
|
|
<< in_to_outs.begin()->second << ": "
|
|
|
|
<< in_to_outs.begin()->second << ": "
|
|
|
@ -85,7 +85,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
auto inplace_input_vec = inputs[in_to_outs.begin()->first];
|
|
|
|
auto inplace_input_vec = inputs[in_to_outs.begin()->first];
|
|
|
|
if (std::find(inplace_input_vec.begin(), inplace_input_vec.end(),
|
|
|
|
if (std::find(inplace_input_vec.begin(), inplace_input_vec.end(),
|
|
|
|
current_op_in->Name()) == inplace_input_vec.end()) {
|
|
|
|
current_op_in->Name()) == inplace_input_vec.end()) {
|
|
|
|
VLOG(3) << "DNNL in-place pass SKIP pattern ";
|
|
|
|
VLOG(3) << "oneDNN in-place pass SKIP pattern ";
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -93,7 +93,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
// is used anywhere else apart from inplaced op
|
|
|
|
// is used anywhere else apart from inplaced op
|
|
|
|
auto input_consumers = current_op_in->outputs;
|
|
|
|
auto input_consumers = current_op_in->outputs;
|
|
|
|
if (input_consumers.size() > 1) {
|
|
|
|
if (input_consumers.size() > 1) {
|
|
|
|
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
|
|
|
|
VLOG(3) << "oneDNN in-place pass FAIL: in-place var cannot "
|
|
|
|
"be an input to multiple operators";
|
|
|
|
"be an input to multiple operators";
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
@ -106,7 +106,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
if ((n->id() != current_op_in->id()) &&
|
|
|
|
if ((n->id() != current_op_in->id()) &&
|
|
|
|
(n->id() != next_op_out->id()) &&
|
|
|
|
(n->id() != next_op_out->id()) &&
|
|
|
|
(n->Name() == current_op_in->Name())) {
|
|
|
|
(n->Name() == current_op_in->Name())) {
|
|
|
|
VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of "
|
|
|
|
VLOG(3) << "oneDNN in-place pass FAIL var used in diffrent part of "
|
|
|
|
"graph ";
|
|
|
|
"graph ";
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -122,7 +122,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
original_output_names[current_op->Name() + current_op_in->Name()] =
|
|
|
|
original_output_names[current_op->Name() + current_op_in->Name()] =
|
|
|
|
current_op_out->Name();
|
|
|
|
current_op_out->Name();
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
VLOG(3) << "DNNL Inplace: Current op already inplaced! ";
|
|
|
|
VLOG(3) << "oneDNN Inplace: Current op already inplaced! ";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// It may be that next op is reusing some of vars, we need to
|
|
|
|
// It may be that next op is reusing some of vars, we need to
|
|
|
@ -133,7 +133,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
if ((n_op_infer_inplace == nullptr)) {
|
|
|
|
if ((n_op_infer_inplace == nullptr)) {
|
|
|
|
for (auto& m : n->outputs) {
|
|
|
|
for (auto& m : n->outputs) {
|
|
|
|
if (m->Name() == current_op_in->Name()) {
|
|
|
|
if (m->Name() == current_op_in->Name()) {
|
|
|
|
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
|
|
|
|
VLOG(3) << "oneDNN in-place pass FAIL: in-place var cannot "
|
|
|
|
"be an output to non-inplaced next op";
|
|
|
|
"be an output to non-inplaced next op";
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -173,7 +173,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
(std::find(next_op_inplace_inputs.begin(),
|
|
|
|
(std::find(next_op_inplace_inputs.begin(),
|
|
|
|
next_op_inplace_inputs.end(),
|
|
|
|
next_op_inplace_inputs.end(),
|
|
|
|
original_name) != next_op_inplace_inputs.end())) {
|
|
|
|
original_name) != next_op_inplace_inputs.end())) {
|
|
|
|
VLOG(3) << "DNNL InPlace: Next Op is in-placed , updating its "
|
|
|
|
VLOG(3) << "oneDNN InPlace: Next Op is in-placed , updating its "
|
|
|
|
"input "
|
|
|
|
"input "
|
|
|
|
"and output var!";
|
|
|
|
"and output var!";
|
|
|
|
next_op->Op()->SetOutput(
|
|
|
|
next_op->Op()->SetOutput(
|
|
|
@ -190,10 +190,24 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
next_op->Op()->RenameInput(original_name, current_op_out->Name());
|
|
|
|
next_op->Op()->RenameInput(original_name, current_op_out->Name());
|
|
|
|
|
|
|
|
|
|
|
|
found_inplace_count++;
|
|
|
|
found_inplace_count++;
|
|
|
|
VLOG(3) << "DNNL InPlace applied!";
|
|
|
|
VLOG(3) << "oneDNN InPlace applied!";
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
// TODO(jczaja): inplace pass does not influece ops inside block ops
|
|
|
|
|
|
|
|
auto should_inplace = [&](Graph* g) {
|
|
|
|
|
|
|
|
std::unordered_set<std::string> unwanted_ops(
|
|
|
|
|
|
|
|
{"conditional_block", "While", "while_loop"});
|
|
|
|
|
|
|
|
for (auto& node : g->Nodes()) {
|
|
|
|
|
|
|
|
if (node->IsOp() &&
|
|
|
|
|
|
|
|
unwanted_ops.find(node->Name()) != unwanted_ops.end()) {
|
|
|
|
|
|
|
|
VLOG(3) << "oneDNN InPlace FAILED: unsupported op: " << node->Name();
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (should_inplace(graph)) gpd(graph, handler);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|
|
|
|
} // namespace ir
|
|
|
|