PaddleBox Framework Part2 (#22466)

* Add two types of Metric Calculator: MultiTaskCalculator & CmatchRankCalculator.
* Add a config for DynamicAdjustChannelNum function to denote whether we will discard the remaining instances when they are not be distributed evenly.
* Remove CPU code in Pull/PushSparse and we will add it back when testing it fully.
* Fix some known issues: such as copying persistable vars after one epoch running.
revert-22710-feature/integrated_ps_api
hutuxian 5 years ago committed by GitHub
parent 3132681e8a
commit 175954d894
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,7 +19,7 @@ IF((NOT DEFINED BOX_PS_VER) OR (NOT DEFINED BOX_PS_URL))
MESSAGE(STATUS "use pre defined download url")
SET(BOX_PS_VER "0.1.1" CACHE STRING "" FORCE)
SET(BOX_PS_NAME "box_ps" CACHE STRING "" FORCE)
SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps_stub.tar.gz" CACHE STRING "" FORCE)
SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps.tar.gz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "BOX_PS_NAME: ${BOX_PS_NAME}, BOX_PS_URL: ${BOX_PS_URL}")
SET(BOX_PS_SOURCE_DIR "${THIRD_PARTY_PATH}/box_ps")

@ -193,7 +193,7 @@ if(WITH_DISTRIBUTE)
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
@ -204,7 +204,7 @@ else()
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method
lod_rank_table fs shell fleet_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()

