parent
8f3b252392
commit
2739096eec
@ -0,0 +1,125 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/graph_print_pass.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class GraphvizVar : public GraphvizNode {
|
||||
public:
|
||||
GraphvizVar(ir::Node* n, const int& i) : GraphvizNode(n, i) {}
|
||||
friend std::ostream& operator<<(std::ostream& sout, const GraphvizVar& var) {
|
||||
sout << "var_" << var.id_ << " [label=\"" << var.node_->Name() << "\"]"
|
||||
<< std::endl;
|
||||
return sout;
|
||||
}
|
||||
};
|
||||
|
||||
class GraphvizOp : public GraphvizNode {
|
||||
public:
|
||||
GraphvizOp(ir::Node* n, const int& i) : GraphvizNode(n, i) {}
|
||||
friend std::ostream& operator<<(std::ostream& sout, const GraphvizOp& op) {
|
||||
sout << "op_" + std::to_string(op.id_) << " [label=\"" << op.node_->Name()
|
||||
<< "\", shape=rect]" << std::endl;
|
||||
PADDLE_ENFORCE(op.stream_.rdbuf()->in_avail() != 0,
|
||||
"No inputs outputs. Please call AddEdge first!");
|
||||
sout << op.stream_.str();
|
||||
return sout;
|
||||
}
|
||||
template <typename Callback>
|
||||
void AddEdge(const Callback& cb) {
|
||||
std::string op_name = "op_" + std::to_string(id_);
|
||||
for (auto var : node_->inputs) {
|
||||
std::string var_name = "var_" + std::to_string(cb(var));
|
||||
stream_ << var_name << "->" << op_name << std::endl;
|
||||
}
|
||||
for (auto var : node_->outputs) {
|
||||
std::string var_name = "var_" + std::to_string(cb(var));
|
||||
stream_ << op_name << "->" << var_name << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::ostringstream stream_;
|
||||
};
|
||||
|
||||
template <typename T, typename Container>
|
||||
std::vector<T*> FilterByNodeWrapper(const Container& con) {
|
||||
std::vector<T*> ret;
|
||||
for (auto& node : con) {
|
||||
auto i = dynamic_cast<T*>(node.get());
|
||||
if (i != nullptr) ret.emplace_back(i);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::unordered_map<ir::Node*, int> SSAGraphPrinterImpl::ToGraphvizNode(
|
||||
const ir::Graph& graph) const {
|
||||
// Convert to GraphvizNode format
|
||||
auto& graphviz_nodes = graph.Get<GraphvizNodes>(kGraphviz);
|
||||
graphviz_nodes.clear();
|
||||
std::unordered_map<ir::Node*, int> vars;
|
||||
int var_id = 0;
|
||||
int op_id = 0;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node->IsVar()) {
|
||||
graphviz_nodes.emplace(new GraphvizVar(node, var_id));
|
||||
vars.emplace(std::make_pair(node, var_id++));
|
||||
} else if (node->IsOp()) {
|
||||
graphviz_nodes.emplace(new GraphvizOp(node, op_id++));
|
||||
} else {
|
||||
PADDLE_THROW("Unknown op type");
|
||||
}
|
||||
}
|
||||
return vars;
|
||||
}
|
||||
|
||||
void SSAGraphPrinterImpl::Print(const ir::Graph& graph,
|
||||
std::ostream& sout) const {
|
||||
auto vars = ToGraphvizNode(graph);
|
||||
auto& nodes = graph.Get<GraphvizNodes>(kGraphviz);
|
||||
|
||||
sout << "digraph G {\n";
|
||||
for (auto& var : FilterByNodeWrapper<GraphvizVar>(nodes)) {
|
||||
sout << *var;
|
||||
}
|
||||
|
||||
for (auto& op : FilterByNodeWrapper<GraphvizOp>(nodes)) {
|
||||
op->AddEdge([&vars](ir::Node* var) { return vars.at(var); });
|
||||
sout << *op;
|
||||
}
|
||||
sout << "}\n";
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> SSAGraphPrintPass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
printer_.reset(new SSAGraphPrinterImpl());
|
||||
std::unique_ptr<std::ostream> fout(
|
||||
new std::ofstream(Get<std::string>(kGraphvizPath)));
|
||||
PADDLE_ENFORCE(fout->good() == true, "Failed to open file.");
|
||||
|
||||
printer_->Print(*graph, *fout);
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(graph_print_pass, paddle::framework::details::SSAGraphPrintPass)
|
||||
.RequirePassAttr(paddle::framework::details::kGraphvizPath);
|
@ -0,0 +1,66 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
constexpr char kGraphvizPath[] = "debug_graphviz_path";
|
||||
constexpr char kGraphviz[] = "graphviz";
|
||||
|
||||
class GraphvizNode {
|
||||
public:
|
||||
GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {}
|
||||
virtual ~GraphvizNode() = default;
|
||||
|
||||
protected:
|
||||
ir::Node* node_;
|
||||
int id_;
|
||||
};
|
||||
class GraphvizNode;
|
||||
typedef std::unordered_set<std::unique_ptr<GraphvizNode>> GraphvizNodes;
|
||||
|
||||
class SSAGraphPrinter {
|
||||
public:
|
||||
virtual ~SSAGraphPrinter() {}
|
||||
virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0;
|
||||
};
|
||||
|
||||
class SSAGraphPrinterImpl : public SSAGraphPrinter {
|
||||
public:
|
||||
void Print(const ir::Graph& graph, std::ostream& sout) const override;
|
||||
|
||||
private:
|
||||
std::unordered_map<ir::Node*, int> ToGraphvizNode(
|
||||
const ir::Graph& graph) const;
|
||||
};
|
||||
|
||||
class SSAGraphPrintPass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
|
||||
private:
|
||||
mutable std::unique_ptr<SSAGraphPrinter> printer_;
|
||||
};
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,79 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/graph_print_pass.h"
|
||||
#include "paddle/fluid/framework/details/graph_test_base.h"
|
||||
|
||||
REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
|
||||
paddle::framework::SumOpMaker);
|
||||
REGISTER_OPERATOR(split, paddle::framework::DummyOp,
|
||||
paddle::framework::SplitOpMaker);
|
||||
|
||||
/*
|
||||
a @ b
|
||||
c
|
||||
d @ e
|
||||
*/
|
||||
|
||||
using paddle::framework::ProgramDesc;
|
||||
using paddle::framework::proto::VarType;
|
||||
|
||||
inline static ProgramDesc FillProgramDesc() {
|
||||
ProgramDesc prog;
|
||||
prog.MutableBlock(0)->Var("a")->SetType(VarType::LOD_TENSOR);
|
||||
prog.MutableBlock(0)->Var("b")->SetType(VarType::LOD_TENSOR);
|
||||
prog.MutableBlock(0)->Var("c")->SetType(VarType::LOD_TENSOR);
|
||||
prog.MutableBlock(0)->Var("d")->SetType(VarType::LOD_TENSOR);
|
||||
prog.MutableBlock(0)->Var("e")->SetType(VarType::LOD_TENSOR);
|
||||
{
|
||||
auto* op = prog.MutableBlock(0)->AppendOp();
|
||||
op->SetType("sum");
|
||||
op->SetInput("X", {"a", "b"});
|
||||
op->SetOutput("Out", {"c"});
|
||||
}
|
||||
{
|
||||
auto* op = prog.MutableBlock(0)->AppendOp();
|
||||
op->SetType("split");
|
||||
op->SetInput("X", {"c"});
|
||||
op->SetOutput("Out", {"d", "e"});
|
||||
}
|
||||
{
|
||||
auto* op = prog.MutableBlock(0)->AppendOp();
|
||||
op->SetType("sum");
|
||||
op->SetInput("X", {"d", "e"});
|
||||
op->SetOutput("Out", {"d"});
|
||||
}
|
||||
return prog;
|
||||
}
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
TEST(SSAGraphPrinter, Normal) {
|
||||
auto program = FillProgramDesc();
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
|
||||
graph->Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
|
||||
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
|
||||
|
||||
// redirect debug graph to a file.
|
||||
constexpr char graph_path[] = "graph_print_pass.txt";
|
||||
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
|
||||
PADDLE_ENFORCE(fout->good());
|
||||
printer->Print(*graph, *fout);
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include "glog/logging.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class DummyOp : public OperatorBase {
|
||||
public:
|
||||
DummyOp(const std::string& type, const VariableNameMap& inputs,
|
||||
const VariableNameMap& outputs, const AttributeMap& attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
private:
|
||||
void RunImpl(const Scope& scope,
|
||||
const platform::Place& place) const override {}
|
||||
};
|
||||
|
||||
class SumOpMaker : public OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() {
|
||||
AddInput("X", "").AsDuplicable();
|
||||
AddOutput("Out", "");
|
||||
AddComment("");
|
||||
}
|
||||
};
|
||||
|
||||
class AssignOpMaker : public OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() {
|
||||
AddInput("X", "").AsDuplicable();
|
||||
AddOutput("Out", "");
|
||||
AddComment("");
|
||||
}
|
||||
};
|
||||
|
||||
class SplitOpMaker : public OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() {
|
||||
AddInput("X", "");
|
||||
AddOutput("Out", "").AsDuplicable();
|
||||
AddComment("");
|
||||
}
|
||||
};
|
||||
|
||||
class DummyVarTypeInference : public VarTypeInference {
|
||||
public:
|
||||
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
|
||||
auto& inputs = op_desc.Input("X");
|
||||
auto type = block->Var(inputs.front())->GetType();
|
||||
auto out_var_name = op_desc.Output("Out").front();
|
||||
block->Var(out_var_name)->SetType(type);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
from parallel_executor_test_base import TestParallelExecutorBase
|
||||
|
||||
|
||||
def fc_with_batchnorm(use_feed):
|
||||
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
|
||||
hidden = img
|
||||
for _ in range(3):
|
||||
hidden = fluid.layers.fc(
|
||||
hidden,
|
||||
size=200,
|
||||
act='tanh',
|
||||
bias_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=1.0)))
|
||||
|
||||
hidden = fluid.layers.batch_norm(input=hidden)
|
||||
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
|
||||
loss = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
loss = fluid.layers.mean(loss)
|
||||
return loss
|
||||
|
||||
|
||||
class TestIrInplace(TestParallelExecutorBase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
os.environ['CPU_NUM'] = str(4)
|
||||
|
||||
def _fc_with_batchnorm(self, ir_memory_optimize, enable_inplace):
|
||||
np.random.seed(5)
|
||||
img = np.random.random(size=[32, 784]).astype(np.float32)
|
||||
label = np.ones(shape=[32, 1], dtype='int64')
|
||||
self.check_network_convergence(
|
||||
fc_with_batchnorm,
|
||||
feed_dict={"image": img,
|
||||
"label": label},
|
||||
use_cuda=True,
|
||||
memory_opt=False, # inplace is conflict with memory opt
|
||||
use_ir_memory_optimize=ir_memory_optimize,
|
||||
enable_inplace=enable_inplace)
|
||||
|
||||
def test_fc_with_batchnorm(self, delta=1e-3):
|
||||
loss00 = self._fc_with_batchnorm(False, False)
|
||||
loss10 = self._fc_with_batchnorm(True, False)
|
||||
loss01 = self._fc_with_batchnorm(False, True)
|
||||
loss11 = self._fc_with_batchnorm(True, True)
|
||||
self.assertAlmostEqual(loss00, loss10, delta=delta)
|
||||
self.assertAlmostEqual(loss00, loss01, delta=delta)
|
||||
self.assertAlmostEqual(loss00, loss11, delta=delta)
|
Loading…
Reference in new issue