|
|
|
|
@ -59,8 +59,9 @@ __global__ void SliceKernel(int num, int dims, const T *input,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SlicePlugin::SlicePlugin(std::vector<int> starts, std::vector<int> ends,
|
|
|
|
|
std::vector<int> axes, bool ban_fp16)
|
|
|
|
|
: starts_(starts), ends_(ends), axes_(axes), ban_fp16_(ban_fp16) {
|
|
|
|
|
std::vector<int> axes, bool with_fp16)
|
|
|
|
|
: starts_(starts), ends_(ends), axes_(axes) {
|
|
|
|
|
with_fp16_ = with_fp16;
|
|
|
|
|
cudaEventCreate(©_event_);
|
|
|
|
|
cudaStreamCreate(©_stream_);
|
|
|
|
|
}
|
|
|
|
|
@ -70,7 +71,6 @@ SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) {
|
|
|
|
|
DeserializeValue(&serial_data, &serial_length, &starts_);
|
|
|
|
|
DeserializeValue(&serial_data, &serial_length, &ends_);
|
|
|
|
|
DeserializeValue(&serial_data, &serial_length, &axes_);
|
|
|
|
|
DeserializeValue(&serial_data, &serial_length, &ban_fp16_);
|
|
|
|
|
cudaEventCreate(©_event_);
|
|
|
|
|
cudaStreamCreate(©_stream_);
|
|
|
|
|
}
|
|
|
|
|
@ -82,19 +82,19 @@ SlicePlugin::~SlicePlugin() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SlicePlugin *SlicePlugin::clone() const {
|
|
|
|
|
return new SlicePlugin(starts_, ends_, axes_, ban_fp16_);
|
|
|
|
|
return new SlicePlugin(starts_, ends_, axes_, with_fp16_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SlicePlugin::supportsFormat(nvinfer1::DataType type,
|
|
|
|
|
nvinfer1::PluginFormat format) const {
|
|
|
|
|
#ifdef SUPPORTS_CUDA_FP16
|
|
|
|
|
return ((type == nvinfer1::DataType::kFLOAT ||
|
|
|
|
|
type == nvinfer1::DataType::kHALF) &&
|
|
|
|
|
(format == nvinfer1::PluginFormat::kNCHW));
|
|
|
|
|
#else
|
|
|
|
|
return ((type == nvinfer1::DataType::kFLOAT) &&
|
|
|
|
|
(format == nvinfer1::PluginFormat::kNCHW));
|
|
|
|
|
#endif
|
|
|
|
|
if (with_fp16_) {
|
|
|
|
|
return ((type == nvinfer1::DataType::kFLOAT ||
|
|
|
|
|
type == nvinfer1::DataType::kHALF) &&
|
|
|
|
|
(format == nvinfer1::PluginFormat::kNCHW));
|
|
|
|
|
} else {
|
|
|
|
|
return ((type == nvinfer1::DataType::kFLOAT) &&
|
|
|
|
|
(format == nvinfer1::PluginFormat::kNCHW));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
nvinfer1::Dims SlicePlugin::getOutputDimensions(int index,
|
|
|
|
|
@ -170,20 +170,17 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
|
|
|
|
|
int blocks = (out_num + threads - 1) / threads;
|
|
|
|
|
auto input_type = getDataType();
|
|
|
|
|
if (input_type == nvinfer1::DataType::kFLOAT) {
|
|
|
|
|
VLOG(1) << "TRT Plugin DataType selected. Slice-->fp32";
|
|
|
|
|
const float *input1 = static_cast<const float *>(inputs[0]);
|
|
|
|
|
float *output = static_cast<float *>(outputs[0]);
|
|
|
|
|
SliceKernel<float><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
|
|
|
|
|
out_num, num_dims, input1, offset_temp_data_, output);
|
|
|
|
|
} else if (input_type == nvinfer1::DataType::kHALF) {
|
|
|
|
|
#ifdef SUPPORTS_CUDA_FP16
|
|
|
|
|
VLOG(1) << "TRT Plugin DataType selected. Slice-->fp16";
|
|
|
|
|
const half *input1 = static_cast<const half *>(inputs[0]);
|
|
|
|
|
half *output = static_cast<half *>(outputs[0]);
|
|
|
|
|
SliceKernel<half><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
|
|
|
|
|
out_num, num_dims, input1, offset_temp_data_, output);
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW(platform::errors::Fatal(
|
|
|
|
|
"The cuda archs you specific should greater than 600."));
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(platform::errors::Fatal(
|
|
|
|
|
"The Slice TRT Plugin's input type should be float or half."));
|
|
|
|
|
@ -194,7 +191,7 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
|
|
|
|
|
size_t SlicePlugin::getSerializationSize() {
|
|
|
|
|
return getBaseSerializationSize() + SerializedSize(getPluginType()) +
|
|
|
|
|
SerializedSize(starts_) + SerializedSize(ends_) +
|
|
|
|
|
SerializedSize(axes_) + SerializedSize(ban_fp16_);
|
|
|
|
|
SerializedSize(axes_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SlicePlugin::serialize(void *buffer) {
|
|
|
|
|
@ -203,15 +200,15 @@ void SlicePlugin::serialize(void *buffer) {
|
|
|
|
|
SerializeValue(&buffer, starts_);
|
|
|
|
|
SerializeValue(&buffer, ends_);
|
|
|
|
|
SerializeValue(&buffer, axes_);
|
|
|
|
|
SerializeValue(&buffer, ban_fp16_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Dynamic Plugin below.
|
|
|
|
|
#if IS_TRT_VERSION_GE(6000)
|
|
|
|
|
SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts,
|
|
|
|
|
std::vector<int> ends,
|
|
|
|
|
std::vector<int> axes, bool ban_fp16)
|
|
|
|
|
: starts_(starts), ends_(ends), axes_(axes), ban_fp16_(ban_fp16) {
|
|
|
|
|
std::vector<int> axes, bool with_fp16)
|
|
|
|
|
: starts_(starts), ends_(ends), axes_(axes) {
|
|
|
|
|
with_fp16_ = with_fp16;
|
|
|
|
|
cudaEventCreate(©_event_);
|
|
|
|
|
cudaStreamCreate(©_stream_);
|
|
|
|
|
}
|
|
|
|
|
@ -221,7 +218,7 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
|
|
|
|
|
DeserializeValue(&serialData, &serialLength, &starts_);
|
|
|
|
|
DeserializeValue(&serialData, &serialLength, &ends_);
|
|
|
|
|
DeserializeValue(&serialData, &serialLength, &axes_);
|
|
|
|
|
DeserializeValue(&serialData, &serialLength, &ban_fp16_);
|
|
|
|
|
DeserializeValue(&serialData, &serialLength, &with_fp16_);
|
|
|
|
|
cudaEventCreate(©_event_);
|
|
|
|
|
cudaStreamCreate(©_stream_);
|
|
|
|
|
}
|
|
|
|
|
@ -237,7 +234,7 @@ int SlicePluginDynamic::initialize() { return 0; }
|
|
|
|
|
|
|
|
|
|
size_t SlicePluginDynamic::getSerializationSize() const {
|
|
|
|
|
size_t size = SerializedSize(starts_) + SerializedSize(ends_) +
|
|
|
|
|
SerializedSize(axes_) + SerializedSize(ban_fp16_);
|
|
|
|
|
SerializedSize(axes_) + SerializedSize(with_fp16_);
|
|
|
|
|
|
|
|
|
|
return size;
|
|
|
|
|
}
|
|
|
|
|
@ -246,7 +243,7 @@ void SlicePluginDynamic::serialize(void *buffer) const {
|
|
|
|
|
SerializeValue(&buffer, starts_);
|
|
|
|
|
SerializeValue(&buffer, ends_);
|
|
|
|
|
SerializeValue(&buffer, axes_);
|
|
|
|
|
SerializeValue(&buffer, ban_fp16_);
|
|
|
|
|
SerializeValue(&buffer, with_fp16_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
|
|
|
|
|
@ -278,19 +275,14 @@ bool SlicePluginDynamic::supportsFormatCombination(
|
|
|
|
|
|
|
|
|
|
const nvinfer1::PluginTensorDesc &in = in_out[pos];
|
|
|
|
|
if (pos == 0) {
|
|
|
|
|
#ifdef SUPPORTS_CUDA_FP16
|
|
|
|
|
if (ban_fp16_) {
|
|
|
|
|
return (in.type == nvinfer1::DataType::kFLOAT) &&
|
|
|
|
|
(in.format == nvinfer1::TensorFormat::kLINEAR);
|
|
|
|
|
} else {
|
|
|
|
|
if (with_fp16_) {
|
|
|
|
|
return (in.type == nvinfer1::DataType::kFLOAT ||
|
|
|
|
|
in.type == nvinfer1::DataType::kHALF) &&
|
|
|
|
|
(in.format == nvinfer1::TensorFormat::kLINEAR);
|
|
|
|
|
} else {
|
|
|
|
|
return (in.type == nvinfer1::DataType::kFLOAT) &&
|
|
|
|
|
(in.format == nvinfer1::TensorFormat::kLINEAR);
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
return (in.type == nvinfer1::DataType::kFLOAT) &&
|
|
|
|
|
(in.format == nvinfer1::TensorFormat::kLINEAR);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
|
|
|
|
|
// output
|
|
|
|
|
@ -362,20 +354,17 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
|
|
|
|
|
int blocks = (out_num + threads - 1) / threads;
|
|
|
|
|
auto input_type = input_desc[0].type;
|
|
|
|
|
if (input_type == nvinfer1::DataType::kFLOAT) {
|
|
|
|
|
VLOG(1) << "TRT Plugin DataType selected. Slice-->fp32";
|
|
|
|
|
const float *input1 = static_cast<const float *>(inputs[0]);
|
|
|
|
|
float *output = static_cast<float *>(outputs[0]);
|
|
|
|
|
SliceKernel<float><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
|
|
|
|
|
out_num, num_dims, input1, offset_temp_data_, output);
|
|
|
|
|
} else if (input_type == nvinfer1::DataType::kHALF) {
|
|
|
|
|
#ifdef SUPPORTS_CUDA_FP16
|
|
|
|
|
VLOG(1) << "TRT Plugin DataType selected. Slice-->fp16";
|
|
|
|
|
const half *input1 = static_cast<const half *>(inputs[0]);
|
|
|
|
|
half *output = static_cast<half *>(outputs[0]);
|
|
|
|
|
SliceKernel<half><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
|
|
|
|
|
out_num, num_dims, input1, offset_temp_data_, output);
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW(platform::errors::Fatal(
|
|
|
|
|
"The cuda archs you specific should greater than 600."));
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(platform::errors::Fatal(
|
|
|
|
|
"The Slice TRT Plugin's input type should be float or half."));
|
|
|
|
|
|