|
|
|
@ -36,21 +36,22 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
framework::arity(x_dims), framework::arity(y_dims),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) and Input(X) of SquaredL2DistanceOp should ",
|
|
|
|
|
"have same dimensions.",
|
|
|
|
|
"But received X's shape = [%s] and Y's shape = [%s],",
|
|
|
|
|
"the dimensions are %d and %d respectively", x_dims, y_dims,
|
|
|
|
|
framework::arity(x_dims), framework::arity(y_dims)));
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::arity(x_dims), framework::arity(y_dims),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) and Input(X) of SquaredL2DistanceOp should "
|
|
|
|
|
"have same dimensions. "
|
|
|
|
|
"But received X's shape = [%s] and Y's shape = [%s], "
|
|
|
|
|
"the dimensions are %d and %d respectively",
|
|
|
|
|
x_dims, y_dims, framework::arity(x_dims),
|
|
|
|
|
framework::arity(y_dims)));
|
|
|
|
|
|
|
|
|
|
int rank = framework::arity(x_dims);
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
rank, 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input dimensions of SquaredL2DistanceOp should be ", "at least 2.",
|
|
|
|
|
"But received shape = [%s] and dimension is %d.", x_dims, rank));
|
|
|
|
|
"Input dimensions of SquaredL2DistanceOp should be at least 2."
|
|
|
|
|
"But received shape = [%s] and dimension is %d.",
|
|
|
|
|
x_dims, rank));
|
|
|
|
|
bool check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) &&
|
|
|
|
|
(framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) {
|
|
|
|
@ -60,11 +61,12 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
product(x_dims) / x_dims[0], product(y_dims) / y_dims[0],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) and Input(Y) of SquaredL2DistanceOp should ",
|
|
|
|
|
"have same dimensions.",
|
|
|
|
|
"But received X's shape = [%s] and Y's shape = [%s]",
|
|
|
|
|
", the products are %d and %d respectively", x_dims, y_dims,
|
|
|
|
|
product(x_dims) / x_dims[0], product(y_dims) / y_dims[0]));
|
|
|
|
|
"Input(X) and Input(Y) of SquaredL2DistanceOp should "
|
|
|
|
|
"have same dimensions."
|
|
|
|
|
"But received X's shape = [%s] and Y's shape = [%s]"
|
|
|
|
|
", the products are %d and %d respectively",
|
|
|
|
|
x_dims, y_dims, product(x_dims) / x_dims[0],
|
|
|
|
|
product(y_dims) / y_dims[0]));
|
|
|
|
|
}
|
|
|
|
|
check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) && (y_dims[0] <= 0 || x_dims[0] <= 0)) {
|
|
|
|
@ -74,11 +76,11 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
y_dims[0] == 1 || y_dims[0] == x_dims[0], true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"First dimension of Input(Y) of SquaredL2DistanceOp ",
|
|
|
|
|
"must be equal to 1", "or to first dimension of Input(X).",
|
|
|
|
|
"But received X's shape = [%s] and Y's shape = [%s],",
|
|
|
|
|
"the first dimensions are %d and %d respectively", x_dims, y_dims,
|
|
|
|
|
x_dims[0], y_dims[0]));
|
|
|
|
|
"First dimension of Input(Y) of SquaredL2DistanceOp "
|
|
|
|
|
"must be equal to 1 or to first dimension of Input(X)."
|
|
|
|
|
"But received X's shape = [%s] and Y's shape = [%s],"
|
|
|
|
|
"the first dimensions are %d and %d respectively",
|
|
|
|
|
x_dims, y_dims, x_dims[0], y_dims[0]));
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("sub_result", {x_dims[0], product(x_dims) / x_dims[0]});
|
|
|
|
|
ctx->SetOutputDim("Out", {x_dims[0], 1});
|
|
|
|
@ -152,17 +154,18 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_dims[0], x_dims[0],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"First dimension of output gradient and Input(X) ",
|
|
|
|
|
"of SquaredL2DistanceGradOp must be equal",
|
|
|
|
|
"But received X's shape = [%s] and grad's shape = [%s],",
|
|
|
|
|
"the first dimensions are %d and %d respectively", x_dims,
|
|
|
|
|
out_dims, x_dims[0], out_dims[0]));
|
|
|
|
|
"First dimension of output gradient and Input(X) "
|
|
|
|
|
"of SquaredL2DistanceGradOp must be equal "
|
|
|
|
|
"But received X's shape = [%s] and grad's shape = [%s], "
|
|
|
|
|
"the first dimensions are %d and %d respectively",
|
|
|
|
|
x_dims, out_dims, x_dims[0], out_dims[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims[1], 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Second dimension of output gradient of ",
|
|
|
|
|
"SquaredL2DistanceGradOp must be 1."
|
|
|
|
|
"But received grad's shape = [%s],",
|
|
|
|
|
"with first dimensions %d", out_dims, out_dims[1]));
|
|
|
|
|
"Second dimension of output gradient of "
|
|
|
|
|
"SquaredL2DistanceGradOp must be 1. "
|
|
|
|
|
"But received grad's shape = [%s], "
|
|
|
|
|
"with second dimension %d",
|
|
|
|
|
out_dims, out_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
auto y_grad_name = framework::GradVarName("Y");
|
|
|
|
|