Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into develop
test=developlocal_add_cudnn_lstm
commit
5857fb3014
@ -1,3 +1,4 @@
|
||||
nv_library(tensorrt_plugin
|
||||
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
|
||||
avg_pool_op_plugin.cu
|
||||
DEPS enforce tensorrt_engine)
|
||||
|
@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h"
|
||||
#include "paddle/fluid/operators/math/pooling.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
nvinfer1::Dims AvgPoolPlugin::getOutputDimensions(
|
||||
int index, const nvinfer1::Dims* inputDims, int nbInputs) {
|
||||
assert(nbInputs == 1);
|
||||
assert(index == 0);
|
||||
assert(inputDims[0].nbDims == 3);
|
||||
nvinfer1::Dims const& input_dims = inputDims[0];
|
||||
|
||||
nvinfer1::Dims output_dims = input_dims;
|
||||
|
||||
output_dims.d[1] = output_shape_[1];
|
||||
output_dims.d[2] = output_shape_[2];
|
||||
return output_dims;
|
||||
}
|
||||
|
||||
int AvgPoolPlugin::enqueue(int batchSize, const void* const* inputs,
|
||||
void** outputs, void* workspace,
|
||||
cudaStream_t stream) {
|
||||
auto const& input_dims = this->getInputDims(0);
|
||||
int input_size = 0;
|
||||
float const* idata = reinterpret_cast<float const*>(inputs[0]);
|
||||
float** odatas = reinterpret_cast<float**>(outputs);
|
||||
|
||||
paddle::operators::math::AvgPool<float> pool_process;
|
||||
paddle::operators::math::Pool2dDirectCUDAFunctor<
|
||||
paddle::operators::math::AvgPool<float>, float>
|
||||
pool2d_forward;
|
||||
|
||||
std::vector<int> input_shape = input_shape_;
|
||||
std::vector<int> output_shape = output_shape_;
|
||||
input_shape.insert(input_shape.begin(), batchSize);
|
||||
output_shape.insert(output_shape.begin(), batchSize);
|
||||
|
||||
pool2d_forward(idata, input_shape, output_shape, ksize_, strides_, paddings_,
|
||||
pool_process, true, odatas[0], stream);
|
||||
|
||||
return cudaGetLastError() != cudaSuccess;
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,111 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
class AvgPoolPlugin : public PluginTensorRT {
|
||||
private:
|
||||
bool ceil_mode_;
|
||||
std::vector<int> ksize_;
|
||||
std::vector<int> strides_;
|
||||
std::vector<int> paddings_;
|
||||
std::vector<int> input_shape_;
|
||||
std::vector<int> output_shape_;
|
||||
|
||||
protected:
|
||||
size_t getSerializationSize() override {
|
||||
return SerializedSize(ceil_mode_) + SerializedSize(ksize_) +
|
||||
SerializedSize(strides_) + SerializedSize(paddings_) +
|
||||
SerializedSize(input_shape_) + getBaseSerializationSize();
|
||||
}
|
||||
|
||||
// TRT will call this func when we need to serialize the configuration of
|
||||
// tensorrt.
|
||||
// It should not be called by users.
|
||||
void serialize(void *buffer) override {
|
||||
serializeBase(buffer);
|
||||
SerializeValue(&buffer, ceil_mode_);
|
||||
SerializeValue(&buffer, ksize_);
|
||||
SerializeValue(&buffer, strides_);
|
||||
SerializeValue(&buffer, paddings_);
|
||||
SerializeValue(&buffer, input_shape_);
|
||||
}
|
||||
|
||||
public:
|
||||
AvgPoolPlugin(bool ceil_mode, std::vector<int> ksize,
|
||||
std::vector<int> strides, std::vector<int> paddings,
|
||||
std::vector<int> input_shape)
|
||||
: ceil_mode_(ceil_mode),
|
||||
ksize_(ksize),
|
||||
strides_(strides),
|
||||
paddings_(paddings),
|
||||
input_shape_(input_shape) {
|
||||
int output_h, output_w;
|
||||
output_shape_ = input_shape_;
|
||||
if (!ceil_mode_) {
|
||||
output_h =
|
||||
(input_shape[1] - ksize_[0] + 2 * paddings_[0]) / strides_[0] + 1;
|
||||
output_w =
|
||||
(input_shape[2] - ksize_[1] + 2 * paddings_[1]) / strides_[1] + 1;
|
||||
} else {
|
||||
output_h =
|
||||
(input_shape[1] - ksize_[0] + 2 * paddings_[0] + strides_[0] - 1) /
|
||||
strides_[0] +
|
||||
1;
|
||||
output_w =
|
||||
(input_shape[2] - ksize_[1] + 2 * paddings_[1] + strides_[1] - 1) /
|
||||
strides_[1] +
|
||||
1;
|
||||
}
|
||||
output_shape_[1] = output_h;
|
||||
output_shape_[2] = output_w;
|
||||
}
|
||||
|
||||
// It was used for tensorrt deserialization.
|
||||
// It should not be called by users.
|
||||
AvgPoolPlugin(void const *serialData, size_t serialLength) {
|
||||
deserializeBase(serialData, serialLength);
|
||||
DeserializeValue(&serialData, &serialLength, &ceil_mode_);
|
||||
DeserializeValue(&serialData, &serialLength, &ksize_);
|
||||
DeserializeValue(&serialData, &serialLength, &strides_);
|
||||
DeserializeValue(&serialData, &serialLength, &paddings_);
|
||||
DeserializeValue(&serialData, &serialLength, &input_shape_);
|
||||
}
|
||||
|
||||
AvgPoolPlugin *clone() const override {
|
||||
return new AvgPoolPlugin(ceil_mode_, ksize_, strides_, paddings_,
|
||||
input_shape_);
|
||||
}
|
||||
|
||||
const char *getPluginType() const override { return "avg_pool"; }
|
||||
int getNbOutputs() const override { return 1; }
|
||||
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
|
||||
int nbInputDims) override;
|
||||
int initialize() override { return 0; }
|
||||
int enqueue(int batchSize, const void *const *inputs, void **outputs,
|
||||
void *workspace, cudaStream_t stream) override;
|
||||
};
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,201 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <mkldnn/include/mkldnn.hpp>
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
|
||||
|
||||
#include "paddle/fluid/platform/mkldnn_helper.h"
|
||||
|
||||
#include "paddle/fluid/operators/math/jit_kernel.h"
|
||||
#include "xbyak.h"
|
||||
#include "xbyak_util.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::DataLayout;
|
||||
using mkldnn::memory;
|
||||
|
||||
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
|
||||
std::transform(format.begin(), format.end(), format.begin(), ::tolower);
|
||||
|
||||
if (!format.compare("nchw")) {
|
||||
return memory::format::nchw;
|
||||
} else if (!format.compare("nchw16c")) {
|
||||
return memory::format::nChw16c;
|
||||
} else if (!format.compare("nchw8c")) {
|
||||
return memory::format::nChw8c;
|
||||
} else if (!format.compare("nhwc")) {
|
||||
return memory::format::nhwc;
|
||||
} else {
|
||||
return memory::format::any;
|
||||
}
|
||||
}
|
||||
|
||||
static void UpdateDataFormat(const framework::ExecutionContext& ctx,
|
||||
framework::Tensor* tensor, const char* attribute) {
|
||||
if (ctx.op().HasAttr(attribute)) {
|
||||
auto format_as_string = ctx.Attr<std::string>(attribute);
|
||||
auto format = StringToMKLDNNFormat(format_as_string);
|
||||
if (format != memory::format::any) {
|
||||
tensor->set_format(format);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void ReorderInput(framework::Tensor* tensor,
|
||||
const platform::Place& place,
|
||||
const mkldnn::engine& engine, bool isFourDim) {
|
||||
using platform::to_void_cast;
|
||||
auto dims = paddle::framework::vectorize2int(tensor->dims());
|
||||
framework::Tensor out_tensor;
|
||||
out_tensor.Resize(tensor->dims());
|
||||
out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc);
|
||||
out_tensor.set_layout(tensor->layout());
|
||||
mkldnn::memory input_memory = {
|
||||
{{dims, platform::MKLDNNGetDataType<T>(), tensor->format()}, engine},
|
||||
to_void_cast<T>(tensor->data<T>())};
|
||||
mkldnn::memory output_memory = {
|
||||
{{dims, platform::MKLDNNGetDataType<T>(), out_tensor.format()}, engine},
|
||||
to_void_cast<T>(out_tensor.mutable_data<T>(place))};
|
||||
platform::Reorder(input_memory, output_memory);
|
||||
tensor->ShareDataWith(out_tensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
int axis = ctx.Attr<int>("axis");
|
||||
auto* x = ctx.Input<Tensor>("X");
|
||||
auto* y = ctx.Input<Tensor>("Y");
|
||||
auto* z = ctx.Output<Tensor>("Out");
|
||||
const T* x_data = x->data<T>();
|
||||
const T* y_data = y->data<T>();
|
||||
T* z_data = z->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto x_dims = x->dims();
|
||||
auto y_dims_untrimmed = y->dims();
|
||||
auto x_int_dims = paddle::framework::vectorize2int(x_dims);
|
||||
|
||||
UpdateDataFormat(ctx, (Tensor*)x, "x_data_format");
|
||||
UpdateDataFormat(ctx, (Tensor*)y, "y_data_format");
|
||||
|
||||
Xbyak::util::Cpu cpu;
|
||||
const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F);
|
||||
const bool are_dims_divisable = !(x_int_dims[1] % 16);
|
||||
const bool is_x_format_correct = x->format() == memory::format::nChw16c;
|
||||
const bool is_y_format_correct = y->format() == memory::format::nc;
|
||||
if (is_x_format_correct && is_y_format_correct && are_dims_divisable &&
|
||||
is_avx512_enabled) {
|
||||
int pre, n, post;
|
||||
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
|
||||
|
||||
if (post == 1) {
|
||||
PADDLE_THROW("Not implemented when post is 1");
|
||||
} else {
|
||||
// Just check whether it works for RE-Resnext.
|
||||
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions");
|
||||
|
||||
int n = x_dims[0];
|
||||
int c = x_dims[1];
|
||||
int h = x_dims[2];
|
||||
int w = x_dims[3];
|
||||
|
||||
PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c,
|
||||
"Y should be in nc format");
|
||||
|
||||
constexpr int simd_width = 16;
|
||||
int C = c / simd_width;
|
||||
|
||||
const auto& multiply =
|
||||
math::jitkernel::KernelPool::Instance()
|
||||
.template Get<math::jitkernel::EltwiseMulnChw16cNCKernel<T>>(n);
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int ni = 0; ni < n; ni++) {
|
||||
for (int ci = 0; ci < C; ci++) {
|
||||
auto ptr_x =
|
||||
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
|
||||
|
||||
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
|
||||
auto ptr_z =
|
||||
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
|
||||
|
||||
multiply->Compute(ptr_x, ptr_y, ptr_z, h, w);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
z->set_layout(DataLayout::kMKLDNN);
|
||||
z->set_format(x->format());
|
||||
} else {
|
||||
// Fallback to naive version:
|
||||
const bool are_inputs_in_same_format = x->format() == y->format();
|
||||
const bool is_x_nchw = x->format() == memory::format::nchw;
|
||||
const bool is_x_nc = x->format() == memory::format::nc;
|
||||
const bool is_y_nchw = y->format() == memory::format::nchw;
|
||||
const bool is_y_nc = y->format() == memory::format::nc;
|
||||
if (!are_inputs_in_same_format) {
|
||||
using platform::MKLDNNDeviceContext;
|
||||
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
||||
const auto& mkldnn_engine = dev_ctx.GetEngine();
|
||||
if (!(is_x_nchw || is_x_nc))
|
||||
ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine,
|
||||
x->dims().size() == 4);
|
||||
if (!(is_y_nchw || is_y_nc))
|
||||
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine,
|
||||
y->dims().size() == 4);
|
||||
}
|
||||
|
||||
auto mul_func = [](T a, T b) -> T { return a * b; };
|
||||
|
||||
TransformFunctor<decltype(mul_func), T,
|
||||
paddle::platform::CPUDeviceContext, T>
|
||||
functor(
|
||||
x, y, z,
|
||||
ctx.template device_context<paddle::platform::CPUDeviceContext>(),
|
||||
mul_func);
|
||||
|
||||
axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis);
|
||||
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
|
||||
"Axis should be in range [0, x_dims)");
|
||||
|
||||
auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed);
|
||||
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
|
||||
|
||||
int pre, n, post;
|
||||
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
|
||||
|
||||
if (post == 1) {
|
||||
functor.RunRowWise(n, pre);
|
||||
} else {
|
||||
functor.RunMidWise(n, pre, post);
|
||||
}
|
||||
z->set_layout(DataLayout::kMKLDNN);
|
||||
z->set_format(x->format());
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_KERNEL(elementwise_mul, MKLDNN, ::paddle::platform::CPUPlace,
|
||||
ops::ElementwiseMulMKLDNNKernel<float>)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue