Merge develop

test=develop
move-code
sneaxiy 6 years ago
commit 2f54d9f995

@ -15,7 +15,7 @@ paddle.fluid.cpu_places (ArgSpec(args=['device_count'], varargs=None, keywords=N
paddle.fluid.cuda_pinned_places (ArgSpec(args=['device_count'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd0c3ebd813c39958c92b78e3eef7e912'))
paddle.fluid.Executor.__init__ (ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.Executor.close (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'f5369953dd0c443961cf79f7a00e1a03'))
paddle.fluid.Executor.run (ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)), ('document', 'aba8093edebf2d5c869b735b92811e45'))
paddle.fluid.Executor.run (ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)), ('document', 'f482e93b38b4018796969a2e1dde479d'))
paddle.fluid.global_scope (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'e148d3ab1ed8edf3e928212a375959c0'))
paddle.fluid.scope_guard (ArgSpec(args=['scope'], varargs=None, keywords=None, defaults=None), ('document', 'b94d1f6bcc29c4fb58fc0058561250c2'))
paddle.fluid.DistributeTranspiler.__init__ (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
@ -47,7 +47,7 @@ paddle.fluid.AsyncExecutor.run (ArgSpec(args=['self', 'program', 'data_feed', 'f
paddle.fluid.AsyncExecutor.save_model (ArgSpec(args=['self', 'save_path'], varargs=None, keywords=None, defaults=None), ('document', 'c8ac0dfcb3b187aba25d03af7fea56b2'))
paddle.fluid.AsyncExecutor.stop (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '5f23d043607bb5d55e466ec3f578e093'))
paddle.fluid.CompiledProgram.__init__ (ArgSpec(args=['self', 'program_or_graph'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.CompiledProgram.with_data_parallel (ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from', 'places'], varargs=None, keywords=None, defaults=(None, None, None, None, None)), ('document', '5e8cca4619a5d7c3280fb3cae7021b14'))
paddle.fluid.CompiledProgram.with_data_parallel (ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from', 'places'], varargs=None, keywords=None, defaults=(None, None, None, None, None)), ('document', 'a8c7793803cf976680d9478e378fa356'))
paddle.fluid.CompiledProgram.with_inference_optimize (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=None), ('document', '9e5b009d850191a010e859189c127fd8'))
paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.ExecutionStrategy) -> None
paddle.fluid.BuildStrategy.GradientScaleStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.BuildStrategy.GradientScaleStrategy, arg0: int) -> None
@ -286,7 +286,7 @@ paddle.fluid.layers.DynamicRNN.block (ArgSpec(args=['self'], varargs=None, keywo
paddle.fluid.layers.DynamicRNN.memory (ArgSpec(args=['self', 'init', 'shape', 'value', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, False, 'float32')), ('document', 'b9174d4e91505b0c8ecc193eb51e248d'))
paddle.fluid.layers.DynamicRNN.output (ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None), ('document', 'b439a176a3328de8a75bdc5c08eece4a'))
paddle.fluid.layers.DynamicRNN.static_input (ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None), ('document', 'f29ad2478b6b2ad4f413d2936a331ea0'))
paddle.fluid.layers.DynamicRNN.step_input (ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None), ('document', '169d694d2224f62b4f3afdc3dbc19e95'))
paddle.fluid.layers.DynamicRNN.step_input (ArgSpec(args=['self', 'x', 'level'], varargs=None, keywords=None, defaults=(0,)), ('document', '7568c5ac7622a10288d3307a94134655'))
paddle.fluid.layers.DynamicRNN.update_memory (ArgSpec(args=['self', 'ex_mem', 'new_mem'], varargs=None, keywords=None, defaults=None), ('document', '5d83987da13b98363d6a807a52d8024f'))
paddle.fluid.layers.StaticRNN.__init__ (ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.StaticRNN.memory (ArgSpec(args=['self', 'init', 'shape', 'batch_ref', 'init_value', 'init_batch_dim_idx', 'ref_batch_dim_idx'], varargs=None, keywords=None, defaults=(None, None, None, 0.0, 0, 1)), ('document', 'c24e368e23afac1ed91a78a639d7a9c7'))
@ -305,7 +305,7 @@ paddle.fluid.layers.tanh (ArgSpec(args=['x', 'name'], varargs=None, keywords=Non
paddle.fluid.layers.atan (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '3a46e0b5f9ce82348406478e610f14c9'))
paddle.fluid.layers.tanh_shrink (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '1e521554b9fdda9061ec6d306f0709b7'))
paddle.fluid.layers.softshrink (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '9eef31597bbafa2bd49691e072296e13'))
paddle.fluid.layers.sqrt (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '072a8541e0f632366bba10f67cb0db27'))
paddle.fluid.layers.sqrt (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'e9e27491c39ac74d0b1ffe506aec0ebb'))
paddle.fluid.layers.abs (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '64650ac42cf82e9920cb0b172b1d29fd'))
paddle.fluid.layers.ceil (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'c75d67dc5fe28f68e4cfffead4f698ad'))
paddle.fluid.layers.floor (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '647b16c5da5ef909649ae02abb434973'))
@ -503,7 +503,7 @@ paddle.fluid.CUDAPinnedPlace.__init__ __init__(self: paddle.fluid.core.CUDAPinne
paddle.fluid.ParamAttr.__init__ (ArgSpec(args=['self', 'name', 'initializer', 'learning_rate', 'regularizer', 'trainable', 'gradient_clip', 'do_model_average'], varargs=None, keywords=None, defaults=(None, None, 1.0, None, True, None, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.WeightNormParamAttr.__init__ (ArgSpec(args=['self', 'dim', 'name', 'initializer', 'learning_rate', 'regularizer', 'trainable', 'gradient_clip', 'do_model_average'], varargs=None, keywords=None, defaults=(None, None, None, 1.0, None, True, None, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.DataFeeder.__init__ (ArgSpec(args=['self', 'feed_list', 'place', 'program'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.DataFeeder.decorate_reader (ArgSpec(args=['self', 'reader', 'multi_devices', 'num_places', 'drop_last'], varargs=None, keywords=None, defaults=(None, True)), ('document', '0eed2f198dc73c08a41b61edbc755753'))
paddle.fluid.DataFeeder.decorate_reader (ArgSpec(args=['self', 'reader', 'multi_devices', 'num_places', 'drop_last'], varargs=None, keywords=None, defaults=(None, True)), ('document', 'f8f3df23c5633c614db781a91b81fb62'))
paddle.fluid.DataFeeder.feed (ArgSpec(args=['self', 'iterable'], varargs=None, keywords=None, defaults=None), ('document', '459e316301279dfd82001b46f0b8ffca'))
paddle.fluid.DataFeeder.feed_parallel (ArgSpec(args=['self', 'iterable', 'num_places'], varargs=None, keywords=None, defaults=(None,)), ('document', '543863d1f9d4853758adb613b8659e85'))
paddle.fluid.clip.ErrorClipByValue.__init__ (ArgSpec(args=['self', 'max', 'min'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
@ -528,11 +528,11 @@ paddle.reader.compose (ArgSpec(args=[], varargs='readers', keywords='kwargs', de
paddle.reader.chain (ArgSpec(args=[], varargs='readers', keywords=None, defaults=None), ('document', 'd22c34e379a53901ae67a6bca7f4def4'))
paddle.reader.shuffle (ArgSpec(args=['reader', 'buf_size'], varargs=None, keywords=None, defaults=None), ('document', 'e42ea6fee23ce26b23cb142cd1d6522d'))
paddle.reader.firstn (ArgSpec(args=['reader', 'n'], varargs=None, keywords=None, defaults=None), ('document', 'c5bb8f7dd4f917f1569a368aab5b8aad'))
paddle.reader.xmap_readers (ArgSpec(args=['mapper', 'reader', 'process_num', 'buffer_size', 'order'], varargs=None, keywords=None, defaults=(False,)), ('document', '283bc0b8a0e26ae186b8b9bee4aec560'))
paddle.reader.xmap_readers (ArgSpec(args=['mapper', 'reader', 'process_num', 'buffer_size', 'order'], varargs=None, keywords=None, defaults=(False,)), ('document', '9c804a42f8a4dbaa76b3c98e0ab7f796'))
paddle.reader.PipeReader.__init__ (ArgSpec(args=['self', 'command', 'bufsize', 'file_type'], varargs=None, keywords=None, defaults=(8192, 'plain')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.reader.PipeReader.get_line (ArgSpec(args=['self', 'cut_lines', 'line_break'], varargs=None, keywords=None, defaults=(True, '\n')), ('document', '5f80a7ed70052f01665e4c74acccfa69'))
paddle.reader.PipeReader.get_line (ArgSpec(args=['self', 'cut_lines', 'line_break'], varargs=None, keywords=None, defaults=(True, '\n')), ('document', '9621ae612e595b6c34eb3bb5f3eb1a45'))
paddle.reader.multiprocess_reader (ArgSpec(args=['readers', 'use_pipe', 'queue_size'], varargs=None, keywords=None, defaults=(True, 1000)), ('document', '7d8b3a96e592107c893d5d51ce968ba0'))
paddle.reader.Fake.__init__ (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.reader.creator.np_array (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '28d457fbc9a71efa4ac91a3be179cada'))
paddle.reader.creator.text_file (ArgSpec(args=['path'], varargs=None, keywords=None, defaults=None), ('document', '44fe286ab6175a5464d3a961a68c266a'))
paddle.reader.creator.recordio (ArgSpec(args=['paths', 'buf_size'], varargs=None, keywords=None, defaults=(100,)), ('document', '11b3704ea42cfd537953387a7e58dae8'))
paddle.reader.creator.text_file (ArgSpec(args=['path'], varargs=None, keywords=None, defaults=None), ('document', 'f45fcb7add066c8e042c6774fc7c3db2'))
paddle.reader.creator.recordio (ArgSpec(args=['paths', 'buf_size'], varargs=None, keywords=None, defaults=(100,)), ('document', 'b4a94ee0e2cefb495619275c2f8c61d2'))

@ -9,6 +9,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
@ -22,6 +23,8 @@ endif()
if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
if(WITH_DISTRIBUTE)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_rpc)
@ -35,6 +38,8 @@ if(WITH_GPU)
else()
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
if(WITH_DISTRIBUTE)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_rpc)
@ -71,6 +76,8 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
@ -98,5 +105,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass)
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass)

@ -46,7 +46,16 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) {
// Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
}
if (strategy_.enable_sequential_execution_) {
VLOG(10) << "Add sequential_execution_pass";
AppendPass("sequential_execution_pass");
}
@ -57,6 +66,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) {
VLOG(10) << "Add fuse_relu_depthwise_conv_pass";
AppendPass("fuse_relu_depthwise_conv_pass");
}
@ -68,29 +78,30 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add automatically inplace.
if (strategy_.enable_inplace_) {
VLOG(10) << "Add inplace_pass";
AppendPass("inplace_pass");
}
if (strategy.fuse_elewise_add_act_ops_) {
VLOG(10) << "Add fuse_elewise_add_act_pass";
AppendPass("fuse_elewise_add_act_pass");
}
// for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
if (strategy.fuse_all_reduce_ops_) {
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
}
// Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) {
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
}
if (strategy.fuse_elewise_add_act_ops_) {
auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass");
// Add a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
viz_pass->Set<std::string>("graph_viz_path",
new std::string(graph_path));
}
}
CollectiveContext *context = CollectiveContext::GetInstance();
context->endpoints_ = strategy_.trainers_endpoints_;
context->trainer_id_ = strategy_.trainer_id_;
@ -108,11 +119,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if (strategy.memory_optimize_) {
auto memory_optimize_pass = AppendPass("memory_optimize_pass");
VLOG(10) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
}
AppendMultiDevPass(strategy);
if (strategy.fuse_all_reduce_ops_) {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG(10) << "Add fuse_all_reduce_op_pass";
AppendPass("fuse_all_reduce_op_pass");
}
// Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) {
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
@ -129,27 +148,29 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("multi_devices_check_pass");
if (SeqOnlyAllReduceOps(strategy)) {
VLOG(10) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass");
}
if (strategy_.remove_unnecessary_lock_) {
VLOG(10) << "Add modify_op_lock_and_record_event_pass";
AppendPass("modify_op_lock_and_record_event_pass");
}
}
// Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass;
ir::Pass *multi_devices_pass = nullptr;
if (strategy_.is_distribution_) {
VLOG(3) << "multi device parameter server mode";
VLOG(10) << "Add dist_multi_devices_pass";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(3) << "multi devices collective mode with allreduce";
VLOG(10) << "Add all_reduce_mode_multi_devices_pass";
multi_devices_pass =
AppendPass("allreduce_mode_multi_devices_pass").get();
AppendPass("all_reduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(3) << "multi deivces collective mode with reduce";
VLOG(10) << "Add reduce_mode_multi_devices_pass";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else {
PADDLE_THROW("Unknown reduce strategy.");
@ -206,9 +227,26 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase("nccl_ctxs");
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
#endif
} else if (pass->Type() == "fuse_all_reduce_op_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
#endif
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
} else if (pass->Type() == "sequential_execution_pass") {
LOG(INFO) << "set enable_sequential_execution:"
<< enable_sequential_execution_;
@ -239,7 +277,7 @@ USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass);
USE_PASS(reduce_mode_multi_devices_pass);
USE_PASS(allreduce_mode_multi_devices_pass);
USE_PASS(all_reduce_mode_multi_devices_pass);
USE_PASS(dist_multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
@ -249,4 +287,6 @@ USE_PASS(all_reduce_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass);
USE_PASS(alloc_continuous_space_for_grad_pass);
USE_PASS(graph_to_program_pass);
USE_PASS(fuse_all_reduce_op_pass);

@ -16,6 +16,7 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/pass_builder.h"
@ -75,6 +76,8 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false};
bool fuse_all_reduce_ops_{false};
bool fuse_relu_depthwise_conv_{false};
bool sync_batch_norm_{false};

@ -0,0 +1,195 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/fused_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
class FuseAllReduceOpPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *nccl_ctxs = &Get<platform::NCCLContextMap>(kNCCLCtxs);
#endif
std::unordered_set<std::string> grads;
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
size_t num_of_all_reduce = params_grads.size();
grads.reserve(num_of_all_reduce);
for (auto p_g : params_grads) {
grads.insert(p_g.second);
}
size_t num_place = places.size();
std::unordered_map<std::string, ir::Node *> all_reduce_ops;
all_reduce_ops.reserve(grads.size());
for (auto &node : result.Nodes()) {
if (node->IsOp()) {
PADDLE_ENFORCE(node->IsWrappedBy<OpHandleBase>());
auto *all_reduce_op_handle =
dynamic_cast<AllReduceOpHandle *>(&node->Wrapper<OpHandleBase>());
if (all_reduce_op_handle) {
auto inputs = DynamicCast<VarHandle>(all_reduce_op_handle->Inputs());
PADDLE_ENFORCE_EQ(inputs.size(), num_place);
// The inputs' name should be the same.
auto &grad_name = inputs[0]->name();
for (size_t i = 1; i < inputs.size(); ++i) {
PADDLE_ENFORCE_EQ(inputs[i]->name(), grad_name,
"The input name should be the same.");
}
PADDLE_ENFORCE_NE(grads.count(grad_name), static_cast<size_t>(0));
all_reduce_ops.emplace(grad_name, node);
}
}
}
VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size();
if (all_reduce_ops.size() == 0) {
return std::move(graph);
}
PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(),
"The number of all_reduce OpHandle is not equal to the "
"number of grads. Maybe some gradients are sparse type, "
"it is not supported currently.");
VLOG(10) << "Insert fused_all_reduce";
auto &group_grads_params =
graph->Get<GroupGradsAndParams>(kGroupGradsAndParams);
for (auto &group_g_p : group_grads_params) {
size_t group_size = group_g_p.size();
PADDLE_ENFORCE_GT(group_size, static_cast<size_t>(0));
std::vector<ir::Node *> group_all_reduce_ops;
group_all_reduce_ops.reserve(group_size);
for (auto &g_p : group_g_p) {
group_all_reduce_ops.emplace_back(all_reduce_ops.at(g_p.first));
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
InsertFusedAllReduce(places, local_scopes, group_size,
group_all_reduce_ops, nccl_ctxs, &result);
#else
InsertFusedAllReduce(places, local_scopes, group_size,
group_all_reduce_ops, &result);
#endif
}
return std::move(graph);
}
void InsertFusedAllReduce(const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const size_t num_of_all_reduce,
const std::vector<ir::Node *> &all_reduce_ops,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const platform::NCCLContextMap *nccl_ctxs,
#endif
ir::Graph *result) const {
std::vector<VarHandleBase *> inputs;
std::vector<VarHandleBase *> outputs;
for (auto &op : all_reduce_ops) {
auto &op_handle = op->Wrapper<OpHandleBase>();
inputs.insert(inputs.end(), op_handle.Inputs().begin(),
op_handle.Inputs().end());
// Remove output
for_each(op_handle.Inputs().begin(), op_handle.Inputs().end(),
[&op_handle](VarHandleBase *var_handle) {
var_handle->RemoveOutput(&op_handle, op_handle.Node());
});
outputs.insert(outputs.end(), op_handle.Outputs().begin(),
op_handle.Outputs().end());
// Remove Input
for_each(
op_handle.Outputs().begin(), op_handle.Outputs().end(),
[](VarHandleBase *var_handle) { var_handle->ClearGeneratedOp(); });
result->RemoveNode(op_handle.Node());
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places,
local_scopes, nccl_ctxs, result);
#else
CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places,
local_scopes, result);
#endif
}
private:
void CreateFusedAllReduceOp(const std::vector<VarHandleBase *> &inputs,
const std::vector<VarHandleBase *> &outputs,
const size_t num_of_all_reduce,
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const platform::NCCLContextMap *nccl_ctxs,
#endif
ir::Graph *result) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *op_handle = new FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce, nccl_ctxs);
#else
auto *op_handle = new FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce);
#endif
for (auto in : inputs) {
op_handle->AddInput(in);
}
for (auto out : outputs) {
op_handle->AddOutput(out);
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (!nccl_ctxs) {
SetCommunicationContext(places, op_handle);
}
#else
SetCommunicationContext(places, op_handle);
#endif
}
void SetCommunicationContext(const std::vector<platform::Place> &places,
FusedAllReduceOpHandle *op_handle) const {
for (size_t i = 0; i < places.size(); ++i) {
op_handle->SetDeviceContext(
places[i], platform::DeviceContextPool::Instance().Get(places[i]));
}
}
};
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_all_reduce_op_pass,
paddle::framework::details::FuseAllReduceOpPass);

@ -0,0 +1,248 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fused_all_reduce_op_handle.h"
#include <algorithm>
#include <utility>
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_bool(skip_fused_all_reduce_check, false, "");
namespace paddle {
namespace framework {
namespace details {
typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
GradientAndLoDTensor;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
FusedAllReduceOpHandle::FusedAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const size_t num_of_all_reduce,
const platform::NCCLContextMap *ctxs)
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
num_of_all_reduce_(num_of_all_reduce),
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
}
}
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
}
#else
FusedAllReduceOpHandle::FusedAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const size_t num_of_all_reduce)
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
num_of_all_reduce_(num_of_all_reduce) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
}
#endif
void FusedAllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
VLOG(4) << this->DebugString();
WaitInputVarGenerated();
// The input: grad0(dev0), grad0(dev1), grad1(dev0), grad1(dev1)...
// The output: grad0(dev0), grad0(dev1), grad1(dev0), grad1(dev1)...
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
size_t place_num = places_.size();
PADDLE_ENFORCE_EQ(
in_var_handles.size(), place_num * num_of_all_reduce_,
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
GradientAndLoDTensor grads_tensor;
grads_tensor.resize(place_num);
int64_t numel = -1;
auto dtype = static_cast<framework::proto::VarType::Type>(0);
for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) {
auto &g_tensor = grads_tensor.at(scope_idx);
g_tensor.reserve(num_of_all_reduce_);
GetGradLoDTensor(scope_idx, in_var_handles, out_var_handles, &g_tensor);
int64_t element_num = 0;
framework::proto::VarType::Type ele_dtype =
static_cast<framework::proto::VarType::Type>(0);
GetDTypeAndNumel(g_tensor, &ele_dtype, &element_num);
if (numel == -1) {
numel = element_num;
}
if (dtype == static_cast<framework::proto::VarType::Type>(0)) {
dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype,
static_cast<framework::proto::VarType::Type>(0));
}
PADDLE_ENFORCE_EQ(ele_dtype, dtype);
// Check whether the address space is contiguous.
std::sort(
g_tensor.begin(), g_tensor.end(),
[](const std::pair<std::string, const LoDTensor *> &grad1,
const std::pair<std::string, const LoDTensor *> &grad2) -> bool {
return grad1.second->data<void>() < grad2.second->data<void>();
});
for (size_t k = 1; k < g_tensor.size(); ++k) {
const void *pre_address = g_tensor.at(k - 1).second->data<void>();
int64_t len = g_tensor.at(k - 1).second->numel();
auto offset = len * framework::SizeOfType(dtype);
void *next_address = reinterpret_cast<void *>(
reinterpret_cast<uintptr_t>(pre_address) + offset);
const void *cur_address = g_tensor.at(k).second->data<void>();
VLOG(10) << k << ", "
<< " pre_address(" << g_tensor.at(k - 1).first
<< "): " << pre_address << ", cur_address("
<< g_tensor.at(k).first << "): " << cur_address
<< ", offset:" << offset << ", " << next_address << ", "
<< cur_address;
PADDLE_ENFORCE_EQ(next_address, cur_address);
}
}
if (!FLAGS_skip_fused_all_reduce_check) {
for (size_t scope_idx = 0; scope_idx < place_num; ++scope_idx) {
for (size_t j = 1; j < num_of_all_reduce_; ++j) {
PADDLE_ENFORCE_EQ(grads_tensor.at(0).at(j).first,
grads_tensor.at(scope_idx).at(j).first);
}
}
}
std::vector<const void *> lod_tensor_data;
for (size_t scope_idx = 0; scope_idx < place_num; ++scope_idx) {
auto data = grads_tensor.at(scope_idx).at(0).second->data<void>();
lod_tensor_data.emplace_back(data);
}
if (platform::is_gpu_place(places_[0])) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int nccl_dtype = platform::ToNCCLDataType(dtype);
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i];
void *buffer = const_cast<void *>(lod_tensor_data.at(i));
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(nccl_dtype),
ncclSum, comm, stream));
});
}
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
all_reduce_calls[0]();
} else {
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
}
});
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
} else {
// Special handle CPU only Operator's gradient. Like CRF
auto grad_name = grads_tensor.at(0).at(0).first;
auto &trg = *this->local_scopes_[0]
->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(grad_name)
->GetMutable<framework::LoDTensor>();
// Reduce All data to trg in CPU
ReduceBufferData func(lod_tensor_data, trg.data<void>(), numel);
VisitDataType(trg.type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope =
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &p = places_[i];
auto *var = scope.FindVar(grad_name);
auto *dev_ctx = dev_ctxes_.at(p);
size_t size = numel * SizeOfType(trg.type());
RunAndRecordEvent(p, [&trg, var, dev_ctx, p, size] {
auto dst_ptr = var->GetMutable<framework::LoDTensor>()->data<void>();
platform::CPUPlace cpu_place;
memory::Copy(cpu_place, dst_ptr, cpu_place, trg.data<void>(), size);
});
}
}
}
void FusedAllReduceOpHandle::GetGradLoDTensor(
const size_t &scope_idx, const std::vector<VarHandle *> &in_var_handles,
const std::vector<VarHandle *> &out_var_handles,
std::vector<std::pair<std::string, const LoDTensor *>> *grad_tensor) const {
auto *local_scope =
local_scopes_.at(scope_idx)->FindVar(kLocalExecScopeName)->Get<Scope *>();
size_t place_num = places_.size();
for (size_t j = 0; j < in_var_handles.size(); j += place_num) {
auto var_name = in_var_handles[j]->name();
PADDLE_ENFORCE_EQ(var_name, out_var_handles[j]->name());
auto &lod_tensor = local_scope->FindVar(var_name)->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(lod_tensor.place(), places_.at(scope_idx));
grad_tensor->emplace_back(std::make_pair(var_name, &lod_tensor));
}
}
void FusedAllReduceOpHandle::GetDTypeAndNumel(
const std::vector<std::pair<std::string, const LoDTensor *>> &grad_tensor,
proto::VarType::Type *dtype, int64_t *numel) const {
*numel = 0;
for (size_t i = 0; i < grad_tensor.size(); ++i) {
// Get element number
int64_t len = grad_tensor.at(i).second->numel();
PADDLE_ENFORCE_GT(len, 0);
*numel += len;
// Get dtype
auto ele_type = grad_tensor.at(i).second->type();
if (i == 0) {
*dtype = ele_type;
}
PADDLE_ENFORCE_EQ(ele_type, *dtype);
}
}
std::string FusedAllReduceOpHandle::Name() const { return "fused_all_reduce"; }
} // namespace details
} // namespace framework
} // namespace paddle

@ -0,0 +1,76 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
namespace details {
struct FusedAllReduceOpHandle : public OpHandleBase {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
FusedAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const size_t num_of_all_reduce,
const platform::NCCLContextMap *ctxs);
#else
FusedAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const size_t num_of_all_reduce);
#endif
std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsMultiDeviceTransfer() override { return true; };
protected:
void RunImpl() override;
private:
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
size_t num_of_all_reduce_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const platform::NCCLContextMap *nccl_ctxs_;
#endif
// Check the dtype of the input
void GetDTypeAndNumel(
const std::vector<std::pair<std::string, const LoDTensor *>> &g_tensor,
proto::VarType::Type *dtype, int64_t *total_num) const;
// Get gradient's name and LoDTensor
void GetGradLoDTensor(const size_t &scope_idx,
const std::vector<VarHandle *> &in_var_handles,
const std::vector<VarHandle *> &out_var_handles,
std::vector<std::pair<std::string, const LoDTensor *>>
*grad_tensor) const;
};
} // namespace details
} // namespace framework
} // namespace paddle

@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker {
class DummyVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
auto type = block->Var(inputs.front())->GetType();
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(type);
void operator()(framework::InferVarTypeContext* ctx) const override {
auto& inputs = ctx->Input("X");
auto type = ctx->GetType(inputs.front());
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, type);
}
};

