|
|
|
|
@ -40,7 +40,9 @@ class SplitPlugin : public PluginTensorRT {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SplitPlugin* clone() const override {
|
|
|
|
|
return new SplitPlugin(axis_, output_length_, with_fp16_);
|
|
|
|
|
auto* ptr = new SplitPlugin(axis_, output_length_, with_fp16_);
|
|
|
|
|
ptr->shareData(this);
|
|
|
|
|
return ptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const char* getPluginType() const override { return "split_plugin"; }
|
|
|
|
|
@ -50,6 +52,7 @@ class SplitPlugin : public PluginTensorRT {
|
|
|
|
|
int num_inputs) override;
|
|
|
|
|
|
|
|
|
|
int initialize() override;
|
|
|
|
|
void terminate() override;
|
|
|
|
|
int enqueue(int batchSize, const void* const* inputs, void** outputs,
|
|
|
|
|
void* workspace, cudaStream_t stream) override;
|
|
|
|
|
|
|
|
|
|
@ -75,6 +78,9 @@ class SplitPlugin : public PluginTensorRT {
|
|
|
|
|
std::vector<int> segment_offsets_;
|
|
|
|
|
thrust::device_vector<int> d_segment_offsets_;
|
|
|
|
|
thrust::device_vector<float*> d_output_ptrs_;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void shareData(const SplitPlugin* another);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#if IS_TRT_VERSION_GE(6000)
|
|
|
|
|
|