|
|
|
@ -40,7 +40,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
|
|
|
|
|
"inputs must be same.");
|
|
|
|
|
|
|
|
|
|
int rank = framework::arity(x_dims);
|
|
|
|
|
PADDLE_ENFORCE(rank >= 2, "Tensor rank should be at least equal to 2.");
|
|
|
|
|
PADDLE_ENFORCE_GE(rank, 2, "Tensor rank should be at least equal to 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(x_dims) / x_dims[0],
|
|
|
|
|
framework::product(y_dims) / y_dims[0],
|
|
|
|
|
"Product of dimensions expcet the first dimension of "
|
|
|
|
@ -87,7 +87,6 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
void InferShape(const framework::InferShapeContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
|
"Gradient of Out should not be null");
|
|
|
|
|
// check out grad dimensions
|
|
|
|
|
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
|
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto y_dims = ctx.Input<Tensor>("Y")->dims();
|
|
|
|
|