|
|
|
@ -34,18 +34,9 @@ class HuberLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
|
|
|
|
|
"The rank of Input(X) must be 2 and the shape is "
|
|
|
|
|
"[batch_size, 1].");
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
if (ctx->IsRuntime() ||
|
|
|
|
|
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims, y_dims, "Shape of X and Y should be same");
|
|
|
|
|
} else {
|
|
|
|
|
if (x_dims[0] != -1 && y_dims[0] != -1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
|
|
|
|
|
"The dim 0 of X and Y must be the same.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (x_dims[1] != -1 && y_dims[1] != -1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], y_dims[1],
|
|
|
|
|
"The dim 1 of X and Y must be the same.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], 1,
|
|
|
|
|