@ -43,6 +43,7 @@ class ChannelObject {
capacity_ = (std::min)(MaxCapacity(), capacity);
}
const std::deque<T>& GetData() const { return data_; }
void Clear() {
std::unique_lock<std::mutex> lock(mutex_);
data_.clear();

@ -390,7 +390,8 @@ void DatasetImpl<T>::GlobalShuffle(int thread_num) {
}
template <typename T>
void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num) {
void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num,
bool discard_remaining_ins) {
if (channel_num_ == channel_num) {
VLOG(3) << "DatasetImpl<T>::DynamicAdjustChannelNum channel_num_="
<< channel_num_ << ", channel_num_=channel_num, no need to adjust";
@ -439,13 +440,13 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num) {
total_data_channel->Write(std::move(local_vec));
}
total_data_channel->Close();
total_data_channel->SetBlockSize(total_data_channel->Size() / channel_num +
1);
// will discard the remaining instances,
// TODO(hutuxian): should add a config here to choose how to deal with
// remaining instances
if (static_cast<int>(total_data_channel->Size()) >= channel_num) {
total_data_channel->SetBlockSize(total_data_channel->Size() / channel_num +
(discard_remaining_ins ? 0 : 1));
}
if (static_cast<int>(input_channel_->Size()) >= channel_num) {
input_channel_->SetBlockSize(input_channel_->Size() / channel_num);
input_channel_->SetBlockSize(input_channel_->Size() / channel_num +
(discard_remaining_ins ? 0 : 1));
}
for (int i = 0; i < channel_num; ++i) {

@ -126,8 +126,9 @@ class Dataset {
virtual void DestroyPreLoadReaders() = 0;
// set preload thread num
virtual void SetPreLoadThreadNum(int thread_num) = 0;
// separate train thread and dataset thread
virtual void DynamicAdjustChannelNum(int channel_num) = 0;
// seperate train thread and dataset thread
virtual void DynamicAdjustChannelNum(int channel_num,
bool discard_remaining_ins = false) = 0;
virtual void DynamicAdjustReadersNum(int thread_num) = 0;
// set fleet send sleep seconds
virtual void SetFleetSendSleepSeconds(int seconds) = 0;
@ -195,7 +196,8 @@ class DatasetImpl : public Dataset {
virtual void CreatePreLoadReaders();
virtual void DestroyPreLoadReaders();
virtual void SetPreLoadThreadNum(int thread_num);
virtual void DynamicAdjustChannelNum(int channel_num);
virtual void DynamicAdjustChannelNum(int channel_num,
bool discard_remaining_ins = false);
virtual void DynamicAdjustReadersNum(int thread_num);
virtual void SetFleetSendSleepSeconds(int seconds);

@ -8,7 +8,7 @@ if(WITH_NCCL)
cc_library(nccl_wrapper SRCS nccl_wrapper.cc DEPS framework_proto variable_helper scope)
endif()
if(WITH_BOX_PS)
cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor box_ps)
nv_library(box_wrapper SRCS box_wrapper.cc box_wrapper.cu DEPS framework_proto lod_tensor box_ps)
else()
cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor)
endif(WITH_BOX_PS)

@ -91,6 +91,7 @@ void BasicAucCalculator::calculate_bucket_error() {
_bucket_error = error_count > 0 ? error_sum / error_count : 0.0;
}
// Deprecated: should use BeginFeedPass & EndFeedPass
void BoxWrapper::FeedPass(int date,
const std::vector<uint64_t>& feasgin_to_box) const {
int ret = boxps_ptr_->FeedPass(date, feasgin_to_box);
@ -140,47 +141,8 @@ void BoxWrapper::PullSparse(const paddle::platform::Place& place,
reinterpret_cast<boxps::FeatureValueGpu*>(buf->ptr());
if (platform::is_cpu_place(place)) {
// Note: Only GPU is supported in paddlebox now, and following code have not
// be tested fully yet
LoDTensor total_keys_tensor;
uint64_t* total_keys = reinterpret_cast<uint64_t*>(
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place));
int64_t offset = 0;
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
for (size_t i = 0; i < keys.size(); ++i) {
memory::Copy(boost::get<platform::CPUPlace>(place), total_keys + offset,
boost::get<platform::CPUPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t));
offset += slot_lengths[i];
}
VLOG(3) << "Begin call PullSparseCPU in BoxPS";
pull_boxps_timer.Start();
// TODO(hutuxian): should use boxps::FeatureValue in the future
int ret = boxps_ptr_->PullSparseCPU(total_keys, total_values_gpu,
static_cast<int>(total_length));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"PullSparseCPU failed in BoxPS."));
pull_boxps_timer.Pause();
VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length
<< "]";
offset = 0;
for (size_t i = 0; i < values.size(); ++i) {
int64_t fea_num = slot_lengths[i];
VLOG(3) << "Begin Copy slot[" << i << "] fea_num[" << fea_num << "]";
for (auto j = 0; j < fea_num; ++j) {
// Copy the emb from BoxPS to paddle tensor. Since
// 'show','click','emb'
// are continuous in memory, so we copy here using the 'show' address
memory::Copy(
boost::get<platform::CPUPlace>(place), values[i] + j * hidden_size,
boost::get<platform::CPUPlace>(place),
reinterpret_cast<float*>(&((total_values_gpu + offset)->show)),
sizeof(float) * hidden_size);
++offset;
}
}
PADDLE_THROW(platform::errors::Unimplemented(
"Warning:: CPUPlace is not supported in PaddleBox now."));
} else if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
@ -253,43 +215,8 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
boxps::FeaturePushValueGpu* total_grad_values_gpu =
reinterpret_cast<boxps::FeaturePushValueGpu*>(buf->ptr());
if (platform::is_cpu_place(place)) {
// Note: only GPU is supported in paddlebox now, and following code have not
// be tested fully yet
LoDTensor total_keys_tensor;
uint64_t* total_keys = reinterpret_cast<uint64_t*>(
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place));
int64_t offset = 0;
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
for (size_t i = 0; i < keys.size(); ++i) {
memory::Copy(boost::get<platform::CPUPlace>(place), total_keys + offset,
boost::get<platform::CPUPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t));
offset += slot_lengths[i];
}
offset = 0;
VLOG(3) << "Begin copy grad tensor to BoxPS struct";
for (size_t i = 0; i < grad_values.size(); ++i) {
int64_t fea_num = slot_lengths[i];
for (auto j = 0; j < fea_num; ++j) {
// Copy the emb grad from paddle tensor to BoxPS. Since
// 'show','click','emb' are continuous in memory, here we copy
// using 'show' address
memory::Copy(
boost::get<platform::CPUPlace>(place),
reinterpret_cast<float*>(&((total_grad_values_gpu + offset)->show)),
boost::get<platform::CPUPlace>(place),
grad_values[i] + j * hidden_size, sizeof(float) * hidden_size);
++offset;
}
}
VLOG(3) << "Begin call PushSparseCPU in BoxPS";
push_boxps_timer.Start();
int ret = boxps_ptr_->PushSparseCPU(total_keys, total_grad_values_gpu,
static_cast<int>(total_length));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"PushSparseCPU failed in BoxPS."));
push_boxps_timer.Pause();
PADDLE_THROW(platform::errors::Unimplemented(
"Warning:: CPUPlace is not supported in PaddleBox now."));
} else if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int device_id = boost::get<platform::CUDAPlace>(place).GetDeviceId();

