|
|
|
@ -26,16 +26,11 @@ class CosSimOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
// notnull check
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of CosSimOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"),
|
|
|
|
|
"Input(Y) of CosSimOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of CosSimOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("XNorm"),
|
|
|
|
|
"Output(XNorm) of CosSimOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("YNorm"),
|
|
|
|
|
"Output(YNorm) of CosSimOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CosSim");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "CosSim");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CosSim");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("XNorm"), "Output", "XNorm", "CosSim");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("YNorm"), "Output", "YNorm", "CosSim");
|
|
|
|
|
|
|
|
|
|
// shape check
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
@ -48,19 +43,28 @@ class CosSimOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
|
|
|
|
|
"Ranks of Input(X) and Input(Y) must be equal.");
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
|
|
|
|
"Rank of Input(X) must not be less than 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims.size(), y_dims.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Ranks of Input(X) [%s] and Input(Y) [%s] must be equal.", x_dims,
|
|
|
|
|
y_dims));
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
x_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Rank of Input(X) %d must not be less than 2.", x_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
framework::slice_ddim(x_dims, 1, x_dims.size()),
|
|
|
|
|
framework::slice_ddim(y_dims, 1, y_dims.size()),
|
|
|
|
|
"All dimensions except the 1st of Input(X) and Input(Y) "
|
|
|
|
|
"must be equal.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"All dimensions except the 1st of Input(X) [%s] and Input(Y) [%s]"
|
|
|
|
|
"must be equal.",
|
|
|
|
|
x_dims, y_dims));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
x_dims[0] == y_dims[0] || y_dims[0] == 1,
|
|
|
|
|
"The 1st dimension of Input(Y) must be equal to Input(X) or"
|
|
|
|
|
" just 1 (which will be broadcasted to match Input(X)).");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The 1st dimension of Input(Y) %d must be equal to Input(X) %d or"
|
|
|
|
|
" just 1 (which will be broadcasted to match Input(X)).",
|
|
|
|
|
y_dims[0], x_dims[0]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// resize tensor
|
|
|
|
@ -116,13 +120,13 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
// notnull check
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("XNorm"), "Input(XNorm) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("YNorm"), "Input(YNorm) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) must not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CosSimGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "CosSimGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("XNorm"), "Input", "XNorm", "CosSimGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("YNorm"), "Input", "YNorm", "CosSimGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "CosSimGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
framework::GradVarName("Out"), "CosSimGrad");
|
|
|
|
|
|
|
|
|
|
// shape check
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
@ -133,26 +137,48 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
|
|
|
|
|
"Ranks of Input(X) and Input(Y) must be equal.");
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
|
|
|
|
"Rank of Input(X) must not be less than 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Ranks of Input(X) %d and Input(Y) %d must be equal.",
|
|
|
|
|
x_dims.size(), y_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
x_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Rank of Input(X) %d must not be less than 2.", x_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
framework::slice_ddim(x_dims, 1, x_dims.size()),
|
|
|
|
|
framework::slice_ddim(y_dims, 1, y_dims.size()),
|
|
|
|
|
"All dimensions except the 1st of Input(X) and Input(Y) "
|
|
|
|
|
"must be equal.");
|
|
|
|
|
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1,
|
|
|
|
|
"The 1st dimension of Input(Y) must be equal to Input(X) or"
|
|
|
|
|
" just 1 (which will be broadcasted to match Input(X)).");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"All dimensions except the 1st of Input(X) [%s] and Input(Y) [%s] "
|
|
|
|
|
"must be equal.",
|
|
|
|
|
x_dims, y_dims));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
true, x_dims[0] == y_dims[0] || y_dims[0] == 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The 1st dimension of Input(Y) %d must be equal to Input(X) %d or"
|
|
|
|
|
" just 1 (which will be broadcasted to match Input(X)).",
|
|
|
|
|
y_dims[0], x_dims[0]));
|
|
|
|
|
auto target_xnorm_dims = framework::make_ddim({x_dims[0], 1});
|
|
|
|
|
auto target_ynorm_dims = framework::make_ddim({y_dims[0], 1});
|
|
|
|
|
PADDLE_ENFORCE_EQ(xnorm_dims, target_xnorm_dims,
|
|
|
|
|
"Shape of Input(XNorm) must be [X.Dim(0), 1].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ynorm_dims, target_ynorm_dims,
|
|
|
|
|
"Shape of Input(YNorm) must be [Y.Dim(0), 1].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims, target_xnorm_dims,
|
|
|
|
|
"Shape of Input(Out) must be [X.Dim(0), 1].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_grad_dims, target_xnorm_dims,
|
|
|
|
|
"Shape of Input(Out@Grad) must be [X.Dim(0), 1].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
xnorm_dims, target_xnorm_dims,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Shape of Input(XNorm) [%s] must be (X.Dim(0), 1) - [%s]",
|
|
|
|
|
xnorm_dims, target_xnorm_dims));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ynorm_dims, target_ynorm_dims,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Shape of Input(YNorm) [%s] must be (Y.Dim(0), 1) - [%s]",
|
|
|
|
|
ynorm_dims, target_ynorm_dims));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_dims, target_xnorm_dims,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Shape of Input(Out) [%s] must be (X.Dim(0), 1) - [%s]", out_dims,
|
|
|
|
|
target_xnorm_dims));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_grad_dims, target_xnorm_dims,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Shape of Input(Out@Grad) [%s] must be (X.Dim(0), 1) - [%s]",
|
|
|
|
|
out_grad_dims, target_xnorm_dims));
|
|
|
|
|
|
|
|
|
|
// resize tensor
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|