@ -11,18 +11,17 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include <algorithm>
#include <fstream>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
@ -134,21 +133,26 @@ void AddOutputToLeafOps(ir::Graph *graph) {
}
} // namespace
void MultiDevSSAGraphBuilderBase::CheckGraph(const ir::Graph &graph) const {}
void MultiDevSSAGraphBuilderBase::Init() const {
all_vars_.clear();
loss_var_name_ = Get<const std::string>(kLossVarName);
VLOG(10) << "Init MultiDevSSAGraphBuilder, loss name: " << loss_var_name_;
places_ = Get<const std::vector<platform::Place>>(kPlaces);
local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);
strategy_ = Get<const BuildStrategy>(kStrategy);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
nccl_ctxs_ = &Get<platform::NCCLContextMap>(kNCCLCtxs);
#endif
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
}
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
Init();
CheckGraph(*graph);
std::vector<ir::Node *> sorted_ops = SortOperations(*graph);
auto nodes = graph->ReleaseNodes();
@ -166,7 +170,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
result.Set(kGraphOps, new GraphOps);
bool is_forwarding = true;
bool insert_collection_ops = NeedCollectiveOps();
for (ir::Node *node : sorted_ops) {
if (DealWithSpecialOp(&result, node)) {
@ -185,8 +188,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
CreateComputationalOps(&result, node, places_.size());
}
// Insert collection ops
if (!is_forwarding && insert_collection_ops) {
// Insert collective ops if nranks > 1
if (!is_forwarding && Get<size_t>(kNRanks) > 1) {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
@ -200,13 +203,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
InsertCollectiveOp(&result, p_name, g_name);
if (NeedCollectiveForGrad(g_name, sorted_ops)) {
InsertCollectiveOp(&result, p_name, g_name);
}
}
} catch (boost::bad_get e) {
}
@ -226,6 +229,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
* Only variables should be the leaves of graph.
*/
AddOutputToLeafOps(&result);
result.Erase(kGraphOps);
return graph;
}
@ -258,6 +262,11 @@ void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
}
}
bool MultiDevSSAGraphBuilderBase::DealWithSpecialOp(ir::Graph *result,
ir::Node *node) const {
return false;
}
std::vector<ir::Node *> MultiDevSSAGraphBuilderBase::SortOperations(
const ir::Graph &graph) const {
return ir::TopologySortOperations(graph);
@ -271,8 +280,20 @@ bool MultiDevSSAGraphBuilderBase::UseGPU() const {
return use_gpu;
}
bool MultiDevSSAGraphBuilderBase::NeedCollectiveOps() const {
return Get<size_t>(kNRanks) > 1;
bool MultiDevSSAGraphBuilderBase::NeedCollectiveForGrad(
const std::string &grad_name, std::vector<ir::Node *> ops) const {
// if we have allreduce_op for current gradient variable in the graph,
// then we don't need to add allreduce_op_handle for this gradient
// NOTE: This is for the case that all gradients should add collective ops
for (auto *node : ops) {
if (node->Op()->Type() != "allreduce") continue;
for (auto in_name : node->Op()->InputArgumentNames()) {
if (in_name == grad_name) {
return false;
}
}
}
return true;
}
void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
@ -496,20 +517,17 @@ VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result,
}
bool MultiDevSSAGraphBuilderBase::IsScaleLossOp(ir::Node *node) const {
return boost::get<int>(
return !loss_var_name_.empty() && node->Op() &&
boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
(static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss)) &&
!loss_var_name_.empty(); // If loss_var is empty. This is test mode
static_cast<int>(OpRole::kLoss));
}
bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
const std::string &og) const {
PADDLE_ENFORCE(all_vars_.count(og) != 0);
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true;
}
return false;
return all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS;
}
void AllReduceSSAGraphBuilder::InsertCollectiveOp(
@ -995,7 +1013,7 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass,
paddle::framework::details::ReduceSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(
allreduce_mode_multi_devices_pass,
all_reduce_mode_multi_devices_pass,
paddle::framework::details::AllReduceSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass,
paddle::framework::details::DistSSAGraphBuilder);

@ -14,7 +14,10 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
@ -31,12 +34,6 @@ namespace framework {
class Scope;
namespace details {
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
@ -44,18 +41,21 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
virtual void Init() const;
virtual void CheckGraph(const ir::Graph &graph) const;
virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const = 0;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
virtual void InsertPostprocessOps(ir::Graph *result) const = 0;
bool UseGPU() const;
bool NeedCollectiveOps() const;
bool NeedCollectiveForGrad(const std::string &grad_name,
std::vector<ir::Node *> ops) const;
bool IsScaleLossOp(ir::Node *node) const;
@ -109,10 +109,6 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const {
return false;
}
virtual void InsertPostprocessOps(ir::Graph *result) const {}
};

@ -16,6 +16,9 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
@ -44,6 +47,26 @@ const char kGraphVars[] = "vars";
typedef std::unordered_set<VarHandleBase *> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
constexpr char kNCCLCtxs[] = "nccl_ctxs";
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
typedef std::unordered_set<std::string> FusedVars;
constexpr char kFusedVars[] = "fused_vars";
typedef std::vector<std::pair<std::string, std::string>> ParamsAndGrads;
constexpr char kParamsAndGrads[] = "params_grads";
typedef std::vector<std::vector<std::pair<std::string, std::string>>>
GroupGradsAndParams;
constexpr char kGroupGradsAndParams[] = "group_grads_params";
constexpr char kFusedVarNamePrefix[] = "@FUSEDVAR@";
} // namespace details
} // namespace framework
} // namespace paddle

