Make fuse_optimizer_op_pass also work when the model contains sparse gradients. (#18664)

* support sparse gradients
test=develop
DDDivano-patch-1
chengduo 6 years ago committed by GitHub
parent 6b78e00da4
commit fd3aad6cb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -95,5 +95,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass record_skip_memory_opt_vars_pass)

@ -76,6 +76,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
PADDLE_ENFORCE(!FLAGS_use_mkldnn,
"Please compile with MKLDNN first to use MKLDNN");
#endif
if (strategy_.enable_sequential_execution_) {
VLOG(1) << "Add sequential_execution_pass";
AppendPass("sequential_execution_pass");
@ -108,30 +109,34 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}
// for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
// coalesce_grad_tensor_pass should be before of MultiDevPass.
if (strategy_.fuse_all_reduce_ops_) {
VLOG(1) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
VLOG(1) << "Add coalesce_grad_tensor_pass";
AppendPass("coalesce_grad_tensor_pass");
}
// Fuse all the optimization operators.
if (strategy_.is_distribution_) {
VLOG(3) << "Currently, fuse_all_optimizer_ops only works under "
"Non-distributed mode.";
strategy_.fuse_all_optimizer_ops_ = false;
}
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce ||
strategy_.is_distribution_) {
VLOG(3) << "Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode.";
strategy_.fuse_all_optimizer_ops_ = false;
}
if (strategy_.fuse_all_optimizer_ops_) {
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce ||
strategy_.is_distribution_) {
VLOG(3)
<< "Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode.";
strategy_.fuse_all_optimizer_ops_ = false;
} else {
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
VLOG(1) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass");
VLOG(1) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass");
VLOG(1) << "Add fuse_momentum_op_pass";
AppendPass("fuse_momentum_op_pass");
}
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
VLOG(1) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass");
VLOG(1) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass");
VLOG(1) << "Add fuse_momentum_op_pass";
AppendPass("fuse_momentum_op_pass");
}
// Add a graph viz pass to record a graph.
@ -301,7 +306,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
#endif
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
} else if (pass->Type() == "coalesce_grad_tensor_pass" ||
pass->Type() == "fuse_adam_op_pass" ||
pass->Type() == "fuse_sgd_op_pass" ||
pass->Type() == "fuse_momentum_op_pass" ||
@ -321,7 +326,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
new bool(use_hierarchical_allreduce_));
#endif
}
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
} else if (pass->Type() == "coalesce_grad_tensor_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes);
@ -389,7 +394,7 @@ USE_PASS(backward_optimizer_op_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass);
USE_PASS(alloc_continuous_space_for_grad_pass);
USE_PASS(coalesce_grad_tensor_pass);
USE_PASS(graph_to_program_pass);
USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass);

@ -17,6 +17,7 @@
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/device_memory_aligment.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_bool(skip_fused_all_reduce_check, false, "");
@ -24,19 +25,6 @@ namespace paddle {
namespace framework {
namespace details {
// Note(zcd): Addresses should be aligned, otherwise, the results may have
// diff.
static size_t Alignment(size_t size, const platform::Place &place) {
// Allow to allocate the minimum chunk size is 4 KB.
size_t alignment = 1 << 12;
if (platform::is_gpu_place(place)) {
// Allow to allocate the minimum chunk size is 256 B.
alignment = 1 << 8;
}
size_t remaining = size % alignment;
return remaining == 0 ? size : size + (alignment - remaining);
}
typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
GradientAndLoDTensor;
@ -121,7 +109,7 @@ void FusedAllReduceOpHandle::RunImpl() {
for (size_t k = 1; k < g_tensor.size(); ++k) {
const void *cur_address = g_tensor.at(k - 1).second->data<void>();
int64_t len = g_tensor.at(k - 1).second->numel();
auto offset = Alignment(len * size_of_dtype, places_[0]);
auto offset = platform::Alignment(len * size_of_dtype, places_[0]);
void *infer_next_address = reinterpret_cast<void *>(
reinterpret_cast<uintptr_t>(cur_address) + offset);
const void *next_address = g_tensor.at(k).second->data<void>();
@ -241,8 +229,8 @@ void FusedAllReduceOpHandle::GetDTypeAndNumel(
// Get element number
int64_t len = grad_tensor.at(i).second->numel();
PADDLE_ENFORCE_GT(len, 0);
// Alignment(len)
*numel += Alignment(len * size_of_dtype, places_[0]) / size_of_dtype;
*numel +=
platform::Alignment(len * size_of_dtype, places_[0]) / size_of_dtype;
}
}

@ -62,11 +62,15 @@ typedef std::vector<std::string> FusedGrads;
constexpr char kFusedGrads[] = "fused_gradients";
typedef std::vector<std::pair<std::string, std::string>> ParamsAndGrads;
constexpr char kParamsAndGrads[] = "params_grads";
constexpr char kParamsAndDenseGrads[] = "params_and_dense_grads";
constexpr char kParamsAndSparseGrads[] = "params_and_sparse_grads";
typedef std::vector<ProgramDesc> ProgramDescs;
constexpr char kProgramDescs[] = "program_descs";
typedef std::vector<std::vector<std::pair<std::string, std::string>>>
GroupParamsAndGrads;
constexpr char kGroupParamsAndGrads[] = "group_params_grads";
constexpr char kGroupParamsAndDenseGrads[] = "group_params_dense_grads";
} // namespace details
} // namespace framework

@ -17,6 +17,8 @@
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/profiler.h"
@ -70,6 +72,29 @@ void ScopeBufferedSSAGraphExecutor::InitVariables() {
InitializeVariable(pair.first, pair.second);
}
}
const ir::Graph &graph = Graph();
if (graph.Has(details::kProgramDescs)) {
auto &program_descs =
graph.Get<details::ProgramDescs>(details::kProgramDescs);
// Init vars
auto &fused_grad_vars = graph.Get<details::FusedVars>(details::kFusedVars);
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
for (auto &var_name : fused_grad_vars) {
auto var = local_exec_scopes_[i]->Var(var_name);
var->GetMutable<LoDTensor>();
}
}
for (auto &program_desc : program_descs) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_exec_scopes_[i], places_[i]);
}
}
}
}
}
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {

@ -45,7 +45,7 @@ cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper)
pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)

@ -55,8 +55,8 @@ class FuseOptimizerOpPass : public ir::Pass {
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const = 0;
void GetSpecifiedOpsAndVars(
const std::string &op_type, const std::vector<std::string> &aux_vars_name,
ir::Node *node, std::vector<ir::Node *> *ops,
const std::vector<std::string> &aux_vars_name,
const std::vector<ir::Node *> &opt_nodes,
std::unordered_map<std::string, std::vector<std::string>> *aux_args_name)
const;
@ -67,27 +67,30 @@ class FuseOptimizerOpPass : public ir::Pass {
bool check_name = true) const;
void InitFusedGradsAndAllocSpaceForGrads(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::vector<std::string> &params,
const std::vector<std::string> &grads, const std::string &fused_grad_name,
ir::Graph *result) const;
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::vector<std::string> &aux_var_names,
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name)
const;
const std::unordered_map<std::string, std::string> &fused_vars_name,
ir::Graph *result) const;
std::unordered_map<std::string, std::vector<Node *>> GetVarInfo(
const Graph &result) const;
proto::VarType::Type GetTypeOfVar(
const std::unordered_map<std::string, std::vector<Node *>> &var_nodes,
const std::string &name) const;
void RunInitOps(const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const BlockDesc &global_block) const;
void GradientsFilter(const std::vector<size_t> &new_grad_idx,
std::vector<Node *> *opt_nodes,
std::unordered_map<std::string, std::vector<std::string>>
*aux_var_set) const;
void InitVars(const std::vector<Scope *> &local_scopes,
const std::string &fused_var_name) const;
bool IsLoDTensorType(const proto::VarType::Type &type) const;
};
} // namespace ir

@ -108,8 +108,6 @@ bool VarDescIsConsistency(const Graph &graph) {
var_name2node_set;
for (ir::Node *node : graph.Nodes()) {
if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
var_name2node_set[node->Var()->Name()].emplace(node);
}
}

@ -38,6 +38,10 @@ struct NodeComp {
bool HasCircle(const Graph &graph);
// Check if the var desc of node is consistency.
// The graph may have the same name node, for example, parameter
// is the input of operator and it also is the output of optimizer.
// For the persistable variable, the var_desc of the nodes with
// the same node name should be equal.
bool VarDescIsConsistency(const Graph &graph);
// Find All Circles for debugging,

@ -38,7 +38,7 @@ class FuseAllReduceOpPass : public ir::Pass {
#endif
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
result.Get<details::ParamsAndGrads>(details::kParamsAndDenseGrads);
size_t num_of_all_reduce = params_grads.size();
std::unordered_set<std::string> grads;
grads.reserve(num_of_all_reduce);
@ -60,8 +60,8 @@ class FuseAllReduceOpPass : public ir::Pass {
"it is not supported currently.");
VLOG(10) << "Insert fused_all_reduce";
auto &group_params_grads =
graph->Get<details::GroupParamsAndGrads>(details::kGroupParamsAndGrads);
auto &group_params_grads = graph->Get<details::GroupParamsAndGrads>(
details::kGroupParamsAndDenseGrads);
for (auto &group_p_g : group_params_grads) {
size_t group_size = group_p_g.size();

@ -49,7 +49,7 @@ class Node {
public:
virtual ~Node() {
if (!wrapper_.empty()) {
VLOG(4) << "ir::Node deleting a wrapper node " << Name();
VLOG(10) << "ir::Node deleting a wrapper node " << Name();
wrapper_deleter_();
}
}

@ -33,16 +33,12 @@ Graph* Pass::Apply(Graph* graph) const {
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.",
attr);
}
auto* native_graph = graph;
ApplyImpl(graph);
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*graph),
"Illegal Pass. Generated graph shouldn't has cycle.");
PADDLE_ENFORCE(VarDescIsConsistency(*graph),
"The VarDescs of persistable variable are not consistency.");
PADDLE_ENFORCE(graph == native_graph,
"Pass::Apply() cannot delete the passed graph and shouldn't "
"return a new graph.(For the need of pybind11)");
applied_ = true;
return graph;
}

@ -88,6 +88,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code
if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment)
# FIXME(typhoonzero): operator deps may not needed.
# op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)

@ -18,6 +18,7 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_memory_aligment.h"
namespace paddle {
namespace operators {
@ -26,7 +27,7 @@ static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL;
template <typename DeviceContext, typename T>
class AllocContinuousSpaceKernel : public framework::OpKernel<T> {
class CoalesceTensorOp : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto &in_var_names = context.Inputs("Input");
@ -86,8 +87,8 @@ class AllocContinuousSpaceKernel : public framework::OpKernel<T> {
framework::TensorCopy(*in_tensors[i], context.GetPlace(), dev_ctx,
&sub_tensor);
offset +=
Alignment(len * size_of_dtype, context.GetPlace()) / size_of_dtype;
offset += platform::Alignment(len * size_of_dtype, context.GetPlace()) /
size_of_dtype;
}
} else if (context.Attr<bool>("set_constant")) {
math::SetConstant<DeviceContext, T> set_constant;
@ -106,7 +107,8 @@ class AllocContinuousSpaceKernel : public framework::OpKernel<T> {
->ShareDataWith(fused_tensor->Slice(
static_cast<int64_t>(offset), static_cast<int64_t>(offset + len)))
.Resize(dim);
len = Alignment(len * size_of_dtype, context.GetPlace()) / size_of_dtype;
len = platform::Alignment(len * size_of_dtype, context.GetPlace()) /
size_of_dtype;
offset += len;
ss << "output(" << out_var_names[i] << ") dim:(" << dim << ")"
<< " address: " << out_tensors[i]->data<void>() << ", ";
@ -115,19 +117,6 @@ class AllocContinuousSpaceKernel : public framework::OpKernel<T> {
}
private:
// Note(zcd): Addresses should be aligned, otherwise, the results may have
// diff.
size_t Alignment(size_t size, const platform::Place &place) const {
// Allow to allocate the minimum chunk size is 4 KB.
size_t alignment = 1 << 12;
if (platform::is_gpu_place(place)) {
// Allow to allocate the minimum chunk size is 256 B.
alignment = 1 << 8;
}
size_t remaining = size % alignment;
return remaining == 0 ? size : size + (alignment - remaining);
}
void GetMemSizeAndDtype(
const std::vector<const framework::LoDTensor *> &lod_tensors,
const std::vector<std::string> var_names, size_t *numel,
@ -156,7 +145,8 @@ class AllocContinuousSpaceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_GT(size, 0);
ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims()
<< "), ";
*numel += Alignment(static_cast<size_t>(size) * size_of_dtype, place) /
*numel += platform::Alignment(static_cast<size_t>(size) * size_of_dtype,
place) /
size_of_dtype;
}
@ -176,17 +166,17 @@ class AllocContinuousSpaceOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("Input",
"(vector<LoDTensor>) The input tensors of"
" alloc_continuous_space operator.")
" coalesce_tensor operator.")
.AsDuplicable();
AddOutput("Output",
"(vector<LoDTensor>) The output "
"tensors of alloc_continuous_space operator. And the address "
"tensors of coalesce_tensor operator. And the address "
"of output tensors are continuous, they are sliced from the "
"tensor of FusedOutput.")
.AsDuplicable();
AddOutput("FusedOutput",
"(LoDTensor) The output tensor "
"of alloc_continuous_space operator. And the tensors of"
"of coalesce_tensor operator. And the tensors of"
" Output is sliced from the tensor of FusedOutput.");
AddAttr<bool>("copy_data", "Whether to copy the Input value to Output.")
.SetDefault(false);
@ -204,7 +194,7 @@ class AllocContinuousSpaceOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
AllocContinuousSpace Operator.
alloc_continuous_space is used to make the address of Output
coalesce_tensor is used to make the address of Output
continuous according to the Input. This Op will alloc a big tensor
according to the tensors of Input, the dtype is the same with those input tensors,
the size is the sum of those input tensors' numel, and the dim of the big
@ -213,7 +203,7 @@ The tensors of Output are sliced from the tensor of FusedOutput.
Note that, the dtype of Input should be the same, and the dim of Input
and Output should equal.
The tensors of Input and Output could be the same or different. And
alloc_continuous_space allows copying the value of Input to Output, or
coalesce_tensor allows copying the value of Input to Output, or
setting the Output with a constant value.
)DOC");
@ -223,27 +213,22 @@ setting the Output with a constant value.
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(alloc_continuous_space,
paddle::operators::AllocContinuousSpaceOp,
REGISTER_OPERATOR(coalesce_tensor, paddle::operators::AllocContinuousSpaceOp,
paddle::operators::AllocContinuousSpaceOpMaker);
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CPU_KERNEL(
alloc_continuous_space,
ops::AllocContinuousSpaceKernel<paddle::platform::CPUDeviceContext,
plat::float16>,
ops::AllocContinuousSpaceKernel<paddle::platform::CPUDeviceContext, int>,
ops::AllocContinuousSpaceKernel<paddle::platform::CPUDeviceContext, float>,
ops::AllocContinuousSpaceKernel<paddle::platform::CPUDeviceContext,
double>);
coalesce_tensor,
ops::CoalesceTensorOp<paddle::platform::CPUDeviceContext, plat::float16>,
ops::CoalesceTensorOp<paddle::platform::CPUDeviceContext, int>,
ops::CoalesceTensorOp<paddle::platform::CPUDeviceContext, float>,
ops::CoalesceTensorOp<paddle::platform::CPUDeviceContext, double>);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL(
alloc_continuous_space,
ops::AllocContinuousSpaceKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::AllocContinuousSpaceKernel<paddle::platform::CUDADeviceContext, int>,
ops::AllocContinuousSpaceKernel<paddle::platform::CUDADeviceContext, float>,
ops::AllocContinuousSpaceKernel<paddle::platform::CUDADeviceContext,
double>);
coalesce_tensor,
ops::CoalesceTensorOp<paddle::platform::CUDADeviceContext, plat::float16>,
ops::CoalesceTensorOp<paddle::platform::CUDADeviceContext, int>,
ops::CoalesceTensorOp<paddle::platform::CUDADeviceContext, float>,
ops::CoalesceTensorOp<paddle::platform::CUDADeviceContext, double>);
#endif

@ -102,17 +102,17 @@ cc_test(lodtensor_printer_test SRCS lodtensor_printer_test.cc DEPS lodtensor_pri
cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS})
if(WITH_GPU)
nv_library(profiler SRCS profiler.cc profiler.cu DEPS device_tracer gpu_info enforce)
nv_test(cuda_helper_test SRCS cuda_helper_test.cu)
nv_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info gpu_info place)
else()
cc_library(profiler SRCS profiler.cc DEPS device_tracer enforce)
cc_library(profiler SRCS profiler.cc DEPS device_tracer enforce)
cc_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info place)
endif()
cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor)
cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
IF(WITH_GPU)
nv_test(cuda_helper_test SRCS cuda_helper_test.cu)
ENDIF()
nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info)
if(WITH_GPU)

