|
|
|
@ -34,34 +34,34 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
auto weight_dims = ctx->GetInputDim("Weight");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The input X must be a 2D Tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims.size(), 2, "The input Y must be a 2D Tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(weight_dims.size(), 3,
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The input X must be a 2D Tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input Y must be a 2D Tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(weight_dims.size(), 3UL,
|
|
|
|
|
"The input Weight must be a 3D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_GT(weight_dims[0], 0,
|
|
|
|
|
PADDLE_ENFORCE(weight_dims[0],
|
|
|
|
|
"The first dimension of Weight must be larger than 0.");
|
|
|
|
|
PADDLE_ENFORCE_GT(weight_dims[1], 0,
|
|
|
|
|
PADDLE_ENFORCE(weight_dims[1],
|
|
|
|
|
"The second dimension of Weight must be larger than 0.");
|
|
|
|
|
PADDLE_ENFORCE_GT(weight_dims[2], 0,
|
|
|
|
|
PADDLE_ENFORCE(weight_dims[2],
|
|
|
|
|
"The third dimension of Weight must be larger than 0.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
|
|
|
|
|
"The first dimension(batch_size) of X must be "
|
|
|
|
|
"equal with the first dimension of the Y.");
|
|
|
|
|
"equal to the first dimension of the Y.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1],
|
|
|
|
|
"The second dimension of X must be equal with the second "
|
|
|
|
|
"The second dimension of X must be equal to the second "
|
|
|
|
|
"dimension of the Weight.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2],
|
|
|
|
|
"The second dimension of Y must be equal with the third "
|
|
|
|
|
"The second dimension of Y must be equal to the third "
|
|
|
|
|
"dimension of the Weight.");
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("Bias")) {
|
|
|
|
|
auto bias_dims = ctx->GetInputDim("Bias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_dims.size(), 2,
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_dims.size(), 2UL,
|
|
|
|
|
"The input Bias must have 2 dimensions.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_dims[0], 1,
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_dims[0], 1UL,
|
|
|
|
|
"The first dimention of input Bias must be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0],
|
|
|
|
|
"The second dimension of Bias must be equal with the "
|
|
|
|
|
"The second dimension of Bias must be equal to the "
|
|
|
|
|
"first dimension of the Weight.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -75,12 +75,12 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
BilinearTensorProductOpMaker(framework::OpProto* proto,
|
|
|
|
|
framework::OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The first input of BilinearTensorProduct op");
|
|
|
|
|
AddInput("Y", "The second input of BilinearTensorProduct op");
|
|
|
|
|
AddInput("Weight", "The input weight of BilinearTensorProduct op");
|
|
|
|
|
AddInput("Bias", "The input bias of BilinearTensorProduct op")
|
|
|
|
|
AddInput("X", "The first input of BilinearTensorProduct op.");
|
|
|
|
|
AddInput("Y", "The second input of BilinearTensorProduct op.");
|
|
|
|
|
AddInput("Weight", "The input weight of BilinearTensorProduct op.");
|
|
|
|
|
AddInput("Bias", "The input bias of BilinearTensorProduct op.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddOutput("Out", "The output of BilinearTensorProduct op");
|
|
|
|
|
AddOutput("Out", "The output of BilinearTensorProduct op.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Bilinear Tensor Product operator.
|
|
|
|
|
Given input X and Y, a 3D tensor weight, and bias. Each column of the
|
|
|
|
@ -99,30 +99,32 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Weight"), "Input(Weight) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Weight"),
|
|
|
|
|
"Input(Weight) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input (Out@GRAD) should not be null");
|
|
|
|
|
"Input (Out@GRAD) should not be null.");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
auto weight_dims = ctx->GetInputDim("Weight");
|
|
|
|
|
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims.size(), 2, "The Out@GRAD must be a 2D Tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims.size(), 2UL,
|
|
|
|
|
"The Out@GRAD must be a 2D Tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[0], out_dims[0],
|
|
|
|
|
"The first dimension(batch_size) of Out@GRAD must be equal with "
|
|
|
|
|
"the first dimension of the X.");
|
|
|
|
|
"The first dimension(batch_size) of Out@GRAD must be equal to "
|
|
|
|
|
"the first dimension of the Input(X).");
|
|
|
|
|
PADDLE_ENFORCE_EQ(weight_dims[0], out_dims[1],
|
|
|
|
|
"The second dimension of Out@GRAD must be equal with "
|
|
|
|
|
"the third dimension of the Weight.");
|
|
|
|
|
"The second dimension of Out@GRAD must be equal to "
|
|
|
|
|
"the third dimension of the Input(Weight).");
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("Bias")) {
|
|
|
|
|
auto bias_dims = ctx->GetInputDim("Bias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_dims[1], out_dims[1],
|
|
|
|
|
"The second dimension of Bias must be equal with "
|
|
|
|
|
"the second dimension of the Out@GRAD.");
|
|
|
|
|
"The second dimension of Out@GRAD must be equal to "
|
|
|
|
|
"the second dimension of the Input(Bias).");
|
|
|
|
|
auto bias_grad_name = framework::GradVarName("Bias");
|
|
|
|
|
if (ctx->HasOutput(bias_grad_name))
|
|
|
|
|
ctx->SetOutputDim(bias_grad_name, bias_dims);
|
|
|
|
|