parent
eb1aeb175b
commit
36c2a9af27
@ -0,0 +1,150 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/details/build_strategy.h"
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
|
||||
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class ParallelExecutorPassBuilder : public ir::PassBuilder {
|
||||
public:
|
||||
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
|
||||
: ir::PassBuilder(), strategy_(strategy) {
|
||||
// Apply a graph viz pass to record a graph.
|
||||
if (!strategy_.debug_graphviz_path_.empty()) {
|
||||
auto viz_pass = AppendPass("graph_viz_pass");
|
||||
const std::string graph_path = string::Sprintf(
|
||||
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
|
||||
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
|
||||
}
|
||||
|
||||
// Apply op fusion.
|
||||
if (strategy.fuse_elewise_add_act_ops_) {
|
||||
auto fuse_elewise_add_act_pass =
|
||||
ir::PassRegistry::Instance().Get("fuse_elewise_add_act_pass");
|
||||
graph = fuse_elewise_add_act_pass->Apply(std::move(graph));
|
||||
// Apply a graph viz pass to record a graph.
|
||||
if (!strategy.debug_graphviz_path_.empty()) {
|
||||
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
|
||||
const std::string graph_path = string::Sprintf(
|
||||
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
|
||||
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
|
||||
graph = viz_pass->Apply(std::move(graph));
|
||||
}
|
||||
}
|
||||
|
||||
// Convert graph to run on multi-devices.
|
||||
auto multi_devices_pass = AppendPass("multi_devices_pass");
|
||||
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
|
||||
&strategy_);
|
||||
|
||||
// Apply a graph print pass to record a graph with device info.
|
||||
if (!strategy_.debug_graphviz_path_.empty()) {
|
||||
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
|
||||
multi_devices_print_pass->SetNotOwned<const std::string>(
|
||||
"debug_graphviz_path", &strategy_.debug_graphviz_path_);
|
||||
multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>(
|
||||
"graph_printer", new details::GraphvizSSAGraphPrinter);
|
||||
}
|
||||
|
||||
// Verify that the graph is correct for multi-device executor.
|
||||
AppendPass("multi_devices_check_pass");
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> Build(
|
||||
const ProgramDesc &main_program,
|
||||
const std::vector<platform::Place> &places,
|
||||
const std::string &loss_var_name,
|
||||
const std::unordered_set<std::string> ¶m_names,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
|
||||
#else
|
||||
const bool use_cuda) const {
|
||||
#endif
|
||||
// Convert the program to graph.
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
|
||||
|
||||
for (std::shared_ptr<ir::Pass> &pass : AllPasses()) {
|
||||
if (pass->Type() == "multi_devices_pass") {
|
||||
pass->SetNotOwned<const std::vector<platform::Place>>("places",
|
||||
&places);
|
||||
pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name);
|
||||
pass->SetNotOwned<const std::unordered_set<std::string>>("params",
|
||||
¶m_names);
|
||||
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
|
||||
&local_scopes);
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
|
||||
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
|
||||
#endif
|
||||
}
|
||||
graph = pass->Apply(std::move(graph));
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
private:
|
||||
BuildStrategy strategy_;
|
||||
};
|
||||
|
||||
ir::PassBuilder *BuildStrategy::CreatePassBuilder() const {
|
||||
pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
|
||||
return pass_builder_.get();
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
|
||||
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
|
||||
const std::string &loss_var_name,
|
||||
const std::unordered_set<std::string> ¶m_names,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
|
||||
#else
|
||||
const bool use_cuda) const {
|
||||
#endif
|
||||
if (!pass_builder_) {
|
||||
CreatePassBuilder();
|
||||
}
|
||||
// std::unique_ptr<ir::Graph> graph;
|
||||
ParallelExecutorPassBuilder *builder =
|
||||
reinterpret_cast<ParallelExecutorPassBuilder *>(pass_builder_.get());
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
std::unique_ptr<ir::Graph> graph =
|
||||
builder->Build(main_program, places, loss_var_name, param_names,
|
||||
local_scopes, use_cuda, nccl_ctxs);
|
||||
#else
|
||||
std::unique_ptr<ir::Graph> graph = builder->Build(
|
||||
main_program, places, loss_var_name, param_names, local_scopes, use_cuda);
|
||||
#endif
|
||||
return graph;
|
||||
}
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(fuse_elewise_add_act_pass);
|
||||
USE_PASS(graph_viz_pass);
|
||||
USE_PASS(multi_devices_pass);
|
||||
USE_PASS(multi_devices_check_pass);
|
||||
USE_PASS(multi_devices_print_pass);
|
@ -0,0 +1,43 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/pass_builder.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) {
|
||||
auto pass = ir::PassRegistry::Instance().Get(pass_type);
|
||||
passes_.emplace_back(pass.release());
|
||||
return passes_.back();
|
||||
}
|
||||
|
||||
void PassBuilder::RemovePass(size_t idx) {
|
||||
PADDLE_ENFORCE(passes_.size() > idx);
|
||||
passes_.erase(passes_.begin() + idx);
|
||||
}
|
||||
|
||||
std::shared_ptr<Pass> PassBuilder::InsertPass(size_t idx,
|
||||
const std::string& pass_type) {
|
||||
PADDLE_ENFORCE(passes_.size() >= idx);
|
||||
std::shared_ptr<Pass> pass(
|
||||
ir::PassRegistry::Instance().Get(pass_type).release());
|
||||
passes_.insert(passes_.begin() + idx, std::move(pass));
|
||||
return passes_[idx];
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,45 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class PassBuilder {
|
||||
public:
|
||||
PassBuilder() {}
|
||||
|
||||
virtual ~PassBuilder() {}
|
||||
|
||||
std::shared_ptr<Pass> AppendPass(const std::string& pass_type);
|
||||
|
||||
std::shared_ptr<Pass> InsertPass(size_t idx, const std::string& pass_type);
|
||||
|
||||
void RemovePass(size_t idx);
|
||||
|
||||
std::vector<std::shared_ptr<Pass>> AllPasses() const { return passes_; }
|
||||
|
||||
protected:
|
||||
std::vector<std::shared_ptr<Pass>> passes_;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,110 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.core as core
|
||||
import numpy as np
|
||||
import unittest
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
|
||||
|
||||
def simple_fc_net():
|
||||
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
hidden = img
|
||||
for _ in range(4):
|
||||
hidden = fluid.layers.fc(
|
||||
hidden,
|
||||
size=200,
|
||||
act='tanh',
|
||||
bias_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=1.0)))
|
||||
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
|
||||
loss = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
loss = fluid.layers.mean(loss)
|
||||
return loss
|
||||
|
||||
|
||||
class TestPassBuilder(unittest.TestCase):
|
||||
def check_network_convergence(self, use_cuda, build_strategy=None):
|
||||
os.environ['CPU_NUM'] = str(4)
|
||||
main = fluid.Program()
|
||||
startup = fluid.Program()
|
||||
with fluid.program_guard(main, startup):
|
||||
loss = simple_fc_net()
|
||||
test_program = main.clone(for_test=True)
|
||||
|
||||
opt = fluid.optimizer.SGD(learning_rate=0.001)
|
||||
opt.minimize(loss)
|
||||
|
||||
batch_size = 32
|
||||
image = np.random.normal(size=(batch_size, 784)).astype('float32')
|
||||
label = np.random.randint(0, 10, (batch_size, 1), dtype="int64")
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup)
|
||||
feed_dict = {'image': image, 'label': label}
|
||||
|
||||
train_exe = fluid.ParallelExecutor(
|
||||
use_cuda=use_cuda,
|
||||
loss_name=loss.name,
|
||||
main_program=main,
|
||||
build_strategy=build_strategy)
|
||||
|
||||
test_exe = fluid.ParallelExecutor(
|
||||
use_cuda=use_cuda,
|
||||
main_program=test_program,
|
||||
share_vars_from=train_exe,
|
||||
build_strategy=build_strategy)
|
||||
|
||||
for i in range(5):
|
||||
test_loss, = test_exe.run([loss.name], feed=feed_dict)
|
||||
|
||||
train_loss, = train_exe.run([loss.name], feed=feed_dict)
|
||||
|
||||
avg_test_loss_val = np.array(test_loss).mean()
|
||||
if math.isnan(float(avg_test_loss_val)):
|
||||
sys.exit("got NaN loss, testing failed.")
|
||||
|
||||
avg_train_loss_val = np.array(train_loss).mean()
|
||||
if math.isnan(float(avg_train_loss_val)):
|
||||
sys.exit("got NaN loss, training failed.")
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
train_loss, test_loss, atol=1e-8),
|
||||
"Train loss: " + str(train_loss) + "\n Test loss:" +
|
||||
str(test_loss))
|
||||
|
||||
def test_parallel_testing_with_new_strategy(self):
|
||||
build_strategy = fluid.BuildStrategy()
|
||||
pass_builder = build_strategy.create_pass_builder()
|
||||
viz_pass = pass_builder.append_pass("graph_viz_pass")
|
||||
all_passes = pass_builder.all_passes()
|
||||
pass_builder.insert_pass(len(all_passes), "graph_viz_pass")
|
||||
pass_builder.remove_pass(len(pass_builder.all_passes()) - 1)
|
||||
viz_pass.set_str("graph_viz_path", "/tmp/viz_pass")
|
||||
|
||||
self.check_network_convergence(
|
||||
use_cuda=core.is_compiled_with_cuda(),
|
||||
build_strategy=build_strategy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue