Add a pass to fuse fc+elementwise_add+layernorm (#19776)
* Add fc_elementwise_layernorm_fuse pass and unittest. * Add fused_fc_elementwise_layernorm op and its GPU kernel. test=develop * Apply fc_elementwise_layernorm_fuse_pass to GPU inference. * Add the setting of attrs in the definition of binary_op. test=develop * Add comment. * Implement the unittest. test=develop * Change the unittest name of layer_norm. test=developexpand_as_op_1
parent
8c2c8dc626
commit
3cd985a669
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,33 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class FCElementwiseLayerNormFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~FCElementwiseLayerNormFusePass() {}
|
||||
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,67 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
TEST(FCElementwiseLayerNormFusePass, basic) {
|
||||
// inputs operator output
|
||||
// --------------------------------------------------------------------
|
||||
// (x, weights_0, bias_0) fc -> fc_out_0
|
||||
// (fc_out_0, weights_1, bias_1) fc -> fc_out_1
|
||||
// (fc_out_1, y) elementwise_add -> elementwise_out
|
||||
// (elementwise_out, scale, bias_2) layer_norm ->
|
||||
Layers layers;
|
||||
auto* x = layers.data("x", {128, 768});
|
||||
auto* weights_0 = layers.data("weights_0", {768, 3072}, true);
|
||||
auto* bias_0 = layers.data("bias_0", {3072}, true);
|
||||
auto* fc_out_0 = layers.fc(x, weights_0, bias_0); // {128, 3072}
|
||||
auto* weights_1 = layers.data("weights_1", {3072, 768}, true);
|
||||
auto* bias_1 = layers.data("bias_1", {768}, true);
|
||||
auto* fc_out_1 =
|
||||
layers.fc(fc_out_0, weights_1, bias_1, 1, "relu"); // {128, 768}
|
||||
fc_out_1->SetShape({128, 768});
|
||||
auto* y = layers.data("y", {128, 768});
|
||||
auto* elementwise_out = layers.elementwise_add(fc_out_1, y);
|
||||
auto* scale = layers.data("scale", {768}, true);
|
||||
auto* bias_2 = layers.data("bias_2", {768}, true);
|
||||
layers.layer_norm(elementwise_out, scale, bias_2);
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
|
||||
auto pass =
|
||||
PassRegistry::Instance().Get("fc_elementwise_layernorm_fuse_pass");
|
||||
int num_nodes_before = graph->Nodes().size();
|
||||
VLOG(3) << DebugString(graph);
|
||||
|
||||
graph.reset(pass->Apply(graph.release()));
|
||||
int num_nodes_after = graph->Nodes().size();
|
||||
int num_fused_nodes_after =
|
||||
GetNumOpNodes(graph, "fused_fc_elementwise_layernorm");
|
||||
VLOG(3) << DebugString(graph);
|
||||
|
||||
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6);
|
||||
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(fc_elementwise_layernorm_fuse_pass);
|
@ -0,0 +1,185 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class FusedFCElementwiseLayerNormOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("X"), true,
|
||||
"Input(X) of fused_fc_elementwise_layernorm should not be null.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("W"), true,
|
||||
"Input(W) of fused_fc_elementwise_layernorm should not be null.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("Y"), true,
|
||||
"Input(Y) of fused_fc_elementwise_layernorm should not be null.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasOutput("Out"), true,
|
||||
"Output(Out) of fused_fc_elementwise_layernorm should not be null.");
|
||||
|
||||
auto w_dims = ctx->GetInputDim("W");
|
||||
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
|
||||
"Fully Connected input should be 2-D tensor.");
|
||||
|
||||
if (ctx->HasInput("Bias0")) {
|
||||
auto bias0_dims = ctx->GetInputDim("Bias0");
|
||||
if (bias0_dims.size() == 2) {
|
||||
PADDLE_ENFORCE_EQ(bias0_dims[0], 1,
|
||||
"The shape of Bias must be [1, dim].");
|
||||
PADDLE_ENFORCE_EQ(bias0_dims[1], w_dims[1],
|
||||
"The shape of Bias must be [1, dim].");
|
||||
} else if (bias0_dims.size() == 1) {
|
||||
PADDLE_ENFORCE_EQ(bias0_dims[0], w_dims[1],
|
||||
"The shape of Bias must be [1, dim].");
|
||||
}
|
||||
}
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
|
||||
PADDLE_ENFORCE_GT(
|
||||
x_dims.size(), x_num_col_dims,
|
||||
"The input tensor Input's rank of FCOp should be larger than "
|
||||
"in_num_col_dims.");
|
||||
|
||||
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
x_mat_dims[1], w_dims[0],
|
||||
"Fully Connected input and weigth size do not match. %s, %s");
|
||||
|
||||
std::vector<int64_t> fc_out_dims;
|
||||
for (int i = 0; i < x_num_col_dims; ++i) {
|
||||
fc_out_dims.push_back(x_dims[i]);
|
||||
}
|
||||
fc_out_dims.push_back(w_dims[1]);
|
||||
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
PADDLE_ENFORCE_EQ(framework::make_ddim(fc_out_dims), y_dims);
|
||||
|
||||
auto begin_norm_axis = ctx->Attrs().Get<int>("begin_norm_axis");
|
||||
PADDLE_ENFORCE_LT(
|
||||
begin_norm_axis, y_dims.size(),
|
||||
"'begin_norm_axis' must be less than the rank of Input(Y).");
|
||||
|
||||
auto y_mat_dim = framework::flatten_to_2d(y_dims, begin_norm_axis);
|
||||
int64_t dim_0 = y_mat_dim[0];
|
||||
int64_t dim_1 = y_mat_dim[1];
|
||||
if (ctx->HasInput("Scale")) {
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1);
|
||||
|
||||
if (ctx->IsRuntime()) {
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], dim_1,
|
||||
"scale should with right");
|
||||
}
|
||||
}
|
||||
if (ctx->HasInput("Bias1")) {
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias1").size(), 1);
|
||||
if (ctx->IsRuntime()) {
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias1")[0], dim_1,
|
||||
"bias should with right");
|
||||
}
|
||||
}
|
||||
|
||||
ctx->SetOutputDim("Out", y_dims);
|
||||
if (ctx->HasOutput("Mean")) {
|
||||
ctx->SetOutputDim("Mean", {dim_0});
|
||||
}
|
||||
if (ctx->HasOutput("Variance")) {
|
||||
ctx->SetOutputDim("Variance", {dim_0});
|
||||
}
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class FusedFCElementwiseLayerNormOpMaker
|
||||
: public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor), The input tensor of fully connected operation");
|
||||
AddInput("W",
|
||||
"(Tensor), The weight tensor of fully connected operation. It is "
|
||||
"a 2-D Tensor with shape (I, O)");
|
||||
AddInput("Bias0",
|
||||
"(Tensor, optional), The bias tensor of fully connecred "
|
||||
"operation. It is a 1-D Tensor with shape (O), or a 2-D Tensor "
|
||||
"with shape (1, O).")
|
||||
.AsDispensable();
|
||||
AddInput("Y",
|
||||
"(Tensor), The second input tensor of elementwise_add operation. "
|
||||
"Note that the shape should be the same as fully connect's result "
|
||||
"tensor.");
|
||||
AddInput(
|
||||
"Scale",
|
||||
"(Tensor, optional), It is a 1-D input Tensor of layer_norm operation.")
|
||||
.AsDispensable();
|
||||
AddInput(
|
||||
"Bias1",
|
||||
"(Tensor, optional), It is a 1-D input Tensor of layer_norm operation.")
|
||||
.AsDispensable();
|
||||
AddOutput("Out",
|
||||
"(Tensor), Output after normalization. The shape is the shame as "
|
||||
"layer_norm's input.");
|
||||
AddOutput("Mean", "(Tensor, optional), Mean of the current minibatch")
|
||||
.AsDispensable();
|
||||
AddOutput("Variance",
|
||||
"(Tensor, optional), Variance of the current minibatch")
|
||||
.AsDispensable();
|
||||
AddAttr<int>("x_num_col_dims",
|
||||
"(int, default 1), This op can take tensors with more than "
|
||||
"two dimensions as its inputs.")
|
||||
.SetDefault(1)
|
||||
.EqualGreaterThan(1);
|
||||
AddAttr<std::string>("activation_type",
|
||||
"Activation type used in fully connected operator.")
|
||||
.SetDefault("");
|
||||
AddAttr<float>("epsilon",
|
||||
"Constant for numerical stability [default 1e-5].")
|
||||
.SetDefault(1e-5)
|
||||
.AddCustomChecker([](const float &epsilon) {
|
||||
PADDLE_ENFORCE_GE(epsilon, 0.0f,
|
||||
"'epsilon' should be between 0.0 and 0.001.");
|
||||
PADDLE_ENFORCE_LE(epsilon, 0.001f,
|
||||
"'epsilon' should be between 0.0 and 0.001.");
|
||||
});
|
||||
AddAttr<int>("begin_norm_axis",
|
||||
"the axis of `begin_norm_axis ... Rank(Y) - 1` will be "
|
||||
"normalized. `begin_norm_axis` splits the tensor(`X`) to a "
|
||||
"matrix [N,H]. [default 1].")
|
||||
.SetDefault(1)
|
||||
.AddCustomChecker([](const int &begin_norm_axis) {
|
||||
PADDLE_ENFORCE_GT(begin_norm_axis, 0,
|
||||
"'begin_norm_axis' should be greater than zero.");
|
||||
});
|
||||
AddComment(R"DOC(
|
||||
fc_out <= fc(X, W, Bias0)
|
||||
add_out <= elementwise_add(fc_out, Y)
|
||||
(out, mean, variance) <= layer_norm(add_out, Scale, Bias1)
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(fused_fc_elementwise_layernorm,
|
||||
ops::FusedFCElementwiseLayerNormOp,
|
||||
ops::FusedFCElementwiseLayerNormOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
@ -0,0 +1,201 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/platform/cuda_device_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T Relu(T x) {
|
||||
return (x > 0) ? x : 0;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float RealSqrt(float x) { return sqrtf(x); }
|
||||
static __device__ __forceinline__ double RealSqrt(double x) { return sqrt(x); }
|
||||
|
||||
template <typename T>
|
||||
struct PairForLayerNorm {
|
||||
__device__ __forceinline__ PairForLayerNorm() {}
|
||||
__device__ __forceinline__ PairForLayerNorm(const T& first, const T& second)
|
||||
: first_(first), second_(second) {}
|
||||
|
||||
T first_;
|
||||
T second_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct PairForLayerNormAddFunctor {
|
||||
__device__ __forceinline__ PairForLayerNorm<T> operator()(
|
||||
const PairForLayerNorm<T>& p1, const PairForLayerNorm<T>& p2) {
|
||||
return PairForLayerNorm<T>(p1.first_ + p2.first_, p1.second_ + p2.second_);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, bool DoRelu, int BlockDim>
|
||||
__global__ void InplaceAddReluAddLayerNormKernel(const T* y, const T* bias_0,
|
||||
const T* bias_1,
|
||||
const T* scale, T* out,
|
||||
T* mean, T* variance, int M,
|
||||
int N, float epsilon) {
|
||||
using BlockReduce = cub::BlockReduce<PairForLayerNorm<double>, BlockDim>;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ T shared_mem[BlockDim + 2];
|
||||
|
||||
for (int i = blockIdx.x; i < M; i += gridDim.x) {
|
||||
int index = i * N + threadIdx.x;
|
||||
|
||||
// The fisrt BlockDim elements will be saved to shared memory.
|
||||
int save_index = threadIdx.x;
|
||||
T* save_ptr = shared_mem;
|
||||
|
||||
double sum_i = 0;
|
||||
double square_sum_i = 0;
|
||||
for (int j = threadIdx.x; j < N; j += blockDim.x) {
|
||||
T tmp_0 = out[index];
|
||||
// Add bias
|
||||
T tmp_1 = bias_0 ? tmp_0 + bias_0[j] : tmp_0;
|
||||
// Relu
|
||||
T tmp_2 = DoRelu ? Relu(tmp_1) : tmp_1;
|
||||
// elementwise_add
|
||||
T tmp_3 = tmp_2 + y[index];
|
||||
|
||||
// Save
|
||||
save_ptr[save_index] = tmp_3;
|
||||
save_ptr = out;
|
||||
|
||||
index += blockDim.x;
|
||||
save_index = index;
|
||||
|
||||
// For layer_norm, reduce to calculate mean and std
|
||||
sum_i += tmp_3;
|
||||
square_sum_i += (tmp_3 * tmp_3);
|
||||
}
|
||||
|
||||
auto pair = BlockReduce(temp_storage)
|
||||
.Reduce(PairForLayerNorm<double>(sum_i, square_sum_i),
|
||||
PairForLayerNormAddFunctor<double>());
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
T mean_i = static_cast<T>(pair.first_ / N);
|
||||
T variance_i = static_cast<T>(pair.second_ / N - mean_i * mean_i);
|
||||
shared_mem[BlockDim] = mean_i;
|
||||
shared_mem[BlockDim + 1] = variance_i;
|
||||
if (mean) {
|
||||
mean[blockIdx.x] = mean_i;
|
||||
}
|
||||
if (variance) {
|
||||
variance[blockIdx.x] = variance_i;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
T mean_i = shared_mem[BlockDim];
|
||||
T std_i = static_cast<T>(RealSqrt(shared_mem[BlockDim + 1] + epsilon));
|
||||
|
||||
index = i * N + threadIdx.x;
|
||||
// First BlockDim elements loading from shared memory.
|
||||
save_index = threadIdx.x;
|
||||
save_ptr = shared_mem;
|
||||
|
||||
// For layer_norm, calculate out
|
||||
for (int j = threadIdx.x; j < N; j += blockDim.x) {
|
||||
T tmp_0 = (save_ptr[save_index] - mean_i) / std_i;
|
||||
T tmp_1 = scale ? scale[j] * tmp_0 : tmp_0;
|
||||
out[index] = bias_1 ? tmp_1 + bias_1[j] : tmp_1;
|
||||
|
||||
save_ptr = out;
|
||||
index += blockDim.x;
|
||||
save_index = index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* x = ctx.Input<framework::Tensor>("X");
|
||||
auto* w = ctx.Input<framework::Tensor>("W");
|
||||
auto* out = ctx.Output<framework::Tensor>("Out");
|
||||
|
||||
auto w_dims = w->dims();
|
||||
int N = w_dims[1];
|
||||
int K = w_dims[0];
|
||||
int M = framework::product(x->dims()) / K;
|
||||
|
||||
const T* x_data = x->data<T>();
|
||||
const T* w_data = w->data<T>();
|
||||
T* out_data = out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
|
||||
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), x_data, K, w_data, N,
|
||||
static_cast<T>(0.0), out_data, N);
|
||||
|
||||
auto* y = ctx.Input<framework::Tensor>("Y");
|
||||
auto* bias_0 = ctx.Input<framework::Tensor>("Bias0");
|
||||
auto* bias_1 = ctx.Input<framework::Tensor>("Bias1");
|
||||
auto* scale = ctx.Input<framework::Tensor>("Scale");
|
||||
|
||||
const T* y_data = y->data<T>();
|
||||
const T* bias_0_data = bias_0 ? bias_0->data<T>() : nullptr;
|
||||
const T* bias_1_data = bias_1 ? bias_1->data<T>() : nullptr;
|
||||
const T* scale_data = scale ? scale->data<T>() : nullptr;
|
||||
|
||||
auto* mean = ctx.Output<framework::Tensor>("Mean");
|
||||
auto* variance = ctx.Output<framework::Tensor>("Variance");
|
||||
|
||||
T* mean_data = mean ? mean->mutable_data<T>(ctx.GetPlace()) : nullptr;
|
||||
T* variance_data =
|
||||
variance ? variance->mutable_data<T>(ctx.GetPlace()) : nullptr;
|
||||
|
||||
bool with_relu =
|
||||
(ctx.Attr<std::string>("activation_type") == "relu") ? true : false;
|
||||
float epsilon = ctx.Attr<float>("epsilon");
|
||||
|
||||
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
|
||||
if (with_relu) {
|
||||
switch (platform::RoundToPowerOfTwo(N)) {
|
||||
CUDA_LAUNCH_KERNEL_HELPER(
|
||||
InplaceAddReluAddLayerNormKernel<
|
||||
T, true,
|
||||
kPowerOfTwoDim><<<std::max(max_threads / kPowerOfTwoDim, 1),
|
||||
kPowerOfTwoDim, 0, dev_ctx.stream()>>>(
|
||||
y_data, bias_0_data, bias_1_data, scale_data, out_data,
|
||||
mean_data, variance_data, M, N, epsilon));
|
||||
}
|
||||
} else {
|
||||
switch (platform::RoundToPowerOfTwo(N)) {
|
||||
CUDA_LAUNCH_KERNEL_HELPER(
|
||||
InplaceAddReluAddLayerNormKernel<
|
||||
T, false,
|
||||
kPowerOfTwoDim><<<std::max(max_threads / kPowerOfTwoDim, 1),
|
||||
kPowerOfTwoDim, 0, dev_ctx.stream()>>>(
|
||||
y_data, bias_0_data, bias_1_data, scale_data, out_data,
|
||||
mean_data, variance_data, M, N, epsilon));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(fused_fc_elementwise_layernorm,
|
||||
ops::FusedFCElementwiseLayerNormOpKernel<float>,
|
||||
ops::FusedFCElementwiseLayerNormOpKernel<double>);
|
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
from paddle.fluid import core
|
||||
from test_fc_op import fc_refer, MatrixGenerate
|
||||
from test_layer_norm_op import _reference_layer_norm_naive
|
||||
|
||||
np.random.random(123)
|
||||
|
||||
|
||||
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
||||
"Paddle core is not compiled with CUDA")
|
||||
class TestFusedFCElementwiseLayerNormOp(OpTest):
|
||||
def config(self):
|
||||
self.matrix = MatrixGenerate(1, 10, 15, 3, 3, 2)
|
||||
self.y_shape = [1, 15]
|
||||
self.begin_norm_axis = 1
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "fused_fc_elementwise_layernorm"
|
||||
self.config()
|
||||
|
||||
# Attr of layer_norm
|
||||
epsilon = 0.00001
|
||||
|
||||
# fc
|
||||
fc_out = fc_refer(self.matrix, True, True)
|
||||
# elementwise_add
|
||||
y = np.random.random_sample(self.y_shape).astype(np.float32)
|
||||
add_out = fc_out + y
|
||||
# layer_norm
|
||||
scale_shape = [np.prod(self.y_shape[self.begin_norm_axis:])]
|
||||
scale = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
bias_1 = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
out, mean, variance = _reference_layer_norm_naive(
|
||||
add_out, scale, bias_1, epsilon, self.begin_norm_axis)
|
||||
|
||||
self.inputs = {
|
||||
"X": self.matrix.input,
|
||||
"W": self.matrix.weights,
|
||||
"Bias0": self.matrix.bias,
|
||||
"Y": y,
|
||||
"Scale": scale,
|
||||
"Bias1": bias_1
|
||||
}
|
||||
self.attrs = {
|
||||
"activation_type": "relu",
|
||||
"epsilon": epsilon,
|
||||
"begin_norm_axis": self.begin_norm_axis
|
||||
}
|
||||
self.outputs = {"Out": out, "Mean": mean, "Variance": variance}
|
||||
|
||||
def test_check_output(self):
|
||||
place = core.CUDAPlace(0)
|
||||
self.check_output_with_place(place, atol=2e-3)
|
||||
|
||||
|
||||
class TestFusedFCElementwiseLayerNormOp2(TestFusedFCElementwiseLayerNormOp):
|
||||
def config(self):
|
||||
self.matrix = MatrixGenerate(4, 5, 6, 2, 2, 1)
|
||||
self.y_shape = [4, 6]
|
||||
self.begin_norm_axis = 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue