commit
7cd2761736
@ -0,0 +1,126 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/details/build_strategy.h"
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
|
||||||
|
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
class ParallelExecutorPassBuilder : public ir::PassBuilder {
|
||||||
|
public:
|
||||||
|
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
|
||||||
|
: ir::PassBuilder(), strategy_(strategy) {
|
||||||
|
// Add a graph viz pass to record a graph.
|
||||||
|
if (!strategy_.debug_graphviz_path_.empty()) {
|
||||||
|
auto viz_pass = AppendPass("graph_viz_pass");
|
||||||
|
const std::string graph_path = string::Sprintf(
|
||||||
|
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
|
||||||
|
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add op fusion.
|
||||||
|
if (strategy.fuse_elewise_add_act_ops_) {
|
||||||
|
auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass");
|
||||||
|
// Add a graph viz pass to record a graph.
|
||||||
|
if (!strategy.debug_graphviz_path_.empty()) {
|
||||||
|
auto viz_pass = AppendPass("graph_viz_pass");
|
||||||
|
const std::string graph_path = string::Sprintf(
|
||||||
|
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
|
||||||
|
viz_pass->Set<std::string>("graph_viz_path",
|
||||||
|
new std::string(graph_path));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert graph to run on multi-devices.
|
||||||
|
auto multi_devices_pass = AppendPass("multi_devices_pass");
|
||||||
|
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
|
||||||
|
&strategy_);
|
||||||
|
|
||||||
|
// Add a graph print pass to record a graph with device info.
|
||||||
|
if (!strategy_.debug_graphviz_path_.empty()) {
|
||||||
|
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
|
||||||
|
multi_devices_print_pass->SetNotOwned<const std::string>(
|
||||||
|
"debug_graphviz_path", &strategy_.debug_graphviz_path_);
|
||||||
|
multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>(
|
||||||
|
"graph_printer", new details::GraphvizSSAGraphPrinter);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the graph is correct for multi-device executor.
|
||||||
|
AppendPass("multi_devices_check_pass");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
BuildStrategy strategy_;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy()
|
||||||
|
const {
|
||||||
|
pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
|
||||||
|
return pass_builder_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
|
||||||
|
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
|
||||||
|
const std::string &loss_var_name,
|
||||||
|
const std::unordered_set<std::string> ¶m_names,
|
||||||
|
const std::vector<Scope *> &local_scopes,
|
||||||
|
#ifdef PADDLE_WITH_CUDA
|
||||||
|
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
|
||||||
|
#else
|
||||||
|
const bool use_cuda) const {
|
||||||
|
#endif
|
||||||
|
// Create a default one if not initialized by user.
|
||||||
|
if (!pass_builder_) {
|
||||||
|
CreatePassesFromStrategy();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
|
||||||
|
|
||||||
|
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
|
||||||
|
if (pass->Type() == "multi_devices_pass") {
|
||||||
|
pass->Erase("places");
|
||||||
|
pass->SetNotOwned<const std::vector<platform::Place>>("places", &places);
|
||||||
|
pass->Erase("loss_var_name");
|
||||||
|
pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name);
|
||||||
|
pass->Erase("params");
|
||||||
|
pass->SetNotOwned<const std::unordered_set<std::string>>("params",
|
||||||
|
¶m_names);
|
||||||
|
pass->Erase("local_scopes");
|
||||||
|
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
|
||||||
|
&local_scopes);
|
||||||
|
#ifdef PADDLE_WITH_CUDA
|
||||||
|
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
|
||||||
|
pass->Erase("nccl_ctxs");
|
||||||
|
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
graph = pass->Apply(std::move(graph));
|
||||||
|
}
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
USE_PASS(fuse_elewise_add_act_pass);
|
||||||
|
USE_PASS(graph_viz_pass);
|
||||||
|
USE_PASS(multi_devices_pass);
|
||||||
|
USE_PASS(multi_devices_check_pass);
|
||||||
|
USE_PASS(multi_devices_print_pass);
|
@ -0,0 +1,43 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/ir/pass_builder.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) {
|
||||||
|
auto pass = ir::PassRegistry::Instance().Get(pass_type);
|
||||||
|
passes_.emplace_back(pass.release());
|
||||||
|
return passes_.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PassBuilder::RemovePass(size_t idx) {
|
||||||
|
PADDLE_ENFORCE(passes_.size() > idx);
|
||||||
|
passes_.erase(passes_.begin() + idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Pass> PassBuilder::InsertPass(size_t idx,
|
||||||
|
const std::string& pass_type) {
|
||||||
|
PADDLE_ENFORCE(passes_.size() >= idx);
|
||||||
|
std::shared_ptr<Pass> pass(
|
||||||
|
ir::PassRegistry::Instance().Get(pass_type).release());
|
||||||
|
passes_.insert(passes_.begin() + idx, std::move(pass));
|
||||||
|
return passes_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/ir/pass.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
class PassBuilder {
|
||||||
|
public:
|
||||||
|
PassBuilder() {}
|
||||||
|
|
||||||
|
virtual ~PassBuilder() {}
|
||||||
|
|
||||||
|
// Append a new pass to the end.
|
||||||
|
std::shared_ptr<Pass> AppendPass(const std::string& pass_type);
|
||||||
|
|
||||||
|
// Insert a new pass after `idx`.
|
||||||
|
std::shared_ptr<Pass> InsertPass(size_t idx, const std::string& pass_type);
|
||||||
|
|
||||||
|
// Remove a new pass at `idx`.
|
||||||
|
void RemovePass(size_t idx);
|
||||||
|
|
||||||
|
// Returns a list of all passes.
|
||||||
|
std::vector<std::shared_ptr<Pass>> AllPasses() const { return passes_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<std::shared_ptr<Pass>> passes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,121 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
import numpy as np
|
||||||
|
import unittest
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def simple_fc_net():
|
||||||
|
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
|
||||||
|
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||||
|
hidden = img
|
||||||
|
for _ in range(4):
|
||||||
|
hidden = fluid.layers.fc(
|
||||||
|
hidden,
|
||||||
|
size=200,
|
||||||
|
act='tanh',
|
||||||
|
bias_attr=fluid.ParamAttr(
|
||||||
|
initializer=fluid.initializer.Constant(value=1.0)))
|
||||||
|
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
|
||||||
|
loss = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||||
|
loss = fluid.layers.mean(loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class TestPassBuilder(unittest.TestCase):
|
||||||
|
def check_network_convergence(self, use_cuda, build_strategy=None):
|
||||||
|
os.environ['CPU_NUM'] = str(4)
|
||||||
|
main = fluid.Program()
|
||||||
|
startup = fluid.Program()
|
||||||
|
with fluid.program_guard(main, startup):
|
||||||
|
loss = simple_fc_net()
|
||||||
|
test_program = main.clone(for_test=True)
|
||||||
|
|
||||||
|
opt = fluid.optimizer.SGD(learning_rate=0.001)
|
||||||
|
opt.minimize(loss)
|
||||||
|
|
||||||
|
batch_size = 32
|
||||||
|
image = np.random.normal(size=(batch_size, 784)).astype('float32')
|
||||||
|
label = np.random.randint(0, 10, (batch_size, 1), dtype="int64")
|
||||||
|
|
||||||
|
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
exe.run(startup)
|
||||||
|
feed_dict = {'image': image, 'label': label}
|
||||||
|
|
||||||
|
train_exe = fluid.ParallelExecutor(
|
||||||
|
use_cuda=use_cuda,
|
||||||
|
loss_name=loss.name,
|
||||||
|
main_program=main,
|
||||||
|
build_strategy=build_strategy)
|
||||||
|
|
||||||
|
test_exe = fluid.ParallelExecutor(
|
||||||
|
use_cuda=use_cuda,
|
||||||
|
main_program=test_program,
|
||||||
|
share_vars_from=train_exe,
|
||||||
|
build_strategy=build_strategy)
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
test_loss, = test_exe.run([loss.name], feed=feed_dict)
|
||||||
|
|
||||||
|
train_loss, = train_exe.run([loss.name], feed=feed_dict)
|
||||||
|
|
||||||
|
avg_test_loss_val = np.array(test_loss).mean()
|
||||||
|
if math.isnan(float(avg_test_loss_val)):
|
||||||
|
sys.exit("got NaN loss, testing failed.")
|
||||||
|
|
||||||
|
avg_train_loss_val = np.array(train_loss).mean()
|
||||||
|
if math.isnan(float(avg_train_loss_val)):
|
||||||
|
sys.exit("got NaN loss, training failed.")
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(
|
||||||
|
train_loss, test_loss, atol=1e-8),
|
||||||
|
"Train loss: " + str(train_loss) + "\n Test loss:" +
|
||||||
|
str(test_loss))
|
||||||
|
|
||||||
|
def test_parallel_testing_with_new_strategy(self):
|
||||||
|
build_strategy = fluid.BuildStrategy()
|
||||||
|
pass_builder = build_strategy._create_passes_from_strategy()
|
||||||
|
origin_len = len(pass_builder.all_passes())
|
||||||
|
|
||||||
|
viz_pass = pass_builder.append_pass("graph_viz_pass")
|
||||||
|
self.assertEqual(origin_len + 1, len(pass_builder.all_passes()))
|
||||||
|
|
||||||
|
pass_builder.insert_pass(
|
||||||
|
len(pass_builder.all_passes()), "graph_viz_pass")
|
||||||
|
self.assertEqual(origin_len + 2, len(pass_builder.all_passes()))
|
||||||
|
|
||||||
|
pass_builder.remove_pass(len(pass_builder.all_passes()) - 1)
|
||||||
|
self.assertEqual(origin_len + 1, len(pass_builder.all_passes()))
|
||||||
|
viz_pass.set_str("graph_viz_path", "/tmp/test_viz_pass")
|
||||||
|
|
||||||
|
self.check_network_convergence(
|
||||||
|
use_cuda=core.is_compiled_with_cuda(),
|
||||||
|
build_strategy=build_strategy)
|
||||||
|
try:
|
||||||
|
os.stat("/tmp/test_viz_pass")
|
||||||
|
except os.error:
|
||||||
|
self.assertFalse(True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue