|
|
|
|
@ -14,12 +14,15 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <algorithm> // for max
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/data_layout.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/operator.h"
|
|
|
|
|
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
@ -35,12 +38,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of elementwise op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"),
|
|
|
|
|
"Input(Y) of elementwise op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of elementwise op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
"Input(X) of elementwise op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
|
|
|
|
|
"Input(Y) of elementwise op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
"Output(Out) of elementwise op should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->GetInputsVarType("Y").front() ==
|
|
|
|
|
@ -49,18 +52,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->GetInputsVarType("Y").front(), ctx->Inputs("Y").front());
|
|
|
|
|
|
|
|
|
|
if (ctx->GetInputsVarType("X").front() ==
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
auto x_dim = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dim = ctx->GetInputDim("Y");
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
x_dim.size(), y_dim.size(),
|
|
|
|
|
"ShapeError: the dimension of input X must greater than or equal to "
|
|
|
|
|
"the one of input Y. But received: the shape of input X = [%s], the "
|
|
|
|
|
"dimension of input X = %d, the shape of input Y = [%s], the "
|
|
|
|
|
"dimension of input Y = %d",
|
|
|
|
|
x_dim, x_dim.size(), y_dim, y_dim.size());
|
|
|
|
|
} else if (ctx->GetInputsVarType("X").front() ==
|
|
|
|
|
framework::proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
framework::proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputDim("Y").size(), 1u,
|
|
|
|
|
"ShapeError: For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
|
|
|
|
|
@ -71,13 +63,31 @@ class ElementwiseOp : public framework::OperatorWithKernel {
|
|
|
|
|
"ShapeError: For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
|
|
|
|
|
"), Y must be scalar. But reveived the first dimension of Y = %s",
|
|
|
|
|
ctx->GetInputDim("Y")[0]);
|
|
|
|
|
} else {
|
|
|
|
|
} else if (ctx->GetInputsVarType("X").front() !=
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
|
|
|
|
|
ctx->GetInputsVarType("X").front());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->ShareDim("X", /*->*/ "Out");
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
if (ctx->GetInputDim("X") == ctx->GetInputDim("Y")) {
|
|
|
|
|
ctx->ShareDim("X", /*->*/ "Out");
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
} else {
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
int max_dim = std::max(x_dims.size(), y_dims.size());
|
|
|
|
|
int axis = ctx->Attrs().Get<int>("axis");
|
|
|
|
|
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
|
|
|
|
|
std::vector<int> x_dims_array(max_dim);
|
|
|
|
|
std::vector<int> y_dims_array(max_dim);
|
|
|
|
|
std::vector<int> out_dims_array(max_dim);
|
|
|
|
|
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
|
|
|
|
|
y_dims_array.data(), out_dims_array.data(),
|
|
|
|
|
max_dim, axis);
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_array));
|
|
|
|
|
// to do
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
@ -207,26 +217,14 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
auto out_grad_name = framework::GradVarName("Out");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(out_grad_name),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim(out_grad_name);
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
x_dims.size(), y_dims.size(),
|
|
|
|
|
"ShapeError: the dimension of Out@GRAD must greater than or equal to "
|
|
|
|
|
"the one of input Y. But received: the shape of Out@GRAD = [%s], the "
|
|
|
|
|
"dimension of Out@GRAD = %d, the shape of input Y = [%s], the "
|
|
|
|
|
"dimension of of input Y = %d",
|
|
|
|
|
x_dims, x_dims.size(), y_dims, y_dims.size());
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
|
|
|
|
|
"Input(Out@GRAD) should not be null.");
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
auto y_grad_name = framework::GradVarName("Y");
|
|
|
|
|
if (ctx->HasOutput(x_grad_name)) {
|
|
|
|
|
ctx->ShareDim(out_grad_name, /*->*/ x_grad_name);
|
|
|
|
|
ctx->ShareLoD(out_grad_name, /*->*/ x_grad_name);
|
|
|
|
|
ctx->ShareDim("X", /*->*/ x_grad_name);
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ x_grad_name);
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasOutput(y_grad_name)) {
|
|
|
|
|
ctx->ShareDim("Y", /*->*/ y_grad_name);
|
|
|
|
|
@ -326,32 +324,6 @@ class ElementwiseOpDoubleGradWithoutDXDY
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// For Add, Sub op, the X, Out is not needed.
|
|
|
|
|
class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
|
|
|
|
|
public:
|
|
|
|
|
using operators::ElementwiseOpGrad::ElementwiseOpGrad;
|
|
|
|
|
using operators::ElementwiseOpGrad::GetExpectedKernelType;
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
if (ctx->HasOutput(x_grad_name)) {
|
|
|
|
|
ctx->ShareDim(framework::GradVarName("Out"), /*->*/ x_grad_name);
|
|
|
|
|
ctx->ShareLoD(framework::GradVarName("Out"), /*->*/ x_grad_name);
|
|
|
|
|
}
|
|
|
|
|
auto y_grad_name = framework::GradVarName("Y");
|
|
|
|
|
if (ctx->HasOutput(y_grad_name)) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
|
|
|
|
|
|
|
|
|
|
ctx->ShareDim("Y", /*->*/ y_grad_name);
|
|
|
|
|
ctx->ShareLoD("Y", /*->*/ y_grad_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ElemwiseGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
@ -372,13 +344,13 @@ DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplace,
|
|
|
|
|
framework::GradVarName("X")});
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplace, {"DDX", "DDOut"});
|
|
|
|
|
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "Y");
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "X",
|
|
|
|
|
"Y");
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
|
|
|
|
|
"Y", "DOut");
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
#define REGISTER_ELEMWISE_GRAD_MAKER(kernel_type, op_name) \
|
|
|
|
|
template <typename T> \
|
|
|
|
|
class kernel_type##GradMaker \
|
|
|
|
|
@ -390,6 +362,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
|
|
|
|
|
std::unique_ptr<T> Apply() const override { \
|
|
|
|
|
auto *op = new T(); \
|
|
|
|
|
op->SetType(#kernel_type "_grad"); \
|
|
|
|
|
op->SetInput("X", this->Input("X")); \
|
|
|
|
|
op->SetInput("Y", this->Input("Y")); \
|
|
|
|
|
op->SetInput(::paddle::framework::GradVarName("Out"), \
|
|
|
|
|
this->OutputGrad("Out")); \
|
|
|
|
|
@ -402,41 +375,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
|
|
|
|
|
} \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \
|
|
|
|
|
class __ElemwiseOp##op_type##Maker__ \
|
|
|
|
|
: public ::paddle::operators::ElementwiseOpMaker { \
|
|
|
|
|
protected: \
|
|
|
|
|
virtual std::string GetName() const { return op_name; } \
|
|
|
|
|
virtual std::string GetEquation() const { return equation; } \
|
|
|
|
|
}; \
|
|
|
|
|
REGISTER_OPERATOR( \
|
|
|
|
|
op_type, ::paddle::operators::ElementwiseOp, \
|
|
|
|
|
__ElemwiseOp##op_type##Maker__, \
|
|
|
|
|
::paddle::operators::ElementwiseOpInferVarType, \
|
|
|
|
|
::paddle::framework::DefaultGradOpMaker<::paddle::framework::OpDesc, \
|
|
|
|
|
true>, \
|
|
|
|
|
::paddle::framework::DefaultGradOpMaker<::paddle::imperative::OpBase, \
|
|
|
|
|
true>); \
|
|
|
|
|
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
|
|
|
|
|
|
|
|
|
|
#define REGISTER_ELEMWISE_EXPLICIT_OP(op_type, op_name, equation) \
|
|
|
|
|
class __ElemwiseOp##op_type##Maker__ \
|
|
|
|
|
: public ::paddle::operators::ElementwiseOpMaker { \
|
|
|
|
|
protected: \
|
|
|
|
|
virtual std::string GetName() const { return op_name; } \
|
|
|
|
|
virtual std::string GetEquation() const { return equation; } \
|
|
|
|
|
}; \
|
|
|
|
|
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
|
|
|
|
|
__ElemwiseOp##op_type##Maker__, \
|
|
|
|
|
::paddle::operators::ElementwiseOpInferVarType, \
|
|
|
|
|
op_type##GradMaker<::paddle::framework::OpDesc>, \
|
|
|
|
|
op_type##GradMaker<::paddle::imperative::OpBase>, \
|
|
|
|
|
::paddle::operators::ElementwiseOpInplace); \
|
|
|
|
|
REGISTER_OPERATOR(op_type##_grad, \
|
|
|
|
|
::paddle::operators::ElementwiseOpExplicitGrad, \
|
|
|
|
|
::paddle::operators::ElementwiseGradOpInplace, \
|
|
|
|
|
::paddle::operators::ElementwiseGradNoBufVarsInference)
|
|
|
|
|
|
|
|
|
|
#define REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(op_type, op_name) \
|
|
|
|
|
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
|
|
|
|
|
::paddle::operators::Elementwise##op_name##OpMaker, \
|
|
|
|
|
|