File diff suppressed because it is too large Load Diff

@ -168,6 +168,11 @@ void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
SectionWorker::cpu_id_.store(pipeline_config_.start_cpu_core_id());
scope_queues_.resize(section_num_);
pipeline_scopes_.resize(pipeline_num_);
for (auto& var : main_program.Block(0).AllVars()) {
if (var->Persistable()) {
persistable_vars_.push_back(var->Name());
}
}
VLOG(3) << "Init ScopeQueues and create all scopes";
for (int i = 0; i < section_num_; ++i) {
@ -266,7 +271,7 @@ void PipelineTrainer::Finalize() {
for (auto& th : section_threads_) {
th.join();
}
for (const auto& var : *param_need_sync_) {
for (const auto& var : persistable_vars_) {
auto* root_tensor = root_scope_->Var(var)->GetMutable<LoDTensor>();
// TODO(hutuxian): Add a final all-reduce?
const auto& thread_tensor =

@ -15,6 +15,7 @@ limitations under the License. */
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/platform/cpu_helper.h"
@ -146,6 +147,9 @@ void SectionWorker::TrainFiles() {
int64_t accum_num = 0;
int batch_size = 0;
Scope* scope = nullptr;
if (device_reader_ != nullptr) {
device_reader_->Start();
}
while (in_scope_queue_->Receive(&scope)) {
if (device_reader_ != nullptr) {
device_reader_->AssignFeedVar(*scope);
@ -202,6 +206,17 @@ void SectionWorker::TrainFiles() {
// No effect when it is a CPUDeviceContext
dev_ctx_->Wait();
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance();
auto& metric_list = box_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_msg->IsJoin() != box_ptr->PassFlag()) {
continue;
}
metric_msg->add_data(exe_scope);
}
#endif
if (section_id_ != section_num_ - 1 && platform::is_gpu_place(place_)) {
// FIXME: Temporarily we assume two adjacent sections are in different
// places,
@ -273,6 +288,9 @@ void SectionWorker::TrainFilesWithProfiler() {
op_total_time[i] = 0.0;
}
platform::Timer timeline;
if (device_reader_ != nullptr) {
device_reader_->Start();
}
bool started = false;
while (in_scope_queue_->Receive(&scope)) {
@ -330,9 +348,11 @@ void SectionWorker::TrainFilesWithProfiler() {
SEC_LOG << "begin running ops";
cal_timer.Resume();
int op_id = 0;
dev_ctx_->Wait();
for (auto& op : ops_) {
timeline.Start();
op->Run(*exe_scope, place_);
dev_ctx_->Wait();
timeline.Pause();
op_total_time[op_id++] += timeline.ElapsedUS();
}
@ -342,6 +362,17 @@ void SectionWorker::TrainFilesWithProfiler() {
// No effect when it is a CPUDeviceContext
dev_ctx_->Wait();
cal_timer.Pause();
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance();
auto& metric_list = box_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_msg->IsJoin() != box_ptr->PassFlag()) {
continue;
}
metric_msg->add_data(exe_scope);
}
#endif
if (section_id_ != section_num_ - 1 && platform::is_gpu_place(place_)) {
// FIXME: Temporarily we assume two adjacent sections are in different

@ -157,6 +157,7 @@ class PipelineTrainer : public TrainerBase {
// The parameters that should be syncronized between different cards using
// nccl all-reduce
std::shared_ptr<std::vector<std::string>> param_need_sync_;
std::vector<std::string> persistable_vars_;
std::vector<std::unique_ptr<SyncFunctor>> sync_functors_;
std::shared_ptr<platform::NCCLContextMap> nccl_ctx_map_;

@ -29,6 +29,9 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/pybind/box_helper_py.h"
#ifdef PADDLE_WITH_BOX_PS
#include <boxps_public.h>
#endif
namespace py = pybind11;
@ -40,6 +43,8 @@ void BindBoxHelper(py::module* m) {
.def(py::init([](paddle::framework::Dataset* dataset) {
return std::make_shared<paddle::framework::BoxHelper>(dataset);
}))
.def("set_date", &framework::BoxHelper::SetDate,
py::call_guard<py::gil_scoped_release>())
.def("begin_pass", &framework::BoxHelper::BeginPass,
py::call_guard<py::gil_scoped_release>())
.def("end_pass", &framework::BoxHelper::EndPass,
@ -51,5 +56,35 @@ void BindBoxHelper(py::module* m) {
.def("load_into_memory", &framework::BoxHelper::LoadIntoMemory,
py::call_guard<py::gil_scoped_release>());
} // end BoxHelper
#ifdef PADDLE_WITH_BOX_PS
void BindBoxWrapper(py::module* m) {
py::class_<framework::BoxWrapper, std::shared_ptr<framework::BoxWrapper>>(
*m, "BoxWrapper")
.def(py::init([]() {
// return std::make_shared<paddle::framework::BoxHelper>(dataset);
return framework::BoxWrapper::GetInstance();
}))
.def("save_base", &framework::BoxWrapper::SaveBase,
py::call_guard<py::gil_scoped_release>())
.def("feed_pass", &framework::BoxWrapper::FeedPass,
py::call_guard<py::gil_scoped_release>())
.def("save_delta", &framework::BoxWrapper::SaveDelta,
py::call_guard<py::gil_scoped_release>())
.def("initialize_gpu", &framework::BoxWrapper::InitializeGPU,
py::call_guard<py::gil_scoped_release>())
.def("init_metric", &framework::BoxWrapper::InitMetric,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_msg", &framework::BoxWrapper::GetMetricMsg,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_name_list", &framework::BoxWrapper::GetMetricNameList,
py::call_guard<py::gil_scoped_release>())
.def("flip_pass_flag", &framework::BoxWrapper::FlipPassFlag,
py::call_guard<py::gil_scoped_release>())
.def("finalize", &framework::BoxWrapper::Finalize,
py::call_guard<py::gil_scoped_release>());
} // end BoxWrapper
#endif
} // end namespace pybind
} // end namespace paddle

@ -23,6 +23,9 @@ namespace paddle {
namespace pybind {
void BindBoxHelper(py::module* m);
#ifdef PADDLE_WITH_BOX_PS
void BindBoxWrapper(py::module* m);
#endif
} // namespace pybind
} // namespace paddle

@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/load_op_lib.h"
@ -1456,6 +1457,9 @@ All parameter, weight, gradient are variables in Paddle.
m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN);
m.def("is_compiled_with_brpc", IsCompiledWithBrpc);
m.def("is_compiled_with_dist", IsCompiledWithDIST);
m.def("run_cmd", [](const std::string &cmd) -> const std::string {
return paddle::framework::shell_get_command_output(cmd);
});
#ifdef PADDLE_WITH_CUDA
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 53 support float16
@ -2245,6 +2249,9 @@ All parameter, weight, gradient are variables in Paddle.
BindFleetWrapper(&m);
BindGlooWrapper(&m);
BindBoxHelper(&m);
#ifdef PADDLE_WITH_BOX_PS
BindBoxWrapper(&m);
#endif
#ifdef PADDLE_WITH_NCCL
BindNCCLWrapper(&m);
#endif

@ -314,12 +314,12 @@ class InMemoryDataset(DatasetBase):
def _dynamic_adjust_before_train(self, thread_num):
if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(thread_num)
self.dataset.dynamic_adjust_channel_num(thread_num, False)
self.dataset.dynamic_adjust_readers_num(thread_num)
def _dynamic_adjust_after_train(self):
if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(self.thread_num)
self.dataset.dynamic_adjust_channel_num(self.thread_num, False)
self.dataset.dynamic_adjust_readers_num(self.thread_num)
def set_queue_num(self, queue_num):
@ -793,6 +793,15 @@ class BoxPSDataset(InMemoryDataset):
super(BoxPSDataset, self).__init__()
self.boxps = core.BoxPS(self.dataset)
def set_date(self, date):
"""
Workaround for date
"""
year = int(date[:4])
month = int(date[4:6])
day = int(date[6:])
self.boxps.set_date(year, month, day)
def begin_pass(self):
"""
Begin Pass
@ -865,3 +874,8 @@ class BoxPSDataset(InMemoryDataset):
"""
self._prepare_to_run()
self.boxps.preload_into_memory()
def _dynamic_adjust_before_train(self, thread_num):
if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(thread_num, True)
self.dataset.dynamic_adjust_readers_num(thread_num)

@ -57,6 +57,7 @@ endif()
if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_pipeline)
LIST(REMOVE_ITEM TEST_OPS test_boxps)
endif()
list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185

@ -90,7 +90,6 @@ class TestBoxPSPreload(unittest.TestCase):
y = fluid.layers.data(name='y', shape=[1], dtype='int64', lod_level=0)
emb_x, emb_y = _pull_box_sparse([x, y], size=2)
emb_xp = _pull_box_sparse(x, size=2)
layers.Print(emb_xp)
concat = layers.concat([emb_x, emb_y], axis=1)
fc = layers.fc(input=concat,
name="fc",
@ -102,7 +101,6 @@ class TestBoxPSPreload(unittest.TestCase):
place = fluid.CPUPlace() if is_cpu or not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
optimizer = fluid.optimizer.SGD(learning_rate=0.5)
batch_size = 2
def binary_print(slot, fout):
@ -125,6 +123,7 @@ class TestBoxPSPreload(unittest.TestCase):
def create_dataset():
dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
dataset.set_date("20190930")
dataset.set_use_var([x, y])
dataset.set_batch_size(2)
dataset.set_thread(1)
@ -134,6 +133,14 @@ class TestBoxPSPreload(unittest.TestCase):
datasets = []
datasets.append(create_dataset())
datasets.append(create_dataset())
optimizer = fluid.optimizer.SGD(learning_rate=0.5)
optimizer = fluid.optimizer.PipelineOptimizer(
optimizer,
cut_list=[],
place_list=[place],
concurrency_list=[1],
queue_size=1,
sync_steps=-1)
optimizer.minimize(loss)
exe.run(fluid.default_startup_program())
datasets[0].load_into_memory()
@ -149,7 +156,8 @@ class TestBoxPSPreload(unittest.TestCase):
exe.train_from_dataset(
program=fluid.default_main_program(),
dataset=datasets[1],
print_period=1)
print_period=1,
debug=True)
datasets[1].end_pass()
for f in filelist:
os.remove(f)

@ -379,7 +379,7 @@ class SingleProcessMultiThread(GradAllReduce):
'''
def __init__(self):
GradAllReduce.__init__(self, -1)
GradAllReduce.__init__(self, 1)
self.mode = "single_process_multi_thread"
def _transpile_startup_program(self):

Loading…
Cancel
Save