@ -0,0 +1,34 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_memory_aligment.h"
namespace paddle {
namespace platform {
size_t Alignment(size_t size, const platform::Place &place) {
size_t alignment = 1024;
if (platform::is_cpu_place(place)) {
alignment = CpuMinChunkSize();
} else {
#ifdef PADDLE_WITH_CUDA
alignment = GpuMinChunkSize();
#else
PADDLE_THROW("Fluid is not compiled with CUDA");
#endif
}
size_t remaining = size % alignment;
return remaining == 0 ? size : size + (alignment - remaining);
}
} // namespace platform
} // namespace paddle

@ -0,0 +1,27 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/gpu_info.h"
#endif
namespace paddle {
namespace platform {
size_t Alignment(size_t size, const platform::Place &place);
} // namespace platform
} // namespace paddle

@ -27,7 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h"
#include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"

@ -38,6 +38,7 @@ class TestParallelExecutorBase(unittest.TestCase):
batch_size=None,
allow_op_delay=False,
feed_dict=None,
get_data_from_feeder=None,
seed=None,
use_parallel_executor=True,
use_reduce=False,
@ -74,6 +75,10 @@ class TestParallelExecutorBase(unittest.TestCase):
if memory_opt:
fluid.memory_optimize(main)
if get_data_from_feeder is not None:
assert feed_dict is None
feed_dict = get_data_from_feeder()
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
@ -81,6 +86,7 @@ class TestParallelExecutorBase(unittest.TestCase):
exec_strategy.allow_op_delay = allow_op_delay
if use_fast_executor:
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce

@ -55,6 +55,34 @@ def fc_with_batchnorm(use_feed=None):
return loss
def bow_net(use_feed,
dict_dim,
is_sparse=False,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2):
"""
BOW net
This model is from https://github.com/PaddlePaddle/models:
fluid/PaddleNLP/text_classification/nets.py
"""
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
emb = fluid.layers.embedding(
input=data, is_sparse=is_sparse, size=[dict_dim, emb_dim])
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
def init_data(batch_size=32, img_shape=[784], label_range=9):
np.random.seed(5)
assert isinstance(img_shape, list)

@ -24,7 +24,7 @@ alignment = 256
class TestAllocContinuousSpace(OpTest):
def setUp(self):
self.op_type = "alloc_continuous_space"
self.op_type = "coalesce_tensor"
self.dtype = np.float32
attrs = self.init_attr()
self.copy_data = attrs["copy_data"]
@ -64,14 +64,13 @@ class TestAllocContinuousSpace(OpTest):
out[0:length] = input[1].flatten()
inputs.append(out)
alloc_continuous_space_var = np.concatenate([input for input in inputs])
coalesce_tensor_var = np.concatenate([input for input in inputs])
if set_constant:
alloc_continuous_space_var = np.ones(
(len(alloc_continuous_space_var))) * constant
coalesce_tensor_var = np.ones((len(coalesce_tensor_var))) * constant
outputs = [(out[0],
np.ones(out[1].shape).astype(self.dtype) * constant)
for out in outputs]
return outputs, alloc_continuous_space_var
return outputs, coalesce_tensor_var
def test_check_output(self):
if core.is_compiled_with_cuda():

