|
|
|
@ -20,7 +20,11 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
functor_list.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid functor list size %d, which should be equal to %d.",
|
|
|
|
|
functor_list.size(), 2));
|
|
|
|
|
static std::unordered_set<std::string> binary_fun = {
|
|
|
|
|
"elementwise_add", "elementwise_mul", "elementwise_add_grad",
|
|
|
|
|
"elementwise_mul_grad"};
|
|
|
|
@ -28,7 +32,11 @@ bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasInPlaceUnary(const std::vector<std::string> &functor_list) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
functor_list.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid functor list size %d, which should be equal to %d.",
|
|
|
|
|
functor_list.size(), 2));
|
|
|
|
|
static std::unordered_set<std::string> InplaceOpSet = {"relu", "relu_grad"};
|
|
|
|
|
bool is_in_place = false;
|
|
|
|
|
for (auto &func_name : functor_list) {
|
|
|
|
@ -38,7 +46,11 @@ bool HasInPlaceUnary(const std::vector<std::string> &functor_list) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
functor_list.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid functor list size %d, which should be equal to %d.",
|
|
|
|
|
functor_list.size(), 2));
|
|
|
|
|
static std::unordered_set<std::string> binary_fun = {"elementwise_add_grad"};
|
|
|
|
|
return binary_fun.count(functor_list[0]) != 0 ||
|
|
|
|
|
binary_fun.count(functor_list[1]) != 0;
|
|
|
|
@ -50,7 +62,11 @@ bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
|
|
|
|
|
* out.
|
|
|
|
|
*/
|
|
|
|
|
static bool IsSupportedCompound(const std::vector<std::string> &functors) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(functors.size(), 2UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
functors.size(), 2UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid functor list size %d, which should be equal to %d.",
|
|
|
|
|
functors.size(), 2));
|
|
|
|
|
|
|
|
|
|
static std::unordered_set<std::string> unary_fun = {"scale", "relu", "tanh",
|
|
|
|
|
"sigmoid"};
|
|
|
|
@ -63,11 +79,12 @@ static bool IsSupportedCompound(const std::vector<std::string> &functors) {
|
|
|
|
|
} else if (binary_fun.count(functors[1])) {
|
|
|
|
|
unary_fun_str = functors[0];
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("%s and %s are not included in fused_list.", functors[0],
|
|
|
|
|
functors[1]);
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"%s and %s are not included in fused_list.", functors[0], functors[1]));
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1,
|
|
|
|
|
"%s is not included in fused_list.", unary_fun_str);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"%s is not included in fused_list.", unary_fun_str));
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -76,15 +93,18 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of FusedElemwiseActivationOp op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("Y"),
|
|
|
|
|
"Input(Y) of FusedElemwiseActivationOp op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FusedElemwiseActivationOp op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("X"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) of FusedElemwiseActivationOp op should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("Y"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Y) of FusedElemwiseActivationOp op should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput("Out"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(Out) of FusedElemwiseActivationOp op should not be null."));
|
|
|
|
|
|
|
|
|
|
auto x_dim = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dim = ctx->GetInputDim("Y");
|
|
|
|
@ -97,9 +117,11 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
|
|
|
|
|
std::string out_lod = bcast_y ? "X" : "Y";
|
|
|
|
|
|
|
|
|
|
if (ctx->Attrs().Get<bool>("save_intermediate_out")) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("IntermediateOut"),
|
|
|
|
|
"Output(IntermediateOut) of FusedElemwiseActivationOp "
|
|
|
|
|
"should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput("IntermediateOut"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(IntermediateOut) of FusedElemwiseActivationOp "
|
|
|
|
|
"should not be null."));
|
|
|
|
|
|
|
|
|
|
if (IsUnaryCompound(
|
|
|
|
|
ctx->Attrs().Get<std::vector<std::string>>("functor_list"))) {
|
|
|
|
@ -139,7 +161,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(),
|
|
|
|
|
ctx.Input<framework::Tensor>("Y")->type(),
|
|
|
|
|
"The element's type of input should be the same.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The element's type of input should be the same."));
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
@ -173,7 +196,10 @@ class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddAttr<std::vector<std::string>>("functor_list",
|
|
|
|
|
"The functors that should be fused.")
|
|
|
|
|
.AddCustomChecker([&](const std::vector<std::string> &functor_list) {
|
|
|
|
|
PADDLE_ENFORCE(IsSupportedCompound(functor_list));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
IsSupportedCompound(functor_list), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"the input functors should support compounding."));
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
@ -266,18 +292,22 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@Grad) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Out@Grad) should not be null."));
|
|
|
|
|
|
|
|
|
|
auto functor_list =
|
|
|
|
|
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
|
|
|
|
|
|
|
|
|
|
if (ctx->Attrs().Get<bool>("save_intermediate_out")) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("IntermediateOut"),
|
|
|
|
|
"Input(IntermediateOut) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("IntermediateOut"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(IntermediateOut) should not be null."));
|
|
|
|
|
} else {
|
|
|
|
|
if (!InputXCanBeAbsent(functor_list)) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("X"), true,
|
|
|
|
|
platform::errors::InvalidArgument("Input(X) should not be null."));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -292,9 +322,11 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
} else {
|
|
|
|
|
// Currently, only when Binary is elementwise_add or elementwise_sub,
|
|
|
|
|
// the "X" could be absent.
|
|
|
|
|
PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
|
|
|
|
|
"Only when BinaryFunctor is elementwise_add, the 'X' "
|
|
|
|
|
"could be absent.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
InputXCanBeAbsent(functor_list), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Only when BinaryFunctor is elementwise_add, the 'X' "
|
|
|
|
|
"could be absent."));
|
|
|
|
|
|
|
|
|
|
// Node: If "X" is absence, the shape of Y should be a continuous
|
|
|
|
|
// subsequence of X, otherwise, we could not infer the shape of dx.
|
|
|
|
@ -306,7 +338,9 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput(y_grad_name)) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("Y"), true,
|
|
|
|
|
platform::errors::InvalidArgument("Input(Y) should not be null."));
|
|
|
|
|
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
|
|
|
|
|
ctx->ShareLoD("Y", y_grad_name);
|
|
|
|
|
}
|
|
|
|
|