diff --git a/CMakeLists.txt b/CMakeLists.txt index cd8c54e24e..32b369bec5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,7 +77,6 @@ option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON) -option(WITH_PREBUILD_OPENBLAS "Make use of the pre-built openblas library" ${WIN32}) # PY_VERSION if(NOT PY_VERSION) diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index a35a1a066b..aeb976b840 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -31,64 +31,66 @@ IF(NOT ${CBLAS_FOUND}) ADD_DEFINITIONS(-DPADDLE_USE_OPENBLAS) - IF (WITH_PREBUILD_OPENBLAS) + IF (WIN32) SET(CBLAS_FOUND true) - MESSAGE(STATUS, "Use prebuild openblas, please put it at " ${CBLAS_INSTALL_DIR}) - ELSE(WITH_PREBUILD_OPENBLAS) - SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -Wno-unused-but-set-variable -Wno-unused-variable") - SET(OPENBLAS_COMMIT "v0.2.20") + MESSAGE(WARNING, "In windows, openblas only support msvc build, please build it manually and put it at " ${CBLAS_INSTALL_DIR}) + ENDIF(WIN32) - IF(CMAKE_CROSSCOMPILING) - SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER}) - GET_FILENAME_COMPONENT(CROSS_SUFFIX ${CMAKE_C_COMPILER} DIRECTORY) - SET(CROSS_SUFFIX ${CROSS_SUFFIX}/) - IF(ANDROID) - IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$") - # use softfp - SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0) - ELSEIF(ANDROID_ABI STREQUAL "arm64-v8a") - SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV8 BINARY=64 USE_THREAD=0) - ENDIF() - ELSEIF(IOS) - IF(CMAKE_OSX_ARCHITECTURES MATCHES "arm64") - SET(OPENBLAS_CC "${OPENBLAS_CC} ${CMAKE_C_FLAGS} -isysroot ${CMAKE_OSX_SYSROOT}") - SET(OPENBLAS_CC "${OPENBLAS_CC} -arch arm64") - SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV8 BINARY=64 USE_THREAD=0 CROSS_SUFFIX=${CROSS_SUFFIX}) - ELSE() - MESSAGE(FATAL_ERROR "OpenBLAS only support arm64 architectures on iOS. " - "You can set IOS_USE_VECLIB_FOR_BLAS=ON or USE_EIGEN_FOR_BLAS=ON to use other blas library instead.") - ENDIF() - ELSEIF(RPI) - # use hardfp - SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV7 USE_THREAD=0) - ENDIF() - ELSE() - IF(APPLE) - SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -isysroot ${CMAKE_OSX_SYSROOT}") + IF (NOT WIN32) + SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -Wno-unused-but-set-variable -Wno-unused-variable") + SET(OPENBLAS_COMMIT "v0.2.20") + + IF(CMAKE_CROSSCOMPILING) + SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER}) + GET_FILENAME_COMPONENT(CROSS_SUFFIX ${CMAKE_C_COMPILER} DIRECTORY) + SET(CROSS_SUFFIX ${CROSS_SUFFIX}/) + IF(ANDROID) + IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$") + # use softfp + SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0) + ELSEIF(ANDROID_ABI STREQUAL "arm64-v8a") + SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV8 BINARY=64 USE_THREAD=0) ENDIF() - SET(OPTIONAL_ARGS "") - IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^x86(_64)?$") - SET(OPTIONAL_ARGS DYNAMIC_ARCH=1 NUM_THREADS=64) + ELSEIF(IOS) + IF(CMAKE_OSX_ARCHITECTURES MATCHES "arm64") + SET(OPENBLAS_CC "${OPENBLAS_CC} ${CMAKE_C_FLAGS} -isysroot ${CMAKE_OSX_SYSROOT}") + SET(OPENBLAS_CC "${OPENBLAS_CC} -arch arm64") + SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV8 BINARY=64 USE_THREAD=0 CROSS_SUFFIX=${CROSS_SUFFIX}) + ELSE() + MESSAGE(FATAL_ERROR "OpenBLAS only support arm64 architectures on iOS. " + "You can set IOS_USE_VECLIB_FOR_BLAS=ON or USE_EIGEN_FOR_BLAS=ON to use other blas library instead.") ENDIF() + ELSEIF(RPI) + # use hardfp + SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV7 USE_THREAD=0) ENDIF() + ELSE() + IF(APPLE) + SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -isysroot ${CMAKE_OSX_SYSROOT}") + ENDIF() + SET(OPTIONAL_ARGS "") + IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^x86(_64)?$") + SET(OPTIONAL_ARGS DYNAMIC_ARCH=1 NUM_THREADS=64) + ENDIF() + ENDIF() - SET(COMMON_ARGS CC=${OPENBLAS_CC} NO_SHARED=1 NO_LAPACK=1 libs) - ExternalProject_Add( - extern_openblas - ${EXTERNAL_PROJECT_LOG_ARGS} - GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git - GIT_TAG ${OPENBLAS_COMMIT} - PREFIX ${CBLAS_SOURCES_DIR} - INSTALL_DIR ${CBLAS_INSTALL_DIR} - BUILD_IN_SOURCE 1 - BUILD_COMMAND ${CMAKE_MAKE_PROGRAM} ${COMMON_ARGS} ${OPTIONAL_ARGS} - INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 NO_LAPACK=1 PREFIX= - && rm -r ${CBLAS_INSTALL_DIR}/lib/cmake ${CBLAS_INSTALL_DIR}/lib/pkgconfig - UPDATE_COMMAND "" - CONFIGURE_COMMAND "" - ) - ENDIF (WITH_PREBUILD_OPENBLAS) - + SET(COMMON_ARGS CC=${OPENBLAS_CC} NO_SHARED=1 NO_LAPACK=1 libs) + ExternalProject_Add( + extern_openblas + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git + GIT_TAG ${OPENBLAS_COMMIT} + PREFIX ${CBLAS_SOURCES_DIR} + INSTALL_DIR ${CBLAS_INSTALL_DIR} + BUILD_IN_SOURCE 1 + BUILD_COMMAND ${CMAKE_MAKE_PROGRAM} ${COMMON_ARGS} ${OPTIONAL_ARGS} + INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 NO_LAPACK=1 PREFIX= + && rm -r ${CBLAS_INSTALL_DIR}/lib/cmake ${CBLAS_INSTALL_DIR}/lib/pkgconfig + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + ) + ELSE() + ENDIF(NOT WIN32) SET(CBLAS_PROVIDER openblas) IF(WITH_C_API) INSTALL(DIRECTORY ${CBLAS_INC_DIR} DESTINATION third_party/openblas) diff --git a/doc/v2/dev/contribute_to_paddle_en.md b/doc/v2/dev/contribute_to_paddle_en.md index c97564d93a..7272339644 120000 --- a/doc/v2/dev/contribute_to_paddle_en.md +++ b/doc/v2/dev/contribute_to_paddle_en.md @@ -1 +1 @@ -../../../CONTRIBUTING.md \ No newline at end of file +../../../CONTRIBUTING.md diff --git a/paddle/fluid/framework/data_type_transform.cu b/paddle/fluid/framework/data_type_transform.cu index f46491293e..7dd9cb5cfd 120000 --- a/paddle/fluid/framework/data_type_transform.cu +++ b/paddle/fluid/framework/data_type_transform.cu @@ -1 +1,15 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + data_type_transform.cc \ No newline at end of file diff --git a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc index dfef9f381b..ecefab32bb 100644 --- a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc @@ -211,12 +211,12 @@ void PrepareLSTMWeight(const LoDTensor& W_forget_w0, VLOG(30) << "LSTMWeight resized to " << out->dims(); float* out_data = out->mutable_data(platform::CPUPlace()); - std::array tensors = + std::array tensors{ {W_forget_w0.data(), W_input_w0.data(), - W_output_w0.data(), W_cell_w0.data()}; - std::array tensors1 = + W_output_w0.data(), W_cell_w0.data()}}; + std::array tensors1{ {W_forget_w1.data(), W_input_w1.data(), - W_output_w1.data(), W_cell_w1.data()}; + W_output_w1.data(), W_cell_w1.data()}}; for (int row = 0; row < D; row++) { for (int col = 0; col < 4; col++) { @@ -238,9 +238,9 @@ void PrepareLSTMWeight(const LoDTensor& W_forget_w0, void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, const LoDTensor& B_output, const LoDTensor& B_cell, LoDTensor* out) { - std::array tensors = + std::array tensors{ {B_forget.data(), B_input.data(), B_output.data(), - B_cell.data()}; + B_cell.data()}}; PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1); int D = B_forget.dims()[0]; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 8d699146bd..5f7cea65d9 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -207,7 +207,7 @@ struct PassRegistrar : public Registrar { return 0; \ } \ static ::paddle::framework::ir::PassRegistrar \ - &__pass_tmp_registrar_##pass_type##__ __UNUSED__() = \ + &__pass_tmp_registrar_##pass_type##__ UNUSED = \ __pass_registrar_##pass_type##__ #define USE_PASS(pass_type) \ @@ -215,7 +215,7 @@ struct PassRegistrar : public Registrar { __use_pass_itself_##pass_type, \ "USE_PASS must be called in global namespace"); \ extern int TouchPassRegistrar_##pass_type(); \ - static int use_pass_itself_##pass_type##_ __UNUSED__() = \ + static int use_pass_itself_##pass_type##_ UNUSED = \ TouchPassRegistrar_##pass_type() } // namespace ir diff --git a/paddle/fluid/framework/tensor_util.cu b/paddle/fluid/framework/tensor_util.cu index edd88c4e54..251c3a5e40 120000 --- a/paddle/fluid/framework/tensor_util.cu +++ b/paddle/fluid/framework/tensor_util.cu @@ -1 +1,15 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + tensor_util.cc \ No newline at end of file diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index a3440cfc78..d55303a51e 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -113,7 +113,9 @@ void Analyzer::Run(Argument* argument) { passes.push_back("infer_clean_graph_pass"); passes.push_back("graph_viz_pass"); // add graphviz for debug. for (auto& pass : ir_passes_) { - if (!disabled_ir_passes_.count(pass)) { + // skip mkldnn pass when use_mkldnn_ = false; + bool skip_pass = (!use_mkldnn_) && pass.find("mkldnn") != std::string::npos; + if (!disabled_ir_passes_.count(pass) && !skip_pass) { passes.push_back(pass); passes.push_back("graph_viz_pass"); // add graphviz for debug. } diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index cbe03c163f..a6360a884d 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -150,4 +150,4 @@ struct NCCLContextMap { } // namespace platform } // namespace paddle -#endif \ No newline at end of file +#endif diff --git a/paddle/fluid/platform/port.h b/paddle/fluid/platform/port.h index 347622f212..d3a6e28549 100644 --- a/paddle/fluid/platform/port.h +++ b/paddle/fluid/platform/port.h @@ -24,42 +24,38 @@ #include "glog/logging.h" #if !defined(_WIN32) - #define UNUSED __attribute__((unused)) - #include // dladdr - #include // backtrace - #include - #include // std::accumulate + #include // dladdr + #include // backtrace + #include + #include // std::accumulate #else - #include - #include // _popen, _pclose - #include - #include // std::accumulate in msvc - // windows version of __attribute__((unused)) - #define UNUSED __pragma(warning(suppress : 4100)) - - #ifndef S_ISDIR // windows port for sys/stat.h - #define S_ISDIR(mode) (((mode)&S_IFMT) == S_IFDIR) - #endif // S_ISDIR - - static void *dlsym(void *handle, const char *symbol_name) { - FARPROC found_symbol; - found_symbol = GetProcAddress((HMODULE)handle, symbol_name); - - if (found_symbol == NULL) { - throw std::runtime_error(std::string(symbol_name) + " not found."); - } - return reinterpret_cast(found_symbol); + #include + #include // _popen, _pclose + #include + #include // std::accumulate in msvc + #ifndef S_ISDIR // windows port for sys/stat.h + #define S_ISDIR(mode) (((mode)&S_IFMT) == S_IFDIR) + #endif // S_ISDIR + + static void *dlsym(void *handle, const char *symbol_name) { + FARPROC found_symbol; + found_symbol = GetProcAddress((HMODULE)handle, symbol_name); + + if (found_symbol == NULL) { + throw std::runtime_error(std::string(symbol_name) + " not found."); } + return reinterpret_cast(found_symbol); + } - static void *dlopen(const char *filename, int flag) { - std::string file_name(filename); - file_name.replace(0, file_name.size() - 1, '/', '\\'); - HMODULE hModule = LoadLibrary(file_name.c_str()); - if (!hModule) { - throw std::runtime_error(file_name + " not found."); - } - return reinterpret_cast(hModule); + static void *dlopen(const char *filename, int flag) { + std::string file_name(filename); + file_name.replace(0, file_name.size() - 1, '/', '\\'); + HMODULE hModule = LoadLibrary(file_name.c_str()); + if (!hModule) { + throw std::runtime_error(file_name + " not found."); } + return reinterpret_cast(hModule); + } #endif // !_WIN32 diff --git a/paddle/fluid/platform/stream_callback_manager.h b/paddle/fluid/platform/stream_callback_manager.h index 3cd1628a0b..0e88a439cf 100644 --- a/paddle/fluid/platform/stream_callback_manager.h +++ b/paddle/fluid/platform/stream_callback_manager.h @@ -18,8 +18,8 @@ #include #include #include +#include "ThreadPool.h" #include "paddle/fluid/platform/enforce.h" -#include "third_party/threadpool/src/extern_threadpool/ThreadPool.h" namespace paddle { namespace platform { diff --git a/paddle/fluid/platform/variant.h b/paddle/fluid/platform/variant.h index fb6a8bb96f..1b10db8669 100644 --- a/paddle/fluid/platform/variant.h +++ b/paddle/fluid/platform/variant.h @@ -45,8 +45,8 @@ limitations under the License. */ // some platform-independent defintion #if defined(_WIN32) -#define __UNUSED__() +#define UNUSED #define __builtin_expect(EXP, C) (EXP) #else -#define __UNUSED__() __attribute__((unused)) -#endif \ No newline at end of file +#define UNUSED __attribute__((unused)) +#endif diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index c4a5421cdb..2e1b4b2ead 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -35,6 +35,7 @@ from . import regularizer from . import average from . import metrics from . import transpiler +from . import distribute_lookup_table from .param_attr import ParamAttr, WeightNormParamAttr from .data_feeder import DataFeeder from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope @@ -111,11 +112,10 @@ def __bootstrap__(): os.environ['OMP_NUM_THREADS'] = str(num_threads) read_env_flags = [ - 'use_pinned_memory', 'check_nan_inf', 'benchmark', - 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', - 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', - 'dist_threadpool_size', 'eager_delete_tensor_gb', - 'reader_queue_speed_test_mode' + 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'eager_delete_scope', + 'use_mkldnn', 'initial_cpu_memory_in_mb', 'init_allocated_mem', + 'free_idle_memory', 'paddle_num_threads', 'dist_threadpool_size', + 'eager_delete_tensor_gb', 'reader_queue_speed_test_mode' ] if os.name != 'nt': read_env_flags.append('warpctc_dir') diff --git a/python/paddle/fluid/distribute_lookup_table.py b/python/paddle/fluid/distribute_lookup_table.py new file mode 100644 index 0000000000..52d9ce75f8 --- /dev/null +++ b/python/paddle/fluid/distribute_lookup_table.py @@ -0,0 +1,39 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +LOOKUP_TABLE_TYPE = "lookup_table" + + +def find_distributed_lookup_table(program): + """ + Find distribute lookup table in program. + We only support one distribute table now. + :param program: + :return: table_name or None + """ + table_name = None + + for op in program.global_block().ops: + if op.type == LOOKUP_TABLE_TYPE: + if op.attr('is_distributed') is True: + if table_name is None: + table_name = op.input("W")[0] + if table_name != op.input("W")[0]: + raise RuntimeError("all distributed lookup_table_ops" + " should have only one table") + else: + if table_name is not None: + assert op.input("W")[0] != table_name + + return table_name diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index d50f6744df..a9075045a2 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -348,6 +348,7 @@ def _copy_reader_create_op_(block, op): if os.name != 'nt': + @templatedoc(op_type='create_recordio_file_reader') def open_recordio_file(filename, shapes, @@ -405,8 +406,8 @@ if os.name != 'nt': startup_var.desc.set_dtypes(dtypes) startup_var.persistable = True - main_prog_var = _copy_reader_var_(default_main_program().current_block(), - startup_var) + main_prog_var = _copy_reader_var_( + default_main_program().current_block(), startup_var) if pass_num > 1: main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c0278efb60..4b9264bfb6 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -342,6 +342,7 @@ def embedding(input, if os.name != 'nt': + @templatedoc(op_type="lstm") def dynamic_lstm(input, size, @@ -961,6 +962,7 @@ def linear_chain_crf(input, label, param_attr=None): if os.name != 'nt': + @templatedoc() def crf_decoding(input, param_attr, label=None): """ @@ -988,9 +990,11 @@ if os.name != 'nt': dtype=helper.input_dtype()) helper.append_op( type='crf_decoding', - inputs={"Emission": [input], - "Transition": transition, - "Label": label}, + inputs={ + "Emission": [input], + "Transition": transition, + "Label": label + }, outputs={"ViterbiPath": [viterbi_path]}) return viterbi_path @@ -5530,8 +5534,13 @@ def label_smooth(label, if os.name != 'nt': + @templatedoc() - def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0): + def roi_pool(input, + rois, + pooled_height=1, + pooled_width=1, + spatial_scale=1.0): """ ${comment} diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index df52b7042f..66eb1229aa 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -105,7 +105,6 @@ if os.name != 'nt': _cum_sum_ = generate_layer_fn('cumsum') - def cumsum(x, axis=None, exclusive=None, reverse=None): locals_var = locals().keys() kwargs = dict() @@ -115,7 +114,6 @@ if os.name != 'nt': kwargs[name] = val return _cum_sum_(**kwargs) - cumsum.__doc__ = _cum_sum_.__doc__ + """ Examples: diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 7e2364a5a8..da92826d41 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -13,21 +13,23 @@ # limitations under the License. from __future__ import print_function -import re -import sys + from collections import defaultdict +from contextlib import contextmanager + from paddle.fluid.framework import Program, Variable, name_scope, default_main_program +from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table + from . import framework from . import layers +from . import unique_name from .backward import append_backward +from .clip import append_gradient_clip_ops, error_clip_callback from .framework import program_guard -from . import unique_name from .initializer import Constant from .layer_helper import LayerHelper -from .regularizer import append_regularization_ops -from .clip import append_gradient_clip_ops, error_clip_callback -from contextlib import contextmanager from .layers import ops +from .regularizer import append_regularization_ops __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', @@ -85,7 +87,7 @@ class Optimizer(object): name=unique_name.generate("learning_rate"), shape=[1], value=float(self._learning_rate), - dtype='float32' if self._dtype == None else self._dtype, + dtype='float32' if self._dtype is None else self._dtype, persistable=True) def _global_learning_rate(self, program=None): @@ -245,6 +247,50 @@ class Optimizer(object): end = len(global_block.ops) return global_block._slice_ops(start, end) + def _process_distribute_lookuptable(self, param_grads, loss, + startup_program): + """ + Because distribute lookup table only support SGD optimizer for now, not support + other optimizer and regularization, so we should find the table parameter out, + and avoid to add regularization and other op for it, and add sgd optimize op + for it independently. + :param param_grads(list((Var, Var))): list of (param, grad) pair. + :param loss: the loss variable. + :param startup_program: the startup program + """ + program = loss.block.program + table_name = find_distributed_lookup_table(program) + table_param = None + table_grad = None + new_param_grads = [] + for p, g in param_grads: + if p.name == table_name: + if table_param is not None: + raise RuntimeError( + "multi dist table var found, only support one now!") + table_param = p + table_grad = g + else: + new_param_grads.append((p, g)) + sgd_op = None + if table_param is not None: + with program_guard(program, startup_program): + param_and_grad = [table_param, table_grad] + with table_param.block.program._optimized_guard(param_and_grad), \ + framework.name_scope("optimizer"): + self._create_global_learning_rate() + # create the optimize op + sgd_op = loss.block.append_op( + type='sgd', + inputs={ + "Param": table_param, + "Grad": table_grad, + "LearningRate": + self._create_param_lr(param_and_grad) + }, + outputs={"ParamOut": param_and_grad[0]}) + return new_param_grads, (table_param, table_grad), sgd_op + def minimize(self, loss, startup_program=None, @@ -260,6 +306,9 @@ class Optimizer(object): params_grads = sorted(params_grads, key=lambda x: x[0].name) + params_grads, table_param_and_grad, table_optimize_op = \ + self._process_distribute_lookuptable(params_grads, loss, startup_program) + params_grads = append_gradient_clip_ops(params_grads) # Add regularization if any @@ -268,6 +317,9 @@ class Optimizer(object): optimize_ops = self._create_optimization_pass(params_grads, loss, startup_program) + if table_optimize_op is not None: + optimize_ops.append(table_optimize_op) + params_grads.append(table_param_and_grad) return optimize_ops, params_grads diff --git a/python/paddle/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/fluid/tests/book/test_label_semantic_roles.py index f63387a906..42ab9b2311 100644 --- a/python/paddle/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/fluid/tests/book/test_label_semantic_roles.py @@ -38,7 +38,7 @@ depth = 8 mix_hidden_lr = 1e-3 IS_SPARSE = True -PASS_NUM = 10 +PASS_NUM = 1 BATCH_SIZE = 10 embedding_name = 'emb' diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 3a5b6b5cb8..d132dd3c48 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -567,7 +567,6 @@ class TestDistLookupTable(TestDistLookupTableBase): 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', - 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'uniform_random', 'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat', 'fake_init' @@ -639,7 +638,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): # 5 save table self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) - trainer, _ = self.get_trainer(config) + trainer, trainer_startup = self.get_trainer(config) self.assertEqual(len(trainer.blocks), 1) ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', @@ -653,6 +652,16 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): 'recv', 'concat' ] self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + startup_ops = [ + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'uniform_random', + 'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat', + 'fake_init' + ] + self.assertEqual([op.type for op in trainer_startup.blocks[0].ops], + startup_ops) class TestDistLookupTableSliceSize(TestDistLookupTableBase): diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 094eaeb59c..89bc248027 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -31,18 +31,17 @@ Steps to transpile pserver: """ import math -import sys import numpy as np import collections -import six import logging -from .ps_dispatcher import RoundRobin, HashName, PSDispatcher +from .ps_dispatcher import RoundRobin, PSDispatcher from .. import core, framework, unique_name from ..framework import Program, default_main_program, \ default_startup_program, Block, \ Parameter, grad_var_name from .details import * +from ..distribute_lookup_table import find_distributed_lookup_table from functools import reduce LOOKUP_TABLE_TYPE = "lookup_table" @@ -292,7 +291,8 @@ class DistributeTranspiler(object): self.optimize_ops, self.params_grads = self._get_optimize_pass() ps_dispatcher = self.config.split_method(self.pserver_endpoints) - self.has_distributed_lookup_table = self._has_distributed_lookup_table() + self.table_name = find_distributed_lookup_table(self.origin_program) + self.has_distributed_lookup_table = self.table_name != None self.param_name_to_grad_name = dict() self.grad_name_to_param_name = dict() for param_var, grad_var in self.params_grads: @@ -966,28 +966,6 @@ to transpile() call.") # ====================== private transpiler functions ===================== - def _has_distributed_lookup_table(self): - # process lookup_table_op - # 1. check all lookup_table_op is distributed - # 2. check all lookup_table_op share the same table. - distributed_lookup_table_ops = [] - # support only one distributed_lookup_table now - self.table_name = None - for op in self.origin_program.global_block().ops: - if op.type == LOOKUP_TABLE_TYPE: - if op.attr('is_distributed') is True: - if self.table_name is None: - self.table_name = op.input("W")[0] - if self.table_name != op.input("W")[0]: - raise RuntimeError("all distributed lookup_table_ops" - " should have only one table") - distributed_lookup_table_ops.append(op) - else: - if self.table_name is not None: - assert op.input("W")[0] != self.table_name - - return len(distributed_lookup_table_ops) > 0 - def _update_dist_lookup_table_vars(self, param_list, grad_list, params_grads): # TODO(wuyi): put find a way to put dist lookup table stuff all together. @@ -1341,7 +1319,6 @@ to transpile() call.") """ create a new block to handle save checkpoint. """ - import os pserver_program.global_block().create_var( name="kLookupTablePath", diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index b5cde7bac7..1e961b936f 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -1719,7 +1719,7 @@ def inputs(layers, *args): if len(args) != 0: layers.extend(args) - Inputs(*[l.name for l in layers]) + Inputs(* [l.name for l in layers]) def outputs(layers, *args): @@ -1769,7 +1769,7 @@ def outputs(layers, *args): assert len(layers) > 0 if HasInputsSet(): # input already set - Outputs(*[l.name for l in layers]) + Outputs(* [l.name for l in layers]) return # just return outputs. if len(layers) != 1: