|
|
|
@ -43,8 +43,25 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("OutsideWeight"),
|
|
|
|
|
"If weights are provided, must specify both "
|
|
|
|
|
"inside and outside weights.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("InsideWeight"), x_dims);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("OutsideWeight"), x_dims);
|
|
|
|
|
auto dims = ctx->GetInputDim("InsideWeight");
|
|
|
|
|
bool check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) &&
|
|
|
|
|
(framework::product(dims) <= 0 || framework::product(x_dims) <= 0)) {
|
|
|
|
|
check = false;
|
|
|
|
|
}
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims, x_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dims = ctx->GetInputDim("OutsideWeight");
|
|
|
|
|
check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) &&
|
|
|
|
|
(framework::product(dims) <= 0 || framework::product(x_dims) <= 0)) {
|
|
|
|
|
check = false;
|
|
|
|
|
}
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims, x_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Diff", x_dims);
|
|
|
|
|