commit
67562a6fcd
@ -0,0 +1,95 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
// LeakyRelu converter from fluid to tensorRT
|
||||||
|
class LeakyReluOpConverter : public OpConverter {
|
||||||
|
public:
|
||||||
|
void operator()(const framework::proto::OpDesc& op,
|
||||||
|
const framework::Scope& scope, bool test_mode) override {
|
||||||
|
VLOG(4) << "convert fluid leaky_relu op to tensorrt layer";
|
||||||
|
|
||||||
|
framework::OpDesc op_desc(op, nullptr);
|
||||||
|
// Declare inputs
|
||||||
|
int input_num = op_desc.Input("X").size();
|
||||||
|
PADDLE_ENFORCE(input_num == 1);
|
||||||
|
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
|
||||||
|
// Get output
|
||||||
|
size_t output_num = op_desc.Output("Out").size();
|
||||||
|
PADDLE_ENFORCE(output_num == 1);
|
||||||
|
// Get attrs
|
||||||
|
float alpha = boost::get<float>(op_desc.GetAttr("alpha"));
|
||||||
|
|
||||||
|
platform::CPUPlace place;
|
||||||
|
std::unique_ptr<framework::LoDTensor> alpha_tensor(
|
||||||
|
new framework::LoDTensor());
|
||||||
|
alpha_tensor->Resize(framework::make_ddim({2}));
|
||||||
|
float* alpha_data = alpha_tensor->mutable_data<float>(place);
|
||||||
|
alpha_data[0] = alpha;
|
||||||
|
alpha_data[1] = 1.f - alpha;
|
||||||
|
// the leaky relu formula y = (x > 0) ? x : alpha * x is equal to
|
||||||
|
// y = alpha * x + (x > 0) ? (1 - alpha) * x : 0
|
||||||
|
TensorRTEngine::Weight scale{nvinfer1::DataType::kFLOAT, &alpha_data[0], 1};
|
||||||
|
TensorRTEngine::Weight shift{nvinfer1::DataType::kFLOAT, nullptr, 0};
|
||||||
|
TensorRTEngine::Weight power{nvinfer1::DataType::kFLOAT, nullptr, 0};
|
||||||
|
// y_scale = alpha * x
|
||||||
|
auto* scale_layer = TRT_ENGINE_ADD_LAYER(
|
||||||
|
engine_, Scale, *input, nvinfer1::ScaleMode::kUNIFORM, shift.get(),
|
||||||
|
scale.get(), power.get());
|
||||||
|
PADDLE_ENFORCE(nullptr != scale_layer);
|
||||||
|
// y_relu = (x > 0) : x : 0
|
||||||
|
auto* relu_layer = TRT_ENGINE_ADD_LAYER(engine_, Activation, *input,
|
||||||
|
nvinfer1::ActivationType::kRELU);
|
||||||
|
PADDLE_ENFORCE(nullptr != relu_layer);
|
||||||
|
//
|
||||||
|
TensorRTEngine::Weight sub_scale{nvinfer1::DataType::kFLOAT, &alpha_data[1],
|
||||||
|
1};
|
||||||
|
auto* scale_relu_layer =
|
||||||
|
TRT_ENGINE_ADD_LAYER(engine_, Scale, *(relu_layer->getOutput(0)),
|
||||||
|
nvinfer1::ScaleMode::kUNIFORM, shift.get(),
|
||||||
|
sub_scale.get(), power.get());
|
||||||
|
PADDLE_ENFORCE(nullptr != scale_relu_layer);
|
||||||
|
auto* output_layer =
|
||||||
|
TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *(scale_layer->getOutput(0)),
|
||||||
|
*(scale_relu_layer->getOutput(0)),
|
||||||
|
nvinfer1::ElementWiseOperation::kSUM);
|
||||||
|
PADDLE_ENFORCE(nullptr != output_layer);
|
||||||
|
// keep alpha tensor to avoid release it's memory
|
||||||
|
std::string alpha_name = op_desc.Output("Out")[0] + "_alpha";
|
||||||
|
PADDLE_ENFORCE(engine_->weight_map.find(alpha_name) ==
|
||||||
|
engine_->weight_map.end());
|
||||||
|
engine_->weight_map[alpha_name] = std::move(alpha_tensor);
|
||||||
|
|
||||||
|
std::string layer_name = "leaky_relu (Output: ";
|
||||||
|
auto output_name = op_desc.Output("Out")[0];
|
||||||
|
output_layer->getOutput(0)->setName(output_name.c_str());
|
||||||
|
engine_->SetITensor(output_name, output_layer->getOutput(0));
|
||||||
|
layer_name += output_name;
|
||||||
|
if (test_mode) {
|
||||||
|
engine_->DeclareOutput(output_name);
|
||||||
|
}
|
||||||
|
output_layer->setName((layer_name + ")").c_str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_TRT_OP_CONVERTER(leaky_relu, LeakyReluOpConverter);
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
TEST(leaky_relu_op, test_leaky_relu) {
|
||||||
|
std::unordered_set<std::string> parameters;
|
||||||
|
framework::Scope scope;
|
||||||
|
TRTConvertValidation validator(10, parameters, scope, 1000);
|
||||||
|
validator.DeclInputVar("leaky_relu_input", nvinfer1::DimsCHW(3, 2, 2));
|
||||||
|
validator.DeclOutputVar("leaky_relu_out", nvinfer1::DimsCHW(3, 2, 2));
|
||||||
|
|
||||||
|
// Prepare Op description
|
||||||
|
framework::OpDesc desc;
|
||||||
|
desc.SetType("leaky_relu");
|
||||||
|
desc.SetInput("X", {"leaky_relu_input"});
|
||||||
|
desc.SetOutput("Out", {"leaky_relu_out"});
|
||||||
|
|
||||||
|
desc.SetAttr("alpha", 0.1f);
|
||||||
|
|
||||||
|
validator.SetOp(*desc.Proto());
|
||||||
|
|
||||||
|
validator.Execute(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
// USE_OP(leaky_relu);
|
||||||
|
USE_OP(leaky_relu);
|
||||||
@ -1,3 +1,3 @@
|
|||||||
nv_library(tensorrt_plugin
|
nv_library(tensorrt_plugin
|
||||||
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
|
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
|
||||||
DEPS enforce device_context)
|
DEPS enforce tensorrt_engine)
|
||||||
|
|||||||
@ -0,0 +1,201 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <mkldnn/include/mkldnn.hpp>
|
||||||
|
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
|
||||||
|
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
|
||||||
|
|
||||||
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/math/jit_kernel.h"
|
||||||
|
#include "xbyak.h"
|
||||||
|
#include "xbyak_util.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using framework::DataLayout;
|
||||||
|
using mkldnn::memory;
|
||||||
|
|
||||||
|
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
|
||||||
|
std::transform(format.begin(), format.end(), format.begin(), ::tolower);
|
||||||
|
|
||||||
|
if (!format.compare("nchw")) {
|
||||||
|
return memory::format::nchw;
|
||||||
|
} else if (!format.compare("nchw16c")) {
|
||||||
|
return memory::format::nChw16c;
|
||||||
|
} else if (!format.compare("nchw8c")) {
|
||||||
|
return memory::format::nChw8c;
|
||||||
|
} else if (!format.compare("nhwc")) {
|
||||||
|
return memory::format::nhwc;
|
||||||
|
} else {
|
||||||
|
return memory::format::any;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void UpdateDataFormat(const framework::ExecutionContext& ctx,
|
||||||
|
framework::Tensor* tensor, const char* attribute) {
|
||||||
|
if (ctx.op().HasAttr(attribute)) {
|
||||||
|
auto format_as_string = ctx.Attr<std::string>(attribute);
|
||||||
|
auto format = StringToMKLDNNFormat(format_as_string);
|
||||||
|
if (format != memory::format::any) {
|
||||||
|
tensor->set_format(format);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void ReorderInput(framework::Tensor* tensor,
|
||||||
|
const platform::Place& place,
|
||||||
|
const mkldnn::engine& engine, bool isFourDim) {
|
||||||
|
using platform::to_void_cast;
|
||||||
|
auto dims = paddle::framework::vectorize2int(tensor->dims());
|
||||||
|
framework::Tensor out_tensor;
|
||||||
|
out_tensor.Resize(tensor->dims());
|
||||||
|
out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc);
|
||||||
|
out_tensor.set_layout(tensor->layout());
|
||||||
|
mkldnn::memory input_memory = {
|
||||||
|
{{dims, platform::MKLDNNGetDataType<T>(), tensor->format()}, engine},
|
||||||
|
to_void_cast<T>(tensor->data<T>())};
|
||||||
|
mkldnn::memory output_memory = {
|
||||||
|
{{dims, platform::MKLDNNGetDataType<T>(), out_tensor.format()}, engine},
|
||||||
|
to_void_cast<T>(out_tensor.mutable_data<T>(place))};
|
||||||
|
platform::Reorder(input_memory, output_memory);
|
||||||
|
tensor->ShareDataWith(out_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
int axis = ctx.Attr<int>("axis");
|
||||||
|
auto* x = ctx.Input<Tensor>("X");
|
||||||
|
auto* y = ctx.Input<Tensor>("Y");
|
||||||
|
auto* z = ctx.Output<Tensor>("Out");
|
||||||
|
const T* x_data = x->data<T>();
|
||||||
|
const T* y_data = y->data<T>();
|
||||||
|
T* z_data = z->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
auto x_dims = x->dims();
|
||||||
|
auto y_dims_untrimmed = y->dims();
|
||||||
|
auto x_int_dims = paddle::framework::vectorize2int(x_dims);
|
||||||
|
|
||||||
|
UpdateDataFormat(ctx, (Tensor*)x, "x_data_format");
|
||||||
|
UpdateDataFormat(ctx, (Tensor*)y, "y_data_format");
|
||||||
|
|
||||||
|
Xbyak::util::Cpu cpu;
|
||||||
|
const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F);
|
||||||
|
const bool are_dims_divisable = !(x_int_dims[1] % 16);
|
||||||
|
const bool is_x_format_correct = x->format() == memory::format::nChw16c;
|
||||||
|
const bool is_y_format_correct = y->format() == memory::format::nc;
|
||||||
|
if (is_x_format_correct && is_y_format_correct && are_dims_divisable &&
|
||||||
|
is_avx512_enabled) {
|
||||||
|
int pre, n, post;
|
||||||
|
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
|
||||||
|
|
||||||
|
if (post == 1) {
|
||||||
|
PADDLE_THROW("Not implemented when post is 1");
|
||||||
|
} else {
|
||||||
|
// Just check whether it works for RE-Resnext.
|
||||||
|
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions");
|
||||||
|
|
||||||
|
int n = x_dims[0];
|
||||||
|
int c = x_dims[1];
|
||||||
|
int h = x_dims[2];
|
||||||
|
int w = x_dims[3];
|
||||||
|
|
||||||
|
PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c,
|
||||||
|
"Y should be in nc format");
|
||||||
|
|
||||||
|
constexpr int simd_width = 16;
|
||||||
|
int C = c / simd_width;
|
||||||
|
|
||||||
|
const auto& multiply =
|
||||||
|
math::jitkernel::KernelPool::Instance()
|
||||||
|
.template Get<math::jitkernel::EltwiseMulnChw16cNCKernel<T>>(n);
|
||||||
|
|
||||||
|
#pragma omp parallel for collapse(2)
|
||||||
|
for (int ni = 0; ni < n; ni++) {
|
||||||
|
for (int ci = 0; ci < C; ci++) {
|
||||||
|
auto ptr_x =
|
||||||
|
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
|
||||||
|
|
||||||
|
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
|
||||||
|
auto ptr_z =
|
||||||
|
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
|
||||||
|
|
||||||
|
multiply->Compute(ptr_x, ptr_y, ptr_z, h, w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
z->set_layout(DataLayout::kMKLDNN);
|
||||||
|
z->set_format(x->format());
|
||||||
|
} else {
|
||||||
|
// Fallback to naive version:
|
||||||
|
const bool are_inputs_in_same_format = x->format() == y->format();
|
||||||
|
const bool is_x_nchw = x->format() == memory::format::nchw;
|
||||||
|
const bool is_x_nc = x->format() == memory::format::nc;
|
||||||
|
const bool is_y_nchw = y->format() == memory::format::nchw;
|
||||||
|
const bool is_y_nc = y->format() == memory::format::nc;
|
||||||
|
if (!are_inputs_in_same_format) {
|
||||||
|
using platform::MKLDNNDeviceContext;
|
||||||
|
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
||||||
|
const auto& mkldnn_engine = dev_ctx.GetEngine();
|
||||||
|
if (!(is_x_nchw || is_x_nc))
|
||||||
|
ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine,
|
||||||
|
x->dims().size() == 4);
|
||||||
|
if (!(is_y_nchw || is_y_nc))
|
||||||
|
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine,
|
||||||
|
y->dims().size() == 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto mul_func = [](T a, T b) -> T { return a * b; };
|
||||||
|
|
||||||
|
TransformFunctor<decltype(mul_func), T,
|
||||||
|
paddle::platform::CPUDeviceContext, T>
|
||||||
|
functor(
|
||||||
|
x, y, z,
|
||||||
|
ctx.template device_context<paddle::platform::CPUDeviceContext>(),
|
||||||
|
mul_func);
|
||||||
|
|
||||||
|
axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis);
|
||||||
|
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
|
||||||
|
"Axis should be in range [0, x_dims)");
|
||||||
|
|
||||||
|
auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed);
|
||||||
|
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
|
||||||
|
|
||||||
|
int pre, n, post;
|
||||||
|
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
|
||||||
|
|
||||||
|
if (post == 1) {
|
||||||
|
functor.RunRowWise(n, pre);
|
||||||
|
} else {
|
||||||
|
functor.RunMidWise(n, pre, post);
|
||||||
|
}
|
||||||
|
z->set_layout(DataLayout::kMKLDNN);
|
||||||
|
z->set_format(x->format());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
|
||||||
|
REGISTER_OP_KERNEL(elementwise_mul, MKLDNN, ::paddle::platform::CPUPlace,
|
||||||
|
ops::ElementwiseMulMKLDNNKernel<float>)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue