Add collective async wait op (#31463)
parent
0205e9f84e
commit
83a2fb1f08
@ -0,0 +1,91 @@
|
|||||||
|
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
class Scope;
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
|
||||||
|
#include "paddle/fluid/platform/collective_helper.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class CWaitCommOp : public framework::OperatorBase {
|
||||||
|
public:
|
||||||
|
CWaitCommOp(const std::string& type, const framework::VariableNameMap& inputs,
|
||||||
|
const framework::VariableNameMap& outputs,
|
||||||
|
const framework::AttributeMap& attrs)
|
||||||
|
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||||
|
|
||||||
|
void RunImpl(const framework::Scope& scope,
|
||||||
|
const platform::Place& place) const override {
|
||||||
|
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
|
||||||
|
platform::errors::PreconditionNotMet(
|
||||||
|
"wait_comm op can run on gpu place only for now."));
|
||||||
|
|
||||||
|
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
|
||||||
|
int ring_id = Attr<int>("ring_id");
|
||||||
|
|
||||||
|
auto compute_stream =
|
||||||
|
static_cast<platform::CUDADeviceContext*>(
|
||||||
|
platform::DeviceContextPool::Instance().Get(place))
|
||||||
|
->stream();
|
||||||
|
auto comm_stream =
|
||||||
|
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
|
||||||
|
|
||||||
|
auto event =
|
||||||
|
platform::NCCLCommContext::Instance().Get(ring_id, place)->comm_event();
|
||||||
|
|
||||||
|
// comm_stream-->event-->compute_stream
|
||||||
|
#ifdef PADDLE_WITH_HIP
|
||||||
|
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, comm_stream));
|
||||||
|
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(compute_stream, event, 0));
|
||||||
|
#else
|
||||||
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, comm_stream));
|
||||||
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(compute_stream, event, 0));
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
PADDLE_THROW(platform::errors::PreconditionNotMet(
|
||||||
|
"PaddlePaddle should compile with GPU."));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class CWaitCommOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() {
|
||||||
|
AddInput("X", "(Tensor) Dependency of the variable need to sync")
|
||||||
|
.AsDuplicable();
|
||||||
|
AddOutput("Out", "(Tensor) Dependency of the variable need to sync")
|
||||||
|
.AsDuplicable();
|
||||||
|
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
|
||||||
|
AddComment(R"DOC(
|
||||||
|
CWaitComm Operator
|
||||||
|
|
||||||
|
Compute stream wait Comm Stream with async event.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
|
||||||
|
REGISTER_OPERATOR(c_wait_comm, ops::CWaitCommOp, ops::CWaitCommOpMaker);
|
@ -0,0 +1,95 @@
|
|||||||
|
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
class Scope;
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
|
||||||
|
#include "paddle/fluid/platform/collective_helper.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class CWaitComputeOp : public framework::OperatorBase {
|
||||||
|
public:
|
||||||
|
CWaitComputeOp(const std::string& type,
|
||||||
|
const framework::VariableNameMap& inputs,
|
||||||
|
const framework::VariableNameMap& outputs,
|
||||||
|
const framework::AttributeMap& attrs)
|
||||||
|
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||||
|
|
||||||
|
void RunImpl(const framework::Scope& scope,
|
||||||
|
const platform::Place& place) const override {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
is_gpu_place(place), true,
|
||||||
|
platform::errors::PreconditionNotMet(
|
||||||
|
"wait_compute op can run on gpu place only for now."));
|
||||||
|
|
||||||
|
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
|
||||||
|
int ring_id = Attr<int>("ring_id");
|
||||||
|
|
||||||
|
auto compute_stream =
|
||||||
|
static_cast<platform::CUDADeviceContext*>(
|
||||||
|
platform::DeviceContextPool::Instance().Get(place))
|
||||||
|
->stream();
|
||||||
|
auto comm_stream =
|
||||||
|
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
|
||||||
|
|
||||||
|
auto event = platform::NCCLCommContext::Instance()
|
||||||
|
.Get(ring_id, place)
|
||||||
|
->compute_event();
|
||||||
|
|
||||||
|
// compute_stream-->event-->comm_stream
|
||||||
|
#ifdef PADDLE_WITH_HIP
|
||||||
|
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, compute_stream));
|
||||||
|
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(comm_stream, event, 0));
|
||||||
|
#else
|
||||||
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, compute_stream));
|
||||||
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(comm_stream, event, 0));
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
PADDLE_THROW(platform::errors::PreconditionNotMet(
|
||||||
|
"PaddlePaddle should compile with GPU."));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class CWaitComputeOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() {
|
||||||
|
AddInput("X", "(Tensor) Dependency of the variable need to sync")
|
||||||
|
.AsDuplicable();
|
||||||
|
AddOutput("Out", "(Tensor) Dependency of the variable need to sync")
|
||||||
|
.AsDuplicable();
|
||||||
|
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
|
||||||
|
AddComment(R"DOC(
|
||||||
|
CWaitCompute Operator
|
||||||
|
|
||||||
|
Comm stream wait Compute Stream with async event.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
|
||||||
|
REGISTER_OPERATOR(c_wait_compute, ops::CWaitComputeOp,
|
||||||
|
ops::CWaitComputeOpMaker);
|
@ -0,0 +1,114 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
import socket
|
||||||
|
from contextlib import closing
|
||||||
|
from six import string_types
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.profiler as profiler
|
||||||
|
import paddle.fluid.unique_name as nameGen
|
||||||
|
from paddle.fluid import core
|
||||||
|
import unittest
|
||||||
|
from multiprocessing import Process
|
||||||
|
import paddle.fluid.layers as layers
|
||||||
|
from functools import reduce
|
||||||
|
from test_collective_base import TestCollectiveRunnerBase, runtime_main
|
||||||
|
|
||||||
|
paddle.enable_static()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectiveAllreduce(TestCollectiveRunnerBase):
|
||||||
|
def __init__(self):
|
||||||
|
self.global_ring_id = 0
|
||||||
|
|
||||||
|
def get_model(self, main_prog, startup_program):
|
||||||
|
ring_id = 0
|
||||||
|
with fluid.program_guard(main_prog, startup_program):
|
||||||
|
tindata = layers.data(
|
||||||
|
name="tindata", shape=[10, 1000], dtype='float32')
|
||||||
|
toutdata = main_prog.current_block().create_var(
|
||||||
|
name="outofallreduce",
|
||||||
|
dtype='float32',
|
||||||
|
type=core.VarDesc.VarType.LOD_TENSOR,
|
||||||
|
persistable=False,
|
||||||
|
stop_gradient=False)
|
||||||
|
|
||||||
|
# tout = tin + tin - tin = tin
|
||||||
|
if True:
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type="elementwise_add",
|
||||||
|
inputs={
|
||||||
|
'X': tindata,
|
||||||
|
'Y': tindata,
|
||||||
|
},
|
||||||
|
outputs={'Out': toutdata}, )
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type="elementwise_sub",
|
||||||
|
inputs={
|
||||||
|
'X': toutdata,
|
||||||
|
'Y': tindata,
|
||||||
|
},
|
||||||
|
outputs={'Out': toutdata}, )
|
||||||
|
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type='c_wait_compute',
|
||||||
|
inputs={'X': toutdata},
|
||||||
|
outputs={'Out': toutdata},
|
||||||
|
attrs={'ring_id': ring_id})
|
||||||
|
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type="c_allreduce_sum",
|
||||||
|
inputs={'X': toutdata},
|
||||||
|
attrs={'ring_id': ring_id},
|
||||||
|
outputs={'Out': toutdata},
|
||||||
|
attr={'use_calc_stream': False})
|
||||||
|
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type="c_wait_comm",
|
||||||
|
inputs={'X': toutdata},
|
||||||
|
outputs={'Out': toutdata},
|
||||||
|
attrs={'ring_id': ring_id})
|
||||||
|
|
||||||
|
# tout = tin + tout - tin = tout
|
||||||
|
if True:
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type="elementwise_add",
|
||||||
|
inputs={
|
||||||
|
'X': tindata,
|
||||||
|
'Y': toutdata,
|
||||||
|
},
|
||||||
|
outputs={'Out': toutdata}, )
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type="elementwise_sub",
|
||||||
|
inputs={
|
||||||
|
'X': toutdata,
|
||||||
|
'Y': tindata,
|
||||||
|
},
|
||||||
|
outputs={'Out': toutdata}, )
|
||||||
|
|
||||||
|
return toutdata
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
runtime_main(TestCollectiveAllreduce, "allreduce", 0)
|
@ -0,0 +1,37 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from test_collective_base import TestDistBase
|
||||||
|
|
||||||
|
paddle.enable_static()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCWaitOp(TestDistBase):
|
||||||
|
def _setup_config(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_allreduce_wait(self):
|
||||||
|
self.check_with_place(
|
||||||
|
"collective_allreduce_op_wait.py",
|
||||||
|
"allreduce",
|
||||||
|
check_error_log=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue