diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 7de97c69e7..38b11d9113 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include #include #include @@ -23,32 +22,14 @@ limitations under the License. */ #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/inference/api/api_impl.h" +#include "paddle/fluid/inference/api/timer.h" #include "paddle/fluid/platform/profiler.h" DEFINE_bool(profile, false, "Turn on profiler for fluid"); namespace paddle { namespace { - -// Timer for timer -class Timer { - public: - double start; - double startu; - void tic() { - struct timeval tp; - gettimeofday(&tp, NULL); - start = tp.tv_sec; - startu = tp.tv_usec; - } - double toc() { - struct timeval tp; - gettimeofday(&tp, NULL); - double used_time_ms = - (tp.tv_sec - start) * 1000.0 + (tp.tv_usec - startu) / 1000.0; - return used_time_ms; - } -}; +using paddle::inference::Timer; template std::string num2str(T a) { @@ -80,7 +61,7 @@ void NativePaddlePredictor::PrepareFeedFetch() { bool NativePaddlePredictor::Init( std::shared_ptr parent_scope) { VLOG(3) << "Predictor::init()"; - +#if !defined(_WIN32) if (FLAGS_profile) { LOG(WARNING) << "Profiler is actived, might affect the performance"; LOG(INFO) << "You can turn off by set gflags '-profile false'"; @@ -89,6 +70,7 @@ bool NativePaddlePredictor::Init( : platform::ProfilerState::kCPU; platform::EnableProfiler(tracking_device); } +#endif if (config_.use_gpu) { place_ = paddle::platform::CUDAPlace(config_.device); @@ -133,10 +115,12 @@ bool NativePaddlePredictor::Init( } NativePaddlePredictor::~NativePaddlePredictor() { +#if !defined(_WIN32) if (FLAGS_profile) { platform::DisableProfiler(platform::EventSortingKey::kTotal, "./profile.log"); } +#endif if (sub_scope_) { scope_->DeleteScope(sub_scope_); } diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index a697218377..afb46a7139 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -3,6 +3,11 @@ cmake_minimum_required(VERSION 3.0) project(cpp_inference_demo CXX C) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +if (WIN32) +set(CMAKE_STATIC_LIBRARY_PREFIX "lib") +else() +set(CMAKE_STATIC_LIBRARY_PREFIX "") +endif() if(NOT DEFINED PADDLE_LIB) message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib") @@ -32,44 +37,56 @@ endif(NOT WIN32) include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/eigen3") +if (NOT WIN32) link_directories("${PADDLE_LIB}/third_party/install/snappy/lib") link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") +link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") +endif(NOT WIN32) + link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/glog/lib") link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") -link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") +link_directories("${PADDLE_LIB}/paddle/fluid/inference") add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) if(WITH_MKL) include_directories("${PADDLE_LIB}/third_party/install/mklml/include") - set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel.so - ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5.so) + set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} + ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") if(EXISTS ${MKLDNN_PATH}) include_directories("${MKLDNN_PATH}/include") set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) endif() else() - set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.a) + set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX}) endif() # Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a if(WITH_STATIC_LIB) set(DEPS - ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid.a) + ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX}) else() set(DEPS - ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid.so) + ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX}) endif() -set(EXTERNAL_LIB "-lrt -ldl -lpthread") +if (NOT WIN32) +set(EXTERNAL_LIB "-lrt -ldl -lpthread") set(DEPS ${DEPS} ${MATH_LIB} ${MKLDNN_LIB} glog gflags protobuf snappystream snappy z ${EXTERNAL_LIB}) +else() +set(DEPS ${DEPS} + ${MATH_LIB} ${MKLDNN_LIB} + ${CMAKE_STATIC_LIBRARY_PREFIX}glog ${CMAKE_STATIC_LIBRARY_PREFIX}gflags ${CMAKE_STATIC_LIBRARY_PREFIX}protobuf + ${EXTERNAL_LIB}) +endif(NOT WIN32) + if(WITH_GPU) - set(DEPS ${DEPS} ${CUDA_LIB}/libcudart.so) + set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) endif() target_link_libraries(${DEMO_NAME} ${DEPS}) diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index 19832e8909..bdc9a15d54 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -16,35 +16,15 @@ #include #include -#include #include #include #include #include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/inference/api/timer.h" namespace paddle { namespace inference { -// Timer for timer -class Timer { - public: - double start; - double startu; - void tic() { - struct timeval tp; - gettimeofday(&tp, NULL); - start = tp.tv_sec; - startu = tp.tv_usec; - } - double toc() { - struct timeval tp; - gettimeofday(&tp, NULL); - double used_time_ms = - (tp.tv_sec - start) * 1000.0 + (tp.tv_usec - startu) / 1000.0; - return used_time_ms; - } -}; - static void split(const std::string &str, char sep, std::vector *pieces) { pieces->clear(); diff --git a/paddle/fluid/inference/api/timer.h b/paddle/fluid/inference/api/timer.h new file mode 100644 index 0000000000..2df5274dc1 --- /dev/null +++ b/paddle/fluid/inference/api/timer.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include // NOLINT + +namespace paddle { +namespace inference { + +// Timer for timer +class Timer { + public: + std::chrono::high_resolution_clock::time_point start; + std::chrono::high_resolution_clock::time_point startu; + + void tic() { start = std::chrono::high_resolution_clock::now(); } + double toc() { + startu = std::chrono::high_resolution_clock::now(); + std::chrono::duration time_span = + std::chrono::duration_cast>(startu - + start); + double used_time_ms = static_cast(time_span.count()) * 1000.0; + return used_time_ms; + } +}; + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 2f8fadcefe..7ec1e78da4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -178,6 +178,8 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(relu);\n") elseif(${TARGET} STREQUAL "fake_dequantize") file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") + elseif(${TARGET} STREQUAL "fake_quantize") + file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n") elseif(${TARGET} STREQUAL "tensorrt_engine_op") message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference") elseif(${TARGET} STREQUAL "fc") @@ -293,6 +295,7 @@ op_library(extract_rows_op DEPS memory) op_library(flatten_op DEPS reshape_op) op_library(sequence_pad_op DEPS sequence_padding) op_library(unstack_op DEPS stack_op) +op_library(fake_quantize_op DEPS memory) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index a91e0f520e..e608eba05d 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -14,86 +14,198 @@ limitations under the License. */ #include "paddle/fluid/operators/fake_quantize_op.h" #include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { -class FakeQuantizeOp : public framework::OperatorWithKernel { +template +using EigenVectorArrayMap = + Eigen::TensorMap>; + +template +using ConstEigenVectorArrayMap = + Eigen::TensorMap>; + +template +struct FindAbsMaxFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const T* in, + const int num, T* out) { + Eigen::DSizes idim(num); + Eigen::DSizes odim(1); + Eigen::TensorMap> in_e(in, idim); + Eigen::TensorMap> out_e(out, odim); + + out_e = in_e.abs().maximum(); + } +}; + +template struct FindAbsMaxFunctor; + +template +struct ClipAndFakeQuantFunctor { + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + T s = scale.data()[0]; + platform::Transform trans; + trans(ctx, in.data(), in.data() + in.numel(), + out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); + auto in_e = framework::EigenVector::Flatten(in); + auto out_e = framework::EigenVector::Flatten(*out); + + out_e.device(*ctx.eigen_device()) = (bin_cnt / s * in_e).round(); + } +}; + +template struct ClipAndFakeQuantFunctor; + +template +struct FindRangeAbsMaxFunctor { + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& cur_scale, + const framework::Tensor& last_scale, + const framework::Tensor& iter, const int window_size, + framework::Tensor* scales_arr, framework::Tensor* out_scale) { + T* scale_arr = scales_arr->mutable_data(ctx.GetPlace()); + int64_t it = iter.data()[0]; + int idx = it % window_size; + T removed = scale_arr[idx]; + T cur = cur_scale.data()[0]; + scale_arr[idx] = cur; + + T max = last_scale.data()[0]; + if (max < cur) { + max = cur; + } else if (fabs(removed - max) < 1e-6) { + int size = (it > window_size) ? window_size : it; + FindAbsMaxFunctor()(ctx, scale_arr, size, + &max); + } + out_scale->mutable_data(ctx.GetPlace())[0] = max; + } +}; + +template struct FindRangeAbsMaxFunctor; + +class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { public: - FakeQuantizeOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) + FakeQuantizeAbsMaxOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FakeQuantizeOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FakeQuantizeOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"), - "OutMovingScale(Out) of FakeQuantizeOp should not be null"); - // if (ctx->HasInput("InMovingScale")) { - ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale")); - //} - // if (ctx->HasInput("InScales")) { - PADDLE_ENFORCE(ctx->HasOutput("OutScales"), - "OutScales(Out) of FakeQuantizeOp should not be null"); - ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales")); - // PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0], - // ctx->Outputs("OutScales")[0], - // "Mean and MeanOut should share the same memory"); - //} + PADDLE_ENFORCE(ctx->HasOutput("OutScale"), + "Output(Scale) of FakeQuantizeOp should not be null."); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->SetOutputDim("OutScale", {1}); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } }; -class FakeQuantizeOpMaker : public framework::OpProtoAndCheckerMaker { +class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "(Tensor) Input tensor of scale operator."); - AddInput("InScales", "(Tensor) scale buffer, used in static quantization.") - .AsDispensable(); - AddInput("InMovingScale", "Last scale, used in static quantization.") - .AsDispensable(); - AddInput("InCurrentIter", - "Last iteration number, used in static quantization.") - .AsDispensable(); - AddOutput("Out", "(Tensor) Output of quantized low level tensor."); - AddOutput("OutScales", - "(Tensor) scale buffer, used in static quantization.") - .AsDispensable(); - AddOutput("OutMovingScale", " Current scale"); - AddOutput("OutCurrentIter", "Current iteration number.").AsDispensable(); - AddAttr("quantize_type", - "(string, default abs_max)" - "The scaling tpe of the quantize operator.") - .SetDefault("abs_max"); - AddAttr("window_size", "(int, default 10000)").SetDefault(10000); + AddInput("X", "(Tensor) Input is float data type."); + AddOutput("Out", + "(Tensor) Output of quantized low level tensor, " + "but also saved as float data type."); + AddOutput("OutScale", "(Tensor) Current scale"); AddAttr("bit_length", "(int, default 8)") .SetDefault(8) - .AddCustomChecker([](const int &bit_length) { + .AddCustomChecker([](const int& bit_length) { PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, "'bit_length' should be between 1 and 16."); }); - AddAttr("is_test", "").SetDefault(false); AddComment(R"DOC( FakeQuantize operator -quantize_type = abs_max: +$$scale = max(abs(X))$$ +$$range = 2^{bit_length - 1} - 1$$ +$$Out = round(X/scale * range)$$ - $$scale = max(abs(x))$$ +)DOC"); + } +}; -quantize_type = range_abs_max: +class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { + public: + FakeQuantizeRangeAbsMaxOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} - $$scale = max(max(abs(x)), history_abs_max)$$ + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FakeQuantizeRangeAbsMaxOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("OutScale"), + "Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null"); + if (ctx->HasOutput("OutScales")) { + int window_size = ctx->Attrs().Get("window_size"); + ctx->SetOutputDim("OutScales", {window_size}); + } + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->SetOutputDim("OutScale", {1}); + ctx->ShareLoD("X", /*->*/ "Out"); + } -quantize_type = moving_average_abs_max: + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; - $$scale = 0.1*scale+0.9*new_abs_max)$$ +class FakeQuantizeRangeAbsMaxOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) Input is float data type."); + AddInput("InScale", "Last scale."); + AddInput("Iter", "Global step iteration.").AsDispensable(); + AddOutput("Out", "(Tensor) Output of quantized low level tensor."); + AddOutput("OutScale", " Current scale"); + AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable(); + AddAttr("window_size", "(int, default 10000) window range size.") + .SetDefault(10000); + AddAttr("bit_length", "(int, default 8), quantization bit number.") + .SetDefault(8) + .AddCustomChecker([](const int& bit_length) { + PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, + "'bit_length' should be between 1 and 16."); + }); + AddAttr("is_test", "").SetDefault(false); + AddComment(R"DOC( +FakeQuantize operator is used in static quantization. -$$Out = scale*X$$ +$$scale = max(max(abs(x)), history_abs_max)$$ +$$range = 2^{bit_length - 1} - 1$$ +$$Out = round(X/scale * range)$$ )DOC"); } @@ -103,10 +215,16 @@ $$Out = scale*X$$ } // namespace paddle namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp, + ops::FakeQuantizeAbsMaxOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max, + ops::FakeQuantizeAbsMaxKernel); -REGISTER_OPERATOR(fake_quantize, ops::FakeQuantizeOp, ops::FakeQuantizeOpMaker, +REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp, + ops::FakeQuantizeRangeAbsMaxOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - fake_quantize, - ops::FakeQuantizeKernel, - ops::FakeQuantizeKernel); +REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max, + ops::FakeQuantizeRangeAbsMaxKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index be0c6730a5..7c65d6dba7 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -20,7 +21,7 @@ namespace paddle { namespace operators { template -__global__ void FindAbsMaxKernel(const int n, const T* in, T* out) { +__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; @@ -43,7 +44,7 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) { __syncthreads(); for (int i = blockDim.x / 2; i > 0; i >>= 1) { - if (tid < i && shared_max_data[tid] < shared_max_data[tid + i]) { + if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { shared_max_data[tid] = shared_max_data[tid + i]; } __syncthreads(); @@ -53,220 +54,124 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) { } } -float FindAbsMaxGpu(const platform::CUDADeviceContext& ctx, const float* array, - int length) { - float host_max; - int kNumTheads = 1024; - int gridDimx = (kNumTheads - 1 + length) / kNumTheads; - gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx; - framework::Tensor t; - float* device_max = t.mutable_data(framework::make_ddim({gridDimx}), - platform::CUDAPlace()); - FindAbsMaxKernel<<>>(length, array, device_max); - FindAbsMaxKernel< - float><<<1, kNumTheads, kNumTheads * sizeof(float), ctx.stream()>>>( - gridDimx, device_max, device_max); - PADDLE_ENFORCE_EQ( - cudaMemcpy(&host_max, device_max, sizeof(float), cudaMemcpyDeviceToHost), - cudaSuccess, "cudaMemcpy failed"); - return host_max; -} +template +struct FindAbsMaxFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const T* in, + const int num, T* out) { + int block = 1024; + int grid = (block - 1 + num) / block; + grid = (grid > block) ? block : grid; + + framework::Tensor max; + T* max_data = + max.mutable_data(framework::make_ddim({grid}), ctx.GetPlace()); + FindAbsMaxKernel<<>>( + in, num, max_data); + FindAbsMaxKernel<<<1, block, 1024 * sizeof(T), ctx.stream()>>>( + max_data, grid, out); + } +}; + +template struct FindAbsMaxFunctor; template -__global__ void ApplySaturateKernel(const int n, const T* in, T* out, - int* num_saturate, const T min, - const T max) { +__global__ void ClipAndQuantKernel(const T* in, const T* scale, + const int bin_cnt, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; - extern __shared__ int shared_count[]; - shared_count[tid] = 0; + T s = scale[0]; for (int i = bid; i < n; i += blockDim.x * gridDim.x) { - if (in[i] > max) { - out[i] = max; - shared_count[tid] += 1; - } else if (in[i] < min) { - out[i] = min; - shared_count[tid] += 1; - } else { - out[i] = in[i]; - } - } - __syncthreads(); - - for (int i = blockDim.x / 2; i > 0; i >>= 1) { - if (tid < i) { - shared_count[tid] += shared_count[tid + i]; - } - __syncthreads(); - } - if (tid == 0) { - num_saturate[blockIdx.x] = shared_count[0]; + T x = in[bid]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt / s * v; + out[bid] = round(v); } } template -__global__ void ReduceKernel(const int n, const T* in, T* out) { - int tid = threadIdx.x; - extern __shared__ T shared_sum[]; - if (tid < n) { - shared_sum[tid] = in[tid]; +__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, + const T* last_scale, + const int64_t* iter, + const int window_size, T* scale_arr, + T* out_scale, int* need_find_max, + int* out_size) { + int it = iter[0]; + int idx = it % window_size; + T removed = scale_arr[idx]; + T cur = cur_scale[0]; + scale_arr[idx] = cur; + T max = last_scale[0]; + out_scale[0] = max < cur ? cur : max; + if (fabs(removed - max) < 1e-6) { + need_find_max[0] = 1; + out_size[0] = it > window_size ? window_size : it; } else { - shared_sum[tid] = T(0); - } - __syncthreads(); - // blockDim.x must >= n - for (int i = (n + 1) / 2; i > 0; i >>= 1) { - if (tid < i) { - shared_sum[tid] += shared_sum[tid + i]; - } - __syncthreads(); - } - if (tid == 0) { - out[0] = shared_sum[0]; + need_find_max[0] = 0; } } template -int ApplySaturateGpu(const platform::CUDADeviceContext& ctx, const int n, - const T* in, T* out, const T min, const T max) { - int host_num_saturate; - int kNumTheads = 1024; - int gridDimx = (n + kNumTheads - 1) / kNumTheads; - gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx; - framework::Tensor t; - int* device_num_saturate = t.mutable_data( - framework::make_ddim({gridDimx}), platform::CUDAPlace()); - ApplySaturateKernel< - T><<>>( - n, in, out, device_num_saturate, min, max); - ReduceKernel<<<1, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>( - gridDimx, device_num_saturate, device_num_saturate); - PADDLE_ENFORCE_EQ(cudaSuccess, - cudaMemcpy(&host_num_saturate, device_num_saturate, - sizeof(int), cudaMemcpyDeviceToHost), - "cudaMemcpy failed"); - return host_num_saturate; -} - -template -class FakeQuantizeCUDAKernel : public framework::OpKernel { - public: - T FindRangeAbsMax(const platform::CUDADeviceContext& ctx, - framework::Tensor* scale_list, framework::Tensor* out_scale, - const T& cur_scale, int window_size, - int current_iter) const { - T* sl = scale_list->mutable_data(platform::CPUPlace()); - T remove_tmp = sl[current_iter]; - sl[current_iter] = cur_scale; - T& max_scale = out_scale->mutable_data(platform::CPUPlace())[0]; - if (max_scale < cur_scale) { - max_scale = cur_scale; - } else if (fabs(remove_tmp - max_scale) < 1e-6) { - int size = (current_iter > window_size) ? window_size : current_iter; - max_scale = T(FindAbsMaxGpu(ctx, scale_list->data(), size)); +struct FindRangeAbsMaxFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& cur_scale, + const framework::Tensor& last_scale, + const framework::Tensor& iter, const int window_size, + framework::Tensor* scales_arr, framework::Tensor* out_scale) { + auto& gpu_place = boost::get(ctx.GetPlace()); + T* scale_arr = scales_arr->mutable_data(gpu_place); + T* out_scale_data = out_scale->mutable_data(gpu_place); + + framework::Tensor need_find_max, out_size; + int* find_max = need_find_max.mutable_data(gpu_place); + int* out_size_data = out_size.mutable_data(gpu_place); + + FindRangeAbsMaxAndFillArray<<<1, 1, 0, ctx.stream()>>>( + cur_scale.data(), last_scale.data(), iter.data(), + window_size, scale_arr, out_scale_data, find_max, out_size_data); + + int g_find_max; + memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max, + sizeof(int), 0); + if (g_find_max) { + int len; + memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data, + sizeof(int), 0); + FindAbsMaxFunctor()(ctx, scale_arr, len, + out_scale_data); } - return max_scale; - } - - T FindMovingAverageAbsMmax(framework::Tensor* in_scale, - framework::Tensor* out_scale, - const T& cur_scale) const { - T* ins = in_scale->mutable_data(platform::CPUPlace()); - T* outs = out_scale->mutable_data(platform::CPUPlace()); - outs[0] = 0.9 * cur_scale + 0.1 * ins[0]; - return T(outs[0]); } +}; - virtual void Compute(const framework::ExecutionContext& context) const { - PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), - "This kernel only runs on GPU device."); - auto& device_ctx = context.cuda_device_context(); - auto* tensor = context.Output("Out"); - auto* in = context.Input("X"); - const bool is_test = context.Attr("is_test"); - tensor->mutable_data(in->place()); - context.Output("OutMovingScale") - ->mutable_data( - context.Input("InMovingScale")->place()); - auto quantize_type = - static_cast(context.Attr("quantize_type")); - if (quantize_type == std::string("range_abs_max")) { - context.Output("OutScales") - ->mutable_data( - context.Input("InScales")->place()); - context.Output("OutCurrentIter") - ->mutable_data( - context.Input("InCurrentIter")->place()); - } - - T scale = T(1); - int window_size = context.Attr("window_size"); - T bin_cnt = (T)((1 << (context.Attr("bit_length") - 1)) - 1); - if (quantize_type == std::string("abs_max")) { - auto* saving_scale = context.Output("OutMovingScale"); - scale = (T)FindAbsMaxGpu(device_ctx, in->data(), in->numel()); - saving_scale->mutable_data(platform::CPUPlace())[0] = scale; - - auto& device_ctx = context.template device_context(); - auto* scale_list = context.Output("OutScales"); - math::SetConstant scalar; - scale_list->mutable_data(context.GetPlace()); - scalar(device_ctx, scale_list, static_cast(0)); - auto* iter = context.Output("OutCurrentIter"); - iter->mutable_data(context.GetPlace()); - scalar(device_ctx, iter, static_cast(0)); - } else if (quantize_type == std::string("range_abs_max")) { - auto* moving_scale = const_cast( - context.Input("InMovingScale")); - if (is_test) { - scale = moving_scale->mutable_data(platform::CPUPlace())[0]; - } else { - auto* it = const_cast( - context.Input("InCurrentIter")); - auto* iter = context.Output("OutCurrentIter"); - int* last_iter = it->mutable_data(platform::CPUPlace()); - int* current_iter = iter->mutable_data(platform::CPUPlace()); - auto* scale_list = context.Output("OutScales"); - auto* saving_scale = - context.Output("OutMovingScale"); - scale = (T)FindAbsMaxGpu(device_ctx, in->data(), in->numel()); - scale = FindRangeAbsMax(device_ctx, scale_list, saving_scale, scale, - window_size, current_iter[0]); - (*current_iter) = (*last_iter) + 1; - } - } else if (quantize_type == std::string("moving_average_abs_max")) { - auto* moving_scale = const_cast( - context.Input("InMovingScale")); - if (is_test) { - scale = moving_scale->mutable_data(platform::CPUPlace())[0]; - } else { - scale = (T)FindAbsMaxGpu(device_ctx, in->data(), in->numel()); - auto* saving_scale = - context.Output("OutMovingScale"); - scale = FindMovingAverageAbsMmax( - const_cast(moving_scale), saving_scale, scale); - } - } - - ApplySaturateGpu(device_ctx, in->numel(), in->data(), - tensor->mutable_data(in->place()), -scale, scale); - scale = bin_cnt / scale; +template struct FindRangeAbsMaxFunctor; - auto& dev = - *context.template device_context().eigen_device(); - auto eigen_out = framework::EigenVector::Flatten(*tensor); - auto eigen_in = framework::EigenVector::Flatten(*tensor); - eigen_out.device(dev) = (scale * eigen_in).round(); +template +struct ClipAndFakeQuantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + int num = in.numel(); + int block = 1024; + int grid = (block - 1 + num) / block; + + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + ClipAndQuantKernel<<>>( + in_data, scale_data, bin_cnt, num, out_data); } }; +template struct ClipAndFakeQuantFunctor; + } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(fake_quantize, - paddle::operators::FakeQuantizeCUDAKernel< - paddle::platform::CUDADeviceContext, float>, - paddle::operators::FakeQuantizeCUDAKernel< - paddle::platform::CUDADeviceContext, double>); +namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max, + ops::FakeQuantizeAbsMaxKernel); +REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, + ops::FakeQuantizeRangeAbsMaxKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 80f71d85dd..7ace7573ec 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -17,137 +17,91 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { -using platform::Transform; +template +struct FindAbsMaxFunctor { + void operator()(const DeviceContext& ctx, const T* in, const int num, T* out); +}; template -class FakeQuantizeKernel : public framework::OpKernel { +struct ClipAndFakeQuantFunctor { + void operator()(const DeviceContext& ctx, const framework::Tensor& in, + const framework::Tensor& scale, const int bin_cnt, + framework::Tensor* out); +}; + +template +struct FindRangeAbsMaxFunctor { + void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale, + const framework::Tensor& last_scale, + const framework::Tensor& iter, const int window_size, + framework::Tensor* scales_arr, framework::Tensor* out_scale); +}; + +template +class FakeQuantizeAbsMaxKernel : public framework::OpKernel { public: - T FindAbsMax(framework::Tensor* in, int n) const { - T* p = in->mutable_data(platform::CPUPlace()); - T abs_max = (T)0.00000001; - for (int i = 0; i < n; i++) { - T tmp = fabs(p[i]); - if (tmp > abs_max) abs_max = tmp; - } - return T(abs_max); - } - T FindRangeAbsMax(framework::Tensor* scale_list, framework::Tensor* out_scale, - const T& cur_scale, int window_size, - int current_iter) const { - T* sl = scale_list->mutable_data(platform::CPUPlace()); - T remove_tmp = sl[current_iter]; - sl[current_iter] = cur_scale; - T& max_scale = out_scale->mutable_data(platform::CPUPlace())[0]; - if (max_scale < cur_scale) { - max_scale = cur_scale; - } else if (fabs(remove_tmp - max_scale) < 1e-6) { - int size = (current_iter > window_size) ? window_size : current_iter; - max_scale = T(FindAbsMax(scale_list, size)); - } - return max_scale; - } + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); - T FindMovingAverageAbsMmax(framework::Tensor* in_scale, - framework::Tensor* out_scale, - const T& cur_scale) const { - T* ins = in_scale->mutable_data(platform::CPUPlace()); - T* outs = out_scale->mutable_data(platform::CPUPlace()); - outs[0] = 0.9 * cur_scale + 0.1 * ins[0]; - return T(outs[0]); + auto* out = context.Output("Out"); + auto* out_scale = context.Output("OutScale"); + T* out_s = out_scale->mutable_data(context.GetPlace()); + + int bit_length = context.Attr("bit_length"); + int bin_cnt = std::pow(2, bit_length - 1) - 1; + + auto& dev_ctx = context.template device_context(); + const T* in_data = in->data(); + FindAbsMaxFunctor()(dev_ctx, in_data, in->numel(), out_s); + ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, + bin_cnt, out); } +}; - virtual void Compute(const framework::ExecutionContext& context) const { - auto* tensor = context.Output("Out"); +template +class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); - const bool is_test = context.Attr("is_test"); - tensor->mutable_data(in->place()); - - auto* oms_tensor = context.Output("OutMovingScale"); - oms_tensor->mutable_data(in->place()); - - auto quantize_type = - static_cast(context.Attr("quantize_type")); - if (quantize_type == std::string("range_abs_max")) { - auto* oss_tensor = context.Output("OutScales"); - oss_tensor->mutable_data( - context.Input("InScales")->place()); - auto* oci_tensor = context.Output("OutCurrentIter"); - oci_tensor->mutable_data( - context.Input("InCurrentIter")->place()); - } + auto* in_scale = context.Input("InScale"); - T scale = static_cast(1); - int window_size = context.Attr("window_size"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + bool is_test = context.Attr("is_test"); int bit_length = context.Attr("bit_length"); int bin_cnt = std::pow(2, bit_length - 1) - 1; + auto& dev_ctx = context.template device_context(); - auto& dev = - *context.template device_context().eigen_device(); - auto raw_in = framework::EigenVector::Flatten(*in); - if (quantize_type == std::string("abs_max")) { - auto* saving_scale = context.Output("OutMovingScale"); - auto scale_out = framework::EigenVector::Flatten(*saving_scale); - scale_out.device(dev) = raw_in.abs().maximum(); - scale = scale_out(0); - - auto& device_ctx = context.template device_context(); - auto* scale_list = context.Output("OutScales"); - math::SetConstant scalar; - scale_list->mutable_data(context.GetPlace()); - scalar(device_ctx, scale_list, static_cast(0)); - auto* iter = context.Output("OutCurrentIter"); - iter->mutable_data(context.GetPlace()); - scalar(device_ctx, iter, static_cast(0)); - } else if (quantize_type == std::string("range_abs_max")) { - auto* moving_scale = context.Input("InMovingScale"); - if (is_test) { - scale = moving_scale->data()[0]; - } else { - auto* it = context.Input("InCurrentIter"); - auto* iter = context.Output("OutCurrentIter"); - const int* last_iter = it->data(); - int* current_iter = iter->mutable_data(platform::CPUPlace()); - auto* scale_list = context.Output("OutScales"); - auto* saving_scale = - context.Output("OutMovingScale"); - auto scale_out = framework::EigenVector::Flatten(*saving_scale); - scale_out.device(dev) = raw_in.abs().maximum(); - scale = saving_scale->mutable_data(platform::CPUPlace())[0]; - scale = FindRangeAbsMax(scale_list, saving_scale, scale, window_size, - current_iter[0]); - saving_scale->mutable_data(platform::CPUPlace())[0] = scale; - (*current_iter) = (*last_iter) + 1; - } - } else if (quantize_type == std::string("moving_average_abs_max")) { - auto* moving_scale = context.Input("InMovingScale"); - if (is_test) { - scale = moving_scale->data()[0]; - } else { - auto* saving_scale = - context.Output("OutMovingScale"); - auto scale_out = framework::EigenVector::Flatten(*saving_scale); - scale_out.device(dev) = raw_in.abs().maximum(); - scale = saving_scale->mutable_data(platform::CPUPlace())[0]; - scale = FindMovingAverageAbsMmax( - const_cast(moving_scale), saving_scale, scale); - saving_scale->mutable_data(platform::CPUPlace())[0] = scale; - } + // testing + if (is_test) { + ClipAndFakeQuantFunctor()(dev_ctx, *in, *in_scale, + bin_cnt, out); + return; } - Transform trans; - trans(context.template device_context(), in->data(), - in->data() + in->numel(), tensor->mutable_data(in->place()), - ClipFunctor(-scale, scale)); - auto eigen_out = framework::EigenVector::Flatten(*tensor); - auto eigen_in = framework::EigenVector::Flatten(*tensor); - eigen_out.device(dev) = (bin_cnt / scale * eigen_in).round(); + // training + auto* out_scale = context.Output("OutScale"); + auto* out_scales = context.Output("OutScales"); + auto* iter = context.Input("Iter"); + + int window_size = context.Attr("window_size"); + out_scale->mutable_data(context.GetPlace()); + + framework::Tensor cur_scale; + T* cur_scale_data = cur_scale.mutable_data({1}, context.GetPlace()); + FindAbsMaxFunctor()(dev_ctx, in->data(), in->numel(), + cur_scale_data); + FindRangeAbsMaxFunctor()(dev_ctx, cur_scale, *in_scale, + *iter, window_size, out_scales, + out_scale); + ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, + bin_cnt, out); } }; diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 01308e416a..133d3f72db 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -53,7 +53,7 @@ class SamplingIdKernel : public framework::OpKernel { static_cast(context.Attr("min")), static_cast(context.Attr("max"))); - std::vector ids(batch_size); + std::vector ids(batch_size); for (int i = 0; i < batch_size; ++i) { T r = dist(engine); int idx = width - 1; @@ -63,7 +63,7 @@ class SamplingIdKernel : public framework::OpKernel { break; } } - ids[i] = ins_vector[idx]; + ids[i] = int64_t(idx); } std::vector out_dim; diff --git a/python/paddle/fluid/inferencer.py b/python/paddle/fluid/inferencer.py index 3d2ef56617..a9b94a2072 100644 --- a/python/paddle/fluid/inferencer.py +++ b/python/paddle/fluid/inferencer.py @@ -98,10 +98,9 @@ class Inferencer(object): raise ValueError( "inputs should be a map of {'input_name': input_var}") - with executor.scope_guard(self.scope): - results = self.exe.run(self.inference_program, - feed=inputs, - fetch_list=[self.predict_var], + with self._prog_and_scope_guard(): + results = self.exe.run(feed=inputs, + fetch_list=[self.predict_var.name], return_numpy=return_numpy) return results diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py index be494a0d34..2e15c224f6 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py @@ -16,7 +16,9 @@ from __future__ import print_function import paddle import paddle.fluid as fluid +import paddle.fluid.core as core import numpy +import os import cifar10_small_test_set @@ -89,7 +91,7 @@ def optimizer_func(): return fluid.optimizer.Adam(learning_rate=0.001) -def train(use_cuda, train_program, params_dirname): +def train(use_cuda, train_program, parallel, params_dirname): BATCH_SIZE = 128 EPOCH_NUM = 1 @@ -116,7 +118,10 @@ def train(use_cuda, train_program, params_dirname): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() trainer = fluid.Trainer( - train_func=train_program, optimizer_func=optimizer_func, place=place) + train_func=train_program, + optimizer_func=optimizer_func, + place=place, + parallel=parallel) trainer.train( reader=train_reader, @@ -125,10 +130,13 @@ def train(use_cuda, train_program, params_dirname): feed_order=['pixel', 'label']) -def infer(use_cuda, inference_program, params_dirname=None): +def infer(use_cuda, inference_program, parallel, params_dirname=None): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() inferencer = fluid.Inferencer( - infer_func=inference_program, param_path=params_dirname, place=place) + infer_func=inference_program, + param_path=params_dirname, + place=place, + parallel=parallel) # The input's dimension of conv should be 4-D or 5-D. # Use normilized image pixels as input data, which should be in the range @@ -139,22 +147,34 @@ def infer(use_cuda, inference_program, params_dirname=None): print("infer results: ", results) -def main(use_cuda): +def main(use_cuda, parallel): if use_cuda and not fluid.core.is_compiled_with_cuda(): return save_path = "image_classification_resnet.inference.model" + os.environ['CPU_NUM'] = str(4) train( use_cuda=use_cuda, train_program=train_network, - params_dirname=save_path) + params_dirname=save_path, + parallel=parallel) + # FIXME(zcd): in the inference stage, the number of + # input data is one, it is not appropriate to use parallel. + if parallel and use_cuda: + return + + os.environ['CPU_NUM'] = str(1) infer( use_cuda=use_cuda, inference_program=inference_network, - params_dirname=save_path) + params_dirname=save_path, + parallel=parallel) if __name__ == '__main__': for use_cuda in (False, True): - main(use_cuda=use_cuda) + for parallel in (False, True): + if use_cuda and not core.is_compiled_with_cuda(): + continue + main(use_cuda=use_cuda, parallel=parallel) diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py index dbc7bc06c9..2f205de1c0 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py @@ -16,7 +16,9 @@ from __future__ import print_function import paddle import paddle.fluid as fluid +import paddle.fluid.core as core import numpy +import os import cifar10_small_test_set @@ -68,7 +70,7 @@ def optimizer_func(): return fluid.optimizer.Adam(learning_rate=0.001) -def train(use_cuda, train_program, params_dirname): +def train(use_cuda, train_program, parallel, params_dirname): BATCH_SIZE = 128 train_reader = paddle.batch( paddle.reader.shuffle( @@ -93,7 +95,10 @@ def train(use_cuda, train_program, params_dirname): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() trainer = fluid.Trainer( - train_func=train_program, place=place, optimizer_func=optimizer_func) + train_func=train_program, + place=place, + optimizer_func=optimizer_func, + parallel=parallel) trainer.train( reader=train_reader, @@ -102,10 +107,13 @@ def train(use_cuda, train_program, params_dirname): feed_order=['pixel', 'label']) -def infer(use_cuda, inference_program, params_dirname=None): +def infer(use_cuda, inference_program, parallel, params_dirname=None): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() inferencer = fluid.Inferencer( - infer_func=inference_program, param_path=params_dirname, place=place) + infer_func=inference_program, + param_path=params_dirname, + place=place, + parallel=parallel) # The input's dimension of conv should be 4-D or 5-D. # Use normilized image pixels as input data, which should be in the range @@ -116,22 +124,31 @@ def infer(use_cuda, inference_program, params_dirname=None): print("infer results: ", results) -def main(use_cuda): - if use_cuda and not fluid.core.is_compiled_with_cuda(): - return +def main(use_cuda, parallel): save_path = "image_classification_vgg.inference.model" + os.environ['CPU_NUM'] = str(4) train( use_cuda=use_cuda, train_program=train_network, - params_dirname=save_path) + params_dirname=save_path, + parallel=parallel) + # FIXME(zcd): in the inference stage, the number of + # input data is one, it is not appropriate to use parallel. + if parallel and use_cuda: + return + os.environ['CPU_NUM'] = str(1) infer( use_cuda=use_cuda, inference_program=inference_network, - params_dirname=save_path) + params_dirname=save_path, + parallel=parallel) if __name__ == '__main__': for use_cuda in (False, True): - main(use_cuda=use_cuda) + for parallel in (False, True): + if use_cuda and not core.is_compiled_with_cuda(): + continue + main(use_cuda=use_cuda, parallel=parallel) diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py index 187bef1b0c..a5adf68158 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py +++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py @@ -64,14 +64,14 @@ def optimizer_func(): return fluid.optimizer.Adam(learning_rate=0.001) -def train(use_cuda, train_program, params_dirname): +def train(use_cuda, train_program, parallel, params_dirname): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() trainer = fluid.Trainer( train_func=train_program, place=place, optimizer_func=optimizer_func, - parallel=True) + parallel=parallel) def event_handler(event): if isinstance(event, fluid.EndEpochEvent): @@ -108,11 +108,14 @@ def train(use_cuda, train_program, params_dirname): feed_order=['img', 'label']) -def infer(use_cuda, inference_program, params_dirname=None): +def infer(use_cuda, inference_program, parallel, params_dirname=None): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() inferencer = fluid.Inferencer( - infer_func=inference_program, param_path=params_dirname, place=place) + infer_func=inference_program, + param_path=params_dirname, + place=place, + parallel=parallel) batch_size = 1 tensor_img = numpy.random.uniform(-1.0, 1.0, @@ -123,20 +126,32 @@ def infer(use_cuda, inference_program, params_dirname=None): print("infer results: ", results[0]) -def main(use_cuda): +def main(use_cuda, parallel): params_dirname = "recognize_digits_conv.inference.model" # call train() with is_local argument to run distributed train + os.environ['CPU_NUM'] = str(4) train( use_cuda=use_cuda, train_program=train_program, - params_dirname=params_dirname) + params_dirname=params_dirname, + parallel=parallel) + + # FIXME(zcd): in the inference stage, the number of + # input data is one, it is not appropriate to use parallel. + if parallel and use_cuda: + return + os.environ['CPU_NUM'] = str(1) infer( use_cuda=use_cuda, inference_program=inference_program, - params_dirname=params_dirname) + params_dirname=params_dirname, + parallel=parallel) if __name__ == '__main__': - # for use_cuda in (False, True): - main(use_cuda=core.is_compiled_with_cuda()) + for use_cuda in (False, True): + for parallel in (False, True): + if use_cuda and not core.is_compiled_with_cuda(): + continue + main(use_cuda=use_cuda, parallel=parallel) diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py index b95e7db122..e7d8b23b32 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py +++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py @@ -16,6 +16,7 @@ from __future__ import print_function import argparse import paddle.fluid as fluid +import paddle.fluid.core as core import paddle import sys import numpy @@ -50,11 +51,14 @@ def optimizer_func(): return fluid.optimizer.Adam(learning_rate=0.001) -def train(use_cuda, train_program, params_dirname): +def train(use_cuda, train_program, params_dirname, parallel): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() trainer = fluid.Trainer( - train_func=train_program, place=place, optimizer_func=optimizer_func) + train_func=train_program, + place=place, + optimizer_func=optimizer_func, + parallel=parallel) def event_handler(event): if isinstance(event, fluid.EndEpochEvent): @@ -86,11 +90,14 @@ def train(use_cuda, train_program, params_dirname): feed_order=['img', 'label']) -def infer(use_cuda, inference_program, params_dirname=None): +def infer(use_cuda, inference_program, parallel, params_dirname=None): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() inferencer = fluid.Inferencer( - infer_func=inference_program, param_path=params_dirname, place=place) + infer_func=inference_program, + param_path=params_dirname, + place=place, + parallel=parallel) batch_size = 1 tensor_img = numpy.random.uniform(-1.0, 1.0, @@ -101,20 +108,32 @@ def infer(use_cuda, inference_program, params_dirname=None): print("infer results: ", results[0]) -def main(use_cuda): +def main(use_cuda, parallel): params_dirname = "recognize_digits_mlp.inference.model" # call train() with is_local argument to run distributed train + os.environ['CPU_NUM'] = str(4) train( use_cuda=use_cuda, train_program=train_program, - params_dirname=params_dirname) + params_dirname=params_dirname, + parallel=parallel) + + # FIXME(zcd): in the inference stage, the number of + # input data is one, it is not appropriate to use parallel. + if parallel and use_cuda: + return + os.environ['CPU_NUM'] = str(1) infer( use_cuda=use_cuda, inference_program=inference_program, - params_dirname=params_dirname) + params_dirname=params_dirname, + parallel=parallel) if __name__ == '__main__': - # for use_cuda in (False, True): - main(use_cuda=False) + for use_cuda in (False, True): + for parallel in (False, True): + if use_cuda and not core.is_compiled_with_cuda(): + continue + main(use_cuda=use_cuda, parallel=parallel) diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index cc0494774a..820ad4af88 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -21,28 +21,41 @@ from op_test import OpTest class TestFakeQuantizeOp(OpTest): def setUp(self): - self.op_type = "fake_quantize" + self.op_type = "fake_quantize_abs_max" + self.attrs = {'bit_length': 8} + self.inputs = {'X': np.random.random((124, 240)).astype("float32"), } + scale = np.max(np.abs(self.inputs['X'])).astype("float32") + self.outputs = { + 'Out': np.round(self.inputs['X'] / scale * ( + (1 << (self.attrs['bit_length'] - 1)) - 1)), + 'OutScale': np.array(scale).astype("float32"), + } + + def test_check_output(self): + self.check_output() + + +class TestFakeQuantizeOp(OpTest): + def setUp(self): + self.op_type = "fake_quantize_range_abs_max" self.attrs = { - 'bit_length': 8, - 'quantize_type': 'abs_max', - 'window_size': 10000 + 'bit_length': int(5), + 'window_size': int(1), + 'is_test': False } self.inputs = { - 'X': np.random.random((10, 10)).astype("float32"), - 'InScales': np.zeros(self.attrs['window_size']).astype("float32"), - 'InCurrentIter': np.zeros(1).astype("float32"), - 'InMovingScale': np.zeros(1).astype("float32") - } - self.scale = { - 'abs_max': np.max(np.abs(self.inputs['X'])).astype("float32") + 'X': np.random.random((8, 16, 7, 7)).astype("float32"), + 'Iter': np.zeros(1).astype("int64"), + 'InScale': np.zeros(1).astype("float32") } + scale = np.max(np.abs(self.inputs['X'])).astype("float32") + out_scales = np.zeros(self.attrs['window_size']).astype("float32") + out_scales[0] = scale self.outputs = { - 'Out': np.round(self.inputs['X'] / self.scale['abs_max'] * ( + 'Out': np.round(self.inputs['X'] / scale * ( (1 << (self.attrs['bit_length'] - 1)) - 1)), - 'OutScales': np.zeros(self.attrs['window_size']).astype("float32"), - 'OutMovingScale': - np.array([self.scale['abs_max']]).astype("float32"), - 'OutCurrentIter': np.zeros(1).astype("float32") + 'OutScale': scale, + 'OutScales': out_scales, } def test_check_output(self): diff --git a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py index 708265b457..674ef2ddf4 100644 --- a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py +++ b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py @@ -25,9 +25,9 @@ class TestSamplingIdOp(OpTest): self.op_type = "sampling_id" self.use_mkldnn = False self.init_kernel_type() - self.X = np.random.random((8, 4)).astype('float32') + self.X = np.random.random((100, 10)).astype('float32') self.inputs = {"X": self.X} - self.Y = np.random.random(8).astype('float32') + self.Y = np.random.random(100).astype('int64') self.outputs = {'Out': self.Y} self.attrs = {'max': 1.0, 'min': 0.0, 'seed': 1} @@ -36,6 +36,16 @@ class TestSamplingIdOp(OpTest): y1 = self.out self.check_output_customized(self.verify_output) y2 = self.out + + # check dtype + assert y1.dtype == np.int64 + assert y2.dtype == np.int64 + + # check output is index ids of inputs + inputs_ids = np.arange(self.X.shape[1]) + assert np.isin(y1, inputs_ids).all() + assert np.isin(y2, inputs_ids).all() + self.assertTrue(np.array_equal(y1, y2)) self.assertEqual(len(y1), len(self.Y))