parent
3dbd4087fe
commit
50bee83f71
@ -0,0 +1,75 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
class InstanceNormOpConverter : public OpConverter {
|
||||||
|
public:
|
||||||
|
void operator()(const framework::proto::OpDesc& op,
|
||||||
|
const framework::Scope& scope, bool test_mode) override {
|
||||||
|
VLOG(4) << "convert fluid prelu op to tensorrt instance norm layer";
|
||||||
|
|
||||||
|
framework::OpDesc op_desc(op, nullptr);
|
||||||
|
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
|
||||||
|
|
||||||
|
float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
|
||||||
|
|
||||||
|
auto* scale_var = scope.FindVar(op_desc.Input("Scale")[0]);
|
||||||
|
auto* bias_var = scope.FindVar(op_desc.Input("Bias")[0]);
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(
|
||||||
|
scale_var,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input [Scale] of instance_norm op converter should not be null"));
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(
|
||||||
|
bias_var,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input [Bias] of instance_norm op converter should not be null"));
|
||||||
|
auto* scale_tensor = scale_var->GetMutable<framework::LoDTensor>();
|
||||||
|
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
scale_tensor->numel(), bias_tensor->numel(),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Num of input [Scale] and [Bias] of instance_norm op converter "
|
||||||
|
"should be equal. Got Scale num = %ld, but Bias num = %ld",
|
||||||
|
scale_tensor->numel(), bias_tensor->numel()));
|
||||||
|
auto* scale_d = scale_tensor->data<float>();
|
||||||
|
auto* bias_d = bias_tensor->data<float>();
|
||||||
|
|
||||||
|
std::vector<float> scale_v;
|
||||||
|
std::vector<float> bias_v;
|
||||||
|
for (int i = 0; i < scale_tensor->numel(); i++) {
|
||||||
|
scale_v.push_back(scale_d[i]);
|
||||||
|
bias_v.push_back(bias_d[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
plugin::InstanceNormPlugin* plugin =
|
||||||
|
new plugin::InstanceNormPlugin(eps, scale_v, bias_v);
|
||||||
|
plugin->getPluginType();
|
||||||
|
nvinfer1::IPluginLayer* layer = engine_->AddPlugin(&input, 1, plugin);
|
||||||
|
|
||||||
|
auto output_name = op_desc.Output("Y")[0];
|
||||||
|
RreplenishLayerAndOutput(layer, "instance_norm", {output_name}, test_mode);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_TRT_OP_CONVERTER(instance_norm, InstanceNormOpConverter);
|
@ -1,5 +1,5 @@
|
|||||||
nv_library(tensorrt_plugin
|
nv_library(tensorrt_plugin
|
||||||
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
|
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
|
||||||
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
|
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
|
||||||
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
|
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu instance_norm_op_plugin.cu
|
||||||
DEPS enforce tensorrt_engine prelu)
|
DEPS enforce tensorrt_engine prelu tensor)
|
||||||
|
@ -0,0 +1,119 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <cassert>
|
||||||
|
#include <vector>
|
||||||
|
#include "glog/logging.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
|
||||||
|
#include "paddle/fluid/platform/cudnn_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace plugin {
|
||||||
|
|
||||||
|
cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype,
|
||||||
|
cudnnDataType_t *cudnn_dtype) {
|
||||||
|
switch (trt_dtype) {
|
||||||
|
case nvinfer1::DataType::kFLOAT:
|
||||||
|
*cudnn_dtype = CUDNN_DATA_FLOAT;
|
||||||
|
break;
|
||||||
|
case nvinfer1::DataType::kHALF:
|
||||||
|
*cudnn_dtype = CUDNN_DATA_HALF;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return CUDNN_STATUS_BAD_PARAM;
|
||||||
|
}
|
||||||
|
return CUDNN_STATUS_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
InstanceNormPlugin *CreateInstanceNormPluginDeserialize(const void *buffer,
|
||||||
|
size_t length) {
|
||||||
|
return new InstanceNormPlugin(buffer, length);
|
||||||
|
}
|
||||||
|
REGISTER_TRT_PLUGIN("instance_norm_plugin",
|
||||||
|
CreateInstanceNormPluginDeserialize);
|
||||||
|
|
||||||
|
int InstanceNormPlugin::initialize() {
|
||||||
|
platform::dynload::cudnnCreate(&handle_);
|
||||||
|
platform::dynload::cudnnCreateTensorDescriptor(&x_desc_);
|
||||||
|
platform::dynload::cudnnCreateTensorDescriptor(&y_desc_);
|
||||||
|
platform::dynload::cudnnCreateTensorDescriptor(&b_desc_);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::Dims InstanceNormPlugin::getOutputDimensions(
|
||||||
|
int index, const nvinfer1::Dims *inputDims, int nbInputs) {
|
||||||
|
assert(nbInputs == 1);
|
||||||
|
assert(index < this->getNbOutputs());
|
||||||
|
nvinfer1::Dims const &input_dims = inputDims[0];
|
||||||
|
nvinfer1::Dims output_dims = input_dims;
|
||||||
|
return output_dims;
|
||||||
|
}
|
||||||
|
|
||||||
|
int InstanceNormPlugin::enqueue(int batch_size, const void *const *inputs,
|
||||||
|
void **outputs, void *workspace,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const auto &input_dims = this->getInputDims(0);
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(input_dims.nbDims, 3,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input Dims should be 3 (except the batch), got %d",
|
||||||
|
input_dims.nbDims));
|
||||||
|
int n = batch_size;
|
||||||
|
int c = input_dims.d[0];
|
||||||
|
int h = input_dims.d[1];
|
||||||
|
int w = input_dims.d[2];
|
||||||
|
|
||||||
|
scale_t.Resize(framework::make_ddim({batch_size, c}));
|
||||||
|
bias_t.Resize(framework::make_ddim({batch_size, c}));
|
||||||
|
int device_id;
|
||||||
|
cudaGetDevice(&device_id);
|
||||||
|
float *scale_d = scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
|
||||||
|
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id));
|
||||||
|
|
||||||
|
for (int i = 0; i < batch_size; i++) {
|
||||||
|
cudaMemcpyAsync(scale_d + i * c, scale_.data(), sizeof(float) * c,
|
||||||
|
cudaMemcpyHostToDevice, stream);
|
||||||
|
cudaMemcpyAsync(bias_d + i * c, bias_.data(), sizeof(float) * c,
|
||||||
|
cudaMemcpyHostToDevice, stream);
|
||||||
|
}
|
||||||
|
platform::dynload::cudnnSetTensor4dDescriptor(
|
||||||
|
b_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1);
|
||||||
|
|
||||||
|
cudnnDataType_t cudnn_dtype;
|
||||||
|
nvinfer1::DataType data_type = getDataType();
|
||||||
|
convert_trt2cudnn_dtype(data_type, &cudnn_dtype);
|
||||||
|
platform::dynload::cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
|
||||||
|
cudnn_dtype, 1, n * c, h, w);
|
||||||
|
platform::dynload::cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW,
|
||||||
|
cudnn_dtype, 1, n * c, h, w);
|
||||||
|
float alpha = 1;
|
||||||
|
float beta = 0;
|
||||||
|
platform::dynload::cudnnSetStream(handle_, stream);
|
||||||
|
|
||||||
|
void const *x_ptr = inputs[0];
|
||||||
|
void *y_ptr = outputs[0];
|
||||||
|
platform::dynload::cudnnBatchNormalizationForwardTraining(
|
||||||
|
handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, &beta, x_desc_,
|
||||||
|
x_ptr, y_desc_, y_ptr, b_desc_, scale_d, bias_d, 1., nullptr, nullptr,
|
||||||
|
eps_, nullptr, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace plugin
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,103 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/tensor.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace plugin {
|
||||||
|
|
||||||
|
class InstanceNormPlugin : public PluginTensorRT {
|
||||||
|
private:
|
||||||
|
float eps_;
|
||||||
|
std::vector<float> scale_;
|
||||||
|
std::vector<float> bias_;
|
||||||
|
|
||||||
|
framework::Tensor scale_t;
|
||||||
|
framework::Tensor bias_t;
|
||||||
|
cudnnHandle_t handle_;
|
||||||
|
cudnnTensorDescriptor_t x_desc_, y_desc_, b_desc_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
size_t getSerializationSize() override {
|
||||||
|
return getBaseSerializationSize() + SerializedSize(eps_) +
|
||||||
|
SerializedSize(scale_) + SerializedSize(bias_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TRT will call this func when we need to serialize the configuration of
|
||||||
|
// tensorrt.
|
||||||
|
// It should not be called by users.
|
||||||
|
void serialize(void *buffer) override {
|
||||||
|
SerializeValue(&buffer, getPluginType());
|
||||||
|
serializeBase(buffer);
|
||||||
|
SerializeValue(&buffer, eps_);
|
||||||
|
SerializeValue(&buffer, scale_);
|
||||||
|
SerializeValue(&buffer, bias_);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit InstanceNormPlugin(const float eps, const std::vector<float> scale,
|
||||||
|
const std::vector<float> bias)
|
||||||
|
: eps_(eps), scale_(scale), bias_(bias) {
|
||||||
|
PADDLE_ENFORCE_EQ(scale.size(), bias.size(),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The instanceNorm's scale and bias should be the "
|
||||||
|
"same size. Got scale size = %d, but bias size = %d",
|
||||||
|
scale.size(), bias.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// It was used for tensorrt deserialization.
|
||||||
|
// It should not be called by users.
|
||||||
|
InstanceNormPlugin(void const *serialData, size_t serialLength) {
|
||||||
|
deserializeBase(serialData, serialLength);
|
||||||
|
DeserializeValue(&serialData, &serialLength, &eps_);
|
||||||
|
DeserializeValue(&serialData, &serialLength, &scale_);
|
||||||
|
DeserializeValue(&serialData, &serialLength, &bias_);
|
||||||
|
}
|
||||||
|
|
||||||
|
~InstanceNormPlugin() {}
|
||||||
|
int initialize() override;
|
||||||
|
|
||||||
|
InstanceNormPlugin *clone() const override {
|
||||||
|
return new InstanceNormPlugin(eps_, scale_, bias_);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *getPluginType() const override { return "instance_norm_plugin"; }
|
||||||
|
int getNbOutputs() const override { return 1; }
|
||||||
|
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
|
||||||
|
int nbInputDims) override;
|
||||||
|
int enqueue(int batchSize, const void *const *inputs, void **outputs,
|
||||||
|
void *workspace, cudaStream_t stream) override;
|
||||||
|
|
||||||
|
bool supportsFormat(nvinfer1::DataType type,
|
||||||
|
nvinfer1::PluginFormat format) const override {
|
||||||
|
return ((type == nvinfer1::DataType::kFLOAT ||
|
||||||
|
type == nvinfer1::DataType::kHALF) &&
|
||||||
|
(format == nvinfer1::PluginFormat::kNCHW));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace plugin
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,50 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <glog/logging.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
|
||||||
|
TEST(TensorRT, instance_norm) {
|
||||||
|
std::string model_dir = FLAGS_infer_model + "/instance_norm";
|
||||||
|
AnalysisConfig config;
|
||||||
|
int batch_size = 4;
|
||||||
|
config.EnableUseGpu(100, 0);
|
||||||
|
config.SetModel(model_dir);
|
||||||
|
config.SwitchUseFeedFetchOps(false);
|
||||||
|
config.EnableTensorRtEngine(1 << 20, batch_size, 0,
|
||||||
|
AnalysisConfig::Precision::kFloat32, false);
|
||||||
|
|
||||||
|
auto predictor = CreatePaddlePredictor(config);
|
||||||
|
|
||||||
|
int length = 4;
|
||||||
|
int input_num = batch_size * length;
|
||||||
|
float *input = new float[input_num];
|
||||||
|
memset(input, 1.0, input_num * sizeof(float));
|
||||||
|
|
||||||
|
auto input_names = predictor->GetInputNames();
|
||||||
|
auto input_t = predictor->GetInputTensor(input_names[0]);
|
||||||
|
input_t->Reshape({batch_size, length});
|
||||||
|
input_t->copy_from_cpu(input);
|
||||||
|
|
||||||
|
ASSERT_TRUE(predictor->ZeroCopyRun());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
Loading…
Reference in new issue