You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/ir/lock_free_optimize_pass.h

132 lines
4.0 KiB

// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
#define PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
#include <string>
#include <vector>
#include <boost/algorithm/string/predicate.hpp>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
/*
* Remove the sum op of all gradients of the backward op.
* And remove the dependecies of the optimizer related to the
* same backward op.
*
* Before this pass:
*
* forward_op1 forward_op2
* | |
* grad_op1 grad_op2
* \ /
* \ /
* sum_op
* |
* sgd_op
*
* After this pass:
* forward_op1 forward_op2
* | |
* grad_op1 grad_op2
* | |
* sgd_op1 sgd_op2
*
* sgd_op1 and sgd_op2 will update the same weight which holds the same
* memory, so we could benefits from the acceleration
*/
class LockFreeOptimizePass : public Pass {
public:
virtual ~LockFreeOptimizePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
private:
// Create a new sgd node via current optimizer node
ir::Node* CreateNewSGDNode(ir::Graph* graph, ir::Node* forward_node,
ir::Node* backward_node, ir::Node* grad_sum_node,
ir::Node* optimize_node) const;
// Replace the input weight's optimizers
void ReplaceUpstreamNode(ir::Node* upstream_node,
ir::Node* old_optimizer_node,
ir::Node* new_optimizer_node) const;
// Replace the output weight's optimizers
void ReplaceAllDownstreamNode(ir::Node* old_optimizer_node,
ir::Node* new_optimizer_node) const;
// Find all weight variables in graph
bool FindAllWeightVars(ir::Graph* graph) const;
// Find the forward_op node via the backward_op node
ir::Node* FindForwardOpViaBackwardOp(ir::Graph* graph,
ir::Node* backward_node) const;
std::vector<ir::Node*> FindConnectedNode(ir::Node* upstream_node,
ir::Node* downstream_node) const;
inline bool IsOpNamed(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node);
return node->NodeType() == Node::Type::kOperation && node->Name() == name;
}
inline bool IsVarNamed(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node);
return node->NodeType() == Node::Type::kVariable && node->Name() == name;
}
inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node);
return node->NodeType() == Node::Type::kVariable &&
boost::algorithm::ends_with(node->Name(), name);
}
inline bool IsVarNameContains(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node);
return node->NodeType() == Node::Type::kVariable &&
node->Name().find(name) != std::string::npos;
}
inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const {
PADDLE_ENFORCE(ctrl_dep_node);
PADDLE_ENFORCE(node);
return IsControlDepVar(*ctrl_dep_node) &&
ctrl_dep_node->inputs.size() >= 1u &&
ctrl_dep_node->inputs[0] == node;
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_