@ -16,6 +16,8 @@ limitations under the License. */
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/inplace_op_inference.h"
@ -127,9 +129,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template <typename T>
struct OpInfoFiller<T, kVarTypeInference> {
void operator()(const char* op_type, OpInfo* info) const {
info->infer_var_type_ = [](const OpDesc& fwd_op, BlockDesc* block) {
info->infer_var_type_ = [](InferVarTypeContext* context) {
T inference;
inference(fwd_op, block);
inference(context);
};
}
};

@ -53,6 +53,31 @@ struct ReduceLoDTensor {
}
};
struct ReduceBufferData {
const std::vector<const void *> &src_data_;
void *dst_data_;
int64_t numel_;
ReduceBufferData(const std::vector<const void *> &src, void *dst,
int64_t numel)
: src_data_(src), dst_data_(dst), numel_(numel) {}
template <typename T>
void apply() const {
T *dst_data = reinterpret_cast<T *>(dst_data_);
for (size_t i = 0; i < src_data_.size(); ++i) {
auto srd_data = reinterpret_cast<const T *>(src_data_[i]);
VLOG(10) << "dst: " << dst_data_ << ", " << srd_data;
if (srd_data == dst_data_) {
continue;
}
std::transform(srd_data, srd_data + numel_, dst_data, dst_data,
[](T a, T b) -> T { return a + b; });
}
}
};
inline void GatherLocalSelectedRows(
const std::vector<const SelectedRows *> &src_selecte_rows_,
const std::vector<platform::Place> &in_places,

@ -46,6 +46,7 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
pass_library(cpu_quantize_placement_pass base)
pass_library(cpu_quantize_pass inference)
pass_library(cpu_quantize_squash_pass inference)
pass_library(fc_fuse_pass inference)
@ -103,6 +104,7 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_cpu_quantize_placement_pass SRCS cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
if(NOT WIN32)

@ -0,0 +1,58 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/cpu_quantize_placement_pass.h"
#include <string>
#include <unordered_set>
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> CPUQuantizePlacementPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Marks operators which are to be quantized.";
const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids");
const auto& op_types_list =
Get<std::unordered_set<std::string>>("quantize_enabled_op_types");
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
if (std::find(excluded_ids_list.begin(), excluded_ids_list.end(),
n->id()) != excluded_ids_list.end())
continue;
auto* op = n->Op();
if (op->HasAttr("use_quantizer") || op->HasProtoAttr("use_quantizer")) {
if (op_types_list.empty()) {
op->SetAttr("use_quantizer", true);
} else if (std::find(op_types_list.begin(), op_types_list.end(),
n->Name()) != op_types_list.end()) {
op->SetAttr("use_quantizer", true);
}
}
}
}
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(cpu_quantize_placement_pass,
paddle::framework::ir::CPUQuantizePlacementPass)
// a vector of operator type names to be quantized ("conv2d" etc.)
.RequirePassAttr("quantize_enabled_op_types")
// a vector of operator ids that are to be excluded from quantization
.RequirePassAttr("quantize_excluded_op_ids");

@ -0,0 +1,34 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Specifies which operators should be quantized.
*/
class CPUQuantizePlacementPass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -0,0 +1,129 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cpu_quantize_placement_pass.h"
#include <gtest/gtest.h>
#include <boost/logic/tribool.hpp>
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
boost::tribool use_quantizer) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (!boost::indeterminate(use_quantizer))
op->SetAttr("use_quantizer", use_quantizer);
if (type == "conv2d") {
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
} else if (type == "relu") {
op->SetInput("X", inputs);
} else if (type == "concat") {
op->SetAttr("axis", 1);
op->SetInput("X", {inputs[0], inputs[1]});
} else if (type == "pool2d") {
op->SetInput("X", {inputs[0]});
} else {
FAIL() << "Unexpected operator type.";
}
op->SetOutput("Out", {outputs[0]});
}
// operator use_quantizer
// ---------------------------------------
// (a,b)->concat->c none
// (c,weights,bias)->conv->f false
// f->relu->g none
// g->pool->h false
// (h,weights2,bias2)->conv->k false
// k->pool->l false
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
SetOp(&prog, "concat", "concat1", {"a", "b"}, {"c"}, boost::indeterminate);
SetOp(&prog, "conv2d", "conv1", {"c", "weights", "bias"}, {"f"}, false);
SetOp(&prog, "relu", "relu1", {"f"}, {"g"}, boost::indeterminate);
SetOp(&prog, "pool2d", "pool1", {"g"}, {"h"}, false);
SetOp(&prog, "conv2d", "conv2", {"h", "weights2", "bias2"}, {"k"}, false);
SetOp(&prog, "pool2d", "pool2", {"k"}, {"l"}, false);
return prog;
}
void MainTest(std::initializer_list<std::string> quantize_enabled_op_types,
std::initializer_list<int> quantize_excluded_op_ids,
unsigned expected_use_quantizer_true_count) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("cpu_quantize_placement_pass");
pass->Set("quantize_enabled_op_types",
new std::unordered_set<std::string>(quantize_enabled_op_types));
pass->Set("quantize_excluded_op_ids",
new std::unordered_set<int>(quantize_excluded_op_ids));
graph = pass->Apply(std::move(graph));
unsigned use_quantizer_true_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->HasAttr("use_quantizer") &&
boost::get<bool>(op->GetAttr("use_quantizer"))) {
++use_quantizer_true_count;
}
}
}
EXPECT_EQ(use_quantizer_true_count, expected_use_quantizer_true_count);
}
TEST(QuantizerPlacementPass, enabled_pool) { MainTest({"pool2d"}, {}, 2); }
TEST(QuantizerPlacementPass, enabled_conv_excluded_one) {
MainTest({"conv2d"}, {4}, 1);
}
TEST(QuantizerPlacementPass, excluded_none) {
// 2 conv + 2 pool
MainTest({}, {}, 4);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(cpu_quantize_placement_pass);

@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
auto &inputs = op_desc.Input("X");
void operator()(InferVarTypeContext *ctx) const override {
auto &inputs = ctx->Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type);
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, default_var_type);
}
};
@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker {
class DummyOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {}
void operator()(framework::InferVarTypeContext *ctx) const override {}
};
} // namespace framework
} // namespace paddle