@ -11,71 +11,122 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from simple_nets import simple_fc_net, fc_with_batchnorm, init_data
from simple_nets import simple_fc_net, fc_with_batchnorm, init_data, bow_net
from fake_reader import fake_imdb_reader
from parallel_executor_test_base import TestParallelExecutorBase
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
from functools import partial
import paddle
import paddle.dataset.mnist as mnist
import unittest
import os
class TestMNIST(TestParallelExecutorBase):
class TestFuseAllReduceOpsBase(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
def _init_data(self, random=True):
np.random.seed(5)
if random:
img = np.random.random(size=[32, 784]).astype(np.float32)
else:
img = np.ones(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
return img, label
def _compare_fuse_all_reduce_ops(self, model, use_cuda):
def compare_fuse_all_reduce_ops(self,
model,
use_cuda,
init_feed_dicta=None,
get_data_from_feeder=None,
optimizer=None,
fuse_all_optimizer_ops=False):
if use_cuda and not core.is_compiled_with_cuda():
return
img, label = init_data()
def _optimizer(learning_rate=1e-6):
optimizer = fluid.optimizer.SGD(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(1e-6))
return optimizer
feed_dict_data = None
if init_feed_dicta is not None:
img, label = init_feed_dicta()
feed_dict_data = {"image": img, "label": label}
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
feed_dict=feed_dict_data,
get_data_from_feeder=get_data_from_feeder,
use_cuda=use_cuda,
fuse_all_reduce_ops=False,
fuse_all_optimizer_ops=fuse_all_optimizer_ops,
memory_opt=False,
optimizer=_optimizer)
optimizer=optimizer)
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
feed_dict=feed_dict_data,
get_data_from_feeder=get_data_from_feeder,
use_cuda=use_cuda,
fuse_all_reduce_ops=True,
fuse_all_optimizer_ops=fuse_all_optimizer_ops,
memory_opt=False,
optimizer=_optimizer)
optimizer=optimizer)
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
def test_simple_fc_with_fuse_op(self):
self._compare_fuse_all_reduce_ops(simple_fc_net, True)
self._compare_fuse_all_reduce_ops(simple_fc_net, False)
def optimizer(self, learning_rate=1e-3):
optimizer = fluid.optimizer.SGD(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(1e-3))
return optimizer
class TestFuseAllReduceOps(TestFuseAllReduceOpsBase):
def _decorate_compare_fused_all_reduce(self, model, use_cuda):
self.compare_fuse_all_reduce_ops(
model,
use_cuda,
init_feed_dicta=init_data,
optimizer=self.optimizer,
fuse_all_optimizer_ops=True)
def test_simple_fc_with_fuse_all_reduce(self):
self._decorate_compare_fused_all_reduce(simple_fc_net, True)
self._decorate_compare_fused_all_reduce(simple_fc_net, False)
def test_batchnorm_fc_with_fuse_all_reduce(self):
self._decorate_compare_fused_all_reduce(fc_with_batchnorm, True)
self._decorate_compare_fused_all_reduce(fc_with_batchnorm, False)
class TestFuseAllReduceOpsAndOptiOps(TestFuseAllReduceOps):
def _decorate_compare_fused_all_reduce(self, model, use_cuda):
self.compare_fuse_all_reduce_ops(
model,
use_cuda,
init_feed_dicta=init_data,
optimizer=self.optimizer,
fuse_all_optimizer_ops=True)
class TestFuseAllReduceOpsWithSparseGrad(TestFuseAllReduceOpsBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
cls.word_dict_len = 5147
batch_size = 64
reader = fake_imdb_reader(cls.word_dict_len, batch_size * 100)
reader = paddle.batch(reader, batch_size=batch_size)()
cls.train_data = next(reader)
def get_data_from_feeder(self):
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=["words", "label"], place=place)
return feeder.feed(self.train_data)
def _decorate_compare_fused_all_reduce(self, model, use_cuda):
self.compare_fuse_all_reduce_ops(
model,
use_cuda,
get_data_from_feeder=self.get_data_from_feeder,
optimizer=self.optimizer)
def test_batchnorm_fc_with_fuse_op(self):
self._compare_fuse_all_reduce_ops(fc_with_batchnorm, True)
self._compare_fuse_all_reduce_ops(fc_with_batchnorm, False)
def test_simple_bow_net_with_fuse_all_reduce(self):
model = partial(bow_net, dict_dim=self.word_dict_len, is_sparse=True)
self._decorate_compare_fused_all_reduce(model, True)
self._decorate_compare_fused_all_reduce(model, False)
if __name__ == '__main__':

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save