|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
@ -30,6 +31,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(graph,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Pointer to graph argument should not be NULL."));
|
|
|
|
|
std::unordered_map<std::string, std::string> original_output_names;
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(),
|
|
|
|
|
"mkldnn_inplace"};
|
|
|
|
@ -40,72 +42,136 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
Graph* g) {
|
|
|
|
|
VLOG(3) << "Start to handle MKL-DNN In-Place pass";
|
|
|
|
|
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(inplace_to_be_op, inplace_to_be_op,
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(current_op, inplace_to_be_op, mkldnn_inplace);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(current_op_in, inplace_to_be_op_in,
|
|
|
|
|
mkldnn_inplace);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(inplace_to_be_op_in, inplace_to_be_op_in,
|
|
|
|
|
mkldnn_inplace);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(inplace_to_be_op_out, inplace_to_be_op_out,
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(current_op_out, inplace_to_be_op_out,
|
|
|
|
|
mkldnn_inplace);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, mkldnn_inplace);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(next_op_out, next_op_out, mkldnn_inplace);
|
|
|
|
|
|
|
|
|
|
if ((inplace_to_be_op->Op()->HasAttr("use_mkldnn") == false) ||
|
|
|
|
|
(boost::get<bool>(inplace_to_be_op->Op()->GetAttr("use_mkldnn")) ==
|
|
|
|
|
false)) {
|
|
|
|
|
if ((current_op->Op()->HasAttr("use_mkldnn") == false) ||
|
|
|
|
|
(boost::get<bool>(current_op->Op()->GetAttr("use_mkldnn")) == false)) {
|
|
|
|
|
VLOG(3) << "do not perform mkl-dnn inplace: use_mkldnn missing or set to "
|
|
|
|
|
"false";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto& infer_inplace = OpInfoMap::Instance()
|
|
|
|
|
.Get(inplace_to_be_op->Op()->Type())
|
|
|
|
|
.infer_inplace_;
|
|
|
|
|
auto& infer_inplace =
|
|
|
|
|
OpInfoMap::Instance().Get(current_op->Op()->Type()).infer_inplace_;
|
|
|
|
|
if (!infer_inplace) {
|
|
|
|
|
VLOG(3) << "do not perform mkl-dnn inplace: missing InplaceInferer";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): Enable more ops
|
|
|
|
|
if (inplace_to_be_op->Op()->Type() != "softmax") {
|
|
|
|
|
VLOG(3)
|
|
|
|
|
<< "Curently works for softmax only. TODO(jczaja): support other ops";
|
|
|
|
|
VLOG(3) << "DNNL Inplace op(" << current_op->id() << ") "
|
|
|
|
|
<< "Curr Node In: " << current_op_in->Name()
|
|
|
|
|
<< " Curr Node out: " << current_op_out->Name();
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "DNNL Inplace next op(" << next_op->id() << ") "
|
|
|
|
|
<< " next Node out: " << next_op_out->Name();
|
|
|
|
|
|
|
|
|
|
auto inputs = current_op->Op()->Inputs();
|
|
|
|
|
auto outputs = current_op->Op()->Outputs();
|
|
|
|
|
auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN
|
|
|
|
|
VLOG(3) << "DNNL InplaceInferer op(" << current_op->id() << ") "
|
|
|
|
|
<< in_to_outs.begin()->first << ": "
|
|
|
|
|
<< inputs[in_to_outs.begin()->first][0] << " "
|
|
|
|
|
<< in_to_outs.begin()->second << ": "
|
|
|
|
|
<< outputs[in_to_outs.begin()->second][0];
|
|
|
|
|
// If InferInplace pattern does not contain input node then skip
|
|
|
|
|
auto inplace_input_vec = inputs[in_to_outs.begin()->first];
|
|
|
|
|
if (std::find(inplace_input_vec.begin(), inplace_input_vec.end(),
|
|
|
|
|
current_op_in->Name()) == inplace_input_vec.end()) {
|
|
|
|
|
VLOG(3) << "DNNL in-place pass SKIP pattern ";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Iterate over all nodes that are ops
|
|
|
|
|
// and check if in-place to be var is part of inputs
|
|
|
|
|
// if positive then do not perform inplace
|
|
|
|
|
for (const Node* n : graph->Nodes()) {
|
|
|
|
|
if (n->IsOp()) {
|
|
|
|
|
// Avoid searchin in op that is to be inplace
|
|
|
|
|
if ((n->id() != inplace_to_be_op->id())) {
|
|
|
|
|
auto* op = n->Op();
|
|
|
|
|
auto inputs = op->Inputs();
|
|
|
|
|
auto in_place_input = inplace_to_be_op_in->Name();
|
|
|
|
|
for (auto& it : inputs) {
|
|
|
|
|
for (auto& var_name : it.second) {
|
|
|
|
|
if (var_name == in_place_input) {
|
|
|
|
|
VLOG(3) << "MKL-DNN in-place pass: in-place var cannot be an "
|
|
|
|
|
"input to more than one operator";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Checking if this particular node (to be inplaced, overwritten)
|
|
|
|
|
// is used anywhere else apart from inplaced op
|
|
|
|
|
auto input_consumers = current_op_in->outputs;
|
|
|
|
|
if (input_consumers.size() > 1) {
|
|
|
|
|
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
|
|
|
|
|
"be an input to multiple operators";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If this op was alrady inplaced in previous pass placements
|
|
|
|
|
// then we need to update input of next op
|
|
|
|
|
// but original name to be changed is gone, so we need to remember it
|
|
|
|
|
// on first time given op is to be inplaced
|
|
|
|
|
if (current_op_in->Name() != current_op_out->Name()) {
|
|
|
|
|
original_output_names[current_op->Name() + current_op_in->Name()] =
|
|
|
|
|
current_op_out->Name();
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << "DNNL Inplace: Current op already inplaced! ";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// It may be that next op is reusing some of vars, we need to
|
|
|
|
|
// make sure that unwanted inplace is not created
|
|
|
|
|
// TODO(jczaja): Make UT for that one
|
|
|
|
|
for (auto& n : current_op_out->outputs) {
|
|
|
|
|
auto& n_op_infer_inplace =
|
|
|
|
|
OpInfoMap::Instance().Get(n->Op()->Type()).infer_inplace_;
|
|
|
|
|
if ((n_op_infer_inplace == nullptr)) {
|
|
|
|
|
for (auto& m : n->outputs) {
|
|
|
|
|
if (m->Name() == current_op_in->Name()) {
|
|
|
|
|
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
|
|
|
|
|
"be an output to non-inplaced next op";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto original_name = inplace_to_be_op_out->Name();
|
|
|
|
|
inplace_to_be_op_out->RenameVar(inplace_to_be_op_in->Name());
|
|
|
|
|
auto original_name =
|
|
|
|
|
original_output_names[current_op->Name() + current_op_in->Name()];
|
|
|
|
|
current_op_out->RenameVar(current_op_in->Name());
|
|
|
|
|
|
|
|
|
|
// Get mapping of input to output
|
|
|
|
|
auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN
|
|
|
|
|
// TODO(jczaja): Support more complex situations
|
|
|
|
|
auto out_name = in_to_outs.begin()->second;
|
|
|
|
|
inplace_to_be_op->Op()->SetOutput(
|
|
|
|
|
out_name, std::vector<std::string>({inplace_to_be_op_out->Name()}));
|
|
|
|
|
next_op->Op()->RenameInput(original_name, inplace_to_be_op_out->Name());
|
|
|
|
|
current_op->Op()->SetOutput(
|
|
|
|
|
out_name, std::vector<std::string>({current_op_out->Name()}));
|
|
|
|
|
|
|
|
|
|
// If next op in a line is doing inplace
|
|
|
|
|
// then we need to update its output as well
|
|
|
|
|
|
|
|
|
|
// Get inferer of next op
|
|
|
|
|
// If no inferer then we are done
|
|
|
|
|
auto& next_op_infer_inplace =
|
|
|
|
|
OpInfoMap::Instance().Get(next_op->Op()->Type()).infer_inplace_;
|
|
|
|
|
if (next_op_infer_inplace) {
|
|
|
|
|
auto in_to_outs = next_op_infer_inplace(false);
|
|
|
|
|
auto out_name = in_to_outs.begin()->second;
|
|
|
|
|
auto* op = next_op->Op();
|
|
|
|
|
auto inputs = op->Inputs();
|
|
|
|
|
auto outputs = op->Outputs();
|
|
|
|
|
// Check if in-place happened
|
|
|
|
|
// for variable we changed (original name)
|
|
|
|
|
// TODO(jczaja): make recursive propagation of inplace
|
|
|
|
|
auto next_op_inplace_inputs = inputs[in_to_outs.begin()->first];
|
|
|
|
|
if ((next_op_inplace_inputs == outputs[in_to_outs.begin()->second]) &&
|
|
|
|
|
(std::find(next_op_inplace_inputs.begin(),
|
|
|
|
|
next_op_inplace_inputs.end(),
|
|
|
|
|
original_name) != next_op_inplace_inputs.end())) {
|
|
|
|
|
VLOG(3) << "DNNL InPlace: Next Op is in-placed , updating its "
|
|
|
|
|
"input "
|
|
|
|
|
"and output var!";
|
|
|
|
|
next_op->Op()->SetOutput(
|
|
|
|
|
out_name, std::vector<std::string>({current_op_out->Name()}));
|
|
|
|
|
next_op_out->RenameVar(current_op_in->Name());
|
|
|
|
|
// Get ops that next_op_out is linked to and update their input
|
|
|
|
|
auto next_op_out_consumers = next_op_out->outputs; // Has to be ops
|
|
|
|
|
for (auto& c : next_op_out_consumers) {
|
|
|
|
|
c->Op()->RenameInput(original_name, current_op_out->Name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
next_op->Op()->RenameInput(original_name, current_op_out->Name());
|
|
|
|
|
|
|
|
|
|
found_inplace_count++;
|
|
|
|
|
VLOG(3) << "MKL-DNN InPlace applied!";
|
|
|
|
|
VLOG(3) << "DNNL InPlace applied!";
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
|