parent
15bdb7ef14
commit
b969116988
@ -1 +1,2 @@
|
||||
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce)
|
||||
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu
|
||||
avg_pool_op_plugin.cu DEPS enforce pooling)
|
||||
|
@ -0,0 +1,62 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h"
|
||||
#include "paddle/fluid/operators/math/pooling.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
nvinfer1::Dims AvgPoolPlugin::getOutputDimensions(
|
||||
int index, const nvinfer1::Dims* inputDims, int nbInputs) {
|
||||
assert(nbInputs == 1);
|
||||
assert(index == 0);
|
||||
assert(inputDims[0].nbDims == 3);
|
||||
nvinfer1::Dims const& input_dims = inputDims[0];
|
||||
|
||||
nvinfer1::Dims output_dims = input_dims;
|
||||
|
||||
output_dims.d[1] = output_shape_[1];
|
||||
output_dims.d[2] = output_shape_[2];
|
||||
return output_dims;
|
||||
}
|
||||
|
||||
int AvgPoolPlugin::enqueue(int batchSize, const void* const* inputs,
|
||||
void** outputs, void* workspace,
|
||||
cudaStream_t stream) {
|
||||
auto const& input_dims = this->getInputDims(0);
|
||||
int input_size = 0;
|
||||
float const* idata = reinterpret_cast<float const*>(inputs[0]);
|
||||
float** odatas = reinterpret_cast<float**>(outputs);
|
||||
|
||||
paddle::operators::math::AvgPool<float> pool_process;
|
||||
paddle::operators::math::Pool2dDirectCUDAFunctor<
|
||||
paddle::operators::math::AvgPool<float>, float>
|
||||
pool2d_forward;
|
||||
|
||||
std::vector<int> input_shape = input_shape_;
|
||||
std::vector<int> output_shape = output_shape_;
|
||||
input_shape.insert(input_shape.begin(), batchSize);
|
||||
output_shape.insert(output_shape.begin(), batchSize);
|
||||
|
||||
pool2d_forward(idata, input_shape, output_shape, ksize_, strides_, paddings_,
|
||||
pool_process, true, odatas[0], stream);
|
||||
|
||||
return cudaGetLastError() != cudaSuccess;
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,109 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class AvgPoolPlugin : public PluginTensorRT {
|
||||
private:
|
||||
bool ceil_mode_;
|
||||
std::vector<int> ksize_;
|
||||
std::vector<int> strides_;
|
||||
std::vector<int> paddings_;
|
||||
std::vector<int> input_shape_;
|
||||
std::vector<int> output_shape_;
|
||||
|
||||
protected:
|
||||
size_t getSerializationSize() override {
|
||||
return SerializedSize(ceil_mode_) + SerializedSize(ksize_) +
|
||||
SerializedSize(strides_) + SerializedSize(paddings_) +
|
||||
SerializedSize(input_shape_) + getBaseSerializationSize();
|
||||
}
|
||||
|
||||
// TRT will call this func when we need to serialize the configuration of
|
||||
// tensorrt.
|
||||
// It should not be called by users.
|
||||
void serialize(void *buffer) override {
|
||||
serializeBase(buffer);
|
||||
SerializeValue(&buffer, ceil_mode_);
|
||||
SerializeValue(&buffer, ksize_);
|
||||
SerializeValue(&buffer, strides_);
|
||||
SerializeValue(&buffer, paddings_);
|
||||
SerializeValue(&buffer, input_shape_);
|
||||
}
|
||||
|
||||
public:
|
||||
AvgPoolPlugin(bool ceil_mode, std::vector<int> ksize,
|
||||
std::vector<int> strides, std::vector<int> paddings,
|
||||
std::vector<int> input_shape)
|
||||
: ceil_mode_(ceil_mode),
|
||||
ksize_(ksize),
|
||||
strides_(strides),
|
||||
paddings_(paddings),
|
||||
input_shape_(input_shape) {
|
||||
int output_h, output_w;
|
||||
output_shape_ = input_shape_;
|
||||
if (!ceil_mode_) {
|
||||
output_h =
|
||||
(input_shape[1] - ksize_[0] + 2 * paddings_[0]) / strides_[0] + 1;
|
||||
output_w =
|
||||
(input_shape[2] - ksize_[1] + 2 * paddings_[1]) / strides_[1] + 1;
|
||||
} else {
|
||||
output_h =
|
||||
(input_shape[1] - ksize_[0] + 2 * paddings_[0] + strides_[0] - 1) /
|
||||
strides_[0] +
|
||||
1;
|
||||
output_w =
|
||||
(input_shape[2] - ksize_[1] + 2 * paddings_[1] + strides_[1] - 1) /
|
||||
strides_[1] +
|
||||
1;
|
||||
}
|
||||
output_shape_[1] = output_h;
|
||||
output_shape_[2] = output_w;
|
||||
}
|
||||
|
||||
// It was used for tensorrt deserialization.
|
||||
// It should not be called by users.
|
||||
AvgPoolPlugin(void const *serialData, size_t serialLength) {
|
||||
deserializeBase(serialData, serialLength);
|
||||
DeserializeValue(&serialData, &serialLength, &ceil_mode_);
|
||||
DeserializeValue(&serialData, &serialLength, &ksize_);
|
||||
DeserializeValue(&serialData, &serialLength, &strides_);
|
||||
DeserializeValue(&serialData, &serialLength, &paddings_);
|
||||
DeserializeValue(&serialData, &serialLength, &input_shape_);
|
||||
}
|
||||
|
||||
AvgPoolPlugin *clone() const override {
|
||||
return new AvgPoolPlugin(ceil_mode_, ksize_, strides_, paddings_,
|
||||
input_shape_);
|
||||
}
|
||||
|
||||
const char *getPluginType() const override { return "avg_pool"; }
|
||||
int getNbOutputs() const override { return 1; }
|
||||
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
|
||||
int nbInputDims) override;
|
||||
int initialize() override { return 0; }
|
||||
int enqueue(int batchSize, const void *const *inputs, void **outputs,
|
||||
void *workspace, cudaStream_t stream) override;
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Loading…
Reference in new issue