Code Clean: Move all pass to paddle::framework::ir (#17228)
* move pass to ir * polish code test=develop * fix dependency test=developrevert-17304-fix_default_paddle_version
parent
648320bb6c
commit
04bd413acb
File diff suppressed because it is too large
Load Diff
@ -1,79 +0,0 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/build_strategy.h"
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
void SetFuseParameterGroupsSize(int group_size);
|
||||
int GetFuseParameterGroupsSize();
|
||||
|
||||
void SetFuseParameterMemorySize(uint64_t memory_size);
|
||||
uint64_t GetFuseParameterMemorySize();
|
||||
|
||||
class AllocContinuousSpaceForGradPass : public ir::Pass {
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph *graph) const override;
|
||||
|
||||
template <typename AttrType>
|
||||
void ResetAttribute(const std::string &attr_name, ir::Graph *graph) const;
|
||||
|
||||
void SetGroupGradsAndParams(
|
||||
const std::unordered_map<std::string, ir::Node *> &var_nodes,
|
||||
const ParamsAndGrads ¶ms_grads,
|
||||
GroupGradsAndParams *group_grads_params) const;
|
||||
|
||||
void SetGroupAccordingToLayers(
|
||||
const std::unordered_map<std::string, ir::Node *> &var_nodes,
|
||||
const ParamsAndGrads ¶ms_grads,
|
||||
GroupGradsAndParams *group_grads_params) const;
|
||||
|
||||
void SetGroupAccordingToMemorySize(
|
||||
const std::unordered_map<std::string, ir::Node *> &var_nodes,
|
||||
GroupGradsAndParams *group_grads_params) const;
|
||||
|
||||
void SetGroupAccordingToGroupSize(
|
||||
const std::unordered_map<std::string, ir::Node *> &var_nodes,
|
||||
GroupGradsAndParams *group_grads_params) const;
|
||||
|
||||
private:
|
||||
bool IsSupportedVarType(const proto::VarType::Type &type) const;
|
||||
|
||||
void RecordParamsAndGrads(ir::Node *node, ParamsAndGrads *params_grads) const;
|
||||
|
||||
void InitFusedVarsAndAllocSpaceForVars(
|
||||
const std::vector<platform::Place> &places,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
const std::unordered_map<std::string, ir::Node *> &vars,
|
||||
const std::string &fused_var_name,
|
||||
const ParamsAndGrads ¶ms_grads) const;
|
||||
|
||||
void AppendAllocSpaceForVarsOp(const std::vector<std::string> ¶ms_name,
|
||||
const std::vector<std::string> &grads_name,
|
||||
const std::string &fused_var_name,
|
||||
BlockDesc *global_block) const;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,31 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class ModifyOpLockAndRecordEventPass : public ir::Pass {
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,108 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
|
||||
#include "paddle/fluid/framework/op_proto_maker.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
|
||||
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
|
||||
op1->Outputs() == op2->Outputs();
|
||||
}
|
||||
|
||||
void SequentialExecutionPass::ApplyImpl(ir::Graph *graph) const {
|
||||
// FIXME(zjl): Insert dependencies between some distributed ops may cause
|
||||
// the multi_devices_graph_pass fails. So we skip these ops here.
|
||||
// Indeed, maybe we should not insert dependencies between these ops
|
||||
// casually, which may cause deadlock easily.
|
||||
// We should add more skipped distributed ops when found errors in
|
||||
// multi_devices_graph_pass
|
||||
static std::unordered_set<std::string> skip_dist_ops{
|
||||
"send", "recv", "send_barrier", "fetch_barrier"};
|
||||
|
||||
auto &ops = graph->Get<const std::vector<OpDesc *>>(kStaleProgramOpDescs);
|
||||
std::vector<ir::Node *> op_node_list;
|
||||
op_node_list.reserve(ops.size());
|
||||
|
||||
std::unordered_map<ir::Node *, size_t> op_deps;
|
||||
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
|
||||
std::unordered_set<ir::Node *> ready_ops;
|
||||
|
||||
for (ir::Node *node : graph->Nodes()) {
|
||||
if (!node->IsOp()) continue;
|
||||
std::unordered_set<ir::Node *> preceding_ops;
|
||||
for (auto *in : node->inputs) {
|
||||
PADDLE_ENFORCE(in->IsVar(),
|
||||
"Preceding Node of Op Nodes must be Var Node");
|
||||
if (in->inputs.empty()) continue;
|
||||
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
|
||||
"Preceding Op Node of Var Node must be unique");
|
||||
preceding_ops.insert(in->inputs[0]);
|
||||
pending_ops[in->inputs[0]].insert(node);
|
||||
}
|
||||
op_deps[node] = preceding_ops.size();
|
||||
if (preceding_ops.empty()) {
|
||||
ready_ops.insert(node);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto *op_desc : ops) {
|
||||
ir::Node *found_node = nullptr;
|
||||
for (auto *node : ready_ops) {
|
||||
if (IsSameOpDesc(op_desc, node->Op())) {
|
||||
PADDLE_ENFORCE(found_node == nullptr,
|
||||
"Found multiple op_desc in graph: %s", op_desc->Type());
|
||||
found_node = node;
|
||||
}
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
|
||||
op_desc->Type());
|
||||
for (auto *pending_op : pending_ops[found_node]) {
|
||||
if (--op_deps.at(pending_op) == 0) {
|
||||
ready_ops.insert(pending_op);
|
||||
}
|
||||
}
|
||||
ready_ops.erase(found_node);
|
||||
if (skip_dist_ops.count(op_desc->Type()) == 0) {
|
||||
op_node_list.push_back(found_node);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < op_node_list.size(); ++i) {
|
||||
auto *dep_var = graph->CreateControlDepVar();
|
||||
op_node_list[i]->inputs.push_back(dep_var);
|
||||
op_node_list[i - 1]->outputs.push_back(dep_var);
|
||||
dep_var->outputs.push_back(op_node_list[i]);
|
||||
dep_var->inputs.push_back(op_node_list[i - 1]);
|
||||
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
|
||||
<< " and " << op_node_list[i]->Name();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(sequential_execution_pass,
|
||||
paddle::framework::details::SequentialExecutionPass)
|
||||
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
|
@ -1,31 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class SequentialExecutionPass : public ir::Pass {
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,4 @@
|
||||
cc_library(fuse_optimizer_op_pass SRCS fuse_optimizer_op_pass.cc DEPS graph graph_helper)
|
||||
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc DEPS fuse_optimizer_op_pass)
|
||||
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc DEPS fuse_optimizer_op_pass)
|
||||
cc_library(fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc DEPS fuse_optimizer_op_pass)
|
@ -0,0 +1,18 @@
|
||||
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
|
||||
cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
|
||||
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
|
||||
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
|
||||
|
||||
if(WITH_GPU)
|
||||
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
|
||||
else()
|
||||
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info)
|
||||
endif()
|
||||
|
||||
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
|
||||
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
|
||||
|
||||
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
|
||||
|
||||
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass reference_count_pass_helper)
|
||||
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue