diff --git a/CMakeLists.txt b/CMakeLists.txt index 8dcf9786e3..efa68c9ba2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -214,6 +214,7 @@ if (NOT WIN32) # there is no official support of warpctc, nccl, cupti in windows include(external/warpctc) # download, build, install warpctc include(cupti) +include(external/gzstream) endif (NOT WIN32) if(WITH_DISTRIBUTE) diff --git a/cmake/external/gzstream.cmake b/cmake/external/gzstream.cmake new file mode 100644 index 0000000000..59d8e93245 --- /dev/null +++ b/cmake/external/gzstream.cmake @@ -0,0 +1,47 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +IF(MOBILE_INFERENCE) + return() +ENDIF() + +include (ExternalProject) + +# NOTE: gzstream is needed when linking with ctr reader. + +SET(GZSTREAM_SOURCES_DIR ${THIRD_PARTY_PATH}/gzstream) +SET(GZSTREAM_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gzstream) +SET(GZSTREAM_INCLUDE_DIR "${GZSTREAM_INSTALL_DIR}/include/" CACHE PATH "gzstream include directory." FORCE) + +ExternalProject_Add( + extern_gzstream + GIT_REPOSITORY "https://github.com/jacquesqiao/gzstream.git" + GIT_TAG "" + PREFIX ${GZSTREAM_SOURCES_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_IN_SOURCE 1 + BUILD_COMMAND make -j8 + INSTALL_COMMAND mkdir -p ${GZSTREAM_INSTALL_DIR}/lib/ && mkdir -p ${GZSTREAM_INSTALL_DIR}/include/ + && cp ${GZSTREAM_SOURCES_DIR}/src/extern_gzstream/libgzstream.a ${GZSTREAM_INSTALL_DIR}/lib + && cp -r ${GZSTREAM_SOURCES_DIR}/src/extern_gzstream/gzstream.h ${GZSTREAM_INSTALL_DIR}/include +) + +ADD_LIBRARY(gzstream STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET gzstream PROPERTY IMPORTED_LOCATION + "${GZSTREAM_INSTALL_DIR}/lib/libgzstream.a") + +include_directories(${GZSTREAM_INCLUDE_DIR}) +ADD_DEPENDENCIES(gzstream extern_gzstream zlib) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 96b38902e8..59605a7e16 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -97,8 +97,8 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) -paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0)) -paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) +paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False)) +paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index d6b5ad4570..93288936fe 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -39,11 +39,12 @@ if (WITH_GPU) endif() cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) +cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass) cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) -set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass) +set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass) if (WITH_GPU) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) endif() diff --git a/paddle/fluid/framework/details/all_reduce_deps_pass.cc b/paddle/fluid/framework/details/all_reduce_deps_pass.cc new file mode 100644 index 0000000000..fe21e21bcf --- /dev/null +++ b/paddle/fluid/framework/details/all_reduce_deps_pass.cc @@ -0,0 +1,125 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/details/all_reduce_deps_pass.h" +#include "paddle/fluid/framework/details/all_reduce_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/op_graph_view.h" +#include "paddle/fluid/framework/details/var_handle.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace details { + +static constexpr char kAllOpDescs[] = "all_op_descs"; + +VarHandle* GetValidInput(const OpHandleBase* a) { + for (auto p : a->Inputs()) { + VarHandle* b = dynamic_cast(p); + if (b) { + return b; + } + } + + return nullptr; +} + +std::unique_ptr AllReduceDepsPass::ApplyImpl( + std::unique_ptr graph) const { + auto graph_ops = ir::FilterByNodeWrapper(*graph); + + // get vars order + int order = 0; + std::unordered_map vars; + // TODO(gongwb): use graph topology sort to find the order of operators. + // Note that must assert topology sort is stable + auto& ops = Get>(kAllOpDescs); + for (auto* op_desc : ops) { + auto outputs = op_desc->Outputs(); + for (auto& o_it : outputs) { + for (auto& v : o_it.second) { // values + vars[v] = order; + } + } + order++; + } + + std::vector dist_ops; + // get allreduce ops. + for (auto& op : graph_ops) { + // FIXME(gongwb):add broad cast. + if (op->Name() == "all_reduce" || op->Name() == "reduce") { + dist_ops.push_back(op); + } + } + + VLOG(10) << "dist_ops size:" << dist_ops.size() << std::endl; + + std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1, + OpHandleBase* op2) { + VarHandle* i0 = dynamic_cast(GetValidInput(op1)); + VarHandle* i1 = dynamic_cast(GetValidInput(op2)); + + PADDLE_ENFORCE(i0 != nullptr && i1 != nullptr, "%s convert to %s error", + op1->DebugString(), op2->DebugString()); + + auto l_it = vars.find(i0->name_); + auto r_it = vars.find(i1->name_); + + if (l_it->second < r_it->second) return true; + + if (l_it->second == r_it->second) { + return i0->name_ < i1->name_; + } + + return false; + }); + + // add dependency. + auto& sorted_ops = dist_ops; + for (size_t i = 1; i < sorted_ops.size(); ++i) { + auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + + auto* pre_op = sorted_ops[i - 1]; + auto* op = sorted_ops[i]; + + pre_op->AddOutput(dep_var); + op->AddInput(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); + + VLOG(10) << "add all_reduce sequential dependencies between " << pre_op + << " and " << op; + + VLOG(10) << "pre_op:" << pre_op->DebugString() + << ", op:" << op->DebugString(); + } + + return graph; +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(all_reduce_deps_pass, + paddle::framework::details::AllReduceDepsPass) + .RequirePassAttr(paddle::framework::details::kAllOpDescs); diff --git a/paddle/fluid/framework/details/all_reduce_deps_pass.h b/paddle/fluid/framework/details/all_reduce_deps_pass.h new file mode 100644 index 0000000000..e8b9108981 --- /dev/null +++ b/paddle/fluid/framework/details/all_reduce_deps_pass.h @@ -0,0 +1,33 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace details { + +// TODO(gongwb): overlap allreduce with backward computation. +class AllReduceDepsPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 70baced0ad..523f9eadf2 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" +#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/sequential_execution_pass.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h" @@ -24,6 +25,10 @@ namespace paddle { namespace framework { namespace details { +static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) { + return (!strategy.enable_sequential_execution_ && strategy.num_trainers_ > 1); +} + class ParallelExecutorPassBuilder : public ir::PassBuilder { public: explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) @@ -70,6 +75,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Verify that the graph is correct for multi-device executor. AppendPass("multi_devices_check_pass"); + if (SeqOnlyAllReduceOps(strategy)) { + AppendPass("all_reduce_deps_pass"); + } + if (strategy_.remove_unnecessary_lock_) { AppendPass("modify_op_lock_and_record_event_pass"); } @@ -124,6 +133,17 @@ std::unique_ptr BuildStrategy::Apply( pass->SetNotOwned("nccl_ctxs", nctx); #endif } else if (pass->Type() == "sequential_execution_pass") { + VLOG(1) << "set enable_sequential_execution:" + << enable_sequential_execution_; + + pass->Erase(kAllOpDescs); + pass->Set>( + kAllOpDescs, + new std::vector(main_program.Block(0).AllOps())); + } else if (pass->Type() == "all_reduce_deps_pass") { + VLOG(1) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this) + << ", num_trainers:" << num_trainers_; + pass->Erase(kAllOpDescs); pass->Set>( kAllOpDescs, @@ -144,4 +164,5 @@ USE_PASS(multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); USE_PASS(sequential_execution_pass); +USE_PASS(all_reduce_deps_pass); USE_PASS(modify_op_lock_and_record_event_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 3236c35efd..9f0a259128 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -73,6 +73,7 @@ struct BuildStrategy { bool fuse_broadcast_op_{false}; + int num_trainers_{1}; bool remove_unnecessary_lock_{false}; // NOTE: diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index bfdfdc56b3..5bd68f9ac2 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -71,7 +71,7 @@ class OperatorBase; class ExecutionContext; /** - * OperatorBase has the basic element that Net will call to do computation. + * OperatorBase has the basic elements that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User * should always construct a proto message OpDesc and call * OpRegistry::CreateOp(op_desc) to get an Operator instance. diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 2c6e337568..0e907e6fd3 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -54,7 +54,7 @@ class ParallelExecutorPrivate { Scope *global_scope_; // not owned std::unique_ptr executor_; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) std::unique_ptr nccl_ctxs_; #endif bool own_local_scope_; @@ -104,7 +104,7 @@ ParallelExecutor::ParallelExecutor( if (member_->use_cuda_) { // Bcast Parameters to all GPUs -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); ncclUniqueId *nccl_id = nullptr; if (nccl_id_var != nullptr) { @@ -124,7 +124,7 @@ ParallelExecutor::ParallelExecutor( // Step 2. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) std::unique_ptr graph = build_strategy.Apply( main_program, member_->places_, loss_var_name, params, member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get()); @@ -213,7 +213,7 @@ void ParallelExecutor::BCastParamsToDevices( } auto &dims = main_tensor.dims(); if (paddle::platform::is_gpu_place(main_tensor.place())) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) std::vector buffers; size_t numel = main_tensor.numel(); ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index 55ca02038e..44384082db 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -120,8 +120,22 @@ class SelectedRows { */ int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false); - void SyncIndex(); + /* + * @brief Get the index of the key from id_to_index_ map. + */ + inline int64_t GetIndexFromId(int64_t key) { + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + return -1; + } else { + return iter->second; + } + } + void SyncIndex(); + /* + * @brief Get complete Dims before + */ DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); dims[0] = height_; @@ -133,9 +147,10 @@ class SelectedRows { // SelectedRows are simply concated when adding together. Until a // SelectedRows add a Tensor, will the duplicate rows be handled. Vector rows_; - std::unordered_map id_to_index_; + std::unordered_map + id_to_index_; // should not be used when rows_ has duplicate member std::unique_ptr value_{nullptr}; - int64_t height_; + int64_t height_; // height indicates the underline tensor's height std::unique_ptr rwlock_{nullptr}; }; diff --git a/paddle/fluid/framework/transfer_scope_cache.cc b/paddle/fluid/framework/transfer_scope_cache.cc index f6219a1417..e52a8317e2 100644 --- a/paddle/fluid/framework/transfer_scope_cache.cc +++ b/paddle/fluid/framework/transfer_scope_cache.cc @@ -17,28 +17,16 @@ namespace paddle { namespace framework { -// Holds all the transfer scope across the process. std::unordered_map& global_transfer_data_cache() { - typedef std::unordered_map map_t; - thread_local std::unique_ptr x(new map_t); + thread_local auto* x = new std::unordered_map; return *x; } -// Holds all the transfer scope for this thread. std::unordered_set& global_transfer_scope_cache() { - typedef std::unordered_set set_t; - thread_local std::unique_ptr x(new set_t); + thread_local auto* x = new std::unordered_set; return *x; } -// Try to create a transfer scope. If one cached scope has match the -// requirement, just return that one. -// Inputs: -// @type0: the source kernel type. -// @type1: the target kernel type. -// @scope: the execution scope of this op. -// Returns: A scope used to hold the transfer data across the different kernel -// type. Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1, const Scope* scope) { Scope* new_scope{nullptr}; @@ -58,5 +46,27 @@ Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1, return new_scope; } +void RemoveKidsFromTransferScopeCache(Scope* scope) { + auto it = global_transfer_scope_cache().find(scope); + if (it != global_transfer_scope_cache().end()) { + global_transfer_scope_cache().erase(it); + } + for (auto* s : scope->kids()) { + auto it = global_transfer_scope_cache().find(s); + if (it != global_transfer_scope_cache().end()) { + global_transfer_scope_cache().erase(it); + } + } + + // remove global transfer data cache + auto& cache = global_transfer_data_cache(); + for (auto it = cache.begin(); it != cache.end();) { + if (it->second == scope) + it = cache.erase(it); + else + it++; + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index 4bd3f93ef7..27b6b80955 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -35,4 +35,5 @@ function(inference_analysis_test TARGET) endif() endfunction(inference_analysis_test) -inference_analysis_test(test_analyzer SRCS analyzer_tester.cc EXTRA_DEPS reset_tensor_array paddle_inference_api) +inference_analysis_test(test_analyzer SRCS analyzer_tester.cc + EXTRA_DEPS reset_tensor_array paddle_inference_api) diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index 7710ed7b61..cb88333d15 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -76,7 +76,8 @@ void TestWord2vecPrediction(const std::string& model_path) { 0.000932706}; const size_t num_elements = outputs.front().data.length() / sizeof(float); // The outputs' buffers are in CPU memory. - for (size_t i = 0; i < std::min((size_t)5UL, num_elements); i++) { + for (size_t i = 0; i < std::min(static_cast(5UL), num_elements); + i++) { LOG(INFO) << "data: " << static_cast(outputs.front().data.data())[i]; PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i], diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index ebe56734c6..d998533977 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -284,6 +284,7 @@ bool AnalysisPredictor::GetFetch(std::vector *outputs, framework::GetFetchVariable(*scope, "fetch", idx); auto type = fetch.type(); auto output = &(outputs->at(i)); + output->name = fetchs_[idx]->Input("X")[0]; if (type == typeid(float)) { GetFetchOne(fetch, output); output->dtype = PaddleDType::FLOAT32; diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index db57812bc3..12ecb7c15e 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -109,7 +109,7 @@ class AnalysisPredictor : public PaddlePredictor { std::map feed_names_; std::vector fetchs_; // Memory buffer for feed inputs. The temporary LoDTensor will cause serious - // concurrency problems, so cache them. + // concurrency problems, wrong results and memory leak, so cache them. std::vector feed_tensors_; details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index a7cef426d1..9c5703c91f 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -185,8 +185,12 @@ bool NativePaddlePredictor::SetFeed(const std::vector &inputs, << inputs.size(); return false; } + + // Cache the inputs memory for better concurrency performance. + feed_tensors_.resize(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { - framework::LoDTensor input; + auto &input = feed_tensors_[i]; framework::DDim ddim = framework::make_ddim(inputs[i].shape); void *input_ptr; if (inputs[i].dtype == PaddleDType::INT64) { @@ -261,6 +265,7 @@ bool NativePaddlePredictor::GetFetch(std::vector *outputs, framework::GetFetchVariable(*scope, "fetch", idx); auto type = fetch.type(); auto output = &(outputs->at(i)); + output->name = fetchs_[idx]->Input("X")[0]; if (type == typeid(float)) { GetFetchOne(fetch, output); output->dtype = PaddleDType::FLOAT32; diff --git a/paddle/fluid/inference/api/api_impl.h b/paddle/fluid/inference/api/api_impl.h index 9dfa48d501..c1fcd198cc 100644 --- a/paddle/fluid/inference/api/api_impl.h +++ b/paddle/fluid/inference/api/api_impl.h @@ -69,6 +69,9 @@ class NativePaddlePredictor : public PaddlePredictor { std::vector feeds_; std::map feed_names_; std::vector fetchs_; + // Memory buffer for feed inputs. The temporary LoDTensor will cause serious + // concurrency problems, wrong results and memory leak, so cache them. + std::vector feed_tensors_; // Do not use unique_ptr, use parent scope to delete framework::Scope *sub_scope_{nullptr}; details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index 2019d1a14f..3e8fb83e9d 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -86,7 +86,11 @@ void CPUAllocator::Free(void* p, size_t size, size_t index) { munlock(p, size); #endif } +#ifdef _WIN32 + _aligned_free(p); +#else free(p); +#endif } bool CPUAllocator::UseGpu() const { return false; } diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index bb9ea3f3ba..832245371e 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -149,6 +149,13 @@ $out = \max(x, 0)$ )DOC"; +UNUSED constexpr char GeluDoc[] = R"DOC( +Gelu Activation Operator. + +$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$ + +)DOC"; + UNUSED constexpr char TanhDoc[] = R"DOC( Tanh Activation Operator. @@ -472,6 +479,7 @@ REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc); REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc); REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc); REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc); +REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc); REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc); REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc); REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc); @@ -489,6 +497,7 @@ REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc); REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid); REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu); +REGISTER_ACTIVATION_OP_GRAD_MAKER(Gelu, gelu); REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp); REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh); REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil); @@ -525,6 +534,7 @@ namespace ops = paddle::operators; __macro(Round, round); \ __macro(Log, log); \ __macro(Square, square); \ + __macro(Gelu, gelu); \ __macro(BRelu, brelu); \ __macro(Pow, pow); \ __macro(STanh, stanh); \ diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 4ffc7f364b..c60ca18d13 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -16,6 +16,11 @@ limitations under the License. */ #include #include +#include +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" @@ -212,6 +217,31 @@ struct ReluGradFunctor : public BaseActivationFunctor { } }; +// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) +template +struct GeluFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + auto temp = + ((x * static_cast(M_SQRT1_2)).erf()).template cast().eval(); + out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); + } +}; + +template +struct GeluGradFunctor : BaseActivationFunctor { + bool Inplace() const { return IsInplace("gelu"); } + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + auto temp = (static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * + ((-static_cast(0.5) * x.square()).exp())) + .template cast() + .eval(); + dx.device(d) = dout * (out / x + temp); + } +}; + // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhFunctor : public BaseActivationFunctor { @@ -877,6 +907,7 @@ struct SwishGradFunctor : public BaseActivationFunctor { __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(exp, ExpFunctor, ExpGradFunctor); \ __macro(relu, ReluFunctor, ReluGradFunctor); \ + __macro(gelu, GeluFunctor, GeluGradFunctor); \ __macro(tanh, TanhFunctor, TanhGradFunctor); \ __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ diff --git a/paddle/fluid/operators/bilinear_tensor_product_op.h b/paddle/fluid/operators/bilinear_tensor_product_op.h index f23336f7b9..5017c3a457 100644 --- a/paddle/fluid/operators/bilinear_tensor_product_op.h +++ b/paddle/fluid/operators/bilinear_tensor_product_op.h @@ -70,7 +70,7 @@ class BilinearTensorProductKernel : public framework::OpKernel { if (bias) { auto bias_vec = EigenMatrix::From(*bias); Eigen::DSizes bcast(batch_size, 1); - output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat; + output_mat.device(place) = bias_vec.broadcast(bcast).eval() + output_mat; } } }; @@ -99,13 +99,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { auto d_out_mat = EigenMatrix::From(*d_out); auto& place = *ctx.template device_context().eigen_device(); auto& dev_ctx = ctx.template device_context(); - // Create the intermediate variable to caculate the Output(Y@Grad). + // Create the intermediate variable to calculate the Output(Y@Grad). Tensor x_scale; x_scale.mutable_data(framework::make_ddim({batch_size, x_dim}), ctx.GetPlace()); auto x_scale_mat = EigenMatrix::From(x_scale); - // Create the intermediate variable to caculate the Output(X@Grad). + // Create the intermediate variable to calculate the Output(X@Grad). Tensor y_scale; y_scale.mutable_data(framework::make_ddim({batch_size, y_dim}), ctx.GetPlace()); @@ -113,65 +113,64 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { math::SetConstant set_zero; - // Set Output(X@Grad) be zero. if (d_x) { d_x->mutable_data(ctx.GetPlace()); set_zero(dev_ctx, d_x, static_cast(0)); } - // Set Output(Y@Grad) be zero. if (d_y) { d_y->mutable_data(ctx.GetPlace()); set_zero(dev_ctx, d_y, static_cast(0)); } + if (d_weight) { + d_weight->mutable_data(ctx.GetPlace()); + } + auto blas = math::GetBlas(ctx); // Caculate the Output(X@Grad) and Output(Y@Grad). - if (d_x || d_y) { + if (d_x || d_y || d_weight) { Eigen::DSizes bcast_for_x(1, y_dim); Eigen::DSizes bcast_for_y(1, x_dim); + Eigen::DSizes bcast_for_weight(1, x_dim); + for (int i = 0; i < out_dim; ++i) { Tensor weight_i = weight->Slice(i, i + 1).Resize( framework::make_ddim({x_dim, y_dim})); auto output_vec = d_out_mat.chip(i, 1); + if (d_x) { y_scale_mat.device(place) = output_vec.reshape(Eigen::DSizes(batch_size, 1)) - .broadcast(bcast_for_x) * + .broadcast(bcast_for_x) + .eval() * y_mat; blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1, y_scale.data(), weight_i.data(), 1, d_x->data()); } - if (d_y) { - x_scale_mat.device(place) = + + if (d_y || d_weight) { + auto output_vec_y = output_vec.reshape(Eigen::DSizes(batch_size, 1)) - .broadcast(bcast_for_y) * - x_mat; - blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, - x_scale.data(), weight_i.data(), 1, d_y->data()); + .broadcast(bcast_for_y) + .eval(); + x_scale_mat.device(place) = output_vec_y * x_mat; + if (d_y) { + blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, + x_scale.data(), weight_i.data(), 1, d_y->data()); + } + if (d_weight) { + Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize( + framework::make_ddim({x_dim, y_dim})); + blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1, + x_scale.data(), y->data(), 0, d_weight_i.data()); + } } } } - // Caculate the gradient of Input(Weight). - if (d_weight) { - d_weight->mutable_data(ctx.GetPlace()); - Eigen::DSizes bcast_for_weight(1, x_dim); - for (int i = 0; i < out_dim; ++i) { - Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize( - framework::make_ddim({x_dim, y_dim})); - auto output_vec = d_out_mat.chip(i, 1); - x_scale_mat.device(place) = - output_vec.reshape(Eigen::DSizes(batch_size, 1)) - .broadcast(bcast_for_weight) * - x_mat; - blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1, - x_scale.data(), y->data(), 0, d_weight_i.data()); - } - } - - // Caculate the gradient of Input(Bias). + // calculate the gradient of Input(Bias). if (d_bias) { d_bias->mutable_data(ctx.GetPlace()); auto d_bias_mat = framework::EigenVector::Flatten(*d_bias); diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index dd3474dd25..2ccc86c1dc 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -120,6 +120,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel { "Dimensions of Input(X) and Mask must be the same."); ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); } }; diff --git a/paddle/fluid/operators/dropout_op_test.cc b/paddle/fluid/operators/dropout_op_test.cc index 424d273c34..3e401d1c4f 100644 --- a/paddle/fluid/operators/dropout_op_test.cc +++ b/paddle/fluid/operators/dropout_op_test.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifndef _WIN32 #include +#endif #include #include // NOLINT diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc index 10290a4aef..c600d1e3d7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc @@ -19,36 +19,21 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/operators/math/jit_kernel.h" -#include "xbyak.h" -#include "xbyak_util.h" +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" namespace paddle { namespace operators { using framework::DataLayout; using mkldnn::memory; - -static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { - std::transform(format.begin(), format.end(), format.begin(), ::tolower); - - if (!format.compare("nchw")) { - return memory::format::nchw; - } else if (!format.compare("nchw16c")) { - return memory::format::nChw16c; - } else if (!format.compare("nchw8c")) { - return memory::format::nChw8c; - } else if (!format.compare("nhwc")) { - return memory::format::nhwc; - } else { - return memory::format::any; - } -} +using platform::StringToMKLDNNFormat; static void UpdateDataFormat(const framework::ExecutionContext& ctx, framework::Tensor* tensor, const char* attribute) { if (ctx.op().HasAttr(attribute)) { auto format_as_string = ctx.Attr(attribute); - auto format = StringToMKLDNNFormat(format_as_string); + auto format = StringToMKLDNNFormat(&format_as_string); if (format != memory::format::any) { tensor->set_format(format); } @@ -93,8 +78,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto y_dims_untrimmed = y->dims(); auto x_int_dims = paddle::framework::vectorize2int(x_dims); - UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); - UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); + UpdateDataFormat(ctx, const_cast(x), "x_data_format"); + UpdateDataFormat(ctx, const_cast(y), "y_data_format"); Xbyak::util::Cpu cpu; const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F); @@ -156,10 +141,10 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); if (!(is_x_nchw || is_x_nc)) - ReorderInput((Tensor*)x, ctx.GetPlace(), mkldnn_engine, + ReorderInput(const_cast(x), ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4); if (!(is_y_nchw || is_y_nc)) - ReorderInput((Tensor*)y, ctx.GetPlace(), mkldnn_engine, + ReorderInput(const_cast(y), ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4); } diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index dadd054b9a..972dcf5494 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/hierarchical_sigmoid_op.h" +#include #include - namespace paddle { namespace operators { @@ -70,13 +70,14 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { const int64_t batch_size = ctx->GetInputDim("X")[0]; std::vector output_shape({batch_size, 1}); ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->ShareLoD("X", /*->*/ "Out"); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); } }; @@ -86,27 +87,40 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor, required) The input tensor with shape [N, D], " + "(LoDTensor, required) The input tensor with shape [N, D], " "where N is the size of mini-batch, and D is the feature size."); AddInput("W", - "(Tensor, required), The parameters of hierarchical " + "(LoDTensor, required), The parameters of hierarchical " "sigmoid operator, each of them is a 2-D tensor, the shape is" - "[num_classes - 1, D]."); + "[K, D]. Which K is the num of non-leaf node in Path Tree"); AddInput("Label", - "(Tensor, required), The labels of training data. It's a" + "(LoDTensor, required), The labels of training data. It's a" "tensor with shape [N, 1]."); + AddInput("PTable", + "(LoDTensor, optional), The Path Table from root to current word" + "it should have shape like [N, L], L is the length of the Path") + .AsDispensable(); + AddInput( + "PathCode", + "(LoDTensor, optional), The Code on each Node of the Path from root " + "to current word" + "it should have shape like [N, L], L is the length of the Path") + .AsDispensable(); AddInput("Bias", - "(Tensor, optional), The bias is a tensor with shape" - "[1, num_classes - 1]."); - AddOutput("Out", - "(Tensor, required) The output of hierarchical sigmoid operator." - "The shape is [N, 1]."); + "(LoDTensor, optional), The bias is a tensor with shape or " + "[num_classes, 1]" + "[num_classes - 1, 1].") + .AsDispensable(); + AddOutput( + "Out", + "(LoDTensor, required) The output of hierarchical sigmoid operator." + "The shape is [N, 1]."); AddOutput("PreOut", - "(Tensor, required) A intermedia 2-D tensor with shape " + "(LoDTensor, required) A intermedia 2-D tensor with shape " "[batch_size, code_length], where code_length represents the " "maximum path length from root to leaf nodes.") .AsIntermediate(); - AddAttr("num_classes", "(int, required), The number of classes") + AddAttr("num_classes", "(int, optional), The number of classes") .SetDefault(2); AddComment(R"DOC( The hierarchical sigmoid operator organize the classes into a binary tree. @@ -115,6 +129,10 @@ belonging to the right branch. This idea is from "F. Morin, Y. Bengio (AISTATS 05): Hierarchical Probabilistic Neural Network Language Model." )DOC"); + AddAttr("is_sparse", + "(boolean, default false) " + "Sparse update.") + .SetDefault(false); } }; @@ -124,16 +142,21 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@Grad) should not be null"); PADDLE_ENFORCE(ctx->HasInput("PreOut"), "Input(Preout) should not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), - "Output(W@Grad should not be null.)"); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X"))); - if (ctx->HasOutput(framework::GradVarName("Bias"))) { - ctx->SetOutputDim(framework::GradVarName("Bias"), - ctx->GetInputDim("Bias")); + "Output(W@Grad should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@Grad should not be null."); + if (!ctx->Attrs().Get("is_sparse")) { + if (ctx->HasOutput(framework::GradVarName("Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Bias"), + ctx->GetInputDim("Bias")); + } + ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); } - ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } @@ -141,11 +164,55 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); } }; +class HierarchicalSigmoidGradOpGradVarTypeInference + : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto w_grad_var_name = op_desc.Output(framework::GradVarName("W")).front(); + auto bias_grad_var_name_vec = + op_desc.Output(framework::GradVarName("Bias")); + std::string bias_grad_var_name; + bool hasBias = false; + if (bias_grad_var_name_vec.size()) { + hasBias = true; + bias_grad_var_name = + op_desc.Output(framework::GradVarName("Bias")).front(); + } + auto attr = op_desc.GetAttr("is_sparse"); + bool is_sparse = boost::get(attr); + if (is_sparse) { + VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") + << " is set to SelectedRows"; + block->Var(w_grad_var_name) + ->SetType(framework::proto::VarType::SELECTED_ROWS); + if (hasBias) { + VLOG(30) << "hierarchical_sigmoid_grad op " + << framework::GradVarName("Bias") << " is set to SelectedRows"; + block->Var(bias_grad_var_name) + ->SetType(framework::proto::VarType::SELECTED_ROWS); + } + } else { + VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") + << " is set to LoDTensor"; + block->Var(w_grad_var_name) + ->SetType(framework::proto::VarType::LOD_TENSOR); + if (hasBias) { + VLOG(30) << "hierarchical_sigmoid_grad op " + << framework::GradVarName("Bias") << " is set to LoDTensor"; + block->Var(bias_grad_var_name) + ->SetType(framework::proto::VarType::LOD_TENSOR); + } + } + block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType()); + } +}; + } // namespace operators } // namespace paddle @@ -153,7 +220,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, ops::HierarchicalSigmoidOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp); +REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp, + ops::HierarchicalSigmoidGradOpGradVarTypeInference); REGISTER_OP_CPU_KERNEL( hierarchical_sigmoid, ops::HierarchicalSigmoidOpKernel, diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 79980cda53..07ff8f947e 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -14,12 +14,16 @@ limitations under the License. */ #pragma once #include +#include #include +#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/platform/transform.h" + namespace paddle { namespace operators { @@ -28,20 +32,38 @@ template ; using platform::Transform; +static std::vector PathToRows(const framework::LoDTensor& path) { + std::set rows; + for (int64_t i = 0; i < path.numel(); ++i) { + int64_t row = path.data()[i]; + if (row < 0) { + continue; + } + rows.emplace(row); + } + return std::vector(rows.begin(), rows.end()); +} template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* w = ctx.Input("W"); - auto* label = ctx.Input("Label"); - auto* bias = ctx.Input("Bias"); - auto* out = ctx.Output("Out"); - auto* pre_out = ctx.Output("PreOut"); + auto& in = detail::Ref(ctx.Input("X")); + auto& w = detail::Ref(ctx.Input("W")); + auto* path = ctx.Input("PTable"); + auto* code = ctx.Input("PathCode"); + auto& label = detail::Ref(ctx.Input("Label")); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + auto* pre_out = ctx.Output("PreOut"); size_t num_classes = static_cast(ctx.Attr("num_classes")); - int64_t code_length = math::FindLastSet(num_classes - 1); - int64_t batch_size = in->dims()[0]; - framework::Tensor sum; + bool is_custom = false; + if (path) { + is_custom = true; + } + int64_t code_length = + path ? path->dims()[1] : math::FindLastSet(num_classes - 1); + int64_t batch_size = in.dims()[0]; + framework::LoDTensor sum; auto& dev_ctx = ctx.template device_context(); auto* pre_out_data = pre_out->mutable_data( framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); @@ -52,7 +74,15 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { zero(dev_ctx, pre_out, static_cast(0.0)); auto& place = *ctx.template device_context().eigen_device(); math::RowwiseSum row_sum; - math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); + + std::unique_ptr> bit_code; + if (!is_custom) { + bit_code.reset(new math::MatrixBitCodeFunctor(num_classes, + label.data())); + } else { + bit_code.reset(new math::MatrixBitCodeFunctor(*path, *code, + label.data())); + } std::vector sum_dims({batch_size, 1UL}); sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); @@ -60,15 +90,15 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); auto out_mat = framework::EigenVector::Flatten(*out); if (bias) { - bit_code.Add(pre_out, *bias); + bit_code->Add(*bias, pre_out); } - bit_code.Mul(pre_out, *w, *in); + bit_code->Mul(pre_out, w, in); // clip to [-40, 40] Transform trans; trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out->numel(), pre_out_data, ClipFunctor(static_cast(-40.0), static_cast(40.0))); - bit_code.Sum(*pre_out, out, static_cast(-1)); + bit_code->Sum(*pre_out, out, static_cast(-1)); // use softrelu to calculate cross entropy pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); row_sum(dev_ctx, *pre_out, &sum); @@ -84,50 +114,103 @@ template class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* w = ctx.Input("W"); - auto* in_grad = ctx.Output(framework::GradVarName("X")); - auto* w_grad = ctx.Output(framework::GradVarName("W")); - auto* bias_grad = - ctx.Output(framework::GradVarName("Bias")); - auto* label = ctx.Input("Label"); - auto* pre_out = ctx.Input("PreOut"); - auto* out_grad = - ctx.Input(framework::GradVarName("Out")); - framework::Tensor pre_out_grad; - - pre_out_grad.mutable_data(pre_out->dims(), ctx.GetPlace()); - in_grad->mutable_data(ctx.GetPlace()); - w_grad->mutable_data(ctx.GetPlace()); + auto& in = detail::Ref(ctx.Input("X")); + auto& w = detail::Ref(ctx.Input("W")); + auto* path = ctx.Input("PTable"); + auto* code = ctx.Input("PathCode"); + auto* bias = ctx.Input("Bias"); + auto* in_grad = + ctx.Output(framework::GradVarName("X")); + bool is_sparse = ctx.Attr("is_sparse"); auto& dev_ctx = ctx.template device_context(); math::SetConstant zero; + auto& label = detail::Ref(ctx.Input("Label")); + auto& pre_out = detail::Ref(ctx.Input("PreOut")); + auto& out_grad = detail::Ref( + ctx.Input(framework::GradVarName("Out"))); + framework::LoDTensor pre_out_grad; + + pre_out_grad.mutable_data(pre_out.dims(), ctx.GetPlace()); + in_grad->mutable_data(ctx.GetPlace()); zero(dev_ctx, in_grad, static_cast(0.0)); - zero(dev_ctx, w_grad, static_cast(0.0)); size_t num_classes = static_cast(ctx.Attr("num_classes")); - math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); + + bool is_custom = false; + if (path) { + is_custom = true; + } + + std::unique_ptr> bit_code; + if (!is_custom) { + bit_code.reset(new math::MatrixBitCodeFunctor(num_classes, + label.data())); + } else { + bit_code.reset(new math::MatrixBitCodeFunctor(*path, *code, + label.data())); + } auto& place = *ctx.template device_context().eigen_device(); - auto pre_out_mat = EigenMatrix::From(*pre_out); + auto pre_out_mat = EigenMatrix::From(pre_out); auto pre_out_grad_mat = EigenMatrix::From(pre_out_grad); - auto out_grad_mat = EigenMatrix::From(*out_grad); + auto out_grad_mat = EigenMatrix::From(out_grad); + Eigen::array bcast{1, static_cast(pre_out_grad.dims()[1])}; // softrelu derivative pre_out_grad_mat.device(place) = static_cast(1.0) - static_cast(1.0) / pre_out_mat.exp(); - bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b) + bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b) pre_out_grad_mat.device(place) = pre_out_grad_mat * out_grad_mat.broadcast(bcast); // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to // be consistent with the clipping in forward. - if (bias_grad) { - bias_grad->mutable_data(ctx.GetPlace()); - zero(dev_ctx, bias_grad, static_cast(0.0)); - bit_code.AddGrad(pre_out_grad, bias_grad); + + if (!is_sparse) { + auto* bias_grad = + ctx.Output(framework::GradVarName("Bias")); + if (bias_grad) { + bias_grad->mutable_data(ctx.GetPlace()); + zero(dev_ctx, bias_grad, static_cast(0.0)); + bit_code->AddGrad(pre_out_grad, bias_grad); + } + auto* w_grad = + ctx.Output(framework::GradVarName("W")); + w_grad->mutable_data(ctx.GetPlace()); + zero(dev_ctx, w_grad, static_cast(0.0)); + bit_code->MulGradWeight(pre_out_grad, w_grad, in); + } else { + framework::Vector real_rows = PathToRows(*path); + auto* w_grad = + ctx.Output(framework::GradVarName("W")); + w_grad->set_rows(real_rows); + // Build a map of id -> row_index to speed up finding the index of one id + w_grad->SyncIndex(); + w_grad->set_height(w.dims()[0]); + auto* w_grad_value = w_grad->mutable_value(); + framework::DDim temp_dim(w.dims()); + set(temp_dim, 0, real_rows.size()); + + w_grad_value->mutable_data(temp_dim, ctx.GetPlace()); + zero(dev_ctx, w_grad_value, static_cast(0.0)); + auto* bias_grad = + ctx.Output(framework::GradVarName("Bias")); + if (bias_grad) { + bias_grad->set_rows(real_rows); + // build ids -> rows index map + bias_grad->SyncIndex(); + bias_grad->set_height(bias->dims()[0]); + auto* bias_grad_value = bias_grad->mutable_value(); + std::vector dims = {static_cast(real_rows.size()), + bias->dims()[1]}; + bias_grad_value->mutable_data(framework::make_ddim(dims), + ctx.GetPlace()); + zero(dev_ctx, bias_grad_value, static_cast(0.0)); + bit_code->AddGrad(pre_out_grad, bias_grad); + } + bit_code->MulGradWeight(pre_out_grad, w_grad, in); } - bit_code.MulGradWeight(pre_out_grad, w_grad, *in); - bit_code.MulGradError(pre_out_grad, *w, in_grad); + bit_code->MulGradError(pre_out_grad, w, in_grad); } }; diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 1e56e29739..71b9293eed 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -19,16 +19,15 @@ namespace operators { namespace math { template -void MatrixBitCodeFunctor::Add(framework::Tensor* tmat, - const framework::Tensor& vec) { - SimpleCodeTable code_table(num_classes_); +void MatrixBitCodeFunctor::Add(const framework::Tensor& vec, + framework::Tensor* tmat) { size_t batch_size = tmat->dims()[0]; size_t width = tmat->dims()[1]; for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); tmat->data()[i * width + j] += vec.data()[index]; } } @@ -37,31 +36,46 @@ void MatrixBitCodeFunctor::Add(framework::Tensor* tmat, template void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, framework::Tensor* vec) { - SimpleCodeTable code_table(num_classes_); size_t batch_size = tmat.dims()[0]; size_t width = tmat.dims()[1]; for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); vec->data()[index] += tmat.data()[i * width + j]; } } } +template +void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, + framework::SelectedRows* vec) { + size_t batch_size = tmat.dims()[0]; + size_t width = tmat.dims()[1]; + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table_->get_code(i); + int code_length = code->get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code->calc_index(j); + int64_t row_index = vec->GetIndexFromId(static_cast(index)); + vec->mutable_value()->data()[row_index] += + tmat.data()[i * width + j]; + } + } +} + template void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t o_width = tmat.dims()[1]; for (size_t i = 0; i < num_samples; ++i) { T sm = static_cast(0.0); - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - if (code.calc_bit(j)) { + if (code->calc_bit(j)) { // calc_bit starts from right most bit, while data in tmat[i] is in the // reverse order. sm += tmat.data()[i * o_width + j]; @@ -75,7 +89,6 @@ template void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, const framework::Tensor& weight, const framework::Tensor& input) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat->dims()[0]; size_t tmat_width = tmat->dims()[1]; size_t input_width = input.dims()[1]; @@ -84,10 +97,10 @@ void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, auto weight_value = weight.data(); auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); T sum = static_cast(0.0); for (size_t k = 0; k < input_width; ++k) { sum += weight_value[weight_width * index + k] * @@ -102,7 +115,6 @@ template void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, const framework::Tensor& input) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; @@ -111,10 +123,10 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, auto weight_value = weight->data(); auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); for (size_t k = 0; k < input_width; ++k) { weight_value[weight_width * index + k] += @@ -124,11 +136,35 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, } } +template +void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, + framework::SelectedRows* weight, + const framework::Tensor& input) { + size_t num_samples = tmat.dims()[0]; + size_t input_width = input.dims()[1]; + size_t tmat_width = tmat.dims()[1]; + size_t weight_width = weight->value().dims()[1]; + auto tmat_value = tmat.data(); + auto weight_value = weight->mutable_value()->data(); + auto input_value = input.data(); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table_->get_code(i); + int code_length = code->get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code->calc_index(j); + for (size_t k = 0; k < input_width; ++k) { + int64_t row_index = weight->GetIndexFromId(static_cast(index)); + weight_value[row_index * weight_width + k] += + tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; + } + } + } +} + template void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, const framework::Tensor& weight, framework::Tensor* input) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t tmat_width = tmat.dims()[1]; size_t input_width = input->dims()[1]; @@ -138,10 +174,10 @@ void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, auto input_value = input->data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); for (size_t k = 0; k < input_width; ++k) { input_value[input_width * i + k] += @@ -154,14 +190,13 @@ void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, template void MatrixBitCodeFunctor::Sub(framework::Tensor* tmat) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat->dims()[0]; size_t o_width = tmat->dims()[1]; for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - if (code.calc_bit(j)) { + if (code->calc_bit(j)) { tmat->data()[i * o_width + j] -= 1; } } diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index c329b8b611..c30bb52641 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" @@ -92,9 +94,27 @@ inline int clz(const T& value) { inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); } #endif // !_WIN32 +// set a code interface to create multiple code +class Code { + public: + virtual ~Code() {} + virtual size_t calc_index(int bit) const = 0; + virtual bool calc_bit(int bit) const = 0; + virtual int get_length() const = 0; +}; +// set a CodeTable interface to create multiple code table +class CodeTable { + public: + virtual std::unique_ptr get_code(int64_t code) const = 0; + virtual size_t size() const = 0; + virtual int get_max_code_length() const = 0; + virtual ~CodeTable() {} +}; -struct SimpleCode { - SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {} +class SimpleCode : public Code { + public: + SimpleCode(size_t code, size_t num_classes, const int64_t* ids) + : c_(static_cast(ids[code]) + num_classes) {} /** * Here the id of root shoud be 1 rather than 0, thus the encoding of class c * is `c + num_classes` and all siblings can get the same weight indice using @@ -104,41 +124,121 @@ struct SimpleCode { * Binary classification path is the suffixes of encoding, thus leave out the * left most bit in calc_bit. */ - inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } - inline bool calc_bit(int bit) const { return c_ & (1 << bit); } - inline int get_length() const { return FindLastSet(c_) - 1; } + size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } + bool calc_bit(int bit) const { return c_ & (1 << bit); } + int get_length() const { return FindLastSet(c_) - 1; } private: size_t c_; }; -struct SimpleCodeTable { - explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {} - SimpleCode operator()(size_t code) const { - return SimpleCode(code, num_classes_); +template +class CustomCode : public Code { + public: + CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode, + const int64_t* ids, int index) + : ids_(ids), index_(index) { + ptable_ = ptable.Slice(index, index + 1); + pcode_ = pcode.Slice(index, index + 1); + } + /** + * Here the id of root shoud be 1 rather than 0, thus the encoding of class c + * is `c + num_classes` and all siblings can get the same weight indice using + * prefixes. + * Weight index is the prefixes of encoding, thus leave out the right most + * bit in calc_index. + * Binary classification path is the suffixes of encoding, thus leave out the + * left most bit in calc_bit. + */ + size_t calc_index(int bit) const { return ptable_.data()[bit]; } + bool calc_bit(int bit) const { return pcode_.data()[bit]; } + int get_length() const { + int length = 0; + + for (int i = 0; i < static_cast(ptable_.dims()[1]); i++) { + if (ptable_.data()[i] >= 0) { + length++; + } else { + return length; + } + } + return length; + } + + private: + framework::Tensor ptable_; + framework::Tensor pcode_; + const int64_t* ids_; + const int index_; +}; + +class SimpleCodeTable : public CodeTable { + public: + SimpleCodeTable(size_t num_classes, const int64_t* ids) + : num_classes_(num_classes), ids_(ids) {} + std::unique_ptr get_code(int64_t code) const { + std::unique_ptr coder(new SimpleCode(code, num_classes_, ids_)); + return coder; } size_t size() const { return num_classes_; } int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } private: size_t num_classes_; + const int64_t* ids_; +}; + +template +class CustomCodeTable : public CodeTable { + public: + CustomCodeTable(const framework::Tensor& ptable, + const framework::Tensor& pcode, const int64_t* ids) + : ptable_(ptable), pcode_(pcode), ids_(ids) {} + + std::unique_ptr get_code(int64_t code) const { + std::unique_ptr coder(new CustomCode(ptable_, pcode_, ids_, code)); + return coder; + } + + size_t size() const { return static_cast(ptable_.dims()[1]); } + int get_max_code_length() const { + return static_cast(ptable_.dims()[1]); + } + + private: + const framework::Tensor& ptable_; + const framework::Tensor& pcode_; + const int64_t* ids_; }; template class MatrixBitCodeFunctor { public: - explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) - : num_classes_(num_classes), ids_(ids) {} + MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) + : num_classes_(num_classes), + ids_(ids), + code_table_(new SimpleCodeTable(num_classes, ids)) {} + + MatrixBitCodeFunctor(const framework::Tensor& ptable, + const framework::Tensor& pcode, const int64_t* ids) + : num_classes_(static_cast(ptable.dims()[1])), + ids_(ids), + code_table_(new CustomCodeTable(ptable, pcode, ids)) {} /* For j < code_length tmat(i, j) += vec(0, index(i, j)) */ - void Add(framework::Tensor* tmat, const framework::Tensor& vec); + void Add(const framework::Tensor& vec, framework::Tensor* tmat); /* For j < code_length vec(0, index(i, j)) += tmat(i, j) */ void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec); + /* For selected rows For j < code_length + vec(0, index(i, j)) += tmat(i, j) + */ + void AddGrad(const framework::Tensor& tmat, framework::SelectedRows* vec); + /* For j < code_length sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) */ @@ -159,6 +259,12 @@ class MatrixBitCodeFunctor { */ void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, const framework::Tensor& input); + /* For SelectedRows Weight, For index(i, j) >= 0: + weight.row(index(i, j)) += tmat(i, j) * input.row(i) + */ + void MulGradWeight(const framework::Tensor& tmat, + framework::SelectedRows* weight, + const framework::Tensor& input); /* For j < code_length input.row(i) += tmat(i, j) * weight.row(index(i, j)) */ @@ -167,6 +273,7 @@ class MatrixBitCodeFunctor { size_t num_classes_; const int64_t* ids_; + std::unique_ptr code_table_; }; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sampler.cc b/paddle/fluid/operators/math/sampler.cc index 690d6f6baa..2708f3bcd8 100644 --- a/paddle/fluid/operators/math/sampler.cc +++ b/paddle/fluid/operators/math/sampler.cc @@ -60,75 +60,30 @@ float LogUniformSampler::Probability(int64_t value) const { return (log((value + 2.0) / (value + 1.0))) / log_range_; } -CustomSampler::CustomSampler(int64_t range, const float* probabilities, +CustomSampler::CustomSampler(int64_t range, const float *probabilities, + const int *alias, const float *alias_probabilities, unsigned int seed) : Sampler(range, seed) { - random_engine_ = std::make_shared(seed_); + random_engine_ = std::make_shared(seed_); real_dist_ = std::make_shared>(0, 1); int_dist_ = std::make_shared>(0, range); - alias_probs_ = std::make_shared>(range + 1); - alias_ = std::make_shared>(range + 1); - probs_ = std::make_shared>(range + 1); - - std::queue> bigs; - std::queue> littles; - for (int64_t i = 0; i <= range; ++i) { - (*probs_)[i] = probabilities[i]; - float normal_prob = probabilities[i] * (range + 1); - if (normal_prob - 1.0 > 1e-4) { - bigs.emplace(i, normal_prob); - } else if (1.0 - normal_prob > 1e-4) { - littles.emplace(i, normal_prob); - } else { - (*alias_probs_)[i] = normal_prob; - (*alias_)[i] = -1; - } - } - - while ((!littles.empty()) && (!bigs.empty())) { - auto big = bigs.front(); - auto little = littles.front(); - bigs.pop(); - littles.pop(); - (*alias_probs_)[little.first] = little.second; - (*alias_)[little.first] = big.first; - auto big_left = big.second - (1 - little.second); - if (big_left - 1.0 > 1e-4) { - bigs.emplace(big.first, big_left); - } else if (1.0 - big_left > 1e-4) { - littles.emplace(big.first, big_left); - } else { - (*alias_probs_)[big.first] = big_left; - (*alias_)[big.first] = -1; - } - } - if (!littles.empty()) { // littles.second is close to 1.0 - auto little = littles.front(); - (*alias_probs_)[little.first] = 1.0; - (*alias_)[little.first] = -1; - } - - if (!bigs.empty()) { // bigs.second is close to 1.0 - auto big = bigs.front(); - (*alias_probs_)[big.first] = 1.0; - (*alias_)[big.first] = -1; - } + alias_probs_ = alias_probabilities; + probs_ = probabilities; + alias_ = alias; } int64_t CustomSampler::Sample() const { auto index = (*int_dist_)(*random_engine_); auto p = (*real_dist_)(*random_engine_); - if (p > (*alias_probs_)[index]) { - return (*alias_)[index]; + if (p > alias_probs_[index]) { + return alias_[index]; } else { return index; } } -float CustomSampler::Probability(int64_t value) const { - return (*probs_)[value]; -} +float CustomSampler::Probability(int64_t value) const { return probs_[value]; } } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sampler.h b/paddle/fluid/operators/math/sampler.h index 836cdad51f..98e0b898a5 100644 --- a/paddle/fluid/operators/math/sampler.h +++ b/paddle/fluid/operators/math/sampler.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include #include #include @@ -38,9 +39,12 @@ class Sampler { seed_ = seed; } } + virtual ~Sampler(); + // Sample a single value virtual int64_t Sample() const = 0; + // The probability that a single call to Sample() returns the given value. virtual float Probability(int64_t value) const = 0; @@ -99,6 +103,7 @@ class LogUniformSampler : public Sampler { class CustomSampler : public Sampler { public: explicit CustomSampler(int64_t range, const float* probabilities, + const int* alias, const float* alias_probabilities, unsigned int seed = 0UL); ~CustomSampler() override {} @@ -108,10 +113,10 @@ class CustomSampler : public Sampler { float Probability(int64_t value) const override; private: - std::shared_ptr> alias_probs_; - std::shared_ptr> alias_; - std::shared_ptr> probs_; - std::shared_ptr random_engine_; + const float* alias_probs_; + const int* alias_; + const float* probs_; + std::shared_ptr random_engine_; std::shared_ptr> real_dist_; std::shared_ptr> int_dist_; }; diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/fluid/operators/math/sequence_pooling.cu index 0015fafbc8..51da6de26e 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cu +++ b/paddle/fluid/operators/math/sequence_pooling.cu @@ -16,13 +16,12 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence_pooling.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/macros.h" namespace paddle { namespace operators { namespace math { -#define FLT_MAX __FLT_MAX__ - template struct MaxPoolFunctor { HOSTDEVICE void operator()(const T* input, const size_t start, diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index 9b0d45ae5b..655e171e63 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/nce_op.h" +#include #include namespace paddle { @@ -25,7 +26,7 @@ class NCEOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input")); PADDLE_ENFORCE(ctx->HasInput("Label")); PADDLE_ENFORCE(ctx->HasInput("Weight")); @@ -67,7 +68,7 @@ class NCEOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), platform::CPUPlace()); @@ -101,11 +102,24 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable(); AddInput( - "CustomDistribution", + "CustomDistProbs", "(Tensor) It is used in 'CostumDist' sampler. " "It is a tensor with shape [num_total_classes]." "The i-th element is the probsbility of the i-th class being sampled.") .AsDispensable(); + AddInput( + "CustomDistAlias", + "(Tensor) It is used in 'CostumDist' sampler. " + "It is a tensor with shape [num_total_classes]." + "The i-th element is the probsbility of the i-th class being sampled.") + .AsDispensable(); + AddInput( + "CustomDistAliasProbs", + "(Tensor) It is used in 'CostumDist' sampler. " + "It is a tensor with shape [num_total_classes]." + "The i-th element is the probsbility of the i-th class being sampled.") + .AsDispensable(); + AddOutput("Cost", "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples."); AddOutput("SampleLogits", @@ -124,21 +138,22 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { "kernel to compute grads." "") .AsIntermediate(); + AddAttr("num_total_classes", "Total number of classes in all samples."); AddAttr("num_neg_samples", "The number of negative classes. The default value is 10.") .SetDefault(10); - AddAttr("sampler", "(int) Which sampler to be used to sample negative class." "0: Uniform; 1: LogUniform; 2: CostumDist.") .SetDefault(0); - AddAttr("seed", "(int) The seed used in sampler. If it is 0, " "the sampler will generate a seed randomly.") .SetDefault(0); + AddAttr("is_sparse", "(boolean, default false) Sparse update.") + .SetDefault(false); AddAttr>("custom_neg_classes", "This attribute only be used in unitest. Classes " @@ -156,11 +171,19 @@ By default this operator uses a uniform distribution for sampling. } }; +class NCEOpGradDescMaker : public framework::DefaultGradOpDescMaker { + using ::paddle::framework::DefaultGradOpDescMaker< + true>::DefaultGradOpDescMaker; + + protected: + virtual std::string GradOpType() const { return "nce_grad"; } +}; + class NCEOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input")); PADDLE_ENFORCE(ctx->HasInput("Weight")); PADDLE_ENFORCE(ctx->HasInput("Cost")); @@ -190,20 +213,45 @@ class NCEOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), platform::CPUPlace()); } }; +class NCEOpGradVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front(); + auto bias_grad = op_desc.Output(framework::GradVarName("Bias")).front(); + + auto attr = op_desc.GetAttr("is_sparse"); + bool is_sparse = boost::get(attr); + if (is_sparse) { + VLOG(30) << "nce_op_grad op " << weight_grad << " and " << bias_grad + << " is set to SelectedRows"; + block->Var(weight_grad) + ->SetType(framework::proto::VarType::SELECTED_ROWS); + block->Var(bias_grad)->SetType(framework::proto::VarType::SELECTED_ROWS); + } else { + VLOG(30) << "nce_op_grad op " << weight_grad << " and " << bias_grad + << " is set to LoDTensor"; + block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR); + block->Var(bias_grad)->SetType(framework::proto::VarType::LOD_TENSOR); + } + block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType()); + block->Var(bias_grad)->SetDataType(block->Var("Input")->GetDataType()); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad); +REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpGradDescMaker, ops::NCEOpMaker); +REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference); REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel, ops::NCEKernel); REGISTER_OP_CPU_KERNEL(nce_grad, diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index e9af8ad4ce..f2ca6ec247 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -16,26 +16,32 @@ limitations under the License. */ #include #include +#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/math/sampler.h" #include "unsupported/Eigen/CXX11/Tensor" + namespace paddle { namespace operators { using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; using Sampler = math::Sampler; +using DDim = framework::DDim; template using EigenMatrix = framework::EigenMatrix; template -void PrepareSamples(const framework::ExecutionContext& context, - Sampler* sampler) { +void PrepareSamples(const framework::ExecutionContext &context, + Sampler *sampler) { auto label = context.Input("Label"); - const int64_t* label_data = label->data(); + const int64_t *label_data = label->data(); auto label_dims = label->dims(); // int num_total_classes = context.Attr("num_total_classes"); // for unitest @@ -44,7 +50,7 @@ void PrepareSamples(const framework::ExecutionContext& context, auto sample_labels = context.Output("SampleLabels"); auto sample_labels_dims = sample_labels->dims(); - int64_t* sample_labels_data = + int64_t *sample_labels_data = sample_labels->mutable_data(context.GetPlace()); int num_label = label_dims.size() == 2 ? label_dims[1] : 1; @@ -70,13 +76,13 @@ void PrepareSamples(const framework::ExecutionContext& context, template class NCEKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { + void Compute(const framework::ExecutionContext &context) const override { int sampler_type = context.Attr("sampler"); int seed = context.Attr("seed"); int num_total_classes = context.Attr("num_total_classes"); int num_neg_samples = context.Attr("num_neg_samples"); - Sampler* sampler; + Sampler *sampler; switch (sampler_type) { case 0: { sampler = new math::UniformSampler(num_total_classes - 1, seed); @@ -87,11 +93,19 @@ class NCEKernel : public framework::OpKernel { break; } case 2: { - auto custom_dist = context.Input("CustomDistribution"); - const float* custom_dist_data = custom_dist->data(); - PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes); - sampler = new math::CustomSampler(num_total_classes - 1, - custom_dist_data, seed); + auto dist_probs = context.Input("CustomDistProbs"); + auto dist_alias = context.Input("CustomDistAlias"); + auto dist_alias_probs = context.Input("CustomDistAliasProbs"); + + PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes); + PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes); + PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes); + + const float *probs_data = dist_probs->data(); + const int *alias_data = dist_alias->data(); + const float *alias_probs_data = dist_alias_probs->data(); + sampler = new math::CustomSampler(num_total_classes - 1, probs_data, + alias_data, alias_probs_data, seed); break; } default: { PADDLE_THROW("Unsupported SamplerType."); } @@ -99,17 +113,17 @@ class NCEKernel : public framework::OpKernel { PrepareSamples(context, sampler); auto sample_labels = context.Output("SampleLabels"); - const int64_t* sample_labels_data = sample_labels->data(); + const int64_t *sample_labels_data = sample_labels->data(); auto sample_out = context.Output("SampleLogits"); - T* sample_out_data = sample_out->mutable_data(context.GetPlace()); + T *sample_out_data = sample_out->mutable_data(context.GetPlace()); auto label = context.Input("Label"); auto sample_weight = context.Input("SampleWeight"); - const T* sample_weight_data = nullptr; + const T *sample_weight_data = nullptr; if (sample_weight != nullptr) { sample_weight_data = sample_weight->data(); } auto out = context.Output("Cost"); - T* out_data = out->mutable_data(context.GetPlace()); + T *out_data = out->mutable_data(context.GetPlace()); int64_t num_true_class = 1; if (label != nullptr) { num_true_class = label->dims()[1]; @@ -119,7 +133,7 @@ class NCEKernel : public framework::OpKernel { // forward bias auto bias = context.Input("Bias"); if (bias != nullptr) { - const T* bias_data = bias->data(); + const T *bias_data = bias->data(); for (int64_t i = 0; i < sample_labels->numel(); ++i) { sample_out_data[i] = bias_data[sample_labels_data[i]]; } @@ -158,16 +172,16 @@ class NCEKernel : public framework::OpKernel { template class NCEGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { + void Compute(const framework::ExecutionContext &context) const override { auto d_out = context.Input(framework::GradVarName("Cost")); - const T* d_out_data = d_out->data(); + const T *d_out_data = d_out->data(); auto label = context.Input("Label"); auto sample_out = context.Input("SampleLogits"); - const T* sample_out_data = sample_out->data(); + const T *sample_out_data = sample_out->data(); auto sample_labels = context.Input("SampleLabels"); - const int64_t* sample_labels_data = sample_labels->data(); + const int64_t *sample_labels_data = sample_labels->data(); auto sample_weight = context.Input("SampleWeight"); - const T* sample_weight_data = nullptr; + const T *sample_weight_data = nullptr; if (sample_weight != nullptr) { sample_weight_data = sample_weight->data(); } @@ -180,7 +194,7 @@ class NCEGradKernel : public framework::OpKernel { int sampler_type = context.Attr("sampler"); int seed = context.Attr("seed"); - Sampler* sampler; + Sampler *sampler; switch (sampler_type) { case 0: { sampler = new math::UniformSampler(num_total_classes - 1, seed); @@ -191,11 +205,19 @@ class NCEGradKernel : public framework::OpKernel { break; } case 2: { - auto custom_dist = context.Input("CustomDistribution"); - const float* custom_dist_data = custom_dist->data(); - PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes); - sampler = new math::CustomSampler(num_total_classes - 1, - custom_dist_data, seed); + auto dist_probs = context.Input("CustomDistProbs"); + auto dist_alias = context.Input("CustomDistAlias"); + auto dist_alias_probs = context.Input("CustomDistAliasProbs"); + + PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes); + PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes); + PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes); + + const float *probs_data = dist_probs->data(); + const int *alias_data = dist_alias->data(); + const float *alias_probs_data = dist_alias_probs->data(); + sampler = new math::CustomSampler(num_total_classes - 1, probs_data, + alias_data, alias_probs_data, seed); break; } default: { PADDLE_THROW("Unsupported SamplerType."); } @@ -203,7 +225,7 @@ class NCEGradKernel : public framework::OpKernel { // T b = 1. / num_total_classes * num_neg_samples; Tensor sample_grad; // tmp tensor - T* sample_grad_data = + T *sample_grad_data = sample_grad.mutable_data(sample_labels->dims(), context.GetPlace()); // backward cost for (int64_t i = 0; i < sample_labels->numel(); ++i) { @@ -217,32 +239,105 @@ class NCEGradKernel : public framework::OpKernel { : w * (o * (1 - o) / (o + b)); sample_grad_data[i] *= d_out_data[sample_idx]; } - // get d_bias - auto d_bias = context.Output(framework::GradVarName("Bias")); - if (d_bias != nullptr) { - T* d_bias_data = d_bias->mutable_data(context.GetPlace()); - std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0); + + bool is_sparse = context.Attr("is_sparse"); + + if (!is_sparse) { + // get d_bias + auto d_bias = context.Output(framework::GradVarName("Bias")); + if (d_bias != nullptr) { + T *d_bias_data = d_bias->mutable_data(context.GetPlace()); + std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0); + for (int64_t i = 0; i < sample_labels->numel(); ++i) { + d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; + } + } + // get d_w + auto d_w = context.Output(framework::GradVarName("Weight")); + if (d_w != nullptr) { + auto d_w_data = d_w->mutable_data(context.GetPlace()); + std::fill(d_w_data, d_w_data + d_w->numel(), 0.0); + auto d_w_matrix = EigenMatrix::From(*d_w); + auto x_matrix = EigenMatrix::From(*(context.Input("Input"))); + for (int64_t i = 0; i < sample_labels->numel(); ++i) { + d_w_matrix.chip(sample_labels_data[i], 0) += + x_matrix.chip(static_cast(i / sample_labels->dims()[1]), 0) * + sample_grad_data[i]; + } + } + } else { + std::vector labels; for (int64_t i = 0; i < sample_labels->numel(); ++i) { - d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; + labels.push_back(sample_labels_data[i]); } - } - // get d_w - auto d_w = context.Output(framework::GradVarName("Weight")); - if (d_w != nullptr) { - auto d_w_data = d_w->mutable_data(context.GetPlace()); - std::fill(d_w_data, d_w_data + d_w->numel(), 0.0); - auto d_w_matrix = EigenMatrix::From(*d_w); + std::set st(labels.begin(), labels.end()); + labels.assign(st.begin(), st.end()); + + auto *bias_var = context.InputVar("Bias"); + DDim bias_dim; + if (bias_var->IsType()) { + bias_dim = context.Input("Bias")->dims(); + } else if (bias_var->IsType()) { + auto *table_t = context.Input("Bias"); + bias_dim = table_t->value().dims(); + } else { + PADDLE_THROW( + "The parameter Bias of a NCE_OP " + "must be either LoDTensor or SelectedRows"); + } + + auto d_bias = + context.Output(framework::GradVarName("Bias")); + d_bias->set_rows(labels); + d_bias->set_height(bias_dim[0]); + + d_bias->mutable_value()->Resize( + {static_cast(labels.size()), bias_dim[1]}); + T *d_bias_data = + d_bias->mutable_value()->mutable_data(context.GetPlace()); + std::fill(d_bias_data, d_bias_data + labels.size(), 0.0); + for (int64_t i = 0; i < sample_labels->numel(); ++i) { + d_bias_data[d_bias->Index(sample_labels_data[i])] += + sample_grad_data[i]; + } + + auto *table_var = context.InputVar("Weight"); + DDim table_dim; + if (table_var->IsType()) { + table_dim = context.Input("Weight")->dims(); + } else if (table_var->IsType()) { + auto *table_t = context.Input("Weight"); + table_dim = table_t->value().dims(); + } else { + PADDLE_THROW( + "The parameter Weight of a NCE_OP " + "must be either LoDTensor or SelectedRows"); + } + + auto d_w = context.Output(framework::GradVarName("Weight")); + + d_w->set_rows(labels); + d_w->set_height(table_dim[0]); + + auto *d_table_value = d_w->mutable_value(); + d_table_value->Resize( + {static_cast(labels.size()), table_dim[1]}); + auto d_w_data = d_table_value->mutable_data(context.GetPlace()); + std::fill(d_w_data, d_w_data + d_table_value->numel(), 0.0); + + auto d_w_matrix = EigenMatrix::From(*d_table_value); auto x_matrix = EigenMatrix::From(*(context.Input("Input"))); for (int64_t i = 0; i < sample_labels->numel(); ++i) { - d_w_matrix.chip(sample_labels_data[i], 0) += + d_w_matrix.chip(d_w->Index(sample_labels_data[i]), 0) += x_matrix.chip(static_cast(i / sample_labels->dims()[1]), 0) * sample_grad_data[i]; } } + // get d_x auto d_x = context.Output(framework::GradVarName("Input")); if (d_x != nullptr) { - auto* d_x_data = d_x->mutable_data(context.GetPlace()); + auto *d_x_data = d_x->mutable_data(context.GetPlace()); std::fill(d_x_data, d_x_data + d_x->numel(), 0.0); auto d_x_matrix = EigenMatrix::From(*d_x); auto w_matrix = EigenMatrix::From(*(context.Input("Weight"))); @@ -251,6 +346,7 @@ class NCEGradKernel : public framework::OpKernel { w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; } } + delete sampler; } }; diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 6c919ee178..7c284312df 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -28,6 +28,12 @@ reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc) reader_library(create_py_reader_op SRCS create_py_reader_op.cc) +if (NOT WIN32 AND NOT ON_INFER) + cc_library(ctr_reader SRCS ctr_reader.cc DEPS gzstream reader zlib) + cc_test(ctr_reader_test SRCS ctr_reader_test.cc DEPS ctr_reader) + reader_library(create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader) +endif () + cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc) # Export local libraries to parent # set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/reader/create_ctr_reader_op.cc b/paddle/fluid/operators/reader/create_ctr_reader_op.cc new file mode 100644 index 0000000000..58a465d87a --- /dev/null +++ b/paddle/fluid/operators/reader/create_ctr_reader_op.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/ctr_reader.h" + +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" +#include "paddle/fluid/operators/reader/reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +class CreateCTRReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + if (out->Get() != nullptr) return; + + const std::string& queue_name = Input("blocking_queue"); + auto* queue_holder_var = scope.FindVar(queue_name); + PADDLE_ENFORCE_NOT_NULL( + queue_holder_var, + "No LoDTensorBlockingQueueHolder variable with name %s found", + queue_name); + auto* queue_holder = + queue_holder_var->template GetMutable(); + + int thread_num = Attr("thread_num"); + std::vector slots = Attr>("slots"); + int batch_size = Attr("batch_size"); + std::vector file_list = + Attr>("file_list"); + out->Reset(std::make_shared(queue_holder->GetQueue(), batch_size, + thread_num, slots, file_list)); + } +}; + +class CreateCTRReaderOpMaker : public FileReaderMakerBase { + protected: + void Apply() override { + AddInput("blocking_queue", + "Name of the `LoDTensorBlockingQueueHolder` variable"); + AddAttr("thread_num", "the thread num to read data"); + AddAttr("batch_size", "the batch size of read data"); + AddAttr>("file_list", + "The list of files that need to read"); + AddAttr>( + "slots", "the slots that should be extract from file"); + + AddComment(R"DOC( + Create CTRReader to support read ctr data with cpp. + )DOC"); + } +}; + +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace reader = ::paddle::operators::reader; + +REGISTER_FILE_READER_OPERATOR(create_ctr_reader, reader::CreateCTRReaderOp, + reader::CreateCTRReaderOpMaker); diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc new file mode 100644 index 0000000000..d1d3ddc89d --- /dev/null +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -0,0 +1,238 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/ctr_reader.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace paddle { +namespace operators { +namespace reader { + +static inline void string_split(const std::string& s, const char delimiter, + std::vector* output) { + size_t start = 0; + size_t end = s.find_first_of(delimiter); + + while (end <= std::string::npos) { + output->emplace_back(s.substr(start, end - start)); + if (end == std::string::npos) { + break; + } + start = end + 1; + end = s.find_first_of(delimiter, start); + } +} + +static inline void parse_line( + const std::string& line, + const std::unordered_map& slot_to_index, + int64_t* label, + std::unordered_map>* slot_to_data) { + std::vector ret; + string_split(line, ' ', &ret); + *label = std::stoi(ret[2]) > 0; + + for (size_t i = 3; i < ret.size(); ++i) { + const std::string& item = ret[i]; + std::vector feasign_and_slot; + string_split(item, ':', &feasign_and_slot); + if (feasign_and_slot.size() == 2 && + slot_to_index.find(feasign_and_slot[1]) != slot_to_index.end()) { + int64_t feasign = std::strtoll(feasign_and_slot[0].c_str(), NULL, 10); + (*slot_to_data)[feasign_and_slot[1]].push_back(feasign); + } + } + + // NOTE:: if the slot has no value, then fill [0] as it's data. + for (auto& item : slot_to_index) { + if (slot_to_data->find(item.first) == slot_to_data->end()) { + (*slot_to_data)[item.first].push_back(0); + } + } +} + +class Reader { + public: + virtual ~Reader() {} + virtual bool HasNext() = 0; + virtual void NextLine(std::string* line) = 0; +}; + +class GzipReader : public Reader { + public: + explicit GzipReader(const std::string& file_name) + : gzstream_(file_name.c_str()) {} + + ~GzipReader() {} + + bool HasNext() override { return gzstream_.peek() != EOF; } + + void NextLine(std::string* line) override { std::getline(gzstream_, *line); } + + private: + igzstream gzstream_; +}; + +class MultiGzipReader : public Reader { + public: + explicit MultiGzipReader(const std::vector& file_list) { + for (auto& file : file_list) { + readers_.emplace_back(std::make_shared(file)); + } + } + + bool HasNext() override { + if (current_reader_index_ >= readers_.size()) { + return false; + } + if (!readers_[current_reader_index_]->HasNext()) { + current_reader_index_++; + return HasNext(); + } + return true; + } + + void NextLine(std::string* line) override { + readers_[current_reader_index_]->NextLine(line); + } + + private: + std::vector> readers_; + size_t current_reader_index_ = 0; +}; + +void MonitorThread(std::vector* thread_status, + std::shared_ptr queue) { + VLOG(30) << "monitor thread in"; + bool reader_thread_is_running = true; + while (reader_thread_is_running) { + VLOG(30) << "reader_thread_is_running"; + reader_thread_is_running = false; + for (size_t i = 0; i < (*thread_status).size(); ++i) { + if ((*thread_status)[i] == Running) { + VLOG(30) << "reader is running!"; + reader_thread_is_running = true; + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + VLOG(30) << "all reader thread is stopped, push empty data into queue"; + queue->Push({}); + VLOG(30) << "monitor thread exited"; +} + +void ReadThread(const std::vector& file_list, + const std::vector& slots, int batch_size, + int thread_id, std::vector* thread_status, + std::shared_ptr queue) { + VLOG(30) << "[" << thread_id << "]" + << " reader thread start! thread_id = " << thread_id; + for (auto& file : file_list) { + VLOG(30) << "[" << thread_id << "]" + << " file " << file; + } + (*thread_status)[thread_id] = Running; + VLOG(30) << "set status to running"; + + std::unordered_map slot_to_index; + for (size_t i = 0; i < slots.size(); ++i) { + slot_to_index[slots[i]] = i; + } + + std::string line; + + std::vector>> batch_data; + std::vector batch_label; + + MultiGzipReader reader(file_list); + + VLOG(30) << "reader inited"; + + while (reader.HasNext()) { + batch_data.clear(); + batch_data.reserve(batch_size); + + batch_label.clear(); + batch_label.reserve(batch_size); + + // read batch_size data + for (int i = 0; i < batch_size; ++i) { + if (reader.HasNext()) { + reader.NextLine(&line); + std::unordered_map> slot_to_data; + int64_t label; + parse_line(line, slot_to_index, &label, &slot_to_data); + batch_data.push_back(slot_to_data); + batch_label.push_back(label); + } else { + break; + } + } + + std::vector lod_datas; + + // first insert tensor for each slots + for (auto& slot : slots) { + std::vector lod_data{0}; + std::vector batch_feasign; + + for (size_t i = 0; i < batch_data.size(); ++i) { + auto& feasign = batch_data[i][slot]; + lod_data.push_back(lod_data.back() + feasign.size()); + batch_feasign.insert(batch_feasign.end(), feasign.begin(), + feasign.end()); + } + + framework::LoDTensor lod_tensor; + framework::LoD lod{lod_data}; + lod_tensor.set_lod(lod); + int64_t* tensor_data = lod_tensor.mutable_data( + framework::make_ddim({1, static_cast(batch_feasign.size())}), + platform::CPUPlace()); + memcpy(tensor_data, batch_feasign.data(), + batch_feasign.size() * sizeof(int64_t)); + lod_datas.push_back(lod_tensor); + } + + // insert label tensor + framework::LoDTensor label_tensor; + auto* label_tensor_data = label_tensor.mutable_data( + framework::make_ddim({1, static_cast(batch_label.size())}), + platform::CPUPlace()); + memcpy(label_tensor_data, batch_label.data(), + batch_label.size() * sizeof(int64_t)); + lod_datas.push_back(label_tensor); + + queue->Push(lod_datas); + VLOG(40) << "push one data, queue_size=" << queue->Size(); + } + + (*thread_status)[thread_id] = Stopped; + VLOG(30) << "set status to stopped, thread " << thread_id << " exited"; +} + +} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h new file mode 100644 index 0000000000..9b2a11bae1 --- /dev/null +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -0,0 +1,133 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include // NOLINT +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" + +namespace paddle { +namespace operators { +namespace reader { + +enum ReaderThreadStatus { Running, Stopped }; + +void ReadThread(const std::vector& file_list, + const std::vector& slots, int batch_size, + int thread_id, std::vector* thread_status, + std::shared_ptr queue); + +// monitor all running thread, if they are all stopped, +// then push an empty data into LoDTensorBlockingQueue +void MonitorThread(std::vector* thread_status, + std::shared_ptr queue); + +class CTRReader : public framework::FileReader { + public: + explicit CTRReader(const std::shared_ptr& queue, + int batch_size, int thread_num, + const std::vector& slots, + const std::vector& file_list) + : batch_size_(batch_size), slots_(slots), file_list_(file_list) { + PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!"); + PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); + PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); + thread_num_ = + file_list_.size() > thread_num ? thread_num : file_list_.size(); + queue_ = queue; + SplitFiles(); + for (size_t i = 0; i < thread_num_; ++i) { + read_thread_status_.push_back(Stopped); + } + } + + ~CTRReader() {} + + void ReadNext(std::vector* out) override { + bool success; + *out = queue_->Pop(&success); + if (!success) out->clear(); + } + + void Shutdown() override { + VLOG(3) << "Shutdown reader"; + if (status_ == ReaderStatus::kStopped) { + return; + } + // shutdown should stop all the reader thread + for (auto& read_thread : read_threads_) { + read_thread->join(); + } + monitor_thread_->join(); + + read_threads_.clear(); + monitor_thread_.reset(nullptr); + queue_->Close(); + status_ = ReaderStatus::kStopped; + } + + void Start() override { + VLOG(3) << "Start reader"; + PADDLE_ENFORCE_EQ(read_threads_.size(), 0, "read thread should be empty!"); + queue_->ReOpen(); + VLOG(3) << "reopen success"; + VLOG(3) << "thread_num " << thread_num_; + for (int thread_id = 0; thread_id < thread_num_; thread_id++) { + read_threads_.emplace_back(new std::thread( + std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_, + thread_id, &read_thread_status_, queue_))); + } + monitor_thread_.reset(new std::thread( + std::bind(&MonitorThread, &read_thread_status_, queue_))); + status_ = ReaderStatus::kRunning; + } + + private: + void SplitFiles() { + file_groups_.resize(thread_num_); + for (size_t i = 0; i < file_list_.size(); ++i) { + auto& file_name = file_list_[i]; + std::ifstream f(file_name.c_str()); + PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name); + file_groups_[i % thread_num_].push_back(file_name); + } + } + + private: + size_t thread_num_; + const int batch_size_; + const std::vector slots_; + const std::vector file_list_; + std::shared_ptr queue_; + std::vector> read_threads_; + std::unique_ptr monitor_thread_; + std::vector read_thread_status_; + std::vector> file_groups_; +}; + +} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/ctr_reader_test.cc b/paddle/fluid/operators/reader/ctr_reader_test.cc new file mode 100644 index 0000000000..8dba9baebc --- /dev/null +++ b/paddle/fluid/operators/reader/ctr_reader_test.cc @@ -0,0 +1,155 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/ctr_reader.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/operators/reader/blocking_queue.h" + +using paddle::operators::reader::LoDTensorBlockingQueue; +using paddle::operators::reader::LoDTensorBlockingQueueHolder; +using paddle::operators::reader::CTRReader; +using paddle::framework::LoDTensor; +using paddle::framework::LoD; +using paddle::framework::DDim; +using paddle::platform::CPUPlace; +using paddle::framework::make_ddim; + +static void generatedata(const std::vector& data, + const std::string& file_name) { + std::ifstream in(file_name.c_str()); + if (in.good()) { + VLOG(3) << "file " << file_name << " exist, delete it first!"; + remove(file_name.c_str()); + } else { + in.close(); + } + + ogzstream out(file_name.c_str()); + PADDLE_ENFORCE(out.good(), "open file %s failed!", file_name); + for (auto& c : data) { + out << c; + } + out.close(); + PADDLE_ENFORCE(out.good(), "save file %s failed!", file_name); +} + +static inline void check_all_data( + const std::vector& ctr_data, + const std::vector& slots, const std::vector& label_dims, + const std::vector& label_value, + const std::vector>>& data_slot_6002, + const std::vector>>& data_slot_6003, + size_t batch_num, size_t batch_size, + std::shared_ptr queue, CTRReader* reader) { + std::vector out; + for (size_t i = 0; i < batch_num; ++i) { + reader->ReadNext(&out); + ASSERT_EQ(out.size(), slots.size() + 1); + auto& label_tensor = out.back(); + ASSERT_EQ(label_tensor.dims(), label_dims[i]); + for (size_t j = 0; j < batch_size && i * batch_num + j < ctr_data.size(); + ++j) { + auto& label = label_tensor.data()[j]; + ASSERT_TRUE(label == 0 || label == 1); + ASSERT_EQ(label, label_value[i * batch_size + j]); + } + auto& tensor_6002 = out[0]; + ASSERT_EQ(std::get<0>(data_slot_6002[i]), tensor_6002.lod()); + ASSERT_EQ(std::memcmp(std::get<1>(data_slot_6002[i]).data(), + tensor_6002.data(), + tensor_6002.dims()[1] * sizeof(int64_t)), + 0); + } + reader->ReadNext(&out); + ASSERT_EQ(out.size(), 0); + ASSERT_EQ(queue->Size(), 0); +} + +TEST(CTR_READER, read_data) { + const std::vector ctr_data = { + "aaaa 1 0 0:6002 1:6003 2:6004 3:6005 4:6006 -1\n", + "bbbb 1 0 5:6003 6:6003 7:6003 8:6004 9:6004 -1\n", + "cccc 1 1 10:6002 11:6002 12:6002 13:6002 14:6002 -2\n", + "dddd 1 0 15:6003 16:6003 17:6003 18:6003 19:6004 -3\n", + "1111 1 1 20:6001 21:6001 22:6001 23:6001 24:6001 12\n", + "2222 1 1 25:6004 26:6004 27:6004 28:6005 29:6005 aa\n", + "3333 1 0 30:6002 31:6003 32:6004 33:6004 34:6005 er\n", + "eeee 1 1 35:6003 36:6003 37:6005 38:6005 39:6005 dd\n", + "ffff 1 1 40:6002 41:6003 42:6004 43:6004 44:6005 66\n", + "gggg 1 1 46:6006 45:6006 47:6003 48:6003 49:6003 ba\n", + }; + std::string gz_file_name = "test_ctr_reader_data.gz"; + generatedata(ctr_data, gz_file_name); + + std::vector label_value = {0, 0, 1, 0, 1, 1, 0, 1, 1, 1}; + + std::tuple> a1({{0, 1, 2, 7}}, + {0, 0, 10, 11, 12, 13, 14}); + std::tuple> a2({{0, 1, 2, 3}}, {0, 0, 0}); + std::tuple> a3({{0, 1, 2, 3}}, {30, 0, 40}); + std::tuple> a4({{0, 1}}, {0}); + std::vector>> data_slot_6002{a1, a2, a3, + a4}; + + std::tuple> b1({{0, 1, 4, 5}}, {1, 5, 6, 7, 0}); + std::tuple> b2({{0, 4, 5, 6}}, + {15, 16, 17, 18, 0, 0}); + std::tuple> b3({{0, 1, 3, 4}}, {31, 35, 36, 41}); + std::tuple> b4({{0, 3}}, {47, 48, 49}); + std::vector>> data_slot_6003{b1, b2, b3, + b4}; + + std::vector label_dims = {{1, 3}, {1, 3}, {1, 3}, {1, 1}}; + + LoDTensorBlockingQueueHolder queue_holder; + int capacity = 64; + queue_holder.InitOnce(capacity, {}, false); + + std::shared_ptr queue = queue_holder.GetQueue(); + + int batch_size = 3; + int thread_num = 1; + std::vector slots = {"6002", "6003"}; + std::vector file_list; + for (int i = 0; i < thread_num; ++i) { + file_list.push_back(gz_file_name); + } + + CTRReader reader(queue, batch_size, thread_num, slots, file_list); + + reader.Start(); + size_t batch_num = + std::ceil(static_cast(ctr_data.size()) / batch_size) * thread_num; + check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, + data_slot_6003, batch_num, batch_size, queue, &reader); + + reader.Shutdown(); + + reader.Start(); + check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, + data_slot_6003, batch_num, batch_size, queue, &reader); + reader.Shutdown(); +} diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index ee16fc66e4..9d48557caf 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -1039,6 +1039,11 @@ HOSTDEVICE inline float16 exp(const float16& a) { return float16(::expf(static_cast(a))); } +template <> +HOSTDEVICE inline float16 erf(const float16& a) { + return float16(::erff(static_cast(a))); +} + template <> HOSTDEVICE inline float16 log(const float16& a) { return float16(::logf(static_cast(a))); diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 20bd349b94..2bb36a5d6b 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -20,12 +20,12 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #ifndef _WIN32 -const float fraction_of_gpu_memory_to_use = 0.92f; +constexpr static float fraction_of_gpu_memory_to_use = 0.92f; #else // fraction_of_gpu_memory_to_use cannot be too high on windows, // since the win32 graphic sub-system can occupy some GPU memory // which may lead to insufficient memory left for paddle -const float fraction_of_gpu_memory_to_use = 0.5f; +constexpr static float fraction_of_gpu_memory_to_use = 0.5f; #endif DEFINE_double(fraction_of_gpu_memory_to_use, fraction_of_gpu_memory_to_use, diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 814012e6c1..761a9815e0 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include "paddle/fluid/framework/operator.h" @@ -292,5 +293,21 @@ inline mkldnn::memory::format data_format_to_memory_format( } } +inline mkldnn::memory::format StringToMKLDNNFormat(std::string* format) { + std::transform(format->begin(), format->end(), format->begin(), ::tolower); + + if (!format->compare("nchw")) { + return mkldnn::memory::format::nchw; + } else if (!format->compare("nchw16c")) { + return mkldnn::memory::format::nChw16c; + } else if (!format->compare("nchw8c")) { + return mkldnn::memory::format::nChw8c; + } else if (!format->compare("nhwc")) { + return mkldnn::memory::format::nhwc; + } else { + return mkldnn::memory::format::any; + } +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 0443ff3fc3..8735b65875 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -29,8 +29,16 @@ limitations under the License. */ namespace pybind11 { namespace detail { +#if !defined(PYBIND11_HIDDEN) +#ifdef _WIN32 +#define PYBIND11_HIDDEN __declspec(dllexport) +#else +#define PYBIND11_HIDDEN __attribute__((visibility("hidden"))) +#endif +#endif + // Can be replaced by a generic lambda in C++14 -struct __attribute__((visibility("hidden"))) paddle_variant_caster_visitor +struct PYBIND11_HIDDEN paddle_variant_caster_visitor : public boost::static_visitor { return_value_policy policy; handle parent; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a2a629acdf..e31c2f2173 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -860,6 +860,12 @@ All parameter, weight, gradient are variables in Paddle. self.remove_unnecessary_lock_ = b; }, R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC") + .def_property( + "num_trainers", + [](const BuildStrategy &self) { return self.num_trainers_; }, + [](BuildStrategy &self, int num_trainers) { + self.num_trainers_ = num_trainers; + }) .def_property( "fuse_elewise_add_act_ops", [](const BuildStrategy &self) { diff --git a/paddle/legacy/cuda/src/hl_cuda_device.cc b/paddle/legacy/cuda/src/hl_cuda_device.cc index a6e27a37ff..92197afb3d 100644 --- a/paddle/legacy/cuda/src/hl_cuda_device.cc +++ b/paddle/legacy/cuda/src/hl_cuda_device.cc @@ -137,10 +137,10 @@ inline pid_t gettid() { #define __NR_gettid 224 #endif pid_t tid = syscall(__NR_gettid); -#endif #else // _WIN32 pid_t tid = _getpid(); #endif // _WIN32 +#endif CHECK_NE((int)tid, -1); return tid; } diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index c4e283e76a..a6720fa798 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -469,18 +469,21 @@ function assert_api_spec_approvals() { BRANCH="develop" fi - API_CHANGE=`git diff --name-only upstream/$BRANCH | grep "paddle/fluid/API.spec" || true` - echo "checking API.spec change, PR: ${GIT_PR_ID}, changes: ${API_CHANGE}" - if [ ${API_CHANGE} ] && [ "${GIT_PR_ID}" != "" ]; then - # NOTE: per_page=10000 should be ok for all cases, a PR review > 10000 is not human readable. - APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \ - python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 7845005 2887803 728699 13348433` - echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}" - if [ "${APPROVALS}" == "FALSE" ]; then - echo "You must have at least 2 approvals for the api change!" - exit 1 - fi - fi + API_FILES=("paddle/fluid/API.spec" "paddle/fluid/framework/operator.h") + for API_FILE in ${API_FILES[*]}; do + API_CHANGE=`git diff --name-only upstream/$BRANCH | grep "${API_FILE}" || true` + echo "checking ${API_FILE} change, PR: ${GIT_PR_ID}, changes: ${API_CHANGE}" + if [ ${API_CHANGE} ] && [ "${GIT_PR_ID}" != "" ]; then + # NOTE: per_page=10000 should be ok for all cases, a PR review > 10000 is not human readable. + APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \ + python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 7845005 2887803 728699 13348433` + echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}" + if [ "${APPROVALS}" == "FALSE" ]; then + echo "You must have at least 2 approvals for the api change! ${API_FILE}" + exit 1 + fi + fi + done } diff --git a/python/paddle/fluid/contrib/reader/ctr_reader.py b/python/paddle/fluid/contrib/reader/ctr_reader.py new file mode 100644 index 0000000000..b8449e8d84 --- /dev/null +++ b/python/paddle/fluid/contrib/reader/ctr_reader.py @@ -0,0 +1,123 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.fluid import core +from paddle.fluid.executor import global_scope +from paddle.fluid.framework import default_main_program, \ + default_startup_program, Variable +from paddle.fluid.unique_name import generate as unique_name + + +def monkey_patch_reader_methods(reader): + def __get_reader__(): + scope = global_scope() + var = scope.find_var(reader.name) + return var.get_reader() + + def reset(): + return __get_reader__().reset() + + reader.reset = reset + reader.stop_gradient = True + reader.persistable = True + return reader + + +def _copy_reader_var_(block, var): + new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER) + new_var.desc.set_shapes(var.desc.shapes()) + new_var.desc.set_dtypes(var.desc.dtypes()) + new_var.persistable = True + return new_var + + +def ctr_reader(feed_data, + capacity, + thread_num, + batch_size, + file_list, + slots, + name=None): + """ + Create a CTR reader for data feeding in Python + + This layer returns a Reader Variable. + The Reader provides :code:`decorate_paddle_reader()` and + :code:`decorate_tensor_provider()` to set a Python generator as the data + source in Python side. When :code:`Executor::Run()` is invoked in C++ + side, the data from the generator would be read automatically. Unlike + :code:`DataFeeder.feed()`, the data reading process and + :code:`Executor::Run()` process can run in parallel using + :code:`py_reader`. The :code:`start()` method of the Reader should be + called when each pass begins, while the :code:`reset()` method should be + called when the pass ends and :code:`fluid.core.EOFException` raises. + Note that :code:`Program.clone()` method cannot clone :code:`py_reader`. + + Args: + capacity(int): The buffer capacity maintained by :code:`py_reader`. + thread_num(list|tuple): List of tuples which declaring data shapes. + batch_size(list|tuple): List of strs which declaring data type. + file_list(list|tuple): List of ints which declaring data lod_level. + slots(bool): Whether use double buffer or not. + name(basestring): The prefix Python queue name and Reader name. None will + be generated automatically. + + Returns: + Variable: A Reader from which we can get feeding data. + + Examples: + + 1. The basic usage of :code:`py_reader` is as follows: + """ + if name is None: + queue_name = unique_name('lod_tensor_blocking_queue') + reader_name = unique_name('create_ctr_reader') + else: + queue_name = "_".join([name, "queue"]) + reader_name = "_".join([name, "reader"]) + + var = global_scope().var(queue_name) + feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes) + + startup_blk = default_startup_program().current_block() + reader_var = startup_blk.create_var(name=reader_name) + startup_blk.append_op( + type='create_ctr_reader', + inputs={'blocking_queue': [queue_name]}, + outputs={'Out': [reader_var]}, + attrs={ + 'thread_num': thread_num, + 'batch_size': batch_size, + 'file_list': file_list, + 'slots': slots, + }) + + reader_var.persistable = True + + main_prog_reader_var = _copy_reader_var_( + default_main_program().current_block(), reader_var) + + reader = monkey_patch_reader_methods(main_prog_reader_var) + + # monkey patch py_reader special methods + reader.queue = feed_queue + reader.exited = False + + main_blk = default_main_program().current_block() + main_blk.append_op( + type='read', inputs={'Reader': [reader]}, outputs={'Out': feed_data}) + + return reader diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 26d7af87b3..0782933c6c 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -637,8 +637,8 @@ def save_inference_model(dirname, if isinstance(target_vars, Variable): target_vars = [target_vars] elif export_for_deployment: - if not (bool(target_vars) and all( - isinstance(var, Variable) for var in target_vars)): + if not (bool(target_vars) and + all(isinstance(var, Variable) for var in target_vars)): raise ValueError("'target_vars' should be a list of Variable.") if main_program is None: @@ -667,10 +667,15 @@ def save_inference_model(dirname, if export_for_deployment: main_program = main_program.clone() global_block = main_program.global_block() + need_to_remove_op_index = [] for i, op in enumerate(global_block.ops): op.desc.set_is_target(False) if op.type == "feed" or op.type == "fetch": - global_block._remove_op(i) + need_to_remove_op_index.append(i) + + for index in need_to_remove_op_index[::-1]: + global_block._remove_op(index) + main_program.desc.flush() main_program = main_program._prune(targets=target_vars) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 6b5a55a662..2051a1ea01 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4398,7 +4398,8 @@ def nce(input, name=None, sampler="uniform", custom_dist=None, - seed=0): + seed=0, + is_sparse=False): """ ${comment} @@ -4424,11 +4425,12 @@ def nce(input, sampler (str): The sampler used to sample class from negtive classes. It can be 'uniform', 'log_uniform' or 'custom_dist'. default: 'uniform'. - custom_dist (Variable): A tensor with shape [num_total_classes]. + custom_dist (float[]): A float[] with size=num_total_classes. It is used when sampler is set to 'custom_dist'. custom_dist[i] is the probsbility of i-th class to be sampled. default: None. seed (int): The seed used in sampler. default: 0. + is_sparse(bool): The flag indicating whether to use sparse update, the weight@GRAD and bias@GRAD will be changed to SelectedRows. Returns: Variable: The output nce loss. @@ -4480,12 +4482,7 @@ def nce(input, shape=[num_total_classes, dim], is_bias=False, dtype=input.dtype) - inputs = { - 'Input': input, - 'Label': label, - 'Weight': w, - 'SampleWeight': sample_weight if sample_weight is not None else [] - } + inputs = {} if helper.bias_attr: b = helper.create_parameter( attr=helper.bias_attr, @@ -4497,18 +4494,10 @@ def nce(input, sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype) sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype) - if num_neg_samples is None: - num_neg_samples = 10 - else: - num_neg_samples = int(num_neg_samples) - - inputs = { - 'Input': input, - 'Label': label, - 'Weight': w, - 'Bias': b, - 'SampleWeight': sample_weight if sample_weight is not None else [] - } + inputs['Input'] = input + inputs['Label'] = label + inputs['Weight'] = w + inputs['SampleWeight'] = sample_weight if sample_weight is not None else [] if sampler == "uniform": sampler = 0 @@ -4516,17 +4505,73 @@ def nce(input, sampler = 1 elif sampler == "custom_dist": assert custom_dist is not None - assert isinstance(custom_dist, Variable) - inputs['CustomDistribution'] = custom_dist + # assert isinstance(custom_dist, Variable) + + custom_dist_len = len(custom_dist) + alias_probs_ = [0] * custom_dist_len + alias_ = [0] * custom_dist_len + bigs = [] + littles = [] + for i in range(custom_dist_len): + normal_prob = custom_dist[i] * custom_dist_len + if normal_prob - 1.0 > 1e-4: + bigs.append((i, normal_prob)) + elif 1.0 - normal_prob > 1e-4: + littles.append((i, normal_prob)) + else: + alias_probs_[i] = normal_prob + alias_[i] = -1 + + while len(bigs) and len(littles): + big = bigs.pop(0) + little = littles.pop(0) + + big_idx = big[0] + big_prob = big[1] + + alias_probs_[little[0]] = little[1] + alias_[little[0]] = big_idx + big_left = big[1] + little[1] - 1 + if big_left - 1.0 > 1e-4: + bigs.append((big_idx, big_left)) + elif 1.0 - big_left > 1e-4: + littles.append((big_idx, big_left)) + else: + alias_probs_[big_idx] = big_left + alias_[big_idx] = -1 + + if len(bigs): + big = bigs.pop(0) + alias_probs_[big[0]] = 1.0 + alias_[big[0]] = -1 + if len(littles): + little = littles.pop(0) + alias_probs_[little[0]] = 1.0 + alias_[little[0]] = -1 + + probs = assign(input=np.array(custom_dist).astype('float32')) + custom_alias = assign(input=np.array(alias_).astype('int32')) + custom_alias_probs = assign( + input=np.array(alias_probs_).astype('float32')) + + inputs['CustomDistProbs'] = probs + inputs['CustomDistAlias'] = custom_alias + inputs['CustomDistAliasProbs'] = custom_alias_probs sampler = 2 else: raise Exception("Unsupported sampler type.") + if num_neg_samples is None: + num_neg_samples = 10 + else: + num_neg_samples = int(num_neg_samples) + attrs = { 'num_total_classes': int(num_total_classes), 'num_neg_samples': num_neg_samples, 'seed': seed, - 'sampler': sampler + 'sampler': sampler, + 'is_sparse': is_sparse } helper.append_op( @@ -4546,27 +4591,43 @@ def hsigmoid(input, num_classes, param_attr=None, bias_attr=None, - name=None): + name=None, + path_table=None, + path_code=None, + is_custom=False, + is_sparse=False): """ The hierarchical sigmoid operator is used to accelerate the training process of language model. This operator organizes the classes into a - complete binary tree, each leaf node represents a class(a word) and each + complete binary tree, or you can use is_custom to pass your own tree to + implement hierarchical. Each leaf node represents a class(a word) and each internal node acts as a binary classifier. For each word there's a unique path from root to it's leaf node, hsigmoid calculate the cost for each internal node on the path, and sum them to get a total cost. hsigmoid can achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N` represents the size of word dict. - Refer to `Hierarchical Probabilistic Neural Network Language Model + Using default tree you can Refer to `Hierarchical Probabilistic Neural Network Language Model `_ + And if you want to use the costumed tree by set 'is_custom' as true you may need to do following things first: + 1. using your word dict to build a binary tree, each leaf node should be an word of your word dict + 2. build a dict to store word_id -> word's leaf to root path, we call it path_table. + 3. build a dict to store word_id -> code of word's leaf to root path, we call it path_code. Code + means label of each binary classification, using 1 indicate true, 0 indicate false. + 4. now, each word should has its path and code along the path, you can pass a batch of path and code + related to the same batch of inputs. + + Args: input (Variable): The input tensor variable with shape :math:`[N \\times D]`, where :math:`N` is the size of mini-batch, and :math:`D` is the feature size. label (Variable): The tensor variable contains labels of training data. It's a tensor with shape is :math:`[N \\times 1]`. - num_classes: (int), The number of classes, must not be less than 2. + num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set, + it should never be None under is_custom=False, but while is_custom is true, it should be non leaf num + which indicates the num of classes using by binary classify. param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create ParamAttr as param_attr. If the Initializer of the param_attr @@ -4578,9 +4639,19 @@ def hsigmoid(input, is not set, the bias is initialized zero. Default: None. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None. + path_table: (Variable|None) this variable can store each batch of samples' path to root, + it should be in leaf -> root order + path_table should have the same shape with path_code, and for each sample i path_table[i] indicates a np.array like + structure and each element in this array is indexes in parent nodes' Weight Matrix. + path_code: (Variable|None) this variable can store each batch of samples' code, + each code consist with every code of parent nodes. it should be in leaf -> root order + is_custom: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is + set you need to set path_table/path_code/num_classes, otherwise num_classes should be set + is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient + of W and input will be sparse. Returns: - Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1] + Out: (LodTensor) The cost of hierarchical sigmoid operator. the shape is [N, 1] Examples: @@ -4596,27 +4667,62 @@ def hsigmoid(input, out = helper.create_variable_for_type_inference(dtype) pre_out = helper.create_variable_for_type_inference(dtype) dim = input.shape[1] - if num_classes < 2: - raise ValueError("num_classes must not be less than 2.") - weights = helper.create_parameter( - attr=helper.param_attr, - shape=[num_classes - 1, dim], - is_bias=False, - dtype=input.dtype) - inputs = {"X": input, "W": weights, "Label": label} - if helper.bias_attr: - bias = helper.create_parameter( - attr=helper.bias_attr, - shape=[1, num_classes - 1], - is_bias=True, + if ((num_classes is None) or (num_classes < 2)) and (not is_custom): + raise ValueError( + "num_classes must not be less than 2 with default tree") + + if (is_custom) and (path_code is None): + raise ValueError("path_code should not be None with costum tree") + elif (is_custom) and (path_table is None): + raise ValueError("path_table should not be None with costum tree") + elif (is_custom) and (num_classes is None): + raise ValueError("num_classes should not be None with costum tree") + else: + pass + + weights = None + + if not is_custom: + weights = helper.create_parameter( + attr=helper.param_attr, + shape=[num_classes - 1, dim], + is_bias=False, dtype=input.dtype) - inputs['Bias'] = bias + else: + weights = helper.create_parameter( + attr=helper.param_attr, + shape=[num_classes, dim], + is_bias=False, + dtype=input.dtype) + inputs = { + "X": input, + "W": weights, + "PTable": path_table, + "PathCode": path_code, + "Label": label + } + if helper.bias_attr: + if not is_custom: + bias = helper.create_parameter( + attr=helper.bias_attr, + shape=[num_classes - 1, 1], + is_bias=True, + dtype=input.dtype) + inputs['Bias'] = bias + else: + bias = helper.create_parameter( + attr=helper.bias_attr, + shape=[num_classes, 1], + is_bias=True, + dtype=input.dtype) + inputs['Bias'] = bias helper.append_op( type="hierarchical_sigmoid", inputs=inputs, outputs={"Out": out, "PreOut": pre_out}, - attrs={"num_classes": num_classes}) + attrs={"num_classes": num_classes, + "is_sparse": is_sparse}) return out @@ -6478,7 +6584,7 @@ def crop(x, shape=None, offsets=None, name=None): helper = LayerHelper('crop', **locals()) if not (isinstance(shape, list) or isinstance(shape, tuple) or \ - isinstance(shape, Variable)): + isinstance(shape, Variable)): raise ValueError("The shape should be a list, tuple or Variable.") if offsets is None: @@ -6600,7 +6706,7 @@ def affine_grid(theta, out_shape, name=None): helper = LayerHelper('affine_grid') if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \ - isinstance(out_shape, Variable)): + isinstance(out_shape, Variable)): raise ValueError("The out_shape should be a list, tuple or Variable.") if not isinstance(theta, Variable): diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 3f4dd5eb71..bdcd045341 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -124,16 +124,11 @@ class ParallelExecutor(object): os.environ.get('CPU_NUM', multiprocessing.cpu_count())) exec_strategy.num_threads = cpu_num * 2 - # Set 1 thread num under nccl2 distribute - # env to make sure all gpus run ops in same order. - if num_trainers > 1: - assert (use_cuda) - # FIXME(gongwb): avoid this set. - exec_strategy.num_threads = 1 - if build_strategy is None: build_strategy = BuildStrategy() + build_strategy.num_trainers = num_trainers + main = main_program main = main if main else framework.default_main_program() if scope == None: diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 1006cc568a..26035f303e 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -63,7 +63,7 @@ function(py_test_modules TARGET_NAME) set(multiValueArgs MODULES DEPS ENVS) cmake_parse_arguments(py_test_modules "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) add_test(NAME ${TARGET_NAME} - COMMAND env PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_modules_ENVS} + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_modules_ENVS} ${PYTHON_EXECUTABLE} ${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) if (py_test_modules_SERIAL) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index ad7591417e..55c43ef115 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -18,7 +18,7 @@ import unittest import numpy as np import paddle.fluid.core as core from op_test import OpTest -from scipy.special import expit +from scipy.special import expit, erf class TestActivation(OpTest): @@ -295,6 +295,23 @@ class TestRelu(TestActivation): self.check_grad(['X'], 'Out', max_relative_error=0.007) +class TestGelu(TestActivation): + def setUp(self): + self.op_type = "gelu" + self.init_dtype() + + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + out = 0.5 * x * (1.0 + erf(x / np.sqrt(2.0))) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out', max_relative_error=0.007) + + class TestBRelu(TestActivation): def setUp(self): self.op_type = "brelu" @@ -628,6 +645,7 @@ create_test_act_fp16_class(TestCos, grad_atol=0.85) create_test_act_fp16_class(TestSin) create_test_act_fp16_class(TestRound, grad_check=False) create_test_act_fp16_class(TestRelu) +create_test_act_fp16_class(TestGelu) create_test_act_fp16_class(TestBRelu) create_test_act_fp16_class(TestRelu6) create_test_act_fp16_class(TestSoftRelu) diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 6948ae3002..2a6c93f75f 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -16,6 +16,8 @@ from __future__ import print_function import unittest import numpy as np +import paddle.fluid.core as core +import paddle.fluid as fluid import math from op_test import OpTest @@ -40,6 +42,29 @@ class CodeTable(object): return self.c & (1 << bit) +class CodeTableWithCustomTree(object): + def __init__(self, path_table, path_code, index): + self.ptable_ = path_table + self.pcode_ = path_code + self.index_ = index + + def cal_index(self, bit): + return self.ptable_[self.index_][bit] + + def get_length(self): + length = 0 + for ele in self.ptable_[self.index_]: # find the first -1 to stop trace + + if ele >= 0: + length = length + 1 + else: + return length + return length + + def cal_bit(self, bit): + return self.pcode_[self.index_][bit] + + def hsigmoid(x, w, label, bias, num_classes): batch_size = x.shape[0] code_length = find_latest_set(num_classes - 1) @@ -52,7 +77,7 @@ def hsigmoid(x, w, label, bias, num_classes): length = code_table.get_length() for j in range(length): idx = code_table.cal_index(j) - pre_output[i][j] += bias[0][idx] + pre_output[i][j] += bias[idx][0] for i in range(batch_size): code_table = CodeTable(num_classes, label[i]) length = code_table.get_length() @@ -77,17 +102,58 @@ def hsigmoid(x, w, label, bias, num_classes): return pre_output, out +def hsigmoidWithCustomTree(x, w, path_table, path_code, label, bias, + num_classes): + batch_size = x.shape[0] + code_length = len(path_table[0]) + code_table = [0 for _ in range(code_length)] + # init pre_out with shape [N, code_length] + pre_output = np.zeros((batch_size, code_length)) + pre_sum = np.zeros((batch_size, 1)) + out = np.zeros((batch_size, 1)).astype("float32") + if isinstance(bias, np.ndarray): + for i in range(batch_size): + code_table = CodeTableWithCustomTree(path_table, path_code, i) + length = code_table.get_length() + for j in range(length): + idx = code_table.cal_index(j) + pre_output[i][j] += bias[idx][0] + for i in range(batch_size): + code_table = CodeTableWithCustomTree(path_table, path_code, i) + length = code_table.get_length() + for j in range(length): + idx = code_table.cal_index(j) + pre_output[i][j] += np.dot(w[idx], x[i]) + # clip[-40.0, 40.0] + pre_output = np.clip(pre_output, -40.0, 40.0) + # out(i, 0) = \sum_j bit(i, j) * preout(i, j) + for i in range(batch_size): + code_table = CodeTableWithCustomTree(path_table, path_code, i) + length = code_table.get_length() + sum = 0.0 + for j in range(length): + if code_table.cal_bit(j): + sum += pre_output[i][j] + out[i] = -1.0 * sum + # soft relu + pre_output = np.log(1 + np.exp(pre_output)) + pre_sum = pre_output.sum(1).reshape((batch_size, 1)) + out += pre_sum + return pre_output, out + + class TestHSigmoidOp(OpTest): def setUp(self): self.op_type = "hierarchical_sigmoid" num_classes = 6 feature_size = 8 batch_size = 4 - x = np.random.random((batch_size, feature_size)).astype("float32") - w = np.random.random((num_classes - 1, feature_size)).astype("float32") + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 2 label = np.random.randint(0, num_classes, (batch_size, 1)) - bias = np.random.random((1, num_classes - 1)).astype("float32") - self.attrs = {'num_classes': num_classes} + bias = np.random.random((num_classes - 1, 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': False} self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} pre_output, out = hsigmoid(x, w, label, bias, num_classes) self.outputs = {'PreOut': pre_output, 'Out': out} @@ -99,5 +165,185 @@ class TestHSigmoidOp(OpTest): self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) +class TestHSigmoidOpSparse(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") + w = np.random.random((num_classes - 1, feature_size)).astype("float32") + label = np.array([0, 1, 4, 5]) + path_table = np.array( + [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), + (0, 2, -1, -1, + -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) + path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store + bias = np.random.random((num_classes - 1, 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': True} + self.inputs = { + 'X': x, + 'W': w, + 'PTable': path_table, + 'PathCode': path_code, + 'Label': label, + 'Bias': bias + } + pre_output, out = hsigmoidWithCustomTree(x, w, path_table, path_code, + label, bias, num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + self.check_output() + + +class TestHSigmoidOpWithSparseGrad(unittest.TestCase): + def hs_net_conf(self, is_sparse): + input_word = fluid.layers.data(name="x", shape=[1], dtype='int64') + path_table = fluid.layers.data( + name='path_table', shape=[3], dtype='int64') + path_code = fluid.layers.data( + name='path_code', shape=[3], dtype='int64') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + data_list = [input_word, path_table, path_code, label] + + emb = fluid.layers.embedding( + input=input_word, + is_sparse=is_sparse, + size=[3, 3], + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(3)))) + + cost = fluid.layers.hsigmoid( + input=emb, + label=label, + bias_attr=True, + num_classes=3, + path_table=path_table, + path_code=path_code, + is_custom=True, + is_sparse=is_sparse) + + avg_cost = fluid.layers.reduce_mean(cost) + + return avg_cost, data_list + + def training_test(self, is_sparse): + with fluid.program_guard(fluid.Program(), fluid.Program()): + start_up = fluid.default_startup_program() + start_up.random_seed = 1 # Fix random seed + x = np.arange(6).reshape(6) + path_table = np.array([(1, 2, -1), (1, 2, -1)]) + path_code = np.array([(1, 0, -1), (0, 0, -1)]) + label = np.array([1, 4]) + + loss, data_list = self.hs_net_conf(is_sparse) + optimizer = fluid.optimizer.SGD(learning_rate=1e-3) + optimizer.minimize(loss) + + main_program = fluid.default_main_program() + place = fluid.CPUPlace() + feeder = fluid.DataFeeder(feed_list=data_list, place=place) + exe = fluid.Executor(place) + + exe.run(start_up) + result = list() + for i in range(10): + data = [([[x[i % 2]]], [list(path_table[i % 2])], + [list(path_code[i % 2])], [label[i % 2]])] + + loss_val = exe.run(main_program, + feed=feeder.feed(data), + fetch_list=[loss]) + result.append(loss_val) + return result + + def test_hs_grad_with_sparse(self): + dense_result = self.training_test(is_sparse=False) + sparse_result = self.training_test(is_sparse=True) + assert (dense_result == sparse_result) + + +class TestHSigmoidOpWithCostumTree(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 2 + label = np.array([0, 1, 4, 5]) + path_table = np.array( + [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), + (0, 2, -1, -1, + -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) + path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store + bias = np.random.random((num_classes - 1, 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': False} + self.inputs = { + 'X': x, + 'W': w, + 'PTable': path_table, + 'PathCode': path_code, + 'Label': label, + 'Bias': bias + } + pre_output, out = hsigmoidWithCustomTree(x, w, path_table, path_code, + label, bias, num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) + + +class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 2 + label = np.array([0, 1, 4, 5]) + path_table = np.array( + [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), + (0, 2, -1, -1, + -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) + path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store + # bias = np.random.random((num_classes - 1, 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': False} + self.inputs = { + 'X': x, + 'W': w, + 'PTable': path_table, + 'PathCode': path_code, + 'Label': label, + } + pre_output, out = hsigmoidWithCustomTree( + x=x, + w=w, + path_table=path_table, + path_code=path_code, + label=label, + bias=None, + num_classes=num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'W'], ['Out'], no_grad_set=set('Label')) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 559c9cda48..5411607711 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -185,6 +185,25 @@ class TestBook(unittest.TestCase): input=x, label=y, num_classes=2)) print(str(program)) + # test hsigmod with custom tree structure + program2 = Program() + with program_guard(program2): + x2 = layers.data(name='x2', shape=[4, 8], dtype='float32') + y2 = layers.data(name='y2', shape=[4], dtype='int64') + path_table = layers.data( + name='path_table', shape=[4, 6], dtype='int64') + path_code = layers.data( + name='path_code', shape=[4, 6], dtype='int64') + self.assertIsNotNone( + layers.hsigmoid( + input=x2, + label=y2, + num_classes=6, + path_table=path_table, + path_code=path_code, + is_custom=True)) + print(str(program2)) + def test_sequence_expand(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_nce.py b/python/paddle/fluid/tests/unittests/test_nce.py index c01fdd5ddd..f4f9744674 100644 --- a/python/paddle/fluid/tests/unittests/test_nce.py +++ b/python/paddle/fluid/tests/unittests/test_nce.py @@ -14,8 +14,12 @@ from __future__ import print_function -import unittest import numpy as np +import unittest + +import paddle.fluid as fluid +import paddle.fluid.initializer as initializer + from op_test import OpTest @@ -59,7 +63,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes, class TestNCE(OpTest): def generate_data(self, dim, batch_size, num_classes, num_true_class, - num_neg_samples): + num_neg_samples, is_sparse): input = np.random.randn(batch_size, dim).astype(np.float32) weight = np.random.randn(num_classes, dim).astype(np.float32) bias = np.random.randn(num_classes).astype(np.float32) @@ -70,7 +74,8 @@ class TestNCE(OpTest): 'num_neg_samples': num_neg_samples, 'custom_neg_classes': list(range(num_neg_samples)), 'seed': 0, - 'sampler': 0 + 'sampler': 0, + 'is_sparse': is_sparse } self.inputs = { 'Input': input, @@ -81,7 +86,7 @@ class TestNCE(OpTest): } def set_data(self): - self.generate_data(5, 5, 4, 1, 2) + self.generate_data(5, 5, 4, 1, 2, False) def compute(self): out = nce(self.inputs['Input'], self.inputs['Weight'], @@ -107,9 +112,110 @@ class TestNCE(OpTest): ["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02) -class TestNCECase1(TestNCE): +class TestNCECase1Tensor(TestNCE): def set_data(self): - self.generate_data(10, 20, 10, 2, 5) + self.generate_data(10, 20, 10, 2, 5, False) + + +class TestNCECase1SelectedRows(unittest.TestCase): + def setUp(self): + self.base_lr = 0.0001 + self.batch_size = 8 + + @staticmethod + def get_place(): + place = fluid.core.CPUPlace() + return place + + @staticmethod + def get_train_data(batch_size): + batchs = [] + for i in range(batch_size): + input = np.random.randn(batch_size, 10).astype(np.float32) + labels = np.random.randint(0, 20, (batch_size, 1)) + batchs.append([input, labels]) + return batchs + + def get_optimizer(self): + # SGD optimizer + optimizer = fluid.optimizer.SGD(learning_rate=self.base_lr) + return optimizer + + def train_network(self, num_total_classes, num_neg_samples, sampler, + custom_dist, is_sparse): + input = fluid.layers.data(name="input", shape=[10], dtype="float32") + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + w_param = fluid.default_main_program().global_block().create_parameter( + shape=[num_total_classes, 10], + dtype='float32', + name='nce_w', + initializer=initializer.ConstantInitializer()) + b_param = fluid.default_main_program().global_block().create_parameter( + shape=[num_total_classes, 1], + dtype='float32', + name='nce_b', + initializer=initializer.ConstantInitializer()) + + cost = fluid.layers.nce(input=input, + label=label, + num_total_classes=num_total_classes, + sampler=sampler, + custom_dist=custom_dist, + sample_weight=None, + param_attr='nce_w', + bias_attr='nce_b', + seed=1, + num_neg_samples=num_neg_samples, + is_sparse=is_sparse) + avg_cost = fluid.layers.mean(cost) + # optimizer + optimizer = self.get_optimizer() + optimizer.minimize(avg_cost) + + return [avg_cost, [input, label]] + + def test_input_is_selected_rows(self): + place = self.get_place() + exe = fluid.Executor(place) + + data = self.get_train_data(self.batch_size) + nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32') + + rets = [] + # for dense + dense_scope = fluid.core.Scope() + dense_startup_program = fluid.framework.Program() + dense_train_program = fluid.framework.Program() + with fluid.scope_guard(dense_scope): + with fluid.program_guard(dense_train_program, + dense_startup_program): + cost, feeds = self.train_network(20, 5, "custom_dist", + nid_freq_arr.tolist(), False) + feeder = fluid.DataFeeder(feed_list=feeds, place=place) + exe.run(dense_startup_program) + loss_val = exe.run(dense_train_program, + feed=feeder.feed(data), + fetch_list=[cost.name]) + rets.append(np.mean(loss_val)) + + # for sparse + sparse_scope = fluid.core.Scope() + sparse_startup_program = fluid.framework.Program() + sparse_train_program = fluid.framework.Program() + with fluid.scope_guard(sparse_scope): + with fluid.program_guard(sparse_train_program, + sparse_startup_program): + cost, feeds = self.train_network(20, 5, "custom_dist", + nid_freq_arr.tolist(), True) + feeder = fluid.DataFeeder(feed_list=feeds, place=place) + exe.run(sparse_startup_program) + loss_val = exe.run(sparse_train_program, + feed=feeder.feed(data), + fetch_list=[cost.name]) + rets.append(np.mean(loss_val)) + + self.assertEqual(rets[0], rets[1]) if __name__ == '__main__':