parent
3dbd4087fe
commit
50bee83f71
@ -0,0 +1,75 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class InstanceNormOpConverter : public OpConverter {
|
||||
public:
|
||||
void operator()(const framework::proto::OpDesc& op,
|
||||
const framework::Scope& scope, bool test_mode) override {
|
||||
VLOG(4) << "convert fluid prelu op to tensorrt instance norm layer";
|
||||
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
|
||||
|
||||
float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
|
||||
|
||||
auto* scale_var = scope.FindVar(op_desc.Input("Scale")[0]);
|
||||
auto* bias_var = scope.FindVar(op_desc.Input("Bias")[0]);
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
scale_var,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input [Scale] of instance_norm op converter should not be null"));
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
bias_var,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input [Bias] of instance_norm op converter should not be null"));
|
||||
auto* scale_tensor = scale_var->GetMutable<framework::LoDTensor>();
|
||||
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
|
||||
PADDLE_ENFORCE_EQ(
|
||||
scale_tensor->numel(), bias_tensor->numel(),
|
||||
platform::errors::InvalidArgument(
|
||||
"Num of input [Scale] and [Bias] of instance_norm op converter "
|
||||
"should be equal. Got Scale num = %ld, but Bias num = %ld",
|
||||
scale_tensor->numel(), bias_tensor->numel()));
|
||||
auto* scale_d = scale_tensor->data<float>();
|
||||
auto* bias_d = bias_tensor->data<float>();
|
||||
|
||||
std::vector<float> scale_v;
|
||||
std::vector<float> bias_v;
|
||||
for (int i = 0; i < scale_tensor->numel(); i++) {
|
||||
scale_v.push_back(scale_d[i]);
|
||||
bias_v.push_back(bias_d[i]);
|
||||
}
|
||||
|
||||
plugin::InstanceNormPlugin* plugin =
|
||||
new plugin::InstanceNormPlugin(eps, scale_v, bias_v);
|
||||
plugin->getPluginType();
|
||||
nvinfer1::IPluginLayer* layer = engine_->AddPlugin(&input, 1, plugin);
|
||||
|
||||
auto output_name = op_desc.Output("Y")[0];
|
||||
RreplenishLayerAndOutput(layer, "instance_norm", {output_name}, test_mode);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_TRT_OP_CONVERTER(instance_norm, InstanceNormOpConverter);
|
@ -1,5 +1,5 @@
|
||||
nv_library(tensorrt_plugin
|
||||
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
|
||||
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
|
||||
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
|
||||
DEPS enforce tensorrt_engine prelu)
|
||||
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu instance_norm_op_plugin.cu
|
||||
DEPS enforce tensorrt_engine prelu tensor)
|
||||
|
@ -0,0 +1,119 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
|
||||
#include "paddle/fluid/platform/cudnn_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype,
|
||||
cudnnDataType_t *cudnn_dtype) {
|
||||
switch (trt_dtype) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
*cudnn_dtype = CUDNN_DATA_FLOAT;
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
*cudnn_dtype = CUDNN_DATA_HALF;
|
||||
break;
|
||||
default:
|
||||
return CUDNN_STATUS_BAD_PARAM;
|
||||
}
|
||||
return CUDNN_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
InstanceNormPlugin *CreateInstanceNormPluginDeserialize(const void *buffer,
|
||||
size_t length) {
|
||||
return new InstanceNormPlugin(buffer, length);
|
||||
}
|
||||
REGISTER_TRT_PLUGIN("instance_norm_plugin",
|
||||
CreateInstanceNormPluginDeserialize);
|
||||
|
||||
int InstanceNormPlugin::initialize() {
|
||||
platform::dynload::cudnnCreate(&handle_);
|
||||
platform::dynload::cudnnCreateTensorDescriptor(&x_desc_);
|
||||
platform::dynload::cudnnCreateTensorDescriptor(&y_desc_);
|
||||
platform::dynload::cudnnCreateTensorDescriptor(&b_desc_);
|
||||
return 0;
|
||||
}
|
||||
|
||||
nvinfer1::Dims InstanceNormPlugin::getOutputDimensions(
|
||||
int index, const nvinfer1::Dims *inputDims, int nbInputs) {
|
||||
assert(nbInputs == 1);
|
||||
assert(index < this->getNbOutputs());
|
||||
nvinfer1::Dims const &input_dims = inputDims[0];
|
||||
nvinfer1::Dims output_dims = input_dims;
|
||||
return output_dims;
|
||||
}
|
||||
|
||||
int InstanceNormPlugin::enqueue(int batch_size, const void *const *inputs,
|
||||
void **outputs, void *workspace,
|
||||
cudaStream_t stream) {
|
||||
const auto &input_dims = this->getInputDims(0);
|
||||
|
||||
PADDLE_ENFORCE_EQ(input_dims.nbDims, 3,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input Dims should be 3 (except the batch), got %d",
|
||||
input_dims.nbDims));
|
||||
int n = batch_size;
|
||||
int c = input_dims.d[0];
|
||||
int h = input_dims.d[1];
|
||||
int w = input_dims.d[2];
|
||||
|
||||
scale_t.Resize(framework::make_ddim({batch_size, c}));
|
||||
bias_t.Resize(framework::make_ddim({batch_size, c}));
|
||||
int device_id;
|
||||
cudaGetDevice(&device_id);
|
||||
float *scale_d = scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
|
||||
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id));
|
||||
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
cudaMemcpyAsync(scale_d + i * c, scale_.data(), sizeof(float) * c,
|
||||
cudaMemcpyHostToDevice, stream);
|
||||
cudaMemcpyAsync(bias_d + i * c, bias_.data(), sizeof(float) * c,
|
||||
cudaMemcpyHostToDevice, stream);
|
||||
}
|
||||
platform::dynload::cudnnSetTensor4dDescriptor(
|
||||
b_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1);
|
||||
|
||||
cudnnDataType_t cudnn_dtype;
|
||||
nvinfer1::DataType data_type = getDataType();
|
||||
convert_trt2cudnn_dtype(data_type, &cudnn_dtype);
|
||||
platform::dynload::cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
|
||||
cudnn_dtype, 1, n * c, h, w);
|
||||
platform::dynload::cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW,
|
||||
cudnn_dtype, 1, n * c, h, w);
|
||||
float alpha = 1;
|
||||
float beta = 0;
|
||||
platform::dynload::cudnnSetStream(handle_, stream);
|
||||
|
||||
void const *x_ptr = inputs[0];
|
||||
void *y_ptr = outputs[0];
|
||||
platform::dynload::cudnnBatchNormalizationForwardTraining(
|
||||
handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, &beta, x_desc_,
|
||||
x_ptr, y_desc_, y_ptr, b_desc_, scale_d, bias_d, 1., nullptr, nullptr,
|
||||
eps_, nullptr, nullptr);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,103 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
class InstanceNormPlugin : public PluginTensorRT {
|
||||
private:
|
||||
float eps_;
|
||||
std::vector<float> scale_;
|
||||
std::vector<float> bias_;
|
||||
|
||||
framework::Tensor scale_t;
|
||||
framework::Tensor bias_t;
|
||||
cudnnHandle_t handle_;
|
||||
cudnnTensorDescriptor_t x_desc_, y_desc_, b_desc_;
|
||||
|
||||
protected:
|
||||
size_t getSerializationSize() override {
|
||||
return getBaseSerializationSize() + SerializedSize(eps_) +
|
||||
SerializedSize(scale_) + SerializedSize(bias_);
|
||||
}
|
||||
|
||||
// TRT will call this func when we need to serialize the configuration of
|
||||
// tensorrt.
|
||||
// It should not be called by users.
|
||||
void serialize(void *buffer) override {
|
||||
SerializeValue(&buffer, getPluginType());
|
||||
serializeBase(buffer);
|
||||
SerializeValue(&buffer, eps_);
|
||||
SerializeValue(&buffer, scale_);
|
||||
SerializeValue(&buffer, bias_);
|
||||
}
|
||||
|
||||
public:
|
||||
explicit InstanceNormPlugin(const float eps, const std::vector<float> scale,
|
||||
const std::vector<float> bias)
|
||||
: eps_(eps), scale_(scale), bias_(bias) {
|
||||
PADDLE_ENFORCE_EQ(scale.size(), bias.size(),
|
||||
platform::errors::InvalidArgument(
|
||||
"The instanceNorm's scale and bias should be the "
|
||||
"same size. Got scale size = %d, but bias size = %d",
|
||||
scale.size(), bias.size()));
|
||||
}
|
||||
|
||||
// It was used for tensorrt deserialization.
|
||||
// It should not be called by users.
|
||||
InstanceNormPlugin(void const *serialData, size_t serialLength) {
|
||||
deserializeBase(serialData, serialLength);
|
||||
DeserializeValue(&serialData, &serialLength, &eps_);
|
||||
DeserializeValue(&serialData, &serialLength, &scale_);
|
||||
DeserializeValue(&serialData, &serialLength, &bias_);
|
||||
}
|
||||
|
||||
~InstanceNormPlugin() {}
|
||||
int initialize() override;
|
||||
|
||||
InstanceNormPlugin *clone() const override {
|
||||
return new InstanceNormPlugin(eps_, scale_, bias_);
|
||||
}
|
||||
|
||||
const char *getPluginType() const override { return "instance_norm_plugin"; }
|
||||
int getNbOutputs() const override { return 1; }
|
||||
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
|
||||
int nbInputDims) override;
|
||||
int enqueue(int batchSize, const void *const *inputs, void **outputs,
|
||||
void *workspace, cudaStream_t stream) override;
|
||||
|
||||
bool supportsFormat(nvinfer1::DataType type,
|
||||
nvinfer1::PluginFormat format) const override {
|
||||
return ((type == nvinfer1::DataType::kFLOAT ||
|
||||
type == nvinfer1::DataType::kHALF) &&
|
||||
(format == nvinfer1::PluginFormat::kNCHW));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,50 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <glog/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
|
||||
TEST(TensorRT, instance_norm) {
|
||||
std::string model_dir = FLAGS_infer_model + "/instance_norm";
|
||||
AnalysisConfig config;
|
||||
int batch_size = 4;
|
||||
config.EnableUseGpu(100, 0);
|
||||
config.SetModel(model_dir);
|
||||
config.SwitchUseFeedFetchOps(false);
|
||||
config.EnableTensorRtEngine(1 << 20, batch_size, 0,
|
||||
AnalysisConfig::Precision::kFloat32, false);
|
||||
|
||||
auto predictor = CreatePaddlePredictor(config);
|
||||
|
||||
int length = 4;
|
||||
int input_num = batch_size * length;
|
||||
float *input = new float[input_num];
|
||||
memset(input, 1.0, input_num * sizeof(float));
|
||||
|
||||
auto input_names = predictor->GetInputNames();
|
||||
auto input_t = predictor->GetInputTensor(input_names[0]);
|
||||
input_t->Reshape({batch_size, length});
|
||||
input_t->copy_from_cpu(input);
|
||||
|
||||
ASSERT_TRUE(predictor->ZeroCopyRun());
|
||||
}
|
||||
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Loading…
Reference in new issue