|
|
@ -12,6 +12,8 @@
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <cuda_fp16.h>
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
|
|
|
|
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
@ -19,6 +21,52 @@ namespace inference {
|
|
|
|
namespace tensorrt {
|
|
|
|
namespace tensorrt {
|
|
|
|
namespace plugin {
|
|
|
|
namespace plugin {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// copied from operators::math::SplitFunctor
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
__global__ void SplitKernel(const T* input_data, const int in_row,
|
|
|
|
|
|
|
|
const int in_col, const int* out_cols,
|
|
|
|
|
|
|
|
int out_cols_size, T** outputs_data) {
|
|
|
|
|
|
|
|
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
|
|
|
int curr_segment = 0;
|
|
|
|
|
|
|
|
int curr_offset = out_cols[0];
|
|
|
|
|
|
|
|
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
|
|
|
|
|
|
|
|
int curr_col_offset = out_cols[curr_segment + 1];
|
|
|
|
|
|
|
|
while (curr_col_offset <= tid_x) {
|
|
|
|
|
|
|
|
curr_offset = curr_col_offset;
|
|
|
|
|
|
|
|
++curr_segment;
|
|
|
|
|
|
|
|
curr_col_offset = out_cols[curr_segment + 1];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int local_col = tid_x - curr_offset;
|
|
|
|
|
|
|
|
int segment_width = curr_col_offset - curr_offset;
|
|
|
|
|
|
|
|
T* output_ptr = outputs_data[curr_segment];
|
|
|
|
|
|
|
|
if (output_ptr != nullptr) {
|
|
|
|
|
|
|
|
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
|
|
|
|
|
|
|
|
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
|
|
|
|
|
|
|
|
output_ptr[tid_y * segment_width + local_col] =
|
|
|
|
|
|
|
|
input_data[tid_y * in_col + tid_x];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
__global__ void SplitKernel(const T* input_data, const int in_row,
|
|
|
|
|
|
|
|
const int in_col, const int fixed_out_col,
|
|
|
|
|
|
|
|
T** outputs_data) {
|
|
|
|
|
|
|
|
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
|
|
|
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
|
|
|
|
|
|
|
|
int split = tid_x / fixed_out_col;
|
|
|
|
|
|
|
|
int in_offset = tid_x - split * fixed_out_col;
|
|
|
|
|
|
|
|
T* output_ptr = outputs_data[split];
|
|
|
|
|
|
|
|
if (output_ptr != nullptr) {
|
|
|
|
|
|
|
|
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
|
|
|
|
|
|
|
|
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
|
|
|
|
|
|
|
|
output_ptr[tid_y * fixed_out_col + in_offset] =
|
|
|
|
|
|
|
|
input_data[tid_y * in_col + tid_x];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
nvinfer1::Dims SplitPlugin::getOutputDimensions(
|
|
|
|
nvinfer1::Dims SplitPlugin::getOutputDimensions(
|
|
|
|
int index, const nvinfer1::Dims* input_dims, int num_inputs) {
|
|
|
|
int index, const nvinfer1::Dims* input_dims, int num_inputs) {
|
|
|
|
PADDLE_ENFORCE_EQ(num_inputs, 1);
|
|
|
|
PADDLE_ENFORCE_EQ(num_inputs, 1);
|
|
|
@ -31,48 +79,95 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions(
|
|
|
|
|
|
|
|
|
|
|
|
int SplitPlugin::initialize() {
|
|
|
|
int SplitPlugin::initialize() {
|
|
|
|
PADDLE_ENFORCE_LE(axis_, nvinfer1::Dims::MAX_DIMS);
|
|
|
|
PADDLE_ENFORCE_LE(axis_, nvinfer1::Dims::MAX_DIMS);
|
|
|
|
|
|
|
|
// notice input dims is [C, H, W]
|
|
|
|
|
|
|
|
nvinfer1::Dims dims = this->getInputDims(0);
|
|
|
|
|
|
|
|
outer_rows_ = 1;
|
|
|
|
|
|
|
|
inner_cols_ = 1;
|
|
|
|
|
|
|
|
for (int i = 0; i < axis_; ++i) {
|
|
|
|
|
|
|
|
outer_rows_ *= dims.d[i];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int i = axis_ + 1; i < dims.nbDims; ++i) {
|
|
|
|
|
|
|
|
inner_cols_ *= dims.d[i];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
same_shape_ = true;
|
|
|
|
std::vector<int> segment_offsets(1, 0);
|
|
|
|
std::vector<int> segment_offsets(1, 0);
|
|
|
|
for (int i = 0; i < this->getNbOutputs(); ++i) {
|
|
|
|
for (int i = 0; i < this->getNbOutputs(); ++i) {
|
|
|
|
segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
|
|
|
|
if (output_length_[i] != output_length_[0]) {
|
|
|
|
|
|
|
|
same_shape_ = false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
segment_offsets.push_back(segment_offsets.back() +
|
|
|
|
|
|
|
|
output_length_[i] * inner_cols_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
segment_offsets_ = segment_offsets;
|
|
|
|
inner_cols_ *= dims.d[axis_];
|
|
|
|
nvinfer1::Dims dims = this->getInputDims(0);
|
|
|
|
d_segment_offsets_ = segment_offsets;
|
|
|
|
nx_ = 1;
|
|
|
|
segment_offsets_ = std::move(segment_offsets);
|
|
|
|
for (int i = dims.nbDims - 1; i > axis_; --i) {
|
|
|
|
d_output_ptrs_.resize(this->getNbOutputs(), nullptr);
|
|
|
|
nx_ *= dims.d[i];
|
|
|
|
return 0;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
inline void Split(cudaStream_t stream, const bool same_shape,
|
|
|
|
|
|
|
|
const int outer_rows, const int inner_cols,
|
|
|
|
|
|
|
|
const std::vector<int>& segment_offsets,
|
|
|
|
|
|
|
|
const int* d_segment_offsets, const T* input, T** outputs) {
|
|
|
|
|
|
|
|
const int kThreadsPerBlock = 1024;
|
|
|
|
|
|
|
|
const int kMaxBlocks = 65535;
|
|
|
|
|
|
|
|
int block_cols = kThreadsPerBlock;
|
|
|
|
|
|
|
|
if (inner_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
|
|
|
|
|
|
|
|
block_cols = ((inner_cols + 31) >> 5) << 5;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ny_ = dims.d[axis_];
|
|
|
|
int block_rows = kThreadsPerBlock / block_cols;
|
|
|
|
nz_ = 1;
|
|
|
|
dim3 block_size = dim3(block_cols, block_rows, 1);
|
|
|
|
for (int i = axis_ - 1; i >= 0; --i) {
|
|
|
|
|
|
|
|
nz_ *= dims.d[i];
|
|
|
|
int grid_cols =
|
|
|
|
|
|
|
|
std::min((inner_cols + block_cols - 1) / block_cols, kMaxBlocks);
|
|
|
|
|
|
|
|
int grid_rows =
|
|
|
|
|
|
|
|
std::min(kMaxBlocks / grid_cols, std::max(outer_rows / block_rows, 1));
|
|
|
|
|
|
|
|
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (same_shape) {
|
|
|
|
|
|
|
|
SplitKernel<<<grid_size, block_size, 0, stream>>>(
|
|
|
|
|
|
|
|
input, outer_rows, inner_cols, segment_offsets[1], outputs);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
SplitKernel<<<grid_size, block_size, 0, stream>>>(
|
|
|
|
|
|
|
|
input, outer_rows, inner_cols, d_segment_offsets,
|
|
|
|
|
|
|
|
static_cast<int>(segment_offsets.size()), outputs);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return 0;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
|
|
|
|
int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
|
|
|
|
void** outputs, void* workspace, cudaStream_t stream) {
|
|
|
|
void** outputs, void* workspace, cudaStream_t stream) {
|
|
|
|
auto const& input_dims = this->getInputDims(0);
|
|
|
|
float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
|
|
|
|
int input_size = 0;
|
|
|
|
if (axis_ == -1 && this->getNbOutputs() < 10) {
|
|
|
|
float const* idata = reinterpret_cast<float const*>(inputs[0]);
|
|
|
|
float** output_ptrs = reinterpret_cast<float**>(outputs);
|
|
|
|
float** odatas = reinterpret_cast<float**>(outputs);
|
|
|
|
int data_type_size = (this->getDataType() == nvinfer1::DataType::kFLOAT)
|
|
|
|
|
|
|
|
? sizeof(__half)
|
|
|
|
// kernel impl here.
|
|
|
|
: sizeof(float);
|
|
|
|
int inputBatchOffset = nx_ * ny_ * nz_;
|
|
|
|
for (int i = 0; i < this->getNbOutputs(); ++i) {
|
|
|
|
for (size_t i = 0; i < this->getNbOutputs(); i++) {
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
for (size_t j = 0; j < batchSize; j++) {
|
|
|
|
cudaMemcpyAsync(
|
|
|
|
cudaMemcpyAsync(
|
|
|
|
output_ptrs[i], input_ptr + segment_offsets_[i],
|
|
|
|
odatas[i] +
|
|
|
|
(segment_offsets_[i + 1] - segment_offsets_[i]) * data_type_size,
|
|
|
|
j * (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ *
|
|
|
|
cudaMemcpyDeviceToDevice, stream) == cudaSuccess);
|
|
|
|
sizeof(float),
|
|
|
|
}
|
|
|
|
inputs[0] +
|
|
|
|
} else {
|
|
|
|
(inputBatchOffset * j + segment_offsets_[i] * nx_) *
|
|
|
|
outer_rows_ *= batchSize;
|
|
|
|
sizeof(float),
|
|
|
|
const int* d_segment_offsets_ptr =
|
|
|
|
(segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * sizeof(float),
|
|
|
|
thrust::raw_pointer_cast(&d_segment_offsets_[0]);
|
|
|
|
cudaMemcpyDeviceToDevice, stream);
|
|
|
|
float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]);
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(cudaMemcpyAsync(output_ptrs, outputs,
|
|
|
|
|
|
|
|
this->getNbOutputs() * sizeof(float*),
|
|
|
|
|
|
|
|
cudaMemcpyHostToDevice,
|
|
|
|
|
|
|
|
stream) == cudaSuccess);
|
|
|
|
|
|
|
|
if (this->getDataType() == nvinfer1::DataType::kFLOAT) {
|
|
|
|
|
|
|
|
Split(stream, same_shape_, outer_rows_, inner_cols_, segment_offsets_,
|
|
|
|
|
|
|
|
d_segment_offsets_ptr, input_ptr, output_ptrs);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
Split(stream, same_shape_, outer_rows_, inner_cols_, segment_offsets_,
|
|
|
|
|
|
|
|
d_segment_offsets_ptr, (__half*)input_ptr, // NOLINT
|
|
|
|
|
|
|
|
(__half**)output_ptrs); // NOLINT
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return cudaGetLastError() != cudaSuccess;
|
|
|
|
return cudaGetLastError() != cudaSuccess;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|