[api 2.0] add collective op for cpu using gloo and paddle.distributed.* apis (#26552)
add collective op for cpu using gloo and paddle.distributed.* apisrevert-26856-strategy_example2
parent
07973c577e
commit
1c68138327
@ -0,0 +1,47 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/collective/barrier_op.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class BarrierOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {}
|
||||
};
|
||||
|
||||
class BarrierOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() {
|
||||
AddInput("X", "(Tensor) Input data (only used in CUDAKernel).");
|
||||
AddOutput("Out", "(Tensor) Output data (only used in CUDAKernel).");
|
||||
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
|
||||
.SetDefault(0);
|
||||
AddComment(R"DOC(
|
||||
Barrier Operator - Barrier among all pariticapitors.)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
|
||||
REGISTER_OP_WITHOUT_GRADIENT(barrier, ops::BarrierOp, ops::BarrierOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(barrier, ops::BarrierOpCPUKernel<int>);
|
@ -0,0 +1,64 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/collective/barrier_op.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#if defined(PADDLE_WITH_NCCL)
|
||||
#include "paddle/fluid/platform/collective_helper.h"
|
||||
#include "paddle/fluid/platform/nccl_helper.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class BarrierOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
#if defined(PADDLE_WITH_NCCL)
|
||||
auto in = ctx.Input<framework::Tensor>("X");
|
||||
auto out = ctx.Output<framework::Tensor>("Out");
|
||||
|
||||
auto place = ctx.GetPlace();
|
||||
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
|
||||
int64_t numel = in->numel();
|
||||
const void* sendbuff = in->data<void>();
|
||||
void* recvbuff = out->mutable_data<T>(place);
|
||||
|
||||
int rid = ctx.Attr<int>("ring_id");
|
||||
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
|
||||
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
|
||||
auto stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
|
||||
ncclRedOp_t nccl_red_type = ncclSum;
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
|
||||
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
|
||||
auto comm_stream =
|
||||
platform::NCCLCommContext::Instance().Get(rid, place)->stream();
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(comm_stream));
|
||||
#else
|
||||
PADDLE_THROW(platform::errors::Unavailable(
|
||||
"PaddlePaddle should compile with NCCL."));
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(barrier, ops::BarrierOpCUDAKernel<int>);
|
@ -0,0 +1,54 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
#if defined(PADDLE_WITH_GLOO)
|
||||
#include <gloo/barrier.h>
|
||||
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class BarrierOpCPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
#if defined(PADDLE_WITH_GLOO)
|
||||
auto gloo = paddle::framework::GlooWrapper::GetInstance();
|
||||
PADDLE_ENFORCE_EQ(
|
||||
gloo->IsInitialized(), true,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"You must initialize the gloo environment first to use it."));
|
||||
gloo::BarrierOptions opts(gloo->GetContext());
|
||||
gloo::barrier(opts);
|
||||
#else
|
||||
PADDLE_THROW(platform::errors::Unavailable(
|
||||
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,33 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/platform/gloo_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
#if defined(PADDLE_WITH_GLOO)
|
||||
void GlooParallelContext::Init() {
|
||||
auto gloo_ptr = paddle::framework::GlooWrapper::GetInstance();
|
||||
gloo_ptr->SetRank(strategy_.rank);
|
||||
gloo_ptr->SetSize(strategy_.rank_num);
|
||||
gloo_ptr->SetPrefix(strategy_.prefix);
|
||||
gloo_ptr->SetIface(strategy_.iface);
|
||||
gloo_ptr->SetTimeoutSeconds(strategy_.init_seconds, strategy_.run_seconds);
|
||||
gloo_ptr->SetHdfsStore(strategy_.path, strategy_.fs_name, strategy_.fs_ugi);
|
||||
gloo_ptr->Init();
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,51 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
|
||||
#if defined(PADDLE_WITH_GLOO)
|
||||
struct GlooParallelStrategy {
|
||||
int rank{0};
|
||||
int rank_num{1};
|
||||
std::string iface;
|
||||
std::string prefix;
|
||||
int init_seconds{9999999};
|
||||
int run_seconds{9999999};
|
||||
std::string path;
|
||||
std::string fs_name;
|
||||
std::string fs_ugi;
|
||||
};
|
||||
|
||||
class GlooParallelContext {
|
||||
public:
|
||||
explicit GlooParallelContext(const GlooParallelStrategy& strategy)
|
||||
: strategy_(strategy) {}
|
||||
|
||||
virtual ~GlooParallelContext() {}
|
||||
|
||||
virtual void Init();
|
||||
|
||||
protected:
|
||||
GlooParallelStrategy strategy_;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,111 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/pybind/gloo_context_py.h"
|
||||
|
||||
#include <Python.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <pybind11/complex.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
|
||||
#include "paddle/fluid/platform/gloo_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
namespace py = ::pybind11;
|
||||
|
||||
// Bind Methods
|
||||
void BindGlooContext(py::module *m) {
|
||||
// define parallel context for gloo
|
||||
#if defined(PADDLE_WITH_GLOO)
|
||||
py::class_<platform::GlooParallelStrategy> gloo_parallel_strategy(
|
||||
*m, "GlooParallelStrategy", "");
|
||||
gloo_parallel_strategy.def(py::init())
|
||||
.def_property("rank_num",
|
||||
[](const platform::GlooParallelStrategy &self) {
|
||||
return self.rank_num;
|
||||
},
|
||||
[](platform::GlooParallelStrategy &self, int nranks) {
|
||||
self.rank_num = nranks;
|
||||
})
|
||||
.def_property(
|
||||
"rank",
|
||||
[](const platform::GlooParallelStrategy &self) { return self.rank; },
|
||||
[](platform::GlooParallelStrategy &self, int rank) {
|
||||
self.rank = rank;
|
||||
})
|
||||
.def_property(
|
||||
"iface",
|
||||
[](const platform::GlooParallelStrategy &self) { return self.iface; },
|
||||
[](platform::GlooParallelStrategy &self, const std::string &iface) {
|
||||
self.iface = iface;
|
||||
})
|
||||
.def_property("prefix",
|
||||
[](const platform::GlooParallelStrategy &self) {
|
||||
return self.prefix;
|
||||
},
|
||||
[](platform::GlooParallelStrategy &self,
|
||||
const std::string &prefix) { self.prefix = prefix; })
|
||||
.def_property("init_seconds",
|
||||
[](const platform::GlooParallelStrategy &self) {
|
||||
return self.init_seconds;
|
||||
},
|
||||
[](platform::GlooParallelStrategy &self, int init_seconds) {
|
||||
self.init_seconds = init_seconds;
|
||||
})
|
||||
.def_property("run_seconds",
|
||||
[](const platform::GlooParallelStrategy &self) {
|
||||
return self.run_seconds;
|
||||
},
|
||||
[](platform::GlooParallelStrategy &self, int run_seconds) {
|
||||
self.run_seconds = run_seconds;
|
||||
})
|
||||
.def_property(
|
||||
"path",
|
||||
[](const platform::GlooParallelStrategy &self) { return self.path; },
|
||||
[](platform::GlooParallelStrategy &self, const std::string &path) {
|
||||
self.path = path;
|
||||
})
|
||||
.def_property("fs_name",
|
||||
[](const platform::GlooParallelStrategy &self) {
|
||||
return self.fs_name;
|
||||
},
|
||||
[](platform::GlooParallelStrategy &self,
|
||||
const std::string &fs_name) { self.fs_name = fs_name; })
|
||||
.def_property("fs_ugi",
|
||||
[](const platform::GlooParallelStrategy &self) {
|
||||
return self.fs_ugi;
|
||||
},
|
||||
[](platform::GlooParallelStrategy &self,
|
||||
const std::string &fs_ugi) { self.fs_ugi = fs_ugi; });
|
||||
|
||||
py::class_<platform::GlooParallelContext> gloo_ctx(*m, "GlooParallelContext");
|
||||
gloo_ctx.def(py::init<const platform::GlooParallelStrategy &>())
|
||||
.def("init", [](platform::GlooParallelContext &self) { self.Init(); });
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
@ -0,0 +1,26 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#pragma once
|
||||
|
||||
#include <Python.h>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
void BindGlooContext(pybind11::module* m);
|
||||
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
import socket
|
||||
from contextlib import closing
|
||||
from six import string_types
|
||||
import math
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.profiler as profiler
|
||||
import paddle.fluid.unique_name as nameGen
|
||||
from paddle.fluid import core
|
||||
import unittest
|
||||
from multiprocessing import Process
|
||||
import paddle.fluid.layers as layers
|
||||
from functools import reduce
|
||||
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
|
||||
|
||||
|
||||
class TestCollectiveAllgatherAPI(TestCollectiveAPIRunnerBase):
|
||||
def __init__(self):
|
||||
self.global_ring_id = 0
|
||||
|
||||
def get_model(self, main_prog, startup_program, rank):
|
||||
with fluid.program_guard(main_prog, startup_program):
|
||||
tensor_list = []
|
||||
tindata = layers.data(
|
||||
name="tindata", shape=[10, 1000], dtype='float32')
|
||||
paddle.distributed.all_gather(tensor_list, tindata)
|
||||
return tensor_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime_main(TestCollectiveAllgatherAPI, "allgather")
|
@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
import socket
|
||||
from contextlib import closing
|
||||
from six import string_types
|
||||
import math
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.profiler as profiler
|
||||
import paddle.fluid.unique_name as nameGen
|
||||
from paddle.fluid import core
|
||||
import unittest
|
||||
from multiprocessing import Process
|
||||
import paddle.fluid.layers as layers
|
||||
from functools import reduce
|
||||
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
|
||||
|
||||
|
||||
class TestCollectiveAllreduceAPI(TestCollectiveAPIRunnerBase):
|
||||
def __init__(self):
|
||||
self.global_ring_id = 0
|
||||
|
||||
def get_model(self, main_prog, startup_program, rank):
|
||||
with fluid.program_guard(main_prog, startup_program):
|
||||
tindata = layers.data(
|
||||
name="tindata", shape=[10, 1000], dtype='float32')
|
||||
paddle.distributed.all_reduce(tindata)
|
||||
return [tindata]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime_main(TestCollectiveAllreduceAPI, "allreduce")
|
@ -0,0 +1,50 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
import socket
|
||||
from contextlib import closing
|
||||
from six import string_types
|
||||
import math
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.profiler as profiler
|
||||
import paddle.fluid.unique_name as nameGen
|
||||
from paddle.fluid import core
|
||||
import unittest
|
||||
from multiprocessing import Process
|
||||
import paddle.fluid.layers as layers
|
||||
from functools import reduce
|
||||
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
|
||||
|
||||
|
||||
class TestCollectiveBarrierAPI(TestCollectiveAPIRunnerBase):
|
||||
def __init__(self):
|
||||
self.global_ring_id = 0
|
||||
|
||||
def get_model(self, main_prog, startup_program, rank):
|
||||
with fluid.program_guard(main_prog, startup_program):
|
||||
paddle.distributed.barrier()
|
||||
return []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime_main(TestCollectiveBarrierAPI, "barrier")
|
@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
import socket
|
||||
from contextlib import closing
|
||||
from six import string_types
|
||||
import math
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.profiler as profiler
|
||||
import paddle.fluid.unique_name as nameGen
|
||||
from paddle.fluid import core
|
||||
import unittest
|
||||
from multiprocessing import Process
|
||||
import paddle.fluid.layers as layers
|
||||
from functools import reduce
|
||||
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
|
||||
|
||||
|
||||
class TestCollectiveBroadcastAPI(TestCollectiveAPIRunnerBase):
|
||||
def __init__(self):
|
||||
self.global_ring_id = 0
|
||||
|
||||
def get_model(self, main_prog, startup_program, rank):
|
||||
with fluid.program_guard(main_prog, startup_program):
|
||||
tindata = layers.data(
|
||||
name="tindata", shape=[10, 1000], dtype='float32')
|
||||
paddle.distributed.broadcast(tindata, src=1)
|
||||
return [tindata]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime_main(TestCollectiveBroadcastAPI, "broadcast")
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue