* 裁剪transformer模型trt支持;修复tensorRT不支持DeletePass的bug (#28517) * skip_layernorm_op done * add unittest * slice op convertor support trt < 6 * skip_layernorm only work in ernie * fix unittestrelease/2.0-rc
parent
2318fb0e77
commit
530deb144e
@ -0,0 +1,139 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <glog/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
|
||||
void run(const AnalysisConfig& config, std::vector<float>* out_data) {
|
||||
auto predictor = CreatePaddlePredictor(config);
|
||||
auto input_names = predictor->GetInputNames();
|
||||
|
||||
int run_batch = 1;
|
||||
const int run_seq_len = 128;
|
||||
|
||||
std::vector<int64_t> tmp_input;
|
||||
std::vector<float> tmp_four_input;
|
||||
tmp_input.reserve(run_batch * run_seq_len);
|
||||
tmp_four_input.reserve(run_batch * run_seq_len);
|
||||
|
||||
int64_t i0[run_seq_len] = {
|
||||
1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321,
|
||||
4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2,
|
||||
75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2};
|
||||
int64_t i1[run_seq_len] = {
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
int64_t i2[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
||||
20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
|
||||
30, 31, 32, 33, 34, 35, 36, 37, 38, 39};
|
||||
float i3[run_seq_len] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
|
||||
// first input
|
||||
auto input_t = predictor->GetInputTensor(input_names[0]);
|
||||
input_t->Reshape({run_batch, run_seq_len, 1});
|
||||
input_t->copy_from_cpu(i0);
|
||||
|
||||
// second input
|
||||
auto input_t2 = predictor->GetInputTensor(input_names[1]);
|
||||
input_t2->Reshape({run_batch, run_seq_len, 1});
|
||||
input_t2->copy_from_cpu(i1);
|
||||
|
||||
// third input.
|
||||
auto input_t3 = predictor->GetInputTensor(input_names[2]);
|
||||
input_t3->Reshape({run_batch, run_seq_len, 1});
|
||||
input_t3->copy_from_cpu(i2);
|
||||
|
||||
auto input_t4 = predictor->GetInputTensor(input_names[3]);
|
||||
input_t4->Reshape({run_batch, run_seq_len, 1});
|
||||
input_t4->copy_from_cpu(i3);
|
||||
|
||||
ASSERT_TRUE(predictor->ZeroCopyRun());
|
||||
|
||||
auto output_names = predictor->GetOutputNames();
|
||||
auto output_t = predictor->GetOutputTensor(output_names[0]);
|
||||
std::vector<int> output_shape = output_t->shape();
|
||||
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
|
||||
std::multiplies<int>());
|
||||
out_data->resize(out_num);
|
||||
output_t->copy_to_cpu(out_data->data());
|
||||
}
|
||||
|
||||
void trt_ernie(bool with_fp16, std::vector<float> result) {
|
||||
AnalysisConfig config;
|
||||
std::string model_dir = FLAGS_infer_model;
|
||||
SetConfig(&config, model_dir, true);
|
||||
|
||||
config.SwitchUseFeedFetchOps(false);
|
||||
|
||||
int batch = 32;
|
||||
int min_seq_len = 1;
|
||||
int max_seq_len = 128;
|
||||
int opt_seq_len = 128;
|
||||
|
||||
std::vector<int> min_shape = {1, min_seq_len, 1};
|
||||
std::vector<int> max_shape = {batch, max_seq_len, 1};
|
||||
std::vector<int> opt_shape = {batch, opt_seq_len, 1};
|
||||
// Set the input's min, max, opt shape
|
||||
std::map<std::string, std::vector<int>> min_input_shape = {
|
||||
{"read_file_0.tmp_0", min_shape},
|
||||
{"read_file_0.tmp_1", min_shape},
|
||||
{"read_file_0.tmp_2", min_shape},
|
||||
{"read_file_0.tmp_3", min_shape}};
|
||||
std::map<std::string, std::vector<int>> max_input_shape = {
|
||||
{"read_file_0.tmp_0", max_shape},
|
||||
{"read_file_0.tmp_1", max_shape},
|
||||
{"read_file_0.tmp_2", max_shape},
|
||||
{"read_file_0.tmp_3", max_shape}};
|
||||
std::map<std::string, std::vector<int>> opt_input_shape = {
|
||||
{"read_file_0.tmp_0", opt_shape},
|
||||
{"read_file_0.tmp_1", opt_shape},
|
||||
{"read_file_0.tmp_2", opt_shape},
|
||||
{"read_file_0.tmp_3", opt_shape}};
|
||||
|
||||
auto precision = AnalysisConfig::Precision::kFloat32;
|
||||
if (with_fp16) {
|
||||
precision = AnalysisConfig::Precision::kHalf;
|
||||
}
|
||||
config.EnableTensorRtEngine(1 << 30, 1, 12, precision, false, false);
|
||||
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
|
||||
opt_input_shape);
|
||||
std::vector<float> out_data;
|
||||
run(config, &out_data);
|
||||
|
||||
for (size_t i = 0; i < out_data.size(); i++) {
|
||||
EXPECT_NEAR(result[i], out_data[i], 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(AnalysisPredictor, no_fp16) {
|
||||
std::vector<float> result = {0.498667, 0.501333};
|
||||
trt_ernie(false, result);
|
||||
}
|
||||
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,91 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/errors.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SkipLayerNormOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext *context) const override {
|
||||
PADDLE_ENFORCE_EQ(context->HasInput("X"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(X) of MultiHeadMatMul should not be null."));
|
||||
PADDLE_ENFORCE_EQ(context->HasInput("Y"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Y) of MultiHeadMatMul should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
context->HasInput("Scale"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Scale) of MultiHeadMatMul should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
context->HasInput("Bias"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Bias) of MultiHeadMatMul should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
context->HasOutput("Out"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Output(Out) of MultiHeadMatMul should not be null."));
|
||||
|
||||
auto dim_input = context->GetInputDim("X");
|
||||
context->SetOutputDim("Out", dim_input);
|
||||
context->ShareLoD("X", "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class SkipLayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "The X input of SkipLayerNorm op");
|
||||
AddInput("Y", "The Y input of SkipLayerNorm op");
|
||||
AddInput("Scale", "The scale input of SkipLayerNorm op");
|
||||
AddInput("Bias", "The bias input of SkipLayerNorm op");
|
||||
AddOutput("Out", "The output of SkipLayerNorm op");
|
||||
AddAttr<float>("epsilon",
|
||||
"param epsilon of layer_norm op in "
|
||||
"skip_layernorm_fuse_pass");
|
||||
AddAttr<int>("begin_norm_axis",
|
||||
"param begin_norm_axis of "
|
||||
"layer_norm op in skip_layernorm_fuse_pass");
|
||||
AddComment(R"DOC(
|
||||
SkipLayerNorm Operator.
|
||||
|
||||
This op is used for skip_layernorm_fuse_pass, which fuse op pattern as followed.
|
||||
|
||||
| | | |
|
||||
other_op1 other_op2 other_op1 other_op2
|
||||
| | fuse \ /
|
||||
|------elementwise_add -> skip_layernorm
|
||||
| |
|
||||
layer_norm other_op3
|
||||
| |
|
||||
other_op3
|
||||
|
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(skip_layernorm, ops::SkipLayerNormOp,
|
||||
ops::SkipLayerNormOpMaker);
|
@ -0,0 +1,66 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <paddle/fluid/platform/device_context.h>
|
||||
#include <algorithm>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/memory/malloc.h"
|
||||
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SkipLayerNormKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
using Tensor = framework::Tensor;
|
||||
auto *X = context.Input<framework::Tensor>("X");
|
||||
auto *Y = context.Input<framework::Tensor>("Y");
|
||||
auto *scale = context.Input<framework::Tensor>("Scale");
|
||||
auto *bias = context.Input<framework::Tensor>("Bias");
|
||||
|
||||
auto *X_d = X->data<T>();
|
||||
auto *Y_d = Y->data<T>();
|
||||
auto *scale_d = scale->data<T>();
|
||||
auto *bias_d = bias->data<T>();
|
||||
float epsilon = context.Attr<float>("epsilon");
|
||||
int begin_norm_axis = context.Attr<int>("begin_norm_axis");
|
||||
|
||||
auto *out = context.Output<framework::Tensor>("Out");
|
||||
out->Resize(X->dims());
|
||||
auto *output_d = out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
size_t num = 1;
|
||||
for (size_t i = 0; i < X->dims().size(); i++) {
|
||||
num *= X->dims()[i];
|
||||
}
|
||||
int hidden = X->dims()[2];
|
||||
auto &device_ctx = context.template device_context<DeviceContext>();
|
||||
operators::math::SkipLayerNormFunctor<T> skip_layer_norm_func;
|
||||
|
||||
skip_layer_norm_func(num, hidden, X_d, Y_d, scale_d, bias_d, output_d,
|
||||
epsilon, device_ctx.stream());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
skip_layernorm,
|
||||
ops::SkipLayerNormKernel<paddle::platform::CUDADeviceContext, float>);
|
Loading…
Reference in new issue