ParallelDyGraph with GPU collective mode (#16827)
implement dygraph.parallel.DataParallel to hook reduce op.feature/fluid_trt_int8
parent
1a4a51db2b
commit
0b07eef118
@ -0,0 +1,25 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/distributed_ops/allreduce_op.h"
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
allreduce, ops::AllReduceOpKernel<plat::CUDADeviceContext, float>,
|
||||||
|
ops::AllReduceOpKernel<plat::CUDADeviceContext, double>,
|
||||||
|
ops::AllReduceOpKernel<plat::CUDADeviceContext, int>,
|
||||||
|
ops::AllReduceOpKernel<plat::CUDADeviceContext, int64_t>,
|
||||||
|
ops::AllReduceOpKernel<plat::CUDADeviceContext, plat::float16>);
|
@ -0,0 +1,87 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/data_type.h"
|
||||||
|
#include "paddle/fluid/framework/lod_tensor.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
||||||
|
#include "paddle/fluid/platform/nccl_helper.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class AllReduceOpKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto place = ctx.GetPlace();
|
||||||
|
PADDLE_ENFORCE(is_gpu_place(place),
|
||||||
|
"AllReduce op can run on gpu place only for now.");
|
||||||
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
||||||
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||||
|
auto in = ctx.Input<framework::Tensor>("X");
|
||||||
|
auto out = ctx.Output<framework::Tensor>("Out");
|
||||||
|
int dtype = platform::ToNCCLDataType(in->type());
|
||||||
|
int64_t numel = in->numel();
|
||||||
|
auto* sendbuff = in->data<void>();
|
||||||
|
out->Resize(in->dims());
|
||||||
|
void* recvbuff = out->mutable_data<T>(place);
|
||||||
|
|
||||||
|
auto* comm = dev_ctx.nccl_comm();
|
||||||
|
// FIXME(typhoonzero): should use nccl stream here.
|
||||||
|
auto stream = dev_ctx.stream();
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(stream, "Should initialize NCCL firstly.");
|
||||||
|
|
||||||
|
int reduce_type = ctx.Attr<int>("reduce_type");
|
||||||
|
ncclRedOp_t red_type = ncclSum;
|
||||||
|
switch (reduce_type) {
|
||||||
|
case 0:
|
||||||
|
red_type = ncclSum;
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
red_type = ncclProd;
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
red_type = ncclMax;
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
red_type = ncclMin;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
VLOG(0) << "call allreduce with type: " << reduce_type;
|
||||||
|
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
|
||||||
|
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
|
||||||
|
comm, stream));
|
||||||
|
if (ctx.Attr<bool>("sync_mode")) {
|
||||||
|
VLOG(0) << "sync allreduce...";
|
||||||
|
cudaError_t e_sync = cudaStreamSynchronize(stream);
|
||||||
|
if (e_sync != 0) {
|
||||||
|
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
PADDLE_THROW("PaddlePaddle should compile with GPU.");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,136 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import contextlib
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import six
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.dygraph as dygraph
|
||||||
|
from paddle.fluid import core
|
||||||
|
from paddle.fluid.optimizer import SGDOptimizer
|
||||||
|
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
|
||||||
|
from paddle.fluid.dygraph.base import to_variable
|
||||||
|
|
||||||
|
from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleImgConvPool(fluid.dygraph.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
name_scope,
|
||||||
|
num_channels,
|
||||||
|
num_filters,
|
||||||
|
filter_size,
|
||||||
|
pool_size,
|
||||||
|
pool_stride,
|
||||||
|
pool_padding=0,
|
||||||
|
pool_type='max',
|
||||||
|
global_pooling=False,
|
||||||
|
conv_stride=1,
|
||||||
|
conv_padding=0,
|
||||||
|
conv_dilation=1,
|
||||||
|
conv_groups=1,
|
||||||
|
act=None,
|
||||||
|
use_cudnn=False,
|
||||||
|
param_attr=None,
|
||||||
|
bias_attr=None):
|
||||||
|
super(SimpleImgConvPool, self).__init__(name_scope)
|
||||||
|
|
||||||
|
self._conv2d = Conv2D(
|
||||||
|
self.full_name(),
|
||||||
|
num_channels=num_channels,
|
||||||
|
num_filters=num_filters,
|
||||||
|
filter_size=filter_size,
|
||||||
|
stride=conv_stride,
|
||||||
|
padding=conv_padding,
|
||||||
|
dilation=conv_dilation,
|
||||||
|
groups=conv_groups,
|
||||||
|
param_attr=None,
|
||||||
|
bias_attr=None,
|
||||||
|
use_cudnn=use_cudnn)
|
||||||
|
|
||||||
|
self._pool2d = Pool2D(
|
||||||
|
self.full_name(),
|
||||||
|
pool_size=pool_size,
|
||||||
|
pool_type=pool_type,
|
||||||
|
pool_stride=pool_stride,
|
||||||
|
pool_padding=pool_padding,
|
||||||
|
global_pooling=global_pooling,
|
||||||
|
use_cudnn=use_cudnn)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
x = self._conv2d(inputs)
|
||||||
|
x = self._pool2d(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MNIST(fluid.dygraph.Layer):
|
||||||
|
def __init__(self, name_scope):
|
||||||
|
super(MNIST, self).__init__(name_scope)
|
||||||
|
|
||||||
|
self._simple_img_conv_pool_1 = SimpleImgConvPool(
|
||||||
|
self.full_name(), 1, 20, 5, 2, 2, act="relu")
|
||||||
|
|
||||||
|
self._simple_img_conv_pool_2 = SimpleImgConvPool(
|
||||||
|
self.full_name(), 20, 50, 5, 2, 2, act="relu")
|
||||||
|
|
||||||
|
pool_2_shape = 50 * 4 * 4
|
||||||
|
SIZE = 10
|
||||||
|
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
|
||||||
|
self._fc = FC(self.full_name(),
|
||||||
|
10,
|
||||||
|
param_attr=fluid.param_attr.ParamAttr(
|
||||||
|
initializer=fluid.initializer.NormalInitializer(
|
||||||
|
loc=0.0, scale=scale)),
|
||||||
|
act="softmax")
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
x = self._simple_img_conv_pool_1(inputs)
|
||||||
|
x = self._simple_img_conv_pool_2(x)
|
||||||
|
x = self._fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TestMnist(TestParallelDyGraphRunnerBase):
|
||||||
|
def get_model(self):
|
||||||
|
model = MNIST("mnist")
|
||||||
|
train_reader = paddle.batch(
|
||||||
|
paddle.dataset.mnist.train(), batch_size=2, drop_last=True)
|
||||||
|
opt = SGDOptimizer(learning_rate=1e-3)
|
||||||
|
return model, train_reader, opt
|
||||||
|
|
||||||
|
def run_one_loop(self, model, opt, data):
|
||||||
|
batch_size = len(data)
|
||||||
|
dy_x_data = np.array([x[0].reshape(1, 28, 28)
|
||||||
|
for x in data]).astype('float32')
|
||||||
|
y_data = np.array(
|
||||||
|
[x[1] for x in data]).astype('int64').reshape(batch_size, 1)
|
||||||
|
img = to_variable(dy_x_data)
|
||||||
|
label = to_variable(y_data)
|
||||||
|
label.stop_gradient = True
|
||||||
|
|
||||||
|
cost = model(img)
|
||||||
|
loss = fluid.layers.cross_entropy(cost, label)
|
||||||
|
avg_loss = fluid.layers.mean(loss)
|
||||||
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
runtime_main(TestMnist)
|
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
import unittest
|
||||||
|
from test_dist_base import TestDistBase
|
||||||
|
|
||||||
|
|
||||||
|
class TestParallelDygraphMnist(TestDistBase):
|
||||||
|
def _setup_config(self):
|
||||||
|
self._sync_mode = False
|
||||||
|
self._nccl2_mode = True
|
||||||
|
self._dygraph = True
|
||||||
|
|
||||||
|
def test_mnist(self):
|
||||||
|
self.check_with_place(
|
||||||
|
"parallel_dygraph_mnist.py", delta=1e-5, check_error_log=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue