diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 82d8c3df4e..d6d4c2f8a2 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -65,7 +65,7 @@ class TensorRTEngine : public EngineBase { // Initialize the inference network, so that TensorRT layers can add to this // network. void InitNetwork() { - infer_builder_.reset(createInferBuilder(logger_)); + infer_builder_.reset(createInferBuilder(&logger_)); infer_network_.reset(infer_builder_->createNetwork()); } // After finishing adding ops, freeze this network and creates the executation diff --git a/paddle/fluid/inference/tensorrt/helper.h b/paddle/fluid/inference/tensorrt/helper.h index 796283d325..2b402cce60 100644 --- a/paddle/fluid/inference/tensorrt/helper.h +++ b/paddle/fluid/inference/tensorrt/helper.h @@ -46,13 +46,13 @@ const int kDataTypeSize[] = { // The following two API are implemented in TensorRT's header file, cannot load // from the dynamic library. So create our own implementation and directly // trigger the method from the dynamic library. -static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) { +static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger* logger) { return static_cast( - dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION)); + dy::createInferBuilder_INTERNAL(logger, NV_TENSORRT_VERSION)); } -static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) { +static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) { return static_cast( - dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); + dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION)); } // A logger for create TensorRT infer builder. @@ -80,7 +80,7 @@ class NaiveLogger : public nvinfer1::ILogger { return *x; } - virtual ~NaiveLogger() override {} + ~NaiveLogger() override {} }; } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/test_tensorrt.cc b/paddle/fluid/inference/tensorrt/test_tensorrt.cc index aed5b5e1a2..a075379857 100644 --- a/paddle/fluid/inference/tensorrt/test_tensorrt.cc +++ b/paddle/fluid/inference/tensorrt/test_tensorrt.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include #include #include "NvInfer.h" -#include "cuda.h" -#include "cuda_runtime_api.h" #include "paddle/fluid/platform/dynload/tensorrt.h" namespace dy = paddle::platform::dynload; @@ -43,7 +43,7 @@ class Logger : public nvinfer1::ILogger { class ScopedWeights { public: - ScopedWeights(float value) : value_(value) { + explicit ScopedWeights(float value) : value_(value) { w.type = nvinfer1::DataType::kFLOAT; w.values = &value_; w.count = 1; @@ -58,13 +58,13 @@ class ScopedWeights { // The following two API are implemented in TensorRT's header file, cannot load // from the dynamic library. So create our own implementation and directly // trigger the method from the dynamic library. -nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) { +nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger* logger) { return static_cast( - dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION)); + dy::createInferBuilder_INTERNAL(logger, NV_TENSORRT_VERSION)); } -nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) { +nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) { return static_cast( - dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); + dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION)); } const char* kInputTensor = "input"; @@ -74,7 +74,7 @@ const char* kOutputTensor = "output"; nvinfer1::IHostMemory* CreateNetwork() { Logger logger; // Create the engine. - nvinfer1::IBuilder* builder = createInferBuilder(logger); + nvinfer1::IBuilder* builder = createInferBuilder(&logger); ScopedWeights weights(2.); ScopedWeights bias(3.); @@ -103,9 +103,9 @@ nvinfer1::IHostMemory* CreateNetwork() { return model; } -void Execute(nvinfer1::IExecutionContext& context, const float* input, +void Execute(nvinfer1::IExecutionContext* context, const float* input, float* output) { - const nvinfer1::ICudaEngine& engine = context.getEngine(); + const nvinfer1::ICudaEngine& engine = context->getEngine(); // Two binds, input and output ASSERT_EQ(engine.getNbBindings(), 2); const int input_index = engine.getBindingIndex(kInputTensor); @@ -119,7 +119,7 @@ void Execute(nvinfer1::IExecutionContext& context, const float* input, // Copy the input to the GPU, execute the network, and copy the output back. ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float), cudaMemcpyHostToDevice, stream)); - context.enqueue(1, buffers, stream, nullptr); + context->enqueue(1, buffers, stream, nullptr); ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float), cudaMemcpyDeviceToHost, stream)); cudaStreamSynchronize(stream); @@ -136,7 +136,7 @@ TEST(TensorrtTest, BasicFunction) { // Use the model to create an engine and an execution context. Logger logger; - nvinfer1::IRuntime* runtime = createInferRuntime(logger); + nvinfer1::IRuntime* runtime = createInferRuntime(&logger); nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(model->data(), model->size(), nullptr); model->destroy(); @@ -145,7 +145,7 @@ TEST(TensorrtTest, BasicFunction) { // Execute the network. float input = 1234; float output; - Execute(*context, &input, &output); + Execute(context, &input, &output); EXPECT_EQ(output, input * 2 + 3); // Destroy the engine. diff --git a/paddle/fluid/operators/math/pooling.cc b/paddle/fluid/operators/math/pooling.cc index 97a2e81c84..b871851798 100644 --- a/paddle/fluid/operators/math/pooling.cc +++ b/paddle/fluid/operators/math/pooling.cc @@ -11,8 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #include "paddle/fluid/operators/math/pooling.h" +#include +#include namespace paddle { namespace operators { @@ -27,9 +28,10 @@ template class Pool2dFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, std::vector& ksize, - std::vector& strides, std::vector& paddings, - PoolProcess pool_process, framework::Tensor* output) { + const framework::Tensor& input, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, PoolProcess pool_process, + framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -63,11 +65,11 @@ class Pool2dFunctor { T ele = pool_process.initial(); for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - pool_process.compute(ele, input_data[h * input_width + w]); + pool_process.compute(input_data[h * input_width + w], &ele); } } int pool_size = (hend - hstart) * (wend - wstart); - pool_process.finalize(ele, (static_cast(pool_size))); + pool_process.finalize(static_cast(pool_size), &ele); output_data[ph * output_width + pw] = ele; } } @@ -86,13 +88,12 @@ class Pool2dFunctor { template class Pool2dGradFunctor { public: - void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, - const framework::Tensor& output, - const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings, - PoolProcess pool_grad_process, - framework::Tensor* input_grad) { + void operator()( + const platform::CPUDeviceContext& context, const framework::Tensor& input, + const framework::Tensor& output, const framework::Tensor& output_grad, + const std::vector& ksize, const std::vector& strides, + const std::vector& paddings, PoolProcess pool_grad_process, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -131,8 +132,8 @@ class Pool2dGradFunctor { input_data[h * input_width + w], output_data[ph * output_width + pw], output_grad_data[ph * output_width + pw], - input_grad_data[h * input_width + w], - static_cast(scale)); + static_cast(scale), + input_grad_data + h * input_width + w); } } } @@ -154,12 +155,11 @@ class Pool2dGradFunctor { template class MaxPool2dGradFunctor { public: - void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, - const framework::Tensor& output, - const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings, - framework::Tensor* input_grad) { + void operator()( + const platform::CPUDeviceContext& context, const framework::Tensor& input, + const framework::Tensor& output, const framework::Tensor& output_grad, + const std::vector& ksize, const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -246,9 +246,10 @@ template class Pool3dFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, std::vector& ksize, - std::vector& strides, std::vector& paddings, - PoolProcess pool_process, framework::Tensor* output) { + const framework::Tensor& input, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, PoolProcess pool_process, + framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -293,14 +294,14 @@ class Pool3dFunctor { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { pool_process.compute( - ele, - input_data[(d * input_height + h) * input_width + w]); + input_data[(d * input_height + h) * input_width + w], + &ele); } } } int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); - pool_process.finalize(ele, static_cast(pool_size)); + pool_process.finalize(static_cast(pool_size), &ele); output_data[output_idx] = ele; } } @@ -320,13 +321,12 @@ class Pool3dFunctor { template class Pool3dGradFunctor { public: - void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, - const framework::Tensor& output, - const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings, - PoolProcess pool_grad_process, - framework::Tensor* input_grad) { + void operator()( + const platform::CPUDeviceContext& context, const framework::Tensor& input, + const framework::Tensor& output, const framework::Tensor& output_grad, + const std::vector& ksize, const std::vector& strides, + const std::vector& paddings, PoolProcess pool_grad_process, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -379,8 +379,8 @@ class Pool3dGradFunctor { (pd * output_height + ph) * output_width + pw; pool_grad_process.compute( input_data[input_idx], output_data[output_idx], - output_grad_data[output_idx], - input_grad_data[input_idx], static_cast(scale)); + output_grad_data[output_idx], static_cast(scale), + input_grad_data + input_idx); } } } @@ -404,12 +404,11 @@ class Pool3dGradFunctor { template class MaxPool3dGradFunctor { public: - void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, - const framework::Tensor& output, - const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings, - framework::Tensor* input_grad) { + void operator()( + const platform::CPUDeviceContext& context, const framework::Tensor& input, + const framework::Tensor& output, const framework::Tensor& output_grad, + const std::vector& ksize, const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -510,9 +509,10 @@ template class MaxPool2dWithIndexFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, std::vector& ksize, - std::vector& strides, std::vector& paddings, - framework::Tensor* output, framework::Tensor* mask) { + const framework::Tensor& input, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* output, + framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -576,8 +576,9 @@ class MaxPool2dWithIndexGradFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& output_grad, - const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings, + const framework::Tensor& mask, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad) { const int batch_size = input_grad->dims()[0]; const int input_height = input_grad->dims()[2]; @@ -628,9 +629,10 @@ template class MaxPool3dWithIndexFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, std::vector& ksize, - std::vector& strides, std::vector& paddings, - framework::Tensor* output, framework::Tensor* mask) { + const framework::Tensor& input, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* output, + framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -708,8 +710,9 @@ class MaxPool3dWithIndexGradFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& output_grad, - const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings, + const framework::Tensor& mask, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad) { const int batch_size = input_grad->dims()[0]; const int input_depth = input_grad->dims()[2]; diff --git a/paddle/fluid/operators/math/pooling.cu b/paddle/fluid/operators/math/pooling.cu index 267f8c409d..b1c76350d1 100644 --- a/paddle/fluid/operators/math/pooling.cu +++ b/paddle/fluid/operators/math/pooling.cu @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -47,11 +49,11 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data, T ele = pool_process.initial(); for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - pool_process.compute(ele, input_data[h * input_width + w]); + pool_process.compute(input_data[h * input_width + w], &ele); } } int pool_size = (hend - hstart) * (wend - wstart); - pool_process.finalize(ele, (static_cast(pool_size))); + pool_process.finalize(static_cast(pool_size), &ele); output_data[index] = ele; } } @@ -96,8 +98,8 @@ __global__ void KernelPool2DGrad( int pool_size = (hend - hstart) * (wend - wstart); int output_sub_idx = ph * output_width + pw; pool_process.compute(input, output_data[output_sub_idx], - output_grad[output_sub_idx], gradient, - static_cast(1.0 / pool_size)); + output_grad[output_sub_idx], + static_cast(1.0 / pool_size), &gradient); } } input_grad[index] = gradient; @@ -158,9 +160,10 @@ template class Pool2dFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& input, std::vector& ksize, - std::vector& strides, std::vector& paddings, - PoolProcess pool_process, framework::Tensor* output) { + const framework::Tensor& input, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, PoolProcess pool_process, + framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -201,9 +204,11 @@ class Pool2dGradFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const framework::Tensor& output, - const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings, - PoolProcess pool_process, framework::Tensor* input_grad) { + const framework::Tensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, PoolProcess pool_process, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -246,8 +251,10 @@ class MaxPool2dGradFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const framework::Tensor& output, - const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings, + const framework::Tensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; @@ -340,12 +347,12 @@ __global__ void KernelPool3D(const int nthreads, const T* input_data, for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { pool_process.compute( - ele, input_data[(d * input_height + h) * input_width + w]); + input_data[(d * input_height + h) * input_width + w], &ele); } } } int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); - pool_process.finalize(ele, static_cast(pool_size)); + pool_process.finalize(static_cast(pool_size), &ele); output_data[index] = ele; } } @@ -405,8 +412,8 @@ __global__ void KernelPool3DGrad( int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); int output_sub_idx = (pd * output_height + ph) * output_width + pw; pool_process.compute(input, output_data[output_sub_idx], - output_grad[output_sub_idx], gradient, - static_cast(1.0 / pool_size)); + output_grad[output_sub_idx], + static_cast(1.0 / pool_size), &gradient); } } } @@ -474,9 +481,10 @@ template class Pool3dFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& input, std::vector& ksize, - std::vector& strides, std::vector& paddings, - PoolProcess pool_process, framework::Tensor* output) { + const framework::Tensor& input, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, PoolProcess pool_process, + framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -525,9 +533,11 @@ class Pool3dGradFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const framework::Tensor& output, - const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings, - PoolProcess pool_process, framework::Tensor* input_grad) { + const framework::Tensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, PoolProcess pool_process, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -578,8 +588,10 @@ class MaxPool3dGradFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const framework::Tensor& output, - const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings, + const framework::Tensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; @@ -736,9 +748,10 @@ template class MaxPool2dWithIndexFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& input, std::vector& ksize, - std::vector& strides, std::vector& paddings, - framework::Tensor* output, framework::Tensor* mask) { + const framework::Tensor& input, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* output, + framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -779,8 +792,9 @@ class MaxPool2dWithIndexGradFunctor { public: void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& output_grad, - const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings, + const framework::Tensor& mask, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad) { const int batch_size = input_grad->dims()[0]; const int input_channels = input_grad->dims()[1]; @@ -937,9 +951,10 @@ template class MaxPool3dWithIndexFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& input, std::vector& ksize, - std::vector& strides, std::vector& paddings, - framework::Tensor* output, framework::Tensor* mask) { + const framework::Tensor& input, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* output, + framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -987,8 +1002,9 @@ class MaxPool3dWithIndexGradFunctor { public: void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& output_grad, - const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings, + const framework::Tensor& mask, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad) { const int batch_size = input_grad->dims()[0]; const int input_channels = input_grad->dims()[1]; diff --git a/paddle/fluid/operators/math/pooling.h b/paddle/fluid/operators/math/pooling.h index 74cb42f0d0..2538d739cc 100644 --- a/paddle/fluid/operators/math/pooling.h +++ b/paddle/fluid/operators/math/pooling.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" @@ -23,8 +24,8 @@ namespace operators { namespace math { #define FLT_MAX \ - __FLT_MAX__ // It might need to be placed in another file, but I'm still - // wondering where to put it. + __FLT_MAX__ // TODO(zcd) :It might need to be placed in another file, but I'm + // still wondering where to put it. /* * \brief Extracting simple operations from pooling. @@ -40,33 +41,33 @@ template class MaxPool { public: DEVICE inline T initial() { return static_cast(-FLT_MAX); } - DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } - DEVICE inline void finalize(T& y, const T& pool_field) {} + DEVICE inline void compute(const T& x, T* y) { *y = *y > x ? *y : x; } + DEVICE inline void finalize(const T& pool_field, T* y) {} }; template class AvgPool { public: DEVICE inline T initial() { return static_cast(0); } - DEVICE inline void compute(T& y, const T& x) { y += x; } - DEVICE inline void finalize(T& y, const T& pool_field) { y /= pool_field; } + DEVICE inline void compute(const T& x, T* y) { *y += x; } + DEVICE inline void finalize(const T& pool_field, T* y) { *y /= pool_field; } }; template class MaxPoolGrad { public: - DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx, - T scale) { - dx += dy * (x == y); + DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale, + T* dx) { + *dx += dy * (x == y); } }; template class AvgPoolGrad { public: - DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx, - T scale) { - dx += (scale * dy); + DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale, + T* dx) { + *dx += (scale * dy); } }; @@ -88,8 +89,9 @@ template class Pool2dFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_compute, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, PoolProcess pool_compute, framework::Tensor* output); }; @@ -98,9 +100,11 @@ class Pool2dGradFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, const framework::Tensor& output, - const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings, - PoolProcess pool_compute, framework::Tensor* input_grad); + const framework::Tensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, PoolProcess pool_compute, + framework::Tensor* input_grad); }; template @@ -108,8 +112,10 @@ class MaxPool2dGradFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, const framework::Tensor& output, - const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings, + const framework::Tensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad); }; @@ -117,8 +123,9 @@ template class Pool3dFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_compute, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, PoolProcess pool_compute, framework::Tensor* output); }; @@ -127,9 +134,11 @@ class Pool3dGradFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, const framework::Tensor& output, - const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings, - PoolProcess pool_compute, framework::Tensor* input_grad); + const framework::Tensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, PoolProcess pool_compute, + framework::Tensor* input_grad); }; template @@ -137,8 +146,10 @@ class MaxPool3dGradFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, const framework::Tensor& output, - const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings, + const framework::Tensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad); }; @@ -153,8 +164,9 @@ template class MaxPool2dWithIndexFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, - std::vector& ksize, std::vector& strides, - std::vector& paddings, framework::Tensor* output, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* output, framework::Tensor* mask); }; @@ -163,8 +175,9 @@ class MaxPool2dWithIndexGradFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& output_grad, - const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings, + const framework::Tensor& mask, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad); }; @@ -172,8 +185,9 @@ template class MaxPool3dWithIndexFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, - std::vector& ksize, std::vector& strides, - std::vector& paddings, framework::Tensor* output, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* output, framework::Tensor* mask); }; @@ -182,8 +196,9 @@ class MaxPool3dWithIndexGradFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& output_grad, - const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings, + const framework::Tensor& mask, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad); }; diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index 38bd3b9975..d63c6c4ed5 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -22,7 +22,7 @@ template class PaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::LoDTensor& seq, framework::Tensor& padding, + const framework::LoDTensor& seq, framework::Tensor* padding, bool norm_by_times) { auto lod = seq.lod(); PADDLE_ENFORCE_GT(lod.size(), 0UL, @@ -37,7 +37,7 @@ class PaddingLoDTensorFunctor { "The first dimension of LoDTensor seq should be " "equal to the sum of all sequences's length."); - auto padding_dims = padding.dims(); + auto padding_dims = padding->dims(); PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, "The input padding should be a 3-D Tensor of shape " "[max_sequence_length, num_sequences, sequence_width]."); @@ -58,7 +58,7 @@ class PaddingLoDTensorFunctor { "width of sequence in LoDTensor seq."); const T* seq_data = seq.data(); - T* padding_data = padding.data(); + T* padding_data = padding->data(); for (int64_t i = 0; i < max_sequence_length; ++i) { for (int64_t j = 0; j < num_sequences; ++j) { int64_t start_pos = abs_offset_lod[level][j]; @@ -84,16 +84,16 @@ template class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, - framework::LoDTensor& seq, const framework::Tensor& padding, + framework::LoDTensor* seq, const framework::Tensor& padding, bool norm_by_times) { - auto lod = seq.lod(); + auto lod = seq->lod(); PADDLE_ENFORCE_GT(lod.size(), 0UL, "The LoD of LoDTensor seq should not be null."); const size_t level = 0; framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); - auto seq_dims = seq.dims(); + auto seq_dims = seq->dims(); PADDLE_ENFORCE_EQ(seq_dims[0], static_cast(abs_offset_lod[level].back()), "The first dimension of LoDTensor seq should be " @@ -114,13 +114,13 @@ class UnpaddingLoDTensorFunctor { "The second dimension of Tensor padding should be " "the number of sequences in LoDTensor seq."); - const int64_t sequence_width = seq.numel() / seq_dims[0]; + const int64_t sequence_width = seq->numel() / seq_dims[0]; PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, "The third dimension of Tensor padding should be the " "width of sequence in LoDTensor seq."); const T* padding_data = padding.data(); - T* seq_data = seq.data(); + T* seq_data = seq->data(); for (int64_t i = 0; i < num_sequences; ++i) { int64_t start_pos = abs_offset_lod[level][i]; int64_t sequence_length = abs_offset_lod[level][i + 1] - start_pos; diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index c044e6fc32..0956a0c17d 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/operators/math/sequence_padding.h" namespace paddle { @@ -61,7 +62,7 @@ template class PaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const framework::LoDTensor& seq, framework::Tensor& padding, + const framework::LoDTensor& seq, framework::Tensor* padding, bool norm_by_times) { auto lod = seq.lod(); PADDLE_ENFORCE_GT(lod.size(), 0UL, @@ -76,7 +77,7 @@ class PaddingLoDTensorFunctor { "The first dimension of LoDTensor seq should be " "equal to the sum of all sequences's length."); - auto padding_dims = padding.dims(); + auto padding_dims = padding->dims(); PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, "The input padding should be a 3-D Tensor of shape " "[max_sequence_length, num_sequences, sequence_width]."); @@ -97,8 +98,8 @@ class PaddingLoDTensorFunctor { "width of sequence in LoDTensor seq."); if (!norm_by_times && num_sequences == 1UL) { - TensorCopy(seq, context.GetPlace(), context, &padding); - padding.Resize(padding_dims); + TensorCopy(seq, context.GetPlace(), context, padding); + padding->Resize(padding_dims); return; } @@ -117,7 +118,7 @@ class PaddingLoDTensorFunctor { dim3 grid(grid_dim_x, grid_dim_y); const T* seq_data = seq.data(); - T* padding_data = padding.data(); + T* padding_data = padding->data(); if (norm_by_times) { SequencePaddingKernel<<>>( padding_data, const_cast(seq_data), @@ -136,16 +137,16 @@ template class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, - framework::LoDTensor& seq, const framework::Tensor& padding, + framework::LoDTensor* seq, const framework::Tensor& padding, bool norm_by_times) { - auto lod = seq.lod(); + auto lod = seq->lod(); PADDLE_ENFORCE_GT(lod.size(), 0UL, "The lod of LoDTensor seq should not be null."); const size_t level = 0; framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); - auto seq_dims = seq.dims(); + auto seq_dims = seq->dims(); PADDLE_ENFORCE_EQ(seq_dims[0], static_cast(abs_offset_lod[level].back()), "The first dimension of LoDTensor seq should be " @@ -166,14 +167,14 @@ class UnpaddingLoDTensorFunctor { "The second dimension of Tensor padding should be " "the number of sequences in LoDTensor seq."); - const int64_t sequence_width = seq.numel() / seq_dims[0]; + const int64_t sequence_width = seq->numel() / seq_dims[0]; PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, "The third dimension of Tensor padding should be the " "width of sequence in LoDTensor seq."); if (!norm_by_times && num_sequences == 1UL) { - TensorCopy(padding, context.GetPlace(), context, &seq); - seq.Resize(seq_dims); + TensorCopy(padding, context.GetPlace(), context, seq); + seq->Resize(seq_dims); return; } @@ -192,7 +193,7 @@ class UnpaddingLoDTensorFunctor { dim3 grid(grid_dim_x, grid_dim_y); const T* padding_data = padding.data(); - T* seq_data = seq.data(); + T* seq_data = seq->data(); if (norm_by_times) { SequencePaddingKernel<<>>( const_cast(padding_data), seq_data, diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index 17f044b9d6..b56e6db1eb 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/platform/device_context.h" @@ -64,13 +65,13 @@ template class PaddingLoDTensorFunctor { public: void operator()(const DeviceContext& context, const framework::LoDTensor& seq, - framework::Tensor& padding, bool norm_by_times); + framework::Tensor* padding, bool norm_by_times); }; template class UnpaddingLoDTensorFunctor { public: - void operator()(const DeviceContext& context, framework::LoDTensor& seq, + void operator()(const DeviceContext& context, framework::LoDTensor* seq, const framework::Tensor& padding, bool norm_by_times); }; diff --git a/paddle/fluid/operators/math/sequence_padding_test.cc b/paddle/fluid/operators/math/sequence_padding_test.cc index e3d6214485..b9a1b9ae4d 100644 --- a/paddle/fluid/operators/math/sequence_padding_test.cc +++ b/paddle/fluid/operators/math/sequence_padding_test.cc @@ -54,12 +54,12 @@ void TestSequencePadding(const paddle::framework::LoD& lod, static_cast(sequence_width)}); padding.mutable_data(padding_dims, *place); paddle::operators::math::PaddingLoDTensorFunctor()( - *context, seq, padding, false); + *context, seq, &padding, false); seq_back.set_lod(lod); seq_back.mutable_data(seq_dims, *place); paddle::operators::math::UnpaddingLoDTensorFunctor()( - *context, seq_back, padding, false); + *context, &seq_back, padding, false); if (paddle::platform::is_cpu_place(*place)) { cpu_seq_back = seq_back; diff --git a/paddle/fluid/operators/send_recv_op_test.cc b/paddle/fluid/operators/send_recv_op_test.cc index a0b5a390db..0d495d8d15 100644 --- a/paddle/fluid/operators/send_recv_op_test.cc +++ b/paddle/fluid/operators/send_recv_op_test.cc @@ -113,7 +113,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, op->SetAttrMap(attrs); } -void StartServerNet(bool is_sparse) { +void StartServerNet(bool is_sparse, std::atomic *initialized) { f::Scope scope; p::CPUPlace place; VLOG(4) << "before init tensor"; @@ -122,7 +122,6 @@ void StartServerNet(bool is_sparse) { } else { InitTensorsInScope(place, &scope); } - // sub program run in listen_and_serv_op, for simple test we use sum f::ProgramDesc program; const auto &root_block = program.Block(0); @@ -130,8 +129,6 @@ void StartServerNet(bool is_sparse) { auto *prefetch_block = program.AppendBlock(root_block); // X for server side tensors, RX for received tensors, must be of same shape. AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block); - VLOG(4) << "before attr"; - f::AttributeMap attrs; attrs.insert({"endpoint", std::string("127.0.0.1:0")}); attrs.insert({"Fanin", 1}); @@ -144,16 +141,19 @@ void StartServerNet(bool is_sparse) { VLOG(4) << "before init op"; listen_and_serv_op = f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); - VLOG(4) << "before run op"; + *initialized = true; listen_and_serv_op->Run(scope, place); LOG(INFO) << "server exit"; } TEST(SendRecvOp, CPUDense) { - std::thread server_thread(StartServerNet, false); - // wait server to start + std::atomic initialized{false}; + std::thread server_thread(StartServerNet, false, &initialized); + while (!initialized) { + } static_cast(listen_and_serv_op.get()) ->WaitServerReady(); + // local net f::Scope scope; p::CPUPlace place; @@ -162,9 +162,11 @@ TEST(SendRecvOp, CPUDense) { scope.Var("RPC_CLIENT_VAR"); f::AttributeMap attrs; - selected_port = static_cast( - listen_and_serv_op.get()) - ->GetSelectedPort(); + auto *listen_and_serv_op_ptr = + static_cast( + listen_and_serv_op.get()); + ASSERT_TRUE(listen_and_serv_op_ptr != nullptr); + selected_port = listen_and_serv_op_ptr->GetSelectedPort(); std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); attrs.insert({"endpoints", std::vector({endpoint})}); attrs.insert({"epmap", std::vector({endpoint})}); @@ -191,9 +193,14 @@ TEST(SendRecvOp, CPUDense) { } TEST(SendRecvOp, CPUSparse) { - std::thread server_thread(StartServerNet, true); + std::atomic initialized; + initialized = false; + std::thread server_thread(StartServerNet, true, &initialized); + while (!initialized) { + } static_cast(listen_and_serv_op.get()) ->WaitServerReady(); + // local net f::Scope scope; p::CPUPlace place; @@ -201,9 +208,11 @@ TEST(SendRecvOp, CPUSparse) { InitSelectedRowsInScope(place, &scope); scope.Var("RPC_CLIENT_VAR"); f::AttributeMap attrs; - selected_port = static_cast( - listen_and_serv_op.get()) - ->GetSelectedPort(); + auto *listen_and_serv_op_ptr = + static_cast( + listen_and_serv_op.get()); + ASSERT_TRUE(listen_and_serv_op_ptr != nullptr); + selected_port = listen_and_serv_op_ptr->GetSelectedPort(); std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); attrs.insert({"endpoints", std::vector({endpoint})}); attrs.insert({"epmap", std::vector({endpoint})}); diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index 85131d0025..705cc894c0 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -162,7 +162,7 @@ class WarpCTCKernel : public framework::OpKernel { static_cast(sequence_width)}); warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); math::PaddingLoDTensorFunctor()( - ctx.template device_context(), *logits, warpctc_logits, + ctx.template device_context(), *logits, &warpctc_logits, false); const T* warpctc_logits_data = warpctc_logits.data(); @@ -217,7 +217,7 @@ class WarpCTCGradKernel : public framework::OpKernel { logits_grad->mutable_data(ctx.GetPlace()); bool norm_by_times = ctx.Attr("norm_by_times"); math::UnpaddingLoDTensorFunctor()( - ctx.template device_context(), *logits_grad, + ctx.template device_context(), logits_grad, *warpctc_grad, norm_by_times); const T* loss_grad_data = loss_grad->data(); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index dcd711a33f..93b09ed692 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -107,7 +107,7 @@ T TensorGetElement(const framework::Tensor &self, size_t offset) { return self.data()[offset]; } else { std::shared_ptr dst(new framework::Tensor); - framework::TensorCopy(self, platform::CPUPlace(), dst.get()); + framework::TensorCopySync(self, platform::CPUPlace(), dst.get()); return dst->data()[offset]; } } @@ -117,9 +117,9 @@ template void TensorSetElement(framework::Tensor *self, size_t offset, T elem) { if (platform::is_gpu_place(self->place())) { std::shared_ptr dst(new framework::Tensor); - framework::TensorCopy(*self, platform::CPUPlace(), dst.get()); + framework::TensorCopySync(*self, platform::CPUPlace(), dst.get()); dst->data()[offset] = elem; - framework::TensorCopy(*dst.get(), self->place(), self); + framework::TensorCopySync(*dst.get(), self->place(), self); } else if (platform::is_cpu_place(self->place())) { self->data()[offset] = elem;