Support generating code for grad_op (#21066)
* Add the definition of operation in fusion_group. * Use operations in OperationMap to detect fusion_group of elementwise pattern. * Add namespace fusion_group in code_generator. * Use operations recorded in OperationMap to generate code. * Remove implementation codes to .cc file. * Refine Operation and CodeGenerator to make it easier to generate code for grad_op. Refine the unittest for better reuse. * Avoid recording the template's keyword in a array. * Support the generating of code for grad_op and add unittest. test=develop * Remove replaced_element_in_order and use use number instead. test=developcustom_op_abi
parent
1cd6721873
commit
9091f8cdf9
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,115 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
namespace fusion_group {
|
||||
|
||||
OperationMap* OperationMap::map = nullptr;
|
||||
|
||||
OperationMap::OperationMap() {
|
||||
InsertUnaryElementwiseOperations();
|
||||
InsertBinaryElementwiseOperations();
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> OperationMap::Find(int type, int num_operands) {
|
||||
std::unordered_set<std::string> res;
|
||||
for (auto& t : operations_) {
|
||||
if ((t.second.type == type) &&
|
||||
(num_operands < 0 || t.second.num_operands == num_operands)) {
|
||||
res.insert(t.first);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void OperationMap::Insert(int type, int num_operands, std::string op_type,
|
||||
std::string expr,
|
||||
std::vector<std::string> grad_exprs) {
|
||||
Operation op(type, num_operands, op_type, {expr});
|
||||
PADDLE_ENFORCE_EQ(op.IsValid(), true, "Operation %s is invalid.", op_type);
|
||||
operations_[op_type] = op;
|
||||
|
||||
if (grad_exprs.size() > 0U) {
|
||||
std::string grad_op_type = op_type + "_grad";
|
||||
Operation grad_op(type, num_operands, grad_op_type, grad_exprs);
|
||||
PADDLE_ENFORCE_EQ(grad_op.IsValid(), true, "Operation %s is invalid.",
|
||||
grad_op_type);
|
||||
operations_[grad_op_type] = grad_op;
|
||||
}
|
||||
}
|
||||
|
||||
void OperationMap::InsertUnaryElementwiseOperations() {
|
||||
int type = 0;
|
||||
int num_oprands = 1;
|
||||
// For unary elementwise operations:
|
||||
// ${0} - x
|
||||
// ${1} - out
|
||||
// ${2} - dout
|
||||
|
||||
// relu:
|
||||
// out = f(x) = x > 0 ? x : 0
|
||||
// dx = dout * (out > 0 ? 1 : 0) = dout * (x > 0 ? 1 : 0)
|
||||
Insert(type, num_oprands, "relu", "real_max(${0}, 0)",
|
||||
{"${0} > 0 ? ${2} : 0"});
|
||||
// sigmoid:
|
||||
// out = f(x) = 1.0 / (1.0 + exp(-x))
|
||||
// dx = dout * out * (1 - out)
|
||||
Insert(type, num_oprands, "sigmoid", "1.0 / (1.0 + real_exp(- ${0}))",
|
||||
{"${2} * ${1} * (1.0 - ${1})"});
|
||||
// tanh:
|
||||
// out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0;
|
||||
// dx = dout * (1 - out * out)
|
||||
Insert(type, num_oprands, "tanh", "2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0",
|
||||
{"${2} * (1.0 - ${1} * ${1})"});
|
||||
}
|
||||
|
||||
void OperationMap::InsertBinaryElementwiseOperations() {
|
||||
int type = 0;
|
||||
int num_oprands = 2;
|
||||
// For binary elementwise oprations:
|
||||
// ${0} - x
|
||||
// ${1} - y
|
||||
// ${2} - out
|
||||
// ${3} - dout
|
||||
|
||||
// elementwise_add:
|
||||
// out = x + y
|
||||
// dx = dout * 1
|
||||
// dy = dout * 1
|
||||
Insert(type, num_oprands, "elementwise_add", "${0} + ${1}", {"${3}", "${3}"});
|
||||
// elementwise_sub:
|
||||
// out = x - y
|
||||
// dx = dout * 1
|
||||
// dy = dout * (-1)
|
||||
Insert(type, num_oprands, "elementwise_sub", "${0} - ${1}",
|
||||
{"${3}", "- ${3}"});
|
||||
// elementwise_mul:
|
||||
// out = x * y
|
||||
// dx = dout * y
|
||||
// dy = dout * x
|
||||
Insert(type, num_oprands, "elementwise_mul", "${0} * ${1}",
|
||||
{"${3} * ${1}", "${3} * ${0}"});
|
||||
Insert(type, num_oprands, "elementwise_div", "${0} / ${1}", {});
|
||||
Insert(type, num_oprands, "elementwise_min", "real_min(${0}, ${1})", {});
|
||||
Insert(type, num_oprands, "elementwise_max", "real_max(${0}, ${1})", {});
|
||||
}
|
||||
|
||||
} // namespace fusion_group
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,99 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
namespace fusion_group {
|
||||
|
||||
struct Operation {
|
||||
Operation() {}
|
||||
Operation(int t, int n, std::string o, std::vector<std::string> e)
|
||||
: type(t), num_operands(n), op_type(o), exprs(e) {}
|
||||
|
||||
bool IsGradOp() {
|
||||
std::string suffix = "_grad";
|
||||
return op_type.rfind(suffix) == (op_type.length() - suffix.length());
|
||||
}
|
||||
|
||||
bool IsValid() {
|
||||
if (!IsGradOp() && exprs.size() != 1U) {
|
||||
return false;
|
||||
}
|
||||
if (IsGradOp() && exprs.size() != static_cast<size_t>(num_operands)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int type;
|
||||
int num_operands;
|
||||
std::string op_type;
|
||||
std::vector<std::string> exprs;
|
||||
};
|
||||
|
||||
class OperationMap {
|
||||
public:
|
||||
OperationMap();
|
||||
|
||||
static OperationMap& Instance() {
|
||||
PADDLE_ENFORCE_NOT_NULL(map, "Need to initialize OperationMap first!");
|
||||
return *map;
|
||||
}
|
||||
|
||||
static OperationMap& Init() {
|
||||
if (map == nullptr) {
|
||||
map = new OperationMap();
|
||||
}
|
||||
return *map;
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> Find(int type, int num_operands = -1);
|
||||
|
||||
bool Has(std::string op_type) {
|
||||
return operations_.find(op_type) != operations_.end();
|
||||
}
|
||||
|
||||
Operation& Get(std::string op_type) {
|
||||
auto iter = operations_.find(op_type);
|
||||
PADDLE_ENFORCE_NE(iter, operations_.end(),
|
||||
"Operation %s is not supported yet.", op_type);
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
private:
|
||||
void Insert(int type, int num_operands, std::string op_type, std::string expr,
|
||||
std::vector<std::string> grad_exprs);
|
||||
|
||||
void InsertUnaryElementwiseOperations();
|
||||
void InsertBinaryElementwiseOperations();
|
||||
|
||||
private:
|
||||
static OperationMap* map;
|
||||
std::unordered_map<std::string, Operation> operations_;
|
||||
DISABLE_COPY_AND_ASSIGN(OperationMap);
|
||||
};
|
||||
|
||||
} // namespace fusion_group
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Loading…
Reference in new issue