|
|
|
@ -79,9 +79,16 @@ class AffineChannelOp : public framework::OperatorWithKernel {
|
|
|
|
|
: x_dims[x_dims.size() - 1]);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dims[0], C);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 1UL);
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dims[0], C);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], C);
|
|
|
|
|
} else {
|
|
|
|
|
if (scale_dims[0] > 0 && b_dims[0] > 0) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dims[0], C);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], C);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
|
|
|
|
ctx->ShareLoD("X", "Out");
|
|
|
|
|