From 47ea2534fb9cac31f1b5c15c54112e6105810cb1 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 13 Dec 2018 17:12:38 +0800 Subject: [PATCH 01/29] clean parallel do test=develop --- .../operators/controlflow/parallel_do_op.cc | 426 ------------------ python/paddle/fluid/backward.py | 79 +--- python/paddle/fluid/framework.py | 4 +- python/paddle/fluid/layers/control_flow.py | 152 +------ .../tests/book/notest_understand_sentiment.py | 18 +- .../fluid/tests/book/test_recognize_digits.py | 15 +- .../paddle/fluid/tests/book/test_word2vec.py | 14 +- .../test_memopt_fit_a_line.py | 87 ---- .../fluid/tests/unittests/test_parallel_op.py | 235 ---------- .../memory_optimization_transpiler.py | 5 +- 10 files changed, 10 insertions(+), 1025 deletions(-) delete mode 100644 paddle/fluid/operators/controlflow/parallel_do_op.cc delete mode 100644 python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py delete mode 100644 python/paddle/fluid/tests/unittests/test_parallel_op.py diff --git a/paddle/fluid/operators/controlflow/parallel_do_op.cc b/paddle/fluid/operators/controlflow/parallel_do_op.cc deleted file mode 100644 index ab25628d45..0000000000 --- a/paddle/fluid/operators/controlflow/parallel_do_op.cc +++ /dev/null @@ -1,426 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/executor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/operators/detail/safe_ref.h" - -namespace paddle { -namespace operators { - -static constexpr char kInputs[] = "inputs"; -static constexpr char kParameters[] = "parameters"; -static constexpr char kPlaces[] = "places"; - -static constexpr char kOutputs[] = "outputs"; -static constexpr char kParallelScopes[] = "parallel_scopes"; - -static constexpr char kParallelBlock[] = "sub_block"; -static constexpr char kUseNCCL[] = "use_nccl"; - -using LoDTensor = framework::LoDTensor; -using SelectedRows = framework::SelectedRows; - -static void SplitTensorAndMoveTensorToScopes( - const framework::Scope &scope, std::vector *sub_scopes, - const std::vector &places, - const std::vector &names) { - size_t num_sub_scopes = 0; - for (auto &argu : names) { - const auto &tensor = - detail::Ref(scope.FindVar(argu), - "Cannot find variable %s in the parent scope", argu) - .Get(); - auto lod_tensors = tensor.SplitLoDTensor(places); - - for (auto &lod : lod_tensors) { - VLOG(3) << lod.dims(); - } - if (num_sub_scopes == 0) { - num_sub_scopes = lod_tensors.size(); - } else { - PADDLE_ENFORCE_EQ(num_sub_scopes, lod_tensors.size()); - } - PADDLE_ENFORCE_NE(num_sub_scopes, 0); - if (sub_scopes->size() == 0) { - sub_scopes->reserve(num_sub_scopes); - for (size_t i = 0; i < num_sub_scopes; ++i) { - sub_scopes->emplace_back(&scope.NewScope()); - } - } - - for (size_t i = 0; i < lod_tensors.size(); ++i) { - *detail::Ref(sub_scopes->at(i)->Var(argu), - "Cannot find variable in the sub-scope", argu) - .GetMutable() = lod_tensors[i]; - } - } -} - -inline void CopyOrShare(const framework::Variable &src, - const platform::Place &dst_place, - framework::Variable *dst) { - if (src.IsType()) { - if (src.Get().place() == dst_place) { - dst->GetMutable()->ShareDataWith(src.Get()); - dst->GetMutable()->set_lod(src.Get().lod()); - } else { - TensorCopy(src.Get(), dst_place, dst->GetMutable()); - } - } else if (src.IsType()) { - auto &src_sr = src.Get(); - auto *dst_sr = dst->GetMutable(); - dst_sr->set_height(src_sr.height()); - if (src_sr.value().place() == dst_place) { - dst_sr->mutable_value()->ShareDataWith(src_sr.value()); - dst_sr->set_rows(src_sr.rows()); - } else { - TensorCopy(src_sr.value(), dst_place, dst_sr->mutable_value()); - } - } else { - PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); - } -} - -void WaitOnPlace(const platform::Place place) { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - dev_ctx.Wait(); -} - -void WaitOnPlaces(const std::vector places) { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - - for (auto &place : places) { - auto &dev_ctx = *pool.Get(place); - dev_ctx.Wait(); - } -} - -class ParallelDoOp : public framework::OperatorBase { - public: - ParallelDoOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - - auto *block = Attr(kParallelBlock); - auto *program = block->Program(); - - auto &places = scope.FindVar(Input(kPlaces))->Get(); - - auto &sub_scopes = *scope.FindVar(Output(kParallelScopes)) - ->GetMutable>(); - - // split input - SplitTensorAndMoveTensorToScopes(scope, &sub_scopes, places, - Inputs(kInputs)); - - // copy parameter - for (auto ¶m : Inputs(kParameters)) { - PADDLE_ENFORCE(scope.FindVar(param)->IsType(), - "Only support parameter type as LoDTensor"); - auto &src = scope.FindVar(param)->Get(); - - auto *sub_scope0 = sub_scopes[0]; - auto *dst0 = sub_scope0->Var(param)->GetMutable(); - dst0->ShareDataWith(src); - - for (size_t i = 1; i < sub_scopes.size(); ++i) { - auto &place = places[i]; - auto *sub_scope = sub_scopes[i]; - auto *dst = sub_scope->Var(param)->GetMutable(); - framework::TensorCopy(src, place, dst); - } - } - WaitOnPlaces(places); - - std::vector> workers; - workers.reserve(places.size()); - for (size_t place_idx = 0; place_idx < sub_scopes.size(); ++place_idx) { - auto &place = places[place_idx]; - auto *cur_scope = sub_scopes[place_idx]; - - workers.emplace_back(framework::Async([program, cur_scope, place, block] { - framework::Executor executor(place); - executor.Run(*program, cur_scope, block->ID(), - false /*create_local_scope*/); - })); - } - for (auto &worker : workers) { - worker.wait(); - } - WaitOnPlaces(places); - - // merge output - for (auto &o_name : Outputs(kOutputs)) { - std::vector lod_tensors; - lod_tensors.reserve(sub_scopes.size()); - for (auto *sub_scope : sub_scopes) { - lod_tensors.emplace_back(&sub_scope->FindVar(o_name)->Get()); - } - - auto *lod_tensor_to_be_merged = - scope.FindVar(o_name)->GetMutable(); - lod_tensor_to_be_merged->MergeLoDTensor(lod_tensors, dev_ctx.GetPlace()); - } - WaitOnPlaces(places); - } -}; - -class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput(kInputs, "").AsDuplicable(); - AddInput(kParameters, "").AsDuplicable(); - AddInput(kPlaces, ""); - AddOutput(kOutputs, "").AsDuplicable(); - AddOutput(kParallelScopes, ""); - AddAttr(kParallelBlock, ""); - AddAttr(kUseNCCL, "true if we use nccl on backward") - .SetDefault(false); - AddComment(R"DOC( -ParallelDo Operator. -)DOC"); - } -}; - -class ParallelDoGradOp : public framework::OperatorBase { - public: - ParallelDoGradOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto *block = Attr(kParallelBlock); - auto *program = block->Program(); - - auto &sub_scopes = scope.FindVar(Input(kParallelScopes)) - ->Get>(); - auto &places = scope.FindVar(Input(kPlaces))->Get(); - - // feed output@grad - SplitTensorAndMoveTensorToScopes( - scope, const_cast *>(&sub_scopes), - places, Inputs(framework::GradVarName(kOutputs))); - WaitOnPlaces(places); - - // exe run - std::vector> workers; - for (size_t i = 0; i < sub_scopes.size(); ++i) { - auto &place = places[i]; - auto *cur_scope = sub_scopes[i]; - - // execute - workers.emplace_back(framework::Async([program, cur_scope, place, block] { - framework::Executor executor(place); - executor.Run(*program, cur_scope, block->ID(), - false /*create_local_scope*/); - })); - } - for (auto &worker : workers) { - worker.wait(); - } - WaitOnPlaces(places); - - // NCCL allreduce op will be added by backward, - // so no need to explicitly accumulate grad - if (!(Attr(kUseNCCL))) { - AccumulateGrad(scope, place, sub_scopes, places); - } else { - for (auto &place : places) { - PADDLE_ENFORCE(platform::is_gpu_place(place), - "NCCL only supports cuda place"); - } - } - for (auto &s : Outputs(framework::GradVarName(kParameters))) { - if (s == framework::kEmptyVarName) { - continue; - } - VLOG(3) << "Moving " << s; - CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s)); - } - WaitOnPlaces(places); - } - - void AccumulateGrad(const framework::Scope &scope, - const platform::Place &place, - const std::vector &sub_scopes, - const platform::PlaceList &places) const { - for (auto &s : Outputs(framework::GradVarName(kParameters))) { - if (s == framework::kEmptyVarName) { - continue; - } - VLOG(3) << "Accumulating " << s; - if (s == framework::kEmptyVarName) continue; - std::string tmp_name; - auto *tmp = sub_scopes[0]->Var(&tmp_name); - - for (size_t i = 1; i < sub_scopes.size(); ++i) { - CopyOrShare(*sub_scopes[i]->FindVar(s), places[0], tmp); - WaitOnPlaces(places); - - auto sum_op = framework::OpRegistry::CreateOp( - "sum", {{"X", {s, tmp_name}}}, {{"Out", {s}}}, - framework::AttributeMap{{"use_mkldnn", {false}}}); - VLOG(10) << sum_op->DebugStringEx(sub_scopes[0]); - sum_op->Run(*sub_scopes[0], places[0]); - WaitOnPlace(places[0]); - } - - CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s)); - } - WaitOnPlaces(places); - } -}; - -std::ostream &operator<<(std::ostream &sout, - const std::vector &strs) { - std::copy(strs.begin(), strs.end(), - std::ostream_iterator(sout, ",")); - return sout; -} - -class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker { - public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; - - protected: - virtual std::unique_ptr Apply() const { - auto *grad = new framework::OpDesc(); - grad->SetType("parallel_do_grad"); - for (auto &input_param : this->InputNames()) { - VLOG(3) << input_param; - grad->SetInput(input_param, this->Input(input_param)); - if (input_param != kPlaces) { - grad->SetOutput(framework::GradVarName(input_param), - this->InputGrad(input_param, false)); - } - } - auto *g_block = this->grad_block_[0]; - - // All variable name that needed by gradient operators - std::unordered_set all_inputs_in_grad_blocks; - - for (size_t i = 0; i < g_block->OpSize(); ++i) { - auto *op = g_block->Op(i); - for (auto &var_name : op->InputArgumentNames()) { - all_inputs_in_grad_blocks.insert(var_name); - } - } - - for (auto &output_param : this->OutputNames()) { - if (output_param == kParallelScopes) { - grad->SetInput(output_param, this->Output(output_param)); - grad->SetInput(framework::GradVarName(output_param), - this->Output(output_param)); - } else { - grad->SetInput(output_param, this->Output(output_param)); - std::vector og_names; - for (auto &og_name : this->OutputGrad(output_param)) { - if (all_inputs_in_grad_blocks.count(og_name) != 0) { - // there are some gradient operators who need the OG. So make this - // OG as an input of parallel.do - og_names.push_back(og_name); - } - // else, there is no operator who need the OG. Do not use this OG as - // an input - } - grad->SetInput(framework::GradVarName(output_param), og_names); - } - } - grad->SetInput("Communicator", {"nccl_com__do_not_change_"}); - grad->SetAttrMap(this->Attrs()); - grad->SetBlockAttr(kParallelBlock, grad_block_[0]); - - return std::unique_ptr(grad); - } -}; - -class ParallelDoGradOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInputs(kParameters)); - PADDLE_ENFORCE(ctx->HasInputs(kInputs)); - PADDLE_ENFORCE(ctx->HasInputs(kOutputs)); - - ctx->SetOutputsDim(framework::GradVarName(kParameters), - ctx->GetInputsDim(kParameters)); - - auto i_dims = ctx->GetInputsDim(kInputs); - auto ig_names = ctx->Outputs(framework::GradVarName(kInputs)); - - for (size_t i = 0; i < ig_names.size(); ++i) { - auto &ig_name = ig_names[i]; - if (ig_name == framework::kEmptyVarName) { - continue; - } - - ctx->SetDims({ig_name}, {i_dims[i]}); - } - - auto p_dims = ctx->GetInputsDim(kParameters); - auto pg_names = ctx->Outputs(framework::GradVarName(kParameters)); - for (size_t i = 0; i < pg_names.size(); ++i) { - auto &pg_name = pg_names[i]; - if (pg_name == framework::kEmptyVarName) { - continue; - } - ctx->SetDims({pg_name}, {p_dims[i]}); - } - } -}; - -class ParallelDoGradOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - framework::BlockDesc *sub_block = - boost::get(op_desc.GetAttr(kParallelBlock)); - for (auto &out_vars : op_desc.Outputs()) { - for (auto &out_var : out_vars.second) { - auto &var = block->FindRecursiveOrCreateVar(out_var); - auto sub_var = sub_block->FindRecursiveOrCreateVar(out_var); - if (sub_var.GetType() != var.GetType()) { - var.SetType(sub_var.GetType()); - } - } - } - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(parallel_do, paddle::operators::ParallelDoOp, - paddle::operators::ParallelDoOpProtoMaker, - paddle::operators::ParallelDoGradOpDescMaker); -REGISTER_OPERATOR(parallel_do_grad, paddle::operators::ParallelDoGradOp, - paddle::operators::ParallelDoGradOpShapeInference, - paddle::operators::ParallelDoGradOpVarTypeInference); diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 17fe8dc3c8..b2c3e7c989 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -249,69 +249,6 @@ def serialize_op_decs(op_desc): return proto.__str__() -def _callback_lookup_(op): - """ - Only used in _append_backward_ops_ - Build and returns a callback function for certain op. For example - - parallel_do: AllReduce - - :param op: - :return: callback function - """ - if op.type == 'parallel_do' and op.attr('use_nccl'): - all_vars = op.block.vars - param_names = set(op.input('parameters')) - param_names = [ - name for name in param_names - if all_vars[name].stop_gradient is False - ] - param_grad_names = [n + "@GRAD" for n in param_names] - - class ParallelDoCallBack(object): - def __init__(self, param_grad_names, parallel_scopes_name): - self.has_inserted_nccl_init = False - self.param_grad_names = param_grad_names - self.parallel_scopes_name = parallel_scopes_name - - def __call__(self, block, context): - if not self.has_inserted_nccl_init: - op_desc = _create_op_desc_( - "ncclInit", - {"parallel_scopes": self.parallel_scopes_name}, - {"Communicator": ['nccl_com__do_not_change_']}, {}) - block.program.global_block().desc.append_op().copy_from( - op_desc) - self.has_inserted_nccl_init = True - - current_op_desc = context["__current_op_desc__"] - for o_param in current_op_desc.output_names(): - for o_argu in current_op_desc.output(o_param): - if o_argu in self.param_grad_names: - allreduce_out_name = o_argu + "__nccl_all_reduce__" - op_desc = _create_op_desc_( - "ncclReduce", - { - "X": [o_argu], - "Communicator": - ['nccl_com__do_not_change_'] - }, - {"Out": [allreduce_out_name]}, - {"reduction": "ncclSum", - "root": 0}, ) - block.desc.append_op().copy_from(op_desc) - - op_desc = _create_op_desc_( - "assign", {"X": [allreduce_out_name]}, - {"Out": [o_argu]}, {}) - block.desc.append_op().copy_from(op_desc) - - return ParallelDoCallBack(param_grad_names, - op.output("parallel_scopes")) - else: - return None - - def _append_backward_ops_(block, ops, target_block, @@ -349,17 +286,8 @@ def _append_backward_ops_(block, sub_block = program.block(op._block_attr_id("sub_block")) grad_sub_block = program._create_block() grad_sub_block._set_forward_block_idx(sub_block.idx) - cb = _callback_lookup_(op) - if cb is not None: - if callbacks is None: - new_callbacks = [cb] - else: - new_callbacks = callbacks + [_callback_lookup_(op)] - _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, - no_grad_dict, grad_to_var, new_callbacks) - else: - _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, - no_grad_dict, grad_to_var, callbacks) + _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, + no_grad_dict, grad_to_var, callbacks) program._rollback() grad_sub_block_list.append(grad_sub_block.desc) @@ -424,9 +352,6 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): # infer_shape and infer_type op_desc.infer_var_type(block.desc) op_desc.infer_shape(block.desc) - # ncclInit dones't need to set data_type - if op_desc.type() == 'ncclInit': - continue for arg in op_desc.output_arg_names(): if arg in new_vars: _infer_var_data_type_(arg, block) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 0897920594..d0bd78454d 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -563,8 +563,8 @@ class Operator(object): OP_WITHOUT_KERNEL_SET = { 'feed', 'fetch', 'save', 'load', 'recurrent', 'go', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', - 'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine', - 'ncclInit', 'select', 'checkpoint_notify', 'gen_nccl_id' + 'listen_and_serv', 'save_combine', 'load_combine', 'ncclInit', 'select', + 'checkpoint_notify', 'gen_nccl_id' } def __init__(self, diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index b7e3968569..21454370dd 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -226,156 +226,6 @@ class BlockGuard(object): return True -class ParallelDo(object): - """ - ParallelDo is used to represent multi-thread data parallel processing. - - Its vanilla implementation can be shown as the following (:math:`|` means - single thread and :math:`||||` means multiple threads) - - .. code-block:: text - - In the forward pass - | Split input onto different devices - | Copy parameter onto different devices - |||| Compute forward pass in parallel - | Merge output from different devices - - In the backward pass - | Split output@grad onto different devices - |||| Compute backward pass in parallel - | accumulate param@grad from different devices to the first device - | Merge input@grad from different devices - | Copy param@grad to the place of parallel_do_op - - Examples: - - .. code-block:: python - - images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - - # ParallelDo version & Single-thread version - if thread_num > 1: - places = fluid.layers.get_places(thread_num) - pd = fluid.layers.control_flow.ParallelDo(places) - with pd.do(): - images = pd.read_input(images) - label = pd.read_input(label) - predict = cnn_model(images) - cost = fluid.layers.cross_entropy(input=predict, label=label) - - avg_cost = fluid.layers.mean(x=cost) - pd.write_output(avg_cost) - - avg_cost = pd() - avg_cost = fluid.layers.mean(avg_cost) - else: - predict = cnn_model(images) - cost = fluid.layers.cross_entropy(input=predict, label=label) - avg_cost = fluid.layers.mean(x=cost) - - .. warning:: - - It will be soon deprecated, please use ParallelExecutor instead. - """ - - def __init__(self, places, use_nccl=False, name=None): - warnings.warn( - "API ParallelDo is deprecated since 0.15.0. Please use ParallelExecutor instead.", - Warning) - self.helper = LayerHelper("parallel_do", name=name) - self.inputs = [] - self.places = places - self.outputs = [] - self.status = StaticRNN.BEFORE_RNN_BLOCK - self.use_nccl = use_nccl - - def do(self): - return BlockGuardWithCompletion(self) - - def parent_block(self): - prog = self.helper.main_program - parent_idx = prog.current_block().parent_idx - assert parent_idx >= 0 - parent_block = prog.block(parent_idx) - return parent_block - - def __call__(self, *args, **kwargs): - if self.status != StaticRNN.AFTER_RNN_BLOCK: - raise ValueError("RNN output can only be retrieved after rnn block") - if len(self.outputs) == 0: - raise ValueError("RNN has no output") - elif len(self.outputs) == 1: - return self.outputs[0] - else: - return self.outputs - - def read_input(self, var): - self.inputs.append(var) - return var - - def write_output(self, var): - self.outputs.append(var) - - def get_parameters(self): - main_program = self.helper.main_program - current_block = main_program.current_block() - parent_block = self.parent_block() - - local_inputs = set() - params = list() - for var in self.inputs: - local_inputs.add(var.name) - - for op in current_block.ops: - for iname in op.input_names: - for in_var_name in op.input(iname): - if in_var_name not in local_inputs: - params.append(in_var_name) - - for oname in op.output_names: - for out_var_name in op.output(oname): - local_inputs.add(out_var_name) - - params = list(set(params)) - - return [parent_block.var(name) for name in params] - - def _complete_op(self): - main_program = self.helper.main_program - current_block = main_program.current_block() - parent_block = self.parent_block() - - step_scope = parent_block.create_var( - type=core.VarDesc.VarType.STEP_SCOPES) - - self.outputs = [ - parent_block.create_var( - name=o.name, - shape=o.shape, - dtype=o.dtype, - lod_level=o.lod_level, - persistable=o.persistable, - stop_gradient=o.stop_gradient) for o in self.outputs - ] - - inputs = [parent_block.var(i.name) for i in self.inputs] - outputs = [parent_block.var(o.name) for o in self.outputs] - - parent_block.append_op( - type='parallel_do', - inputs={ - 'inputs': inputs, - 'parameters': self.get_parameters(), - 'places': self.places - }, - outputs={'outputs': outputs, - 'parallel_scopes': [step_scope]}, - attrs={'sub_block': current_block, - 'use_nccl': self.use_nccl}) - - class BlockGuardWithCompletion(BlockGuard): """ BlockGuardWithCompletion class. @@ -384,7 +234,7 @@ class BlockGuardWithCompletion(BlockGuard): """ def __init__(self, rnn): - if not (isinstance(rnn, StaticRNN) or isinstance(rnn, ParallelDo)): + if not isinstance(rnn, StaticRNN): raise TypeError( "BlockGuardWithCompletion takes a StaticRNN or ParallelDo") super(BlockGuardWithCompletion, self).__init__(rnn.helper.main_program) diff --git a/python/paddle/fluid/tests/book/notest_understand_sentiment.py b/python/paddle/fluid/tests/book/notest_understand_sentiment.py index a666507bd9..5658bb4ec4 100644 --- a/python/paddle/fluid/tests/book/notest_understand_sentiment.py +++ b/python/paddle/fluid/tests/book/notest_understand_sentiment.py @@ -15,7 +15,6 @@ from __future__ import print_function from paddle.fluid.layers.device import get_places -from paddle.fluid.layers.control_flow import ParallelDo import unittest import paddle.fluid as fluid import paddle @@ -147,22 +146,7 @@ def train(word_dict, cost, acc_out, prediction = net_method( data, label, input_dim=dict_dim, class_dim=class_dim) else: - places = get_places() - pd = ParallelDo(places) - with pd.do(): - cost, acc, _ = net_method( - pd.read_input(data), - pd.read_input(label), - input_dim=dict_dim, - class_dim=class_dim) - pd.write_output(cost) - pd.write_output(acc) - - cost, acc = pd() - cost = fluid.layers.mean(cost) - acc_out = fluid.layers.mean(acc) - prediction = None - assert save_dirname is None + raise NotImplementedError() adagrad = fluid.optimizer.Adagrad(learning_rate=0.002) adagrad.minimize(cost) diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index 4a70976a48..54936519ce 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -25,7 +25,6 @@ import numpy import paddle import paddle.fluid as fluid from paddle.fluid.layers.device import get_places -from paddle.fluid.layers.control_flow import ParallelDo BATCH_SIZE = 64 @@ -82,19 +81,7 @@ def train(nn_type, net_conf = conv_net if parallel: - places = get_places() - pd = ParallelDo(places) - with pd.do(): - img_ = pd.read_input(img) - label_ = pd.read_input(label) - prediction, avg_loss, acc = net_conf(img_, label_) - for o in [avg_loss, acc]: - pd.write_output(o) - - avg_loss, acc = pd() - # get mean loss and acc through every devices. - avg_loss = fluid.layers.mean(avg_loss) - acc = fluid.layers.mean(acc) + raise NotImplementedError() else: prediction, avg_loss, acc = net_conf(img, label) diff --git a/python/paddle/fluid/tests/book/test_word2vec.py b/python/paddle/fluid/tests/book/test_word2vec.py index 9191f0fc20..08f70c9cab 100644 --- a/python/paddle/fluid/tests/book/test_word2vec.py +++ b/python/paddle/fluid/tests/book/test_word2vec.py @@ -17,7 +17,6 @@ from __future__ import print_function import paddle import paddle.fluid as fluid from paddle.fluid.layers.device import get_places -from paddle.fluid.layers.control_flow import ParallelDo import unittest import os import numpy as np @@ -84,18 +83,7 @@ def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True): avg_cost, predict_word = __network__( [first_word, second_word, third_word, forth_word, next_word]) else: - places = get_places() - pd = ParallelDo(places) - with pd.do(): - avg_cost, predict_word = __network__( - list( - map(pd.read_input, [ - first_word, second_word, third_word, forth_word, - next_word - ]))) - pd.write_output(avg_cost) - - avg_cost = fluid.layers.mean(pd()) + raise NotImplementedError() sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) sgd_optimizer.minimize(avg_cost) diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py deleted file mode 100644 index dab2a52bc9..0000000000 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import math -import sys - -import paddle -import paddle.fluid as fluid -from paddle.fluid.layers.device import get_places -from paddle.fluid.layers.control_flow import ParallelDo - -# need to fix random seed and training data to compare the loss -# value accurately calculated by the default and the memory optimization -# version. -fluid.default_startup_program().random_seed = 111 - -x = fluid.layers.data(name='x', shape=[13], dtype='float32') -y = fluid.layers.data(name='y', shape=[1], dtype='float32') - -device_type = 'CPU' -use_nccl = False -place = fluid.CPUPlace() -if fluid.core.is_compiled_with_cuda(): - device_type = 'CUDA' - use_nccl = False - place = fluid.CUDAPlace(0) - -places = get_places(device_count=0, device_type=device_type) -pd = ParallelDo(places, use_nccl=use_nccl) -with pd.do(): - x_ = pd.read_input(x) - y_ = pd.read_input(y) - y_predict = fluid.layers.fc(input=x_, size=1, act=None) - cost = fluid.layers.square_error_cost(input=y_predict, label=y_) - avg_cost = fluid.layers.mean(x=cost) - pd.write_output(avg_cost) - -cost = pd() -avg_cost = fluid.layers.mean(x=cost) -sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01) -sgd_optimizer.minimize(avg_cost) - -fluid.memory_optimize(fluid.default_main_program(), print_log=True) -# fluid.release_memory(fluid.default_main_program()) - -BATCH_SIZE = 200 - -# fix the order of training data -train_reader = paddle.batch( - paddle.dataset.uci_housing.train(), batch_size=BATCH_SIZE, drop_last=False) - -# train_reader = paddle.batch( -# paddle.reader.shuffle( -# paddle.dataset.uci_housing.train(), buf_size=500), -# batch_size=BATCH_SIZE) - -feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) -exe = fluid.Executor(place) - -exe.run(fluid.default_startup_program()) - -PASS_NUM = 100 -for pass_id in range(PASS_NUM): - for data in train_reader(): - avg_loss_value, = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost]) - - if avg_loss_value[0] < 10.0: - exit(0) # if avg cost less than 10.0, we think our code is good. - print(avg_loss_value[0]) - if math.isnan(float(avg_loss_value)): - sys.exit("got NaN loss, training failed.") -exit(1) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_op.py b/python/paddle/fluid/tests/unittests/test_parallel_op.py deleted file mode 100644 index 380e172844..0000000000 --- a/python/paddle/fluid/tests/unittests/test_parallel_op.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest - -import paddle.fluid as fluid -from paddle.fluid.layers.device import get_places -from paddle.fluid.layers.control_flow import ParallelDo -import paddle.fluid.profiler as profiler -import numpy -import six - - -class BaseParallelForTest(unittest.TestCase): - def run_test(self, callback, feed, fetch): - """ - Run the unittest for parallel.for - Args: - callback(callable): A callable function returns a generator. There - are two yields in the generator function. The first yield - returns the data layers, and the second yield returns the loss. - The modified data variables will be sent back during the first - yield. - - feed(dict): The executor feeding dictionary. - fetch(list|basestr): The fetch name lists. - - Returns: - None - - Raises: - AssertionError when the computation of cpu, parallel.for in cpu, - gpu, parallel.for in gpu are different. - - """ - cpu = fluid.CPUPlace() - result_cpu = self._run_test_impl_( - callback=callback, - feed=feed, - fetch=fetch, - place=cpu, - use_parallel=False) - result_cpu_parallel = self._run_test_impl_( - callback=callback, - feed=feed, - fetch=fetch, - place=cpu, - use_parallel=True) - if fluid.core.is_compiled_with_cuda(): - gpu = fluid.CUDAPlace(0) - result_gpu = self._run_test_impl_( - callback=callback, - feed=feed, - fetch=fetch, - place=gpu, - use_parallel=False, - use_gpu=True) - result_gpu_parallel = self._run_test_impl_( - callback=callback, - feed=feed, - fetch=fetch, - place=gpu, - use_parallel=True, - use_gpu=True) - result_gpu_nccl = self._run_test_impl_( - callback=callback, - feed=feed, - fetch=fetch, - place=gpu, - use_parallel=True, - use_nccl=True, - use_gpu=True) - self._assert_same_(fetch, result_cpu, result_cpu_parallel, - result_gpu, result_gpu_parallel, result_gpu_nccl) - else: - self._assert_same_(fetch, result_cpu, result_cpu_parallel) - - def _run_test_impl_(self, - callback, - feed, - fetch, - place, - use_parallel=False, - use_nccl=False, - use_gpu=False): - """ - Run a single test, returns the fetch values - Args: - place(Place): the computation place. - use_parallel(bool): Whether use parallel.for or not. - - Returns: - Fetched numpy arrays. - - """ - if isinstance(fetch, six.string_types): - fetch = [fetch] - main = fluid.Program() - startup = fluid.Program() - # Fix seed - main.random_seed = 10 - startup.random_seed = 10 - - with fluid.program_guard(main, startup): - generator = callback() - # Automatically insert parallel do if use_parallel = True - if use_parallel: - thread_num = fluid.core.get_cuda_device_count( - ) if use_gpu else 8 - places = get_places(thread_num) - pd = ParallelDo(places, use_nccl=use_nccl) - data = next(generator) - - if isinstance(data, fluid.framework.Variable): - data = [data] - - with pd.do(): - ins = list(map(pd.read_input, data)) - if len(ins) == 1: - ins = ins[0] - loss = generator.send(ins) # patch input - pd.write_output(loss) - - loss = pd() - else: - data = next(generator) - loss = generator.send(data) - self.assertIsNotNone(loss) - avg_loss = fluid.layers.mean(loss) - fluid.backward.append_backward(loss=avg_loss) - - exe = fluid.Executor(place) - exe.run(startup) - if use_gpu: - profile_type = 'GPU' - else: - profile_type = 'CPU' - with profiler.profiler(profile_type, 'total', '/tmp/profiler'): - return exe.run(main, feed=feed, fetch_list=fetch) - - def _assert_same_(self, fetch, *args): - """ - Assert the return values of `run_test` are same. - Args: - fetch: Fetch list. Used for print error message - *args: The fetch result lists of each situations. - - Returns: - None - - Raises: - AssertionError - - """ - - def _impl_(a, b, fetch_id, item_id): - item_str = [ - 'CPU', 'ParallelCPU', 'GPU', 'ParallelGPU', 'ParallelGPUNCCL' - ] - flag = numpy.allclose(a, b, rtol=0.1, atol=1e-3) - self.assertTrue(flag, - "The {0} are different in {1}, {2} vs {3}".format( - fetch[fetch_id], item_str[item_id], a, b)) - - for i, items in enumerate(zip(*args)): - self.assertGreater(len(items), 0) - for j in range(1, len(items)): - _impl_(items[0], items[j], fetch_id=i, item_id=j) - - -class ParallelOpTest(BaseParallelForTest): - @staticmethod - def __network__(): - x = fluid.layers.data(shape=[784], dtype='float32', name='img') - x = yield x - hidden = fluid.layers.fc(input=x, size=200, param_attr='fc1.w') - hidden = fluid.layers.batch_norm(input=hidden) - loss = fluid.layers.mean(hidden) - yield loss - - def test_simple_fc(self): - self.run_test( - callback=self.__network__, - feed={ - 'img': numpy.random.random(size=(51, 784)).astype('float32') - }, - fetch=['fc1.w@GRAD']) - - def test_fc_with_tiny_data(self): - self.run_test( - callback=self.__network__, - feed={'img': numpy.random.random(size=(1, 784)).astype('float32')}, - fetch=['fc1.w@GRAD']) - - -class ParallelOpTestMultipleInput(BaseParallelForTest): - @staticmethod - def __network__(): - x = fluid.layers.data( - shape=[784], dtype='float32', name='img1', stop_gradient=False) - y = fluid.layers.data( - shape=[784], dtype='float32', name='img2', stop_gradient=False) - yield [x, y] - x = x + y - hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w') - hidden2 = fluid.layers.fc(input=hidden1, size=200, param_attr='fc2.w') - hidden3 = fluid.layers.fc(input=hidden2, size=200, param_attr='fc3.w') - loss = fluid.layers.mean(hidden3) - yield loss - - def test_simple_fc(self): - self.run_test( - callback=self.__network__, - feed={ - 'img1': numpy.random.random(size=(51, 784)).astype('float32'), - 'img2': numpy.random.random(size=(51, 784)).astype('float32') - }, - fetch=['fc1.w@GRAD', 'fc2.w@GRAD', 'fc3.w@GRAD']) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 95aafec053..d10ea4e472 100755 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -35,11 +35,10 @@ dtype_to_size = { } SUB_BLOCK_OPS = [ - "while", "while_grad", "parallel_do", "parallel_do_grad", - "conditional_block", "conditional_block_grad" + "while", "while_grad", "conditional_block", "conditional_block_grad" ] -SUB_BLOCK_PAIR = [("while", "while_grad"), ("parallel_do", "parallel_do_grad"), +SUB_BLOCK_PAIR = [("while", "while_grad"), ("conditional_block", "conditional_block_grad")] PRINT_LOG = False From 36da940bc1ec69f1bdcb1d83c473136dc070fd87 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 13 Dec 2018 17:14:53 +0800 Subject: [PATCH 02/29] clean more test=develop --- python/paddle/fluid/layers/control_flow.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 21454370dd..9d98e8333b 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -235,8 +235,7 @@ class BlockGuardWithCompletion(BlockGuard): def __init__(self, rnn): if not isinstance(rnn, StaticRNN): - raise TypeError( - "BlockGuardWithCompletion takes a StaticRNN or ParallelDo") + raise TypeError("BlockGuardWithCompletion takes a StaticRNN") super(BlockGuardWithCompletion, self).__init__(rnn.helper.main_program) self.rnn = rnn From fc6ec6bd1425b01a130cefe7411422e8eb62a95d Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Thu, 13 Dec 2018 17:43:53 +0800 Subject: [PATCH 03/29] add sparse mode adam --- paddle/fluid/operators/optimizers/adam_op.cc | 5 +++ paddle/fluid/operators/optimizers/adam_op.h | 41 +++++++++++++------ python/paddle/fluid/optimizer.py | 7 +++- .../fluid/tests/unittests/test_adam_op.py | 20 +++++---- 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index 5710cda39a..b2c2e5c325 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -110,6 +110,11 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { "(float, default 1.0e-8) " "Constant for numerical stability") .SetDefault(1.0e-8f); + AddAttr( + "sparse_mode", + "(bool, default false) " + "only update the parameter that has gradient in sparse update") + .SetDefault(false); AddComment(R"DOC( Adam Optimizer. diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 3455d1ee54..ca5454ef04 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -177,12 +177,13 @@ struct SparseAdamFunctor { const int64_t* rows_; int64_t row_numel_; int64_t row_count_; + bool sparse_mode_; SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2, T* mom2_out, const T* lr, const T* grad, const T* param, T* param_out, const int64_t* rows, - int64_t row_numel, int64_t row_count) + int64_t row_numel, int64_t row_count, bool sparse_mode) : beta1_(beta1), beta2_(beta2), epsilon_(epsilon), @@ -198,13 +199,10 @@ struct SparseAdamFunctor { param_out_(param_out), rows_(rows), row_numel_(row_numel), - row_count_(row_count) {} - - inline HOSTDEVICE void operator()(size_t i) const { - auto row_idx = - math::BinarySearch(rows_, row_count_, i / row_numel_); - T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; + row_count_(row_count), + sparse_mode_(sparse_mode) {} + inline HOSTDEVICE void sparse_update(size_t i, T g) const { // The following code is the same as dense T mom1 = moment1_[i]; T mom2 = moment2_[i]; @@ -225,6 +223,13 @@ struct SparseAdamFunctor { moment2_out_[i] = mom2; param_out_[i] = p; } + + inline HOSTDEVICE void operator()(size_t i) const { + auto row_idx = + math::BinarySearch(rows_, row_count_, i / row_numel_); + T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; + sparse_update(i, g); + } }; template @@ -240,6 +245,7 @@ class AdamOpKernel : public framework::OpKernel { using paddle::framework::LoDTensor; using paddle::operators::detail::Ref; + bool sparse_mode = ctx.Attr("sparse_mode"); T beta1 = static_cast(ctx.Attr("beta1")); T beta2 = static_cast(ctx.Attr("beta2")); T epsilon = static_cast(ctx.Attr("epsilon")); @@ -351,11 +357,22 @@ class AdamOpKernel : public framework::OpKernel { mom2_out.template mutable_data(ctx.GetPlace()), lr.template data(), grad_data, param.template data(), param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, - grad_merge.rows().size()); - platform::ForRange for_range( - static_cast(ctx.device_context()), - param.numel()); - for_range(functor); + grad_merge.rows().size(), sparse_mode); + if (sparse_mode) { + size_t row_count = grad_merge.rows().size(); + for (size_t row_index = 0; row_index < row_count; ++row_index) { + for (size_t offset = 0; offset < row_numel; ++offset) { + size_t i = rows[row_index] * row_numel + offset; + T g = grad_data[row_index * row_numel + offset]; + functor.sparse_update(i, g); + } + } + } else { + platform::ForRange for_range( + static_cast(ctx.device_context()), + param.numel()); + for_range(functor); + } } else { PADDLE_THROW("Variable type not supported by adam_op"); } diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index da92826d41..9c7482bc40 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -663,7 +663,8 @@ class AdamOptimizer(Optimizer): beta2=0.999, epsilon=1e-8, regularization=None, - name=None): + name=None, + sparse_mode=False): assert learning_rate is not None assert beta1 is not None assert beta2 is not None @@ -676,6 +677,7 @@ class AdamOptimizer(Optimizer): self._beta1 = beta1 self._beta2 = beta2 self._epsilon = epsilon + self._sparse_mode = sparse_mode def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) @@ -729,7 +731,8 @@ class AdamOptimizer(Optimizer): attrs={ "beta1": self._beta1, "beta2": self._beta2, - "epsilon": self._epsilon + "epsilon": self._epsilon, + "sparse_mode": self._sparse_mode }) return adam_op diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index 5318d2f976..da91875a14 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -194,7 +194,8 @@ def adam_step(inputs, attributes): return param_out, moment1_out, moment2_out -def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): +def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad, + sparse_mode): ''' Simulate one step of the adam optimizer :param inputs: dict of inputs @@ -230,7 +231,7 @@ def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): class TestSparseAdamOp(unittest.TestCase): - def setup(self, scope, place): + def setup(self, scope, place, sparse_mode): beta1 = 0.78 beta2 = 0.836 epsilon = 1e-4 @@ -262,19 +263,21 @@ class TestSparseAdamOp(unittest.TestCase): self.sparse_inputs = ["Grad"] - param_out, mom1, mom2 = adam_step_sparse( - self.dense_inputs, self.attrs, height, rows, row_numel, np_array) + param_out, mom1, mom2 = adam_step_sparse(self.dense_inputs, self.attrs, + height, rows, row_numel, + np_array, sparse_mode) self.outputs = { "ParamOut": param_out, "Moment1Out": mom1, "Moment2Out": mom2 } - def check_with_place(self, place): + def check_with_place(self, place, sparse_mode): scope = core.Scope() - self.setup(scope, place) + self.setup(scope, place, sparse_mode) op_args = dict() + op_args['sparse_mode'] = sparse_mode for key, np_array in self.dense_inputs.items(): var = scope.var(key).get_tensor() var.set(np_array, place) @@ -305,12 +308,13 @@ class TestSparseAdamOp(unittest.TestCase): 0.00001) j += 1 - def test_sparse_sgd(self): + def test_sparse_adam(self): places = [core.CPUPlace()] if core.is_compiled_with_cuda(): places.append(core.CUDAPlace(0)) for place in places: - self.check_with_place(place) + for sparse_mode in (True, False): + self.check_with_place(place, sparse_mode) if __name__ == "__main__": From 3dc29b390537cca68f43f21f44c2c2fde84fa297 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Thu, 13 Dec 2018 22:02:55 +0800 Subject: [PATCH 04/29] change sparse_update to adam_update --- paddle/fluid/operators/optimizers/adam_op.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index ca5454ef04..25e23c5f9d 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -202,7 +202,7 @@ struct SparseAdamFunctor { row_count_(row_count), sparse_mode_(sparse_mode) {} - inline HOSTDEVICE void sparse_update(size_t i, T g) const { + inline HOSTDEVICE void adam_update(size_t i, T g) const { // The following code is the same as dense T mom1 = moment1_[i]; T mom2 = moment2_[i]; @@ -228,7 +228,7 @@ struct SparseAdamFunctor { auto row_idx = math::BinarySearch(rows_, row_count_, i / row_numel_); T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; - sparse_update(i, g); + adam_update(i, g); } }; @@ -364,7 +364,7 @@ class AdamOpKernel : public framework::OpKernel { for (size_t offset = 0; offset < row_numel; ++offset) { size_t i = rows[row_index] * row_numel + offset; T g = grad_data[row_index * row_numel + offset]; - functor.sparse_update(i, g); + functor.adam_update(i, g); } } } else { From c624417c6f5f1d61ab539aa9c88e95b929a19054 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 14 Dec 2018 11:27:14 +0800 Subject: [PATCH 05/29] change sparse mode to lazy mode --- paddle/fluid/operators/optimizers/adam_op.cc | 2 +- paddle/fluid/operators/optimizers/adam_op.h | 12 ++++++------ python/paddle/fluid/optimizer.py | 6 +++--- .../paddle/fluid/tests/unittests/test_adam_op.py | 16 ++++++++-------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index b2c2e5c325..7993224327 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -111,7 +111,7 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { "Constant for numerical stability") .SetDefault(1.0e-8f); AddAttr( - "sparse_mode", + "lazy_mode", "(bool, default false) " "only update the parameter that has gradient in sparse update") .SetDefault(false); diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 25e23c5f9d..5870557bb7 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -177,13 +177,13 @@ struct SparseAdamFunctor { const int64_t* rows_; int64_t row_numel_; int64_t row_count_; - bool sparse_mode_; + bool lazy_mode_; SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2, T* mom2_out, const T* lr, const T* grad, const T* param, T* param_out, const int64_t* rows, - int64_t row_numel, int64_t row_count, bool sparse_mode) + int64_t row_numel, int64_t row_count, bool lazy_mode) : beta1_(beta1), beta2_(beta2), epsilon_(epsilon), @@ -200,7 +200,7 @@ struct SparseAdamFunctor { rows_(rows), row_numel_(row_numel), row_count_(row_count), - sparse_mode_(sparse_mode) {} + lazy_mode_(lazy_mode) {} inline HOSTDEVICE void adam_update(size_t i, T g) const { // The following code is the same as dense @@ -245,7 +245,7 @@ class AdamOpKernel : public framework::OpKernel { using paddle::framework::LoDTensor; using paddle::operators::detail::Ref; - bool sparse_mode = ctx.Attr("sparse_mode"); + bool lazy_mode = ctx.Attr("lazy_mode"); T beta1 = static_cast(ctx.Attr("beta1")); T beta2 = static_cast(ctx.Attr("beta2")); T epsilon = static_cast(ctx.Attr("epsilon")); @@ -357,8 +357,8 @@ class AdamOpKernel : public framework::OpKernel { mom2_out.template mutable_data(ctx.GetPlace()), lr.template data(), grad_data, param.template data(), param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, - grad_merge.rows().size(), sparse_mode); - if (sparse_mode) { + grad_merge.rows().size(), lazy_mode); + if (lazy_mode) { size_t row_count = grad_merge.rows().size(); for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t offset = 0; offset < row_numel; ++offset) { diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 9c7482bc40..c53bf4913a 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -664,7 +664,7 @@ class AdamOptimizer(Optimizer): epsilon=1e-8, regularization=None, name=None, - sparse_mode=False): + lazy_mode=False): assert learning_rate is not None assert beta1 is not None assert beta2 is not None @@ -677,7 +677,7 @@ class AdamOptimizer(Optimizer): self._beta1 = beta1 self._beta2 = beta2 self._epsilon = epsilon - self._sparse_mode = sparse_mode + self._lazy_mode = lazy_mode def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) @@ -732,7 +732,7 @@ class AdamOptimizer(Optimizer): "beta1": self._beta1, "beta2": self._beta2, "epsilon": self._epsilon, - "sparse_mode": self._sparse_mode + "lazy_mode": self._lazy_mode }) return adam_op diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index da91875a14..461196689c 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -195,7 +195,7 @@ def adam_step(inputs, attributes): def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad, - sparse_mode): + lazy_mode): ''' Simulate one step of the adam optimizer :param inputs: dict of inputs @@ -231,7 +231,7 @@ def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad, class TestSparseAdamOp(unittest.TestCase): - def setup(self, scope, place, sparse_mode): + def setup(self, scope, place, lazy_mode): beta1 = 0.78 beta2 = 0.836 epsilon = 1e-4 @@ -265,19 +265,19 @@ class TestSparseAdamOp(unittest.TestCase): param_out, mom1, mom2 = adam_step_sparse(self.dense_inputs, self.attrs, height, rows, row_numel, - np_array, sparse_mode) + np_array, lazy_mode) self.outputs = { "ParamOut": param_out, "Moment1Out": mom1, "Moment2Out": mom2 } - def check_with_place(self, place, sparse_mode): + def check_with_place(self, place, lazy_mode): scope = core.Scope() - self.setup(scope, place, sparse_mode) + self.setup(scope, place, lazy_mode) op_args = dict() - op_args['sparse_mode'] = sparse_mode + op_args['lazy_mode'] = lazy_mode for key, np_array in self.dense_inputs.items(): var = scope.var(key).get_tensor() var.set(np_array, place) @@ -313,8 +313,8 @@ class TestSparseAdamOp(unittest.TestCase): if core.is_compiled_with_cuda(): places.append(core.CUDAPlace(0)) for place in places: - for sparse_mode in (True, False): - self.check_with_place(place, sparse_mode) + for lazy_mode in (True, False): + self.check_with_place(place, lazy_mode) if __name__ == "__main__": From eb5d427d3940cd53500fc4003c66ad37ef1738db Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 14 Dec 2018 11:37:39 +0800 Subject: [PATCH 06/29] add comment for lazy_mode adam optimizer --- python/paddle/fluid/optimizer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index c53bf4913a..59c22d4e49 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -641,9 +641,14 @@ class AdamOptimizer(Optimizer): beta1 (float): The exponential decay rate for the 1st moment estimates. beta2 (float): The exponential decay rate for the 2nd moment estimates. epsilon (float): a small float value for numerical stability. - regularization: A Regularizer, such as - fluid.regularizer.L2DecayRegularizer. + regularization: A Regularizer, such as fluid.regularizer.L2DecayRegularizer. name: A optional name prefix. + lazy_mode(bool: false): The official Adam algorithm has two moving-average accumulators + the accumulators are updated at every step. Every element of the two moving-average is updated + in both dense mode and sparse mode. If the size of parameter is very large, then the update + may be very slow. The lazy mode only update the element that has gradient is the current + mini-batch, so it will be much more faster. But this mode has different semantics with the + original Adam algorithm and may lead to different result. Examples: .. code-block:: python From cf5264629f914724f91e0a364adca4728b8dcc96 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 14 Dec 2018 15:20:07 +0800 Subject: [PATCH 07/29] update API.spec test=develop --- paddle/fluid/API.spec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 8e6482ca98..cfa28948e9 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -365,7 +365,7 @@ paddle.fluid.optimizer.MomentumOptimizer.__init__ ArgSpec(args=['self', 'learnin paddle.fluid.optimizer.MomentumOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.optimizer.AdagradOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1e-06, None, None)) paddle.fluid.optimizer.AdagradOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) -paddle.fluid.optimizer.AdamOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.9, 0.999, 1e-08, None, None)) +paddle.fluid.optimizer.AdamOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon', 'regularization', 'name', 'lazy_mode'], varargs=None, keywords=None, defaults=(0.001, 0.9, 0.999, 1e-08, None, None, False)) paddle.fluid.optimizer.AdamOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.optimizer.AdamaxOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.9, 0.999, 1e-08, None, None)) paddle.fluid.optimizer.AdamaxOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) From 6445cf1e91f6e9ac169f6834d4b3471136d9bd38 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sun, 16 Dec 2018 21:22:03 +0800 Subject: [PATCH 08/29] fix test=develop --- python/paddle/fluid/tests/book/test_recognize_digits.py | 2 +- python/paddle/fluid/tests/book/test_word2vec.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index 54936519ce..3b2c4af8ae 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -260,7 +260,7 @@ def inject_all_tests(): for use_cuda in (False, True): if use_cuda and not core.is_compiled_with_cuda(): continue - for parallel in (False, True): + for parallel in (False, ): for nn_type in ('mlp', 'conv'): inject_test_method(use_cuda, parallel, nn_type, True) diff --git a/python/paddle/fluid/tests/book/test_word2vec.py b/python/paddle/fluid/tests/book/test_word2vec.py index 08f70c9cab..e24a9aa989 100644 --- a/python/paddle/fluid/tests/book/test_word2vec.py +++ b/python/paddle/fluid/tests/book/test_word2vec.py @@ -250,7 +250,7 @@ def inject_test_method(use_cuda, is_sparse, is_parallel): for use_cuda in (False, True): for is_sparse in (False, True): - for is_parallel in (False, True): + for is_parallel in (False, ): inject_test_method(use_cuda, is_sparse, is_parallel) if __name__ == '__main__': From fcde2b2725566a9cde0c8930d2e80e6a044d6784 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 17 Dec 2018 10:29:59 +0800 Subject: [PATCH 09/29] add ForRangeIn --- paddle/fluid/operators/optimizers/adam_op.h | 7 ++- paddle/fluid/platform/for_range.h | 55 +++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 5870557bb7..e8b977e2d9 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -359,14 +359,17 @@ class AdamOpKernel : public framework::OpKernel { param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, grad_merge.rows().size(), lazy_mode); if (lazy_mode) { + std::vector id_vector; size_t row_count = grad_merge.rows().size(); for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t offset = 0; offset < row_numel; ++offset) { size_t i = rows[row_index] * row_numel + offset; - T g = grad_data[row_index * row_numel + offset]; - functor.adam_update(i, g); + id_vector.push_back(i); } } + platform::ForRangeIn for_range_in( + static_cast(ctx.device_context()), id_vector); + for_range_in(functor); } else { platform::ForRange for_range( static_cast(ctx.device_context()), diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index c153e80fe4..9fbaa36723 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -13,11 +13,38 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + +#include + +#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { namespace platform { +template +struct ForRangeIn { + ForRangeIn(const DeviceContext& dev_ctx, std::vector range); + + template + void operator()(Function func) const; +}; + +template <> +struct ForRangeIn { + ForRangeIn(const CPUDeviceContext& dev_ctx, std::vector range) + : range_(range) {} + + template + void operator()(Function func) const { + for (auto i : range_) { + func(i); + } + } + + std::vector range_; +}; + template struct ForRange { ForRange(const DeviceContext& dev_ctx, size_t limit); @@ -79,6 +106,34 @@ struct ForRange { int limit_; }; +template +__global__ static void ForRangeInElemwiseOp(Function func, T* vector, + int vector_size) { + size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx < vector_size) { + func(vector[idx]); + } +} + +template <> +struct ForRangeIn { + ForRange(const CUDADeviceContext& dev_ctx, std::vector range) + : dev_ctx_(dev_ctx), range_(range) {} + + template + inline void operator()(Function func) const { + constexpr int num_threads = 1024; + int block_size = range_.size() <= num_threads ? limit_ : num_threads; + int grid_size = (range_.size() + num_threads - 1) / num_threads; + + ForRangeInElemwiseOp<<>>( + func, range_.data(), range_.size()); + } + + const CUDADeviceContext& dev_ctx_; + framework::Vector range_; +}; + #endif } // namespace platform From 763e8fdf02ebe00b845680b264b7a5c6a56b61ae Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 17 Dec 2018 11:17:10 +0800 Subject: [PATCH 10/29] fix compile error --- paddle/fluid/platform/for_range.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index 9fbaa36723..a767bf9299 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -117,17 +117,18 @@ __global__ static void ForRangeInElemwiseOp(Function func, T* vector, template <> struct ForRangeIn { - ForRange(const CUDADeviceContext& dev_ctx, std::vector range) + ForRangeIn(const CUDADeviceContext& dev_ctx, std::vector range) : dev_ctx_(dev_ctx), range_(range) {} template inline void operator()(Function func) const { constexpr int num_threads = 1024; - int block_size = range_.size() <= num_threads ? limit_ : num_threads; + int range_size = range_.size(); + int block_size = range_size <= num_threads ? range_size : num_threads; int grid_size = (range_.size() + num_threads - 1) / num_threads; ForRangeInElemwiseOp<<>>( - func, range_.data(), range_.size()); + func, range_.data(), range_size); } const CUDADeviceContext& dev_ctx_; From 96604fda1016bd91c25ace7e7510f0a746ed3797 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 17 Dec 2018 16:59:20 +0800 Subject: [PATCH 11/29] fix gpu data test=develop --- paddle/fluid/operators/optimizers/adam_op.h | 3 ++- paddle/fluid/platform/for_range.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index e8b977e2d9..01d3d60054 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -361,9 +361,10 @@ class AdamOpKernel : public framework::OpKernel { if (lazy_mode) { std::vector id_vector; size_t row_count = grad_merge.rows().size(); + std::vector cpu_rows(grad_merge.rows()); for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t offset = 0; offset < row_numel; ++offset) { - size_t i = rows[row_index] * row_numel + offset; + size_t i = cpu_rows[row_index] * row_numel + offset; id_vector.push_back(i); } } diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index a767bf9299..ab00d8b8f5 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -128,7 +128,7 @@ struct ForRangeIn { int grid_size = (range_.size() + num_threads - 1) / num_threads; ForRangeInElemwiseOp<<>>( - func, range_.data(), range_size); + func, range_.CUDAData(dev_ctx_.GetPlace()), range_size); } const CUDADeviceContext& dev_ctx_; From 1141db811455eadc6b44bbb3785b0510f1f51870 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 17 Dec 2018 19:32:32 +0800 Subject: [PATCH 12/29] update test_adam_op test=develop --- paddle/fluid/operators/optimizers/adam_op.h | 1 + .../fluid/tests/unittests/test_adam_op.py | 30 ++++++++++++------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 01d3d60054..8fc6689ff1 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -358,6 +358,7 @@ class AdamOpKernel : public framework::OpKernel { lr.template data(), grad_data, param.template data(), param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, grad_merge.rows().size(), lazy_mode); + VLOG(3) << "lazy_mode :" << lazy_mode; if (lazy_mode) { std::vector id_vector; size_t row_count = grad_merge.rows().size(); diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index 461196689c..ff7fc5100e 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -219,14 +219,25 @@ def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad, moment2_out = np.zeros(shape=[height, row_numel]) param_out = np.zeros(shape=[height, row_numel]) - for idx, row_id in enumerate(rows): + def update_row(row_id, update_value): moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1 - ) * np_grad[idx] + ) * update_value moment2_out[row_id] = beta2 * moment2[row_id] + ( - 1 - beta2) * np.square(np_grad[idx]) + 1 - beta2) * np.square(update_value) lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) param_out[row_id] = param[row_id] - lr_t * (moment1_out[row_id] / ( np.sqrt(moment2_out[row_id]) + epsilon)) + + if lazy_mode: + for idx, row_id in enumerate(rows): + update_row(row_id, np_grad[idx]) + else: + for row_id in range(param_out.shape[0]): + update_value = np.zeros(np_grad[0].shape).astype("float32") + if row_id in rows: + update_value = np_grad[rows.index(row_id)] + update_row(row_id, update_value) + return param_out, moment1_out, moment2_out @@ -249,6 +260,7 @@ class TestSparseAdamOp(unittest.TestCase): 'Beta2Pow': np.array([beta2**10]).astype("float32"), "LearningRate": np.full((1), 2.0).astype("float32") } + self.init_output = np.full((height, row_numel), 0.0).astype("float32") self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} grad_selected_rows = scope.var('Grad').get_selected_rows() @@ -286,7 +298,7 @@ class TestSparseAdamOp(unittest.TestCase): op_args[s] = s for s in self.outputs: var = scope.var(s).get_tensor() - var.set(self.outputs[s], place) + var.set(self.init_output, place) op_args[s] = s for k in self.attrs: op_args[k] = self.attrs[k] @@ -300,13 +312,9 @@ class TestSparseAdamOp(unittest.TestCase): actual = np.array(out_var) actual = actual.reshape([actual.size]) np_array = np_array.reshape([np_array.size]) - for idx, row_id in enumerate(self.rows): - j = 0 - while j < self.row_numel: - pos = row_id * self.row_numel + j - self.assertLess((actual[pos] - np_array[pos]) / actual[pos], - 0.00001) - j += 1 + + for i in range(np_array.size): + self.assertLess((actual[i] - np_array[i]), 0.00001) def test_sparse_adam(self): places = [core.CPUPlace()] From fd152289fa694b99704e4821a71b0c1f160896aa Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 17 Dec 2018 22:14:11 +0800 Subject: [PATCH 13/29] clean for range in test=develop --- paddle/fluid/operators/optimizers/adam_op.h | 14 +++--- paddle/fluid/platform/for_range.h | 52 --------------------- 2 files changed, 6 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 8fc6689ff1..4f212bb69a 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -227,8 +227,10 @@ struct SparseAdamFunctor { inline HOSTDEVICE void operator()(size_t i) const { auto row_idx = math::BinarySearch(rows_, row_count_, i / row_numel_); - T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; - adam_update(i, g); + if (!(lazy_mode_ && row_idx < 0)) { + T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; + adam_update(i, g); + } } }; @@ -359,19 +361,15 @@ class AdamOpKernel : public framework::OpKernel { param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, grad_merge.rows().size(), lazy_mode); VLOG(3) << "lazy_mode :" << lazy_mode; - if (lazy_mode) { - std::vector id_vector; + if (lazy_mode && platform::is_cpu_place(ctx.GetPlace())) { size_t row_count = grad_merge.rows().size(); std::vector cpu_rows(grad_merge.rows()); for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t offset = 0; offset < row_numel; ++offset) { size_t i = cpu_rows[row_index] * row_numel + offset; - id_vector.push_back(i); + functor.adam_update(i, grad_data[row_index * row_numel + offset]); } } - platform::ForRangeIn for_range_in( - static_cast(ctx.device_context()), id_vector); - for_range_in(functor); } else { platform::ForRange for_range( static_cast(ctx.device_context()), diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index ab00d8b8f5..910d1669f2 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -22,29 +22,6 @@ limitations under the License. */ namespace paddle { namespace platform { -template -struct ForRangeIn { - ForRangeIn(const DeviceContext& dev_ctx, std::vector range); - - template - void operator()(Function func) const; -}; - -template <> -struct ForRangeIn { - ForRangeIn(const CPUDeviceContext& dev_ctx, std::vector range) - : range_(range) {} - - template - void operator()(Function func) const { - for (auto i : range_) { - func(i); - } - } - - std::vector range_; -}; - template struct ForRange { ForRange(const DeviceContext& dev_ctx, size_t limit); @@ -106,35 +83,6 @@ struct ForRange { int limit_; }; -template -__global__ static void ForRangeInElemwiseOp(Function func, T* vector, - int vector_size) { - size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx < vector_size) { - func(vector[idx]); - } -} - -template <> -struct ForRangeIn { - ForRangeIn(const CUDADeviceContext& dev_ctx, std::vector range) - : dev_ctx_(dev_ctx), range_(range) {} - - template - inline void operator()(Function func) const { - constexpr int num_threads = 1024; - int range_size = range_.size(); - int block_size = range_size <= num_threads ? range_size : num_threads; - int grid_size = (range_.size() + num_threads - 1) / num_threads; - - ForRangeInElemwiseOp<<>>( - func, range_.CUDAData(dev_ctx_.GetPlace()), range_size); - } - - const CUDADeviceContext& dev_ctx_; - framework::Vector range_; -}; - #endif } // namespace platform From 56686d0f34db008238095331b6f981d8f4e5d3d4 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 17 Dec 2018 22:16:52 +0800 Subject: [PATCH 14/29] clean code test=develop --- paddle/fluid/platform/for_range.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index 910d1669f2..c153e80fe4 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -13,10 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once - -#include - -#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { From fe3995d33527e8503739b6de3dd555fa3ad35073 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 18 Dec 2018 07:15:42 +0800 Subject: [PATCH 15/29] refine code test=develop --- paddle/fluid/operators/optimizers/adam_op.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 4f212bb69a..f214d8272f 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -227,7 +227,9 @@ struct SparseAdamFunctor { inline HOSTDEVICE void operator()(size_t i) const { auto row_idx = math::BinarySearch(rows_, row_count_, i / row_numel_); - if (!(lazy_mode_ && row_idx < 0)) { + if (lazy_mode_ && row_idx < 0) { + return; + } else { T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; adam_update(i, g); } From 52bc4ee75adf64e449dfdbbdbbe3e41cdc593bdc Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sun, 16 Dec 2018 20:27:17 +0800 Subject: [PATCH 16/29] delay infer scope test=develop --- paddle/fluid/framework/operator.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index a62afe248b..86e1713b02 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -703,8 +703,6 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { - RuntimeInferShapeContext infer_shape_ctx(*this, scope); - this->InferShape(&infer_shape_ctx); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -758,6 +756,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, dev_ctx = pool.Get(expected_kernel_key.place_); } + RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope); + this->InferShape(&infer_shape_ctx); kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx)); if (!transfered_inplace_vars.empty()) { From bbff0df320f0f68634a5ae3c4d9507b52a1134f7 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sun, 16 Dec 2018 21:49:25 +0800 Subject: [PATCH 17/29] try cache variables test=develop --- paddle/fluid/framework/ngraph_operator.cc | 15 +++++++- paddle/fluid/framework/operator.cc | 47 ++++++++++++++++------- paddle/fluid/framework/operator.h | 22 ++++++++--- paddle/fluid/framework/type_defs.h | 3 ++ 4 files changed, 66 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index e2cdfc845f..e37f0915c5 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -278,7 +278,20 @@ std::shared_ptr NgraphEngine::backend_ = ngraph::runtime::Backend::create("CPU"); void NgraphEngine::GetNgInputShape(std::shared_ptr op) { - op->RuntimeInferShape(scope_, place_); + RuntimeContext ctx; + for (auto& var_name_item : op->Inputs()) { + std::vector input_vars = ctx.inputs[var_name_item.first]; + for (auto& var_name : var_name_item.second) { + input_vars.push_back(scope_.FindVar(var_name)); + } + } + for (auto& var_name_item : op->Outputs()) { + std::vector output_vars = ctx.outputs[var_name_item.first]; + for (auto& var_name : var_name_item.second) { + output_vars.push_back(scope_.FindVar(var_name)); + } + } + op->RuntimeInferShape(scope_, place_, ctx); for (auto& var_name_item : op->Inputs()) { for (auto& var_name : var_name_item.second) { auto* var = scope_.FindVar(var_name); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 86e1713b02..79e3d29a63 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -477,23 +477,22 @@ bool OpSupportGPU(const std::string& op_type) { class RuntimeInferShapeContext : public InferShapeContext { public: - RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) - : op_(op), scope_(scope) {} + RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope, + const RuntimeContext& ctx) + : op_(op), scope_(scope), ctx_(ctx) {} bool HasInput(const std::string& name) const override { // has only one input - const auto& ins = op_.Inputs(); + const auto& ins = ctx_.inputs; auto it = ins.find(name); if (it == ins.end()) { return false; } const auto& in = it->second; - if (in.size() == 0 || in[0] == kEmptyVarName) { - return false; - } + if (in.size() == 0) return false; PADDLE_ENFORCE_EQ(in.size(), 1UL, "Input %s should not have more than one inputs", name); - return scope_.FindVar(in[0]) != nullptr; + return in[0] != nullptr; } bool HasOutput(const std::string& name) const override { @@ -678,6 +677,7 @@ class RuntimeInferShapeContext : public InferShapeContext { private: const OperatorBase& op_; const Scope& scope_; + const RuntimeContext& ctx_; }; static void CheckTensorNANOrInf(const std::string& name, @@ -696,8 +696,9 @@ static void CheckTensorNANOrInf(const std::string& name, } void OperatorWithKernel::RuntimeInferShape(const Scope& scope, - const platform::Place& place) const { - RuntimeInferShapeContext infer_shape_ctx(*this, scope); + const platform::Place& place, + const RuntimeContext& ctx) const { + RuntimeInferShapeContext infer_shape_ctx(*this, scope, ctx); this->InferShape(&infer_shape_ctx); } @@ -743,10 +744,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope, KernelTypeToString(expected_kernel_key)); } + RuntimeContext ctx; // do data transformScope &transfer_scope; std::vector transfered_inplace_vars; auto* transfer_scope = - TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars); + PrepareData(scope, expected_kernel_key, &transfered_inplace_vars, &ctx); // exec scope is the scope that kernel actually executed on. const Scope& exec_scope = @@ -756,7 +758,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, dev_ctx = pool.Get(expected_kernel_key.place_); } - RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope); + RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); this->InferShape(&infer_shape_ctx); kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx)); @@ -797,13 +799,20 @@ void OperatorWithKernel::TransferInplaceVarsBack( } } -Scope* OperatorWithKernel::TryTransferData( +Scope* OperatorWithKernel::PrepareData( const Scope& scope, const OpKernelType& expected_kernel_key, - std::vector* transfered_inplace_vars) const { + std::vector* transfered_inplace_vars, + RuntimeContext* ctx) const { Scope* new_scope = nullptr; for (auto& var_name_item : Inputs()) { - for (auto& var_name : var_name_item.second) { + std::vector& input_vars = ctx->inputs[var_name_item.first]; + input_vars.resize(var_name_item.second.size()); + + for (size_t i = 0; i < var_name_item.second.size(); ++i) { + auto& var_name = var_name_item.second[i]; auto* var = scope.FindVar(var_name); + input_vars[i] = var; + // Only tensor can be tranfer to another device. if (var == nullptr || !VarIsTensor(*var)) { continue; @@ -851,12 +860,22 @@ Scope* OperatorWithKernel::TryTransferData( } auto* trans_var = new_scope->Var(var_name); + input_vars[i] = var; Tensor out; TransformData(expected_kernel_key, kernel_type_for_var, *tensor_in, &out); SetTensorToVariable(*var, out, trans_var); } } + for (auto& var_name_item : Outputs()) { + std::vector& output_vars = ctx->outputs[var_name_item.first]; + output_vars.resize(var_name_item.second.size()); + + for (size_t i = 0; i < var_name_item.second.size(); ++i) { + auto& var_name = var_name_item.second[i]; + output_vars[i] = scope.FindVar(var_name); + } + } return new_scope; } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 0a6a28a5bc..438ae25398 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -70,6 +70,14 @@ Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); class OperatorBase; class ExecutionContext; +class RuntimeContext { + public: + RuntimeContext() {} + + VariableValueMap inputs; + VariableValueMap outputs; +}; + /** * OperatorBase has the basic elements that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User @@ -129,7 +137,8 @@ class OperatorBase { void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; } virtual void RuntimeInferShape(const Scope& scope, - const platform::Place& place) const {} + const platform::Place& place, + const RuntimeContext& ctx) const {} protected: std::string type_; @@ -350,8 +359,8 @@ class OperatorWithKernel : public OperatorBase { OpInfoMap::Instance().Get(Type()).infer_shape_(ctx); } - void RuntimeInferShape(const Scope& scope, - const platform::Place& place) const override; + void RuntimeInferShape(const Scope& scope, const platform::Place& place, + const RuntimeContext& ctx) const override; protected: virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; @@ -371,9 +380,10 @@ class OperatorWithKernel : public OperatorBase { * * * transfered_inplace_vars is a output vector. */ - Scope* TryTransferData( - const Scope& scope, const OpKernelType& expected_kernel_key, - std::vector* transfered_inplace_vars) const; + Scope* PrepareData(const Scope& scope, + const OpKernelType& expected_kernel_key, + std::vector* transfered_inplace_vars, + RuntimeContext* ctx) const; void TransferInplaceVarsBack(const Scope& scope, const std::vector& inplace_vars, diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 2de6233a9e..938e2024c3 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -28,8 +28,11 @@ class OperatorBase; class OpDesc; class InferShapeContext; class BlockDesc; +class Variable; using VariableNameMap = std::map>; +// TODO(panyx0718): Replace vector with something like gtl::Vector. +using VariableValueMap = std::map>; // The order should be as same as framework.proto using Attribute = From 840e6729e224d867386bdfc9ff12af4b71ee7188 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 17 Dec 2018 21:27:56 +0800 Subject: [PATCH 18/29] inject context test=develop --- paddle/fluid/framework/ngraph_operator.cc | 14 +------- paddle/fluid/framework/operator.cc | 36 +++++++++++-------- paddle/fluid/framework/operator.h | 9 +++-- .../fluid/operators/beam_search_decode_op.cc | 3 +- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index e37f0915c5..23f681ce88 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -278,19 +278,7 @@ std::shared_ptr NgraphEngine::backend_ = ngraph::runtime::Backend::create("CPU"); void NgraphEngine::GetNgInputShape(std::shared_ptr op) { - RuntimeContext ctx; - for (auto& var_name_item : op->Inputs()) { - std::vector input_vars = ctx.inputs[var_name_item.first]; - for (auto& var_name : var_name_item.second) { - input_vars.push_back(scope_.FindVar(var_name)); - } - } - for (auto& var_name_item : op->Outputs()) { - std::vector output_vars = ctx.outputs[var_name_item.first]; - for (auto& var_name : var_name_item.second) { - output_vars.push_back(scope_.FindVar(var_name)); - } - } + RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_); op->RuntimeInferShape(scope_, place_, ctx); for (auto& var_name_item : op->Inputs()) { for (auto& var_name : var_name_item.second) { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 79e3d29a63..461d357527 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -137,6 +137,23 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { } } +RuntimeContext::RuntimeContext(const VariableNameMap& innames, + const VariableNameMap& outnames, + const Scope& scope) { + for (auto& var_name_item : innames) { + std::vector& input_vars = inputs[var_name_item.first]; + for (auto& var_name : var_name_item.second) { + input_vars.push_back(scope.FindVar(var_name)); + } + } + for (auto& var_name_item : outnames) { + std::vector& output_vars = outputs[var_name_item.first]; + for (auto& var_name : var_name_item.second) { + output_vars.push_back(scope.FindVar(var_name)); + } + } +} + void OperatorBase::Run(const Scope& scope, const platform::Place& place) { VLOG(4) << place << " " << DebugStringEx(&scope); if (platform::is_gpu_place(place)) { @@ -704,6 +721,7 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { + RuntimeContext ctx(Inputs(), Outputs(), scope); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -717,15 +735,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, OpKernelMap& kernels = kernels_iter->second; - // TODO(dzhwinter) : kernel fallback mechanism will be added when all the - // transform functions are ready. - - // for (auto& candidate : kKernelPriority) { - // Do selection - // } - - auto expected_kernel_key = - this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx)); + auto expected_kernel_key = this->GetExpectedKernelType( + ExecutionContext(*this, scope, *dev_ctx, ctx)); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; auto kernel_iter = kernels.find(expected_kernel_key); @@ -744,7 +755,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, KernelTypeToString(expected_kernel_key)); } - RuntimeContext ctx; // do data transformScope &transfer_scope; std::vector transfered_inplace_vars; auto* transfer_scope = @@ -760,7 +770,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); this->InferShape(&infer_shape_ctx); - kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx)); + kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, ctx)); if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transfered. @@ -784,6 +794,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } } } + void OperatorWithKernel::TransferInplaceVarsBack( const Scope& scope, const std::vector& inplace_vars, const Scope& transfer_scope) const { @@ -806,7 +817,6 @@ Scope* OperatorWithKernel::PrepareData( Scope* new_scope = nullptr; for (auto& var_name_item : Inputs()) { std::vector& input_vars = ctx->inputs[var_name_item.first]; - input_vars.resize(var_name_item.second.size()); for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto& var_name = var_name_item.second[i]; @@ -869,8 +879,6 @@ Scope* OperatorWithKernel::PrepareData( } for (auto& var_name_item : Outputs()) { std::vector& output_vars = ctx->outputs[var_name_item.first]; - output_vars.resize(var_name_item.second.size()); - for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto& var_name = var_name_item.second[i]; output_vars[i] = scope.FindVar(var_name); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 438ae25398..e359414d15 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -72,7 +72,8 @@ class ExecutionContext; class RuntimeContext { public: - RuntimeContext() {} + RuntimeContext(const VariableNameMap& innames, + const VariableNameMap& outnames, const Scope& scope); VariableValueMap inputs; VariableValueMap outputs; @@ -165,8 +166,9 @@ class OperatorBase { class ExecutionContext { public: ExecutionContext(const OperatorBase& op, const Scope& scope, - const platform::DeviceContext& device_context) - : op_(op), scope_(scope), device_context_(device_context) {} + const platform::DeviceContext& device_context, + const RuntimeContext& ctx) + : op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {} const OperatorBase& op() const { return op_; } @@ -295,6 +297,7 @@ class ExecutionContext { const OperatorBase& op_; const Scope& scope_; const platform::DeviceContext& device_context_; + const RuntimeContext& ctx_; }; template <> diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index ae9765b761..7f2bde55c9 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -122,7 +122,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& dev_ctx = *pool.Get(dev_place); - framework::ExecutionContext ctx(*this, scope, dev_ctx); + framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope); + framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx); const LoDTensorArray* ids = ctx.Input("Ids"); const LoDTensorArray* scores = ctx.Input("Scores"); From eaf8ba35b519b780629a7108d08ffd3895ac18fe Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 18 Dec 2018 09:42:57 +0800 Subject: [PATCH 19/29] change input test=develop --- paddle/fluid/framework/operator.cc | 50 ++++++++++++++++++++++++++++++ paddle/fluid/framework/operator.h | 33 +++++++++++++++----- paddle/fluid/operators/prelu_op.cc | 2 +- 3 files changed, 76 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 461d357527..87f61f3afc 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -143,12 +143,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames, for (auto& var_name_item : innames) { std::vector& input_vars = inputs[var_name_item.first]; for (auto& var_name : var_name_item.second) { + LOG(ERROR) << "first in " << var_name_item.first << ":" << var_name; input_vars.push_back(scope.FindVar(var_name)); } } for (auto& var_name_item : outnames) { std::vector& output_vars = outputs[var_name_item.first]; for (auto& var_name : var_name_item.second) { + LOG(ERROR) << "first out " << var_name_item.first << ":" << var_name; output_vars.push_back(scope.FindVar(var_name)); } } @@ -429,11 +431,52 @@ bool ExecutionContext::HasOutput(const std::string& name) const { return var != nullptr; } +const Variable* ExecutionContext::InputVar(const std::string& name) const { + auto it = ctx_.inputs.find(name); + if (it == ctx_.inputs.end()) return nullptr; + + PADDLE_ENFORCE_LE(it->second.size(), 1UL, + "Operator %s's input %s should contain only one variable.", + op_.Type(), name); + return it->second.empty() ? nullptr : it->second[0]; +} + +Variable* ExecutionContext::OutputVar(const std::string& name) const { + auto opt = op_.Output(name); + return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt); +} + +const Variable* ExecutionContext::FastInputVar(const std::string& name) const { + auto it = ctx_.inputs.find(name); + if (it == ctx_.inputs.end()) return nullptr; + + PADDLE_ENFORCE_LE(it->second.size(), 1UL, + "Operator %s's input %s should contain only one variable.", + op_.Type(), name); + return it->second.empty() ? nullptr : it->second[0]; +} + +Variable* ExecutionContext::FastOutputVar(const std::string& name) const { + auto it = ctx_.outputs.find(name); + if (it == ctx_.outputs.end()) return nullptr; + + PADDLE_ENFORCE_LE(it->second.size(), 1UL, + "Operator %s's output %s should contain only one variable.", + op_.Type(), name); + return it->second.empty() ? nullptr : it->second[0]; +} + template <> const Tensor* ExecutionContext::Input(const std::string& name) const { return Input(name); } +template <> +const Tensor* ExecutionContext::FastInput( + const std::string& name) const { + return FastInput(name); +} + template <> const std::vector ExecutionContext::MultiInput( const std::string& name) const { @@ -458,6 +501,11 @@ Tensor* ExecutionContext::Output(const std::string& name) const { return Output(name); } +template <> +Tensor* ExecutionContext::FastOutput(const std::string& name) const { + return FastOutput(name); +} + template <> std::vector ExecutionContext::MultiOutput( const std::string& name) const { @@ -822,6 +870,7 @@ Scope* OperatorWithKernel::PrepareData( auto& var_name = var_name_item.second[i]; auto* var = scope.FindVar(var_name); input_vars[i] = var; + LOG(ERROR) << "second in " << var_name_item.first << ":" << var_name; // Only tensor can be tranfer to another device. if (var == nullptr || !VarIsTensor(*var)) { @@ -882,6 +931,7 @@ Scope* OperatorWithKernel::PrepareData( for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto& var_name = var_name_item.second[i]; output_vars[i] = scope.FindVar(var_name); + LOG(ERROR) << "second out " << var_name_item.first << ":" << var_name; } } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index e359414d15..0aad91dbee 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -191,15 +191,9 @@ class ExecutionContext { return op_.Outputs(name).size(); } - const Variable* InputVar(const std::string& name) const { - auto ipt = op_.Input(name); - return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - } + const Variable* InputVar(const std::string& name) const; - Variable* OutputVar(const std::string& name) const { - auto opt = op_.Output(name); - return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt); - } + Variable* OutputVar(const std::string& name) const; const std::vector MultiInputVar( const std::string& name) const { @@ -238,6 +232,22 @@ class ExecutionContext { return var == nullptr ? nullptr : var->GetMutable(); } + template + const T* FastInput(const std::string& name) const { + auto* var = FastInputVar(name); + return var == nullptr ? nullptr : &var->Get(); + } + + template + T* FastOutput(const std::string& name) const { + auto var = FastOutputVar(name); + return var == nullptr ? nullptr : var->GetMutable(); + } + + const Variable* FastInputVar(const std::string& name) const; + + Variable* FastOutputVar(const std::string& name) const; + template const std::vector MultiInput(const std::string& name) const { auto names = op_.Inputs(name); @@ -303,6 +313,10 @@ class ExecutionContext { template <> const Tensor* ExecutionContext::Input(const std::string& name) const; +template <> +const Tensor* ExecutionContext::FastInput( + const std::string& name) const; + template <> const std::vector ExecutionContext::MultiInput( const std::string& name) const; @@ -310,6 +324,9 @@ const std::vector ExecutionContext::MultiInput( template <> Tensor* ExecutionContext::Output(const std::string& name) const; +template <> +Tensor* ExecutionContext::FastOutput(const std::string& name) const; + template <> std::vector ExecutionContext::MultiOutput( const std::string& name) const; diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 62c55c4f55..b6155ed3dd 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -56,7 +56,7 @@ class PReluOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), + return framework::OpKernelType(ctx.FastInput("X")->type(), ctx.device_context()); } }; From fb8ae30331f42b6b9ef67c80e0ccb3fffcbf9836 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 18 Dec 2018 12:35:45 +0800 Subject: [PATCH 20/29] fix test=develop --- paddle/fluid/framework/operator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 87f61f3afc..807667e684 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -919,7 +919,7 @@ Scope* OperatorWithKernel::PrepareData( } auto* trans_var = new_scope->Var(var_name); - input_vars[i] = var; + input_vars[i] = trans_var; Tensor out; TransformData(expected_kernel_key, kernel_type_for_var, *tensor_in, &out); From 70981f5d799b5ab1593743b6ec88af6c40698a3b Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 18 Dec 2018 15:30:23 +0800 Subject: [PATCH 21/29] clean test=develop --- paddle/fluid/framework/operator.cc | 36 ++++++++++++------------------ paddle/fluid/framework/operator.h | 16 ++++++------- paddle/fluid/operators/prelu_op.cc | 2 +- 3 files changed, 23 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 807667e684..7d5a6198a0 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -143,14 +143,12 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames, for (auto& var_name_item : innames) { std::vector& input_vars = inputs[var_name_item.first]; for (auto& var_name : var_name_item.second) { - LOG(ERROR) << "first in " << var_name_item.first << ":" << var_name; input_vars.push_back(scope.FindVar(var_name)); } } for (auto& var_name_item : outnames) { std::vector& output_vars = outputs[var_name_item.first]; for (auto& var_name : var_name_item.second) { - LOG(ERROR) << "first out " << var_name_item.first << ":" << var_name; output_vars.push_back(scope.FindVar(var_name)); } } @@ -441,22 +439,13 @@ const Variable* ExecutionContext::InputVar(const std::string& name) const { return it->second.empty() ? nullptr : it->second[0]; } -Variable* ExecutionContext::OutputVar(const std::string& name) const { - auto opt = op_.Output(name); - return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt); -} - -const Variable* ExecutionContext::FastInputVar(const std::string& name) const { - auto it = ctx_.inputs.find(name); - if (it == ctx_.inputs.end()) return nullptr; - - PADDLE_ENFORCE_LE(it->second.size(), 1UL, - "Operator %s's input %s should contain only one variable.", - op_.Type(), name); - return it->second.empty() ? nullptr : it->second[0]; +const Variable* ExecutionContext::LegacyInputVar( + const std::string& name) const { + auto ipt = op_.Input(name); + return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); } -Variable* ExecutionContext::FastOutputVar(const std::string& name) const { +Variable* ExecutionContext::OutputVar(const std::string& name) const { auto it = ctx_.outputs.find(name); if (it == ctx_.outputs.end()) return nullptr; @@ -466,15 +455,20 @@ Variable* ExecutionContext::FastOutputVar(const std::string& name) const { return it->second.empty() ? nullptr : it->second[0]; } +Variable* ExecutionContext::LegacyOutputVar(const std::string& name) const { + auto opt = op_.Output(name); + return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt); +} + template <> const Tensor* ExecutionContext::Input(const std::string& name) const { return Input(name); } template <> -const Tensor* ExecutionContext::FastInput( +const Tensor* ExecutionContext::LegacyInput( const std::string& name) const { - return FastInput(name); + return LegacyInput(name); } template <> @@ -502,8 +496,8 @@ Tensor* ExecutionContext::Output(const std::string& name) const { } template <> -Tensor* ExecutionContext::FastOutput(const std::string& name) const { - return FastOutput(name); +Tensor* ExecutionContext::LegacyOutput(const std::string& name) const { + return LegacyOutput(name); } template <> @@ -870,7 +864,6 @@ Scope* OperatorWithKernel::PrepareData( auto& var_name = var_name_item.second[i]; auto* var = scope.FindVar(var_name); input_vars[i] = var; - LOG(ERROR) << "second in " << var_name_item.first << ":" << var_name; // Only tensor can be tranfer to another device. if (var == nullptr || !VarIsTensor(*var)) { @@ -931,7 +924,6 @@ Scope* OperatorWithKernel::PrepareData( for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto& var_name = var_name_item.second[i]; output_vars[i] = scope.FindVar(var_name); - LOG(ERROR) << "second out " << var_name_item.first << ":" << var_name; } } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 0aad91dbee..39190d07b4 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -233,20 +233,20 @@ class ExecutionContext { } template - const T* FastInput(const std::string& name) const { - auto* var = FastInputVar(name); + const T* LegacyInput(const std::string& name) const { + auto* var = LegacyInputVar(name); return var == nullptr ? nullptr : &var->Get(); } template - T* FastOutput(const std::string& name) const { - auto var = FastOutputVar(name); + T* LegacyOutput(const std::string& name) const { + auto var = LegacyOutputVar(name); return var == nullptr ? nullptr : var->GetMutable(); } - const Variable* FastInputVar(const std::string& name) const; + const Variable* LegacyInputVar(const std::string& name) const; - Variable* FastOutputVar(const std::string& name) const; + Variable* LegacyOutputVar(const std::string& name) const; template const std::vector MultiInput(const std::string& name) const { @@ -314,7 +314,7 @@ template <> const Tensor* ExecutionContext::Input(const std::string& name) const; template <> -const Tensor* ExecutionContext::FastInput( +const Tensor* ExecutionContext::LegacyInput( const std::string& name) const; template <> @@ -325,7 +325,7 @@ template <> Tensor* ExecutionContext::Output(const std::string& name) const; template <> -Tensor* ExecutionContext::FastOutput(const std::string& name) const; +Tensor* ExecutionContext::LegacyOutput(const std::string& name) const; template <> std::vector ExecutionContext::MultiOutput( diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index b6155ed3dd..62c55c4f55 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -56,7 +56,7 @@ class PReluOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.FastInput("X")->type(), + return framework::OpKernelType(ctx.Input("X")->type(), ctx.device_context()); } }; From f897bd16c0e4deb683075e137e7bfe5890488205 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 18 Dec 2018 15:40:23 +0800 Subject: [PATCH 22/29] clean test=develop --- paddle/fluid/framework/operator.cc | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 7d5a6198a0..8c83748668 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -812,6 +812,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); this->InferShape(&infer_shape_ctx); + // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext + // not Scope. Imperative mode only pass inputs and get outputs. kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, ctx)); if (!transfered_inplace_vars.empty()) { @@ -919,13 +921,6 @@ Scope* OperatorWithKernel::PrepareData( SetTensorToVariable(*var, out, trans_var); } } - for (auto& var_name_item : Outputs()) { - std::vector& output_vars = ctx->outputs[var_name_item.first]; - for (size_t i = 0; i < var_name_item.second.size(); ++i) { - auto& var_name = var_name_item.second[i]; - output_vars[i] = scope.FindVar(var_name); - } - } return new_scope; } From 19ebd8b4cfffa2ba42c68fa4c761c54e857c6566 Mon Sep 17 00:00:00 2001 From: peizhilin Date: Tue, 18 Dec 2018 20:20:19 +0800 Subject: [PATCH 23/29] add ctc support for windows --- CMakeLists.txt | 4 ++-- cmake/external/warpctc.cmake | 30 ++++++++++++++++++++++----- cmake/operators.cmake | 2 +- paddle/fluid/operators/CMakeLists.txt | 4 +--- paddle/fluid/platform/port.h | 1 - python/paddle/fluid/__init__.py | 10 +++++++-- python/paddle/fluid/framework.py | 18 +++++++++++----- python/setup.py.in | 9 ++++---- 8 files changed, 55 insertions(+), 23 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cb646d3ce5..c31f51a3f7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -208,10 +208,10 @@ include(external/xxhash) # download xxhash include(external/dlpack) include(external/snappy) # download snappy include(external/snappystream) # download snappystream +include(external/warpctc) # download, build, install warpctc if (NOT WIN32) -# there is no official support of warpctc, nccl, cupti in windows -include(external/warpctc) # download, build, install warpctc +# there is no official support of nccl, cupti in windows include(cupti) include(external/gzstream) endif (NOT WIN32) diff --git a/cmake/external/warpctc.cmake b/cmake/external/warpctc.cmake index 07e1137e16..7b937c93fe 100644 --- a/cmake/external/warpctc.cmake +++ b/cmake/external/warpctc.cmake @@ -26,25 +26,33 @@ SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include" # Used in unit test test_WarpCTCLayer SET(WARPCTC_LIB_DIR "${WARPCTC_INSTALL_DIR}/lib" CACHE PATH "Warp-ctc Library Directory" FORCE) -SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}" - CACHE FILEPATH "Warp-ctc Library" FORCE) -IF(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" ) +IF(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR WIN32) SET(USE_OMP OFF) ELSE() SET(USE_OMP ON) ENDIF() +IF(WIN32) + SET(WARPCTC_REPOSITORY "https://github.com/wopeizl/warp-ctc.git") +ELSE() + SET(WARPCTC_REPOSITORY "https://github.com/dzhwinter/warp-ctc.git") +ENDIF() + ExternalProject_Add( extern_warpctc ${EXTERNAL_PROJECT_LOG_ARGS} - GIT_REPOSITORY "https://github.com/dzhwinter/warp-ctc.git" + GIT_REPOSITORY ${WARPCTC_REPOSITORY} PREFIX ${WARPCTC_SOURCES_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} - -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} + -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG} + -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} + -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR} -DWITH_GPU=${WITH_GPU} -DWITH_OMP=${USE_OMP} @@ -59,6 +67,18 @@ ExternalProject_Add( -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR} ) +IF(WIN32) + IF(NOT EXISTS "${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}") + add_custom_command(TARGET extern_warpctc POST_BUILD + COMMAND cmake -E copy ${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX} ${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX} + ) + ENDIF() + SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}" + CACHE FILEPATH "Warp-ctc Library" FORCE) +else(WIN32) + SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}" + CACHE FILEPATH "Warp-ctc Library" FORCE) +ENDIF(WIN32) MESSAGE(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}") INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR}) # For warpctc code to include its headers. diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 2ced43f9e6..70d159b4f3 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -84,7 +84,7 @@ function(op_library TARGET) endif() if (WIN32) # remove windows unsupported op, because windows has no nccl, no warpctc such ops. - foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op") + foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op") if ("${TARGET}" STREQUAL "${windows_unsupport_op}") return() endif() diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 257bfc0a3f..d9b0c66e57 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -64,9 +64,7 @@ endif() set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) -if (NOT WIN32) - set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) -endif() +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions) if (WITH_GPU) diff --git a/paddle/fluid/platform/port.h b/paddle/fluid/platform/port.h index ad070171df..c1b81159ac 100644 --- a/paddle/fluid/platform/port.h +++ b/paddle/fluid/platform/port.h @@ -55,7 +55,6 @@ static void *dlsym(void *handle, const char *symbol_name) { static void *dlopen(const char *filename, int flag) { std::string file_name(filename); - file_name.replace(0, file_name.size() - 1, '/', '\\'); HMODULE hModule = LoadLibrary(file_name.c_str()); if (!hModule) { throw std::runtime_error(file_name + " not found."); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index b00510d443..8f3660ca38 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -102,6 +102,13 @@ def __bootstrap__(): import sys import os import platform + + if os.name == 'nt': + third_lib_path = os.path.abspath(os.path.dirname( + __file__)) + os.sep + '..' + os.sep + 'libs' + os.environ['path'] += ';' + third_lib_path + sys.path.append(third_lib_path) + from . import core in_test = 'unittest' in sys.modules @@ -128,13 +135,12 @@ def __bootstrap__(): 'free_idle_memory', 'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb', 'fast_eager_deletion_mode', 'allocator_strategy', 'reader_queue_speed_test_mode', - 'print_sub_graph_dir', 'pe_profile_fname' + 'print_sub_graph_dir', 'pe_profile_fname', 'warpctc_dir' ] if 'Darwin' not in sysstr: read_env_flags.append('use_pinned_memory') if os.name != 'nt': - read_env_flags.append('warpctc_dir') read_env_flags.append('cpu_deterministic') if core.is_compiled_with_dist(): diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index d0bd78454d..b5d603d478 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -16,6 +16,7 @@ from __future__ import print_function import collections import contextlib +import os import re import six import sys @@ -27,11 +28,18 @@ from .proto import framework_pb2 try: from . import core except ImportError as e: - raise ImportError( - """NOTE: You may need to run \"export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH\" - if you encounters \"libmkldnn.so not found\" errors. If you have python - installed in other directory, replace \"/usr/local/lib\" with your own - directory. The original error is: \n""" + cpt.get_exception_message(e)) + if os.name == 'nt': + raise ImportError( + """NOTE: You may need to run \"set PATH=c:\python27\lib:%PATH%\" + if you encounters \"mkldnn.dll not found\" errors. If you have python + installed in other directory, replace \"c:\python27\lib" with your own + directory. The original error is: \n""" + cpt.get_exception_message(e)) + else: + raise ImportError( + """NOTE: You may need to run \"export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH\" + if you encounters \"libmkldnn.so not found\" errors. If you have python + installed in other directory, replace \"/usr/local/lib\" with your own + directory. The original error is: \n""" + cpt.get_exception_message(e)) except Exception as e: raise e from . import unique_name diff --git a/python/setup.py.in b/python/setup.py.in index cf8f28bd25..fefe8fbaa7 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -160,10 +160,11 @@ if '${WITH_FLUID_ONLY}'== 'OFF': # put all thirdparty libraries in paddle.libs libs_path='${PADDLE_BINARY_DIR}/python/paddle/libs' -if os.name != 'nt': - package_data['paddle.libs']= [] - package_data['paddle.libs']=['libwarpctc' + ext_name] - shutil.copy('${WARPCTC_LIBRARIES}', libs_path) + +package_data['paddle.libs']= [] +package_data['paddle.libs']=['libwarpctc' + ext_name] +shutil.copy('${WARPCTC_LIBRARIES}', libs_path) + if '${WITH_MKL}' == 'ON': shutil.copy('${MKLML_LIB}', libs_path) shutil.copy('${MKLML_IOMP_LIB}', libs_path) From ed5bd5e58639bfe8e584f4acdce2398701b12853 Mon Sep 17 00:00:00 2001 From: peizhilin Date: Tue, 18 Dec 2018 20:23:24 +0800 Subject: [PATCH 24/29] test=develop --- paddle/fluid/platform/dynload/CMakeLists.txt | 2 -- paddle/fluid/platform/dynload/cudnn.h | 2 +- paddle/fluid/platform/dynload/dynamic_loader.cc | 2 ++ paddle/fluid/platform/dynload/dynamic_loader.h | 6 ++++++ paddle/fluid/platform/dynload/mklml.h | 2 +- paddle/fluid/platform/dynload/tensorrt.h | 2 +- paddle/fluid/platform/dynload/warpctc.h | 2 +- 7 files changed, 12 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 5939c500c9..07159d4a12 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -16,9 +16,7 @@ if (CUPTI_FOUND) list(APPEND CUDA_SRCS cupti.cc) endif(CUPTI_FOUND) nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader) -if (NOT WIN32) cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc) -endif(NOT WIN32) if (WITH_MKLML) cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml) endif() diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 550fe2edee..2f4f8101e4 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -34,7 +34,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); #define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \ struct DynLoad__##__name { \ template \ - auto operator()(Args... args) -> decltype(__name(args...)) { \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ using cudnn_func = decltype(&::__name); \ std::call_once(cudnn_dso_flag, []() { \ cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \ diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index cc5cda6106..eddebfe92a 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -201,6 +201,8 @@ void* GetCurandDsoHandle() { void* GetWarpCTCDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib"); +#elif defined(_WIN32) + return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "warpctc.dll"); #else return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so"); #endif diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index 84fd2ce998..edb4c649ad 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -18,6 +18,12 @@ namespace paddle { namespace platform { namespace dynload { +#ifndef _WIN32 +#define DECLARE_TYPE(__name, ...) decltype(__name(__VA_ARGS__)) +#else +#define DECLARE_TYPE(__name, ...) decltype(auto) +#endif + void* GetCublasDsoHandle(); void* GetCUDNNDsoHandle(); void* GetCUPTIDsoHandle(); diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index c3f9433503..d0619293ac 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -34,7 +34,7 @@ extern void* mklml_dso_handle; #define DYNAMIC_LOAD_MKLML_WRAP(__name) \ struct DynLoad__##__name { \ template \ - auto operator()(Args... args) -> decltype(__name(args...)) { \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ using mklmlFunc = decltype(&::__name); \ std::call_once(mklml_dso_flag, []() { \ mklml_dso_handle = paddle::platform::dynload::GetMKLMLDsoHandle(); \ diff --git a/paddle/fluid/platform/dynload/tensorrt.h b/paddle/fluid/platform/dynload/tensorrt.h index 5d67658b94..751aa54b1a 100644 --- a/paddle/fluid/platform/dynload/tensorrt.h +++ b/paddle/fluid/platform/dynload/tensorrt.h @@ -33,7 +33,7 @@ extern void* tensorrt_dso_handle; #define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \ struct DynLoad__##__name { \ template \ - auto operator()(Args... args) -> decltype(__name(args...)) { \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ using tensorrt_func = decltype(__name(args...)) (*)(Args...); \ std::call_once(tensorrt_dso_flag, []() { \ tensorrt_dso_handle = \ diff --git a/paddle/fluid/platform/dynload/warpctc.h b/paddle/fluid/platform/dynload/warpctc.h index 18ed9956f1..bc1977b05d 100644 --- a/paddle/fluid/platform/dynload/warpctc.h +++ b/paddle/fluid/platform/dynload/warpctc.h @@ -34,7 +34,7 @@ extern void* warpctc_dso_handle; #define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \ struct DynLoad__##__name { \ template \ - auto operator()(Args... args) -> decltype(__name(args...)) { \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ using warpctcFunc = decltype(&::__name); \ std::call_once(warpctc_dso_flag, []() { \ warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \ From b73d7d2f21a4010d10b1a2456e5991d77ed5e01e Mon Sep 17 00:00:00 2001 From: peizhilin Date: Tue, 18 Dec 2018 20:27:14 +0800 Subject: [PATCH 25/29] test=develop --- python/setup.py.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/setup.py.in b/python/setup.py.in index fefe8fbaa7..22b9537a90 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -162,7 +162,7 @@ if '${WITH_FLUID_ONLY}'== 'OFF': libs_path='${PADDLE_BINARY_DIR}/python/paddle/libs' package_data['paddle.libs']= [] -package_data['paddle.libs']=['libwarpctc' + ext_name] +package_data['paddle.libs']=[('libwarpctc' if os.name != 'nt' else 'warpctc') + ext_name] shutil.copy('${WARPCTC_LIBRARIES}', libs_path) if '${WITH_MKL}' == 'ON': From aa6e9c30becf0215870fd3633684c97a6d614263 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Wed, 19 Dec 2018 03:54:05 +0100 Subject: [PATCH 26/29] [MKL-DNN ]Added transpose/transpose2 Op (#14872) * - Added transpose MKLDNN Op - Few basic UT works - Added 1D transpose - implementing generic mem desc for MKLDNN transpose - Modified trnaspose op to support more dimensional data eg. 5,6..10 - Added is_test attribute to transpose op test=develop * - Added support for MKLDNN::memory::format::any for Transpose MKLDNN op test=develop * - Additional transpose mkldnn op correction to mkldnn layout test=develop * Cosmetic fixes test=develop * - Removed const_cast to obey coding standard test=develop --- paddle/fluid/operators/transpose_mkldnn_op.cc | 124 ++++++++++++++++++ paddle/fluid/operators/transpose_op.cc | 49 ++++++- .../unittests/test_transpose_mkldnn_op.py | 76 +++++++++++ .../tests/unittests/test_transpose_op.py | 13 +- 4 files changed, 258 insertions(+), 4 deletions(-) create mode 100644 paddle/fluid/operators/transpose_mkldnn_op.cc create mode 100644 python/paddle/fluid/tests/unittests/test_transpose_mkldnn_op.py diff --git a/paddle/fluid/operators/transpose_mkldnn_op.cc b/paddle/fluid/operators/transpose_mkldnn_op.cc new file mode 100644 index 0000000000..37f1cadc7d --- /dev/null +++ b/paddle/fluid/operators/transpose_mkldnn_op.cc @@ -0,0 +1,124 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/framework/data_layout_transform.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using framework::DataLayout; + +template +class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + const bool is_test = ctx.Attr("is_test"); + PADDLE_ENFORCE( + is_test == true, + "ConvTransposeMKLDNN works only for inference!. Set is_test = True"); + auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + std::vector axis = ctx.Attr>("axis"); + int ndims = axis.size(); + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + const T* input_data = input->data(); + + if (ndims == 1) { + output->ShareDataWith(*input); + return; + } + + std::vector nchw_axis(ndims, 0); + for (size_t i = 0; i < nchw_axis.size(); ++i) { + nchw_axis[i] = i; + } + + std::vector nchw_tz = paddle::framework::vectorize2int(input->dims()); + std::string data_format = ctx.Attr("data_format"); + + auto src_md = + input->format() != mkldnn::memory::format::nchw + ? platform::MKLDNNMemDesc(nchw_tz, platform::MKLDNNGetDataType(), + input->format()) + : Axis2MemoryDesc(nchw_tz, nchw_axis); + + this->TransposeKernel(ctx.GetPlace(), Axis2MemoryDesc(nchw_tz, axis), + src_md, output, input_data, nchw_tz, mkldnn_engine); + } + + protected: + mkldnn::memory::desc Axis2MemoryDesc(std::vector& nchw_tz, + std::vector& axis) const { + mkldnn_memory_desc_t mem_fmt; + + mem_fmt.primitive_kind = mkldnn_memory; + mem_fmt.ndims = axis.size(); + for (unsigned int i = 0; i < nchw_tz.size(); ++i) { + mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format, + // regardless physical layout) + } + mem_fmt.data_type = mkldnn_f32; + mem_fmt.format = mkldnn_blocked; + + unsigned int total_stride = 1; + for (int i = nchw_tz.size() - 1; i >= 0; --i) { + mem_fmt.layout_desc.blocking.padding_dims[i] = + nchw_tz[i]; // logical dimensions (nchw format, regardless physical + // layout) + mem_fmt.layout_desc.blocking.block_dims[i] = 1; + mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset + mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride; + mem_fmt.layout_desc.blocking.strides[1][axis[i]] = 1; + total_stride *= nchw_tz[axis[i]]; + } + mem_fmt.layout_desc.blocking.offset_padding = 0; // no initial offset + return mem_fmt; + } + + void TransposeKernel(platform::Place place, mkldnn::memory::desc md_o, + mkldnn::memory::desc md_i, Tensor* output, + const T* data_i, std::vector& nchw_dims, + const mkldnn::engine& eng) const { + // Make Memory primitive descriptors + auto mpd_o = mkldnn::memory::primitive_desc(md_o, eng); + auto mpd_i = mkldnn::memory::primitive_desc(md_i, eng); + + auto data_o = output->mutable_data( + place, paddle::memory::Allocator::kDefault, mpd_o.get_size()); + + auto src = mkldnn::memory(mpd_i, (T*)(data_i)); + auto dst = mkldnn::memory(mpd_o, data_o); + + auto r = mkldnn::reorder(src, dst); + mkldnn::stream(mkldnn::stream::kind::eager).submit({r}).wait(); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(transpose2, MKLDNN, ::paddle::platform::CPUPlace, + ops::TransposeMKLDNNOpKernel); +REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace, + ops::TransposeMKLDNNOpKernel); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index bc1f59bc1a..b3b379d16f 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -16,6 +16,10 @@ limitations under the License. */ #include #include +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + namespace paddle { namespace operators { @@ -53,11 +57,32 @@ class TransposeOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Out", out_dims); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } +#endif + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace(), layout_, library_); + } }; class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); AddInput( "X", "(Tensor) The input tensor, tensors with rank up to 6 are supported."); @@ -67,6 +92,16 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { "(vector) A list of values, and the size of the list should be " "the same with the input tensor rank. This operator permutes the input " "tensor's axes according to the values given."); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("AnyLayout"); AddComment(R"DOC( Transpose Operator. @@ -144,8 +179,18 @@ class Transpose2Op : public TransposeOp { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } +#endif + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace(), layout_, library_); } }; diff --git a/python/paddle/fluid/tests/unittests/test_transpose_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_transpose_mkldnn_op.py new file mode 100644 index 0000000000..61ac879011 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_transpose_mkldnn_op.py @@ -0,0 +1,76 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +from test_transpose_op import TestTransposeOp + + +class TestTransposeMKLDNN(TestTransposeOp): + def init_op_type(self): + self.op_type = "transpose2" + self.use_mkldnn = True + self.is_test = True + return + + def test_check_grad(self): + return + + def test_check_grad_no_input(self): + return + + def test_check_grad_no_filter(self): + return + + +class TestCase0MKLDNN(TestTransposeMKLDNN): + def initTestCase(self): + self.shape = (3, ) + self.axis = (0, ) + + +class TestCase1a(TestTransposeMKLDNN): + def initTestCase(self): + self.shape = (3, 4, 5) + self.axis = (0, 2, 1) + + +class TestCase1b(TestTransposeMKLDNN): + def initTestCase(self): + self.shape = (3, 4, 5) + self.axis = (2, 1, 0) + + +class TestCase2(TestTransposeMKLDNN): + def initTestCase(self): + self.shape = (2, 3, 4, 5) + self.axis = (0, 2, 3, 1) + + +class TestCase3(TestTransposeMKLDNN): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.axis = (4, 2, 3, 1, 0) + + +class TestCase4(TestTransposeMKLDNN): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6, 1) + self.axis = (4, 2, 3, 1, 0, 5) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index bbcabb751f..93be9d28da 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -21,15 +21,24 @@ from op_test import OpTest class TestTransposeOp(OpTest): def setUp(self): + self.init_op_type() self.initTestCase() - self.op_type = "transpose2" self.inputs = {'X': np.random.random(self.shape).astype("float32")} - self.attrs = {'axis': list(self.axis)} + self.attrs = { + 'axis': list(self.axis), + 'use_mkldnn': self.use_mkldnn, + 'is_test': self.is_test, + } self.outputs = { 'XShape': np.random.random(self.shape).astype("float32"), 'Out': self.inputs['X'].transpose(self.axis) } + def init_op_type(self): + self.op_type = "transpose2" + self.use_mkldnn = False + self.is_test = False + def test_check_output(self): self.check_output(no_check_set=['XShape']) From b849157e9d3584a8d4b891340706c181c542deb0 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 19 Dec 2018 11:44:48 +0800 Subject: [PATCH 27/29] Add size enforce (#14919) --- .../distributed/brpc_sendrecvop_utils.cc | 23 ++++++++++++++----- .../fluid/operators/distributed/grpc_serde.cc | 8 +++++++ .../operators/distributed/sendrecvop_utils.h | 9 ++++++-- .../distributed/variable_response.cc | 2 +- 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/distributed/brpc_sendrecvop_utils.cc b/paddle/fluid/operators/distributed/brpc_sendrecvop_utils.cc index 6fed9ba92c..e4604db3a3 100644 --- a/paddle/fluid/operators/distributed/brpc_sendrecvop_utils.cc +++ b/paddle/fluid/operators/distributed/brpc_sendrecvop_utils.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #endif #include +#include #include // NOLINT #include "paddle/fluid/framework/data_type.h" @@ -31,7 +32,12 @@ namespace distributed { class IOBufWriter { public: - static void Append(butil::IOBuf* iobuf, int k, const char* v, int64_t vlen) { + static void Append(const std::string& varname, butil::IOBuf* iobuf, int k, + const char* v, int64_t vlen) { + if (vlen >= std::numeric_limits::max() || vlen < 0) { + LOG(FATAL) << "AppendZeroCopy varname:" << varname << ", vlen:" << vlen; + } + iobuf->append(reinterpret_cast(&k), 4); iobuf->append(reinterpret_cast(&vlen), 8); iobuf->append(v, vlen); @@ -87,6 +93,10 @@ class IOBufWriter { int k, const char* v, int64_t vlen, bool in_cuda_pinned, void (*destroy)(void*), void* user_data) { + if (vlen >= std::numeric_limits::max() || vlen < 0) { + LOG(FATAL) << "AppendZeroCopy varname:" << varname << ", vlen:" << vlen; + } + #ifdef PADDLE_WITH_BRPC_RDMA IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned, destroy, user_data); @@ -134,7 +144,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, request->set_type(::sendrecv::NCCL_ID); const ncclUniqueId& uid = var->Get(); // TODO(gongwb): use append_zero to avoid data copy. - IOBufWriter::Append(iobuf, + IOBufWriter::Append(name, iobuf, sendrecv::VariableMessage::kSerializedFieldNumber, uid.internal, NCCL_UNIQUE_ID_BYTES); return; @@ -149,7 +159,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, // FIXME(gongwb): it seems that can use zero copy. if (var_is_not_stable) { IOBufWriter::Append( - iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber, + name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber, static_cast(payload->ptr()), payload->memory_size()); } else { if (platform::is_gpu_place(ctx.GetPlace())) { @@ -171,10 +181,11 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, if (var->IsType()) { auto* slr = var->GetMutable(); - size_t rows_memory_size = - slr->rows().size() * framework::SizeOfType(typeid(int64_t)); + PADDLE_ENFORCE(VectorElemName(slr->rows()) == typeid(int64_t).name()); + size_t rows_memory_size = slr->rows().size() * sizeof(int64_t); - IOBufWriter::Append(iobuf, ::sendrecv::VariableMessage::kRowsFieldNumber, + IOBufWriter::Append(name, iobuf, + ::sendrecv::VariableMessage::kRowsFieldNumber, reinterpret_cast(slr->rows().data()), static_cast(rows_memory_size)); } diff --git a/paddle/fluid/operators/distributed/grpc_serde.cc b/paddle/fluid/operators/distributed/grpc_serde.cc index 299dfe3543..a9dea9cfd2 100644 --- a/paddle/fluid/operators/distributed/grpc_serde.cc +++ b/paddle/fluid/operators/distributed/grpc_serde.cc @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include #endif +#include #include // NOLINT #include "google/protobuf/io/coded_stream.h" @@ -102,6 +103,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload->memory_size()); + if (payload->memory_size() >= std::numeric_limits::max()) { + LOG(FATAL) << "AppendZeroCopy varname:" << name + << ", vlen:" << payload->memory_size(); + } // steal reference of tensor data ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows int num_slices = 2; // only SelectedRows have rows buffer @@ -115,7 +120,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, if (var->IsType()) { auto* slr = var->GetMutable(); ProtoEncodeHelper e2(static_cast(buf), 128); + + PADDLE_ENFORCE(VectorElemName(slr->rows()) == typeid(int64_t).name()); size_t rows_memory_size = slr->rows().size() * sizeof(int64_t); + e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); slices[2] = ::grpc::Slice(e2.size()); memcpy(const_cast(slices[2].begin()), e2.data(), e2.size()); diff --git a/paddle/fluid/operators/distributed/sendrecvop_utils.h b/paddle/fluid/operators/distributed/sendrecvop_utils.h index 33eded0e6c..6a87178be5 100644 --- a/paddle/fluid/operators/distributed/sendrecvop_utils.h +++ b/paddle/fluid/operators/distributed/sendrecvop_utils.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include #include +#include #include #include "paddle/fluid/framework/data_type.h" @@ -23,9 +24,8 @@ limitations under the License. */ #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/platform/port.h" - #include "paddle/fluid/operators/distributed/send_recv.pb.h" +#include "paddle/fluid/platform/port.h" namespace paddle { namespace operators { @@ -83,6 +83,11 @@ inline framework::proto::VarType::Type ToVarType( } } +template