Open fuse optimization ops (#18741)

* open fuse optimization ops
test=develop
DDDivano-patch-1
chengduo 6 years ago committed by GitHub
parent 582cc29799
commit 4140fe11a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -88,7 +88,7 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types
bool fuse_all_optimizer_ops_{false};
bool fuse_all_optimizer_ops_{true};
bool fuse_all_reduce_ops_{false};
// fuse_relu_depthwise_conv can fuse the `relu ->
// depthwise_conv`

@ -483,6 +483,4 @@ class CoalesceGradTensorPass : public ir::Pass {
} // namespace paddle
REGISTER_PASS(coalesce_grad_tensor_pass,
paddle::framework::ir::CoalesceGradTensorPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
paddle::framework::ir::CoalesceGradTensorPass);

@ -204,6 +204,4 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::ir::FuseAdamOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::ir::FuseAdamOpPass);

@ -87,6 +87,4 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_momentum_op_pass, paddle::framework::ir::FuseMomentumOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
REGISTER_PASS(fuse_momentum_op_pass, paddle::framework::ir::FuseMomentumOpPass);

@ -65,6 +65,4 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::ir::FuseSgdOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::ir::FuseSgdOpPass);

@ -26,7 +26,7 @@ namespace paddle {
namespace framework {
namespace ir {
constexpr char kGraphvizPath[] = "debug_graphviz_path";
constexpr char kGraphvizPath[] = "graph_viz_path";
class SSAGraphPrinter {
public:

@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h"
@ -25,8 +26,6 @@ namespace framework {
namespace ir {
using inference::analysis::Dot;
namespace {
const char kGraphVizPath[] = "graph_viz_path";
std::string FormatName(const Node* node) {
if (!node->IsOp() || !node->Op() ||
!node->Op()->HasAttr(OpProtoAndCheckerMaker::OpNamescopeAttrName())) {
@ -39,7 +38,7 @@ std::string FormatName(const Node* node) {
} // namespace
void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
const std::string& graph_viz_path = Get<std::string>(kGraphvizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good());
@ -132,4 +131,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
} // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
.RequirePassAttr(paddle::framework::ir::kGraphvizPath);

@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_printer.h"
namespace paddle {
namespace framework {
@ -29,7 +29,12 @@ class SSAGraghBuilderWithPrinterPass : public ir::Pass {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
if (Has("graph_printer")) {
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
} else {
GraphvizSSAGraphPrinter printer;
printer.Print(*graph, *fout);
}
}
};

@ -24,6 +24,7 @@ namespace framework {
namespace ir {
Graph* Pass::Apply(Graph* graph) const {
CheckPrevPass();
PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty.");
for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
@ -41,6 +42,10 @@ Graph* Pass::Apply(Graph* graph) const {
PADDLE_ENFORCE(VarDescIsConsistency(*graph),
"The VarDescs of persistable variable are not consistency.");
applied_ = true;
if (!graph->Has(kPassRecorder)) {
graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
}
graph->Get<PassRecorder>(kPassRecorder).insert(Type());
return graph;
}

@ -20,6 +20,7 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
@ -31,6 +32,9 @@ namespace ir {
template <typename PassType>
struct PassRegistrar;
typedef std::unordered_set<std::string> PassRecorder;
constexpr char kPassRecorder[] = "pass_recorder";
class Pass {
public:
Pass() = default;
@ -104,6 +108,10 @@ class Pass {
LOG(FATAL) << "Calling virtual Pass not implemented.";
}
// Some Pass must be placed before this Pass, and some
// Pass must be placed after this Pass.
virtual void CheckPrevPass() const {}
private:
template <typename PassType>
friend struct PassRegistrar;

@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass_builder.h"
#include <memory>
#include <utility>
namespace paddle {
namespace framework {
namespace ir {
std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) {
VLOG(3) << "Append " << pass_type;
auto pass = ir::PassRegistry::Instance().Get(pass_type);
passes_.emplace_back(pass.release());
return passes_.back();

@ -26,7 +26,7 @@ class SyncBatchNormPass : public Pass {
void ApplyImpl(ir::Graph *graph) const override {
VLOG(3) << "Use synchronous batch norm";
for (const Node *n : graph->Nodes()) {
if (n->IsOp()) {
if (n->IsOp() && n->Op()) {
auto *op = n->Op();
if (op->Type() == "batch_norm") {
op->SetType("sync_batch_norm");

@ -32,6 +32,7 @@ feed_dict = {
class InplaceTestBase(unittest.TestCase):
def initParameter(self):
self.use_cuda = True
self.fuse_all_optimizer_ops = False
def setUp(self):
self.initParameter()
@ -39,7 +40,6 @@ class InplaceTestBase(unittest.TestCase):
self.device_count = fluid.core.get_cuda_device_count()
else:
self.device_count = 4
assert batch_size % self.device_count == 0
def build_program_and_scope(self):
@ -90,6 +90,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = memory_optimize
build_strategy.enable_inplace = enable_inplace
build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops
compiled_prog = fluid.CompiledProgram(prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
@ -135,6 +136,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = memory_optimize
build_strategy.enable_inplace = enable_inplace
build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops
compiled_program = fluid.CompiledProgram(
prog).with_data_parallel(
loss_name=loss.name,
@ -162,6 +164,19 @@ class InplaceTestBase(unittest.TestCase):
class CPUInplaceTest(InplaceTestBase):
def initParameter(self):
self.use_cuda = False
self.fuse_all_optimizer_ops = False
class CUDAInplaceTestWithFuseOptimizationOps(InplaceTestBase):
def initParameter(self):
self.use_cuda = True
self.fuse_all_optimizer_ops = True
class CPUInplaceTestWithFuseOptimizationOps(InplaceTestBase):
def initParameter(self):
self.use_cuda = True
self.fuse_all_optimizer_ops = True
if __name__ == '__main__':

Loading…
Cancel
Save