Collective ops (#15572)
* wip allreduce in op * wip * wip * wip * wip adding test * wip for conflict with mp mode * fix tests test=develop * fix cpu build test=develop * fix travis clang format test=develop * fix cpu build test=develop * update api.spec test=develop * delete comment test=develop * fix cpplint test=develop * fix test=develop * follow comment test=develop * add file test=develop * fix build test=develop * update test=develop * to be compatible with sync_bn, and fix mp mode in develop test=developrevert-16190-refine_parallel_executor
parent
b9fc80a133
commit
6382b62f6b
@ -0,0 +1,143 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <future> // NOLINT
|
||||
#include <ostream>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include "paddle/fluid/platform/nccl_helper.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
struct MutableDataFunctor {
|
||||
MutableDataFunctor(void** data, framework::LoDTensor* tensor,
|
||||
const platform::Place& place)
|
||||
: data_(data), tensor_(tensor), place_(place) {}
|
||||
|
||||
template <typename T>
|
||||
void apply() {
|
||||
*data_ = tensor_->mutable_data<T>(place_);
|
||||
}
|
||||
|
||||
void** data_;
|
||||
framework::LoDTensor* tensor_;
|
||||
platform::Place place_;
|
||||
};
|
||||
|
||||
class AllReduceOp : public framework::OperatorBase {
|
||||
using OperatorBase::OperatorBase;
|
||||
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& place) const override {
|
||||
PADDLE_ENFORCE(is_gpu_place(place),
|
||||
"AllReduce op can run on gpu place only for now.");
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
||||
auto* ctx = pool.Get(place);
|
||||
auto in_names = Inputs("X");
|
||||
auto out_names = Outputs("Out");
|
||||
PADDLE_ENFORCE_EQ(in_names.size(), 1, "Only support one input");
|
||||
PADDLE_ENFORCE_EQ(out_names.size(), 1, "Only support one output");
|
||||
|
||||
auto* in = scope.FindVar(in_names[0]);
|
||||
auto* out = scope.FindVar(out_names[0]);
|
||||
|
||||
PADDLE_ENFORCE(in->IsType<framework::LoDTensor>() ||
|
||||
out->IsType<framework::LoDTensor>(),
|
||||
"Only support allreduce LoDTensors");
|
||||
|
||||
int dtype = -1;
|
||||
auto in_tensor = in->Get<framework::LoDTensor>();
|
||||
dtype = platform::ToNCCLDataType(in_tensor.type());
|
||||
|
||||
int64_t numel = in_tensor.numel();
|
||||
auto* sendbuff = in_tensor.data<void>();
|
||||
auto* out_tensor = out->GetMutable<framework::LoDTensor>();
|
||||
out_tensor->Resize(in_tensor.dims());
|
||||
void* recvbuff = nullptr;
|
||||
framework::VisitDataType(in_tensor.type(),
|
||||
MutableDataFunctor(&recvbuff, out_tensor, place));
|
||||
|
||||
auto cuda_ctx = static_cast<platform::CUDADeviceContext*>(ctx);
|
||||
auto* comm = cuda_ctx->nccl_comm();
|
||||
// FIXME(typhoonzero): should use nccl stream here.
|
||||
auto stream = cuda_ctx->stream();
|
||||
|
||||
int reduce_type = Attr<int>("reduce_type");
|
||||
ncclRedOp_t red_type = ncclSum;
|
||||
switch (reduce_type) {
|
||||
case 0:
|
||||
red_type = ncclSum;
|
||||
break;
|
||||
case 1:
|
||||
red_type = ncclProd;
|
||||
break;
|
||||
case 2:
|
||||
red_type = ncclMax;
|
||||
break;
|
||||
case 3:
|
||||
red_type = ncclMin;
|
||||
break;
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
|
||||
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
|
||||
comm, stream));
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
class AllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() {
|
||||
AddInput("X", "(Tensor), tensor to be allreduced.");
|
||||
AddOutput("Out", "(Tensor) the result of allreduced.");
|
||||
AddAttr<int>("reduce_type", "(int) determin the reduce type.")
|
||||
.SetDefault(0);
|
||||
AddComment(R"DOC(
|
||||
***AllReduce Operator***
|
||||
|
||||
Call NCCL AllReduce internally. Note that this op must be used when one
|
||||
thread is managing one GPU device.
|
||||
|
||||
For speed reasons, reduce_type should be an integer:
|
||||
|
||||
0: sum
|
||||
1: prod
|
||||
2: max
|
||||
3: min
|
||||
|
||||
If input and output are the same variable, in-place allreduce will be used.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class AllReduceOpShapeInference : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext* ctx) const override {}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(allreduce, ops::AllReduceOp,
|
||||
paddle::framework::EmptyGradOpMaker, ops::AllReduceOpMaker,
|
||||
ops::AllReduceOpShapeInference);
|
@ -0,0 +1,47 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
from ..layer_helper import LayerHelper, unique_name
|
||||
|
||||
|
||||
def _allreduce(x, out=None, reduce_type="sum"):
|
||||
helper = LayerHelper("allreduce", **locals())
|
||||
# Convert string reduce type to op int type
|
||||
red_typ_int = 0
|
||||
if reduce_type == "sum":
|
||||
red_typ_int = 0
|
||||
elif reduce_type == "prod":
|
||||
red_typ_int = 1
|
||||
elif reduce_type == "max":
|
||||
red_typ_int = 2
|
||||
elif reduce_type == "min":
|
||||
red_typ_int = 3
|
||||
else:
|
||||
raise TypeError("reduce type can only be [sum|prod|max|min]")
|
||||
|
||||
if out is None:
|
||||
out = helper.create_variable(
|
||||
name=unique_name.generate(".".join([x.name, 'tmp'])),
|
||||
shape=x.shape,
|
||||
dtype=x.dtype,
|
||||
type=x.type,
|
||||
persistable=x.persistable,
|
||||
stop_gradient=True)
|
||||
helper.append_op(
|
||||
type='allreduce',
|
||||
inputs={'X': [x]},
|
||||
outputs={'Out': [out]},
|
||||
attrs={"reduce_type": red_typ_int})
|
||||
return out
|
@ -0,0 +1,120 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.profiler as profiler
|
||||
from paddle.fluid import core
|
||||
import unittest
|
||||
from multiprocessing import Process
|
||||
import os
|
||||
import signal
|
||||
from functools import reduce
|
||||
from test_dist_base import TestDistRunnerBase, runtime_main
|
||||
|
||||
DTYPE = "float32"
|
||||
paddle.dataset.mnist.fetch()
|
||||
|
||||
# Fix seed for test
|
||||
fluid.default_startup_program().random_seed = 1
|
||||
fluid.default_main_program().random_seed = 1
|
||||
|
||||
|
||||
def cnn_model(data):
|
||||
conv_pool_1 = fluid.nets.simple_img_conv_pool(
|
||||
input=data,
|
||||
filter_size=5,
|
||||
num_filters=20,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act="relu",
|
||||
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
|
||||
value=0.01)))
|
||||
conv_pool_2 = fluid.nets.simple_img_conv_pool(
|
||||
input=conv_pool_1,
|
||||
filter_size=5,
|
||||
num_filters=50,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act="relu",
|
||||
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
|
||||
value=0.01)))
|
||||
|
||||
SIZE = 10
|
||||
input_shape = conv_pool_2.shape
|
||||
param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE]
|
||||
scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5
|
||||
|
||||
predict = fluid.layers.fc(
|
||||
input=conv_pool_2,
|
||||
size=SIZE,
|
||||
act="softmax",
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.01)))
|
||||
return predict
|
||||
|
||||
|
||||
class TestDistMnist2x2(TestDistRunnerBase):
|
||||
def get_model(self, batch_size=2, single_device=False):
|
||||
# Input data
|
||||
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
|
||||
# Train program
|
||||
predict = cnn_model(images)
|
||||
cost = fluid.layers.cross_entropy(input=predict, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
|
||||
# Evaluator
|
||||
batch_size_tensor = fluid.layers.create_tensor(dtype='int64')
|
||||
batch_acc = fluid.layers.accuracy(
|
||||
input=predict, label=label, total=batch_size_tensor)
|
||||
|
||||
inference_program = fluid.default_main_program().clone()
|
||||
|
||||
# Reader
|
||||
train_reader = paddle.batch(
|
||||
paddle.dataset.mnist.test(), batch_size=batch_size)
|
||||
test_reader = paddle.batch(
|
||||
paddle.dataset.mnist.test(), batch_size=batch_size)
|
||||
|
||||
# Optimization
|
||||
# TODO(typhoonzero): fix distributed adam optimizer
|
||||
# opt = fluid.optimizer.AdamOptimizer(
|
||||
# learning_rate=0.001, beta1=0.9, beta2=0.999)
|
||||
opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9)
|
||||
if single_device:
|
||||
opt.minimize(avg_cost)
|
||||
else:
|
||||
# multi device or distributed multi device
|
||||
params_grads = opt.backward(avg_cost)
|
||||
data_parallel_param_grads = []
|
||||
for p, g in params_grads:
|
||||
# NOTE: scale will be done on loss scale in multi_devices_graph_pass using nranks.
|
||||
grad_reduce = fluid.layers.collective._allreduce(g)
|
||||
data_parallel_param_grads.append([p, grad_reduce])
|
||||
opt.apply_gradients(data_parallel_param_grads)
|
||||
|
||||
return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime_main(TestDistMnist2x2)
|
@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import unittest
|
||||
from test_dist_base import TestDistBase
|
||||
|
||||
|
||||
class TestDistMnistNCCL2(TestDistBase):
|
||||
def _setup_config(self):
|
||||
self._sync_mode = True
|
||||
self._use_reduce = False
|
||||
self._use_reader_alloc = False
|
||||
self._nccl2_mode = True
|
||||
self._nccl2_reduce_layer = True
|
||||
|
||||
def test_dist_train(self):
|
||||
import paddle.fluid as fluid
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
self.check_with_place("dist_allreduce_op.py", delta=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue