|
|
|
@ -40,17 +40,27 @@ class CosSimOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
|
|
|
|
|
"Ranks of Input(X) and Input(Y) must be equal.");
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
|
|
|
|
"Rank of Input(X) must not be less than 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()),
|
|
|
|
|
framework::slice_ddim(y_dims, 1, y_dims.size()),
|
|
|
|
|
"All dimensions except the 1st of Input(X) and Input(Y) "
|
|
|
|
|
"must be equal.");
|
|
|
|
|
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1,
|
|
|
|
|
"The 1st dimension of Input(Y) must be equal to Input(X) or"
|
|
|
|
|
" just 1 (which will be broadcasted to match Input(X)).");
|
|
|
|
|
bool check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) &&
|
|
|
|
|
(framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) {
|
|
|
|
|
check = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
|
|
|
|
|
"Ranks of Input(X) and Input(Y) must be equal.");
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
|
|
|
|
"Rank of Input(X) must not be less than 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
framework::slice_ddim(x_dims, 1, x_dims.size()),
|
|
|
|
|
framework::slice_ddim(y_dims, 1, y_dims.size()),
|
|
|
|
|
"All dimensions except the 1st of Input(X) and Input(Y) "
|
|
|
|
|
"must be equal.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
x_dims[0] == y_dims[0] || y_dims[0] == 1,
|
|
|
|
|
"The 1st dimension of Input(Y) must be equal to Input(X) or"
|
|
|
|
|
" just 1 (which will be broadcasted to match Input(X)).");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// resize tensor
|
|
|
|
|
ctx->SetOutputDim("Out", {x_dims[0], 1});
|
|
|
|
|