@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle {
namespace framework {
@ -677,7 +678,8 @@ void OpDesc::InferVarType(BlockDesc *block) const {
// var type inference. Hence, we don't do any "default" setting here.
auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) {
info.infer_var_type_(*this, block);
InferVarTypeContext context(this, block);
info.infer_var_type_(&context);
}
}

@ -254,18 +254,29 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
member_->places_, nccl_id, build_strategy.num_trainers_,
build_strategy.trainer_id_));
std::unique_ptr<platform::NCCLContextMap> dev_nccl_ctxs;
dev_nccl_ctxs.reset(new platform::NCCLContextMap(member_->places_));
// Initialize device context's nccl comm
// Note, more than one ParallelExecutor with same place, the nccl comm will
// Initialize device context's nccl comm, will be used by normal
// Operators like sync_batch_norm, and collective ops.
// NOTE: more than one ParallelExecutor with same place, the nccl comm will
// be rewrite and there will be some problem.
// NOTE: NCCL group-calls and non-group-calls can not use the same
// NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use
// same communicators.
std::unique_ptr<platform::NCCLContextMap> dev_nccl_ctxs;
if (nccl_id == nullptr) {
dev_nccl_ctxs.reset(new platform::NCCLContextMap(member_->places_));
}
for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) {
auto &nccl_ctx = dev_nccl_ctxs->at(dev_id);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
pool.Get(member_->places_[dev_id]));
dev_ctx->set_nccl_comm(nccl_ctx.comm());
if (nccl_id != nullptr) {
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[dev_id]);
dev_ctx->set_nccl_comm(nccl_ctx.comm());
} else {
auto &nccl_ctx = dev_nccl_ctxs->at(member_->places_[dev_id]);
dev_ctx->set_nccl_comm(nccl_ctx.comm());
}
}
#else
PADDLE_THROW("Not compiled with CUDA");

