From a24691a2a9299e3ee3055aa309dc3d3749572aaa Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Wed, 31 Oct 2018 20:49:56 +0800 Subject: [PATCH 01/23] add nearest neighbor interpolation operator cpu kernel --- .../operators/nearest_neighbor_interp_op.cc | 115 ++++++++++ .../operators/nearest_neighbor_interp_op.cu | 210 ++++++++++++++++++ .../operators/nearest_neighbor_interp_op.h | 130 +++++++++++ 3 files changed, 455 insertions(+) create mode 100644 paddle/fluid/operators/nearest_neighbor_interp_op.cc create mode 100644 paddle/fluid/operators/nearest_neighbor_interp_op.cu create mode 100644 paddle/fluid/operators/nearest_neighbor_interp_op.h diff --git a/paddle/fluid/operators/nearest_neighbor_interp_op.cc b/paddle/fluid/operators/nearest_neighbor_interp_op.cc new file mode 100644 index 0000000000..4e29fe5ac3 --- /dev/null +++ b/paddle/fluid/operators/nearest_neighbor_interp_op.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/nearest_neighbor_interp_op.h" +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class NearestNeighborInterpOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of BilinearInterOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of BilinearInterOp should not be null."); + + auto dim_x = ctx->GetInputDim("X"); // NCHW format + int out_h = ctx->Attrs().Get("out_h"); + int out_w = ctx->Attrs().Get("out_w"); + PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4"); + + if (ctx->HasInput("OutSize")) { + auto out_size_dim = ctx->GetInputDim("OutSize"); + PADDLE_ENFORCE_EQ(out_size_dim.size(), 1, + "OutSize's dimension size must be 1"); + PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2"); + } + std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); + ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); + } +}; + +class NearestNeighborInterpOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of nearest neighbor interpolation, " + "This is a 4-D tensor with shape of (N x C x h x w)"); + AddInput("OutSize", + "This is a 1-D tensor with two number. " + "The first number is height and the second number is width.") + .AsDispensable(); + AddOutput("Out", "The dimension of output is (N x C x out_h x out_w)"); + + AddAttr("out_h", "output height of bilinear interpolation op."); + AddAttr("out_w", "output width of bilinear interpolation op."); + AddComment(R"DOC( + Nearest neighbor interpolation is to perform nearest neighbor interpolation + in bot the 3rd dimention(in height direction) and the 4th dimention(in width + direction) on input tensor. + + For details, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation + )DOC"); + } +}; + +class NearestNeighborInterpOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto dim_x = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), dim_x); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(nearest_neighbor_interp, ops::NearestNeighborInterpOp, + ops::NearestNeighborInterpOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(nearest_neighbor_interp_grad, + ops::NearestNeighborInterpOpGrad); +REGISTER_OP_CPU_KERNEL(nearest_neighbor_interp, + ops::NearestNeighborInterpKernel, + ops::NearestNeighborInterpKernel); +REGISTER_OP_CPU_KERNEL(nearest_neighbor_interp_grad, + ops::NearestNeighborInterpGradKernel); diff --git a/paddle/fluid/operators/nearest_neighbor_interp_op.cu b/paddle/fluid/operators/nearest_neighbor_interp_op.cu new file mode 100644 index 0000000000..11002d2e1e --- /dev/null +++ b/paddle/fluid/operators/nearest_neighbor_interp_op.cu @@ -0,0 +1,210 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/nearest_neighbor_interp_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +template +using EigenTensor = framework::EigenTensor; +using framework::Tensor; + +template +__global__ void KeBilinearInterpFw( + const T* in, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const T ratio_h, const T ratioW) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < nthreads) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = ratio_h * out_img_idy; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + T h1lambda = ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; + + int out_img_idx = tid % out_img_w; + int in_img_idx = ratioW * out_img_idx; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + T w1lambda = ratioW * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; + + const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + + // bilinear interpolation + out[out_id_h * output_w + out_id_w] = + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + + h1lambda * (w2lambda * in_pos[h_id * in_img_w] + + w1lambda * in_pos[h_id * in_img_w + w_id]); + } +} + +template +__global__ void KeBilinearInterpBw( + T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, + const size_t input_w, const T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const T ratio_h, const T ratioW) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < nthreads) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = ratio_h * out_img_idy; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + T h1lambda = ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; + + int out_img_idx = tid % out_img_w; + int in_img_idx = ratioW * out_img_idx; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + T w1lambda = ratioW * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; + + T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + const T* out_pos = &out[out_id_h * output_w + out_id_w]; + atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]); + atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]); + atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]); + atomicAdd(&in_pos[h_id * in_img_w + w_id], + h1lambda * w1lambda * out_pos[0]); + } +} + +template +class NearestNeighborInterpOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto* input_t = ctx.Input("X"); // float tensor + auto* output_t = ctx.Output("Out"); // float tensor + auto* input = input_t->data(); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + auto out_dims = output_t->dims(); + auto out_size_t = ctx.Input("OutSize"); + if (out_size_t != nullptr) { + Tensor sizes; + framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_h = size_data[0]; + out_w = size_data[1]; + } + auto* output = output_t->mutable_data( + {out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace()); + + int batch_size = input_t->dims()[0]; + int channels = input_t->dims()[1]; + int in_h = input_t->dims()[2]; + int in_w = input_t->dims()[3]; + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = channels * in_hw; + int out_chw = channels * out_hw; + + T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + if (in_h == out_h && in_w == out_w) { + memcpy(output, input, input_t->numel() * sizeof(T)); + } else { + int threadNum = batch_size * out_chw; + int blocks = (threadNum + 1024 - 1) / 1024; + + KeBilinearInterpFw< + T><<>>( + input, in_h, in_w, batch_size, in_chw, output, out_h, out_w, + batch_size, out_chw, channels, ratio_h, ratio_w); + } + } +}; + +template +class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* d_input_t = ctx.Output(framework::GradVarName("X")); + auto* d_output_t = ctx.Input(framework::GradVarName("Out")); + auto* d_output = d_output_t->data(); + auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); + + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, d_input_t, static_cast(0.0)); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + + auto out_size_t = ctx.Input("OutSize"); + if (out_size_t != nullptr) { + Tensor sizes; + framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_h = size_data[0]; + out_w = size_data[1]; + } + + int batch_size = d_input_t->dims()[0]; + int channels = d_input_t->dims()[1]; + int in_h = d_input_t->dims()[2]; + int in_w = d_input_t->dims()[3]; + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = channels * in_hw; + int out_chw = channels * out_hw; + + T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + if (in_h == out_h && in_w == out_w) { + memcpy(d_input, d_output, d_input_t->numel() * sizeof(T)); + } else { + int threadNum = batch_size * out_chw; + int blocks = (threadNum + 1024 - 1) / 1024; + + KeBilinearInterpBw< + T><<>>( + d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w, + batch_size, out_chw, channels, ratio_h, ratio_w); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp, + ops::NearestNeighborInterpOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(nearest_neighborinterp_grad, + ops::NearestNeighborInterpGradOpCUDAKernel); diff --git a/paddle/fluid/operators/nearest_neighbor_interp_op.h b/paddle/fluid/operators/nearest_neighbor_interp_op.h new file mode 100644 index 0000000000..5ba12eaa7c --- /dev/null +++ b/paddle/fluid/operators/nearest_neighbor_interp_op.h @@ -0,0 +1,130 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +using EigenTensor = framework::EigenTensor; +using Tensor = framework::Tensor; + +template +class NearestNeighborInterpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = out_size->data(); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + + const int in_n = input->dims()[0]; + const int in_c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + + output->mutable_data({in_n, in_c, out_h, out_w}, ctx.GetPlace()); + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, output, static_cast(0.0)); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*input, ctx.GetPlace(), output); + return; + } + + float ratio_h = + (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + float ratio_w = + (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + auto input_t = EigenTensor::From(*input); + auto output_t = EigenTensor::From(*output); + for (int k = 0; k < out_h; k++) { // loop for images + for (int l = 0; l < out_w; l++) { + int in_k = static_cast(round(ratio_h * k)); + int in_l = static_cast(round(ratio_w * l)); + for (int i = 0; i < in_n; i++) { // loop for batches + for (int j = 0; j < in_c; j++) { // loop for channels + output_t(i, j, k, l) = input_t(i, j, in_k, in_l); + } + } + } + } + } +}; + +template +class NearestNeighborInterpGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = out_size->data(); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + + const int in_n = input_grad->dims()[0]; + const int in_c = input_grad->dims()[1]; + const int in_h = input_grad->dims()[2]; + const int in_w = input_grad->dims()[3]; + + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); + return; + } + + float ratio_h = + (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + float ratio_w = + (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(*output_grad); + for (int k = 0; k < out_h; k++) { // loop for images + for (int l = 0; l < out_w; l++) { + int in_k = static_cast(round(ratio_h * k)); + int in_l = static_cast(round(ratio_w * l)); + for (int i = 0; i < in_n; i++) { // loop for batches + for (int j = 0; j < in_c; j++) { // loop for channels + input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l); + } + } + } + } + } +}; + +} // namespace operators +} // namespace paddle From 9755611938eb7f3aaa61cf8ffc66648fc6f7c801 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 1 Nov 2018 13:51:55 +0800 Subject: [PATCH 02/23] add unittest for nearest_neighbor_interp_op --- .../operators/nearest_neighbor_interp_op.cc | 2 +- .../operators/nearest_neighbor_interp_op.cu | 2 +- .../operators/nearest_neighbor_interp_op.h | 30 ++-- .../test_nearest_neighbor_interp_op.py | 158 ++++++++++++++++++ 4 files changed, 176 insertions(+), 16 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_nearest_neighbor_interp_op.py diff --git a/paddle/fluid/operators/nearest_neighbor_interp_op.cc b/paddle/fluid/operators/nearest_neighbor_interp_op.cc index 4e29fe5ac3..b50648d617 100644 --- a/paddle/fluid/operators/nearest_neighbor_interp_op.cc +++ b/paddle/fluid/operators/nearest_neighbor_interp_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/paddle/fluid/operators/nearest_neighbor_interp_op.cu b/paddle/fluid/operators/nearest_neighbor_interp_op.cu index 11002d2e1e..16acc694ab 100644 --- a/paddle/fluid/operators/nearest_neighbor_interp_op.cu +++ b/paddle/fluid/operators/nearest_neighbor_interp_op.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/paddle/fluid/operators/nearest_neighbor_interp_op.h b/paddle/fluid/operators/nearest_neighbor_interp_op.h index 5ba12eaa7c..a37cc703b1 100644 --- a/paddle/fluid/operators/nearest_neighbor_interp_op.h +++ b/paddle/fluid/operators/nearest_neighbor_interp_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -37,12 +37,12 @@ class NearestNeighborInterpKernel : public framework::OpKernel { out_w = out_size_data[1]; } - const int in_n = input->dims()[0]; - const int in_c = input->dims()[1]; + const int n = input->dims()[0]; + const int c = input->dims()[1]; const int in_h = input->dims()[2]; const int in_w = input->dims()[3]; - output->mutable_data({in_n, in_c, out_h, out_w}, ctx.GetPlace()); + output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); auto& device_ctx = ctx.template device_context(); math::SetConstant zero; @@ -61,11 +61,11 @@ class NearestNeighborInterpKernel : public framework::OpKernel { auto input_t = EigenTensor::From(*input); auto output_t = EigenTensor::From(*output); for (int k = 0; k < out_h; k++) { // loop for images + int in_k = static_cast(round(ratio_h * k)); for (int l = 0; l < out_w; l++) { - int in_k = static_cast(round(ratio_h * k)); int in_l = static_cast(round(ratio_w * l)); - for (int i = 0; i < in_n; i++) { // loop for batches - for (int j = 0; j < in_c; j++) { // loop for channels + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels output_t(i, j, k, l) = input_t(i, j, in_k, in_l); } } @@ -78,6 +78,7 @@ template class NearestNeighborInterpGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* output_grad = ctx.Input(framework::GradVarName("Out")); @@ -90,11 +91,12 @@ class NearestNeighborInterpGradKernel : public framework::OpKernel { out_w = out_size_data[1]; } - const int in_n = input_grad->dims()[0]; - const int in_c = input_grad->dims()[1]; - const int in_h = input_grad->dims()[2]; - const int in_w = input_grad->dims()[3]; + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + input_grad->mutable_data({n, c, in_h, in_w}, ctx.GetPlace()); auto& device_ctx = ctx.template device_context(); math::SetConstant zero; @@ -113,11 +115,11 @@ class NearestNeighborInterpGradKernel : public framework::OpKernel { auto input_grad_t = EigenTensor::From(*input_grad); auto output_grad_t = EigenTensor::From(*output_grad); for (int k = 0; k < out_h; k++) { // loop for images + int in_k = static_cast(round(ratio_h * k)); for (int l = 0; l < out_w; l++) { - int in_k = static_cast(round(ratio_h * k)); int in_l = static_cast(round(ratio_w * l)); - for (int i = 0; i < in_n; i++) { // loop for batches - for (int j = 0; j < in_c; j++) { // loop for channels + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l); } } diff --git a/python/paddle/fluid/tests/unittests/test_nearest_neighbor_interp_op.py b/python/paddle/fluid/tests/unittests/test_nearest_neighbor_interp_op.py new file mode 100644 index 0000000000..78ad3b98f5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nearest_neighbor_interp_op.py @@ -0,0 +1,158 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core + + +def nearest_neighbor_interp_np(X, out_h, out_w, out_size=None): + """nearest neighbor interpolation implement in shape [N, C, H, W]""" + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + n, c, in_h, in_w = X.shape + + ratio_h = ratio_w = 0.0 + if out_h > 1: + ratio_h = (in_h - 1.0) / (out_h - 1.0) + if out_w > 1: + ratio_w = (in_w - 1.0) / (out_w - 1.0) + + out = np.zeros((n, c, out_h, out_w)) + for i in range(out_h): + in_i = int(round(ratio_h * i)) + for j in range(out_w): + in_j = int(round(ratio_w * j)) + out[:, :, i, j] = X[:, :, in_i, in_j] + + return out.astype(X.dtype) + + +class TestBilinearInterpOp(OpTest): + def setUp(self): + self.out_size = None + self.init_test_case() + self.op_type = "nearest_neighbor_interp" + input_np = np.random.random(self.input_shape).astype("float32") + output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w, + self.out_size) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + self.attrs = {'out_h': self.out_h, 'out_w': self.out_w} + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', in_place=True) + + def init_test_case(self): + self.input_shape = [2, 3, 4, 4] + self.out_h = 2 + self.out_w = 2 + self.out_size = np.array([3, 3]).astype("int32") + + +class TestCase1(TestBilinearInterpOp): + def init_test_case(self): + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + + +class TestCase2(TestBilinearInterpOp): + def init_test_case(self): + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + + +class TestCase3(TestBilinearInterpOp): + def init_test_case(self): + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + + +class TestCase4(TestBilinearInterpOp): + def init_test_case(self): + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.out_size = np.array([2, 2]).astype("int32") + + +class TestCase5(TestBilinearInterpOp): + def init_test_case(self): + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.out_size = np.array([11, 11]).astype("int32") + + +class TestCase6(TestBilinearInterpOp): + def init_test_case(self): + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + self.out_size = np.array([65, 129]).astype("int32") + + +class TestBilinearInterpOpUint8(OpTest): + def setUp(self): + self.out_size = None + self.init_test_case() + self.op_type = "nearest_neighbor_interp" + input_np = np.random.randint( + low=0, high=256, size=self.input_shape).astype("uint8") + output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w, + self.out_size) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + self.attrs = {'out_h': self.out_h, 'out_w': self.out_w} + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output_with_place(place=core.CPUPlace(), atol=1) + + def init_test_case(self): + self.input_shape = [1, 3, 9, 6] + self.out_h = 10 + self.out_w = 9 + + +class TestCase1Uint8(TestBilinearInterpOpUint8): + def init_test_case(self): + self.input_shape = [2, 3, 128, 64] + self.out_h = 120 + self.out_w = 50 + + +class TestCase2Uint8(TestBilinearInterpOpUint8): + def init_test_case(self): + self.input_shape = [4, 1, 7, 8] + self.out_h = 5 + self.out_w = 13 + self.out_size = np.array([6, 15]).astype("int32") + + +if __name__ == "__main__": + unittest.main() From df4a3544aa50ccd6d62c724fe53683e0ad2ac483 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 1 Nov 2018 16:28:51 +0800 Subject: [PATCH 03/23] nearest neighbor interp add cuda kernel. test=develop --- paddle/fluid/API.spec | 1 + .../operators/nearest_neighbor_interp_op.cc | 9 +- .../operators/nearest_neighbor_interp_op.cu | 149 ++++++++---------- python/paddle/fluid/layers/nn.py | 35 +++- .../fluid/tests/unittests/test_layers.py | 10 ++ 5 files changed, 111 insertions(+), 93 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 3bbe7c2b8c..65436cdd98 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -121,6 +121,7 @@ paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], vararg paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR')) paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)) paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) +paddle.fluid.layers.resize_nearest ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) diff --git a/paddle/fluid/operators/nearest_neighbor_interp_op.cc b/paddle/fluid/operators/nearest_neighbor_interp_op.cc index b50648d617..54c0198255 100644 --- a/paddle/fluid/operators/nearest_neighbor_interp_op.cc +++ b/paddle/fluid/operators/nearest_neighbor_interp_op.cc @@ -25,9 +25,9 @@ class NearestNeighborInterpOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of BilinearInterOp should not be null."); + "Input(X) of NearestNeighborInterOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of BilinearInterOp should not be null."); + "Output(Out) of NearestNeighborInterOp should not be null."); auto dim_x = ctx->GetInputDim("X"); // NCHW format int out_h = ctx->Attrs().Get("out_h"); @@ -64,8 +64,9 @@ class NearestNeighborInterpOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable(); AddOutput("Out", "The dimension of output is (N x C x out_h x out_w)"); - AddAttr("out_h", "output height of bilinear interpolation op."); - AddAttr("out_w", "output width of bilinear interpolation op."); + AddAttr("out_h", + "output height of nearest neighbor interpolation op."); + AddAttr("out_w", "output width of nearest neighbor interpolation op."); AddComment(R"DOC( Nearest neighbor interpolation is to perform nearest neighbor interpolation in bot the 3rd dimention(in height direction) and the 4th dimention(in width diff --git a/paddle/fluid/operators/nearest_neighbor_interp_op.cu b/paddle/fluid/operators/nearest_neighbor_interp_op.cu index 16acc694ab..d403f772fc 100644 --- a/paddle/fluid/operators/nearest_neighbor_interp_op.cu +++ b/paddle/fluid/operators/nearest_neighbor_interp_op.cu @@ -15,17 +15,14 @@ namespace paddle { namespace operators { -template -using EigenTensor = framework::EigenTensor; using framework::Tensor; template -__global__ void KeBilinearInterpFw( +__global__ void KeNearestNeighborInterpFw( const T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const T ratio_h, const T ratioW) { + const size_t num_channels, const T ratio_h, const T ratio_w) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < nthreads) { @@ -36,34 +33,22 @@ __global__ void KeBilinearInterpFw( int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = ratio_h * out_img_idy; - int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = ratio_h * out_img_idy - in_img_idy; - T h2lambda = 1.f - h1lambda; + int in_img_idy = static_cast(round(ratio_h * out_img_idy)); int out_img_idx = tid % out_img_w; - int in_img_idx = ratioW * out_img_idx; - int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = ratioW * out_img_idx - in_img_idx; - T w2lambda = 1.f - w1lambda; - - const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + - in_img_idy * in_img_w + in_img_idx]; - - // bilinear interpolation - out[out_id_h * output_w + out_id_w] = - h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + - h1lambda * (w2lambda * in_pos[h_id * in_img_w] + - w1lambda * in_pos[h_id * in_img_w + w_id]); + int in_img_idx = static_cast(round(ratio_w * out_img_idx)); + + out[tid] = in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; } } template -__global__ void KeBilinearInterpBw( +__global__ void KeNearestNeighborInterpBw( T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, const size_t input_w, const T* out, const size_t out_img_h, const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const T ratio_h, const T ratioW) { + const size_t num_channels, const T ratio_h, const T ratio_w) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < nthreads) { @@ -74,25 +59,15 @@ __global__ void KeBilinearInterpBw( int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = ratio_h * out_img_idy; - int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = ratio_h * out_img_idy - in_img_idy; - T h2lambda = 1.f - h1lambda; + int in_img_idy = static_cast(round(ratio_h * out_img_idy)); int out_img_idx = tid % out_img_w; - int in_img_idx = ratioW * out_img_idx; - int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = ratioW * out_img_idx - in_img_idx; - T w2lambda = 1.f - w1lambda; + int in_img_idx = static_cast(round(ratio_w * out_img_idx)); T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + in_img_idy * in_img_w + in_img_idx]; - const T* out_pos = &out[out_id_h * output_w + out_id_w]; - atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]); - atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]); - atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]); - atomicAdd(&in_pos[h_id * in_img_w + w_id], - h1lambda * w1lambda * out_pos[0]); + const T out_pos = out[out_id_h * output_w + out_id_w]; + atomicAdd(in_pos, out_pos); } } @@ -102,48 +77,49 @@ class NearestNeighborInterpOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - auto* input_t = ctx.Input("X"); // float tensor - auto* output_t = ctx.Output("Out"); // float tensor - auto* input = input_t->data(); + auto* input = ctx.Input("X"); // float tensor + auto* output = ctx.Output("Out"); // float tensor + auto* input_data = input->data(); int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); - auto out_dims = output_t->dims(); - auto out_size_t = ctx.Input("OutSize"); - if (out_size_t != nullptr) { + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { Tensor sizes; - framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes); + framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); auto size_data = sizes.data(); out_h = size_data[0]; out_w = size_data[1]; } - auto* output = output_t->mutable_data( - {out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace()); - int batch_size = input_t->dims()[0]; - int channels = input_t->dims()[1]; - int in_h = input_t->dims()[2]; - int in_w = input_t->dims()[3]; + int n = input->dims()[0]; + int c = input->dims()[1]; + int in_h = input->dims()[2]; + int in_w = input->dims()[3]; + + auto* output_data = + output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); int in_hw = in_h * in_w; int out_hw = out_h * out_w; - int in_chw = channels * in_hw; - int out_chw = channels * out_hw; + int in_chw = c * in_hw; + int out_chw = c * out_hw; T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; if (in_h == out_h && in_w == out_w) { - memcpy(output, input, input_t->numel() * sizeof(T)); - } else { - int threadNum = batch_size * out_chw; - int blocks = (threadNum + 1024 - 1) / 1024; - - KeBilinearInterpFw< - T><<>>( - input, in_h, in_w, batch_size, in_chw, output, out_h, out_w, - batch_size, out_chw, channels, ratio_h, ratio_w); + memcpy(output_data, input_data, input->numel() * sizeof(T)); + return; } + + int threadNum = n * out_chw; + int blocks = (threadNum + 1024 - 1) / 1024; + + KeNearestNeighborInterpFw< + T><<>>( + input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, + out_chw, c, ratio_h, ratio_w); } }; @@ -151,52 +127,53 @@ template class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* d_input_t = ctx.Output(framework::GradVarName("X")); - auto* d_output_t = ctx.Input(framework::GradVarName("Out")); - auto* d_output = d_output_t->data(); - auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + auto* output_grad_data = output_grad->data(); + auto* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); auto& device_ctx = ctx.template device_context(); math::SetConstant zero; - zero(device_ctx, d_input_t, static_cast(0.0)); + zero(device_ctx, input_grad, static_cast(0.0)); int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); - auto out_size_t = ctx.Input("OutSize"); - if (out_size_t != nullptr) { + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { Tensor sizes; - framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes); + framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); auto size_data = sizes.data(); out_h = size_data[0]; out_w = size_data[1]; } - int batch_size = d_input_t->dims()[0]; - int channels = d_input_t->dims()[1]; - int in_h = d_input_t->dims()[2]; - int in_w = d_input_t->dims()[3]; + int n = input_grad->dims()[0]; + int c = input_grad->dims()[1]; + int in_h = input_grad->dims()[2]; + int in_w = input_grad->dims()[3]; int in_hw = in_h * in_w; int out_hw = out_h * out_w; - int in_chw = channels * in_hw; - int out_chw = channels * out_hw; + int in_chw = c * in_hw; + int out_chw = c * out_hw; T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; if (in_h == out_h && in_w == out_w) { - memcpy(d_input, d_output, d_input_t->numel() * sizeof(T)); - } else { - int threadNum = batch_size * out_chw; - int blocks = (threadNum + 1024 - 1) / 1024; - - KeBilinearInterpBw< - T><<>>( - d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w, - batch_size, out_chw, channels, ratio_h, ratio_w); + memcpy(input_grad, output_grad, input_grad->numel() * sizeof(T)); + return; } + + int threadNum = n * out_chw; + int blocks = (threadNum + 1024 - 1) / 1024; + + KeNearestNeighborInterpBw< + T><<>>( + input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, + n, out_chw, c, ratio_h, ratio_w); } }; @@ -206,5 +183,5 @@ class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp, ops::NearestNeighborInterpOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(nearest_neighborinterp_grad, +REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp_grad, ops::NearestNeighborInterpGradOpCUDAKernel); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 110e6d5ab2..f4d8308e7c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -101,6 +101,7 @@ __all__ = [ 'image_resize', 'image_resize_short', 'resize_bilinear', + 'resize_nearest', 'gather', 'scatter', 'sequence_scatter', @@ -5584,6 +5585,7 @@ def image_resize(input, Supporting resample methods: 'BILINEAR' : Bilinear interpolation + 'NEAREST' : Nearest neighbor interpolation Args: input (Variable): The input tensor of image resize layer, @@ -5610,13 +5612,17 @@ def image_resize(input, out = fluid.layers.image_resize(input, out_shape=[12, 12]) """ - resample_methods = {'BILINEAR': 'bilinear_interp'} + resample_methods = { + 'BILINEAR': 'bilinear_interp', + 'NEAREST': 'nearest_neighbor_interp' + } if resample not in resample_methods: raise ValueError( - "The 'resample' of image_resize can only be 'BILINEAR' currently.") + "The 'resample' of image_resize can only be 'BILINEAR' and 'NEAREST' currently." + ) if out_shape is None and scale is None: raise ValueError("One of out_shape and scale must not be None") - helper = LayerHelper('bilinear_interp', **locals()) + helper = LayerHelper(resample_methods[resample], **locals()) dtype = helper.input_dtype() def _is_list_or_turple_(data): @@ -5672,6 +5678,29 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None): return image_resize(input, out_shape, scale, name, 'BILINEAR') +@templatedoc(op_type="bilinear_interp") +def resize_nearest(input, out_shape=None, scale=None, name=None): + """ + ${comment} + + Args: + input(${x_type}): ${x_comment}. + + out_shape(${out_size_type}): ${out_size_comment}. + + scale(float|None): The multiplier for the input height or width. At + least one of out_shape or scale must be set. And out_shape has + a higher priority than scale. Default: None. + + name(str|None): The output variable name. + + Returns: + ${out_comment}. + """ + + return image_resize(input, out_shape, scale, name, 'NEAREST') + + def image_resize_short(input, out_short_len, resample='BILINEAR'): """ Resize a batch of images. The short edge of input images will be diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 50de468dba..0390938901 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -485,6 +485,16 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(output) print(str(program)) + def test_resize_bilinear(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[3, 9, 6], dtype="float32") + output = layers.resize_nearest(x, out_shape=[12, 12]) + self.assertIsNotNone(output) + output = layers.resize_nearest(x, scale=3) + self.assertIsNotNone(output) + print(str(program)) + def test_polygon_box_transform(self): program = Program() with program_guard(program): From e46f03e19dd59a7ca36d4a1491f57d4bafd06741 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 5 Nov 2018 13:20:16 +0800 Subject: [PATCH 04/23] Add TESTING_DEBUG_MODE to support debug info in daily CI test test=develop --- paddle/scripts/paddle_build.sh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index d7676f89ab..2f5fef36c4 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -367,7 +367,12 @@ function run_test() { Running unit tests ... ======================================== EOF - ctest --output-on-failure + if [ ${TESTING_DEBUG_MODE:-OFF} == "ON" ] ; then + ctest -V + else + ctest --output-on-failure + fi + # make install should also be test when unittest make install -j `nproc` pip install ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl From 34bfae243a7d4ba7085bf9c337a65f6464fe2c5c Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 2 Nov 2018 12:09:55 +0800 Subject: [PATCH 05/23] Add Interpolate operation. test=develop --- paddle/fluid/operators/bilinear_interp_op.cc | 116 ------- paddle/fluid/operators/bilinear_interp_op.cu | 207 ------------ paddle/fluid/operators/bilinear_interp_op.h | 163 ---------- ...eighbor_interp_op.cc => interpolate_op.cc} | 70 +++-- paddle/fluid/operators/interpolate_op.cu | 286 +++++++++++++++++ paddle/fluid/operators/interpolate_op.h | 236 ++++++++++++++ .../operators/nearest_neighbor_interp_op.cu | 187 ----------- .../operators/nearest_neighbor_interp_op.h | 132 -------- python/paddle/fluid/layers/nn.py | 20 +- .../unittests/test_bilinear_interp_op.py | 168 ---------- .../tests/unittests/test_interpolate_op.py | 294 ++++++++++++++++++ .../fluid/tests/unittests/test_layers.py | 2 +- .../test_nearest_neighbor_interp_op.py | 158 ---------- 13 files changed, 874 insertions(+), 1165 deletions(-) delete mode 100644 paddle/fluid/operators/bilinear_interp_op.cc delete mode 100644 paddle/fluid/operators/bilinear_interp_op.cu delete mode 100644 paddle/fluid/operators/bilinear_interp_op.h rename paddle/fluid/operators/{nearest_neighbor_interp_op.cc => interpolate_op.cc} (55%) create mode 100644 paddle/fluid/operators/interpolate_op.cu create mode 100644 paddle/fluid/operators/interpolate_op.h delete mode 100644 paddle/fluid/operators/nearest_neighbor_interp_op.cu delete mode 100644 paddle/fluid/operators/nearest_neighbor_interp_op.h delete mode 100644 python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_interpolate_op.py delete mode 100644 python/paddle/fluid/tests/unittests/test_nearest_neighbor_interp_op.py diff --git a/paddle/fluid/operators/bilinear_interp_op.cc b/paddle/fluid/operators/bilinear_interp_op.cc deleted file mode 100644 index 2dc3399da1..0000000000 --- a/paddle/fluid/operators/bilinear_interp_op.cc +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/fluid/operators/bilinear_interp_op.h" -#include -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using framework::Tensor; - -class BilinearInterpOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of BilinearInterOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of BilinearInterOp should not be null."); - - auto dim_x = ctx->GetInputDim("X"); // NCHW format - int out_h = ctx->Attrs().Get("out_h"); - int out_w = ctx->Attrs().Get("out_w"); - PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4"); - - if (ctx->HasInput("OutSize")) { - auto out_size_dim = ctx->GetInputDim("OutSize"); - PADDLE_ENFORCE_EQ(out_size_dim.size(), 1, - "OutSize's dimension size must be 1"); - PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2"); - } - std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); - ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); - } -}; - -class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "The input tensor of bilinear interpolation, " - "This is a 4-D tensor with shape of (N x C x h x w)"); - AddInput("OutSize", - "This is a 1-D tensor with two number. " - "The first number is height and the second number is width.") - .AsDispensable(); - AddOutput("Out", "The dimension of output is (N x C x out_h x out_w)"); - - AddAttr("out_h", "output height of bilinear interpolation op."); - AddAttr("out_w", "output width of bilinear interpolation op."); - AddComment(R"DOC( - Bilinear interpolation is an extension of linear interpolation for - interpolating functions of two variables (e.g. H-direction and - W-direction in this op) on a rectilinear 2D grid. - - The key idea is to perform linear interpolation first in one - direction, and then again in the other direction. - - For details, please refer to Wikipedia: - https://en.wikipedia.org/wiki/Bilinear_interpolation - )DOC"); - } -}; - -class BilinearInterpOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto dim_x = ctx->GetInputDim("X"); - if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), dim_x); - } - } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp, - ops::BilinearInterpOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad); -REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel, - ops::BilinearInterpKernel); -REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, - ops::BilinearInterpGradKernel); diff --git a/paddle/fluid/operators/bilinear_interp_op.cu b/paddle/fluid/operators/bilinear_interp_op.cu deleted file mode 100644 index 4c19715384..0000000000 --- a/paddle/fluid/operators/bilinear_interp_op.cu +++ /dev/null @@ -1,207 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/fluid/operators/bilinear_interp_op.h" -#include "paddle/fluid/platform/cuda_primitives.h" - -namespace paddle { -namespace operators { - -using framework::Tensor; - -template -__global__ void KeBilinearInterpFw( - const T* in, const size_t in_img_h, const size_t in_img_w, - const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, - const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const T ratio_h, const T ratioW) { - int nthreads = output_h * output_w; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { - int out_id_h = tid / output_w; - int out_id_w = tid % output_w; - int in_img_size = input_w / num_channels; - int out_img_size = output_w / num_channels; - int channel_id = out_id_w / out_img_size; - - int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = ratio_h * out_img_idy; - int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = ratio_h * out_img_idy - in_img_idy; - T h2lambda = 1.f - h1lambda; - - int out_img_idx = tid % out_img_w; - int in_img_idx = ratioW * out_img_idx; - int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = ratioW * out_img_idx - in_img_idx; - T w2lambda = 1.f - w1lambda; - - const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + - in_img_idy * in_img_w + in_img_idx]; - - // bilinear interpolation - out[out_id_h * output_w + out_id_w] = - h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + - h1lambda * (w2lambda * in_pos[h_id * in_img_w] + - w1lambda * in_pos[h_id * in_img_w + w_id]); - } -} - -template -__global__ void KeBilinearInterpBw( - T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, - const size_t input_w, const T* out, const size_t out_img_h, - const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const T ratio_h, const T ratioW) { - int nthreads = output_h * output_w; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { - int out_id_h = tid / output_w; - int out_id_w = tid % output_w; - int in_img_size = input_w / num_channels; - int out_img_size = output_w / num_channels; - int channel_id = out_id_w / out_img_size; - - int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = ratio_h * out_img_idy; - int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = ratio_h * out_img_idy - in_img_idy; - T h2lambda = 1.f - h1lambda; - - int out_img_idx = tid % out_img_w; - int in_img_idx = ratioW * out_img_idx; - int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = ratioW * out_img_idx - in_img_idx; - T w2lambda = 1.f - w1lambda; - - T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + - in_img_idy * in_img_w + in_img_idx]; - const T* out_pos = &out[out_id_h * output_w + out_id_w]; - atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]); - atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]); - atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]); - atomicAdd(&in_pos[h_id * in_img_w + w_id], - h1lambda * w1lambda * out_pos[0]); - } -} - -template -class BilinearInterpOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "This kernel only runs on GPU device."); - auto* input_t = ctx.Input("X"); // float tensor - auto* output_t = ctx.Output("Out"); // float tensor - auto* input = input_t->data(); - - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - auto out_dims = output_t->dims(); - auto out_size_t = ctx.Input("OutSize"); - if (out_size_t != nullptr) { - Tensor sizes; - framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes); - auto size_data = sizes.data(); - out_h = size_data[0]; - out_w = size_data[1]; - } - auto* output = output_t->mutable_data( - {out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace()); - - int batch_size = input_t->dims()[0]; - int channels = input_t->dims()[1]; - int in_h = input_t->dims()[2]; - int in_w = input_t->dims()[3]; - - int in_hw = in_h * in_w; - int out_hw = out_h * out_w; - int in_chw = channels * in_hw; - int out_chw = channels * out_hw; - - T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; - - if (in_h == out_h && in_w == out_w) { - memcpy(output, input, input_t->numel() * sizeof(T)); - } else { - int threadNum = batch_size * out_chw; - int blocks = (threadNum + 1024 - 1) / 1024; - - KeBilinearInterpFw< - T><<>>( - input, in_h, in_w, batch_size, in_chw, output, out_h, out_w, - batch_size, out_chw, channels, ratio_h, ratio_w); - } - } -}; - -template -class BilinearInterpGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* d_input_t = ctx.Output(framework::GradVarName("X")); - auto* d_output_t = ctx.Input(framework::GradVarName("Out")); - auto* d_output = d_output_t->data(); - auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); - - auto& device_ctx = - ctx.template device_context(); - math::SetConstant zero; - zero(device_ctx, d_input_t, static_cast(0.0)); - - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - - auto out_size_t = ctx.Input("OutSize"); - if (out_size_t != nullptr) { - Tensor sizes; - framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes); - auto size_data = sizes.data(); - out_h = size_data[0]; - out_w = size_data[1]; - } - - int batch_size = d_input_t->dims()[0]; - int channels = d_input_t->dims()[1]; - int in_h = d_input_t->dims()[2]; - int in_w = d_input_t->dims()[3]; - - int in_hw = in_h * in_w; - int out_hw = out_h * out_w; - int in_chw = channels * in_hw; - int out_chw = channels * out_hw; - - T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; - - if (in_h == out_h && in_w == out_w) { - memcpy(d_input, d_output, d_input_t->numel() * sizeof(T)); - } else { - int threadNum = batch_size * out_chw; - int blocks = (threadNum + 1024 - 1) / 1024; - - KeBilinearInterpBw< - T><<>>( - d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w, - batch_size, out_chw, channels, ratio_h, ratio_w); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(bilinear_interp, - ops::BilinearInterpOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad, - ops::BilinearInterpGradOpCUDAKernel); diff --git a/paddle/fluid/operators/bilinear_interp_op.h b/paddle/fluid/operators/bilinear_interp_op.h deleted file mode 100644 index 70847cb8c1..0000000000 --- a/paddle/fluid/operators/bilinear_interp_op.h +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class BilinearInterpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input_t = ctx.Input("X"); // float tensor - auto* output_t = ctx.Output("Out"); // float tensor - auto out_dims = output_t->dims(); - auto* input = input_t->data(); - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - auto out_size_t = ctx.Input("OutSize"); - if (out_size_t != nullptr) { - auto out_size_data = out_size_t->data(); - out_h = out_size_data[0]; - out_w = out_size_data[1]; - } - auto* output = output_t->mutable_data( - {out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace()); - int batch_size = input_t->dims()[0]; - int channels = input_t->dims()[1]; - int in_h = input_t->dims()[2]; - int in_w = input_t->dims()[3]; - - int in_hw = in_h * in_w; - int out_hw = out_h * out_w; - int in_chw = channels * in_hw; - int out_chw = channels * out_hw; - - float ratio_h = - (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - float ratio_w = - (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; - - if (in_h == out_h && in_w == out_w) { - memcpy(output, input, input_t->numel() * sizeof(T)); - } else { - for (int k = 0; k < batch_size; ++k) { // loop for batches - for (int i = 0; i < out_h; ++i) { // loop for images - int h = ratio_h * i; - int hid = (h < in_h - 1) ? 1 : 0; - float h1lambda = ratio_h * i - h; - float h2lambda = 1.f - h1lambda; - - for (int j = 0; j < out_w; ++j) { - int w = ratio_w * j; - int wid = (w < in_w - 1) ? 1 : 0; - float w1lambda = ratio_w * j - w; - float w2lambda = 1.f - w1lambda; - // calculate four position for bilinear interpolation - const T* in_pos = &input[k * in_chw + h * in_w + w]; - T* out_pos = &output[k * out_chw + i * out_w + j]; - - for (int c = 0; c < channels; ++c) { // loop for channels - // bilinear interpolation - out_pos[0] = static_cast( - h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) + - h1lambda * (w2lambda * in_pos[hid * in_w] + - w1lambda * in_pos[hid * in_w + wid])); - in_pos += in_hw; - out_pos += out_hw; - } - } - } - } - } - } -}; - -template -class BilinearInterpGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* d_input_t = ctx.Output(framework::GradVarName("X")); - auto* d_output_t = ctx.Input(framework::GradVarName("Out")); - auto* d_output = d_output_t->data(); - auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); - auto& device_ctx = - ctx.template device_context(); - math::SetConstant zero; - zero(device_ctx, d_input_t, static_cast(0.0)); - - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - - auto out_size_t = ctx.Input("OutSize"); - if (out_size_t != nullptr) { - auto out_size_data = out_size_t->data(); - out_h = out_size_data[0]; - out_w = out_size_data[1]; - } - - int batch_size = d_input_t->dims()[0]; - int channels = d_input_t->dims()[1]; - int in_h = d_input_t->dims()[2]; - int in_w = d_input_t->dims()[3]; - - int in_hw = in_h * in_w; - int out_hw = out_h * out_w; - int in_chw = channels * in_hw; - int out_chw = channels * out_hw; - - float ratio_h = - (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - float ratio_w = - (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; - - if (in_h == out_h && in_w == out_w) { - memcpy(d_input, d_output, d_input_t->numel() * sizeof(T)); - } else { - for (int k = 0; k < batch_size; ++k) { // loop for batches - for (int i = 0; i < out_h; ++i) { // loop for images - int h = ratio_h * i; - int hid = (h < in_h - 1) ? 1 : 0; - float h1lambda = ratio_h * i - h; - float h2lambda = 1 - h1lambda; - - for (int j = 0; j < out_w; ++j) { - int w = ratio_w * j; - int wid = (w < in_w - 1) ? 1 : 0; - float w1lambda = ratio_w * j - w; - float w2lambda = 1 - w1lambda; - T* in_pos = &d_input[k * in_chw + h * in_w + w]; - const T* out_pos = &d_output[k * out_chw + i * out_w + j]; - - for (int c = 0; c < channels; ++c) { // loop for channels - in_pos[0] += static_cast(h2lambda * w2lambda * out_pos[0]); - in_pos[wid] += static_cast(h2lambda * w1lambda * out_pos[0]); - in_pos[hid * in_w] += - static_cast(h1lambda * w2lambda * out_pos[0]); - in_pos[hid * in_w + wid] += - static_cast(h1lambda * w1lambda * out_pos[0]); - in_pos += in_hw; - out_pos += out_hw; - } - } - } - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/nearest_neighbor_interp_op.cc b/paddle/fluid/operators/interpolate_op.cc similarity index 55% rename from paddle/fluid/operators/nearest_neighbor_interp_op.cc rename to paddle/fluid/operators/interpolate_op.cc index 54c0198255..e2000d0e0c 100644 --- a/paddle/fluid/operators/nearest_neighbor_interp_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -9,7 +9,8 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/nearest_neighbor_interp_op.h" +#include "paddle/fluid/operators/interpolate_op.h" +#include #include #include "paddle/fluid/framework/op_registry.h" @@ -18,16 +19,21 @@ namespace operators { using framework::Tensor; -class NearestNeighborInterpOp : public framework::OperatorWithKernel { +class InterpolateOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of NearestNeighborInterOp should not be null."); + "Input(X) of InterpolateOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of NearestNeighborInterOp should not be null."); + "Output(Out) of InterpolationOp should not be null."); + + auto interp_method = ctx->Attrs().Get("interp_method"); + PADDLE_ENFORCE( + "bilinear" == interp_method || "nearest" == interp_method, + "Interpolation method can only be \"bilinear\" or \"nearest\"."); auto dim_x = ctx->GetInputDim("X"); // NCHW format int out_h = ctx->Attrs().Get("out_h"); @@ -52,33 +58,53 @@ class NearestNeighborInterpOp : public framework::OperatorWithKernel { } }; -class NearestNeighborInterpOpMaker : public framework::OpProtoAndCheckerMaker { +class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "The input tensor of nearest neighbor interpolation, " - "This is a 4-D tensor with shape of (N x C x h x w)"); + "The input tensor of interpolate operator, " + "This is a 4-D tensor with shape of [N, C, H, w]."); AddInput("OutSize", - "This is a 1-D tensor with two number. " + "This is a 1-D tensor with two numbers to specify output size. " "The first number is height and the second number is width.") .AsDispensable(); - AddOutput("Out", "The dimension of output is (N x C x out_h x out_w)"); + AddOutput("Out", + "The output tensor of interpolate operator, " + "This is a 4-D tensor with shape of [N, C, H, W]."); - AddAttr("out_h", - "output height of nearest neighbor interpolation op."); - AddAttr("out_w", "output width of nearest neighbor interpolation op."); + AddAttr("out_h", "output height of interpolate op."); + AddAttr("out_w", "output width of interpolate op."); + AddAttr( + "interp_method", + "(string), interpolation method, can be \"bilinear\" for " + "bilinear interpolation and \"nearest\" for nearest " + "neighbor interpolation."); AddComment(R"DOC( + This operator samples input X to given output shape by using specified + interpolation method, the interpolation methods can be \"nearest\" + for nearest neighbor interpolation and \"bilinear\" for bilinear + interpolation. + Nearest neighbor interpolation is to perform nearest neighbor interpolation in bot the 3rd dimention(in height direction) and the 4th dimention(in width direction) on input tensor. - For details, please refer to Wikipedia: + Bilinear interpolation is an extension of linear interpolation for + interpolating functions of two variables (e.g. H-direction and + W-direction in this op) on a rectilinear 2D grid. The key idea is + to perform linear interpolation first in one direction, and then + again in the other direction. + + For details of nearest neighbor interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation + + For details of bilinear interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Bilinear_interpolation )DOC"); } }; -class NearestNeighborInterpOpGrad : public framework::OperatorWithKernel { +class InterpolateOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -104,13 +130,11 @@ class NearestNeighborInterpOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(nearest_neighbor_interp, ops::NearestNeighborInterpOp, - ops::NearestNeighborInterpOpMaker, +REGISTER_OPERATOR(interpolate, ops::InterpolateOp, ops::InterpolateOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(nearest_neighbor_interp_grad, - ops::NearestNeighborInterpOpGrad); -REGISTER_OP_CPU_KERNEL(nearest_neighbor_interp, - ops::NearestNeighborInterpKernel, - ops::NearestNeighborInterpKernel); -REGISTER_OP_CPU_KERNEL(nearest_neighbor_interp_grad, - ops::NearestNeighborInterpGradKernel); +REGISTER_OPERATOR(interpolate_grad, ops::InterpolateOpGrad); +REGISTER_OP_CPU_KERNEL(interpolate, ops::InterpolateKernel, + ops::InterpolateKernel, + ops::InterpolateKernel); +REGISTER_OP_CPU_KERNEL(interpolate_grad, ops::InterpolateGradKernel, + ops::InterpolateGradKernel); diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu new file mode 100644 index 0000000000..3b9ece4830 --- /dev/null +++ b/paddle/fluid/operators/interpolate_op.cu @@ -0,0 +1,286 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "paddle/fluid/operators/interpolate_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +__global__ void KeNearestNeighborInterpFw( + const T* in, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < nthreads) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = static_cast(ratio_h * out_img_idy + 0.5); + + int out_img_idx = tid % out_img_w; + int in_img_idx = static_cast(ratio_w * out_img_idx + 0.5); + + out[tid] = in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + } +} + +template +__global__ void KeNearestNeighborInterpBw( + T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, + const size_t input_w, const T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < nthreads) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = static_cast(ratio_h * out_img_idy + 0.5); + + int out_img_idx = tid % out_img_w; + int in_img_idx = static_cast(ratio_w * out_img_idx + 0.5); + + T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + const T out_pos = out[out_id_h * output_w + out_id_w]; + platform::CudaAtomicAdd(in_pos, out_pos); + } +} + +template +__global__ void KeBilinearInterpFw( + const T* in, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < nthreads) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = ratio_h * out_img_idy; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + T h1lambda = ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; + + int out_img_idx = tid % out_img_w; + int in_img_idx = ratio_w * out_img_idx; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + T w1lambda = ratio_w * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; + + const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + + // bilinear interpolation + out[out_id_h * output_w + out_id_w] = + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + + h1lambda * (w2lambda * in_pos[h_id * in_img_w] + + w1lambda * in_pos[h_id * in_img_w + w_id]); + } +} + +template +__global__ void KeBilinearInterpBw( + T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, + const size_t input_w, const T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const T ratio_h, const T ratio_w) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < nthreads) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = ratio_h * out_img_idy; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + T h1lambda = ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; + + int out_img_idx = tid % out_img_w; + int in_img_idx = ratio_w * out_img_idx; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + T w1lambda = ratio_w * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; + + T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + const T* out_pos = &out[out_id_h * output_w + out_id_w]; + platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos[h_id * in_img_w], + h1lambda * w2lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos[h_id * in_img_w + w_id], + h1lambda * w1lambda * out_pos[0]); + } +} + +template +class InterpolateOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* input_data = input->data(); + + auto interp_method = ctx.Attr("interp_method"); + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + Tensor sizes; + framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_h = size_data[0]; + out_w = size_data[1]; + } + + int n = input->dims()[0]; + int c = input->dims()[1]; + int in_h = input->dims()[2]; + int in_w = input->dims()[3]; + + auto* output_data = + output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = c * in_hw; + int out_chw = c * out_hw; + + float ratio_h = + (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + float ratio_w = + (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*input, ctx.GetPlace(), output); + return; + } + + int threadNum = n * out_chw; + int blocks = (threadNum + 1024 - 1) / 1024; + + if ("nearest" == interp_method) { + KeNearestNeighborInterpFw< + T><<>>( + input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, + out_chw, c, ratio_h, ratio_w); + } else if ("bilinear" == interp_method) { + KeBilinearInterpFw< + T><<>>( + input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, + out_chw, c, ratio_h, ratio_w); + } + } +}; + +template +class InterpolateGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + auto* output_grad_data = output_grad->data(); + auto* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); + + auto interp_method = ctx.Attr("interp_method"); + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + Tensor sizes; + framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_h = size_data[0]; + out_w = size_data[1]; + } + + int n = input_grad->dims()[0]; + int c = input_grad->dims()[1]; + int in_h = input_grad->dims()[2]; + int in_w = input_grad->dims()[3]; + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = c * in_hw; + int out_chw = c * out_hw; + + float ratio_h = + (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + float ratio_w = + (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); + return; + } + + int threadNum = n * out_chw; + int blocks = (threadNum + 1024 - 1) / 1024; + + if ("nearest" == interp_method) { + KeNearestNeighborInterpBw< + T><<>>( + input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, + out_w, n, out_chw, c, ratio_h, ratio_w); + } else if ("bilinear" == interp_method) { + KeBilinearInterpBw< + T><<>>( + input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, + out_w, n, out_chw, c, ratio_h, ratio_w); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(interpolate, ops::InterpolateOpCUDAKernel, + ops::InterpolateOpCUDAKernel, + ops::InterpolateOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(interpolate_grad, + ops::InterpolateGradOpCUDAKernel, + ops::InterpolateGradOpCUDAKernel); diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h new file mode 100644 index 0000000000..7fdb3e1f5a --- /dev/null +++ b/paddle/fluid/operators/interpolate_op.h @@ -0,0 +1,236 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +using EigenTensor = framework::EigenTensor; +using Tensor = framework::Tensor; + +template +static void NearestNeighborInterpolate(const Tensor& input, Tensor* output, + const float ratio_h, const float ratio_w, + const int n, const int c, + const int out_h, const int out_w) { + auto input_t = EigenTensor::From(input); + auto output_t = EigenTensor::From(*output); + for (int k = 0; k < out_h; k++) { // loop for images + int in_k = static_cast(ratio_h * k + 0.5); + + for (int l = 0; l < out_w; l++) { + int in_l = static_cast(ratio_w * l + 0.5); + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + output_t(i, j, k, l) = input_t(i, j, in_k, in_l); + } + } + } + } +} + +template +static void BilinearInterpolation(const Tensor& input, Tensor* output, + const float ratio_h, const float ratio_w, + const int in_h, const int in_w, const int n, + const int c, const int out_h, + const int out_w) { + auto input_t = EigenTensor::From(input); + auto output_t = EigenTensor::From(*output); + for (int k = 0; k < out_h; k++) { // loop for images + int y_n = static_cast(ratio_h * k); + int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); + float d_n = ratio_h * k - y_n; + float d_s = 1.f - d_n; + + for (int l = 0; l < out_w; l++) { + int x_w = static_cast(ratio_w * l); + int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); + float d_w = ratio_w * l - x_w; + float d_e = 1.f - d_w; + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + // bilinear interpolation + output_t(i, j, k, l) = input_t(i, j, y_n, x_w) * d_s * d_e + + input_t(i, j, y_s, x_w) * d_n * d_e + + input_t(i, j, y_n, x_e) * d_s * d_w + + input_t(i, j, y_s, x_e) * d_n * d_w; + } + } + } + } +} + +template +static void NearestNeighborInterpolateGrad(const Tensor& output_grad, + Tensor* input_grad, + const float ratio_h, + const float ratio_w, const int n, + const int c, const int out_h, + const int out_w) { + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); + for (int k = 0; k < out_h; k++) { // loop for images + int in_k = static_cast(ratio_h * k + 0.5); + + for (int l = 0; l < out_w; l++) { + int in_l = static_cast(ratio_w * l + 0.5); + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l); + } + } + } + } +} + +template +static void BilinearInterpolationGrad(const Tensor& output_grad, + Tensor* input_grad, const float ratio_h, + const float ratio_w, const int in_h, + const int in_w, const int n, const int c, + const int out_h, const int out_w) { + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); + for (int k = 0; k < out_h; k++) { // loop for images + int y_n = static_cast(ratio_h * k); + int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); + float d_n = ratio_h * k - y_n; + float d_s = 1.f - d_n; + + for (int l = 0; l < out_w; l++) { + int x_w = static_cast(ratio_w * l); + int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); + float d_w = ratio_w * l - x_w; + float d_e = 1.f - d_w; + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + // bilinear interpolation grad + const T grad = output_grad_t(i, j, k, l); + input_grad_t(i, j, y_n, x_w) += static_cast(grad * d_s * d_e); + input_grad_t(i, j, y_s, x_w) += static_cast(grad * d_n * d_e); + input_grad_t(i, j, y_n, x_e) += static_cast(grad * d_s * d_w); + input_grad_t(i, j, y_s, x_e) += static_cast(grad * d_n * d_w); + } + } + } + } +} + +template +class InterpolateKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + + std::string interp_method = ctx.Attr("interp_method"); + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = out_size->data(); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + + output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, output, static_cast(0.0)); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*input, ctx.GetPlace(), output); + return; + } + + float ratio_h = + (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + float ratio_w = + (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + if ("bilinear" == interp_method) { + BilinearInterpolation(*input, output, ratio_h, ratio_w, in_h, in_w, n, + c, out_h, out_w); + } else if ("nearest" == interp_method) { + NearestNeighborInterpolate(*input, output, ratio_h, ratio_w, n, c, + out_h, out_w); + } + } +}; + +template +class InterpolateGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + + std::string interp_method = ctx.Attr("interp_method"); + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = out_size->data(); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + + input_grad->mutable_data({n, c, in_h, in_w}, ctx.GetPlace()); + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); + return; + } + + float ratio_h = + (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + float ratio_w = + (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + if ("bilinear" == interp_method) { + BilinearInterpolationGrad(*output_grad, input_grad, ratio_h, ratio_w, + in_h, in_w, n, c, out_h, out_w); + } else if ("nearest" == interp_method) { + NearestNeighborInterpolateGrad(*output_grad, input_grad, ratio_h, + ratio_w, n, c, out_h, out_w); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/nearest_neighbor_interp_op.cu b/paddle/fluid/operators/nearest_neighbor_interp_op.cu deleted file mode 100644 index d403f772fc..0000000000 --- a/paddle/fluid/operators/nearest_neighbor_interp_op.cu +++ /dev/null @@ -1,187 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/fluid/operators/nearest_neighbor_interp_op.h" -#include "paddle/fluid/platform/cuda_primitives.h" - -namespace paddle { -namespace operators { - -using framework::Tensor; - -template -__global__ void KeNearestNeighborInterpFw( - const T* in, const size_t in_img_h, const size_t in_img_w, - const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, - const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const T ratio_h, const T ratio_w) { - int nthreads = output_h * output_w; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { - int out_id_h = tid / output_w; - int out_id_w = tid % output_w; - int in_img_size = input_w / num_channels; - int out_img_size = output_w / num_channels; - int channel_id = out_id_w / out_img_size; - - int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = static_cast(round(ratio_h * out_img_idy)); - - int out_img_idx = tid % out_img_w; - int in_img_idx = static_cast(round(ratio_w * out_img_idx)); - - out[tid] = in[out_id_h * input_w + channel_id * in_img_size + - in_img_idy * in_img_w + in_img_idx]; - } -} - -template -__global__ void KeNearestNeighborInterpBw( - T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, - const size_t input_w, const T* out, const size_t out_img_h, - const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const T ratio_h, const T ratio_w) { - int nthreads = output_h * output_w; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { - int out_id_h = tid / output_w; - int out_id_w = tid % output_w; - int in_img_size = input_w / num_channels; - int out_img_size = output_w / num_channels; - int channel_id = out_id_w / out_img_size; - - int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = static_cast(round(ratio_h * out_img_idy)); - - int out_img_idx = tid % out_img_w; - int in_img_idx = static_cast(round(ratio_w * out_img_idx)); - - T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + - in_img_idy * in_img_w + in_img_idx]; - const T out_pos = out[out_id_h * output_w + out_id_w]; - atomicAdd(in_pos, out_pos); - } -} - -template -class NearestNeighborInterpOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "This kernel only runs on GPU device."); - auto* input = ctx.Input("X"); // float tensor - auto* output = ctx.Output("Out"); // float tensor - auto* input_data = input->data(); - - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - auto out_size = ctx.Input("OutSize"); - if (out_size != nullptr) { - Tensor sizes; - framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); - auto size_data = sizes.data(); - out_h = size_data[0]; - out_w = size_data[1]; - } - - int n = input->dims()[0]; - int c = input->dims()[1]; - int in_h = input->dims()[2]; - int in_w = input->dims()[3]; - - auto* output_data = - output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); - - int in_hw = in_h * in_w; - int out_hw = out_h * out_w; - int in_chw = c * in_hw; - int out_chw = c * out_hw; - - T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; - - if (in_h == out_h && in_w == out_w) { - memcpy(output_data, input_data, input->numel() * sizeof(T)); - return; - } - - int threadNum = n * out_chw; - int blocks = (threadNum + 1024 - 1) / 1024; - - KeNearestNeighborInterpFw< - T><<>>( - input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, - out_chw, c, ratio_h, ratio_w); - } -}; - -template -class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input_grad = ctx.Output(framework::GradVarName("X")); - auto* output_grad = ctx.Input(framework::GradVarName("Out")); - auto* output_grad_data = output_grad->data(); - auto* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - - auto& device_ctx = - ctx.template device_context(); - math::SetConstant zero; - zero(device_ctx, input_grad, static_cast(0.0)); - - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - - auto out_size = ctx.Input("OutSize"); - if (out_size != nullptr) { - Tensor sizes; - framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); - auto size_data = sizes.data(); - out_h = size_data[0]; - out_w = size_data[1]; - } - - int n = input_grad->dims()[0]; - int c = input_grad->dims()[1]; - int in_h = input_grad->dims()[2]; - int in_w = input_grad->dims()[3]; - - int in_hw = in_h * in_w; - int out_hw = out_h * out_w; - int in_chw = c * in_hw; - int out_chw = c * out_hw; - - T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; - - if (in_h == out_h && in_w == out_w) { - memcpy(input_grad, output_grad, input_grad->numel() * sizeof(T)); - return; - } - - int threadNum = n * out_chw; - int blocks = (threadNum + 1024 - 1) / 1024; - - KeNearestNeighborInterpBw< - T><<>>( - input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, - n, out_chw, c, ratio_h, ratio_w); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp, - ops::NearestNeighborInterpOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp_grad, - ops::NearestNeighborInterpGradOpCUDAKernel); diff --git a/paddle/fluid/operators/nearest_neighbor_interp_op.h b/paddle/fluid/operators/nearest_neighbor_interp_op.h deleted file mode 100644 index a37cc703b1..0000000000 --- a/paddle/fluid/operators/nearest_neighbor_interp_op.h +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/math_function.h" - -namespace paddle { -namespace operators { - -template -using EigenTensor = framework::EigenTensor; -using Tensor = framework::Tensor; - -template -class NearestNeighborInterpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); - - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - auto out_size = ctx.Input("OutSize"); - if (out_size != nullptr) { - auto out_size_data = out_size->data(); - out_h = out_size_data[0]; - out_w = out_size_data[1]; - } - - const int n = input->dims()[0]; - const int c = input->dims()[1]; - const int in_h = input->dims()[2]; - const int in_w = input->dims()[3]; - - output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); - auto& device_ctx = - ctx.template device_context(); - math::SetConstant zero; - zero(device_ctx, output, static_cast(0.0)); - - if (in_h == out_h && in_w == out_w) { - framework::TensorCopy(*input, ctx.GetPlace(), output); - return; - } - - float ratio_h = - (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - float ratio_w = - (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; - - auto input_t = EigenTensor::From(*input); - auto output_t = EigenTensor::From(*output); - for (int k = 0; k < out_h; k++) { // loop for images - int in_k = static_cast(round(ratio_h * k)); - for (int l = 0; l < out_w; l++) { - int in_l = static_cast(round(ratio_w * l)); - for (int i = 0; i < n; i++) { // loop for batches - for (int j = 0; j < c; j++) { // loop for channels - output_t(i, j, k, l) = input_t(i, j, in_k, in_l); - } - } - } - } - } -}; - -template -class NearestNeighborInterpGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* input_grad = ctx.Output(framework::GradVarName("X")); - auto* output_grad = ctx.Input(framework::GradVarName("Out")); - - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - auto out_size = ctx.Input("OutSize"); - if (out_size != nullptr) { - auto out_size_data = out_size->data(); - out_h = out_size_data[0]; - out_w = out_size_data[1]; - } - - const int n = input->dims()[0]; - const int c = input->dims()[1]; - const int in_h = input->dims()[2]; - const int in_w = input->dims()[3]; - - input_grad->mutable_data({n, c, in_h, in_w}, ctx.GetPlace()); - auto& device_ctx = - ctx.template device_context(); - math::SetConstant zero; - zero(device_ctx, input_grad, static_cast(0.0)); - - if (in_h == out_h && in_w == out_w) { - framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); - return; - } - - float ratio_h = - (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - float ratio_w = - (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; - - auto input_grad_t = EigenTensor::From(*input_grad); - auto output_grad_t = EigenTensor::From(*output_grad); - for (int k = 0; k < out_h; k++) { // loop for images - int in_k = static_cast(round(ratio_h * k)); - for (int l = 0; l < out_w; l++) { - int in_l = static_cast(round(ratio_w * l)); - for (int i = 0; i < n; i++) { // loop for batches - for (int j = 0; j < c; j++) { // loop for channels - input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l); - } - } - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index f4d8308e7c..3b65825b96 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5612,17 +5612,14 @@ def image_resize(input, out = fluid.layers.image_resize(input, out_shape=[12, 12]) """ - resample_methods = { - 'BILINEAR': 'bilinear_interp', - 'NEAREST': 'nearest_neighbor_interp' - } + resample_methods = {'BILINEAR': 'bilinear', 'NEAREST': 'nearest'} if resample not in resample_methods: raise ValueError( "The 'resample' of image_resize can only be 'BILINEAR' and 'NEAREST' currently." ) if out_shape is None and scale is None: raise ValueError("One of out_shape and scale must not be None") - helper = LayerHelper(resample_methods[resample], **locals()) + helper = LayerHelper('interpolate', **locals()) dtype = helper.input_dtype() def _is_list_or_turple_(data): @@ -5647,15 +5644,18 @@ def image_resize(input, out = helper.create_variable_for_type_inference(dtype) helper.append_op( - type=resample_methods[resample], + type='interpolate', inputs=inputs, outputs={"Out": out}, - attrs={"out_h": out_h, - "out_w": out_w}) + attrs={ + "out_h": out_h, + "out_w": out_w, + "interp_method": resample_methods[resample] + }) return out -@templatedoc(op_type="bilinear_interp") +@templatedoc(op_type="interpolate") def resize_bilinear(input, out_shape=None, scale=None, name=None): """ ${comment} @@ -5678,7 +5678,7 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None): return image_resize(input, out_shape, scale, name, 'BILINEAR') -@templatedoc(op_type="bilinear_interp") +@templatedoc(op_type="interpolate") def resize_nearest(input, out_shape=None, scale=None, name=None): """ ${comment} diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py deleted file mode 100644 index bed847c3c1..0000000000 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import numpy as np -from op_test import OpTest -import paddle.fluid.core as core - - -def bilinear_interp_np(input, out_h, out_w, out_size): - if out_size is not None: - out_h = out_size[0] - out_w = out_size[1] - batch_size, channel, in_h, in_w = input.shape - if out_h > 1: - ratio_h = (in_h - 1.0) / (out_h - 1.0) - else: - ratio_h = 0.0 - if out_w > 1: - ratio_w = (in_w - 1.0) / (out_w - 1.0) - else: - ratio_w = 0.0 - - out = np.zeros((batch_size, channel, out_h, out_w)) - for i in range(out_h): - h = int(ratio_h * i) - hid = 1 if h < in_h - 1 else 0 - h1lambda = ratio_h * i - h - h2lambda = 1.0 - h1lambda - for j in range(out_w): - w = int(ratio_w * j) - wid = 1 if w < in_w - 1 else 0 - w1lambda = ratio_w * j - w - w2lambda = 1.0 - w1lambda - - out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] + - w1lambda*input[:, :, h, w+wid]) + \ - h1lambda*(w2lambda*input[:, :, h+hid, w] + - w1lambda*input[:, :, h+hid, w+wid]) - return out.astype(input.dtype) - - -class TestBilinearInterpOp(OpTest): - def setUp(self): - self.out_size = None - self.init_test_case() - self.op_type = "bilinear_interp" - input_np = np.random.random(self.input_shape).astype("float32") - output_np = bilinear_interp_np(input_np, self.out_h, self.out_w, - self.out_size) - self.inputs = {'X': input_np} - if self.out_size is not None: - self.inputs['OutSize'] = self.out_size - self.attrs = {'out_h': self.out_h, 'out_w': self.out_w} - self.outputs = {'Out': output_np} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) - - def init_test_case(self): - self.input_shape = [2, 3, 4, 4] - self.out_h = 2 - self.out_w = 2 - self.out_size = np.array([3, 3]).astype("int32") - - -class TestCase1(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [4, 1, 7, 8] - self.out_h = 1 - self.out_w = 1 - - -class TestCase2(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [3, 3, 9, 6] - self.out_h = 12 - self.out_w = 12 - - -class TestCase3(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [1, 1, 128, 64] - self.out_h = 64 - self.out_w = 128 - - -class TestCase4(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [4, 1, 7, 8] - self.out_h = 1 - self.out_w = 1 - self.out_size = np.array([2, 2]).astype("int32") - - -class TestCase5(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [3, 3, 9, 6] - self.out_h = 12 - self.out_w = 12 - self.out_size = np.array([11, 11]).astype("int32") - - -class TestCase6(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [1, 1, 128, 64] - self.out_h = 64 - self.out_w = 128 - self.out_size = np.array([65, 129]).astype("int32") - - -class TestBilinearInterpOpUint8(OpTest): - def setUp(self): - self.out_size = None - self.init_test_case() - self.op_type = "bilinear_interp" - input_np = np.random.randint( - low=0, high=256, size=self.input_shape).astype("uint8") - output_np = bilinear_interp_np(input_np, self.out_h, self.out_w, - self.out_size) - self.inputs = {'X': input_np} - if self.out_size is not None: - self.inputs['OutSize'] = self.out_size - self.attrs = {'out_h': self.out_h, 'out_w': self.out_w} - self.outputs = {'Out': output_np} - - def test_check_output(self): - self.check_output_with_place(place=core.CPUPlace(), atol=1) - - def init_test_case(self): - self.input_shape = [1, 3, 9, 6] - self.out_h = 10 - self.out_w = 9 - - -class TestCase1Uint8(TestBilinearInterpOpUint8): - def init_test_case(self): - self.input_shape = [2, 3, 128, 64] - self.out_h = 120 - self.out_w = 50 - - -class TestCase2Uint8(TestBilinearInterpOpUint8): - def init_test_case(self): - self.input_shape = [4, 1, 7, 8] - self.out_h = 5 - self.out_w = 13 - self.out_size = np.array([6, 15]).astype("int32") - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_interpolate_op.py b/python/paddle/fluid/tests/unittests/test_interpolate_op.py new file mode 100644 index 0000000000..a90f4aace2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_interpolate_op.py @@ -0,0 +1,294 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core + + +def nearest_neighbor_interp_np(X, out_h, out_w, out_size=None): + """nearest neighbor interpolation implement in shape [N, C, H, W]""" + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + n, c, in_h, in_w = X.shape + + ratio_h = ratio_w = 0.0 + if out_h > 1: + ratio_h = (in_h - 1.0) / (out_h - 1.0) + if out_w > 1: + ratio_w = (in_w - 1.0) / (out_w - 1.0) + + out = np.zeros((n, c, out_h, out_w)) + for i in range(out_h): + in_i = int(ratio_h * i + 0.5) + for j in range(out_w): + in_j = int(ratio_w * j + 0.5) + out[:, :, i, j] = X[:, :, in_i, in_j] + + return out.astype(X.dtype) + + +def bilinear_interp_np(input, out_h, out_w, out_size): + """bilinear interpolation implement in shape [N, C, H, W]""" + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + batch_size, channel, in_h, in_w = input.shape + if out_h > 1: + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + ratio_h = 0.0 + if out_w > 1: + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 0.0 + + out = np.zeros((batch_size, channel, out_h, out_w)) + for i in range(out_h): + h = int(ratio_h * i) + hid = 1 if h < in_h - 1 else 0 + h1lambda = ratio_h * i - h + h2lambda = 1.0 - h1lambda + for j in range(out_w): + w = int(ratio_w * j) + wid = 1 if w < in_w - 1 else 0 + w1lambda = ratio_w * j - w + w2lambda = 1.0 - w1lambda + + out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] + + w1lambda*input[:, :, h, w+wid]) + \ + h1lambda*(w2lambda*input[:, :, h+hid, w] + + w1lambda*input[:, :, h+hid, w+wid]) + return out.astype(input.dtype) + + +INTERPOLATE_FUNCS = { + 'bilinear': bilinear_interp_np, + 'nearest': nearest_neighbor_interp_np, +} + + +class TestInterpolateOp(OpTest): + def setUp(self): + self.out_size = None + self.init_test_case() + self.op_type = "interpolate" + input_np = np.random.random(self.input_shape).astype("float32") + + output_np = INTERPOLATE_FUNCS[self.interp_method]( + input_np, self.out_h, self.out_w, self.out_size) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'interp_method': self.interp_method + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 4, 4] + self.out_h = 2 + self.out_w = 2 + self.out_size = np.array([3, 3]).astype("int32") + + +class TestBilinearInterpCase1(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + + +class TestBilinearInterpCase2(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + + +class TestBilinearInterpCase3(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + + +class TestBilinearInterpCase4(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.out_size = np.array([2, 2]).astype("int32") + + +class TestBilinearInterpCase5(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.out_size = np.array([11, 11]).astype("int32") + + +class TestBilinearInterpCase6(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + self.out_size = np.array([65, 129]).astype("int32") + + +# class TestBilinearInterpBigScale(TestInterpolateOp): +# def init_test_case(self): +# self.interp_method = 'bilinear' +# self.input_shape = [32, 16, 128, 64] +# self.out_h = 200 +# self.out_w = 100 +# self.out_size = np.array([201, 101]).astype('int32') + + +class TestInterpolateOpUint8(OpTest): + def setUp(self): + self.out_size = None + self.init_test_case() + self.op_type = "interpolate" + input_np = np.random.randint( + low=0, high=256, size=self.input_shape).astype("uint8") + output_np = INTERPOLATE_FUNCS[self.interp_method]( + input_np, self.out_h, self.out_w, self.out_size) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'interp_method': self.interp_method + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output_with_place(place=core.CPUPlace(), atol=1) + + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [1, 3, 9, 6] + self.out_h = 10 + self.out_w = 9 + + +class TestBilinearInterpCase1Uint8(TestInterpolateOpUint8): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 128, 64] + self.out_h = 120 + self.out_w = 50 + + +class TestBilinearInterpCase2Uint8(TestInterpolateOpUint8): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 1, 7, 8] + self.out_h = 5 + self.out_w = 13 + self.out_size = np.array([6, 15]).astype("int32") + + +class TestNearestNeighborInterpCase1(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + + +class TestNearestNeighborInterpCase2(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + + +class TestNearestNeighborInterpCase3(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + + +class TestNearestNeighborInterpCase4(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.out_size = np.array([2, 2]).astype("int32") + + +class TestNearestNeighborInterpCase5(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.out_size = np.array([11, 11]).astype("int32") + + +class TestNearestNeighborInterpCase6(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + self.out_size = np.array([65, 129]).astype("int32") + + +class TestNearestNeighborInterpCase1Uint8(TestInterpolateOpUint8): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 3, 128, 64] + self.out_h = 120 + self.out_w = 50 + + +class TestNearestNeighborInterpCase2Uint8(TestInterpolateOpUint8): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 5 + self.out_w = 13 + self.out_size = np.array([6, 15]).astype("int32") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 0390938901..30e87793a6 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -485,7 +485,7 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(output) print(str(program)) - def test_resize_bilinear(self): + def test_resize_nearest(self): program = Program() with program_guard(program): x = layers.data(name='x', shape=[3, 9, 6], dtype="float32") diff --git a/python/paddle/fluid/tests/unittests/test_nearest_neighbor_interp_op.py b/python/paddle/fluid/tests/unittests/test_nearest_neighbor_interp_op.py deleted file mode 100644 index 78ad3b98f5..0000000000 --- a/python/paddle/fluid/tests/unittests/test_nearest_neighbor_interp_op.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import numpy as np -from op_test import OpTest -import paddle.fluid.core as core - - -def nearest_neighbor_interp_np(X, out_h, out_w, out_size=None): - """nearest neighbor interpolation implement in shape [N, C, H, W]""" - if out_size is not None: - out_h = out_size[0] - out_w = out_size[1] - n, c, in_h, in_w = X.shape - - ratio_h = ratio_w = 0.0 - if out_h > 1: - ratio_h = (in_h - 1.0) / (out_h - 1.0) - if out_w > 1: - ratio_w = (in_w - 1.0) / (out_w - 1.0) - - out = np.zeros((n, c, out_h, out_w)) - for i in range(out_h): - in_i = int(round(ratio_h * i)) - for j in range(out_w): - in_j = int(round(ratio_w * j)) - out[:, :, i, j] = X[:, :, in_i, in_j] - - return out.astype(X.dtype) - - -class TestBilinearInterpOp(OpTest): - def setUp(self): - self.out_size = None - self.init_test_case() - self.op_type = "nearest_neighbor_interp" - input_np = np.random.random(self.input_shape).astype("float32") - output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w, - self.out_size) - self.inputs = {'X': input_np} - if self.out_size is not None: - self.inputs['OutSize'] = self.out_size - self.attrs = {'out_h': self.out_h, 'out_w': self.out_w} - self.outputs = {'Out': output_np} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) - - def init_test_case(self): - self.input_shape = [2, 3, 4, 4] - self.out_h = 2 - self.out_w = 2 - self.out_size = np.array([3, 3]).astype("int32") - - -class TestCase1(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [4, 1, 7, 8] - self.out_h = 1 - self.out_w = 1 - - -class TestCase2(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [3, 3, 9, 6] - self.out_h = 12 - self.out_w = 12 - - -class TestCase3(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [1, 1, 128, 64] - self.out_h = 64 - self.out_w = 128 - - -class TestCase4(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [4, 1, 7, 8] - self.out_h = 1 - self.out_w = 1 - self.out_size = np.array([2, 2]).astype("int32") - - -class TestCase5(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [3, 3, 9, 6] - self.out_h = 12 - self.out_w = 12 - self.out_size = np.array([11, 11]).astype("int32") - - -class TestCase6(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [1, 1, 128, 64] - self.out_h = 64 - self.out_w = 128 - self.out_size = np.array([65, 129]).astype("int32") - - -class TestBilinearInterpOpUint8(OpTest): - def setUp(self): - self.out_size = None - self.init_test_case() - self.op_type = "nearest_neighbor_interp" - input_np = np.random.randint( - low=0, high=256, size=self.input_shape).astype("uint8") - output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w, - self.out_size) - self.inputs = {'X': input_np} - if self.out_size is not None: - self.inputs['OutSize'] = self.out_size - self.attrs = {'out_h': self.out_h, 'out_w': self.out_w} - self.outputs = {'Out': output_np} - - def test_check_output(self): - self.check_output_with_place(place=core.CPUPlace(), atol=1) - - def init_test_case(self): - self.input_shape = [1, 3, 9, 6] - self.out_h = 10 - self.out_w = 9 - - -class TestCase1Uint8(TestBilinearInterpOpUint8): - def init_test_case(self): - self.input_shape = [2, 3, 128, 64] - self.out_h = 120 - self.out_w = 50 - - -class TestCase2Uint8(TestBilinearInterpOpUint8): - def init_test_case(self): - self.input_shape = [4, 1, 7, 8] - self.out_h = 5 - self.out_w = 13 - self.out_size = np.array([6, 15]).astype("int32") - - -if __name__ == "__main__": - unittest.main() From fef2faa709008d681477f4ef5d7dc77e063de392 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 5 Nov 2018 19:07:59 +0800 Subject: [PATCH 06/23] limit CUDA kernel parallel threads max number to 4096. test=develop --- paddle/fluid/operators/interpolate_op.cu | 30 +++++++++++-------- .../tests/unittests/test_interpolate_op.py | 23 +++++++++----- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index 3b9ece4830..190afbdac4 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -26,7 +26,8 @@ __global__ void KeNearestNeighborInterpFw( const size_t num_channels, const float ratio_h, const float ratio_w) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; int in_img_size = input_w / num_channels; @@ -52,7 +53,8 @@ __global__ void KeNearestNeighborInterpBw( const size_t num_channels, const float ratio_h, const float ratio_w) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; int in_img_size = input_w / num_channels; @@ -80,7 +82,8 @@ __global__ void KeBilinearInterpFw( const size_t num_channels, const float ratio_h, const float ratio_w) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; int in_img_size = input_w / num_channels; @@ -118,7 +121,8 @@ __global__ void KeBilinearInterpBw( const size_t num_channels, const T ratio_h, const T ratio_w) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; int in_img_size = input_w / num_channels; @@ -194,17 +198,18 @@ class InterpolateOpCUDAKernel : public framework::OpKernel { return; } - int threadNum = n * out_chw; - int blocks = (threadNum + 1024 - 1) / 1024; + int pixelNum = n * out_chw; + int grid_dim = (pixelNum + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; if ("nearest" == interp_method) { KeNearestNeighborInterpFw< - T><<>>( + T><<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w); } else if ("bilinear" == interp_method) { KeBilinearInterpFw< - T><<>>( + T><<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w); } @@ -257,17 +262,18 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel { return; } - int threadNum = n * out_chw; - int blocks = (threadNum + 1024 - 1) / 1024; + int pixelNum = n * out_chw; + int grid_dim = (pixelNum + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; if ("nearest" == interp_method) { KeNearestNeighborInterpBw< - T><<>>( + T><<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w); } else if ("bilinear" == interp_method) { KeBilinearInterpBw< - T><<>>( + T><<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w); } diff --git a/python/paddle/fluid/tests/unittests/test_interpolate_op.py b/python/paddle/fluid/tests/unittests/test_interpolate_op.py index a90f4aace2..dd3bf5fd5c 100644 --- a/python/paddle/fluid/tests/unittests/test_interpolate_op.py +++ b/python/paddle/fluid/tests/unittests/test_interpolate_op.py @@ -167,13 +167,13 @@ class TestBilinearInterpCase6(TestInterpolateOp): self.out_size = np.array([65, 129]).astype("int32") -# class TestBilinearInterpBigScale(TestInterpolateOp): -# def init_test_case(self): -# self.interp_method = 'bilinear' -# self.input_shape = [32, 16, 128, 64] -# self.out_h = 200 -# self.out_w = 100 -# self.out_size = np.array([201, 101]).astype('int32') +class TestBilinearInterpBigScale(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 4, 64, 32] + self.out_h = 100 + self.out_w = 50 + self.out_size = np.array([101, 51]).astype('int32') class TestInterpolateOpUint8(OpTest): @@ -273,6 +273,15 @@ class TestNearestNeighborInterpCase6(TestInterpolateOp): self.out_size = np.array([65, 129]).astype("int32") +class TestNearestNeighborInterpBigScale(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 4, 64, 32] + self.out_h = 100 + self.out_w = 50 + self.out_size = np.array([101, 51]).astype('int32') + + class TestNearestNeighborInterpCase1Uint8(TestInterpolateOpUint8): def init_test_case(self): self.interp_method = 'nearest' From 8b47d90f5d89cd0ef38d06960e4f06f7bb7dd383 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Wed, 7 Nov 2018 11:17:19 +0800 Subject: [PATCH 07/23] add 'actual_shape' attribute. test=develop --- paddle/fluid/API.spec | 6 +- paddle/fluid/operators/interpolate_op.cc | 6 +- python/paddle/fluid/layers/nn.py | 124 +++++++++++++++--- .../tests/unittests/test_interpolate_op.py | 40 +++++- 4 files changed, 148 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 65436cdd98..1948c6ecc9 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -118,10 +118,10 @@ paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)) paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)) paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)) -paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR')) +paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None)) paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)) -paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) -paddle.fluid.layers.resize_nearest ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) +paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, None)) +paddle.fluid.layers.resize_nearest ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index e2000d0e0c..8f979e05d3 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -40,11 +40,13 @@ class InterpolateOp : public framework::OperatorWithKernel { int out_w = ctx->Attrs().Get("out_w"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4"); - if (ctx->HasInput("OutSize")) { + if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { auto out_size_dim = ctx->GetInputDim("OutSize"); PADDLE_ENFORCE_EQ(out_size_dim.size(), 1, "OutSize's dimension size must be 1"); PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2"); + ctx->ShareLoD("X", "Out"); + return; } std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); @@ -86,7 +88,7 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { interpolation. Nearest neighbor interpolation is to perform nearest neighbor interpolation - in bot the 3rd dimention(in height direction) and the 4th dimention(in width + in both the 3rd dimention(in height direction) and the 4th dimention(in width direction) on input tensor. Bilinear interpolation is an extension of linear interpolation for diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3b65825b96..46ce401b17 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5575,7 +5575,8 @@ def image_resize(input, out_shape=None, scale=None, name=None, - resample='BILINEAR'): + resample='BILINEAR', + actual_shape=None): """ **Resize a Batch of Images** @@ -5600,25 +5601,50 @@ def image_resize(input, Default: None name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. - resample(str): The resample method. It can only be 'BILINEAR' currently. + resample(str): The resample method. It supports 'BILINEAR' and 'NEAREST' + currently. Default: 'BILINEAR' + actual_shape(Variable): An optional input to specify output shape + dynamically. If provided, image resize + according to this given shape rather than + :attr:`out_shape` and :attr:`scale` specifying + shape. That is to say actual_shape has the + highest priority. It is recommended to use + actual_shape instead of :attr:`out_shape` if you + want to specify output shape dynamically. When + using actual_shape to specify output shape, one of + :attr:`out_shape` and :attr:`scale` should also be + set, otherwise errors would be occured in graph + constructing stage. + Default: None Returns: Variable: The output is a 4-D tensor of the shape (num_batches, channls, out_h, out_w). + Raises: + TypeError: out_shape should be a list or tuple or Variable. + TypeError: actual_shape should either be Variable or None. + ValueError: The 'resample' of image_resize can only be 'BILINEAR' + or 'NEAREST' currently. + ValueError: One of out_shape and scale must not be None. + ValueError: out_shape length should be 2. + Examples: .. code-block:: python out = fluid.layers.image_resize(input, out_shape=[12, 12]) """ - resample_methods = {'BILINEAR': 'bilinear', 'NEAREST': 'nearest'} + resample_methods = { + 'BILINEAR': 'bilinear', + 'NEAREST': 'nearest', + } if resample not in resample_methods: raise ValueError( - "The 'resample' of image_resize can only be 'BILINEAR' and 'NEAREST' currently." + "The 'resample' of image_resize can only be 'BILINEAR' or 'NEAREST' currently." ) if out_shape is None and scale is None: - raise ValueError("One of out_shape and scale must not be None") + raise ValueError("One of out_shape and scale must not be None.") helper = LayerHelper('interpolate', **locals()) dtype = helper.input_dtype() @@ -5629,19 +5655,28 @@ def image_resize(input, out_w = 0 inputs = {"X": input} if out_shape is not None: - if not (_is_list_or_turple_(out_shape) and - len(out_shape) == 2) and not isinstance(out_shape, Variable): - raise ValueError('out_shape should be a list or tuple or variable') - if _is_list_or_turple_(out_shape): - out_shape = list(map(int, out_shape)) - out_h = out_shape[0] - out_w = out_shape[1] - else: + if isinstance(out_shape, Variable): + warnings.warn("out_shape as Variable type is deprecated, \ + it is recommended to use actual_shape instead of \ + out_shape to specify output shape dynamically.") inputs['OutSize'] = out_shape + elif not (_is_list_or_turple_(out_shape)): + raise TypeError("out_shape should be a list or tuple or Variable.") + elif len(out_shape) != 2: + raise ValueError("out_shape length should be 2.") + + out_shape = list(map(int, out_shape)) + out_h = out_shape[0] + out_w = out_shape[1] else: out_h = int(input.shape[2] * scale) out_w = int(input.shape[3] * scale) + if isinstance(actual_shape, Variable): + inputs["OutSize"] = actual_shape + elif actual_shape is not None: + raise TypeError("actual_shape should either be Variable or None.") + out = helper.create_variable_for_type_inference(dtype) helper.append_op( type='interpolate', @@ -5656,9 +5691,24 @@ def image_resize(input, @templatedoc(op_type="interpolate") -def resize_bilinear(input, out_shape=None, scale=None, name=None): +def resize_bilinear(input, + out_shape=None, + scale=None, + name=None, + actual_shape=None): """ - ${comment} + Resize input by performing bilinear interpolation based on given + output shape which specified by actual_shape, out_shape and scale + in priority order. + + Bilinear interpolation is an extension of linear interpolation for + interpolating functions of two variables (e.g. H-direction and + W-direction in this op) on a rectilinear 2D grid. The key idea is + to perform linear interpolation first in one direction, and then + again in the other direction. + + For details of bilinear interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Bilinear_interpolation Args: input(${x_type}): ${x_comment}. @@ -5670,18 +5720,41 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None): a higher priority than scale. Default: None. name(str|None): The output variable name. + actual_shape(Variable): An optional input to specify output shape + dynamically. If provided, image resize + according to this given shape rather than + :attr:`out_shape` and :attr:`scale` specifying + shape. That is to say actual_shape has the + highest priority. It is recommended to use + actual_shape instead of :attr:`out_shape` if you + want to specify output shape dynamically. When + using actual_shape to specify output shape, one of + :attr:`out_shape` and :attr:`scale` should also be + set, otherwise errors would be occured in graph + constructing stage. + Default: None Returns: ${out_comment}. """ - return image_resize(input, out_shape, scale, name, 'BILINEAR') + return image_resize(input, out_shape, scale, name, 'BILINEAR', actual_shape) @templatedoc(op_type="interpolate") -def resize_nearest(input, out_shape=None, scale=None, name=None): +def resize_nearest(input, + out_shape=None, + scale=None, + name=None, + actual_shape=None): """ - ${comment} + Resize input by performing nearest neighbor interpolation in both the + 3rd dimention(in height direction) and the 4th dimention(in width + direction) based on given output shape which specified by actual_shape, + out_shape and scale in priority order. + + For details of nearest neighbor interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation Args: input(${x_type}): ${x_comment}. @@ -5693,12 +5766,25 @@ def resize_nearest(input, out_shape=None, scale=None, name=None): a higher priority than scale. Default: None. name(str|None): The output variable name. + actual_shape(Variable): An optional input to specify output shape + dynamically. If provided, image resize + according to this given shape rather than + :attr:`out_shape` and :attr:`scale` specifying + shape. That is to say actual_shape has the + highest priority. It is recommended to use + actual_shape instead of :attr:`out_shape` if you + want to specify output shape dynamically. When + using actual_shape to specify output shape, one of + :attr:`out_shape` and :attr:`scale` should also be + set, otherwise errors would be occured in graph + constructing stage. + Default: None Returns: ${out_comment}. """ - return image_resize(input, out_shape, scale, name, 'NEAREST') + return image_resize(input, out_shape, scale, name, 'NEAREST', actual_shape) def image_resize_short(input, out_short_len, resample='BILINEAR'): diff --git a/python/paddle/fluid/tests/unittests/test_interpolate_op.py b/python/paddle/fluid/tests/unittests/test_interpolate_op.py index dd3bf5fd5c..9748d094cd 100644 --- a/python/paddle/fluid/tests/unittests/test_interpolate_op.py +++ b/python/paddle/fluid/tests/unittests/test_interpolate_op.py @@ -20,11 +20,18 @@ from op_test import OpTest import paddle.fluid.core as core -def nearest_neighbor_interp_np(X, out_h, out_w, out_size=None): +def nearest_neighbor_interp_np(X, + out_h, + out_w, + out_size=None, + actual_shape=None): """nearest neighbor interpolation implement in shape [N, C, H, W]""" if out_size is not None: out_h = out_size[0] out_w = out_size[1] + if actual_shape is not None: + out_h = actual_shape[0] + out_w = actual_shape[1] n, c, in_h, in_w = X.shape ratio_h = ratio_w = 0.0 @@ -43,11 +50,14 @@ def nearest_neighbor_interp_np(X, out_h, out_w, out_size=None): return out.astype(X.dtype) -def bilinear_interp_np(input, out_h, out_w, out_size): +def bilinear_interp_np(input, out_h, out_w, out_size=None, actual_shape=None): """bilinear interpolation implement in shape [N, C, H, W]""" if out_size is not None: out_h = out_size[0] out_w = out_size[1] + if actual_shape is not None: + out_h = actual_shape[0] + out_w = actual_shape[1] batch_size, channel, in_h, in_w = input.shape if out_h > 1: ratio_h = (in_h - 1.0) / (out_h - 1.0) @@ -86,15 +96,18 @@ INTERPOLATE_FUNCS = { class TestInterpolateOp(OpTest): def setUp(self): self.out_size = None + self.actual_shape = None self.init_test_case() self.op_type = "interpolate" input_np = np.random.random(self.input_shape).astype("float32") output_np = INTERPOLATE_FUNCS[self.interp_method]( - input_np, self.out_h, self.out_w, self.out_size) + input_np, self.out_h, self.out_w, self.out_size, self.actual_shape) self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size + if self.actual_shape is not None: + self.inputs['OutSize'] = self.actual_shape self.attrs = { 'out_h': self.out_h, 'out_w': self.out_w, @@ -167,6 +180,15 @@ class TestBilinearInterpCase6(TestInterpolateOp): self.out_size = np.array([65, 129]).astype("int32") +class TestBilinearInterpActualShape(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.out_size = np.array([66, 40]).astype("int32") + + class TestBilinearInterpBigScale(TestInterpolateOp): def init_test_case(self): self.interp_method = 'bilinear' @@ -179,12 +201,13 @@ class TestBilinearInterpBigScale(TestInterpolateOp): class TestInterpolateOpUint8(OpTest): def setUp(self): self.out_size = None + self.actual_shape = None self.init_test_case() self.op_type = "interpolate" input_np = np.random.randint( low=0, high=256, size=self.input_shape).astype("uint8") output_np = INTERPOLATE_FUNCS[self.interp_method]( - input_np, self.out_h, self.out_w, self.out_size) + input_np, self.out_h, self.out_w, self.out_size, self.actual_shape) self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size @@ -273,6 +296,15 @@ class TestNearestNeighborInterpCase6(TestInterpolateOp): self.out_size = np.array([65, 129]).astype("int32") +class TestNearestNeighborInterpActualShape(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.out_size = np.array([66, 40]).astype("int32") + + class TestNearestNeighborInterpBigScale(TestInterpolateOp): def init_test_case(self): self.interp_method = 'nearest' From f3eafec19d8b43ecedfff6b8ddb2c2b3acefe6eb Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 8 Nov 2018 13:29:44 +0800 Subject: [PATCH 08/23] fix pserver weight decay multi inputs test=develop --- .../fluid/transpiler/distribute_transpiler.py | 52 +++++++++++++------ 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 7c7fba7671..094eaeb59c 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -1706,13 +1706,27 @@ to transpile() call.") outputs=outputs, attrs=opt_op.all_attrs()) - def _is_splited_grad_var(self, var, var_dict): + def _get_pserver_grad_param_var(self, var, var_dict): + """ + Return pserver side grad/param variable, return None + if the variable is not grad/param, e.g. + + a@GRAD -> a@GRAD.block0 + a@GRAD -> a@GRAD (a is not splited) + fc_0.w_0 -> fc_0.w_0.block_0 + fc_0.w_0 -> fc_0.w_0 (weight is not splited) + _generated_var_123 -> None + """ grad_block = None for _, g in six.iteritems(var_dict): if self._orig_varname(g.name) == self._orig_varname(var.name): + # skip per trainer vars if g.name.find(".trainer_") == -1: - grad_block = g - break + # only param or grads have splited blocks + if self._orig_varname(g.name) in self.grad_name_to_param_name or\ + self._orig_varname(g.name) in self.param_name_to_grad_name: + grad_block = g + break return grad_block def _clone_lr_op(self, program, block, op): @@ -1745,32 +1759,38 @@ to transpile() call.") for key, varlist in six.iteritems(inputs): if not isinstance(varlist, list): varlist = [varlist] - for var in varlist: - # for ops like clipping and weight decay, get the splited var + for i in range(len(varlist)): + var = varlist[i] + # for ops like clipping and weight decay, get the splited var (xxx.block0) # for inputs/outputs - grad_block = self._is_splited_grad_var( + grad_block = self._get_pserver_grad_param_var( var, program.global_block().vars) if grad_block: - inputs[key] = grad_block + varlist[i] = grad_block elif var.name not in program.global_block().vars: - program.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) + tmpvar = program.global_block()._clone_variable(var) + varlist[i] = tmpvar + else: + varlist[i] = program.global_block().vars[var.name] + inputs[key] = varlist outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, opt_op) for key, varlist in six.iteritems(outputs): if not isinstance(varlist, list): varlist = [varlist] - for var in varlist: - grad_block = self._is_splited_grad_var( + for i in range(len(varlist)): + var = varlist[i] + grad_block = self._get_pserver_grad_param_var( var, program.global_block().vars) if grad_block: - outputs[key] = grad_block + varlist[i] = grad_block elif var.name not in program.global_block().vars: - program.global_block()._clone_variable(var) + tmpvar = program.global_block()._clone_variable(var) + varlist[i] = tmpvar + else: + varlist[i] = program.global_block().vars[var.name] + outputs[key] = varlist return optimize_block.append_op( type=opt_op.type, From 03e11f3fc9dc53d4925f341f15bec0f2393da80a Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 8 Nov 2018 06:01:25 +0000 Subject: [PATCH 09/23] add vscal jitcode --- paddle/fluid/operators/math/jit_code.cc | 35 +++++ paddle/fluid/operators/math/jit_code.h | 30 +++- paddle/fluid/operators/math/jit_kernel.h | 3 +- .../fluid/operators/math/jit_kernel_blas.cc | 143 +++++++++--------- paddle/fluid/operators/math/jit_kernel_exp.cc | 15 +- .../fluid/operators/math/jit_kernel_test.cc | 9 +- 6 files changed, 150 insertions(+), 85 deletions(-) diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index a92e5d351e..f853497804 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -96,6 +96,41 @@ void VVVJitCode::generate() { } ret(); } + +bool VScalJitCode::init(int d) { return MayIUse(avx); } + +void VScalJitCode::generate() { + int offset = 0; + vbroadcastss(ymm_src1, ptr[param1]); + for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { + vmovups(ymm_src2, ptr[param2 + offset]); + vmulps(ymm_dst, ymm_src1, ymm_src2); + vmovups(ptr[param3 + offset], ymm_dst); + offset += sizeof(float) * AVX_FLOAT_BLOCK; + } + int rest = num_ % AVX_FLOAT_BLOCK; + if (rest >= 4) { + vmovups(xmm_src2, ptr[param2 + offset]); + vmulps(xmm_dst, xmm_src1, xmm_src2); + vmovups(ptr[param3 + offset], xmm_dst); + offset += sizeof(float) * 4; + rest -= 4; + } + if (rest >= 2) { + vmovq(xmm_src2, ptr[param2 + offset]); + vmulps(xmm_dst, xmm_src1, xmm_src2); + vmovq(ptr[param3 + offset], xmm_dst); + offset += sizeof(float) * 2; + rest -= 2; + } + if (rest > 0) { + vmovss(xmm_src2, ptr[param2 + offset]); + vmulss(xmm_dst, xmm_src1, xmm_src2); + vmovss(ptr[param3 + offset], xmm_dst); + } + ret(); +} + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 73692ebc67..d87831c579 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -29,9 +29,9 @@ using ymm_t = const Xbyak::Ymm; using zmm_t = const Xbyak::Zmm; using Label = Xbyak::Label; -// function: vec = Operand(vec, vec) (maybe with relu) typedef enum { mul = 0, add } operand_type; +// function: vec = Operand(vec, vec) (maybe with relu) class VVVJitCode : public JitCode { public: const char* name() const override { @@ -41,7 +41,7 @@ class VVVJitCode : public JitCode { } else if (type_ == operand_type::add) { base += "_Add"; } - base += (with_relu_ ? "_relu" : ""); + base += (with_relu_ ? "_Relu" : ""); return base.c_str(); } explicit VVVJitCode(int d, operand_type type, bool with_relu, @@ -72,6 +72,32 @@ class VVVJitCode : public JitCode { ymm_t ymm_zero = ymm_t(2); }; +class VScalJitCode : public JitCode { + public: + DECLARE_JIT_CODE(VScalJitCode); + explicit VScalJitCode(int d, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), num_(d) {} + static bool init(int d); + void generate() override; + + private: + int num_; + reg64_t param1{abi_param1}; + reg64_t param2{abi_param2}; + reg64_t param3{abi_param3}; + + xmm_t xmm_src1 = xmm_t(0); + xmm_t xmm_src2 = xmm_t(1); + xmm_t xmm_dst = xmm_t(1); + xmm_t xmm_zero = xmm_t(2); + + ymm_t ymm_src1 = ymm_t(0); + ymm_t ymm_src2 = ymm_t(1); + ymm_t ymm_dst = ymm_t(1); + ymm_t ymm_zero = ymm_t(2); +}; + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 04e0b81d3e..6ee651b988 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -83,8 +83,7 @@ class VAddReluKernel : public Kernel { template class VScalKernel : public Kernel { public: - virtual void Compute(const T a, const T *x, T *y) const = 0; - virtual void Compute(const T a, T *x) const = 0; + void (*Compute)(const T *, const T *, T *, int); }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index f976953a24..a9537ab096 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -57,6 +57,13 @@ void VAddReluRefer(const T* x, const T* y, T* z, int n) { } } +template +void VScalRefer(const T* a, const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = a[0] * x[i]; + } +} + #ifdef PADDLE_WITH_MKLML template void VMulMKL(const T* x, const T* y, T* z, int n); @@ -83,6 +90,28 @@ template <> void VAddMKL(const double* x, const double* y, double* z, int n) { platform::dynload::vdAdd(n, x, y, z); } + +template +void VScalMKL(const T* a, const T* x, T* y, int n); + +template <> +void VScalMKL(const float* a, const float* x, float* y, int n) { + if (x == y) { + platform::dynload::cblas_sscal(n, *a, y, 1); + } else { + VScalRefer(a, x, y, n); + } +} + +template <> +void VScalMKL(const double* a, const double* x, double* y, int n) { + if (x == y) { + platform::dynload::cblas_dscal(n, *a, y, 1); + } else { + VScalRefer(a, x, y, n); + } +} + #endif #define DECLARE_STATIC_FUNC \ @@ -226,87 +255,60 @@ bool VAddReluKernelImpl::useJIT(int d) { } #endif -#undef DECLARE_STATIC_FUNC - -REGISTER_JITKERNEL(vmul, VMulKernel); -REGISTER_JITKERNEL(vadd, VAddKernel); -REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); - -/* VSCAL JitKernel */ -template +/* VScal JitKernel */ +template class VScalKernelImpl : public VScalKernel { public: - explicit VScalKernelImpl(int d) : VScalKernel() { this->num_ = d; } - void Compute(const T a, const T* x, T* y) const override { - for (int i = 0; i < this->num_; ++i) { - y[i] = a * x[i]; - } - } - void Compute(const T a, T* x) const override { - for (int i = 0; i < this->num_; ++i) { - x[i] = a * x[i]; + DECLARE_STATIC_FUNC; + explicit VScalKernelImpl(int d) : VScalKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VScalJitCode(d, sz > 4096 ? sz : 4096)); + this->Compute = + jitcode_->getCode(); + return; } - } -}; - +#endif #ifdef PADDLE_WITH_MKLML -#define MKL_FLOAT(isa, block) \ - template <> \ - void VScalKernelImpl::Compute(const float a, float* x) \ - const { \ - platform::dynload::cblas_sscal(this->num_, a, x, 1); \ - } - -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VScalKernelImpl::Compute(const double a, double* x) \ - const { \ - platform::dynload::cblas_dscal(this->num_, a, x, 1); \ - } - -FOR_EACH_ISA(MKL_FLOAT, kGT16); -FOR_EACH_ISA_BLOCK(MKL_DOUBLE); + if (useMKL(d)) { + this->Compute = VScalMKL; + return; + } #endif - -#define INTRI8_FLOAT(isa) \ - template <> \ - void VScalKernelImpl::Compute( \ - const float a, const float* x, float* y) const { \ - __m256 tmp; \ - __m256 scalar = _mm256_set1_ps(a); \ - tmp = _mm256_loadu_ps(x); \ - tmp = _mm256_mul_ps(tmp, scalar); \ - _mm256_storeu_ps(y, tmp); \ - } -#define INTRI8_INPLACE_FLOAT(isa) \ - template <> \ - void VScalKernelImpl::Compute(const float a, float* x) \ - const { \ - __m256 tmp; \ - __m256 scalar = _mm256_set1_ps(a); \ - tmp = _mm256_loadu_ps(x); \ - tmp = _mm256_mul_ps(tmp, scalar); \ - _mm256_storeu_ps(x, tmp); \ + this->Compute = VScalRefer; } +#ifdef PADDLE_WITH_XBYAK -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI8_INPLACE_FLOAT(jit::avx); + private: + std::unique_ptr jitcode_{nullptr}; #endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI8_INPLACE_FLOAT(jit::avx2); +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VScalKernelImpl::useJIT(int d) { + return gen::VScalJitCode::init(d); +} #endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); -INTRI8_INPLACE_FLOAT(jit::avx512f); + +#ifdef PADDLE_WITH_MKLML +template <> +bool VScalKernelImpl::useMKL(int d) { + return d > 512; +} +template <> +bool VScalKernelImpl::useMKL(int d) { + return true; +} #endif -// TODO(TJ): eq16 test and complete avx512 -#undef INTRI8_FLOAT -#undef INTRI8_INPLACE_FLOAT -#undef MKL_FLOAT -#undef MKL_DOUBLE +#undef DECLARE_STATIC_FUNC + +REGISTER_JITKERNEL(vmul, VMulKernel); +REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vscal, VScalKernel); +REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); /* VAddBias JitKernel */ template @@ -467,7 +469,6 @@ class VIdentityKernelImpl : public VIdentityKernel { void Compute(const T* x, T* y) const override {} }; -REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel); REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel); REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index d7c177e678..07a77086da 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -409,9 +409,10 @@ class VTanhKernelImpl : public VTanhKernel { vaddbias_ = KernelPool::Instance().template Get>(d); } void Compute(const T* x, T* y) const override { - vscal_->Compute(static_cast(2), x, y); + const T a = static_cast(2); + vscal_->Compute(&a, x, y, this->num_); vsigmoid_->Compute(y, y); - vscal_->Compute(static_cast(2), y); + vscal_->Compute(&a, y, y, this->num_); vaddbias_->Compute(static_cast(-1), y, y); } @@ -472,9 +473,10 @@ class VTanhKernelImpl : public VTanhKernel { _mm256_storeu_ps(y, tmp); \ x += AVX_FLOAT_BLOCK; \ y += AVX_FLOAT_BLOCK; \ - vscal_->Compute(2.f, x, y); \ + const float a = 2.f; \ + vscal_->Compute(&a, x, y, this->num_); \ vsigmoid_->Compute(y, y); \ - vscal_->Compute(2.f, y); \ + vscal_->Compute(&a, y, y, this->num_); \ vaddbias_->Compute(-1.f, y, y); \ } @@ -502,9 +504,10 @@ class VTanhKernelImpl : public VTanhKernel { } \ x += this->end_; \ y += this->end_; \ - vscal_->Compute(2.f, x, y); \ + const float a = 2.f; \ + vscal_->Compute(&a, x, y, this->num_); \ vsigmoid_->Compute(y, y); \ - vscal_->Compute(2.f, y); \ + vscal_->Compute(&a, y, y, this->num_); \ vaddbias_->Compute(-1.f, y, y); \ } diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 9a19424691..04a199faae 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -281,9 +281,10 @@ void vtanh_better( const paddle::operators::math::jitkernel::VAddBiasKernel>& vaddbias, const int n, const float* x, float* y) { - vscal->Compute(2.f, x, y); + const float tmp1 = 2.f; + vscal->Compute(&tmp1, x, y, n); vsigmoid->Compute(y, y); - vscal->Compute(2.f, y); + vscal->Compute(&tmp1, y, y, n); vaddbias->Compute(-1.f, y, y); } @@ -531,12 +532,12 @@ TEST(JitKernel, vscal) { auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(a, x_data, ztgt_data); + ker->Compute(&a, x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); auto ttgts1 = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(a, y_data); + ker->Compute(&a, y_data, y_data, d); } auto ttgte1 = GetCurrentUS(); VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat From 3d950a812ddd5a0d75555b36fa605a404ef04232 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 8 Nov 2018 06:57:17 +0000 Subject: [PATCH 10/23] combine jitcode of vscal --- paddle/fluid/operators/math/jit_code.cc | 77 ++++++++----------- paddle/fluid/operators/math/jit_code.h | 49 ++++-------- .../fluid/operators/math/jit_kernel_blas.cc | 25 +++--- 3 files changed, 58 insertions(+), 93 deletions(-) diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index f853497804..6b3eecfbd1 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -24,21 +24,30 @@ namespace gen { using namespace platform::jit; // NOLINT -bool VVVJitCode::init(int d) { +bool VXXJitCode::init(int d, int scalar_index) { // It's not necessary to use avx512 since it would slow down the frequency // and this kernel is not compute bound. - return MayIUse(avx); + return MayIUse(avx) && scalar_index >= 0 && scalar_index <= 2; } -void VVVJitCode::generate() { +void VXXJitCode::generate() { // do not need push stack, and do not need save avx512reg if do not use avx512 int offset = 0; if (with_relu_) { vxorps(ymm_zero, ymm_zero, ymm_zero); } + if (scalar_index_ == 1) { + vbroadcastss(ymm_src1, ptr[param1]); + } else if (scalar_index_ == 2) { + vbroadcastss(ymm_src2, ptr[param2]); + } for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { - vmovups(ymm_src1, ptr[param1 + offset]); - vmovups(ymm_src2, ptr[param2 + offset]); + if (scalar_index_ != 1) { + vmovups(ymm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(ymm_src2, ptr[param2 + offset]); + } if (type_ == operand_type::mul) { vmulps(ymm_dst, ymm_src1, ymm_src2); } else if (type_ == operand_type::add) { @@ -52,8 +61,12 @@ void VVVJitCode::generate() { } int rest = num_ % AVX_FLOAT_BLOCK; if (rest >= 4) { - vmovups(xmm_src1, ptr[param1 + offset]); - vmovups(xmm_src2, ptr[param2 + offset]); + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } if (type_ == operand_type::mul) { vmulps(xmm_dst, xmm_src1, xmm_src2); } else if (type_ == operand_type::add) { @@ -67,8 +80,12 @@ void VVVJitCode::generate() { rest -= 4; } if (rest >= 2) { - vmovq(xmm_src1, ptr[param1 + offset]); - vmovq(xmm_src2, ptr[param2 + offset]); + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } if (type_ == operand_type::mul) { vmulps(xmm_dst, xmm_src1, xmm_src2); } else if (type_ == operand_type::add) { @@ -82,8 +99,12 @@ void VVVJitCode::generate() { rest -= 2; } if (rest > 0) { - vmovss(xmm_src1, ptr[param1 + offset]); - vmovss(xmm_src2, ptr[param2 + offset]); + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } if (type_ == operand_type::mul) { vmulss(xmm_dst, xmm_src1, xmm_src2); } else if (type_ == operand_type::add) { @@ -97,40 +118,6 @@ void VVVJitCode::generate() { ret(); } -bool VScalJitCode::init(int d) { return MayIUse(avx); } - -void VScalJitCode::generate() { - int offset = 0; - vbroadcastss(ymm_src1, ptr[param1]); - for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { - vmovups(ymm_src2, ptr[param2 + offset]); - vmulps(ymm_dst, ymm_src1, ymm_src2); - vmovups(ptr[param3 + offset], ymm_dst); - offset += sizeof(float) * AVX_FLOAT_BLOCK; - } - int rest = num_ % AVX_FLOAT_BLOCK; - if (rest >= 4) { - vmovups(xmm_src2, ptr[param2 + offset]); - vmulps(xmm_dst, xmm_src1, xmm_src2); - vmovups(ptr[param3 + offset], xmm_dst); - offset += sizeof(float) * 4; - rest -= 4; - } - if (rest >= 2) { - vmovq(xmm_src2, ptr[param2 + offset]); - vmulps(xmm_dst, xmm_src1, xmm_src2); - vmovq(ptr[param3 + offset], xmm_dst); - offset += sizeof(float) * 2; - rest -= 2; - } - if (rest > 0) { - vmovss(xmm_src2, ptr[param2 + offset]); - vmulss(xmm_dst, xmm_src1, xmm_src2); - vmovss(ptr[param3 + offset], xmm_dst); - } - ret(); -} - } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index d87831c579..939d9897e6 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -31,11 +31,11 @@ using Label = Xbyak::Label; typedef enum { mul = 0, add } operand_type; -// function: vec = Operand(vec, vec) (maybe with relu) -class VVVJitCode : public JitCode { +// function: vec = Operand(vec(scalar), vec(scalar)) (maybe with relu) +class VXXJitCode : public JitCode { public: const char* name() const override { - std::string base = "VVVJitCode"; + std::string base = "VXXJitCode"; if (type_ == operand_type::mul) { base += "_Mul"; } else if (type_ == operand_type::add) { @@ -44,18 +44,21 @@ class VVVJitCode : public JitCode { base += (with_relu_ ? "_Relu" : ""); return base.c_str(); } - explicit VVVJitCode(int d, operand_type type, bool with_relu, - size_t code_size = 256 * 1024, void* code_ptr = nullptr) + explicit VXXJitCode(int d, operand_type type, int scalar_index, + bool with_relu, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) : JitCode(code_size, code_ptr), num_(d), type_(type), + scalar_index_(scalar_index), with_relu_(with_relu) {} - static bool init(int d); + static bool init(int d, int scalar_index = 0); void generate() override; private: int num_; operand_type type_; + int scalar_index_; bool with_relu_; reg64_t param1{abi_param1}; reg64_t param2{abi_param2}; @@ -63,39 +66,13 @@ class VVVJitCode : public JitCode { xmm_t xmm_src1 = xmm_t(0); xmm_t xmm_src2 = xmm_t(1); - xmm_t xmm_dst = xmm_t(1); - xmm_t xmm_zero = xmm_t(2); + xmm_t xmm_dst = xmm_t(2); + xmm_t xmm_zero = xmm_t(3); ymm_t ymm_src1 = ymm_t(0); ymm_t ymm_src2 = ymm_t(1); - ymm_t ymm_dst = ymm_t(1); - ymm_t ymm_zero = ymm_t(2); -}; - -class VScalJitCode : public JitCode { - public: - DECLARE_JIT_CODE(VScalJitCode); - explicit VScalJitCode(int d, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d) {} - static bool init(int d); - void generate() override; - - private: - int num_; - reg64_t param1{abi_param1}; - reg64_t param2{abi_param2}; - reg64_t param3{abi_param3}; - - xmm_t xmm_src1 = xmm_t(0); - xmm_t xmm_src2 = xmm_t(1); - xmm_t xmm_dst = xmm_t(1); - xmm_t xmm_zero = xmm_t(2); - - ymm_t ymm_src1 = ymm_t(0); - ymm_t ymm_src2 = ymm_t(1); - ymm_t ymm_dst = ymm_t(1); - ymm_t ymm_zero = ymm_t(2); + ymm_t ymm_dst = ymm_t(2); + ymm_t ymm_zero = ymm_t(3); }; } // namespace gen diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index a9537ab096..ead4385cdb 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -131,7 +131,7 @@ class VMulKernelImpl : public VMulKernel { if (useJIT(d)) { // roughly estimate the size of code size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false, + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); @@ -150,14 +150,14 @@ class VMulKernelImpl : public VMulKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VMulKernelImpl::useJIT(int d) { - return gen::VVVJitCode::init(d); + return gen::VXXJitCode::init(d); } #endif @@ -182,7 +182,7 @@ class VAddKernelImpl : public VAddKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false, + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); @@ -200,14 +200,14 @@ class VAddKernelImpl : public VAddKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VAddKernelImpl::useJIT(int d) { - return gen::VVVJitCode::init(d); + return gen::VXXJitCode::init(d); } #endif @@ -232,7 +232,7 @@ class VAddReluKernelImpl : public VAddReluKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true, + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); @@ -244,14 +244,14 @@ class VAddReluKernelImpl : public VAddReluKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VAddReluKernelImpl::useJIT(int d) { - return gen::VVVJitCode::init(d); + return gen::VXXJitCode::init(d); } #endif @@ -264,7 +264,8 @@ class VScalKernelImpl : public VScalKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VScalJitCode(d, sz > 4096 ? sz : 4096)); + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false, + sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; @@ -281,14 +282,14 @@ class VScalKernelImpl : public VScalKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VScalKernelImpl::useJIT(int d) { - return gen::VScalJitCode::init(d); + return gen::VXXJitCode::init(d, 1); } #endif From 7fd640b88205d5c6c7d99e47ed6a40209225aae2 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Thu, 8 Nov 2018 07:41:10 +0100 Subject: [PATCH 11/23] added additional call to graph_viz_pass test=develop --- paddle/fluid/inference/analysis/analyzer.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index ef4142f334..559b3b6d21 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -101,6 +101,7 @@ Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); } void Analyzer::Run(Argument* argument) { std::vector passes; + passes.push_back("graph_viz_pass"); // add graphviz for debug. #ifdef PADDLE_WITH_MKLDNN if (use_mkldnn_) { VLOG(3) << "Adding MKL-DNN placement pass"; @@ -110,13 +111,13 @@ void Analyzer::Run(Argument* argument) { // infer_clean_graph_pass should be the first default pass // after mkldnn_placement_pass. passes.push_back("infer_clean_graph_pass"); + passes.push_back("graph_viz_pass"); // add graphviz for debug. for (auto& pass : ir_passes_) { if (!disabled_ir_passes_.count(pass)) { passes.push_back(pass); passes.push_back("graph_viz_pass"); // add graphviz for debug. } } - passes.push_back("graph_viz_pass"); argument->Set(kFluidToIrPassesAttr, new std::vector(passes)); for (auto& x : data_) { From 5e64244f250376666814816fc333c614cc8c085d Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 8 Nov 2018 07:32:39 +0000 Subject: [PATCH 12/23] add vaddbias jitcode test=develop --- paddle/fluid/operators/math/jit_code.h | 12 ++- paddle/fluid/operators/math/jit_kernel.h | 4 +- .../fluid/operators/math/jit_kernel_blas.cc | 84 ++++++++----------- paddle/fluid/operators/math/jit_kernel_exp.cc | 12 +-- .../fluid/operators/math/jit_kernel_test.cc | 10 +-- 5 files changed, 62 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 939d9897e6..aaedb0ae10 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -31,16 +31,26 @@ using Label = Xbyak::Label; typedef enum { mul = 0, add } operand_type; -// function: vec = Operand(vec(scalar), vec(scalar)) (maybe with relu) +// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) class VXXJitCode : public JitCode { public: const char* name() const override { std::string base = "VXXJitCode"; + if (scalar_index_ == 1) { + base += "_Scalar"; + } else { + base += "_Vec"; + } if (type_ == operand_type::mul) { base += "_Mul"; } else if (type_ == operand_type::add) { base += "_Add"; } + if (scalar_index_ == 2) { + base += "_Scalar"; + } else { + base += "_Vec"; + } base += (with_relu_ ? "_Relu" : ""); return base.c_str(); } diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 6ee651b988..e9b259282c 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -83,13 +83,15 @@ class VAddReluKernel : public Kernel { template class VScalKernel : public Kernel { public: + // y = a.*x void (*Compute)(const T *, const T *, T *, int); }; template class VAddBiasKernel : public Kernel { public: - virtual void Compute(const T a, const T *x, T *y) const = 0; + // y = a.+x + void (*Compute)(const T *, const T *, T *, int); }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 1f468a7fe3..d5e45cf7f4 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -60,6 +60,13 @@ void VScalRefer(const T* a, const T* x, T* y, int n) { } } +template +void VAddBiasRefer(const T* a, const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = a[0] + x[i]; + } +} + #ifdef PADDLE_WITH_MKLML template void VMulMKL(const T* x, const T* y, T* z, int n); @@ -300,62 +307,46 @@ bool VScalKernelImpl::useMKL(int d) { } #endif -#undef DECLARE_STATIC_FUNC - -REGISTER_JITKERNEL(vmul, VMulKernel); -REGISTER_JITKERNEL(vadd, VAddKernel); -REGISTER_JITKERNEL(vscal, VScalKernel); -REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); - /* VAddBias JitKernel */ -template +template class VAddBiasKernelImpl : public VAddBiasKernel { public: - explicit VAddBiasKernelImpl(int d) : VAddBiasKernel() { this->num_ = d; } - void Compute(const T a, const T* x, T* y) const override { - for (int i = 0; i < this->num_; ++i) { - y[i] = x[i] + a; + DECLARE_STATIC_FUNC; + explicit VAddBiasKernelImpl(int d) : VAddBiasKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false, + sz > 4096 ? sz : 4096)); + this->Compute = + jitcode_->getCode(); + return; } - } -}; - -#define INTRI8_FLOAT(isa) \ - template <> \ - void VAddBiasKernelImpl::Compute( \ - const float a, const float* x, float* y) const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \ - _mm256_storeu_ps(y, tmp); \ - } +#endif -#define INTRI16_FLOAT(isa) \ - template <> \ - void VAddBiasKernelImpl::Compute( \ - const float a, const float* x, float* y) const { \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \ - tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ + this->Compute = VAddBiasRefer; } +#ifdef PADDLE_WITH_XBYAK -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); + private: + std::unique_ptr jitcode_{nullptr}; #endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VAddBiasKernelImpl::useJIT(int d) { + return gen::VXXJitCode::init(d, 1); +} #endif -// TODO(TJ): eq16 test and complete avx512 -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT +#undef DECLARE_STATIC_FUNC + +REGISTER_JITKERNEL(vmul, VMulKernel); +REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); +REGISTER_JITKERNEL(vscal, VScalKernel); +REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); /* VRelu JitKernel */ template @@ -466,7 +457,6 @@ class VIdentityKernelImpl : public VIdentityKernel { void Compute(const T* x, T* y) const override {} }; -REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel); REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 5df17c11b4..fd507808cd 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -409,11 +409,11 @@ class VTanhKernelImpl : public VTanhKernel { vaddbias_ = KernelPool::Instance().template Get>(d); } void Compute(const T* x, T* y) const override { - const T a = static_cast(2); + const T a = static_cast(2), b = static_cast(-1); vscal_->Compute(&a, x, y, this->num_); vsigmoid_->Compute(y, y); vscal_->Compute(&a, y, y, this->num_); - vaddbias_->Compute(static_cast(-1), y, y); + vaddbias_->Compute(&b, y, y, this->num_); } private: @@ -473,11 +473,11 @@ class VTanhKernelImpl : public VTanhKernel { _mm256_storeu_ps(y, tmp); \ x += AVX_FLOAT_BLOCK; \ y += AVX_FLOAT_BLOCK; \ - const float a = 2.f; \ + const float a = 2.f, b = -1.f; \ vscal_->Compute(&a, x, y, this->num_); \ vsigmoid_->Compute(y, y); \ vscal_->Compute(&a, y, y, this->num_); \ - vaddbias_->Compute(-1.f, y, y); \ + vaddbias_->Compute(&b, y, y, this->num_); \ } #define INTRI_GT16_FLOAT(isa, expisa) \ @@ -504,11 +504,11 @@ class VTanhKernelImpl : public VTanhKernel { } \ x += this->end_; \ y += this->end_; \ - const float a = 2.f; \ + const float a = 2.f, b = -1.f; \ vscal_->Compute(&a, x, y, this->num_); \ vsigmoid_->Compute(y, y); \ vscal_->Compute(&a, y, y, this->num_); \ - vaddbias_->Compute(-1.f, y, y); \ + vaddbias_->Compute(&b, y, y, this->num_); \ } #ifndef __WIN32 diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 04a199faae..596bd3b2d3 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -128,7 +128,7 @@ TEST(JitKernel, vaddbias) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(a, x_data, ztgt_data); + ker->Compute(&a, x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); @@ -281,11 +281,11 @@ void vtanh_better( const paddle::operators::math::jitkernel::VAddBiasKernel>& vaddbias, const int n, const float* x, float* y) { - const float tmp1 = 2.f; - vscal->Compute(&tmp1, x, y, n); + const float a = 2.f, b = -1.f; + vscal->Compute(&a, x, y, n); vsigmoid->Compute(y, y); - vscal->Compute(&tmp1, y, y, n); - vaddbias->Compute(-1.f, y, y); + vscal->Compute(&a, y, y, n); + vaddbias->Compute(&b, y, y, n); } TEST(JitKernel, vtanh) { From 381bea0a16937a6c28b03aef04937688873fed3d Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Thu, 8 Nov 2018 16:09:28 +0800 Subject: [PATCH 13/23] fix test_analysis_predictor test=develop --- paddle/fluid/inference/api/CMakeLists.txt | 4 ++-- paddle/fluid/inference/api/analysis_predictor_tester.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 49a9ebe3dd..fd05c96777 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -37,8 +37,8 @@ if(WITH_TESTING) ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${PYTHON_TESTS_DIR}/book) set_tests_properties(test_api_impl PROPERTIES DEPENDS test_image_classification) endif() -cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor ${inference_deps} paddle_inference_api - ARGS --dirname=${PYTHON_TESTS_DIR}/book) +cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor ${inference_deps} + ARGS --dirname=${WORD2VEC_MODEL_DIR}) if(WITH_GPU AND TENSORRT_FOUND) cc_library(paddle_inference_tensorrt_subgraph_engine diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index 13c25da1b5..f75c45f3a0 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -24,7 +24,7 @@ using contrib::AnalysisConfig; TEST(AnalysisPredictor, ZeroCopy) { AnalysisConfig config; - config.model_dir = FLAGS_dirname + "/word2vec.inference.model"; + config.model_dir = FLAGS_dirname; config.use_feed_fetch_ops = false; auto predictor = CreatePaddlePredictor(config); From ba8b5619a3e1fff66fdecffbb9faf34a675a730c Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Thu, 8 Nov 2018 16:53:29 +0800 Subject: [PATCH 14/23] Revert "cherry picked windows patches." --- CMakeLists.txt | 4 +- cmake/cuda.cmake | 13 +-- cmake/cudnn.cmake | 3 +- cmake/external/boost.cmake | 2 +- cmake/external/gflags.cmake | 9 +- cmake/external/glog.cmake | 2 +- cmake/external/gtest.cmake | 2 +- cmake/external/openblas.cmake | 1 - cmake/flags.cmake | 7 +- cmake/generic.cmake | 9 -- cmake/inference_lib.cmake | 26 +---- cmake/version.cmake | 2 +- doc/fluid/dev/contribute_to_paddle_cn.md | 1 - doc/fluid/dev/contribute_to_paddle_en.md | 1 - .../development/contribute_to_paddle.md | 1 - .../development/cpu_profiling_cn.md | 1 - .../development/host_memory_profiling_cn.md | 1 - .../advanced_usage/development/new_op.md | 1 - .../advanced_usage/development/timeline_cn.md | 1 - doc/v2/dev/contribute_to_paddle_en.md | 2 +- paddle/fluid/framework/executor.cc | 15 --- paddle/fluid/framework/executor.h | 4 +- paddle/fluid/framework/ir/node.cc | 5 - paddle/fluid/framework/ir/node.h | 4 - paddle/fluid/framework/ir/pass.h | 27 ----- paddle/fluid/framework/tensor.h | 5 - paddle/fluid/inference/CMakeLists.txt | 8 -- paddle/fluid/inference/analysis/argument.h | 2 +- paddle/fluid/inference/analysis/helper.h | 26 ++++- paddle/fluid/inference/api/CMakeLists.txt | 1 - paddle/fluid/inference/api/api.cc | 1 + paddle/fluid/inference/api/api_impl.cc | 31 ++++-- paddle/fluid/inference/api/api_impl.h | 2 +- .../inference/api/demo_ci/CMakeLists.txt | 55 ++++++----- .../inference/api/demo_ci/inference_icnet.cc | 99 ------------------- paddle/fluid/inference/api/helper.h | 26 +++-- paddle/fluid/inference/api/timer.h | 39 -------- paddle/fluid/memory/detail/buddy_allocator.cc | 3 +- paddle/fluid/memory/detail/meta_cache.cc | 2 - .../fluid/memory/detail/system_allocator.cc | 1 - paddle/fluid/operators/CMakeLists.txt | 12 ++- paddle/fluid/operators/accuracy_op.h | 1 - paddle/fluid/operators/cast_op.h | 1 - .../detection/roi_perspective_transform_op.cu | 4 +- .../fluid/operators/elementwise_op_function.h | 1 + paddle/fluid/operators/load_combine_op.cc | 29 +----- paddle/fluid/operators/load_op.cc | 28 +----- paddle/fluid/operators/lstm_unit_op.h | 2 +- paddle/fluid/operators/math/CMakeLists.txt | 7 +- paddle/fluid/operators/math/cpu_vec.h | 4 + .../math/detail/activation_functions.h | 5 +- .../fluid/operators/math/jit_kernel_blas.cc | 4 + .../operators/math/jit_kernel_crf_decode.cc | 5 +- paddle/fluid/operators/math/jit_kernel_exp.cc | 20 ++-- paddle/fluid/operators/math/jit_kernel_rnn.cc | 4 + .../operators/math/selected_rows_functor.cu | 1 - .../fluid/operators/math/sequence_pooling.cu | 3 +- paddle/fluid/operators/print_op.cc | 1 - paddle/fluid/operators/save_combine_op.cc | 29 +----- paddle/fluid/operators/save_op.cc | 43 ++------ paddle/fluid/operators/split_lod_tensor_op.cc | 1 - paddle/fluid/operators/tensorrt_engine_op.h | 4 +- paddle/fluid/platform/cpu_info.h | 12 --- paddle/fluid/platform/cudnn_helper.h | 11 --- paddle/fluid/platform/device_context.cc | 3 +- paddle/fluid/platform/device_context.h | 15 +-- paddle/fluid/platform/enforce.h | 2 +- paddle/fluid/platform/init.cc | 2 - paddle/fluid/platform/macros.h | 13 --- paddle/fluid/platform/port.h | 6 +- 70 files changed, 199 insertions(+), 519 deletions(-) delete mode 120000 doc/fluid/dev/contribute_to_paddle_cn.md delete mode 120000 doc/fluid/dev/contribute_to_paddle_en.md delete mode 120000 doc/fluid/new_docs/advanced_usage/development/contribute_to_paddle.md delete mode 120000 doc/fluid/new_docs/advanced_usage/development/cpu_profiling_cn.md delete mode 120000 doc/fluid/new_docs/advanced_usage/development/host_memory_profiling_cn.md delete mode 120000 doc/fluid/new_docs/advanced_usage/development/new_op.md delete mode 120000 doc/fluid/new_docs/advanced_usage/development/timeline_cn.md delete mode 100644 paddle/fluid/inference/api/demo_ci/inference_icnet.cc delete mode 100644 paddle/fluid/inference/api/timer.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4625516458..ed704585d8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,7 +26,6 @@ message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: " "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") if(WIN32) set(CMAKE_STATIC_LIBRARY_PREFIX lib) - set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} "/MT") #create multithread dynamic library endif(WIN32) if(NOT CMAKE_CROSSCOMPILING) @@ -34,6 +33,7 @@ if(NOT CMAKE_CROSSCOMPILING) endif(NOT CMAKE_CROSSCOMPILING) find_package(Git REQUIRED) find_package(Threads REQUIRED) + include(simd) ################################ Configurations ####################################### @@ -178,10 +178,10 @@ include(external/eigen) # download eigen3 include(external/pybind11) # download pybind11 include(external/cares) include(external/cub) +include(external/xxhash) # download xxhash if (NOT WIN32) # there is no official support of snappystream, warpctc, nccl, cupti in windows -include(external/xxhash) # download xxhash include(external/snappy) # download snappy include(external/snappystream) # download snappystream include(external/warpctc) # download, build, install warpctc diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 564878131c..f507bb41a1 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -169,21 +169,18 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # So, don't set these flags here. - if (NOT WIN32) # windows msvc2015 support c++11 natively. -# -std=c++11 -fPIC not recoginize by msvc +# -std=c++11 -fPIC not recoginize by msvc, -Xcompiler will be added by cmake. list(APPEND CUDA_NVCC_FLAGS "-std=c++11") -# in cuda9, suppress cuda warning on eigen with "-w" -list(APPEND CUDA_NVCC_FLAGS "-w" "-Xcompiler -fPIC") -else(NOT WIN32) -list(APPEND CUDA_NVCC_FLAGS "-w" "-Xcompiler -fPIC" "-Xcompiler /w") +list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC") endif(NOT WIN32) if(WITH_FAST_MATH) # Make use of fast math library. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html list(APPEND CUDA_NVCC_FLAGS "--use_fast_math") -endif(WITH_FAST_MATH) - +endif() +# in cuda9, suppress cuda warning on eigen +list(APPEND CUDA_NVCC_FLAGS "-w") # Set :expt-relaxed-constexpr to suppress Eigen warnings list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr") diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index 813611b032..cd51533926 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -48,6 +48,7 @@ find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} # libcudnn_static.a NO_DEFAULT_PATH DOC "Path to cuDNN library.") + if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY) set(CUDNN_FOUND ON) else() @@ -82,7 +83,7 @@ if(CUDNN_FOUND) if(NOT CUDNN_MAJOR_VERSION) set(CUDNN_VERSION "???") - else() + else() math(EXPR CUDNN_VERSION "${CUDNN_MAJOR_VERSION} * 1000 + ${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCHLEVEL_VERSION}") diff --git a/cmake/external/boost.cmake b/cmake/external/boost.cmake index 65f55b64ca..ada61de8eb 100644 --- a/cmake/external/boost.cmake +++ b/cmake/external/boost.cmake @@ -48,7 +48,7 @@ ExternalProject_Add( DOWNLOAD_DIR ${BOOST_DOWNLOAD_DIR} DOWNLOAD_COMMAND wget --no-check-certificate ${BOOST_URL} -c -q -O ${BOOST_TAR}.tar.gz && tar zxf ${BOOST_TAR}.tar.gz -DOWNLOAD_NO_PROGRESS 1 + DOWNLOAD_NO_PROGRESS 1 PREFIX ${BOOST_SOURCES_DIR} CONFIGURE_COMMAND "" BUILD_COMMAND "" diff --git a/cmake/external/gflags.cmake b/cmake/external/gflags.cmake index 0d4cecd4de..cf58cc3976 100644 --- a/cmake/external/gflags.cmake +++ b/cmake/external/gflags.cmake @@ -35,9 +35,7 @@ ExternalProject_Add( CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} - -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} - -DBUILD_STATIC_LIBS=ON -DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR} -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_TESTING=OFF @@ -47,10 +45,6 @@ ExternalProject_Add( -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} ) - -ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL) -SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES}) -ADD_DEPENDENCIES(gflags extern_gflags) IF(WIN32) IF(NOT EXISTS "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib") add_custom_command(TARGET extern_gflags POST_BUILD @@ -58,6 +52,9 @@ IF(WIN32) ) ENDIF() ENDIF(WIN32) +ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES}) +ADD_DEPENDENCIES(gflags extern_gflags) LIST(APPEND external_project_dependencies gflags) diff --git a/cmake/external/glog.cmake b/cmake/external/glog.cmake index a205d4ec77..25ef2970ac 100644 --- a/cmake/external/glog.cmake +++ b/cmake/external/glog.cmake @@ -34,6 +34,7 @@ ELSE() SET(GLOG_REPOSITORY "https://github.com/google/glog.git") SET(GLOG_TAG "v0.3.5") ENDIF() + ExternalProject_Add( extern_glog ${EXTERNAL_PROJECT_LOG_ARGS} @@ -45,7 +46,6 @@ ExternalProject_Add( CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} - -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} -DCMAKE_INSTALL_PREFIX=${GLOG_INSTALL_DIR} -DCMAKE_INSTALL_LIBDIR=${GLOG_INSTALL_DIR}/lib diff --git a/cmake/external/gtest.cmake b/cmake/external/gtest.cmake index bfb04916dc..d335298742 100644 --- a/cmake/external/gtest.cmake +++ b/cmake/external/gtest.cmake @@ -51,7 +51,6 @@ IF(WITH_TESTING) -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} - -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR} -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_GMOCK=ON @@ -71,5 +70,6 @@ IF(WITH_TESTING) ADD_LIBRARY(gtest_main STATIC IMPORTED GLOBAL) SET_PROPERTY(TARGET gtest_main PROPERTY IMPORTED_LOCATION ${GTEST_MAIN_LIBRARIES}) ADD_DEPENDENCIES(gtest_main extern_gtest) + LIST(APPEND external_project_dependencies gtest gtest_main) ENDIF(WITH_TESTING) diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index abc906d31f..755dbd610c 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -124,7 +124,6 @@ INCLUDE_DIRECTORIES(${CBLAS_INC_DIR}) # linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas) SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c) FILE(WRITE ${dummyfile} "const char *dummy_cblas = \"${dummyfile}\";") - ADD_LIBRARY(cblas STATIC ${dummyfile}) IF("${CBLAS_PROVIDER}" STREQUAL "MKLML") diff --git a/cmake/flags.cmake b/cmake/flags.cmake index a652b844c6..343e44ab4b 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -144,14 +144,11 @@ set(GPU_COMMON_FLAGS -Wno-error=unused-function # Warnings in Numpy Header. -Wno-error=array-bounds # Warnings in Eigen::array ) + else(NOT WIN32) set(COMMON_FLAGS - -fPIC - -fno-omit-frame-pointer "/w") #disable all warnings. set(GPU_COMMON_FLAGS - -fPIC - -fno-omit-frame-pointer "/w") #disable all warnings endif(NOT WIN32) @@ -167,8 +164,8 @@ endif(APPLE) if(LINUX) set(GPU_COMMON_FLAGS -Wall - -Werror -Wextra + -Werror ${GPU_COMMON_FLAGS}) endif(LINUX) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 7421a012a1..62227c6784 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -238,7 +238,6 @@ function(cc_library TARGET_NAME) # add libxxx.lib prefix in windows set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}") endif(WIN32) - if(cc_library_SRCS) if(cc_library_SHARED OR cc_library_shared) # build *.so add_library(${TARGET_NAME} SHARED ${cc_library_SRCS}) @@ -308,11 +307,7 @@ function(cc_test TARGET_NAME) set(multiValueArgs SRCS DEPS ARGS) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) add_executable(${TARGET_NAME} ${cc_test_SRCS}) - if(WIN32) # in windows deps. shlwapi library. - target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog shlwapi) - else(WIN32) target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) - endif(WIN32) add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME} ${cc_test_ARGS} @@ -383,11 +378,7 @@ function(nv_test TARGET_NAME) set(multiValueArgs SRCS DEPS) cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) - if(WIN32) - target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog shlwapi) - else(WIN32) target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) - endif(WIN32) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_test(${TARGET_NAME} ${TARGET_NAME}) if (nv_test_SERIAL) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 72ce7070c8..efdb093a7b 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -31,31 +31,10 @@ function(copy TARGET) foreach(index RANGE ${len}) list(GET copy_lib_SRCS ${index} src) list(GET copy_lib_DSTS ${index} dst) - if (WIN32) - # windows cmd shell will not expand wildcard automatically. - # below expand the files,libs and copy them by rules. - file(GLOB header_files ${src} "*.h") - file(GLOB static_lib_files ${src} "*.lib") - file(GLOB dll_lib_files ${src} "*.dll") - set(src_files ${header_files} ${static_lib_files} ${dll_lib_files}) - - if (NOT "${src_files}" STREQUAL "") - list(REMOVE_DUPLICATES src_files) - endif() - add_custom_command(TARGET ${TARGET} PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E make_directory "${dst}" - ) - foreach(src_file ${src_files}) - add_custom_command(TARGET ${TARGET} PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E copy "${src_file}" "${dst}" - COMMENT "copying ${src_file} -> ${dst}") - endforeach() - else(WIN32) # not windows - add_custom_command(TARGET ${TARGET} PRE_BUILD + add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND mkdir -p "${dst}" COMMAND cp -r "${src}" "${dst}" COMMENT "copying ${src} -> ${dst}") - endif(WIN32) endforeach() endfunction() @@ -87,14 +66,13 @@ copy(boost_lib DSTS ${dst_dir} DEPS boost ) -if(NOT WIN32) + set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/xxhash") copy(xxhash_lib SRCS ${XXHASH_INCLUDE_DIR} ${XXHASH_LIBRARIES} DSTS ${dst_dir} ${dst_dir}/lib DEPS xxhash ) -endif(NOT WIN32) if(NOT PROTOBUF_FOUND) set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/protobuf") diff --git a/cmake/version.cmake b/cmake/version.cmake index fbf559f76b..ac10bdf067 100644 --- a/cmake/version.cmake +++ b/cmake/version.cmake @@ -44,5 +44,5 @@ while ("${PADDLE_VERSION}" STREQUAL "") endif() endwhile() -add_definitions(-DPADDLE_VERSION="${PADDLE_VERSION}") +add_definitions(-DPADDLE_VERSION=${PADDLE_VERSION}) message(STATUS "Paddle version is ${PADDLE_VERSION}") diff --git a/doc/fluid/dev/contribute_to_paddle_cn.md b/doc/fluid/dev/contribute_to_paddle_cn.md deleted file mode 120000 index bcb71b3da1..0000000000 --- a/doc/fluid/dev/contribute_to_paddle_cn.md +++ /dev/null @@ -1 +0,0 @@ -../../v2/dev/contribute_to_paddle_cn.md diff --git a/doc/fluid/dev/contribute_to_paddle_en.md b/doc/fluid/dev/contribute_to_paddle_en.md deleted file mode 120000 index 16679a4063..0000000000 --- a/doc/fluid/dev/contribute_to_paddle_en.md +++ /dev/null @@ -1 +0,0 @@ -../../v2/dev/contribute_to_paddle_en.md diff --git a/doc/fluid/new_docs/advanced_usage/development/contribute_to_paddle.md b/doc/fluid/new_docs/advanced_usage/development/contribute_to_paddle.md deleted file mode 120000 index 9f1af6133f..0000000000 --- a/doc/fluid/new_docs/advanced_usage/development/contribute_to_paddle.md +++ /dev/null @@ -1 +0,0 @@ -../../../dev/contribute_to_paddle_cn.md diff --git a/doc/fluid/new_docs/advanced_usage/development/cpu_profiling_cn.md b/doc/fluid/new_docs/advanced_usage/development/cpu_profiling_cn.md deleted file mode 120000 index 8c13564629..0000000000 --- a/doc/fluid/new_docs/advanced_usage/development/cpu_profiling_cn.md +++ /dev/null @@ -1 +0,0 @@ -../../../howto/optimization/cpu_profiling_cn.md diff --git a/doc/fluid/new_docs/advanced_usage/development/host_memory_profiling_cn.md b/doc/fluid/new_docs/advanced_usage/development/host_memory_profiling_cn.md deleted file mode 120000 index 5501686e98..0000000000 --- a/doc/fluid/new_docs/advanced_usage/development/host_memory_profiling_cn.md +++ /dev/null @@ -1 +0,0 @@ -../../../howto/optimization/host_memory_profiling_cn.md diff --git a/doc/fluid/new_docs/advanced_usage/development/new_op.md b/doc/fluid/new_docs/advanced_usage/development/new_op.md deleted file mode 120000 index a0d1af57ba..0000000000 --- a/doc/fluid/new_docs/advanced_usage/development/new_op.md +++ /dev/null @@ -1 +0,0 @@ -../../../dev/new_op_cn.md diff --git a/doc/fluid/new_docs/advanced_usage/development/timeline_cn.md b/doc/fluid/new_docs/advanced_usage/development/timeline_cn.md deleted file mode 120000 index 1a782fd363..0000000000 --- a/doc/fluid/new_docs/advanced_usage/development/timeline_cn.md +++ /dev/null @@ -1 +0,0 @@ -../../../howto/optimization/timeline_cn.md diff --git a/doc/v2/dev/contribute_to_paddle_en.md b/doc/v2/dev/contribute_to_paddle_en.md index 7272339644..c97564d93a 120000 --- a/doc/v2/dev/contribute_to_paddle_en.md +++ b/doc/v2/dev/contribute_to_paddle_en.md @@ -1 +1 @@ -../../../CONTRIBUTING.md +../../../CONTRIBUTING.md \ No newline at end of file diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 93624b76ec..8ed0ba1dfa 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include - #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/feed_fetch_method.h" @@ -48,7 +46,6 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { VLOG(5) << "destroy ExecutorPrepareContext"; } -#ifndef _WIN32 template static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, GarbageCollector* gc, @@ -83,7 +80,6 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, gc->Add(erase_tensors); } } -#endif Executor::Executor(const platform::Place& place) : place_(place) {} @@ -371,7 +367,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, CreateVariables(ctx->prog_, local_scope, ctx->block_id_); } -#ifndef _WIN32 int64_t max_memory_size = GetEagerDeletionThreshold(); std::unique_ptr> gc; // WhileOp would set keep_kids to false @@ -413,16 +408,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } else { platform::DeviceContextPool::Instance().Get(place_)->Wait(); } -#else // WIN32 - for (auto& op : ctx->ops_) { - op->Run(*local_scope, place_); - if (FLAGS_benchmark) { - VLOG(2) << "Memory used after operator " + op->Type() + " running: " - << memory::memory_usage(place_); - } - } - platform::DeviceContextPool::Instance().Get(place_)->Wait(); -#endif // NOT WIN32 if (local_scope != scope) { scope->DeleteScope(local_scope); diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index a2a6c6bfb1..36b36d49c2 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -17,14 +17,12 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" -#ifndef _WIN32 -#include "paddle/fluid/framework/garbage_collector.h" -#endif namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc index 30879b1f36..9277abe8c1 100644 --- a/paddle/fluid/framework/ir/node.cc +++ b/paddle/fluid/framework/ir/node.cc @@ -17,12 +17,7 @@ limitations under the License. */ namespace paddle { namespace framework { namespace ir { -// msvc15 don't support constexpr in correct way. -#if !defined(_WIN32) constexpr char Node::kControlDepVarName[]; -#else -const char Node::kControlDepVarName[] = "__control_var"; -#endif int Node::count_ = 0; std::unique_ptr CreateNodeForTest(const std::string& name, diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index a3be133344..d6d42f5e92 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -28,11 +28,7 @@ namespace ir { class Node { public: enum class Type { kOperation, kVariable }; -#if !defined(_WIN32) // msvc not support constexpr correctly. static constexpr char kControlDepVarName[] = "__control_var"; -#else - static const char kControlDepVarName[]; -#endif Type NodeType() const { return type_; } diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index ddbe0ddc12..9570c59cff 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -21,7 +21,6 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/variant.h" namespace paddle { @@ -196,7 +195,6 @@ struct PassRegistrar : public Registrar { __test_global_namespace_##uniq_name##__>::value, \ msg) -#if !defined(_WIN32) // Register a new pass that can be applied on the IR. #define REGISTER_PASS(pass_type, pass_class) \ STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ @@ -219,32 +217,7 @@ struct PassRegistrar : public Registrar { extern int TouchPassRegistrar_##pass_type(); \ static int use_pass_itself_##pass_type##_ __attribute__((unused)) = \ TouchPassRegistrar_##pass_type() -#else -// windows version of __attribute__((unused)) -#define UNUSED(x) __pragma(warning(suppress : 4100)) x -#define REGISTER_PASS(pass_type, pass_class) \ - STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ - __reg_pass__##pass_type, \ - "REGISTER_PASS must be called in global namespace"); \ - static ::paddle::framework::ir::PassRegistrar \ - __pass_registrar_##pass_type##__(#pass_type); \ - int TouchPassRegistrar_##pass_type() { \ - __pass_registrar_##pass_type##__.Touch(); \ - return 0; \ - } \ - static ::paddle::framework::ir::PassRegistrar UNUSED( \ - &__pass_tmp_registrar_##pass_type##__) = \ - __pass_registrar_##pass_type##__ - -#define USE_PASS(pass_type) \ - STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ - __use_pass_itself_##pass_type, \ - "USE_PASS must be called in global namespace"); \ - extern int TouchPassRegistrar_##pass_type(); \ - static int UNUSED(use_pass_itself_##pass_type##_) = \ - TouchPassRegistrar_##pass_type() -#endif // !_WIN32 } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index dd984445db..f1d2685485 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -20,11 +20,6 @@ limitations under the License. */ #include #include -#if defined(_WIN32) -#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h -#define GOOGLE_GLOG_DLL_DECL -#endif - #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/memory/memory.h" diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index ad023ec46c..e5678cf607 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -16,10 +16,6 @@ cc_library(paddle_fluid_api DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) -get_property(fluid_third_partys GLOBAL PROPERTY FLUID_THRID_PARTYS) -if (WIN32) -list(APPEND fluid_third_partys gflags glog protobuf cblas) -endif(WIN32) # paddle_fluid_origin exclude inference api interface cc_library(paddle_fluid_origin DEPS ${fluid_modules} paddle_fluid_api) @@ -37,11 +33,7 @@ if (WITH_GPU AND TENSORRT_FOUND) endif() # Create static library -if (WIN32) -cc_library(paddle_fluid DEPS ${fluid_modules} ${fluid_third_partys} paddle_fluid_api paddle_inference_api) -else(WIND32) cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array) -endif(WIN32) if(NOT APPLE) # TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac. diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 3242aced39..e8fb0775b4 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -26,7 +26,6 @@ #include #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h" -#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/variant.h" namespace paddle { @@ -103,6 +102,7 @@ struct Argument { std::unordered_map> attr_deleters_; }; +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) #define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \ if (UNLIKELY(!(field__))) { \ LOG(ERROR) << "field " << #field__ << " should be set."; \ diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index e20ddfa24f..5151e2b69a 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -25,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/port.h" namespace paddle { namespace inference { @@ -124,6 +124,20 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) { return *var->GetMutable(); } +static void ExecShellCommand(const std::string &cmd, std::string *message) { + char buffer[128]; + std::shared_ptr pipe(popen(cmd.c_str(), "r"), pclose); + if (!pipe) { + LOG(ERROR) << "error running command: " << cmd; + return; + } + while (!feof(pipe.get())) { + if (fgets(buffer, 128, pipe.get()) != nullptr) { + *message += buffer; + } + } +} + static framework::proto::ProgramDesc LoadProgramDesc( const std::string &model_path) { std::ifstream fin(model_path, std::ios::in | std::ios::binary); @@ -145,6 +159,16 @@ static bool FileExists(const std::string &filepath) { return exists; } +static bool PathExists(const std::string &path) { + struct stat statbuf; + if (stat(path.c_str(), &statbuf) != -1) { + if (S_ISDIR(statbuf.st_mode)) { + return true; + } + } + return false; +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 5e55acf892..49a9ebe3dd 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -24,7 +24,6 @@ if(WITH_GPU AND TENSORRT_FOUND) endif() cc_library(reset_tensor_array SRCS details/reset_tensor_array.cc DEPS lod_tensor scope) -cc_library(helper SRCS helper.cc DEPS reset_tensor_array lod_tensor scope) cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS reset_tensor_array lod_tensor scope) cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor) cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS paddle_inference_api) diff --git a/paddle/fluid/inference/api/api.cc b/paddle/fluid/inference/api/api.cc index 20fab8078f..01ea942d3c 100644 --- a/paddle/fluid/inference/api/api.cc +++ b/paddle/fluid/inference/api/api.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle_inference_api.h" namespace paddle { diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 27f272f2d8..d06ab8f8c8 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include #include #include #include @@ -25,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/helper.h" -#include "paddle/fluid/inference/api/timer.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/profiler.h" @@ -33,6 +31,16 @@ DEFINE_bool(profile, false, "Turn on profiler for fluid"); DECLARE_int32(paddle_num_threads); namespace paddle { +namespace { +using paddle::inference::Timer; + +template +std::string num2str(T a) { + std::stringstream istr; + istr << a; + return istr.str(); +} +} // namespace void NativePaddlePredictor::PrepareFeedFetch() { for (auto *op : inference_program_->Block(0).AllOps()) { @@ -55,6 +63,7 @@ void NativePaddlePredictor::PrepareFeedFetch() { bool NativePaddlePredictor::Init( std::shared_ptr parent_scope) { + VLOG(3) << "Predictor::init()"; #if !defined(_WIN32) if (FLAGS_profile) { LOG(WARNING) << "Profiler is actived, might affect the performance"; @@ -82,21 +91,21 @@ bool NativePaddlePredictor::Init( paddle::framework::InitDevices(false); scope_.reset(new paddle::framework::Scope()); } + executor_.reset(new paddle::framework::Executor(place_)); + // Initialize the inference program if (!config_.model_dir.empty()) { // Parameters are saved in separate files sited in // the specified `dirname`. inference_program_ = paddle::inference::Load(executor_.get(), scope_.get(), config_.model_dir); - } else if (!config_.prog_file.empty() && !config_.param_file.empty()) { // All parameters are saved in a single file. // The file names should be consistent with that used // in Python API `fluid.io.save_inference_model`. inference_program_ = paddle::inference::Load( executor_.get(), scope_.get(), config_.prog_file, config_.param_file); - } else { LOG(ERROR) << "fail to load inference model from " << config_.model_dir; return false; @@ -126,7 +135,7 @@ NativePaddlePredictor::~NativePaddlePredictor() { bool NativePaddlePredictor::Run(const std::vector &inputs, std::vector *output_data, int batch_size) { - using Timer = paddle::inference::Timer; + VLOG(3) << "Predictor::predict"; Timer timer; timer.tic(); // set feed variable @@ -138,9 +147,11 @@ bool NativePaddlePredictor::Run(const std::vector &inputs, } // Run the inference program // if share variables, we need not create variables + VLOG(4) << "Run prepared context"; executor_->RunPreparedContext(ctx_.get(), scope, false, /* don't create local scope each time*/ false /* don't create variable each time */); + VLOG(4) << "Finish prepared context"; // get fetch variable if (!GetFetch(output_data, scope)) { LOG(ERROR) << "fail to get fetches"; @@ -155,6 +166,7 @@ bool NativePaddlePredictor::Run(const std::vector &inputs, } std::unique_ptr NativePaddlePredictor::Clone() { + VLOG(3) << "Predictor::clone"; std::unique_ptr cls(new NativePaddlePredictor(config_)); if (!dynamic_cast(cls.get())->Init(scope_)) { @@ -172,6 +184,7 @@ std::unique_ptr NativePaddlePredictor::Clone() { bool NativePaddlePredictor::SetFeed(const std::vector &inputs, framework::Scope *scope) { + VLOG(3) << "Predictor::set_feed"; if (inputs.size() != feeds_.size()) { LOG(ERROR) << "wrong feed input size, need " << feeds_.size() << " but get " << inputs.size(); @@ -231,6 +244,7 @@ void NativePaddlePredictor::GetFetchOne(const framework::LoDTensor &fetch, bool NativePaddlePredictor::GetFetch(std::vector *outputs, framework::Scope *scope) { + VLOG(3) << "Predictor::get_fetch"; outputs->resize(fetchs_.size()); for (size_t i = 0; i < fetchs_.size(); ++i) { int idx = boost::get(fetchs_[i]->GetAttr("col")); @@ -255,22 +269,25 @@ bool NativePaddlePredictor::GetFetch(std::vector *outputs, template <> std::unique_ptr CreatePaddlePredictor< NativeConfig, PaddleEngineKind::kNative>(const NativeConfig &config) { + VLOG(3) << "create NativePaddlePredictor"; if (config.use_gpu) { // 1. GPU memeroy PADDLE_ENFORCE_GT( config.fraction_of_gpu_memory, 0.f, - "fraction_of_gpu_memory in the config should be set to range (0.,1.]"); + "fraction_of_gpu_memory in the config should be set to range (0., 1.]"); PADDLE_ENFORCE_GE(config.device, 0, "Invalid device id %d", config.device); std::vector flags; if (config.fraction_of_gpu_memory >= 0.0f || config.fraction_of_gpu_memory <= 0.95f) { flags.push_back("dummpy"); std::string flag = "--fraction_of_gpu_memory_to_use=" + - std::to_string(config.fraction_of_gpu_memory); + num2str(config.fraction_of_gpu_memory); flags.push_back(flag); + VLOG(3) << "set flag: " << flag; framework::InitGflags(flags); } } + std::unique_ptr predictor(new NativePaddlePredictor(config)); if (!dynamic_cast(predictor.get())->Init(nullptr)) { return nullptr; diff --git a/paddle/fluid/inference/api/api_impl.h b/paddle/fluid/inference/api/api_impl.h index ed3bdd8de7..4e4ab47ca9 100644 --- a/paddle/fluid/inference/api/api_impl.h +++ b/paddle/fluid/inference/api/api_impl.h @@ -31,10 +31,10 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/inference/api/details/reset_tensor_array.h" +#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/profiler.h" -#include "paddle_inference_api.h" // NOLINT namespace paddle { diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index a742ba71ee..49683eab07 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -6,13 +6,13 @@ option(WITH_STATIC_LIB "Compile demo with static/shared library, default use sta option(USE_TENSORRT "Compile demo with TensorRT." OFF) macro(safe_set_static_flag) - foreach(flag_var - CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE - CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) - if(${flag_var} MATCHES "/MD") - string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") - endif(${flag_var} MATCHES "/MD") - endforeach(flag_var) + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + if(${flag_var} MATCHES "/MD") + string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") + endif(${flag_var} MATCHES "/MD") + endforeach(flag_var) endmacro() if (WIN32) @@ -37,25 +37,26 @@ if(NOT DEFINED DEMO_NAME) endif() -if(WITH_GPU) # default gpu path +if(WITH_GPU) if(NOT WIN32) set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library") else() if(CUDA_LIB STREQUAL "") - set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64") + set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64") endif() endif(NOT WIN32) endif() +include_directories("D:/Paddle/") include_directories("${PADDLE_LIB}") include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") include_directories("${PADDLE_LIB}/third_party/install/glog/include") include_directories("${PADDLE_LIB}/third_party/install/gflags/include") include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") if (NOT WIN32) - include_directories("${PADDLE_LIB}/third_party/install/snappy/include") - include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") - include_directories("${PADDLE_LIB}/third_party/install/zlib/include") +include_directories("${PADDLE_LIB}/third_party/install/snappy/include") +include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") +include_directories("${PADDLE_LIB}/third_party/install/zlib/include") endif(NOT WIN32) include_directories("${PADDLE_LIB}/third_party/boost") @@ -63,15 +64,15 @@ include_directories("${PADDLE_LIB}/third_party/eigen3") if (NOT WIN32) if (USE_TENSORRT AND WITH_GPU) - include_directories("${TENSORRT_INCLUDE_DIR}") - link_directories("${TENSORRT_LIB_DIR}") + include_directories("${TENSORRT_INCLUDE_DIR}") + link_directories("${TENSORRT_LIB_DIR}") endif() endif(NOT WIN32) if (NOT WIN32) - link_directories("${PADDLE_LIB}/third_party/install/snappy/lib") - link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") - link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") +link_directories("${PADDLE_LIB}/third_party/install/snappy/lib") +link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") +link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") endif(NOT WIN32) link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") @@ -85,7 +86,7 @@ add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) if(WITH_MKL) include_directories("${PADDLE_LIB}/third_party/install/mklml/include") set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} - ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) + ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") if(EXISTS ${MKLDNN_PATH}) include_directories("${MKLDNN_PATH}/include") @@ -98,25 +99,25 @@ endif() # Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a if(WITH_STATIC_LIB) set(DEPS - ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX}) + ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX}) else() set(DEPS - ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX}) + ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX}) endif() if (NOT WIN32) - set(EXTERNAL_LIB "-lrt -ldl -lpthread") - set(DEPS ${DEPS} +set(EXTERNAL_LIB "-lrt -ldl -lpthread") +set(DEPS ${DEPS} ${MATH_LIB} ${MKLDNN_LIB} glog gflags protobuf snappystream snappy z xxhash ${EXTERNAL_LIB}) else() - set(DEPS ${DEPS} +set(DEPS ${DEPS} ${MATH_LIB} ${MKLDNN_LIB} ${CMAKE_STATIC_LIBRARY_PREFIX}glog ${CMAKE_STATIC_LIBRARY_PREFIX}gflags ${CMAKE_STATIC_LIBRARY_PREFIX}protobuf ${EXTERNAL_LIB}) - # NOTE(dzhwinter) shlwapi will be deprecated. - set(DEPS ${DEPS} libcmt shlwapi) +# NOTE(dzhwinter) shlwapi is deprecated. +set(DEPS ${DEPS} libcmt shlwapi) endif(NOT WIN32) if(WITH_GPU) @@ -128,8 +129,8 @@ if(WITH_GPU) set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) else() set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) - set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} ) - set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} ) endif() endif() diff --git a/paddle/fluid/inference/api/demo_ci/inference_icnet.cc b/paddle/fluid/inference/api/demo_ci/inference_icnet.cc deleted file mode 100644 index 88e220c0b6..0000000000 --- a/paddle/fluid/inference/api/demo_ci/inference_icnet.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#define GOOGLE_GLOG_DLL_DECL -#include -#include -#include // NOLINT -#include -#include -#include // NOLINT -#include -#include "paddle/fluid/inference/paddle_inference_api.h" - -namespace paddle { - -NativeConfig GetConfig() { - NativeConfig config; - config.prog_file = "hs_lb_without_bn_cudnn/__model__"; - config.param_file = "hs_lb_without_bn_cudnn/__params__"; - config.fraction_of_gpu_memory = 0.0; - config.use_gpu = true; - config.device = 0; - return config; -} - -using Time = decltype(std::chrono::high_resolution_clock::now()); -Time TimeNow() { return std::chrono::high_resolution_clock::now(); } -double TimeDiff(Time t1, Time t2) { - typedef std::chrono::microseconds ms; - auto diff = t2 - t1; - ms counter = std::chrono::duration_cast(diff); - return counter.count() / 1000.0; -} - -std::vector PrepareData() { - int height = 449; - int width = 581; - std::vector data; - for (int i = 0; i < 3 * height * width; ++i) { - data.push_back(0.0); - } - PaddleTensor tensor; - tensor.shape = std::vector({batch_size, 3, height, width}); - tensor.data.Resize(sizeof(float) * batch_size * 3 * height * width); - std::copy(data.begin(), data.end(), static_cast(tensor.data.data())); - tensor.dtype = PaddleDType::FLOAT32; - std::vector paddle_tensor_feeds(1, tensor); - return std::move(paddle_tensor_feeds); -} - -void TestNaive(int batch_size, int thread_num) { - NativeConfig config = GetConfig(); - - int num_jobs = thread_num; // parallel jobs. - constexpr int epoches = 10; // each job run epoches. - std::vector threads; - std::vector> predictors; - for (int tid = 0; tid < num_jobs; ++tid) { - auto& pred = CreatePaddlePredictor(config); - predictors.emplace_back(std::move(pred)); - } - - auto time1 = TimeNow(); - for (int tid = 0; tid < num_jobs; ++tid) { - threads.emplace_back([&, tid]() { - auto& predictor = predictors[tid]; - PaddleTensor tensor_out; - std::vector outputs(1, tensor_out); - for (size_t i = 0; i < epoches; i++) { - ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); - VLOG(3) << "tid : " << tid << " run: " << i << "finished"; - ASSERT_EQ(outputs.size(), 1UL); - } - }); - } - for (int i = 0; i < num_jobs; ++i) { - threads[i].join(); - } - auto time2 = TimeNow(); - VLOG(3) << "Thread num " << thread_num << "total time cost" - << (time2 - time1); -} -} // namespace paddle - -int main(int argc, char** argv) { - paddle::TestNaive(1, 1); // single thread. - paddle::TestNaive(1, 5); // 5 threads. - return 0; -} diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index f5c83bcd54..e46dc13269 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -14,22 +14,36 @@ #pragma once -#define GLOG_NO_ABBREVIATED_SEVERITIES -#define GOOGLE_GLOG_DLL_DECL #include -#include +#include #include // NOLINT -#include #include #include #include #include -#include "paddle/fluid/inference/api/timer.h" -#include "paddle_inference_api.h" //NOLINT +#include "paddle/fluid/string/printf.h" +#include "paddle_inference_api.h" namespace paddle { namespace inference { +// Timer for timer +class Timer { + public: + std::chrono::high_resolution_clock::time_point start; + std::chrono::high_resolution_clock::time_point startu; + + void tic() { start = std::chrono::high_resolution_clock::now(); } + double toc() { + startu = std::chrono::high_resolution_clock::now(); + std::chrono::duration time_span = + std::chrono::duration_cast>(startu - + start); + double used_time_ms = static_cast(time_span.count()) * 1000.0; + return used_time_ms; + } +}; + static void split(const std::string &str, char sep, std::vector *pieces) { pieces->clear(); diff --git a/paddle/fluid/inference/api/timer.h b/paddle/fluid/inference/api/timer.h deleted file mode 100644 index 2df5274dc1..0000000000 --- a/paddle/fluid/inference/api/timer.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include // NOLINT - -namespace paddle { -namespace inference { - -// Timer for timer -class Timer { - public: - std::chrono::high_resolution_clock::time_point start; - std::chrono::high_resolution_clock::time_point startu; - - void tic() { start = std::chrono::high_resolution_clock::now(); } - double toc() { - startu = std::chrono::high_resolution_clock::now(); - std::chrono::duration time_span = - std::chrono::duration_cast>(startu - - start); - double used_time_ms = static_cast(time_span.count()) * 1000.0; - return used_time_ms; - } -}; - -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/memory/detail/buddy_allocator.cc b/paddle/fluid/memory/detail/buddy_allocator.cc index ce283f0621..26ef27c3ca 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.cc +++ b/paddle/fluid/memory/detail/buddy_allocator.cc @@ -11,8 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define GLOG_NO_ABBREVIATED_SEVERITIES -#define GOOGLE_GLOG_DLL_DECL + #include "paddle/fluid/memory/detail/buddy_allocator.h" #include "glog/logging.h" diff --git a/paddle/fluid/memory/detail/meta_cache.cc b/paddle/fluid/memory/detail/meta_cache.cc index 2a283733f5..b86e4f38c4 100644 --- a/paddle/fluid/memory/detail/meta_cache.cc +++ b/paddle/fluid/memory/detail/meta_cache.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define GLOG_NO_ABBREVIATED_SEVERITIES -#define GOOGLE_GLOG_DLL_DECL #include "glog/logging.h" #include "paddle/fluid/memory/detail/memory_block.h" #include "paddle/fluid/platform/assert.h" diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index 92849bc2c0..1b96798d23 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #define GLOG_NO_ABBREVIATED_SEVERITIES -#define GOOGLE_GLOG_DLL_DECL #include "paddle/fluid/memory/detail/system_allocator.h" diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c43f0a2159..919ad96f7a 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -86,7 +86,7 @@ function(op_library TARGET) # remove windows unsupported op, because windows has no nccl, no warpctc such ops. foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op" "crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op" - "fusion_seqconv_eltadd_relu_op" "hash_op") + "fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op") if ("${TARGET}" STREQUAL "${windows_unsupport_op}") return() endif() @@ -284,10 +284,12 @@ op_library(array_to_lod_tensor_op DEPS lod_rank_table_op) op_library(max_sequence_len_op DEPS lod_rank_table) op_library(sequence_conv_op DEPS context_project) op_library(sequence_pool_op DEPS sequence_pooling) -op_library(lstm_op DEPS sequence2batch lstm_compute) -op_library(hierarchical_sigmoid_op DEPS matrix_bit_code) -op_library(lstmp_op DEPS sequence2batch lstm_compute) -op_library(gru_op DEPS sequence2batch gru_compute) +if (NOT WIN32) + op_library(lstm_op DEPS sequence2batch lstm_compute) + op_library(hierarchical_sigmoid_op DEPS matrix_bit_code) + op_library(lstmp_op DEPS sequence2batch lstm_compute) + op_library(gru_op DEPS sequence2batch gru_compute) +endif(NOT WIN32) op_library(recurrent_op DEPS executor) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(cos_sim_op DEPS cos_sim_functor) diff --git a/paddle/fluid/operators/accuracy_op.h b/paddle/fluid/operators/accuracy_op.h index 8d3313db96..803244dd48 100644 --- a/paddle/fluid/operators/accuracy_op.h +++ b/paddle/fluid/operators/accuracy_op.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once #include - #include "paddle/fluid/framework/op_registry.h" namespace paddle { diff --git a/paddle/fluid/operators/cast_op.h b/paddle/fluid/operators/cast_op.h index ea710aaad5..8fa0416049 100644 --- a/paddle/fluid/operators/cast_op.h +++ b/paddle/fluid/operators/cast_op.h @@ -54,7 +54,6 @@ class CastOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out = context.Output("Out"); - framework::VisitDataType( static_cast( context.Attr("out_dtype")), diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu index e70945a2bd..c82930cc49 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu @@ -31,12 +31,12 @@ namespace operators { template __device__ bool GT_E(T a, T b) { - return (a > b) || fabsf(static_cast(a - b)) < 1e-4; + return (a > b) || fabs(a - b) < 1e-4; } template __device__ bool LT_E(T a, T b) { - return (a < b) || fabsf(static_cast(a - b)) < 1e-4; + return (a < b) || fabs(a - b) < 1e-4; } template diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index 29276955fe..93204216f9 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc index 59f44b112c..0522a94195 100644 --- a/paddle/fluid/operators/load_combine_op.cc +++ b/paddle/fluid/operators/load_combine_op.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include -#include #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device_context.h" @@ -33,15 +32,9 @@ class LoadCombineOp : public framework::OperatorBase { const platform::Place &place) const override { auto filename = Attr("file_path"); auto load_as_fp16 = Attr("load_as_fp16"); - auto format = Attr("format"); - std::unique_ptr fin; - if (format == "windows") { - fin.reset(new std::ifstream(filename, - std::ios_base::in | std::ios_base::binary)); - } else { - fin.reset(new std::ifstream(filename)); - } - PADDLE_ENFORCE(static_cast(*fin), + + std::ifstream fin(filename); + PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load_combine op", filename); auto out_var_names = Outputs("Out"); @@ -61,11 +54,11 @@ class LoadCombineOp : public framework::OperatorBase { auto *tensor = out_var->GetMutable(); // Error checking - PADDLE_ENFORCE(static_cast(*fin), "Cannot read more from file %s", + PADDLE_ENFORCE(static_cast(fin), "Cannot read more from file %s", filename); // Get data from fin to tensor - DeserializeFromStream(*fin, tensor, dev_ctx); + DeserializeFromStream(fin, tensor, dev_ctx); auto in_dtype = framework::ToDataType(tensor->type()); auto out_dtype = @@ -110,18 +103,6 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { "LoDTensors will be loaded from \"file_path\".") .AddCustomChecker( [](const std::string &path) { return !path.empty(); }); - AddAttr("format", - R"DOC((windows|linux)" "saved model file format - windows and linux file newline symbol is -different. windows(newline is \n\r) or linux(newline is \r) -So if you set attribute format to windows, then we saved model file in binary. -It can be used both linux and windows. If you set format to linux, -it will save file in normal file, newline symbol is \r. Need to note -that these two format is not inter-compatible.)DOC") - .SetDefault("linux") - .AddCustomChecker([](const std::string &s) { - return s == "windows" || s == "linux"; - }); AddComment(R"DOC( LoadCombine Operator. diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index e0e2c3dc4f..51219504ff 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include -#include #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/op_registry.h" @@ -35,15 +34,8 @@ class LoadOp : public framework::OperatorBase { // FIXME(yuyang18): We save variable to local file now, but we should change // it to save an output stream. auto filename = Attr("file_path"); - auto format = Attr("format"); - std::unique_ptr fin; - if (format == "windows") { - fin.reset(new std::ifstream(filename, - std::ios_base::in | std::ios_base::binary)); - } else { - fin.reset(new std::ifstream(filename)); - } - PADDLE_ENFORCE(static_cast(*fin), "Cannot open file %s for load op", + std::ifstream fin(filename); + PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", filename); auto out_var_name = Output("Out"); @@ -52,9 +44,9 @@ class LoadOp : public framework::OperatorBase { out_var_name); if (out_var->IsType()) { - LoadLodTensor(*fin, place, out_var); + LoadLodTensor(fin, place, out_var); } else if (out_var->IsType()) { - LoadSelectedRows(*fin, place, out_var); + LoadSelectedRows(fin, place, out_var); } else { PADDLE_ENFORCE( false, @@ -118,18 +110,6 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { R"(Variable will be loaded from "file_path")") .AddCustomChecker( [](const std::string &path) { return !path.empty(); }); - AddAttr("format", - R"DOC((windows|linux)" "saved model file format - windows and linux file newline symbol is -different. windows(newline is \n\r) or linux(newline is \r) -So if you set attribute format to windows, then we saved model file in binary. -It can be used both linux and windows. If you set format to linux, -it will save file in normal file, newline symbol is \r. Need to note -that these two format is not inter-compatible.)DOC") - .SetDefault("linux") - .AddCustomChecker([](const std::string &s) { - return s == "windows" || s == "linux"; - }); AddComment( "Load operator will load a LoDTensor / SelectedRows variable from disk " "file."); diff --git a/paddle/fluid/operators/lstm_unit_op.h b/paddle/fluid/operators/lstm_unit_op.h index 5d1d667fe1..4ead9c2293 100644 --- a/paddle/fluid/operators/lstm_unit_op.h +++ b/paddle/fluid/operators/lstm_unit_op.h @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index f2f398b8a1..868a7a7064 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -57,6 +57,9 @@ math_library(sequence_padding) math_library(sequence_pooling DEPS math_function) math_library(sequence_scale) math_library(softmax DEPS math_function) +if (NOT WIN32) + math_library(matrix_bit_code) +endif (NOT WIN32) math_library(unpooling) math_library(vol2col) @@ -72,9 +75,7 @@ if(WITH_GPU) endif() cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) -if (NOT WIN32) - math_library(matrix_bit_code) -endif (NOT WIN32) + set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc) set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce) if(WITH_XBYAK) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 38df5776bf..0aed253c80 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -18,6 +18,10 @@ limitations under the License. */ #include #include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/enforce.h" +#ifdef __AVX__ +#include +#endif + #ifdef PADDLE_WITH_MKLML #include "paddle/fluid/platform/dynload/mklml.h" #endif diff --git a/paddle/fluid/operators/math/detail/activation_functions.h b/paddle/fluid/operators/math/detail/activation_functions.h index 24df1f93ed..b127fbe8c8 100644 --- a/paddle/fluid/operators/math/detail/activation_functions.h +++ b/paddle/fluid/operators/math/detail/activation_functions.h @@ -15,10 +15,13 @@ limitations under the License. */ #pragma once #include #include -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/hostdevice.h" +#ifdef __AVX__ +#include +#endif + namespace paddle { namespace operators { namespace math { diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 73089a4f0c..f976953a24 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -25,6 +25,10 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/mklml.h" #endif +#ifdef __AVX__ +#include +#endif + namespace paddle { namespace operators { namespace math { diff --git a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc index 4626ff5cb3..a4861c347e 100644 --- a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc +++ b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc @@ -16,6 +16,9 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/math/jit_kernel_macro.h" +#ifdef __AVX__ +#include +#endif namespace paddle { namespace operators { @@ -260,7 +263,6 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { } \ } -#ifndef _WIN32 // commented out crf decoding #ifdef __AVX__ INTRIAVX_FLOAT(kEQ8); INTRIAVX_FLOAT(kGT8LT16); @@ -273,7 +275,6 @@ INTRIAVX2_FLOAT(jit::avx2, kGT8LT16); INTRIAVX2_FLOAT(jit::avx2, kEQ16); INTRIAVX2_FLOAT(jit::avx2, kGT16); #endif -#endif // WIN32 #ifdef __AVX512F__ INTRIAVX2_FLOAT(jit::avx512f, kEQ8); INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16); diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 131c226589..d7c177e678 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -20,6 +20,10 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/mklml.h" #endif +#ifdef __AVX__ +#include +#endif + namespace paddle { namespace operators { namespace math { @@ -62,18 +66,14 @@ namespace detail { #ifdef __AVX__ -#if defined(_WIN32) -#define ALIGN32 __declspec(align(32)) -#else #define ALIGN32 __attribute__((aligned(32))) -#endif // _WIN32 #define _PS256_CONST(Name, Val) \ - static const float ALIGN32 _ps256_##Name[8] = {Val, Val, Val, Val, \ + static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \ Val, Val, Val, Val} #define _PI256_CONST(Name, Val) \ - static const int ALIGN32 _pi256_##Name[8] = {Val, Val, Val, Val, \ + static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \ Val, Val, Val, Val} _PI256_CONST(0x7f, 0x7f); @@ -98,7 +98,7 @@ typedef union imm_xmm_union { #define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \ { \ - imm_xmm_union ALIGN32 u; \ + imm_xmm_union u ALIGN32; \ u.imm = imm_; \ xmm0_ = u.xmm[0]; \ xmm1_ = u.xmm[1]; \ @@ -106,7 +106,7 @@ typedef union imm_xmm_union { #define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \ { \ - imm_xmm_union ALIGN32 u; \ + imm_xmm_union u ALIGN32; \ u.xmm[0] = xmm0_; \ u.xmm[1] = xmm1_; \ imm_ = u.imm; \ @@ -508,14 +508,12 @@ class VTanhKernelImpl : public VTanhKernel { vaddbias_->Compute(-1.f, y, y); \ } -#ifndef __WIN32 #ifdef __AVX__ INTRI8_FLOAT(jit::avx, detail::ExpAVX); INTRI16_FLOAT(jit::avx, detail::ExpAVX); INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX); INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX); -#endif // AVX -#endif // WIN32 +#endif #ifdef __AVX2__ INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); diff --git a/paddle/fluid/operators/math/jit_kernel_rnn.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc index fc6a3caef0..ba3e917377 100644 --- a/paddle/fluid/operators/math/jit_kernel_rnn.cc +++ b/paddle/fluid/operators/math/jit_kernel_rnn.cc @@ -18,6 +18,10 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" +#ifdef __AVX__ +#include +#endif + namespace paddle { namespace operators { namespace math { diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index ddd6b2a531..c4fccdbf86 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -16,7 +16,6 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/fluid/operators/math/sequence_pooling.cu index 51da6de26e..0015fafbc8 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cu +++ b/paddle/fluid/operators/math/sequence_pooling.cu @@ -16,12 +16,13 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence_pooling.h" #include "paddle/fluid/platform/cuda_primitives.h" -#include "paddle/fluid/platform/macros.h" namespace paddle { namespace operators { namespace math { +#define FLT_MAX __FLT_MAX__ + template struct MaxPoolFunctor { HOSTDEVICE void operator()(const T* input, const size_t start, diff --git a/paddle/fluid/operators/print_op.cc b/paddle/fluid/operators/print_op.cc index e18bc17fd6..e7f1caf4d3 100644 --- a/paddle/fluid/operators/print_op.cc +++ b/paddle/fluid/operators/print_op.cc @@ -13,7 +13,6 @@ limitations under the License. */ #include -#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/var_type.h" diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index f1cd7c6ff6..5b05f757c0 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include #include -#include #include #include #include "paddle/fluid/framework/data_type.h" @@ -42,7 +41,6 @@ class SaveCombineOp : public framework::OperatorBase { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); auto save_as_fp16 = Attr("save_as_fp16"); - auto format = Attr("format"); bool is_present = FileExists(filename); if (is_present && !overwrite) { @@ -51,14 +49,8 @@ class SaveCombineOp : public framework::OperatorBase { } MkDirRecursively(DirName(filename).c_str()); - std::unique_ptr fout; - if (format == "windows") { - fout.reset(new std::ofstream(filename, - std::ios_base::out | std::ios_base::binary)); - } else { - fout.reset(new std::ofstream(filename)); - } - PADDLE_ENFORCE(static_cast(*fout), "Cannot open %s to write", + std::ofstream fout(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", filename); auto inp_var_names = Inputs("X"); @@ -94,11 +86,12 @@ class SaveCombineOp : public framework::OperatorBase { // copy LoD info to the new tensor out.set_lod(tensor.lod()); framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); - framework::SerializeToStream(*fout, out, dev_ctx); + framework::SerializeToStream(fout, out, dev_ctx); } else { - framework::SerializeToStream(*fout, tensor, dev_ctx); + framework::SerializeToStream(fout, tensor, dev_ctx); } } + fout.close(); } }; @@ -131,18 +124,6 @@ to a file on disk. "The \"file_path\" where the LoDTensor variables will be saved.") .AddCustomChecker( [](const std::string &path) { return !path.empty(); }); - AddAttr("format", - R"DOC((windows|linux)" "saved model file format - windows and linux file newline symbol is -different. windows(newline is \n\r) or linux(newline is \r) -So if you set attribute format to windows, then we saved model file in binary. -It can be used both linux and windows. If you set format to linux, -it will save file in normal file, newline symbol is \r. Need to note -that these two format is not inter-compatible.)DOC") - .SetDefault("linux") - .AddCustomChecker([](const std::string &s) { - return s == "windows" || s == "linux"; - }); } }; diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 9eea9e1a95..e79cffcf49 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include #include -#include #include #include "paddle/fluid/framework/data_type.h" @@ -65,7 +64,6 @@ class SaveOp : public framework::OperatorBase { framework::Variable *var) const { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); - auto format = Attr("format"); if (FileExists(filename) && !overwrite) { PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", @@ -82,14 +80,8 @@ class SaveOp : public framework::OperatorBase { // FIXME(yuyang18): We save variable to local file now, but we should change // it to save an output stream. - std::unique_ptr fout; - if (format == "windows") { - fout.reset(new std::ofstream(filename, - std::ios_base::out | std::ios_base::binary)); - } else { - fout.reset(new std::ofstream(filename)); - } - PADDLE_ENFORCE(static_cast(*fout), "Cannot open %s to write", + std::ofstream fout(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", filename); auto save_as_fp16 = Attr("save_as_fp16"); @@ -103,10 +95,11 @@ class SaveOp : public framework::OperatorBase { framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); // copy LoD info to the new tensor out.set_lod(tensor.lod()); - framework::SerializeToStream(*fout, out, dev_ctx); + framework::SerializeToStream(fout, out, dev_ctx); } else { - framework::SerializeToStream(*fout, tensor, dev_ctx); + framework::SerializeToStream(fout, tensor, dev_ctx); } + fout.close(); } void SaveSelectedRows(const framework::Scope &scope, @@ -117,7 +110,6 @@ class SaveOp : public framework::OperatorBase { lt_var != nullptr, "Can not find variable kLookupTablePath for SaveSelectedRows"); std::string filename = lt_var->data(); - auto format = Attr("format"); VLOG(4) << "SaveSelectedRows get File name: " << filename; MkDirRecursively(DirName(filename).c_str()); @@ -130,16 +122,11 @@ class SaveOp : public framework::OperatorBase { // FIXME(yuyang18): We save variable to local file now, but we should change // it to save an output stream. - std::unique_ptr fout; - if (format == "windows") { - fout.reset(new std::ofstream(filename, - std::ios_base::out | std::ios_base::binary)); - } else { - fout.reset(new std::ofstream(filename)); - } - PADDLE_ENFORCE(static_cast(*fout), "Cannot open %s to write", + std::ofstream fout(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", filename); - framework::SerializeToStream(*fout, selectedRows, dev_ctx); + framework::SerializeToStream(fout, selectedRows, dev_ctx); + fout.close(); } }; @@ -167,18 +154,6 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file "The \"file_path\" where the variable will be saved.") .AddCustomChecker( [](const std::string &path) { return !path.empty(); }); - AddAttr("format", - R"DOC((windows|linux)" "saved model file format - windows and linux file newline symbol is -different. windows(newline is \n\r) or linux(newline is \r) -So if you set attribute format to windows, then we saved model file in binary. -It can be used both linux and windows. If you set format to linux, -it will save file in normal file, newline symbol is \r. Need to note -that these two format is not inter-compatible.)DOC") - .SetDefault("linux") - .AddCustomChecker([](const std::string &s) { - return s == "windows" || s == "linux"; - }); } }; diff --git a/paddle/fluid/operators/split_lod_tensor_op.cc b/paddle/fluid/operators/split_lod_tensor_op.cc index cfe491f4c5..767449cde9 100644 --- a/paddle/fluid/operators/split_lod_tensor_op.cc +++ b/paddle/fluid/operators/split_lod_tensor_op.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/port.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index f30668fd21..673f86da76 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -34,7 +34,7 @@ namespace operators { using FluidDT = framework::proto::VarType_Type; using TRT_DT = nvinfer1::DataType; -namespace { // NOLINT +namespace { TRT_DT FluidDataType2TRT(FluidDT type) { switch (type) { @@ -60,7 +60,7 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape) { return nvinfer1::DimsCHW(shape[1], 1, 1); } -} // NOLINT // namespace +} // namespace using inference::Singleton; using inference::tensorrt::TRT_EngineManager; diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index bc0204e579..6810a1651a 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -16,18 +16,6 @@ limitations under the License. */ #include -#ifdef _WIN32 -#if defined(__AVX2__) -#include //avx2 -#elif defined(__AVX__) -#include //avx -#endif // AVX -#else // WIN32 -#ifdef __AVX__ -#include -#endif -#endif // WIN32 - namespace paddle { namespace platform { diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 0ec3a2a859..07bb02be19 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -59,7 +59,6 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) { #define CUDNN_VERSION_MIN(major, minor, patch) \ (CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch))) -#if !defined(_WIN32) #define CUDNN_ENFORCE(condition) \ do { \ cudnnStatus_t status = condition; \ @@ -67,16 +66,6 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) { PADDLE_THROW(::paddle::platform::cudnnGetErrorString(status)); \ } \ } while (false) -#else -// windows -#define CUDNN_ENFORCE(condition) \ - do { \ - cudnnStatus_t status = condition; \ - if (status != CUDNN_STATUS_SUCCESS) { \ - std::cerr << ::paddle::platform::cudnnGetErrorString(status); \ - } \ - } while (false) -#endif enum class DataLayout { // Not use kNHWC, diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index b95e25e2c1..ff49a1d57f 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -55,6 +55,7 @@ DeviceContextPool::DeviceContextPool( for (auto& p : places) { set.insert(p); } + for (auto& p : set) { if (platform::is_cpu_place(p)) { #ifdef PADDLE_WITH_MKLDNN @@ -204,9 +205,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) << ", Runtime Version: " << runtime_version_ / 1000 << "." << (runtime_version_ % 100) / 10; -#ifndef _WIN32 callback_manager_.reset(new StreamCallbackManager(stream_)); -#endif // NOT WIN32 } CUDADeviceContext::~CUDADeviceContext() { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 51cac83961..df248f9bb1 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -32,7 +32,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/stream_callback_manager.h" #endif #include "unsupported/Eigen/CXX11/Tensor" @@ -173,7 +173,6 @@ class CUDADeviceContext : public DeviceContext { PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); } -#ifndef _WIN32 template void AddStreamCallback(Callback&& callback) const { std::lock_guard guard(callback_mtx_); @@ -184,16 +183,6 @@ class CUDADeviceContext : public DeviceContext { std::lock_guard guard(callback_mtx_); callback_manager_->Wait(); } -#else - template - void AddStreamCallback(Callback&& callback) const { - // ugly empty functor. - } - - void WaitStreamCallback() const { - // ugly empty functor. - } -#endif private: CUDAPlace place_; @@ -212,12 +201,10 @@ class CUDADeviceContext : public DeviceContext { mutable std::mutex mtx_; -#ifndef _WIN32 // This lock is only used by callback // If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes mutable std::mutex callback_mtx_; std::unique_ptr callback_manager_; -#endif }; template <> diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 23f64170eb..a251bfcd99 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -127,7 +127,7 @@ struct EOFException : public std::exception { #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) #else // there is no equivalent intrinsics in msvc. -#define UNLIKELY(condition) ((condition) == 0) +#define UNLIKELY(condition) (condition == 0) #endif #if !defined(_WIN32) diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index e373a34d1e..2211e55043 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -167,9 +167,7 @@ void InitGLOG(const std::string &prog_name) { // glog will not hold the ARGV[0] inside. // Use strdup to alloc a new string. google::InitGoogleLogging(strdup(prog_name.c_str())); -#if !defined(_WIN32) google::InstallFailureSignalHandler(); -#endif } } // namespace framework diff --git a/paddle/fluid/platform/macros.h b/paddle/fluid/platform/macros.h index 906ed6e825..32b7efc04c 100644 --- a/paddle/fluid/platform/macros.h +++ b/paddle/fluid/platform/macros.h @@ -28,16 +28,3 @@ limitations under the License. */ #if defined(__FLT_MAX__) #define FLT_MAX __FLT_MAX__ #endif // __FLT_MAX__ - -#ifdef _WIN32 -#if defined(PADDLE_COMPILE) -// by default, msvc has predefined macro _LIB for static library -// only shared library need to export and import symbols -// static library export all symbols by default. -#define PADDLE_DLL __declspec(dllexport) -#else -#define PADDLE_DLL __declspec(dllimport) -#endif -#else -#define PADDLE_DLL -#endif diff --git a/paddle/fluid/platform/port.h b/paddle/fluid/platform/port.h index 8f1e3bdd31..cf9f4aa95b 100644 --- a/paddle/fluid/platform/port.h +++ b/paddle/fluid/platform/port.h @@ -15,13 +15,12 @@ #pragma once #include -#include -#include // NOLINT #include + +#include #include #define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h -#define GOOGLE_GLOG_DLL_DECL #include "glog/logging.h" #if !defined(_WIN32) @@ -62,6 +61,7 @@ static void *dlopen(const char *filename, int flag) { } return reinterpret_cast(hModule); } + #endif // !_WIN32 static void ExecShellCommand(const std::string &cmd, std::string *message) { From f1c1acf1ac027f01b9f764734684769cf38b5a26 Mon Sep 17 00:00:00 2001 From: Krzysztof Binias Date: Thu, 8 Nov 2018 11:50:59 +0100 Subject: [PATCH 15/23] Changed hardcoded format to any in convolution and bumped MKL-DNN version to 0.17-rc test=develop --- cmake/external/mkldnn.cmake | 2 +- paddle/fluid/operators/conv_mkldnn_op.cc | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index baf253df27..58d1333f93 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -54,7 +54,7 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} DEPENDS ${MKLDNN_DEPENDS} GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" - GIT_TAG "64e03a1939e0d526aa8e9f2e3f7dc0ad8d372944" + GIT_TAG "21fb5f2af1dd14e132af4f1b79160977ee487818" PREFIX ${MKLDNN_SOURCES_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 72cac9bc9f..f2cc6642ee 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -375,8 +375,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto src_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), - (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); + weights_tz, platform::MKLDNNGetDataType(), chosen_memory_format); std::vector bias_tz; // TODO(mgallus): avoid empty vector creation. // Currently used whenever bias is != nullptr. auto dst_md = platform::MKLDNNMemDesc( From c5b6573a5a1872b9bb80b00a4f26f5e5a913f398 Mon Sep 17 00:00:00 2001 From: chengduo Date: Thu, 8 Nov 2018 19:36:10 +0800 Subject: [PATCH 16/23] Fix input (#14208) * fix input test=develop * fix split_ids test=develop * ElementwiseMul should not support SelectedRows * fix scale op test=develop * change GetTensorFromVar() method to GetTensorOrSelectedRowsFromVar() * fix operator * refine MultiOutput * fix MultiOutput test=develop * disable test_dist_save_load test=develop * fix elementwise_op test=develop * add get_sparse_as_op test=develop * add info for check test=develop * rename get_sparse_as_op with extract_rows_as_op. test=develop * elementwise doesn't support selected_rows * fix regularizer * remove extract_rows_as test=develop * fix ci test=develop * add test for sum_op * fix regularizer test=develop * test=develop * fix pserver weight decay multi inputs test=develop --- .../details/multi_devices_graph_pass.cc | 6 + paddle/fluid/framework/operator.cc | 36 +++--- paddle/fluid/framework/operator.h | 10 +- paddle/fluid/operators/CMakeLists.txt | 1 - paddle/fluid/operators/elementwise_add_op.h | 71 ++++++------ paddle/fluid/operators/elementwise_div_op.h | 7 +- paddle/fluid/operators/elementwise_max_op.h | 7 +- paddle/fluid/operators/elementwise_min_op.h | 7 +- paddle/fluid/operators/elementwise_mul_op.h | 7 +- paddle/fluid/operators/elementwise_op.h | 44 +++++--- paddle/fluid/operators/elementwise_sub_op.h | 7 +- paddle/fluid/operators/extract_rows_op.cc | 103 ------------------ .../operators/math/selected_rows_functor.h | 2 + paddle/fluid/operators/scale_op.h | 17 +-- paddle/fluid/operators/split_ids_op.cc | 3 +- paddle/fluid/operators/split_ids_op.h | 4 + paddle/fluid/operators/sum_op.cc | 4 +- paddle/fluid/pybind/const_value.cc | 1 + python/paddle/fluid/regularizer.py | 66 ++++------- .../tests/unittests/test_dist_transpiler.py | 15 +-- .../unittests/test_elementwise_mul_op.py | 51 --------- .../tests/unittests/test_extract_rows_op.py | 60 ---------- .../fluid/tests/unittests/test_regularizer.py | 4 +- .../fluid/tests/unittests/test_sum_op.py | 100 +++++++++++++---- 24 files changed, 240 insertions(+), 393 deletions(-) delete mode 100644 paddle/fluid/operators/extract_rows_op.cc delete mode 100644 python/paddle/fluid/tests/unittests/test_extract_rows_op.py diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 67d29a42d7..3dc177a8cb 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -648,6 +648,12 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID( const ir::Graph &graph, const std::string &varname, const std::unordered_map &sharded_var_device) const { auto got = sharded_var_device.find(varname); + if (got == sharded_var_device.end()) { + auto pos = varname.find(framework::kNewGradSuffix); + if (pos != std::string::npos) { + got = sharded_var_device.find(varname.substr(0, pos)); + } + } return got == sharded_var_device.end() ? -1 : got->second; } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 45fc36c706..73886ed304 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable& var) { return var.IsType() || var.IsType(); } -const Tensor* GetTensorFromVar(const Variable& var) { +const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) { if (var.IsType()) { return static_cast(&(var.Get())); } else if (var.IsType()) { @@ -369,7 +369,7 @@ const Tensor* GetTensorFromVar(const Variable& var) { } } -static Tensor* GetMutableTensorFromVar(Variable* var) { +Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) { if (var->IsType()) { return var->GetMutable(); } else if (var->IsType()) { @@ -414,8 +414,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const { template <> const Tensor* ExecutionContext::Input(const std::string& name) const { - auto* var = InputVar(name); - return var == nullptr ? nullptr : GetTensorFromVar(*var); + return Input(name); } template <> @@ -425,17 +424,21 @@ const std::vector ExecutionContext::MultiInput( std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [&](const std::string& sub_name) { + [&](const std::string& sub_name) -> const Tensor* { auto var = scope_.FindVar(sub_name); - return var == nullptr ? nullptr : GetTensorFromVar(*var); + if (var == nullptr) return nullptr; + PADDLE_ENFORCE( + var->IsType(), + "%s should be LoDTensor, but the received type is %s", + sub_name, var->Type().name()); + return &(var->Get()); }); return res; } template <> Tensor* ExecutionContext::Output(const std::string& name) const { - auto var = OutputVar(name); - return var == nullptr ? nullptr : GetMutableTensorFromVar(var); + return Output(name); } template <> @@ -445,10 +448,14 @@ std::vector ExecutionContext::MultiOutput( std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [&](const std::string& sub_name) { + [&](const std::string& sub_name) -> Tensor* { auto var = scope_.FindVar(sub_name); - return var == nullptr ? nullptr - : GetMutableTensorFromVar(var); + if (var == nullptr) return nullptr; + PADDLE_ENFORCE( + var->IsType(), + "%s should be LoDTensor, but the received type is %s", + sub_name, var->Type().name()); + return var->GetMutable(); }); return res; } @@ -768,11 +775,12 @@ void OperatorWithKernel::TransferInplaceVarsBack( const Scope& transfer_scope) const { for (auto& var_name : inplace_vars) { VLOG(3) << "share inplace var " + var_name + " back to it's original scope"; - auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name)); + auto* original_tensor = + GetMutableLoDTensorOrSelectedRowsValueFromVar(scope.FindVar(var_name)); auto* var = transfer_scope.FindVar(var_name); PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr", var_name); - auto* transformed_tensor = GetTensorFromVar(*var); + auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var); original_tensor->ShareDataWith(*transformed_tensor); } } @@ -789,7 +797,7 @@ Scope* OperatorWithKernel::TryTransferData( continue; } - auto* tensor_in = GetTensorFromVar(*var); + auto* tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); if (!tensor_in->IsInitialized()) { continue; } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 96ad320523..40b0130b26 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -54,6 +54,9 @@ constexpr char kGradVarSuffix[] = "@GRAD"; /// Variables with this suffix are supposed to be filled up with zeros. constexpr char kZeroVarSuffix[] = "@ZERO"; +/// Variables with this suffix are the new Gradient. +constexpr char kNewGradSuffix[] = "@NEWGRAD@"; + // define some kernel priority /* Define multiple kernel type fallback order*/ extern std::vector> kKernelPriority; @@ -63,7 +66,8 @@ inline std::string GradVarName(const std::string& var_name) { } proto::VarType::Type GetDataTypeOfVar(const Variable* var); -const Tensor* GetTensorFromVar(const Variable& var); +const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); +Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); class OperatorBase; class ExecutionContext; @@ -224,7 +228,7 @@ class ExecutionContext { std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [&](const std::string& sub_name) { + [&](const std::string& sub_name) -> const T* { auto var = scope_.FindVar(sub_name); return var == nullptr ? nullptr : &var->Get(); }); @@ -237,7 +241,7 @@ class ExecutionContext { std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [&](const std::string& sub_name) { + [&](const std::string& sub_name) -> T* { auto var = scope_.FindVar(sub_name); return var == nullptr ? nullptr : var->GetMutable(); }); diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 919ad96f7a..2a7de024bf 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -296,7 +296,6 @@ op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) op_library(unsqueeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op) -op_library(extract_rows_op DEPS memory) op_library(flatten_op DEPS reshape_op) op_library(sequence_pad_op DEPS sequence_padding) op_library(unstack_op DEPS stack_op) diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index c60cb1f92e..9edbdbefe7 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -28,9 +28,9 @@ struct AddFunctor { }; template -void default_elementwise_add(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, framework::Tensor* z) { +void default_elementwise_add(const framework::ExecutionContext &ctx, + const framework::Tensor *x, + const framework::Tensor *y, framework::Tensor *z) { int axis = ctx.Attr("axis"); ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, AddFunctor(), z); @@ -40,9 +40,9 @@ template typename std::enable_if< std::is_floating_point::value && std::is_same::value>::type -elementwise_add(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { +elementwise_add(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z) { auto eigen_x = framework::EigenVector::Flatten(*x); auto eigen_y = framework::EigenVector::Flatten(*y); auto eigen_z = framework::EigenVector::Flatten(*z); @@ -55,21 +55,20 @@ template typename std::enable_if< !std::is_floating_point::value || !std::is_same::value>::type -elementwise_add(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { +elementwise_add(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z) { default_elementwise_add(ctx, x, y, z); } template class ElementwiseAddKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *y = ctx.Input("Y"); + auto *z = ctx.Output("Out"); - const auto x = ctx.Input("X"); - const auto y = ctx.Input("Y"); - auto z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); auto dims_equal = x->dims() == y->dims(); @@ -87,13 +86,13 @@ struct IdentityGrad { }; template -void default_elementwise_add_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, - framework::Tensor* dx, - framework::Tensor* dy) { +void default_elementwise_add_grad(const framework::ExecutionContext &ctx, + const framework::Tensor *x, + const framework::Tensor *y, + const framework::Tensor *out, + const framework::Tensor *dout, + framework::Tensor *dx, + framework::Tensor *dy) { int axis = ctx.Attr("axis"); ElemwiseExplicitGradCompute, @@ -106,11 +105,11 @@ template typename std::enable_if< std::is_floating_point::value && std::is_same::value>::type -elementwise_add_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, framework::Tensor* dx, - framework::Tensor* dy) { +elementwise_add_grad(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + const framework::Tensor *out, + const framework::Tensor *dout, framework::Tensor *dx, + framework::Tensor *dy) { auto blas = math::GetBlas(ctx); if (dx) { @@ -128,27 +127,27 @@ template typename std::enable_if< !std::is_floating_point::value || !std::is_same::value>::type -elementwise_add_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, framework::Tensor* dx, - framework::Tensor* dy) { +elementwise_add_grad(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + const framework::Tensor *out, + const framework::Tensor *dout, framework::Tensor *dx, + framework::Tensor *dy) { default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } template class ElementwiseAddGradKernel : public ElemwiseGradKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { + void Compute(const framework::ExecutionContext &ctx) const override { ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *dy = ctx.Output(framework::GradVarName("Y")); // skip out, x, y - auto* out = dout; + auto *out = dout; auto *x = dout, *y = dout; if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr && diff --git a/paddle/fluid/operators/elementwise_div_op.h b/paddle/fluid/operators/elementwise_div_op.h index 41a7950bf0..cdb1264d29 100644 --- a/paddle/fluid/operators/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise_div_op.h @@ -28,11 +28,10 @@ template class ElementwiseDivKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, diff --git a/paddle/fluid/operators/elementwise_max_op.h b/paddle/fluid/operators/elementwise_max_op.h index bfb5c93195..367489dd56 100644 --- a/paddle/fluid/operators/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise_max_op.h @@ -29,11 +29,10 @@ template class ElementwiseMaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, diff --git a/paddle/fluid/operators/elementwise_min_op.h b/paddle/fluid/operators/elementwise_min_op.h index db035ffb52..1bd0a62797 100644 --- a/paddle/fluid/operators/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise_min_op.h @@ -28,11 +28,10 @@ template class ElementwiseMinKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, diff --git a/paddle/fluid/operators/elementwise_mul_op.h b/paddle/fluid/operators/elementwise_mul_op.h index b870d08a1a..29e4ab7db1 100644 --- a/paddle/fluid/operators/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise_mul_op.h @@ -60,11 +60,10 @@ template class ElementwiseMulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); if (x->numel() == y->numel()) { elementwise_mul(ctx, x, y, z); diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h index 68c6e315cc..5eb4233344 100644 --- a/paddle/fluid/operators/elementwise_op.h +++ b/paddle/fluid/operators/elementwise_op.h @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" + #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -29,7 +31,8 @@ class ElementwiseOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; using Tensor = framework::Tensor; - void InferShape(framework::InferShapeContext* ctx) const override { + + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of elementwise op should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), @@ -37,6 +40,17 @@ class ElementwiseOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of elementwise op should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("X").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front()); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Y").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Y").front(), ctx->GetInputsVarType("Y").front()); + auto x_dim = ctx->GetInputDim("X"); auto y_dim = ctx->GetInputDim("Y"); PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), @@ -47,9 +61,8 @@ class ElementwiseOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("X")->type()); + const framework::ExecutionContext &ctx) const override { + auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { @@ -64,12 +77,12 @@ class ElementwiseOp : public framework::OperatorWithKernel { class ElementwiseOpInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { auto x_name = op_desc.Input("X")[0]; auto out_name = op_desc.Output("Out")[0]; - auto& x = block->FindRecursiveOrCreateVar(x_name); - auto& out = block->FindRecursiveOrCreateVar(out_name); + auto &x = block->FindRecursiveOrCreateVar(x_name); + auto &out = block->FindRecursiveOrCreateVar(out_name); out.SetType(x.GetType()); out.SetDataType(x.GetDataType()); } @@ -131,6 +144,7 @@ But the output only shares the LoD information with the input $X$. protected: virtual std::string GetName() const = 0; + virtual std::string GetEquation() const = 0; }; @@ -139,7 +153,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; using Tensor = framework::Tensor; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), @@ -165,7 +179,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::ToDataType( ctx.Input(framework::GradVarName("Out"))->type()); @@ -187,7 +201,7 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { using operators::ElementwiseOpGrad::GetExpectedKernelType; using Tensor = framework::Tensor; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); @@ -209,11 +223,11 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { template class ElemwiseGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* dx = + void Compute(const framework::ExecutionContext &context) const override { + auto *dx = context.Output(framework::GradVarName("X")); if (dx != nullptr) { - auto& dout = + auto &dout = *context.Input(framework::GradVarName("Out")); dx->set_lod(dout.lod()); } @@ -234,7 +248,7 @@ class ElemwiseGradKernel : public framework::OpKernel { \ protected: \ std::unique_ptr Apply() const override { \ - auto* op = new paddle::framework::OpDesc(); \ + auto *op = new paddle::framework::OpDesc(); \ op->SetType(#kernel_type "_grad"); \ op->SetInput("Y", Input("Y")); \ op->SetInput(::paddle::framework::GradVarName("Out"), \ diff --git a/paddle/fluid/operators/elementwise_sub_op.h b/paddle/fluid/operators/elementwise_sub_op.h index 3385df0897..7204c43464 100644 --- a/paddle/fluid/operators/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise_sub_op.h @@ -28,11 +28,10 @@ template class ElementwiseSubKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, diff --git a/paddle/fluid/operators/extract_rows_op.cc b/paddle/fluid/operators/extract_rows_op.cc deleted file mode 100644 index 3acae3bcdf..0000000000 --- a/paddle/fluid/operators/extract_rows_op.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class ExtractRowsOpInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of ExtractRowsOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of ExtractRowsOp should not be null."); - PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X")[0], - framework::proto::VarType::SELECTED_ROWS, - "The type of input(X) must be SelectedRows."); - auto in_dims = ctx->GetInputDim("X"); - - ctx->SetOutputDim( - "Out", framework::make_ddim(std::vector{in_dims[0], 1})); - } -}; - -class ExtractRowsOp : public framework::OperatorBase { - public: - ExtractRowsOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto &in = scope.FindVar(Input("X"))->Get(); - auto out = scope.FindVar(Output("Out"))->GetMutable(); - - auto &in_rows = in.rows(); - auto out_dim = framework::make_ddim( - std::vector{static_cast(in_rows.size()), 1}); - auto dst_ptr = out->mutable_data(out_dim, in.place()); - - if (paddle::platform::is_gpu_place(in.place())) { -#ifdef PADDLE_WITH_CUDA - platform::DeviceContextPool &pool = - platform::DeviceContextPool::Instance(); - auto *dev_ctx = pool.Get(in.place()); - auto src_ptr = in_rows.Data(in.place()); - auto stream = - reinterpret_cast(*dev_ctx) - .stream(); - memory::Copy(boost::get(out->place()), dst_ptr, - boost::get(in.place()), src_ptr, - in_rows.size() * sizeof(int64_t), stream); -#else - PADDLE_THROW("Not compiled with CUDA."); -#endif - } else { - memory::Copy(platform::CPUPlace(), dst_ptr, platform::CPUPlace(), - in_rows.data(), in_rows.size() * sizeof(int64_t)); - } - } -}; - -class ExtractRowsOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(SelectedRows). The input tensor of extract_rows operator," - " and its type is SelectedRows."); - AddOutput("Out", "(Tensor). The the rows of input(X)."); - - AddComment(R"DOC( - ExtractRows Operator. - -The function of extract_rows_op is extracting the rows from the input(X) -whose type is SelectedRows. - - )DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(extract_rows, ops::ExtractRowsOp, ops::ExtractRowsOpMaker, - ops::ExtractRowsOpInferShape); diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index b24ffb57ac..6d146d39d6 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -64,6 +64,8 @@ struct SelectedRowsSumTo { framework::SelectedRows* input2); }; +// FIXME: The result of SelectedRowsAddToTensor maybe non deterministic, +// because it uses CudaAtomicAdd. // input2 = input1 + input2 template struct SelectedRowsAddToTensor { diff --git a/paddle/fluid/operators/scale_op.h b/paddle/fluid/operators/scale_op.h index d8a199bc2b..96b8b00b42 100644 --- a/paddle/fluid/operators/scale_op.h +++ b/paddle/fluid/operators/scale_op.h @@ -24,19 +24,13 @@ class ScaleKernel : public framework::OpKernel { public: virtual void Compute(const framework::ExecutionContext& ctx) const { auto* in_var = ctx.InputVar("X"); - auto* in = ctx.Input("X"); - - auto* out_var = ctx.OutputVar("Out"); - auto* out = ctx.Output("Out"); - out->mutable_data(in->place()); - - PADDLE_ENFORCE_EQ(in->dims(), out->dims(), - "in and out should have the same dim"); + auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var); auto scale = static_cast(ctx.Attr("scale")); auto bias = static_cast(ctx.Attr("bias")); auto bias_after_scale = ctx.Attr("bias_after_scale"); + auto* out_var = ctx.OutputVar("Out"); if (in_var->IsType() && in_var != out_var) { auto& in_slr = in_var->Get(); auto* out_slr = out_var->GetMutable(); @@ -44,6 +38,13 @@ class ScaleKernel : public framework::OpKernel { out_slr->set_height(in_slr.height()); } + auto* out = + framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var); + out->mutable_data(in->place()); + + PADDLE_ENFORCE_EQ(in->dims(), out->dims(), + "in and out should have the same dim"); + auto eigen_out = framework::EigenVector::Flatten(*out); auto eigen_in = framework::EigenVector::Flatten(*in); auto& dev = *ctx.template device_context().eigen_device(); diff --git a/paddle/fluid/operators/split_ids_op.cc b/paddle/fluid/operators/split_ids_op.cc index 243f81e296..01d432e130 100644 --- a/paddle/fluid/operators/split_ids_op.cc +++ b/paddle/fluid/operators/split_ids_op.cc @@ -64,8 +64,7 @@ class SplitIdsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::ToDataType( - ctx.MultiInput("Ids").front()->type()), + framework::GetDataTypeOfVar(ctx.MultiInputVar("Ids").front()), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/split_ids_op.h b/paddle/fluid/operators/split_ids_op.h index 69ac6c5a6b..c8b0e71521 100644 --- a/paddle/fluid/operators/split_ids_op.h +++ b/paddle/fluid/operators/split_ids_op.h @@ -113,6 +113,10 @@ class SplitIdsOpKernel : public framework::OpKernel { row_width * sizeof(T)); } } + } else { + PADDLE_THROW( + "% should be LoDTensor or SelectedRows, but the received type is %s", + ctx.Inputs("Ids")[0], ids_var->Type().name()); } } }; diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index d19ac9839c..7df14158f3 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -85,8 +85,8 @@ class SumOp : public framework::OperatorWithKernel { for (size_t idx = 0; idx < x_vars.size(); ++idx) { PADDLE_ENFORCE(x_vars[idx] != nullptr, "Input var[%s] should not be nullptr", x_vars_name[idx]); - // FIXME(zcd): The input x_var may be SelectedRows or LoDTensor. - auto tensor = framework::GetTensorFromVar(*x_vars[idx]); + auto tensor = + framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_vars[idx]); if (tensor->numel() == 0) { continue; } diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index 1f61a0e289..06d8b65fb1 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -27,6 +27,7 @@ void BindConstValue(pybind11::module* m) { m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; }); m->def("kControlDepVarName", [] { return framework::ir::Node::kControlDepVarName; }); + m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; }); auto op_proto_and_checker_maker = m->def_submodule("op_proto_and_checker_maker"); diff --git a/python/paddle/fluid/regularizer.py b/python/paddle/fluid/regularizer.py index 57185da4d1..d8aace9fdf 100644 --- a/python/paddle/fluid/regularizer.py +++ b/python/paddle/fluid/regularizer.py @@ -61,14 +61,25 @@ def append_regularization_ops(parameters_and_grads, regularization=None): params_and_grads.append((param, grad)) continue - assert grad.shape == regularization_term.shape + new_grad = grad + if grad.type == core.VarDesc.VarType.SELECTED_ROWS: + # FIXME(zcd): If the grad is SELECTED_ROWS, after regularization, + # the grad's type and name will be changed. But the gradient's name + # is used in ParallelExecutor Reduce mode, so I add a flag for + # the new_grad here. + new_grad = grad.block.create_var( + name=grad.name + core.kNewGradSuffix(), + dtype=param.dtype, + shape=param.shape, + lod_level=param.lod_level, + type=core.VarDesc.VarType.LOD_TENSOR) grad.block.append_op( - type='elementwise_add', - inputs={"X": grad, - "Y": regularization_term}, - outputs={"Out": grad}) - params_and_grads.append((param, grad)) + type='sum', + inputs={"X": [grad, regularization_term]}, + outputs={"Out": new_grad}) + + params_and_grads.append((param, new_grad)) return params_and_grads @@ -142,26 +153,7 @@ class L2DecayRegularizer(WeightDecayRegularizer): assert isinstance(block, framework.Block) decay = block.create_var( - dtype="float32", shape=param.shape, lod_level=param.lod_level) - - if grad.type == core.VarDesc.VarType.SELECTED_ROWS: - idx = block.create_var( - dtype="int64", - shape=param.shape, - type=core.VarDesc.VarType.LOD_TENSOR) - decay = block.create_var( - dtype="float32", - shape=param.shape, - type=core.VarDesc.VarType.LOD_TENSOR) - block.append_op( - type='extract_rows', inputs={'X': grad}, outputs={'Out': idx}) - block.append_op( - type='lookup_table', - inputs={'W': param, - 'Ids': idx}, - outputs={'Out': decay}, - attrs={'is_sparse': True}) - param = decay + dtype=param.dtype, shape=param.shape, lod_level=param.lod_level) # Append Op to calculate decay block.append_op( @@ -218,27 +210,9 @@ class L1DecayRegularizer(WeightDecayRegularizer): """ assert isinstance(param, framework.Parameter) assert isinstance(block, framework.Block) + decay = block.create_var( - dtype="float32", shape=param.shape, lod_level=param.lod_level) - - if grad.type == core.VarDesc.VarType.SELECTED_ROWS: - idx = block.create_var( - dtype="int64", - shape=param.shape, - type=core.VarDesc.VarType.LOD_TENSOR) - decay = block.create_var( - dtype="float32", - shape=param.shape, - type=core.VarDesc.VarType.LOD_TENSOR) - block.append_op( - type='extract_rows', inputs={'X': grad}, outputs={'Out': idx}) - block.append_op( - type='lookup_table', - inputs={'W': param, - 'Ids': idx}, - outputs={'Out': decay}, - attrs={'is_sparse': True}) - param = decay + dtype=param.dtype, shape=param.shape, lod_level=param.lod_level) # Append sign op block.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 986fdd9ff2..3a5b6b5cb8 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -373,9 +373,8 @@ class TestL2Decay(TranspilerTest): self.assertEqual(len(pserver.blocks), 3) self.assertEqual([op.type for op in pserver.blocks[1].ops], ["sum", "scale", "clip", "sgd"]) - self.assertEqual( - [op.type for op in pserver.blocks[2].ops], - ["sum", "scale", "clip", "scale", "elementwise_add", "sgd"]) + self.assertEqual([op.type for op in pserver.blocks[2].ops], + ["sum", "scale", "clip", "scale", "sum", "sgd"]) # TODO(typhoonzero): test clipping and L2Decay ops are removed from trainer @@ -416,12 +415,10 @@ class TestL2DecayWithPiecewise(TranspilerTest): "logical_and", "conditional_block", "fill_constant", "conditional_block" ]) - self.assertEqual( - [op.type for op in pserver.blocks[7].ops], - ["sum", "scale", "scale", "elementwise_add", "momentum"]) - self.assertEqual( - [op.type for op in pserver.blocks[8].ops], - ["sum", "scale", "scale", "elementwise_add", "momentum"]) + self.assertEqual([op.type for op in pserver.blocks[7].ops], + ["sum", "scale", "scale", "sum", "momentum"]) + self.assertEqual([op.type for op in pserver.blocks[8].ops], + ["sum", "scale", "scale", "sum", "momentum"]) class TestEmptyPserverOptimizeBlocks(TranspilerTest): diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index 6a129b6df9..53409e436c 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -117,56 +117,5 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): } -class TestElementWiseMulSelectedRows(OpTest): - def setUp(self): - self.rows = [0, 1, 2, 3, 4, 5, 6] - self.feature = 12 - self.height = 100 - self.input_shape = (len(self.rows), self.feature) - - def prepare_input(self, scope, place): - self.input = { - "X": np.random.random(self.input_shape).astype("float32"), - "Y": np.random.random(self.input_shape).astype("float32") - } - - def init_input(in_name): - x_selected_rows = scope.var(in_name).get_selected_rows() - x_selected_rows.set_height(self.height) - x_selected_rows.set_rows(self.rows) - x_array = self.input[in_name] - x_tensor = x_selected_rows.get_tensor() - x_tensor.set(x_array, place) - - init_input("X") - init_input("Y") - - def create_out_selected_row(self, scope): - return scope.var('Out').get_selected_rows() - - def check_result(self, out_selected_rows): - assert out_selected_rows.height() == self.height - assert out_selected_rows.rows() == self.rows - out_tensor = np.array(out_selected_rows.get_tensor()) - assert out_tensor.shape == self.input_shape - - def check_with_place(self, place): - scope = core.Scope() - self.prepare_input(scope, place) - - out_selected_rows = self.create_out_selected_row(scope) - out_selected_rows.set_height(0) - out_selected_rows.set_rows([]) - - elementwise_mul = Operator("elementwise_mul", X='X', Y='Y', Out='Out') - elementwise_mul.run(scope, place) - self.check_result(out_selected_rows) - - def test_elewisemul_with_selected_rows_input(self): - places = [core.CPUPlace()] - for place in places: - self.check_with_place(place) - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_extract_rows_op.py b/python/paddle/fluid/tests/unittests/test_extract_rows_op.py deleted file mode 100644 index 8629bcf0f2..0000000000 --- a/python/paddle/fluid/tests/unittests/test_extract_rows_op.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import numpy as np -import paddle.fluid.core as core -from paddle.fluid.op import Operator -from op_test import OpTest - - -class TestExtractRows(OpTest): - def check_with_place(self, place): - scope = core.Scope() - - # create and initialize Variable - feature_len = 12 - rows = [0, 4, 4, 7] - np_array = np.ones((len(rows), feature_len)).astype("float32") - - in_x = scope.var('X').get_selected_rows() - in_x.set_height(len(rows)) - in_x.set_rows(rows) - in_x_tensor = in_x.get_tensor() - in_x_tensor.set(np_array, place) - - # create Out Variable - out_tensor = scope.var('Out').get_tensor() - - # create and run lookup_table operator - extract_rows_op = Operator("extract_rows", X='X', Out='Out') - extract_rows_op.run(scope, place) - - # get result from Out - result_array = np.array(out_tensor) - result_array = [ele[0] for ele in result_array] - assert result_array == rows - - def test_concat_rows(self): - places = [core.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(core.CUDAPlace(0)) - for place in places: - self.check_with_place(place) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_regularizer.py b/python/paddle/fluid/tests/unittests/test_regularizer.py index 6727335c60..20f91cf448 100644 --- a/python/paddle/fluid/tests/unittests/test_regularizer.py +++ b/python/paddle/fluid/tests/unittests/test_regularizer.py @@ -55,7 +55,7 @@ class TestL2DecayRegularizer(unittest.TestCase): params_grads = optimizer.append_regularization_ops(params_grads) self.assertEqual(len(params_grads), 1) self.assertEqual(len(block.ops), count_ops + 2) - self.assertEqual(block.ops[-1].type, 'elementwise_add') + self.assertEqual(block.ops[-1].type, 'sum') self.assertEqual(block.ops[-2].type, 'scale') @@ -92,7 +92,7 @@ class TestL1DecayRegularizer(unittest.TestCase): params_grads = optimizer.append_regularization_ops(params_grads) self.assertEqual(len(params_grads), 1) self.assertEqual(len(block.ops), count_ops + 3) - self.assertEqual(block.ops[-1].type, 'elementwise_add') + self.assertEqual(block.ops[-1].type, 'sum') self.assertEqual(block.ops[-2].type, 'scale') self.assertEqual(block.ops[-3].type, 'sign') diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index 643878dc5c..0be5be6e97 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -49,11 +49,14 @@ class TestSumOp(OpTest): class TestSelectedRowsSumOp(OpTest): - def check_with_place(self, place, inplace): + def setUp(self): self.height = 10 self.row_numel = 12 self.rows = [0, 1, 2, 3, 4, 5, 6] + self.dtype = np.float32 + self.init_kernel_type() + def check_with_place(self, place, inplace): self.check_input_and_optput(core.Scope(), place, inplace, True, True, True) self.check_input_and_optput(core.Scope(), place, inplace, False, True, @@ -64,12 +67,12 @@ class TestSelectedRowsSumOp(OpTest): False) def init_kernel_type(self): - self.dtype = np.float32 + pass - def _get_array(self, row_num, row_numel): - array = np.ones((row_num, row_numel)).astype(self.dtype) - for i in range(row_num): - array[i] *= i + def _get_array(self, rows, row_numel): + array = np.ones((len(rows), row_numel)).astype(self.dtype) + for i in range(len(rows)): + array[i] *= rows[i] return array def check_input_and_optput(self, @@ -105,7 +108,7 @@ class TestSelectedRowsSumOp(OpTest): self.assertTrue( np.array_equal( np.array(out.get_tensor()), - self._get_array(len(self.rows), self.row_numel) * + self._get_array(self.rows, self.row_numel) * has_data_w_num)) else: self.assertEqual(len(out.rows()), 0) @@ -121,7 +124,7 @@ class TestSelectedRowsSumOp(OpTest): w_selected_rows = var.get_selected_rows() w_selected_rows.set_height(self.height) w_selected_rows.set_rows(rows) - w_array = self._get_array(len(rows), self.row_numel) + w_array = self._get_array(self.rows, self.row_numel) w_tensor = w_selected_rows.get_tensor() w_tensor.set(w_array, place) @@ -136,36 +139,91 @@ class TestSelectedRowsSumOp(OpTest): self.check_with_place(place, inplace) +class TestLoDTensorAndSelectedRowsOp(TestSelectedRowsSumOp): + def setUp(self): + self.height = 10 + self.row_numel = 12 + self.rows = [0, 1, 2, 2, 4, 5, 6] + + def check_with_place(self, place, inplace): + scope = core.Scope() + if inplace: + self.create_lod_tensor(scope, place, "x1") + self.create_selected_rows(scope, place, "x2", True) + out = scope.var("x1").get_tensor() + out_name = "x1" + else: + self.create_selected_rows(scope, place, "x1", True) + self.create_lod_tensor(scope, place, "x2") + out = scope.var("out").get_tensor() + out_name = "out" + + # create and run sum operator + sum_op = Operator("sum", X=["x1", "x2"], Out=out_name) + sum_op.run(scope, place) + + result = np.ones((1, self.height)).astype(np.int32).tolist()[0] + for ele in self.rows: + result[ele] += 1 + + out_t = np.array(out) + self.assertEqual(out_t.shape[0], self.height) + self.assertTrue( + np.array_equal(out_t, + self._get_array([i for i in range( + self.height)], self.row_numel) * np.tile( + np.array(result).reshape(self.height, 1), + self.row_numel))) + + def create_lod_tensor(self, scope, place, var_name): + var = scope.var(var_name) + w_tensor = var.get_tensor() + w_array = self._get_array([i for i in range(self.height)], + self.row_numel) + w_tensor.set(w_array, place) + return var + + +#----------- test fp16 ----------- +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") class TestFP16SumOp(TestSumOp): def init_kernel_type(self): self.dtype = np.float16 def test_check_output(self): - if core.is_compiled_with_cuda(): - place = core.CUDAPlace(0) - if core.is_float16_supported(place): - self.check_output_with_place(place, atol=2e-2) + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=2e-2) # FIXME: Because of the precision fp16, max_relative_error # should be 0.15 here. def test_check_grad(self): - if core.is_compiled_with_cuda(): - place = core.CUDAPlace(0) - if core.is_float16_supported(place): - self.check_grad(['x0'], 'Out', max_relative_error=0.15) + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad(['x0'], 'Out', max_relative_error=0.15) -class TestFP16SelectedRowsSumOp(TestSelectedRowsSumOp): - def init_kernel_type(self): - self.dtype = np.float16 +def create_test_sum_fp16_class(parent): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestSumFp16Case(parent): + def init_kernel_type(self): + self.dtype = np.float16 - def test_w_is_selected_rows(self): - if core.is_compiled_with_cuda(): + def test_w_is_selected_rows(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): for inplace in [True, False]: self.check_with_place(place, inplace) + cls_name = "{0}_{1}".format(parent.__name__, "SumFp16Test") + TestSumFp16Case.__name__ = cls_name + globals()[cls_name] = TestSumFp16Case + + +create_test_sum_fp16_class(TestSelectedRowsSumOp) +create_test_sum_fp16_class(TestLoDTensorAndSelectedRowsOp) if __name__ == "__main__": unittest.main() From f8b2680c537428a463b0a7a45a722a5c917f18aa Mon Sep 17 00:00:00 2001 From: chengduo Date: Thu, 8 Nov 2018 20:27:21 +0800 Subject: [PATCH 17/23] fix test_conv2d (#14330) test=develop --- .../fluid/tests/unittests/test_conv2d_op.py | 38 ++++++++----------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index aba3e7139c..6ab13b5106 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -225,29 +225,29 @@ class TestWithInput1x1Filter1x1(TestConv2dOp): #----------------Conv2dCUDNN---------------- -def create_test_cudnn_class(parent, cls_name): +def create_test_cudnn_class(parent): @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestCUDNNCase(parent): def init_kernel_type(self): self.use_cudnn = True - cls_name = "{0}".format(cls_name) + cls_name = "{0}_{1}".format(parent.__name__, "CUDNN") TestCUDNNCase.__name__ = cls_name globals()[cls_name] = TestCUDNNCase -create_test_cudnn_class(TestConv2dOp, "TestPool2DCUDNNOp") -create_test_cudnn_class(TestWithPad, "TestPool2DCUDNNOpCase1") -create_test_cudnn_class(TestWithStride, "TestPool2DCUDNNOpCase2") -create_test_cudnn_class(TestWithGroup, "TestPool2DCUDNNOpCase3") -create_test_cudnn_class(TestWith1x1, "TestPool2DCUDNNOpCase4") -create_test_cudnn_class(TestWithInput1x1Filter1x1, "TestPool2DCUDNNOpCase4") +create_test_cudnn_class(TestConv2dOp) +create_test_cudnn_class(TestWithPad) +create_test_cudnn_class(TestWithStride) +create_test_cudnn_class(TestWithGroup) +create_test_cudnn_class(TestWith1x1) +create_test_cudnn_class(TestWithInput1x1Filter1x1) #----------------Conv2dCUDNN---------------- -def create_test_cudnn_fp16_class(parent, cls_name, grad_check=True): +def create_test_cudnn_fp16_class(parent, grad_check=True): @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestConv2DCUDNNFp16(parent): @@ -279,23 +279,17 @@ def create_test_cudnn_fp16_class(parent, cls_name, grad_check=True): max_relative_error=0.02, no_grad_set=set(['Input'])) - cls_name = "{0}".format(cls_name) + cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16") TestConv2DCUDNNFp16.__name__ = cls_name globals()[cls_name] = TestConv2DCUDNNFp16 -create_test_cudnn_fp16_class( - TestConv2dOp, "TestPool2DCUDNNFp16Op", grad_check=False) -create_test_cudnn_fp16_class( - TestWithPad, "TestPool2DCUDNNFp16OpCase1", grad_check=False) -create_test_cudnn_fp16_class( - TestWithStride, "TestPool2DCUDNNFp16OpCase2", grad_check=False) -create_test_cudnn_fp16_class( - TestWithGroup, "TestPool2DCUDNNFp16OpCase3", grad_check=False) -create_test_cudnn_fp16_class( - TestWith1x1, "TestPool2DCUDNNFp16OpCase4", grad_check=False) -create_test_cudnn_fp16_class( - TestWithInput1x1Filter1x1, "TestPool2DCUDNNFp16OpCase4", grad_check=False) +create_test_cudnn_fp16_class(TestConv2dOp, grad_check=False) +create_test_cudnn_fp16_class(TestWithPad, grad_check=False) +create_test_cudnn_fp16_class(TestWithStride, grad_check=False) +create_test_cudnn_fp16_class(TestWithGroup, grad_check=False) +create_test_cudnn_fp16_class(TestWith1x1, grad_check=False) +create_test_cudnn_fp16_class(TestWithInput1x1Filter1x1, grad_check=False) # -------TestDepthwiseConv From 080112276aba3fcccde0767b1a5eb20ec777fedb Mon Sep 17 00:00:00 2001 From: Krzysztof Binias Date: Thu, 8 Nov 2018 13:50:08 +0100 Subject: [PATCH 18/23] Fixed problem with array subscript is above array bounds in MKL-DNN jit_uni_reorder_utils.cpp:prb_simplify function test=develop --- cmake/external/mkldnn.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index 58d1333f93..9fea9ca05b 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -45,7 +45,7 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML") ELSE() MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN") ENDIF() -SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result") +SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds") SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value") SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}") From 9735e3016af5ba8a60e3a48db35adefec841ba52 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 8 Nov 2018 20:52:10 +0800 Subject: [PATCH 19/23] fix test the build strategy is finalized after create_passes. So future change of build strategy has no effects. test=develop --- python/paddle/fluid/tests/unittests/test_dist_base.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 45fae63b01..4b8a215190 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -98,17 +98,18 @@ class TestDistRunnerBase(object): strategy.allow_op_delay = False build_stra = fluid.BuildStrategy() - if args.batch_merge_repeat > 1: - pass_builder = build_stra._create_passes_from_strategy() - mypass = pass_builder.insert_pass( - len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass") - mypass.set_int("num_repeats", args.batch_merge_repeat) if args.use_reduce: build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce else: build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce + if args.batch_merge_repeat > 1: + pass_builder = build_stra._create_passes_from_strategy() + mypass = pass_builder.insert_pass( + len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass") + mypass.set_int("num_repeats", args.batch_merge_repeat) + exe = fluid.ParallelExecutor( args.use_cuda, loss_name=avg_cost.name, From abe209234fc6660e75c80b0823a24e7f48b0204a Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Fri, 9 Nov 2018 13:24:22 +0800 Subject: [PATCH 20/23] Exhaustive search for cuDNN conv. (#14286) * exhaustive search for cuDNN conv. * Refine code and add unit testing. * Fix model load in fluid/inference and unit testing in conv2d * Follow comments. * Fix compiling test=develop --- .../framework/ir/graph_pattern_detector.cc | 1 + .../fluid/inference/api/analysis_predictor.h | 2 + paddle/fluid/inference/api/api.cc | 1 - paddle/fluid/inference/api/helper.h | 3 +- paddle/fluid/inference/io.cc | 3 +- paddle/fluid/inference/tensorrt/engine.h | 2 +- .../operators/add_position_encoding_op.h | 7 +- paddle/fluid/operators/conv_cudnn_op.cu.cc | 207 ++++++++++++++++-- paddle/fluid/operators/conv_cudnn_op_cache.h | 90 ++++++++ paddle/fluid/operators/conv_op.cc | 11 +- paddle/fluid/operators/tensorrt_engine_op.h | 2 +- paddle/fluid/platform/device_context.cc | 5 +- paddle/fluid/platform/dynload/cudnn.h | 93 ++++---- python/paddle/fluid/__init__.py | 3 +- python/paddle/fluid/layers/nn.py | 17 +- .../fluid/tests/unittests/test_conv2d_op.py | 10 +- .../fluid/tests/unittests/test_conv3d_op.py | 6 + 17 files changed, 384 insertions(+), 79 deletions(-) create mode 100644 paddle/fluid/operators/conv_cudnn_op_cache.h diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index b20d701322..fa713fe1dd 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index b7dc206733..a9f4cce6df 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -13,6 +13,8 @@ // limitations under the License. #pragma once +#include +#include #include #include #include "paddle/fluid/framework/naive_executor.h" diff --git a/paddle/fluid/inference/api/api.cc b/paddle/fluid/inference/api/api.cc index 01ea942d3c..20fab8078f 100644 --- a/paddle/fluid/inference/api/api.cc +++ b/paddle/fluid/inference/api/api.cc @@ -16,7 +16,6 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle_inference_api.h" namespace paddle { diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index e46dc13269..af21c0095c 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -16,13 +16,14 @@ #include #include +#include #include // NOLINT #include #include #include #include +#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/string/printf.h" -#include "paddle_inference_api.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index e246a06fd0..31f43bfdca 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -59,7 +59,8 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) { bool IsPersistable(const framework::VarDesc* var) { if (var->Persistable() && var->GetType() != framework::proto::VarType::FEED_MINIBATCH && - var->GetType() != framework::proto::VarType::FETCH_LIST) { + var->GetType() != framework::proto::VarType::FETCH_LIST && + var->GetType() != framework::proto::VarType::RAW) { return true; } return false; diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index d9d3827321..828181200e 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -134,7 +134,7 @@ class TensorRTEngine : public EngineBase { std::unordered_map> weight_map; - // TODO: (NHZLX) + // TODO(NHZLX) // In the normal case, the paddle-trt exists bug when runing the googlenet. // When there are more than two convolutions of 1 * 1 with the same input, the // paddle-tensorrt will do the merging optimization, which fuse those conv diff --git a/paddle/fluid/operators/add_position_encoding_op.h b/paddle/fluid/operators/add_position_encoding_op.h index 5f371235f1..0b40d3de89 100644 --- a/paddle/fluid/operators/add_position_encoding_op.h +++ b/paddle/fluid/operators/add_position_encoding_op.h @@ -66,9 +66,10 @@ class AddPositionEncodingKernel : public framework::OpKernel { x_lod.empty() ? max_seq_len : x_lod[0][i + 1] - x_lod[0][i]; for (int j = 0; j < max_length; ++j) { for (int k = 0; k < half_size; ++k) { - const double val = (half_size > 1) - ? j / pow(10000.0, double(k) / (half_size - 1)) - : j / 10000.0; + const double val = + (half_size > 1) + ? j / pow(10000.0, static_cast(k) / (half_size - 1)) + : j / 10000.0; dst_ptr[k] = src_ptr[k] * alpha + sin(val) * beta; dst_ptr[half_size + k] = src_ptr[half_size + k] * alpha + cos(val) * beta; diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 76eda51ad4..5f8d510be7 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -15,15 +15,22 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/profiler.h" DEFINE_bool(cudnn_deterministic, false, "Whether allow using an autotuning algorithm for convolution " "operator. The autotuning algorithm may be non-deterministic. If " "true, the algorithm is deterministic."); +DEFINE_uint64(conv_workspace_size_limit, 4096, + "cuDNN convolution workspace limit in MB unit."); +DEFINE_bool(cudnn_exhaustive_search, false, + "Whether enable exhaustive search for cuDNN convolution or " + "not, defalut is False."); namespace paddle { namespace operators { @@ -36,13 +43,25 @@ using DataLayout = platform::DataLayout; template using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; +static constexpr char kCUDNNFwdAlgoCache[] = "kCUDNNFwdAlgoCache"; +static constexpr char kCUDNNBwdDataAlgoCache[] = "kCUDNNBwdDataAlgoCache"; +static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache"; + static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = static_cast(1024) * 1024 * 1024; +static constexpr size_t kNUM_CUDNN_FWD_ALGS = + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; +static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; +static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = + CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; + template class CUDNNConvOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use CUDAPlace."); auto* input = ctx.Input("Input"); @@ -55,6 +74,8 @@ class CUDNNConvOpKernel : public framework::OpKernel { int groups = ctx.Attr("groups"); int64_t user_workspace_size = static_cast(ctx.Attr("workspace_size_MB")); + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); const T* input_data = input->data(); const T* filter_data = filter->data(); @@ -120,19 +141,19 @@ class CUDNNConvOpKernel : public framework::OpKernel { // ------------------- cudnn conv workspace --------------------- size_t workspace_size_in_bytes; // final workspace to allocate. size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; - if (user_workspace_size > 0) { - workspace_size_limit = user_workspace_size * 1024 * 1024; + if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { + int64_t max_user_size = + std::max(static_cast(FLAGS_conv_workspace_size_limit), + user_workspace_size); + workspace_size_limit = max_user_size * 1024 * 1024; } + // ------------------- cudnn conv algorithm --------------------- cudnnConvolutionFwdAlgo_t algo; - auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( - handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); - + bool half_float = false; #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) // Tensor core is supported since the volta GPU and // is only enabled when input and filter data are float16 @@ -143,6 +164,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); // Currently tensor core is only enabled using this algo algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + half_float = true; VLOG(5) << "use cudnn_tensor_op_math"; } else { CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( @@ -151,6 +173,57 @@ class CUDNNConvOpKernel : public framework::OpKernel { } #endif + auto x_dims = framework::vectorize(input->dims()); + auto f_dims = framework::vectorize(filter->dims()); + if ((!exhaustive_search) && (!half_float)) { + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + VLOG(3) << "cuDNN forward algo " << algo; + } else if (exhaustive_search && (!half_float)) { + AlgorithmsCache* algo_cache = nullptr; + if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) { + algo_cache = + ctx.scope() + .FindVar(kCUDNNFwdAlgoCache) + ->GetMutable>(); + } else { + algo_cache = + const_cast(ctx.scope()) + .Var(kCUDNNFwdAlgoCache) + ->GetMutable>(); + } + algo = algo_cache->GetAlgorithm( + x_dims, f_dims, strides, paddings, dilations, 0, [&]() { + int returned_algo_count; + std::array + fwd_perf_stat; + auto cudnn_find_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE( + platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( + handle, cudnn_input_desc, input_data, cudnn_filter_desc, + filter_data, cudnn_conv_desc, cudnn_output_desc, + output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, + fwd_perf_stat.data(), cudnn_workspace, + workspace_size_limit)); + }; + workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit); + + VLOG(3) << "Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = fwd_perf_stat[i]; + VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time + << " " << stat.memory; + } + return fwd_perf_stat[0].algo; + }); + VLOG(3) << "choose algo " << algo; + } else { + PADDLE_ENFORCE(half_float, + "cuDNN exhaustive search doesn't support half float."); + } + // get workspace size able to allocate CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, @@ -162,7 +235,6 @@ class CUDNNConvOpKernel : public framework::OpKernel { // ------------------- cudnn conv forward --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); for (int i = 0; i < groups; i++) { auto cudnn_func = [&](void* cudnn_workspace) { CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( @@ -180,6 +252,7 @@ template class CUDNNConvGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use CUDAPlace."); auto input = ctx.Input("Input"); @@ -198,6 +271,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { int groups = ctx.Attr("groups"); int64_t user_workspace_size = static_cast(ctx.Attr("workspace_size_MB")); + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); + if (exhaustive_search && FLAGS_cudnn_deterministic) { + PADDLE_THROW( + "Cann't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time."); + } // ------------------- cudnn descriptors --------------------- ScopedTensorDescriptor input_desc; @@ -265,14 +345,66 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { cudnnConvolutionBwdFilterAlgo_t filter_algo; size_t workspace_size_in_bytes = 0, tmp_size = 0; size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; - if (user_workspace_size > 0) { - workspace_size_limit = user_workspace_size * 1024 * 1024; + if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { + int64_t max_user_size = + std::max(static_cast(FLAGS_conv_workspace_size_limit), + user_workspace_size); + workspace_size_limit = max_user_size * 1024 * 1024; } - auto& dev_ctx = ctx.template device_context(); + auto x_dims = framework::vectorize(input->dims()); + auto f_dims = framework::vectorize(filter->dims()); auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); if (input_grad) { - if (!FLAGS_cudnn_deterministic) { + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + if (exhaustive_search) { + AlgorithmsCache* data_algo_cache; + if (ctx.scope().FindVar(kCUDNNBwdDataAlgoCache)) { + data_algo_cache = + ctx.scope() + .FindVar(kCUDNNBwdDataAlgoCache) + ->GetMutable< + AlgorithmsCache>(); + } else { + data_algo_cache = + const_cast(ctx.scope()) + .Var(kCUDNNBwdDataAlgoCache) + ->GetMutable< + AlgorithmsCache>(); + } + data_algo = data_algo_cache->GetAlgorithm( + x_dims, f_dims, strides, paddings, dilations, 0, [&]() { + int returned_algo_count; + std::array + data_perf_stat; + auto cudnn_find_bd_data_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE( + platform::dynload:: + cudnnFindConvolutionBackwardDataAlgorithmEx( + handle, cudnn_filter_desc, filter_data, + cudnn_output_grad_desc, output_grad_data, + cudnn_conv_desc, cudnn_input_desc, input_grad_data, + kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count, + data_perf_stat.data(), cudnn_workspace, + workspace_size_limit)); + }; + workspace_handle.RunFunc(cudnn_find_bd_data_func, + workspace_size_limit); + + VLOG(3) << "Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = data_perf_stat[i]; + VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time + << " " << stat.memory; + } + return data_perf_stat[0].algo; + }); + VLOG(3) << "cuDNN backward data algo " << data_algo; + } else if (FLAGS_cudnn_deterministic) { + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + } else { CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( handle, cudnn_filter_desc, @@ -285,10 +417,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { cudnn_input_desc, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, &data_algo)); - } else { - data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } - CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( handle, cudnn_filter_desc, cudnn_output_grad_desc, @@ -297,17 +426,54 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } if (filter_grad) { - if (!FLAGS_cudnn_deterministic) { + T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); + if (exhaustive_search) { + AlgorithmsCache* f_algo_cache; + if (ctx.scope().FindVar(kCUDNNBwdFilterAlgoCache)) { + f_algo_cache = + ctx.scope() + .FindVar(kCUDNNBwdFilterAlgoCache) + ->GetMutable< + AlgorithmsCache>(); + } else { + f_algo_cache = + const_cast(ctx.scope()) + .Var(kCUDNNBwdFilterAlgoCache) + ->GetMutable< + AlgorithmsCache>(); + } + filter_algo = f_algo_cache->GetAlgorithm( + x_dims, f_dims, strides, paddings, dilations, 0, [&]() { + int returned_algo_count; + std::array + filter_perf_stat; + auto cudnn_find_bd_f_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE( + platform::dynload:: + cudnnFindConvolutionBackwardFilterAlgorithmEx( + handle, cudnn_input_desc, input_data, + cudnn_output_grad_desc, output_grad_data, + cudnn_conv_desc, cudnn_filter_desc, + filter_grad_data, kNUM_CUDNN_BWD_FILTER_ALGS, + &returned_algo_count, filter_perf_stat.data(), + cudnn_workspace, workspace_size_limit)); + }; + workspace_handle.RunFunc(cudnn_find_bd_f_func, + workspace_size_limit); + return filter_perf_stat[0].algo; + }); + VLOG(3) << "cuDNN backward filter algo " << filter_algo; + } else if (FLAGS_cudnn_deterministic) { + filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + } else { CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, cudnn_filter_desc, CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, &filter_algo)); - } else { - filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } - CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, @@ -317,7 +483,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { // ------------------- cudnn conv backward data --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset input_grad. diff --git a/paddle/fluid/operators/conv_cudnn_op_cache.h b/paddle/fluid/operators/conv_cudnn_op_cache.h new file mode 100644 index 0000000000..4b534321f7 --- /dev/null +++ b/paddle/fluid/operators/conv_cudnn_op_cache.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace operators { + +template +class AlgorithmsCache { + public: + // Caches the best algorithm for a given + // combination of tensor dimensions & compute data type. + TAlgorithm GetAlgorithm( + const std::vector& dims1, const std::vector& dims2, + const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, + int algorithmFlags, // can set for different data type + std::function gen_func); + + private: + std::unordered_map hash_; + std::mutex mutex_; +}; + +template +TAlgorithm AlgorithmsCache::GetAlgorithm( + const std::vector& dims1, const std::vector& dims2, + const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, int algorithmFlags, + std::function gen_func) { + std::lock_guard lock(mutex_); + int64_t seed = 0; + // Hash all of the inputs, use to try and look up a previously + // discovered algorithm, or fall back to generating a new one. + std::hash hashFn; + // do hash like boost + // https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x + for (const auto num : dims1) { + seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + + for (const auto num : dims2) { + seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1; + } + + for (const auto num : strides) { + seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + + (seed >> 2) + 2; + } + + for (const auto num : paddings) { + seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + + (seed >> 2) + 3; + } + + for (const auto num : dilations) { + seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + + (seed >> 2) + 4; + } + + seed ^= hashFn(static_cast(algorithmFlags)) + 0x9e3779b9 + + (seed << 6) + (seed >> 2) + 5; + + if (seed == 0) return gen_func(); + + if (hash_.find(seed) == hash_.end()) { + TAlgorithm value = gen_func(); + hash_[seed] = value; + } + return hash_[seed]; +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 2cd9979bd3..7401f100d7 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -189,6 +189,11 @@ void Conv2DOpMaker::Make() { "workspace size can increase performance but also requires " "better hardware. This size should be chosen carefully.") .SetDefault(4096); + AddAttr("exhaustive_search", + "(bool, default false) cuDNN has many algorithm to calculation " + "convolution, whether enable exhaustive search ", + "for cuDNN convolution or not, defalut is False.") + .SetDefault(false); AddComment(R"DOC( Convolution Operator. @@ -283,7 +288,11 @@ void Conv3DOpMaker::Make() { "workspace size can increase performance but also requires " "better hardware. This size should be chosen carefully.") .SetDefault(4096); - + AddAttr("exhaustive_search", + "(bool, default false) cuDNN has many algorithm to calculation " + "convolution, whether enable exhaustive search ", + "for cuDNN convolution or not, defalut is False.") + .SetDefault(false); AddComment(R"DOC( Convolution3D Operator. diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 673f86da76..b9faac0858 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -34,7 +34,7 @@ namespace operators { using FluidDT = framework::proto::VarType_Type; using TRT_DT = nvinfer1::DataType; -namespace { +namespace { // NOLINT TRT_DT FluidDataType2TRT(FluidDT type) { switch (type) { diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index ff49a1d57f..f5541014af 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -204,7 +204,10 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) << "." << (driver_version_ % 100) / 10 << ", Runtime Version: " << runtime_version_ / 1000 << "." << (runtime_version_ % 100) / 10; - + size_t cudnn_dso_ver = dynload::cudnnGetVersion(); + LOG_FIRST_N(WARNING, 1) << "device: " << place_.device + << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "." + << (cudnn_dso_ver % 100) / 10 << "."; callback_manager_.reset(new StreamCallbackManager(stream_)); } diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index d3d754b6f5..c26143d2f2 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -65,51 +65,54 @@ extern void EnforceCUDNNLoaded(const char* fn_name); * include all needed cudnn functions in HPPL * different cudnn version has different interfaces **/ -#define CUDNN_DNN_ROUTINE_EACH(__macro) \ - __macro(cudnnSetTensor4dDescriptor); \ - __macro(cudnnSetTensor4dDescriptorEx); \ - __macro(cudnnSetTensorNdDescriptor); \ - __macro(cudnnGetTensorNdDescriptor); \ - __macro(cudnnGetConvolutionNdForwardOutputDim); \ - __macro(cudnnGetConvolutionForwardAlgorithm); \ - __macro(cudnnCreateTensorDescriptor); \ - __macro(cudnnDestroyTensorDescriptor); \ - __macro(cudnnCreateFilterDescriptor); \ - __macro(cudnnSetFilter4dDescriptor); \ - __macro(cudnnSetFilterNdDescriptor); \ - __macro(cudnnGetFilterNdDescriptor); \ - __macro(cudnnSetPooling2dDescriptor); \ - __macro(cudnnSetPoolingNdDescriptor); \ - __macro(cudnnGetPoolingNdDescriptor); \ - __macro(cudnnDestroyFilterDescriptor); \ - __macro(cudnnCreateConvolutionDescriptor); \ - __macro(cudnnCreatePoolingDescriptor); \ - __macro(cudnnDestroyPoolingDescriptor); \ - __macro(cudnnSetConvolution2dDescriptor); \ - __macro(cudnnDestroyConvolutionDescriptor); \ - __macro(cudnnSetConvolutionNdDescriptor); \ - __macro(cudnnGetConvolutionNdDescriptor); \ - __macro(cudnnDeriveBNTensorDescriptor); \ - __macro(cudnnCreateSpatialTransformerDescriptor); \ - __macro(cudnnSetSpatialTransformerNdDescriptor); \ - __macro(cudnnDestroySpatialTransformerDescriptor); \ - __macro(cudnnSpatialTfGridGeneratorForward); \ - __macro(cudnnSpatialTfGridGeneratorBackward); \ - __macro(cudnnSpatialTfSamplerForward); \ - __macro(cudnnSpatialTfSamplerBackward); \ - __macro(cudnnCreate); \ - __macro(cudnnDestroy); \ - __macro(cudnnSetStream); \ - __macro(cudnnActivationForward); \ - __macro(cudnnConvolutionForward); \ - __macro(cudnnConvolutionBackwardBias); \ - __macro(cudnnGetConvolutionForwardWorkspaceSize); \ - __macro(cudnnTransformTensor); \ - __macro(cudnnPoolingForward); \ - __macro(cudnnPoolingBackward); \ - __macro(cudnnSoftmaxBackward); \ - __macro(cudnnSoftmaxForward); \ - __macro(cudnnGetVersion); \ +#define CUDNN_DNN_ROUTINE_EACH(__macro) \ + __macro(cudnnSetTensor4dDescriptor); \ + __macro(cudnnSetTensor4dDescriptorEx); \ + __macro(cudnnSetTensorNdDescriptor); \ + __macro(cudnnGetTensorNdDescriptor); \ + __macro(cudnnGetConvolutionNdForwardOutputDim); \ + __macro(cudnnGetConvolutionForwardAlgorithm); \ + __macro(cudnnCreateTensorDescriptor); \ + __macro(cudnnDestroyTensorDescriptor); \ + __macro(cudnnCreateFilterDescriptor); \ + __macro(cudnnSetFilter4dDescriptor); \ + __macro(cudnnSetFilterNdDescriptor); \ + __macro(cudnnGetFilterNdDescriptor); \ + __macro(cudnnSetPooling2dDescriptor); \ + __macro(cudnnSetPoolingNdDescriptor); \ + __macro(cudnnGetPoolingNdDescriptor); \ + __macro(cudnnDestroyFilterDescriptor); \ + __macro(cudnnCreateConvolutionDescriptor); \ + __macro(cudnnCreatePoolingDescriptor); \ + __macro(cudnnDestroyPoolingDescriptor); \ + __macro(cudnnSetConvolution2dDescriptor); \ + __macro(cudnnDestroyConvolutionDescriptor); \ + __macro(cudnnSetConvolutionNdDescriptor); \ + __macro(cudnnGetConvolutionNdDescriptor); \ + __macro(cudnnDeriveBNTensorDescriptor); \ + __macro(cudnnCreateSpatialTransformerDescriptor); \ + __macro(cudnnSetSpatialTransformerNdDescriptor); \ + __macro(cudnnDestroySpatialTransformerDescriptor); \ + __macro(cudnnSpatialTfGridGeneratorForward); \ + __macro(cudnnSpatialTfGridGeneratorBackward); \ + __macro(cudnnSpatialTfSamplerForward); \ + __macro(cudnnSpatialTfSamplerBackward); \ + __macro(cudnnCreate); \ + __macro(cudnnDestroy); \ + __macro(cudnnSetStream); \ + __macro(cudnnActivationForward); \ + __macro(cudnnConvolutionForward); \ + __macro(cudnnConvolutionBackwardBias); \ + __macro(cudnnGetConvolutionForwardWorkspaceSize); \ + __macro(cudnnTransformTensor); \ + __macro(cudnnPoolingForward); \ + __macro(cudnnPoolingBackward); \ + __macro(cudnnSoftmaxBackward); \ + __macro(cudnnSoftmaxForward); \ + __macro(cudnnGetVersion); \ + __macro(cudnnFindConvolutionForwardAlgorithmEx); \ + __macro(cudnnFindConvolutionBackwardFilterAlgorithmEx); \ + __macro(cudnnFindConvolutionBackwardDataAlgorithmEx); \ __macro(cudnnGetErrorString); CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index c4cfd8e468..b79b00846e 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -126,7 +126,8 @@ def __bootstrap__(): if core.is_compiled_with_cuda(): read_env_flags += [ - 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic' + 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', + 'conv_workspace_size_limit', 'cudnn_exhaustive_search' ] core.init_gflags([sys.argv[0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 595537ab1e..0a39587574 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -27,6 +27,7 @@ from .tensor import concat from . import utils from .. import unique_name from functools import reduce +from .. import core __all__ = [ 'fc', @@ -1666,6 +1667,20 @@ def conv2d(input, pre_bias = helper.create_variable_for_type_inference(dtype) + if use_cudnn: + helper.create_variable( + name="kCUDNNFwdAlgoCache", + persistable=True, + type=core.VarDesc.VarType.RAW) + helper.create_variable( + name="kCUDNNBwdDataAlgoCache", + persistable=True, + type=core.VarDesc.VarType.RAW) + helper.create_variable( + name="kCUDNNBwdFilterAlgoCache", + persistable=True, + type=core.VarDesc.VarType.RAW) + helper.append_op( type=l_type, inputs={ @@ -1679,7 +1694,7 @@ def conv2d(input, 'dilations': dilation, 'groups': groups, 'use_cudnn': use_cudnn, - 'use_mkldnn': False + 'use_mkldnn': False, }) pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 6ab13b5106..ebbbf3ab8b 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -67,6 +67,7 @@ class TestConv2dOp(OpTest): def setUp(self): self.op_type = "conv2d" self.use_cudnn = False + self.exhaustive_search = False self.use_cuda = False self.use_mkldnn = False self.data_format = "AnyLayout" @@ -98,7 +99,8 @@ class TestConv2dOp(OpTest): 'dilations': self.dilations, 'use_cudnn': self.use_cudnn, 'use_mkldnn': self.use_mkldnn, - 'data_format': self.data_format + 'data_format': self.data_format, + 'exhaustive_search': self.exhaustive_search } self.outputs = {'Output': output} @@ -361,6 +363,12 @@ class TestDepthwiseConvWithDilation2(TestConv2dOp): self.op_type = "depthwise_conv2d" +class TestCUDNNExhaustiveSearch(TestConv2dOp): + def init_kernel_type(self): + self.use_cudnn = True + self.exhaustive_search = True + + # Please Don't remove the following code. # Currently, CI use cudnn V5.0 which not support dilation conv. # class TestCUDNNWithDilation(TestWithDilation): diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_op.py index ddaf99fe06..69c5ab7a4a 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_op.py @@ -335,6 +335,12 @@ class TestFP16WithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1): self.check_output_with_place(place, atol=2e-2) +class TestCUDNNExhaustiveSearch(TestCUDNN): + def init_kernel_type(self): + self.use_cudnn = True + self.exhaustive_search = True + + # FIXME(typhoonzero): find a way to determine if # using cudnn > 6 in python # class TestWithDilationCUDNN(TestWithDilation): From d08334011a155f00bc1160adf2e400a00f7c66c3 Mon Sep 17 00:00:00 2001 From: peizhilin Date: Fri, 9 Nov 2018 14:09:27 +0800 Subject: [PATCH 21/23] fix merge issue --- paddle/fluid/framework/ir/pass.h | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index a9199414ba..e1767337ab 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -217,28 +217,6 @@ struct PassRegistrar : public Registrar { extern int TouchPassRegistrar_##pass_type(); \ static int use_pass_itself_##pass_type##_ __UNUSED__() = \ TouchPassRegistrar_##pass_type() -#else -#define REGISTER_PASS(pass_type, pass_class) \ - STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ - __reg_pass__##pass_type, \ - "REGISTER_PASS must be called in global namespace"); \ - static ::paddle::framework::ir::PassRegistrar \ - __pass_registrar_##pass_type##__(#pass_type); \ - int TouchPassRegistrar_##pass_type() { \ - __pass_registrar_##pass_type##__.Touch(); \ - return 0; \ - } \ - static ::paddle::framework::ir::PassRegistrar UNUSED( \ - &__pass_tmp_registrar_##pass_type##__) = \ - __pass_registrar_##pass_type##__ - -#define USE_PASS(pass_type) \ - STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ - __use_pass_itself_##pass_type, \ - "USE_PASS must be called in global namespace"); \ - extern int TouchPassRegistrar_##pass_type(); \ - static int UNUSED(use_pass_itself_##pass_type##_) = \ - TouchPassRegistrar_##pass_type() } // namespace ir } // namespace framework From 4b1f1a878732b920f94f3e42d0cb328c308d4bca Mon Sep 17 00:00:00 2001 From: peizhilin Date: Fri, 9 Nov 2018 14:21:34 +0800 Subject: [PATCH 22/23] fix merge issue --- paddle/fluid/inference/analysis/helper.h | 1 + paddle/fluid/platform/init.cc | 2 ++ paddle/fluid/platform/port.h | 3 +-- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index 5151e2b69a..ea568a581d 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/port.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 4910baec6a..092585ed2a 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -175,7 +175,9 @@ void InitGLOG(const std::string &prog_name) { // glog will not hold the ARGV[0] inside. // Use strdup to alloc a new string. google::InitGoogleLogging(strdup(prog_name.c_str())); +#ifndef _WIN32 google::InstallFailureSignalHandler(); +#endif } } // namespace framework diff --git a/paddle/fluid/platform/port.h b/paddle/fluid/platform/port.h index cf9f4aa95b..4ff07edc19 100644 --- a/paddle/fluid/platform/port.h +++ b/paddle/fluid/platform/port.h @@ -30,11 +30,10 @@ #include #include // std::accumulate #else +#include #include // _popen, _pclose #include -#if defined(_WIN32) #include // std::accumulate in msvc -#endif // windows version of __attribute__((unused)) #define UNUSED __pragma(warning(suppress : 4100)) From 350f1f397178ac7d6a73f0c9b5cb00c2d65e5e47 Mon Sep 17 00:00:00 2001 From: peizhilin Date: Fri, 9 Nov 2018 14:29:58 +0800 Subject: [PATCH 23/23] remove duplicate function definition --- paddle/fluid/inference/analysis/helper.h | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index ea568a581d..2517f5a373 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -125,20 +125,6 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) { return *var->GetMutable(); } -static void ExecShellCommand(const std::string &cmd, std::string *message) { - char buffer[128]; - std::shared_ptr pipe(popen(cmd.c_str(), "r"), pclose); - if (!pipe) { - LOG(ERROR) << "error running command: " << cmd; - return; - } - while (!feof(pipe.get())) { - if (fgets(buffer, 128, pipe.get()) != nullptr) { - *message += buffer; - } - } -} - static framework::proto::ProgramDesc LoadProgramDesc( const std::string &model_path) { std::ifstream fin(model_path, std::ios::in | std::ios::binary);