|
|
|
@ -32,17 +32,18 @@ class CosSimOp : public framework::OperatorWithKernel {
|
|
|
|
|
// shape check
|
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto y_dims = ctx.Input<Tensor>("Y")->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::arity(x_dims), framework::arity(y_dims),
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
|
|
|
|
|
"Ranks of Input(X) and Input(Y) must be equal.");
|
|
|
|
|
PADDLE_ENFORCE_GE(framework::arity(x_dims), 2,
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
|
|
|
|
"Rank of Input(X) must not be less than 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
framework::slice_ddim(x_dims, 1, framework::arity(x_dims)),
|
|
|
|
|
framework::slice_ddim(y_dims, 1, framework::arity(y_dims)),
|
|
|
|
|
"All dimensions except 1st of Input(X) and Input(Y) must be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()),
|
|
|
|
|
framework::slice_ddim(y_dims, 1, y_dims.size()),
|
|
|
|
|
"All dimensions except the 1st of Input(X) and Input(Y) "
|
|
|
|
|
"must be equal.");
|
|
|
|
|
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1,
|
|
|
|
|
"1st dimension of Input(Y) must be equal to Input(X) or "
|
|
|
|
|
"just 1 (which will be broadcasted to match Input(X)).");
|
|
|
|
|
"The 1st dimension of Input(Y) must be equal to Input(X) or"
|
|
|
|
|
" just 1 (which will be broadcasted to match Input(X)).");
|
|
|
|
|
|
|
|
|
|
// resize tensor
|
|
|
|
|
ctx.Output<Tensor>("Out")->Resize({x_dims[0], 1});
|
|
|
|
@ -58,8 +59,14 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("X", "The 1st input of cos_sim op.");
|
|
|
|
|
AddInput("Y", "The 2nd input of cos_sim op.");
|
|
|
|
|
AddOutput("Out", "The output of cos_sim op.");
|
|
|
|
|
AddOutput("XNorm", "Row norm of the first input.").AsIntermediate();
|
|
|
|
|
AddOutput("YNorm", "Row norm of the second input.").AsIntermediate();
|
|
|
|
|
AddOutput("XNorm",
|
|
|
|
|
"Norm of the first input, reduced along the 1st "
|
|
|
|
|
"dimension.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("YNorm",
|
|
|
|
|
"Norm of the second input, reduced along the 1st "
|
|
|
|
|
"dimension.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Cosine Similarity Operator.
|
|
|
|
@ -95,29 +102,32 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
// shape check
|
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto y_dims = ctx.Input<Tensor>("Y")->dims();
|
|
|
|
|
PADDLE_ENFORCE_GE(framework::arity(x_dims), framework::arity(y_dims),
|
|
|
|
|
"Ranks of Input(X) and Input(Y) must be equal.");
|
|
|
|
|
PADDLE_ENFORCE_GE(framework::arity(x_dims), 2,
|
|
|
|
|
"Rank of Input(X) must not be less than 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
framework::slice_ddim(x_dims, 1, framework::arity(x_dims)),
|
|
|
|
|
framework::slice_ddim(y_dims, 1, framework::arity(y_dims)),
|
|
|
|
|
"All dimensions except 1st of Input(X) and Input(Y) must be equal.");
|
|
|
|
|
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1,
|
|
|
|
|
"1st dimension of Input(Y) must be equal to Input(X) or "
|
|
|
|
|
"just 1 (which will be broadcasted to match Input(X)).");
|
|
|
|
|
auto xnorm_dims = ctx.Input<Tensor>("XNorm")->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(xnorm_dims, framework::make_ddim({x_dims[0], 1}),
|
|
|
|
|
"Shape of Input(XNorm) must be [X.Dim(0), 1].");
|
|
|
|
|
auto ynorm_dims = ctx.Input<Tensor>("YNorm")->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(ynorm_dims, framework::make_ddim({y_dims[0], 1}),
|
|
|
|
|
"Shape of Input(YNorm) must be [Y.Dim(0), 1].");
|
|
|
|
|
auto out_dims = ctx.Input<Tensor>("Out")->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims, framework::make_ddim({x_dims[0], 1}),
|
|
|
|
|
"Shape of Input(Out) must be [X.Dim(0), 1].");
|
|
|
|
|
auto out_grad_dims =
|
|
|
|
|
ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_grad_dims, framework::make_ddim({x_dims[0], 1}),
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
|
|
|
|
|
"Ranks of Input(X) and Input(Y) must be equal.");
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
|
|
|
|
"Rank of Input(X) must not be less than 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()),
|
|
|
|
|
framework::slice_ddim(y_dims, 1, y_dims.size()),
|
|
|
|
|
"All dimensions except the 1st of Input(X) and Input(Y) "
|
|
|
|
|
"must be equal.");
|
|
|
|
|
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1,
|
|
|
|
|
"The 1st dimension of Input(Y) must be equal to Input(X) or"
|
|
|
|
|
" just 1 (which will be broadcasted to match Input(X)).");
|
|
|
|
|
auto target_xnorm_dims = framework::make_ddim({x_dims[0], 1}),
|
|
|
|
|
auto target_ynorm_dims = framework::make_ddim({y_dims[0], 1}),
|
|
|
|
|
PADDLE_ENFORCE_EQ(xnorm_dims, target_xnorm_dims,
|
|
|
|
|
"Shape of Input(XNorm) must be [X.Dim(0), 1].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ynorm_dims, target_ynorm_dims,
|
|
|
|
|
"Shape of Input(YNorm) must be [Y.Dim(0), 1].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims, target_xnorm_dims,
|
|
|
|
|
"Shape of Input(Out) must be [X.Dim(0), 1].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_grad_dims, target_xnorm_dims,
|
|
|
|
|
"Shape of Input(Out@Grad) must be [X.Dim(0), 1].");
|
|
|
|
|
|
|
|
|
|
// resize tensor
|
|
|
|
|