polish parallel dygraph code (#17164)

* add var grad hook test=develop
resnext-opt
Yan Xu 6 years ago committed by chengduo
parent d7df4e5e5b
commit 0217555530

@ -150,9 +150,9 @@ class Autograd {
const std::vector<VarBase*>& ingrads = it->second; const std::vector<VarBase*>& ingrads = it->second;
for (size_t i = 0; i < ingrads.size(); ++i) { for (size_t i = 0; i < ingrads.size(); ++i) {
if (!ingrads[i]) continue; if (!ingrads[i]) continue;
if (ready_op->input_vars_[it->first][i]->IsStopGradient()) { auto p = ready_op->input_vars_[it->first][i];
continue;
} if (p->IsStopGradient()) continue;
OpBase* pre_op = ready_op->pre_ops_[it->first][i]; OpBase* pre_op = ready_op->pre_ops_[it->first][i];
if (!pre_op) continue; if (!pre_op) continue;
@ -415,15 +415,11 @@ void OpBase::InvokeBackwardHooks() {
} }
} }
void OpBase::RegisterBackwardHooks(const py::object& callable, bool front) { void OpBase::RegisterBackwardHooks(const py::object& callable) {
VLOG(3) << "Register backward hooks " << trace_id_; VLOG(3) << "Register backward hooks " << trace_id_;
// TODO(minqiyang): check the callable format // TODO(minqiyang): check the callable format
if (front) { backward_hooks_.push_back(callable);
backward_hooks_.insert(backward_hooks_.begin(), callable);
} else {
backward_hooks_.push_back(callable);
}
} }
void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) { void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {

@ -310,7 +310,7 @@ class PYBIND11_HIDDEN OpBase {
return grad_op_descs_[index]->Type(); return grad_op_descs_[index]->Type();
} }
void RegisterBackwardHooks(const py::object& callable, bool front = false); void RegisterBackwardHooks(const py::object& callable);
void InvokeBackwardHooks(); void InvokeBackwardHooks();

@ -39,6 +39,7 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto in = ctx.Input<framework::Tensor>("X"); auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out"); auto out = ctx.Output<framework::Tensor>("Out");
int dtype = platform::ToNCCLDataType(in->type()); int dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel(); int64_t numel = in->numel();
auto* sendbuff = in->data<void>(); auto* sendbuff = in->data<void>();
@ -66,12 +67,10 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
red_type = ncclMin; red_type = ncclMin;
break; break;
} }
VLOG(0) << "call allreduce with type: " << reduce_type;
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type, sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
comm, stream)); comm, stream));
if (ctx.Attr<bool>("sync_mode")) { if (ctx.Attr<bool>("sync_mode")) {
VLOG(0) << "sync allreduce...";
cudaError_t e_sync = cudaStreamSynchronize(stream); cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) { if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync); LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync);

@ -252,11 +252,9 @@ PYBIND11_MODULE(core, m) {
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC") py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<const std::string &>()) .def(py::init<const std::string &>())
.def("register_backward_hooks", .def("register_backward_hooks",
[](imperative::OpBase &self, const py::object &callable, [](imperative::OpBase &self, const py::object &callable) {
bool front = false) { self.RegisterBackwardHooks(callable);
self.RegisterBackwardHooks(callable, front); })
},
py::arg("callable"), py::arg("front") = false)
.def_property("_trace_id", .def_property("_trace_id",
[](const imperative::OpBase &self) { [](const imperative::OpBase &self) {
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;

@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
import os import os
import six import six
import numpy as np
from .. import core from .. import core
from . import layers from . import layers
from .. import framework from .. import framework
from ..layers import collective from ..layers import collective
from . import to_variable
__all__ = ["prepare_context"] __all__ = ["prepare_context"]
@ -75,31 +77,33 @@ class Env(object):
class DataParallel(layers.Layer): class DataParallel(layers.Layer):
def __init__(self, layers): def __init__(self, layers, strategy):
super(DataParallel, super(DataParallel,
self).__init__(layers.full_name() + "_data_parallel") self).__init__(layers.full_name() + "_data_parallel")
self._layers = layers self._layers = layers
self._strategy = strategy
def build_once(self, *inputs, **kwargs):
#TODO(Yancey1989): broadcast all the paramters
pass
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
def _collective_hook(iop): return self._layers(*inputs, **kwargs)
op = framework._dygraph_tracer()._ops[iop._trace_id]
for k, v in six.iteritems(op.inputs): def scale_loss(self, loss):
for ivar in v: if self._strategy.nranks < 2:
g = ivar._grad_ivar() return loss
if g: loss_scale = to_variable(
g_var = framework.Variable( np.array([self._strategy.nranks]).astype("float32"))
block=self._helper.main_program.current_block(), loss_scale.stop_gradient = True
name=ivar._grad_name(), loss = loss / loss_scale
stop_gradient=True, return loss
ivar=g)
collective._allreduce(g_var, g_var, sync_mode=True) def apply_collective_grads(self):
if self._strategy.nranks < 2:
outs = self._layers(*inputs, **kwargs) return
for _, op in six.iteritems(framework._dygraph_tracer()._ops):
# hook collective ops for param in self._layers.parameters():
op.iop.register_backward_hooks(_collective_hook, front=True) if param.trainable and param._ivar._grad_ivar():
return outs g_var = framework.Variable(
block=self._helper.main_program.current_block(),
name=param._ivar._grad_name(),
stop_gradient=True,
ivar=param._ivar._grad_ivar())
collective._allreduce(g_var, g_var, sync_mode=True)

@ -101,11 +101,13 @@ class MNIST(fluid.dygraph.Layer):
loc=0.0, scale=scale)), loc=0.0, scale=scale)),
act="softmax") act="softmax")
def forward(self, inputs): def forward(self, inputs, label):
x = self._simple_img_conv_pool_1(inputs) x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x) x = self._simple_img_conv_pool_2(x)
x = self._fc(x) cost = self._fc(x)
return x loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
return avg_loss
class TestMnist(TestParallelDyGraphRunnerBase): class TestMnist(TestParallelDyGraphRunnerBase):
@ -113,7 +115,7 @@ class TestMnist(TestParallelDyGraphRunnerBase):
model = MNIST("mnist") model = MNIST("mnist")
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=2, drop_last=True) paddle.dataset.mnist.train(), batch_size=2, drop_last=True)
opt = SGDOptimizer(learning_rate=1e-3) opt = fluid.optimizer.SGD(learning_rate=1e-3)
return model, train_reader, opt return model, train_reader, opt
def run_one_loop(self, model, opt, data): def run_one_loop(self, model, opt, data):
@ -126,9 +128,8 @@ class TestMnist(TestParallelDyGraphRunnerBase):
label = to_variable(y_data) label = to_variable(y_data)
label.stop_gradient = True label.stop_gradient = True
cost = model(img) avg_loss = model(img, label)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
return avg_loss return avg_loss

@ -31,7 +31,7 @@ import paddle.fluid.dygraph as dygraph
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.parallel import DataParallel from paddle.fluid.dygraph.parallel import DataParallel
RUN_STEP = 10 RUN_STEP = 5
DEFAULT_BATCH_SIZE = 2 DEFAULT_BATCH_SIZE = 2
@ -200,6 +200,7 @@ class TestParallelDyGraphRunnerBase(object):
"train_one_loop should be implemented by the child classes.") "train_one_loop should be implemented by the child classes.")
def run_trainer(self, args): def run_trainer(self, args):
seed = 90 seed = 90
device_id = int(os.getenv("FLAGS_selected_gpus", "0")) device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id) place = fluid.CUDAPlace(device_id)
@ -217,32 +218,35 @@ class TestParallelDyGraphRunnerBase(object):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
np.random.seed(seed)
import random
random.seed = seed
model, train_reader, opt = self.get_model() model, train_reader, opt = self.get_model()
nranks = len(args.endpoints.split(",")) if args.endpoints else 1 nranks = len(args.endpoints.split(",")) if args.endpoints else 1
if args.update_method == "nccl2": if args.update_method == "nccl2":
sys.stderr.write("")
model = dygraph.parallel.DataParallel(model)
strategy = dygraph.parallel.ParallelStrategy() strategy = dygraph.parallel.ParallelStrategy()
strategy.nranks = nranks strategy.nranks = nranks
strategy.local_rank = args.trainer_id strategy.local_rank = args.trainer_id
strategy.trainer_endpoints = args.endpoints.split(",") strategy.trainer_endpoints = args.endpoints.split(",")
strategy.current_endpoint = args.current_endpoint strategy.current_endpoint = args.current_endpoint
dygraph.parallel.prepare_context(strategy) dygraph.parallel.prepare_context(strategy)
model = dygraph.parallel.DataParallel(model, strategy)
out_losses = [] out_losses = []
for step_id, data in enumerate(train_reader()): for step_id, data in enumerate(train_reader()):
data = _get_data(data) data = _get_data(data)
if step_id == RUN_STEP: if step_id == RUN_STEP:
break break
loss = self.run_one_loop(model, opt, data) loss = self.run_one_loop(model, opt, data)
out_losses.append(loss.numpy())
# FIXME(Yancey1989): scale the loss inplace # FIXME(Yancey1989): scale the loss inplace
loss.stop_gradient = True if args.update_method == "nccl2":
loss_scale = to_variable(np.array([nranks]).astype("float32")) loss = model.scale_loss(loss)
loss = loss / loss_scale
out_losses.append(loss.numpy())
loss.backward() loss.backward()
if args.update_method == "nccl2":
model.apply_collective_grads()
opt.minimize(loss) opt.minimize(loss)
model.clear_gradients() model.clear_gradients()
@ -663,9 +667,6 @@ class TestDistBase(unittest.TestCase):
local_loss = local_losses[step_id] local_loss = local_losses[step_id]
tr0_loss = tr0_losses[step_id] tr0_loss = tr0_losses[step_id]
tr1_loss = tr1_losses[step_id] tr1_loss = tr1_losses[step_id]
dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2
if not self._dygraph:
# Parallel DyGraph already scaled the loss in training
dist_loss = dist_loss / 2
print("=======", local_loss, ":", dist_loss[0], "=======") print("=======", local_loss, ":", dist_loss[0], "=======")
self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta) self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta)

@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from test_dist_base import TestDistBase from test_dist_base import TestDistBase
import paddle.fluid as fluid
class TestParallelDygraphMnist(TestDistBase): class TestParallelDygraphMnist(TestDistBase):
@ -24,8 +25,8 @@ class TestParallelDygraphMnist(TestDistBase):
self._dygraph = True self._dygraph = True
def test_mnist(self): def test_mnist(self):
self.check_with_place( if fluid.core.is_compiled_with_cuda():
"parallel_dygraph_mnist.py", delta=1e-5, check_error_log=True) self.check_with_place("parallel_dygraph_mnist.py", delta=1e-5)
if __name__ == "__main__": if __name__ == "__main__":

@ -0,0 +1,35 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_dist_base import TestDistBase
import paddle.fluid as fluid
class TestParallelDygraphSeResNeXt(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
def test_se_resnext(self):
# TODO(Yancey1989): BN and Dropout is related with batchsize, so the delta is the 1,
# try to remove the BN and Dropout in the network and using delta = 1e-5
if fluid.core.is_compiled_with_cuda():
self.check_with_place("parallel_dygraph_se_resnext.py", delta=1)
if __name__ == "__main__":
unittest.main()
Loading…
Cancel
Save