You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							357 lines
						
					
					
						
							12 KiB
						
					
					
				
			
		
		
	
	
							357 lines
						
					
					
						
							12 KiB
						
					
					
				/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License. */
 | 
						|
 | 
						|
#pragma once
 | 
						|
#include "paddle/fluid/framework/eigen.h"
 | 
						|
#include "paddle/fluid/framework/op_registry.h"
 | 
						|
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
 | 
						|
#include "paddle/fluid/operators/math/blas.h"
 | 
						|
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
 | 
						|
    !defined(__OSX__)
 | 
						|
#include "paddle/fluid/operators/math/jit_kernel.h"
 | 
						|
#endif
 | 
						|
#include "paddle/fluid/operators/math/math_function.h"
 | 
						|
 | 
						|
namespace paddle {
 | 
						|
namespace operators {
 | 
						|
 | 
						|
// Wrap RowwiseMean and ColwiseMean.
 | 
						|
// Reuse the cpu codes and replace the gpu codes with cublas_gemv, which is
 | 
						|
// significantly faster. Unlike the RowwiseMean and ColwiseMean, the
 | 
						|
// implementation only considers 2D.
 | 
						|
template <typename DeviceContext, typename T>
 | 
						|
struct RowwiseMean2D {
 | 
						|
  RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx);
 | 
						|
 | 
						|
  void operator()(const platform::DeviceContext& context,
 | 
						|
                  const framework::Tensor& input, framework::Tensor* vec);
 | 
						|
};
 | 
						|
 | 
						|
#ifdef PADDLE_WITH_CUDA
 | 
						|
template <typename T>
 | 
						|
class RowwiseMean2D<platform::CUDADeviceContext, T> {
 | 
						|
 public:
 | 
						|
  RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx)
 | 
						|
      : left_(left), right_(right) {
 | 
						|
    framework::DDim ones_dim({right_});
 | 
						|
    divisor_.mutable_data<T>(ones_dim, dev_ctx.GetPlace());
 | 
						|
    math::set_constant(dev_ctx, &divisor_, 1.0 / right);
 | 
						|
  }
 | 
						|
  void operator()(const platform::CUDADeviceContext& context,
 | 
						|
                  const framework::Tensor& input, framework::Tensor* out) {
 | 
						|
    math::GetBlas<platform::CUDADeviceContext, T>(context).GEMV(
 | 
						|
        false, left_, right_, 1., input.data<T>(), divisor_.data<T>(), 0.,
 | 
						|
        out->data<T>());
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  int left_;
 | 
						|
  int right_;
 | 
						|
  framework::Tensor divisor_;
 | 
						|
};
 | 
						|
#endif
 | 
						|
 | 
						|
template <typename T>
 | 
						|
class RowwiseMean2D<platform::CPUDeviceContext, T> {
 | 
						|
 public:
 | 
						|
  RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx) {}
 | 
						|
 | 
						|
  void operator()(const platform::CPUDeviceContext& context,
 | 
						|
                  const framework::Tensor& input, framework::Tensor* out) {
 | 
						|
    row_mean_(context, input, out);
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  math::RowwiseMean<platform::CPUDeviceContext, T> row_mean_;
 | 
						|
};
 | 
						|
 | 
						|
template <typename DeviceContext, typename T>
 | 
						|
struct ColwiseSum2D {
 | 
						|
  ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx);
 | 
						|
 | 
						|
  void operator()(const platform::DeviceContext& context,
 | 
						|
                  const framework::Tensor& input, framework::Tensor* vec);
 | 
						|
};
 | 
						|
 | 
						|
#ifdef PADDLE_WITH_CUDA
 | 
						|
template <typename T>
 | 
						|
class ColwiseSum2D<platform::CUDADeviceContext, T> {
 | 
						|
 public:
 | 
						|
  ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx)
 | 
						|
      : left_(left), right_(right) {
 | 
						|
    framework::DDim ones_dim({left_});
 | 
						|
    divisor_.mutable_data<T>(ones_dim, dev_ctx.GetPlace());
 | 
						|
    math::set_constant(dev_ctx, &divisor_, 1.0);
 | 
						|
  }
 | 
						|
 | 
						|
  void operator()(const platform::CUDADeviceContext& context,
 | 
						|
                  const framework::Tensor& input, framework::Tensor* out) {
 | 
						|
    math::GetBlas<platform::CUDADeviceContext, T>(context).GEMV(
 | 
						|
        true, left_, right_, 1., input.data<T>(), divisor_.data<T>(), 0.,
 | 
						|
        out->data<T>());
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  int left_;
 | 
						|
  int right_;
 | 
						|
  framework::Tensor divisor_;
 | 
						|
};
 | 
						|
#endif
 | 
						|
 | 
						|
template <typename T>
 | 
						|
class ColwiseSum2D<platform::CPUDeviceContext, T> {
 | 
						|
 public:
 | 
						|
  ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx) {}
 | 
						|
 | 
						|
  void operator()(const platform::CPUDeviceContext& context,
 | 
						|
                  const framework::Tensor& input, framework::Tensor* out) {
 | 
						|
    col_wise_(context, input, out);
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  math::ColwiseSum<platform::CPUDeviceContext, T> col_wise_;
 | 
						|
};
 | 
						|
 | 
						|
template <typename T>
 | 
						|
struct SubAndSquareFunctor {
 | 
						|
  inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); }
 | 
						|
};
 | 
						|
 | 
						|
template <typename T>
 | 
						|
struct DivAndSqrtFunctor {
 | 
						|
  explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; }
 | 
						|
  inline HOSTDEVICE T operator()(T a, T b) const {
 | 
						|
    return a / (sqrt(b + epsilon_));
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  T epsilon_;
 | 
						|
};
 | 
						|
 | 
						|
template <typename T>
 | 
						|
struct MulFunctor {
 | 
						|
  inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
 | 
						|
};
 | 
						|
 | 
						|
template <typename T>
 | 
						|
struct AddFunctor {
 | 
						|
  inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
 | 
						|
};
 | 
						|
 | 
						|
template <typename T>
 | 
						|
struct SubFunctor {
 | 
						|
  inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
 | 
						|
};
 | 
						|
 | 
						|
template <typename T>
 | 
						|
struct MulInvVarFunctor {
 | 
						|
  inline HOSTDEVICE T operator()(T a, T b) const {
 | 
						|
    return a * std::sqrt(1.0 / b);
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
using Tensor = framework::Tensor;
 | 
						|
using LoDTensor = framework::LoDTensor;
 | 
						|
using DataLayout = framework::DataLayout;
 | 
						|
 | 
						|
template <typename DeviceContext, typename T>
 | 
						|
class LayerNormKernel : public framework::OpKernel<T> {
 | 
						|
 public:
 | 
						|
  void Compute(const framework::ExecutionContext& ctx) const override {
 | 
						|
    const float epsilon = ctx.Attr<float>("epsilon");
 | 
						|
    auto* scale = ctx.Input<Tensor>("Scale");
 | 
						|
    auto* bias = ctx.Input<Tensor>("Bias");
 | 
						|
    auto x = *ctx.Input<Tensor>("X");
 | 
						|
 | 
						|
    auto* y = ctx.Output<Tensor>("Y");
 | 
						|
    auto* mean = ctx.Output<Tensor>("Mean");
 | 
						|
    auto* var = ctx.Output<Tensor>("Variance");
 | 
						|
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
 | 
						|
 | 
						|
    const auto x_dims = x.dims();
 | 
						|
 | 
						|
    y->mutable_data<T>(ctx.GetPlace());
 | 
						|
    mean->mutable_data<T>(ctx.GetPlace());
 | 
						|
    var->mutable_data<T>(ctx.GetPlace());
 | 
						|
 | 
						|
    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
 | 
						|
    int left = static_cast<int>(matrix_dim[0]);
 | 
						|
    int right = static_cast<int>(matrix_dim[1]);
 | 
						|
    framework::DDim matrix_shape({left, right});
 | 
						|
 | 
						|
    x.Resize(matrix_shape);
 | 
						|
    Tensor out;
 | 
						|
    out.ShareDataWith(*y);
 | 
						|
    out.Resize(matrix_shape);
 | 
						|
 | 
						|
#if defined(PADDLE_WITH_CUDA) || defined(_WIN32) || defined(__APPLE__) || \
 | 
						|
    defined(__OSX__)
 | 
						|
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
 | 
						|
    RowwiseMean2D<DeviceContext, T> row_mean(left, right, ctx.device_context());
 | 
						|
 | 
						|
    // get mean
 | 
						|
    row_mean(dev_ctx, x, mean);
 | 
						|
 | 
						|
    // get variance
 | 
						|
    ElementwiseComputeEx<SubAndSquareFunctor<T>, DeviceContext, T>(
 | 
						|
        ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor<T>(), &out);
 | 
						|
    row_mean(dev_ctx, out, var);
 | 
						|
 | 
						|
    // get x_norm
 | 
						|
    ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
 | 
						|
        ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), &out);
 | 
						|
    ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
 | 
						|
        ctx, &out, var, /*axis*/ 0,
 | 
						|
        DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), &out);
 | 
						|
 | 
						|
    if (scale) {
 | 
						|
      ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
 | 
						|
          ctx, &out, scale, /*axis*/ 1, MulFunctor<T>(), &out);
 | 
						|
    }
 | 
						|
    if (bias) {
 | 
						|
      ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
 | 
						|
          ctx, &out, bias, /*axis*/ 1, AddFunctor<T>(), &out);
 | 
						|
    }
 | 
						|
