[Ernie GPU Optimize]: Embedding_eltwise_layernorm Fuse (#22494)
* 1. add embedding eltwise layernorm fuse 2. add embedding eltwise layernorm op 3. refine inplace_add_relu 4. refine fc_eltwise_layernorm test=develop * 1. refine fc test=develop * fix comments test=develop * fix comments test=developrevert-22710-feature/integrated_ps_api
parent
4ff2915d1f
commit
8d6dc102fe
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,95 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
namespace patterns {
|
||||||
|
|
||||||
|
struct EmbeddingEltwiseLayerNormPattern : public PatternBase {
|
||||||
|
EmbeddingEltwiseLayerNormPattern(PDPattern* pattern,
|
||||||
|
const std::string& name_scope)
|
||||||
|
: PatternBase(pattern, name_scope, "embedding_eltwise_layernorm") {}
|
||||||
|
|
||||||
|
PDNode* operator()();
|
||||||
|
|
||||||
|
PATTERN_DECL_NODE(lookup_table1_x);
|
||||||
|
PATTERN_DECL_NODE(lookup_table2_x);
|
||||||
|
PATTERN_DECL_NODE(lookup_table3_x);
|
||||||
|
|
||||||
|
PATTERN_DECL_NODE(lookup_table1_w);
|
||||||
|
PATTERN_DECL_NODE(lookup_table2_w);
|
||||||
|
PATTERN_DECL_NODE(lookup_table3_w);
|
||||||
|
|
||||||
|
PATTERN_DECL_NODE(lookup_table1);
|
||||||
|
PATTERN_DECL_NODE(lookup_table2);
|
||||||
|
PATTERN_DECL_NODE(lookup_table3);
|
||||||
|
|
||||||
|
PATTERN_DECL_NODE(lookup_table1_out);
|
||||||
|
PATTERN_DECL_NODE(lookup_table2_out);
|
||||||
|
PATTERN_DECL_NODE(lookup_table3_out);
|
||||||
|
|
||||||
|
PATTERN_DECL_NODE(eltwise_add_12);
|
||||||
|
PATTERN_DECL_NODE(eltwise_add_12_out);
|
||||||
|
|
||||||
|
PATTERN_DECL_NODE(eltwise_add);
|
||||||
|
PATTERN_DECL_NODE(eltwise_add_out);
|
||||||
|
|
||||||
|
PATTERN_DECL_NODE(layer_norm);
|
||||||
|
PATTERN_DECL_NODE(layer_norm_bias);
|
||||||
|
PATTERN_DECL_NODE(layer_norm_scale);
|
||||||
|
PATTERN_DECL_NODE(layer_norm_out);
|
||||||
|
// Delete the mean and var nodes in the graph.
|
||||||
|
PATTERN_DECL_NODE(layer_norm_mean);
|
||||||
|
PATTERN_DECL_NODE(layer_norm_variance);
|
||||||
|
};
|
||||||
|
} // namespace patterns
|
||||||
|
|
||||||
|
// The EmbeddingEltwiseLayerNormFusePass detect the following pattern:
|
||||||
|
//
|
||||||
|
// inputs operator output
|
||||||
|
// --------------------------------------------------------------------
|
||||||
|
// (word, weights_0) lookup_table -> word_emb
|
||||||
|
// (pos, weights_1) lookup_table -> pos_emb
|
||||||
|
// (sent, weights_2) lookup_table -> sent_emb
|
||||||
|
// (word_emb, pos_emb) elementweise_add -> elementwise_out_0
|
||||||
|
// (elemtwise_out_0, sent_emb) elementweise_add -> elementwise_out_1
|
||||||
|
// (elementwise_out_1, scale, bias) layer_norm -> layer_norm_out
|
||||||
|
//
|
||||||
|
// and then convert the corresponding subgraph to:
|
||||||
|
//
|
||||||
|
// (word, pos, sent, weights_0, weights_1, weights_2,
|
||||||
|
// scale, baias) embedding_eltwise_layernorm -> layer_norm_out
|
||||||
|
|
||||||
|
class EmbeddingEltwiseLayerNormFusePass : public FusePassBase {
|
||||||
|
public:
|
||||||
|
virtual ~EmbeddingEltwiseLayerNormFusePass() {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void ApplyImpl(Graph* graph) const;
|
||||||
|
|
||||||
|
const std::string name_scope_{"embedding_eltwise_layernorm_fuse"};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,177 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/operators/detail/safe_ref.h"
|
||||||
|
#include "paddle/fluid/platform/errors.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InferShape(framework::InferShapeContext* context) const override {
|
||||||
|
PADDLE_ENFORCE_EQ(context->HasInput("WordId"), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input(WordId) of EmbeddingEltWiseLayerNormOp should "
|
||||||
|
"not be null."));
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
context->HasInput("PosId"), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input(PosId) of EmbeddingEltWiseLayerNormOp should not be null."));
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(context->HasInput("SentId"), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input(SentId) of EmbeddingEltWiseLayerNormOp should "
|
||||||
|
"not be null."));
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(context->HasInput("WordEmb"), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input(WordEmb) of EmbeddingEltWiseLayerNormOp "
|
||||||
|
"should not be null."));
|
||||||
|
PADDLE_ENFORCE_EQ(context->HasInput("PosEmb"), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input(PosEmb) of EmbeddingEltWiseLayerNormOp should "
|
||||||
|
"not be null."));
|
||||||
|
PADDLE_ENFORCE_EQ(context->HasInput("SentEmb"), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input(SentEmb) of EmbeddingEltWiseLayerNormOp "
|
||||||
|
"should not be null."));
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
context->HasInput("Bias"), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input(Bias) of EmbeddingEltWiseLayerNormOp should not be null."));
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
context->HasInput("Scale"), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input(Scale) of EmbeddingEltWiseLayerNormOp should not be null."));
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
context->HasOutput("Out"), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Output(Out) of EmbeddingEltWiseLayerNormOp should not be null."));
|
||||||
|
|
||||||
|
// batch * seq_len * 1
|
||||||
|
auto dims_word_id = context->GetInputDim("WordId");
|
||||||
|
// word_num * hidden
|
||||||
|
auto dims_word_emb = context->GetInputDim("WordEmb");
|
||||||
|
auto dims_pos_emb = context->GetInputDim("PosEmb");
|
||||||
|
auto dims_sent_emb = context->GetInputDim("SentEmb");
|
||||||
|
// hidden
|
||||||
|
auto dims_bias = context->GetInputDim("Bias");
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
dims_word_emb[1], dims_bias[0],
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The second dims (%d) of the Word Embedding should be equal "
|
||||||
|
"to the Bias's size(%d).",
|
||||||
|
dims_word_emb[1], dims_bias[0]));
|
||||||
|
PADDLE_ENFORCE_EQ(dims_word_emb.size(), 2,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The WordEmb dim's size shoule be 2, but found %d.",
|
||||||
|
dims_word_emb.size()));
|
||||||
|
PADDLE_ENFORCE_EQ(dims_pos_emb.size(), 2,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The PosEmb dim's size shoule be 2, but found %d.",
|
||||||
|
dims_pos_emb.size()));
|
||||||
|
PADDLE_ENFORCE_EQ(dims_sent_emb.size(), 2,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The SentEmb dim's size shoule be 2, but found %d.",
|
||||||
|
dims_sent_emb.size()));
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
dims_word_emb[1], dims_pos_emb[1],
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The WordEmb first dim size(%d) shoule equal to PosEmb ones(%d).",
|
||||||
|
dims_word_emb[1], dims_pos_emb[1]));
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
dims_word_emb[1], dims_sent_emb[1],
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The WordEmb first dim size(%d) shoule equal to SentEmb ones(%d).",
|
||||||
|
dims_word_emb[1], dims_sent_emb[1]));
|
||||||
|
|
||||||
|
int batch = dims_word_id[0];
|
||||||
|
int seq_len = dims_word_id[1];
|
||||||
|
int hidden = dims_word_emb[1];
|
||||||
|
auto dim_output = framework::make_ddim({batch, seq_len, hidden});
|
||||||
|
context->SetOutputDim("Out", dim_output);
|
||||||
|
context->ShareLoD("WordId", /*->*/ "Out");
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "WordEmb");
|
||||||
|
return framework::OpKernelType(data_type, ctx.device_context());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class EmbeddingEltWiseLayerNormOpMaker
|
||||||
|
: public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("WordId", "The word id input of EmbeddingEltWiseLayerNorm op");
|
||||||
|
AddInput("PosId", "The position id input of EmbeddingEltWiseLayerNorm op");
|
||||||
|
AddInput("SentId", "The sentence id input of EmbeddingEltWiseLayerNorm op");
|
||||||
|
AddInput("WordEmb",
|
||||||
|
"The Word embedding input of EmbeddingEltWiseLayerNorm op");
|
||||||
|
AddInput("PosEmb",
|
||||||
|
"The Position embedding input of EmbeddingEltWiseLayerNorm op");
|
||||||
|
AddInput("SentEmb",
|
||||||
|
"The Sent embedding input of EmbeddingEltWiseLayerNorm op");
|
||||||
|
AddInput("Bias", "The LayerNorm Bias of EmbeddingEltWiseLayerNorm op");
|
||||||
|
AddInput("Scale", "The LayerNorm Scale of EmbeddingEltWiseLayerNorm op");
|
||||||
|
AddOutput("Out", "The output of EmbeddingEltWiseLayerNorm op");
|
||||||
|
AddAttr<float>("epsilon",
|
||||||
|
"Constant for numerical stability [default 1e-5].")
|
||||||
|
.SetDefault(1e-5)
|
||||||
|
.AddCustomChecker([](const float& epsilon) {
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
epsilon, 0.0f,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"'epsilon' is %f, but it should be between 0.0 and 0.001",
|
||||||
|
epsilon));
|
||||||
|
PADDLE_ENFORCE_LE(
|
||||||
|
epsilon, 0.001f,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"'epsilon' is %f, but it should be between 0.0 and 0.001.",
|
||||||
|
epsilon));
|
||||||
|
});
|
||||||
|
AddComment(R"DOC(
|
||||||
|
EmbeddingEltWiseLayerNorm Operator.
|
||||||
|
|
||||||
|
This op is used for optimize the following structure in ernie model.
|
||||||
|
wordid -> lookup_table_op -> word
|
||||||
|
posid -> lookup_table_op -> pos
|
||||||
|
sentdid -> lookup_table_op -> sent
|
||||||
|
word + pos + sent -> Y
|
||||||
|
Y -> layer_norm -> Out
|
||||||
|
|
||||||
|
Not suggest to use in other case except has same structure as ernie.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(fused_embedding_eltwise_layernorm,
|
||||||
|
ops::EmbeddingEltWiseLayerNormOp,
|
||||||
|
ops::EmbeddingEltWiseLayerNormOpMaker);
|
@ -0,0 +1,165 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <paddle/fluid/platform/device_context.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cub/cub.cuh> // NOLINT
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/memory/malloc.h"
|
||||||
|
#include "paddle/fluid/operators/detail/safe_ref.h"
|
||||||
|
#include "paddle/fluid/operators/math/blas.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using kvp = cub::KeyValuePair<T, T>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using cv2 = cub::CubVector<T, 2>;
|
||||||
|
|
||||||
|
template <typename T, int TPB>
|
||||||
|
__device__ inline void LayerNorm(const cv2<T> &thread_data, const int ld,
|
||||||
|
const int offset, const float *bias,
|
||||||
|
const float *scale, T *output, float eps) {
|
||||||
|
using BlockReduce = cub::BlockReduce<cv2<T>, TPB>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
|
__shared__ T mu; // mean
|
||||||
|
__shared__ T rsigma; // 1 / std.dev.
|
||||||
|
|
||||||
|
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
mu = sum_kv.x;
|
||||||
|
rsigma = rsqrt(sum_kv.y - mu * mu + eps);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int i = threadIdx.x; i < ld; i += TPB) {
|
||||||
|
const int idx = offset + i;
|
||||||
|
const T val = output[idx];
|
||||||
|
const T g(scale[i]);
|
||||||
|
const T b(bias[i]);
|
||||||
|
output[idx] = g * (val - mu) * rsigma + b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, unsigned TPB>
|
||||||
|
__global__ void EmbEltwiseLayernormKernel(
|
||||||
|
int hidden, const int64_t *word_id_d, const int64_t *pos_id_d,
|
||||||
|
const int64_t *sent_id_d, const T *scale, const T *bias, const T *word_emb,
|
||||||
|
const T *pos_emb, const T *sent_emb, T *output, float eps) {
|
||||||
|
cub::Sum pair_sum;
|
||||||
|
// blockIdx.x: position in the sequence
|
||||||
|
// blockIdx.y: batch
|
||||||
|
// gridDim.x: Seq
|
||||||
|
// gridDim.y: Batch
|
||||||
|
__shared__ int64_t word_id;
|
||||||
|
__shared__ int64_t pos_id;
|
||||||
|
__shared__ int64_t sent_id;
|
||||||
|
|
||||||
|
const T rhidden = T(1.f) / T(hidden);
|
||||||
|
const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
word_id = word_id_d[seq_pos];
|
||||||
|
pos_id = pos_id_d[seq_pos];
|
||||||
|
sent_id = sent_id_d[seq_pos];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// load word, pos, sentence embeddings and add them toghether
|
||||||
|
const int64_t woffset = word_id * hidden;
|
||||||
|
const int64_t poffset = pos_id * hidden;
|
||||||
|
const int64_t soffset = sent_id * hidden;
|
||||||
|
const int64_t out_offset = seq_pos * hidden;
|
||||||
|
|
||||||
|
cv2<T> thread_data;
|
||||||
|
thread_data.x = 0;
|
||||||
|
thread_data.y = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = threadIdx.x; it < hidden; it += TPB) {
|
||||||
|
const T w(word_emb[woffset + it]);
|
||||||
|
const T p(pos_emb[poffset + it]);
|
||||||
|
const T s(sent_emb[soffset + it]);
|
||||||
|
const T val = w + s + p;
|
||||||
|
|
||||||
|
output[out_offset + it] = val;
|
||||||
|
const T rhiddenval = rhidden * val;
|
||||||
|
cv2<T> temp_data;
|
||||||
|
temp_data.x = rhiddenval;
|
||||||
|
temp_data.y = rhiddenval * val;
|
||||||
|
|
||||||
|
thread_data = pair_sum(thread_data, temp_data);
|
||||||
|
}
|
||||||
|
LayerNorm<T, TPB>(thread_data, hidden, out_offset, bias, scale, output, eps);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext &context) const override {
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
auto *word_id = context.Input<framework::Tensor>("WordId");
|
||||||
|
auto *pos_id = context.Input<framework::Tensor>("PosId");
|
||||||
|
auto *sent_id = context.Input<framework::Tensor>("SentId");
|
||||||
|
|
||||||
|
auto *word_emb = context.Input<framework::Tensor>("WordEmb");
|
||||||
|
auto *pos_emb = context.Input<framework::Tensor>("PosEmb");
|
||||||
|
auto *sent_emb = context.Input<framework::Tensor>("SentEmb");
|
||||||
|
|
||||||
|
auto *bias = context.Input<framework::Tensor>("Bias");
|
||||||
|
auto *scale = context.Input<framework::Tensor>("Scale");
|
||||||
|
auto *out = context.Output<framework::Tensor>("Out");
|
||||||
|
|
||||||
|
auto *word_id_d = word_id->data<int64_t>();
|
||||||
|
auto *pos_id_d = pos_id->data<int64_t>();
|
||||||
|
auto *sent_id_d = sent_id->data<int64_t>();
|
||||||
|
|
||||||
|
auto *word_emb_d = word_emb->data<T>();
|
||||||
|
auto *pos_emb_d = pos_emb->data<T>();
|
||||||
|
auto *sent_emb_d = sent_emb->data<T>();
|
||||||
|
|
||||||
|
auto *bias_d = bias->data<T>();
|
||||||
|
auto *scale_d = scale->data<T>();
|
||||||
|
auto *output_d = out->mutable_data<T>(context.GetPlace());
|
||||||
|
// compute q*k with eltadd
|
||||||
|
auto &device_ctx = context.template device_context<DeviceContext>();
|
||||||
|
float eps = context.Attr<float>("epsilon");
|
||||||
|
|
||||||
|
// should be (B * S * hidden)
|
||||||
|
auto word_id_dims = word_id->dims();
|
||||||
|
auto word_emb_dims = word_emb->dims();
|
||||||
|
|
||||||
|
int batch = word_id_dims[0];
|
||||||
|
int seq_len = word_id_dims[1];
|
||||||
|
int hidden = word_emb_dims[1];
|
||||||
|
|
||||||
|
const unsigned tpb = 256;
|
||||||
|
const dim3 grid(seq_len, batch, 1);
|
||||||
|
const dim3 block(tpb, 1, 1);
|
||||||
|
EmbEltwiseLayernormKernel<T, tpb><<<grid, block, 0, device_ctx.stream()>>>(
|
||||||
|
hidden, word_id_d, pos_id_d, sent_id_d, scale_d, bias_d, word_emb_d,
|
||||||
|
pos_emb_d, sent_emb_d, output_d, eps);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_CUDA_KERNEL(fused_embedding_eltwise_layernorm,
|
||||||
|
ops::EmbeddingEltWiseLayerNormKernel<
|
||||||
|
paddle::platform::CUDADeviceContext, float>);
|
@ -0,0 +1,78 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pass_test import PassTest
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
|
||||||
|
|
||||||
|
class EmbEltwiseLayerNormFusePassTest(PassTest):
|
||||||
|
def setUp(self):
|
||||||
|
with fluid.program_guard(self.main_program, self.startup_program):
|
||||||
|
word_id = fluid.layers.data(
|
||||||
|
name="word_id",
|
||||||
|
shape=[1, 128, 1],
|
||||||
|
dtype="int64",
|
||||||
|
append_batch_size=False)
|
||||||
|
pos_id = fluid.layers.data(
|
||||||
|
name="pos_id",
|
||||||
|
shape=[1, 128, 1],
|
||||||
|
dtype="int64",
|
||||||
|
append_batch_size=False)
|
||||||
|
sent_id = fluid.layers.data(
|
||||||
|
name="sent_id",
|
||||||
|
shape=[1, 128, 1],
|
||||||
|
dtype="int64",
|
||||||
|
append_batch_size=False)
|
||||||
|
word_emb = fluid.layers.embedding(
|
||||||
|
input=word_id, size=(128, 768), dtype='float32')
|
||||||
|
pos_emb = fluid.layers.embedding(
|
||||||
|
input=pos_id, size=(128, 768), dtype='float32')
|
||||||
|
sent_emb = fluid.layers.embedding(
|
||||||
|
input=sent_id, size=(128, 768), dtype='float32')
|
||||||
|
add1 = fluid.layers.elementwise_add(word_emb, pos_emb)
|
||||||
|
add2 = fluid.layers.elementwise_add(add1, sent_emb)
|
||||||
|
hidden1 = fluid.layers.layer_norm(input=add2, begin_norm_axis=2)
|
||||||
|
|
||||||
|
self.feeds = {
|
||||||
|
"word_id": np.random.randint(
|
||||||
|
low=0, high=128, size=(1, 128, 1)).astype("int64"),
|
||||||
|
"pos_id": np.random.randint(
|
||||||
|
low=0, high=128, size=(1, 128, 1)).astype("int64"),
|
||||||
|
"sent_id": np.random.randint(
|
||||||
|
low=0, high=128, size=(1, 128, 1)).astype("int64"),
|
||||||
|
}
|
||||||
|
self.fetch_list = [hidden1]
|
||||||
|
self.pass_names = "embedding_eltwise_layernorm_fuse_pass"
|
||||||
|
self.fused_op_type = "fused_embedding_eltwise_layernorm"
|
||||||
|
self.num_fused_ops = 1
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
use_gpu_set = [True]
|
||||||
|
if not core.is_compiled_with_cuda():
|
||||||
|
return
|
||||||
|
self.pass_attrs = {
|
||||||
|
"embedding_eltwise_layernorm_fuse_pass": {
|
||||||
|
"use_gpu": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
place = fluid.CUDAPlace(0)
|
||||||
|
self.check_output_with_place(place, startup_on_cpu=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue