|
|
|
@ -26,14 +26,14 @@ template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
template <typename DeviceContext, typename T, typename AttrType = T>
|
|
|
|
|
class NormKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
|
|
|
|
|
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
|
|
|
|
|
auto* out = context.Output<framework::Tensor>("Out");
|
|
|
|
|
T epsilon = context.Attr<T>("epsilon");
|
|
|
|
|
auto epsilon = static_cast<T>(context.Attr<AttrType>("epsilon"));
|
|
|
|
|
out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
int batch_size = in_x->dims()[0];
|
|
|
|
|
int channels = in_x->dims()[1];
|
|
|
|
@ -82,7 +82,7 @@ class NormKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
template <typename DeviceContext, typename T, typename AttrType = T>
|
|
|
|
|
class NormGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
@ -90,7 +90,7 @@ class NormGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
|
|
|
|
|
const framework::Tensor* out_grad =
|
|
|
|
|
context.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
T epsilon = context.Attr<T>("epsilon");
|
|
|
|
|
auto epsilon = static_cast<T>(context.Attr<AttrType>("epsilon"));
|
|
|
|
|
framework::Tensor* in_x_grad =
|
|
|
|
|
context.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
in_x_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|