@ -34,7 +34,7 @@ DEFINE_double(
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, false,
DEFINE_bool(fast_eager_deletion_mode, true,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends.");

@ -44,6 +44,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
<< dst_place;
return;
}
#ifdef PADDLE_WITH_MKLDNN
if (src.layout() == DataLayout::kMKLDNN) {
dst->set_mkldnn_prim_desc(src.get_mkldnn_prim_desc());
}
#endif
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}

@ -27,6 +27,7 @@ namespace framework {
class OperatorBase;
class OpDesc;
class InferShapeContext;
class InferVarTypeContext;
class BlockDesc;
class Variable;
@ -53,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const std::vector<BlockDesc*>& grad_block)>;
using InferVarTypeFN =
std::function<void(const OpDesc& /*op_desc*/, BlockDesc* /*block*/)>;
std::function<void(framework::InferVarTypeContext* /*context*/)>;
using InferShapeFN = std::function<void(InferShapeContext*)>;

@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
@ -21,26 +23,123 @@ limitations under the License. */
namespace paddle {
namespace framework {
class OpDesc;
class BlockDesc;
// default infer var type context
class InferVarTypeContext {
public:
InferVarTypeContext(const OpDesc* op, BlockDesc* block)
: op_(op), block_(block) {}
virtual ~InferVarTypeContext() {}
virtual Attribute GetAttr(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->GetAttr(name);
}
virtual bool HasVar(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindVarRecursive(name) != nullptr;
}
virtual bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Inputs().count(name) > 0;
}
virtual bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Outputs().count(name) > 0;
}
virtual const std::vector<std::string>& Input(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Input(name);
}
virtual const std::vector<std::string>& Output(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Output(name);
}
virtual proto::VarType::Type GetType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetType();
}
virtual void SetType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetType(type);
}
virtual proto::VarType::Type GetDataType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataType();
}
virtual void SetDataType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataType(type);
}
virtual std::vector<proto::VarType::Type> GetDataTypes(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataTypes();
}
virtual void SetDataTypes(
const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type);
}
virtual std::vector<int64_t> GetShape(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetShape();
}
virtual void SetShape(const std::string& name,
const std::vector<int64_t>& dims) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetShape(dims);
}
virtual int32_t GetLoDLevel(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetLoDLevel();
}
virtual void SetLoDLevel(const std::string& name, int32_t lod_level) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level);
}
protected:
const OpDesc* op_;
BlockDesc* block_;
};
class VarTypeInference {
public:
virtual ~VarTypeInference() {}
virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0;
virtual void operator()(InferVarTypeContext* context) const = 0; // NOLINT
};
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const final {
void operator()(framework::InferVarTypeContext* ctx) const final { // NOLINT
auto in_out_var_names = this->GetInputOutputWithSameType();
for (auto& i_o_n : in_out_var_names) {
auto& x_name = op_desc.Input(i_o_n.first).at(0);
auto& out_name = op_desc.Output(i_o_n.second).at(0);
auto& x_name = ctx->Input(i_o_n.first).at(0);
auto& out_name = ctx->Output(i_o_n.second).at(0);
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
ctx->SetType(out_name, ctx->GetType(x_name));
ctx->SetDataType(out_name, ctx->GetDataType(x_name));
}
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save