#else
 | 
						|
    PADDLE_ENFORCE_EQ(mean->numel(), left);
 | 
						|
    PADDLE_ENFORCE_EQ(var->numel(), left);
 | 
						|
    PADDLE_ENFORCE_EQ(scale->numel(), right);
 | 
						|
    PADDLE_ENFORCE_EQ(bias->numel(), right);
 | 
						|
 | 
						|
    const auto& ker = math::jitkernel::KernelPool::Instance()
 | 
						|
                          .template Get<math::jitkernel::LayerNormKernel<T>>(
 | 
						|
                              static_cast<int>(right));
 | 
						|
    ker->Compute(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
 | 
						|
                 scale->data<T>(), bias->data<T>(), static_cast<int>(left),
 | 
						|
                 static_cast<const float>(epsilon));
 | 
						|
#endif
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
template <typename DeviceContext, typename T>
 | 
						|
class LayerNormGradKernel : public framework::OpKernel<T> {
 | 
						|
 public:
 | 
						|
  void Compute(const framework::ExecutionContext& ctx) const override {
 | 
						|
    const float epsilon = ctx.Attr<float>("epsilon");
 | 
						|
    auto x = *ctx.Input<Tensor>("X");
 | 
						|
    auto* y = ctx.Input<Tensor>("Y");
 | 
						|
    auto* mean = ctx.Input<Tensor>("Mean");
 | 
						|
    auto* var = ctx.Input<Tensor>("Variance");
 | 
						|
    auto* scale = ctx.Input<Tensor>("Scale");
 | 
						|
    auto* bias = ctx.Input<Tensor>("Bias");
 | 
						|
    auto d_y = *ctx.Input<Tensor>(framework::GradVarName("Y"));
 | 
						|
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
 | 
						|
 | 
						|
    // init output
 | 
						|
    auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
 | 
						|
    auto* d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
 | 
						|
    auto* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
 | 
						|
 | 
						|
    const auto& x_dims = x.dims();
 | 
						|
    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
 | 
						|
    int left = static_cast<int>(matrix_dim[0]);
 | 
						|
    int right = static_cast<int>(matrix_dim[1]);
 | 
						|
    framework::DDim matrix_shape({left, right});
 | 
						|
 | 
						|
    d_y.Resize(matrix_shape);
 | 
						|
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
 | 
						|
    ColwiseSum2D<DeviceContext, T> colwise_sum(left, right,
 | 
						|
                                               ctx.device_context());
 | 
						|
 | 
						|
    Tensor temp;
 | 
						|
    Tensor temp_norm;
 | 
						|
    if (d_scale || d_x) {
 | 
						|
      x.Resize(matrix_shape);
 | 
						|
      temp.mutable_data<T>(matrix_shape, ctx.GetPlace());
 | 
						|
 | 
						|
      if (!(bias && scale)) {
 | 
						|
        temp_norm.ShareDataWith(*y);
 | 
						|
        temp_norm.Resize(matrix_shape);
 | 
						|
      } else {
 | 
						|
        temp_norm.mutable_data<T>(matrix_shape, ctx.GetPlace());
 | 
						|
        // get x_norm
 | 
						|
        ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
 | 
						|
            ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), &temp_norm);
 | 
						|
        ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
 | 
						|
            ctx, &temp_norm, var, /*axis*/ 0,
 | 
						|
            DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), &temp_norm);
 | 
						|
      }
 | 
						|
    }
 | 
						|
 | 
						|
    if (d_bias) {
 | 
						|
      d_bias->mutable_data<T>(ctx.GetPlace());
 | 
						|
      colwise_sum(dev_ctx, d_y, d_bias);
 | 
						|
    }
 | 
						|
    if (d_scale) {
 | 
						|
      d_scale->mutable_data<T>(ctx.GetPlace());
 | 
						|
      ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
 | 
						|
          ctx, &temp_norm, &d_y, /*axis*/ 0, MulFunctor<T>(), &temp);
 | 
						|
      colwise_sum(dev_ctx, temp, d_scale);
 | 
						|
    }
 | 
						|
 | 
						|
    if (d_x) {
 | 
						|
      framework::DDim vec_shape({left});
 | 
						|
      d_x->mutable_data<T>(ctx.GetPlace());
 | 
						|
      auto dx_dim = d_x->dims();
 | 
						|
      Tensor temp_vec;
 | 
						|
      temp_vec.mutable_data<T>(vec_shape, ctx.GetPlace());
 | 
						|
 | 
						|
      RowwiseMean2D<DeviceContext, T> row_mean(left, right,
 | 
						|
                                               ctx.device_context());
 | 
						|
 | 
						|
      if (d_scale) {
 | 
						|
        // dy_dx
 | 
						|
        ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
 | 
						|
            ctx, &d_y, scale, /*axis*/ 1, MulFunctor<T>(), &temp);
 | 
						|
        framework::TensorCopy(temp, ctx.GetPlace(), ctx.device_context(), d_x);
 | 
						|
 | 
						|
        // dy_dmean_dx
 | 
						|
        row_mean(dev_ctx, temp, &temp_vec);
 | 
						|
        ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
 | 
						|
            ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor<T>(), d_x);
 | 
						|
 | 
						|
        // dy_var_dx
 | 
						|
        ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
 | 
						|
            ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
 | 
						|
      } else {
 | 
						|
        // dy_dx
 | 
						|
        framework::TensorCopy(d_y, ctx.GetPlace(), ctx.device_context(), d_x);
 | 
						|
 | 
						|
        // dy_dmean_dx
 | 
						|
        row_mean(dev_ctx, d_y, &temp_vec);
 | 
						|
        ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
 | 
						|
            ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor<T>(), d_x);
 | 
						|
 | 
						|
        // dy_var_dx
 | 
						|
        ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
 | 
						|
            ctx, &d_y, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
 | 
						|
      }
 | 
						|
      // dy_var_dx
 | 
						|
      row_mean(dev_ctx, temp, &temp_vec);
 | 
						|
      ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
 | 
						|
          ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor<T>(), &temp);
 | 
						|
      ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
 | 
						|
          ctx, d_x, &temp, /*axis*/ 0, SubFunctor<T>(), d_x);
 | 
						|
 | 
						|
      ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
 | 
						|
          ctx, d_x, var, /*axis*/ 0,
 | 
						|
          DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), d_x);
 | 
						|
      d_x->Resize(dx_dim);
 | 
						|
    }
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
}  // namespace operators
 | 
						|
}  // namespace paddle
 |