|
|
|
@ -210,17 +210,35 @@ class LayerNormKernel : public framework::OpKernel<T> {
|
|
|
|
|
ctx, &out, bias, /*axis*/ 1, AddFunctor<T>(), &out);
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_ENFORCE_EQ(mean->numel(), left);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->numel(), left);
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale->numel(), right);
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias->numel(), right);
|
|
|
|
|
PADDLE_ENFORCE_EQ(mean->numel(), left,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"mean's length (%d) is not equal with expected (%d).",
|
|
|
|
|
mean->numel(), left));
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->numel(), left,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"var's length (%d) is not equal with expected (%d).",
|
|
|
|
|
var->numel(), left));
|
|
|
|
|
if (scale) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
scale->numel(), right,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"scale's length (%d) is not equal with expected (%d).",
|
|
|
|
|
scale->numel(), right));
|
|
|
|
|
}
|
|
|
|
|
if (bias) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
bias->numel(), right,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"bias's length (%d) is not equal with expected (%d).",
|
|
|
|
|
bias->numel(), right));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto ker =
|
|
|
|
|
jit::KernelFuncs<jit::LayerNormTuple<T>, platform::CPUPlace>::Cache()
|
|
|
|
|
.At(right);
|
|
|
|
|
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
|
|
|
|
|
scale->data<T>(), bias->data<T>(), static_cast<int>(left),
|
|
|
|
|
static_cast<const float>(epsilon), right);
|
|
|
|
|
scale ? scale->data<T>() : nullptr, bias ? bias->data<T>() : nullptr,
|
|
|
|
|
static_cast<int>(left), static_cast<const float>(epsilon), right);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|