commit
8580b7a130
@ -0,0 +1,80 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
/*
|
||||
* PRelu converter from fluid to tensorRT.
|
||||
*/
|
||||
class PReluOpConverter : public OpConverter {
|
||||
public:
|
||||
void operator()(const framework::proto::OpDesc& op,
|
||||
const framework::Scope& scope, bool test_mode) override {
|
||||
VLOG(4) << "convert fluid prelu op to tensorrt prelu layer";
|
||||
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
// Declare inputs
|
||||
int input_num = op_desc.Input("X").size();
|
||||
PADDLE_ENFORCE(input_num == 1);
|
||||
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
|
||||
// Get output
|
||||
size_t output_num = op_desc.Output("Out").size();
|
||||
PADDLE_ENFORCE(output_num == 1);
|
||||
// Get attrs
|
||||
std::string mode = boost::get<std::string>(op_desc.GetAttr("mode"));
|
||||
//
|
||||
auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]);
|
||||
PADDLE_ENFORCE_NOT_NULL(alpha_var);
|
||||
auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();
|
||||
|
||||
platform::CUDAPlace place;
|
||||
std::unique_ptr<framework::LoDTensor> alpha_tensor_device(
|
||||
new framework::LoDTensor());
|
||||
alpha_tensor_device->Resize(alpha_tensor->dims());
|
||||
TensorCopySync(*alpha_tensor, place, alpha_tensor_device.get());
|
||||
float* alpha_data = alpha_tensor_device->mutable_data<float>(place);
|
||||
|
||||
// Transform alpha to TensorRTEngine::Weight
|
||||
TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT,
|
||||
static_cast<void*>(alpha_data),
|
||||
alpha_tensor_device->numel());
|
||||
PReluPlugin* plugin = new PReluPlugin(alpha_rt, mode);
|
||||
nvinfer1::IPluginLayer* layer =
|
||||
engine_->AddPlugin(&input, input_num, plugin);
|
||||
// keep alpha tensor to avoid release it's memory
|
||||
engine_->weight_map[op_desc.Input("Alpha")[0]] =
|
||||
std::move(alpha_tensor_device);
|
||||
|
||||
std::string layer_name = "prelu (Output: ";
|
||||
auto output_name = op_desc.Output("Out")[0];
|
||||
layer->getOutput(0)->setName(output_name.c_str());
|
||||
engine_->SetITensor(output_name, layer->getOutput(0));
|
||||
layer_name += output_name;
|
||||
if (test_mode) {
|
||||
engine_->DeclareOutput(output_name);
|
||||
}
|
||||
layer->setName((layer_name + ")").c_str());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_TRT_OP_CONVERTER(prelu, PReluOpConverter);
|
@ -0,0 +1,94 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
TEST(prelu_op, test_channel_wise) {
|
||||
std::unordered_set<std::string> parameters({"prelu_alpha"});
|
||||
framework::Scope scope;
|
||||
TRTConvertValidation validator(10, parameters, scope, 1000);
|
||||
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
|
||||
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(3, 1, 1));
|
||||
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
|
||||
|
||||
// Prepare Op description
|
||||
framework::OpDesc desc;
|
||||
desc.SetType("prelu");
|
||||
desc.SetInput("X", {"prelu_input"});
|
||||
desc.SetInput("Alpha", {"prelu_alpha"});
|
||||
desc.SetOutput("Out", {"prelu_out"});
|
||||
|
||||
desc.SetAttr("mode", std::string("channel"));
|
||||
|
||||
validator.SetOp(*desc.Proto());
|
||||
|
||||
validator.Execute(1);
|
||||
}
|
||||
|
||||
TEST(prelu_op, test_element_wise) {
|
||||
std::unordered_set<std::string> parameters({"prelu_alpha"});
|
||||
framework::Scope scope;
|
||||
TRTConvertValidation validator(10, parameters, scope, 1000);
|
||||
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
|
||||
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims4(10, 3, 2, 2));
|
||||
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
|
||||
|
||||
// Prepare Op description
|
||||
framework::OpDesc desc;
|
||||
desc.SetType("prelu");
|
||||
desc.SetInput("X", {"prelu_input"});
|
||||
desc.SetInput("Alpha", {"prelu_alpha"});
|
||||
desc.SetOutput("Out", {"prelu_out"});
|
||||
|
||||
desc.SetAttr("mode", std::string("element"));
|
||||
|
||||
validator.SetOp(*desc.Proto());
|
||||
|
||||
validator.Execute(1);
|
||||
}
|
||||
|
||||
TEST(prelu_op, test_scalar) {
|
||||
std::unordered_set<std::string> parameters({"prelu_alpha"});
|
||||
framework::Scope scope;
|
||||
TRTConvertValidation validator(10, parameters, scope, 1000);
|
||||
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
|
||||
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(1, 1, 1));
|
||||
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
|
||||
|
||||
// Prepare Op description
|
||||
framework::OpDesc desc;
|
||||
desc.SetType("prelu");
|
||||
desc.SetInput("X", {"prelu_input"});
|
||||
desc.SetInput("Alpha", {"prelu_alpha"});
|
||||
desc.SetOutput("Out", {"prelu_out"});
|
||||
|
||||
desc.SetAttr("mode", std::string("all"));
|
||||
|
||||
validator.SetOp(*desc.Proto());
|
||||
|
||||
validator.Execute(1);
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
// USE_OP(prelu);
|
||||
USE_CPU_ONLY_OP(prelu);
|
@ -1 +1 @@
|
||||
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce)
|
||||
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu prelu_op_plugin.cu DEPS enforce)
|
||||
|
@ -0,0 +1,131 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <cassert>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
static const int CUDA_NUM_THREADS = 1024;
|
||||
static const int CUDA_MAX_NUM_BLOCKS = 65535;
|
||||
inline static int GET_NUM_BLOCKS(const int N) {
|
||||
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
|
||||
}
|
||||
|
||||
__global__ void PReluChannelWiseKernel(const float *input, const float *alpha,
|
||||
float *output, int channel,
|
||||
size_t spatial_size) {
|
||||
size_t offset = blockIdx.x * spatial_size;
|
||||
const float *in = input + offset;
|
||||
float *out = output + offset;
|
||||
float scale = alpha[blockIdx.x % channel];
|
||||
|
||||
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
|
||||
float x = in[i];
|
||||
out[i] = (x > 0) ? x : scale * x;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void PReluElementWiseKernel(const float *input, const float *alpha,
|
||||
float *output, size_t spatial_size) {
|
||||
size_t offset = blockIdx.x * spatial_size;
|
||||
const float *in = input + offset;
|
||||
const float *scale = alpha + offset;
|
||||
float *out = output + offset;
|
||||
|
||||
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
|
||||
float x = in[i];
|
||||
out[i] = (x > 0) ? x : scale[i] * x;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void PReluScalarKernel(const float *input, const float *alpha,
|
||||
float *output, size_t spatial_size) {
|
||||
size_t offset = blockIdx.x * spatial_size;
|
||||
const float *in = input + offset;
|
||||
float scale = *alpha;
|
||||
float *out = output + offset;
|
||||
|
||||
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
|
||||
float x = in[i];
|
||||
out[i] = (x > 0) ? x : scale * x;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void PReluChannelWise(cudaStream_t stream, const float *input,
|
||||
const float *alpha, float *output,
|
||||
int batch_size,
|
||||
const nvinfer1::Dims &dims) {
|
||||
size_t unroll = batch_size * dims.d[0];
|
||||
size_t spatial_size = dims.d[1] * dims.d[2];
|
||||
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
|
||||
PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
|
||||
input, alpha, output, dims.d[0], spatial_size);
|
||||
}
|
||||
|
||||
static inline void PReluElementWise(cudaStream_t stream, const float *input,
|
||||
const float *alpha, float *output,
|
||||
int batch_size,
|
||||
const nvinfer1::Dims &dims) {
|
||||
size_t unroll = batch_size * dims.d[0];
|
||||
size_t spatial_size = dims.d[1] * dims.d[2];
|
||||
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
|
||||
PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
|
||||
input, alpha, output, spatial_size);
|
||||
}
|
||||
|
||||
static inline void PReluScalar(cudaStream_t stream, const float *input,
|
||||
const float *alpha, float *output,
|
||||
int batch_size, const nvinfer1::Dims &dims) {
|
||||
size_t unroll = batch_size * dims.d[0];
|
||||
size_t spatial_size = dims.d[1] * dims.d[2];
|
||||
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
|
||||
PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
|
||||
input, alpha, output, spatial_size);
|
||||
}
|
||||
|
||||
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
|
||||
const nvinfer1::Dims *inputDims,
|
||||
int nbInputs) {
|
||||
assert(nbInputs == 1);
|
||||
assert(index < this->getNbOutputs());
|
||||
nvinfer1::Dims const &input_dims = inputDims[0];
|
||||
nvinfer1::Dims output_dims = input_dims;
|
||||
return output_dims;
|
||||
}
|
||||
|
||||
int PReluPlugin::enqueue(int batchSize, const void *const *inputs,
|
||||
void **outputs, void *workspace, cudaStream_t stream) {
|
||||
// input dims is CHW.
|
||||
const auto &input_dims = this->getInputDims(0);
|
||||
const float *input = reinterpret_cast<const float *>(inputs[0]);
|
||||
const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
|
||||
float *output = reinterpret_cast<float **>(outputs)[0];
|
||||
if (mode_ == "channel") {
|
||||
PReluChannelWise(stream, input, alpha, output, batchSize, input_dims);
|
||||
} else if (mode_ == "element") {
|
||||
PReluElementWise(stream, input, alpha, output, batchSize, input_dims);
|
||||
} else {
|
||||
PReluScalar(stream, input, alpha, output, batchSize, input_dims);
|
||||
}
|
||||
return cudaGetLastError() != cudaSuccess;
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,68 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class PReluPlugin : public PluginTensorRT {
|
||||
TensorRTEngine::Weight alpha_;
|
||||
std::string mode_;
|
||||
|
||||
protected:
|
||||
size_t getSerializationSize() override {
|
||||
// return getBaseSerializationSize(alpha_) + SerializedSize(mode_);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// TRT will call this func when we need to serialize the configuration of
|
||||
// tensorrt.
|
||||
// It should not be called by users.
|
||||
void serialize(void *buffer) override {
|
||||
// serializeBase(buffer);
|
||||
// SerializeValue(&buffer, alpha_);
|
||||
// SerializeValue(&buffer, mode_);
|
||||
}
|
||||
|
||||
public:
|
||||
PReluPlugin(TensorRTEngine::Weight const &alpha, std::string const &mode)
|
||||
: alpha_(alpha), mode_(mode) {}
|
||||
|
||||
// It was used for tensorrt deserialization.
|
||||
// It should not be called by users.
|
||||
PReluPlugin(void const *serialData, size_t serialLength) {
|
||||
// deserializeBase(serialData, serialLength);
|
||||
// DeserializeValue(&serialData, &serialLength, &alpha_);
|
||||
// DeserializeValue(&serialData, &serialLength, &mode_);
|
||||
}
|
||||
|
||||
PReluPlugin *clone() const override { return new PReluPlugin(alpha_, mode_); }
|
||||
|
||||
const char *getPluginType() const override { return "prelu"; }
|
||||
int getNbOutputs() const override { return 1; }
|
||||
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
|
||||
int nbInputDims) override;
|
||||
int enqueue(int batchSize, const void *const *inputs, void **outputs,
|
||||
void *workspace, cudaStream_t stream) override;
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,195 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/mixed_vector.h"
|
||||
#include "paddle/fluid/operators/math/softmax.h"
|
||||
#include "paddle/fluid/operators/warpctc_op.h"
|
||||
#include "paddle/fluid/platform/cudnn_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
#if CUDNN_VERSION >= 7001
|
||||
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
|
||||
using ScopedCTCLossDescriptor = platform::ScopedCTCLossDescriptor;
|
||||
using DataLayout = platform::DataLayout;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class CudnnCTCKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
// =====================Copied code from warpctc===========================
|
||||
auto* logits = ctx.Input<LoDTensor>("Logits");
|
||||
auto* label = ctx.Input<LoDTensor>("Label");
|
||||
auto* warpctc_grad = ctx.Output<LoDTensor>("WarpCTCGrad");
|
||||
auto* loss = ctx.Output<LoDTensor>("Loss");
|
||||
|
||||
const size_t level = 0;
|
||||
|
||||
auto logits_lod = framework::ToAbsOffset(logits->lod());
|
||||
auto logits_dims = logits->dims();
|
||||
PADDLE_ENFORCE_EQ(logits_dims[0],
|
||||
static_cast<int64_t>(logits_lod[level].back()),
|
||||
"The first dimension of Input(Logits) should be equal to "
|
||||
"the sum of all sequences' lengths.");
|
||||
|
||||
auto label_lod = framework::ToAbsOffset(label->lod());
|
||||
auto label_dims = label->dims();
|
||||
PADDLE_ENFORCE_EQ(
|
||||
label_dims[0], label->numel(),
|
||||
"The width of each timestep in Input(Label) should be 1.");
|
||||
|
||||
const size_t num_sequences = logits_lod[level].size() - 1;
|
||||
PADDLE_ENFORCE_EQ(num_sequences, label_lod[level].size() - 1,
|
||||
"The number of sequences of Input(Logits) should be "
|
||||
"equal to that of Input(Label).");
|
||||
PADDLE_ENFORCE_LE(num_sequences, 256,
|
||||
"The labelLengths must less than 256 for cudnn call.");
|
||||
|
||||
const size_t sequence_width = logits->numel() / logits_dims[0];
|
||||
auto loss_dims =
|
||||
framework::make_ddim({static_cast<int64_t>(num_sequences), 1});
|
||||
|
||||
// NOTE: cudnn takes softmax input, calculate softmax first, then do padding
|
||||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
LoDTensor softmax_logits;
|
||||
softmax_logits.mutable_data<T>(logits->dims(), ctx.GetPlace());
|
||||
softmax_logits.set_lod(logits_lod);
|
||||
int rank = logits->dims().size();
|
||||
Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1);
|
||||
Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1);
|
||||
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, &in_2d, &out_2d);
|
||||
|
||||
// ctc needs sequences data stored in transposed padding format
|
||||
// logits and grad using padding data of layout 'TNC'
|
||||
// T: max_sequence_length
|
||||
// N: batch_size (num_sequences)
|
||||
// C: width
|
||||
LoDTensor warpctc_logits;
|
||||
const size_t max_sequence_length =
|
||||
math::MaximumSequenceLength(logits_lod[level]);
|
||||
auto warpctc_logits_dims =
|
||||
framework::make_ddim({static_cast<int64_t>(max_sequence_length),
|
||||
static_cast<int64_t>(num_sequences),
|
||||
static_cast<int64_t>(sequence_width)});
|
||||
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
|
||||
|
||||
LoDTensor cpu_pad_value;
|
||||
T* pad_value_data =
|
||||
cpu_pad_value.mutable_data<T>({1}, platform::CPUPlace());
|
||||
*pad_value_data = static_cast<T>(0);
|
||||
LoDTensor pad_value;
|
||||
if (platform::is_cpu_place(ctx.GetPlace())) {
|
||||
pad_value = cpu_pad_value;
|
||||
} else {
|
||||
TensorCopySync(cpu_pad_value, ctx.GetPlace(), &pad_value);
|
||||
}
|
||||
|
||||
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
|
||||
ctx.template device_context<DeviceContext>(), softmax_logits,
|
||||
&warpctc_logits, pad_value, -1, 0, false /* norm_by_times */,
|
||||
math::kLengthBatchWidth);
|
||||
const T* warpctc_logits_data = warpctc_logits.data<T>();
|
||||
|
||||
std::vector<int> warpctc_label_lengths(num_sequences);
|
||||
std::vector<int> warpctc_logits_lengths(num_sequences);
|
||||
|
||||
for (size_t i = 0; i < num_sequences; ++i) {
|
||||
warpctc_label_lengths[i] = label_lod[level][i + 1] - label_lod[level][i];
|
||||
warpctc_logits_lengths[i] =
|
||||
logits_lod[level][i + 1] - logits_lod[level][i];
|
||||
}
|
||||
|
||||
T* warpctc_grad_data =
|
||||
warpctc_grad->mutable_data<T>(warpctc_logits.dims(), ctx.GetPlace());
|
||||
|
||||
math::SetConstant<DeviceContext, T>()(
|
||||
ctx.template device_context<DeviceContext>(), warpctc_grad,
|
||||
static_cast<T>(0));
|
||||
|
||||
Tensor warpctc_label;
|
||||
TensorCopySync(*label, platform::CPUPlace(), &warpctc_label);
|
||||
const int* warpctc_label_data = warpctc_label.data<int>();
|
||||
// ========================================================================
|
||||
|
||||
ScopedTensorDescriptor logits_desc;
|
||||
ScopedTensorDescriptor grad_desc;
|
||||
ScopedCTCLossDescriptor ctcloss_desc;
|
||||
// layout here doesn't have effect.
|
||||
DataLayout layout = DataLayout::kNCHW;
|
||||
|
||||
auto cu_logits_desc = logits_desc.descriptor<T>(
|
||||
layout, framework::vectorize2int(warpctc_logits.dims()));
|
||||
auto cu_grad_desc = grad_desc.descriptor<T>(
|
||||
layout, framework::vectorize2int(warpctc_grad->dims()));
|
||||
auto cu_ctcloss_desc = ctcloss_desc.descriptor<T>();
|
||||
|
||||
auto handle = dev_ctx.cudnn_handle();
|
||||
size_t workspace_size;
|
||||
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnGetCTCLossWorkspaceSize(
|
||||
handle, cu_logits_desc, cu_grad_desc, warpctc_label_data,
|
||||
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
|
||||
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, &workspace_size));
|
||||
|
||||
T* loss_data = loss->mutable_data<T>(loss_dims, ctx.GetPlace());
|
||||
|
||||
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
|
||||
auto cudnn_func = [&](void* cudnn_workspace) {
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnCTCLoss(
|
||||
handle, cu_logits_desc, warpctc_logits_data, warpctc_label_data,
|
||||
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
|
||||
loss_data, cu_grad_desc, warpctc_grad_data,
|
||||
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, cudnn_workspace,
|
||||
workspace_size));
|
||||
};
|
||||
workspace_handle.RunFunc(cudnn_func, workspace_size);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class CudnnCTCGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* warpctc_grad = ctx.Input<LoDTensor>("WarpCTCGrad");
|
||||
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
|
||||
const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
|
||||
|
||||
logits_grad->mutable_data<T>(ctx.GetPlace());
|
||||
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
|
||||
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
|
||||
ctx.template device_context<DeviceContext>(), *warpctc_grad,
|
||||
logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth);
|
||||
|
||||
const T* loss_grad_data = loss_grad->data<T>();
|
||||
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
|
||||
ctx.template device_context<DeviceContext>(), loss_grad_data,
|
||||
logits_grad);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
#if CUDNN_VERSION >= 7001
|
||||
REGISTER_OP_KERNEL(
|
||||
warpctc, CUDNN, plat::CUDAPlace,
|
||||
ops::CudnnCTCKernel<paddle::platform::CUDADeviceContext, float>);
|
||||
REGISTER_OP_KERNEL(
|
||||
warpctc_grad, CUDNN, plat::CUDAPlace,
|
||||
ops::CudnnCTCGradKernel<paddle::platform::CUDADeviceContext, float>);
|
||||
#endif
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue