diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index f185bf505f..541b3fede9 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -42,8 +42,8 @@ DatasetImpl::DatasetImpl() { channel_num_ = 1; file_idx_ = 0; cur_channel_ = 0; - fleet_send_batch_size_ = 80000; - fleet_send_sleep_seconds_ = 2; + fleet_send_batch_size_ = 1024; + fleet_send_sleep_seconds_ = 0; merge_by_insid_ = false; erase_duplicate_feas_ = true; keep_unmerged_ins_ = true; @@ -51,6 +51,7 @@ DatasetImpl::DatasetImpl() { parse_ins_id_ = false; parse_content_ = false; preload_thread_num_ = 0; + global_index_ = 0; } // set filelist, file_idx_ will reset to zero. @@ -291,7 +292,7 @@ void DatasetImpl::LocalShuffle() { } template -void DatasetImpl::GlobalShuffle() { +void DatasetImpl::GlobalShuffle(int thread_num) { VLOG(3) << "DatasetImpl::GlobalShuffle() begin"; platform::Timer timeline; timeline.Start(); @@ -358,13 +359,21 @@ void DatasetImpl::GlobalShuffle() { ars.shrink_to_fit(); data.clear(); data.shrink_to_fit(); - sleep(this->fleet_send_sleep_seconds_); + // currently we find bottleneck is server not able to handle large data + // in time, so we can remove this sleep and set fleet_send_batch_size to + // 1024, and set server thread to 24. + if (fleet_send_sleep_seconds_ != 0) { + sleep(this->fleet_send_sleep_seconds_); + } } }; - VLOG(3) << "start global shuffle threads"; std::vector global_shuffle_threads; - for (int i = 0; i < thread_num_; ++i) { + if (thread_num == -1) { + thread_num = thread_num_; + } + VLOG(3) << "start global shuffle threads, num = " << thread_num; + for (int i = 0; i < thread_num; ++i) { global_shuffle_threads.push_back(std::thread(global_shuffle_func)); } for (std::thread& t : global_shuffle_threads) { @@ -378,6 +387,101 @@ void DatasetImpl::GlobalShuffle() { << timeline.ElapsedSec() << " seconds"; } +template +void DatasetImpl::DynamicAdjustChannelNum(int channel_num) { + if (channel_num_ == channel_num) { + VLOG(3) << "DatasetImpl::DynamicAdjustChannelNum channel_num_=" + << channel_num_ << ", channel_num_=channel_num, no need to adjust"; + return; + } + VLOG(3) << "adjust channel num from " << channel_num_ << " to " + << channel_num; + channel_num_ = channel_num; + std::vector>* origin_channels = nullptr; + std::vector>* other_channels = nullptr; + // find out which channel (output or consume) has data + int cur_channel = 0; + uint64_t output_channels_data_size = 0; + uint64_t consume_channels_data_size = 0; + CHECK(multi_output_channel_.size() == multi_consume_channel_.size()); + for (int i = 0; i < multi_output_channel_.size(); ++i) { + output_channels_data_size += multi_output_channel_[i]->Size(); + consume_channels_data_size += multi_consume_channel_[i]->Size(); + } + if (output_channels_data_size != 0) { + CHECK(consume_channels_data_size == 0); // NOLINT + cur_channel = 0; + } else { + CHECK(output_channels_data_size == 0); // NOLINT + cur_channel = 1; + } + if (cur_channel == 0) { + origin_channels = &multi_output_channel_; + other_channels = &multi_consume_channel_; + } else { + origin_channels = &multi_consume_channel_; + other_channels = &multi_output_channel_; + } + CHECK(origin_channels != nullptr); // NOLINT + CHECK(other_channels != nullptr); // NOLINT + + paddle::framework::Channel total_data_channel = + paddle::framework::MakeChannel(); + std::vector> new_channels; + std::vector> new_other_channels; + std::vector local_vec; + for (int i = 0; i < origin_channels->size(); ++i) { + local_vec.clear(); + (*origin_channels)[i]->Close(); + (*origin_channels)[i]->ReadAll(local_vec); + total_data_channel->Write(std::move(local_vec)); + } + total_data_channel->Close(); + total_data_channel->SetBlockSize(total_data_channel->Size() / channel_num + + 1); + + for (int i = 0; i < channel_num; ++i) { + local_vec.clear(); + total_data_channel->Read(local_vec); + new_other_channels.push_back(paddle::framework::MakeChannel()); + new_channels.push_back(paddle::framework::MakeChannel()); + new_channels[i]->Write(std::move(local_vec)); + } + + total_data_channel->Clear(); + origin_channels->clear(); + other_channels->clear(); + *origin_channels = new_channels; + *other_channels = new_other_channels; + + new_channels.clear(); + new_other_channels.clear(); + std::vector>().swap(new_channels); + std::vector>().swap(new_other_channels); + local_vec.clear(); + std::vector().swap(local_vec); + VLOG(3) << "adjust channel num done"; +} + +template +void DatasetImpl::DynamicAdjustReadersNum(int thread_num) { + if (thread_num_ == thread_num) { + VLOG(3) << "DatasetImpl::DynamicAdjustReadersNum thread_num_=" + << thread_num_ << ", thread_num_=thread_num, no need to adjust"; + return; + } + VLOG(3) << "adjust readers num from " << thread_num_ << " to " << thread_num; + thread_num_ = thread_num; + std::vector>().swap(readers_); + CreateReaders(); + VLOG(3) << "adjust readers num done"; +} + +template +void DatasetImpl::SetFleetSendSleepSeconds(int seconds) { + fleet_send_sleep_seconds_ = seconds; +} + template void DatasetImpl::CreateReaders() { VLOG(3) << "Calling CreateReaders()"; @@ -509,7 +613,16 @@ int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, CHECK(ar.Cursor() == ar.Finish()); auto fleet_ptr = FleetWrapper::GetInstance(); - int64_t index = fleet_ptr->LocalRandomEngine()() % channel_num_; + // not use random because it doesn't perform well here. + // to make sure each channel get data equally, we just put data to + // channel one by one. + // int64_t index = fleet_ptr->LocalRandomEngine()() % channel_num_; + int64_t index = 0; + { + std::unique_lock lk(global_index_mutex_); + index = global_index_++; + } + index = index % channel_num_; VLOG(3) << "ramdom index=" << index; multi_output_channel_[index]->Write(std::move(data)); diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 176a53f8f3..bcf344d23a 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -99,7 +99,7 @@ class Dataset { // local shuffle data virtual void LocalShuffle() = 0; // global shuffle data - virtual void GlobalShuffle() = 0; + virtual void GlobalShuffle(int thread_num = -1) = 0; // for slots shuffle virtual void SlotsShuffle(const std::set& slots_to_replace) = 0; virtual void GetRandomData(const std::set& slots_to_replace, @@ -120,6 +120,11 @@ class Dataset { virtual void DestroyPreLoadReaders() = 0; // set preload thread num virtual void SetPreLoadThreadNum(int thread_num) = 0; + // seperate train thread and dataset thread + virtual void DynamicAdjustChannelNum(int channel_num) = 0; + virtual void DynamicAdjustReadersNum(int thread_num) = 0; + // set fleet send sleep seconds + virtual void SetFleetSendSleepSeconds(int seconds) = 0; protected: virtual int ReceiveFromClient(int msg_type, int client_id, @@ -169,7 +174,7 @@ class DatasetImpl : public Dataset { virtual void WaitPreLoadDone(); virtual void ReleaseMemory(); virtual void LocalShuffle(); - virtual void GlobalShuffle(); + virtual void GlobalShuffle(int thread_num = -1); virtual void SlotsShuffle(const std::set& slots_to_replace) {} virtual void GetRandomData(const std::set& slots_to_replace, std::vector* result) {} @@ -181,6 +186,9 @@ class DatasetImpl : public Dataset { virtual void CreatePreLoadReaders(); virtual void DestroyPreLoadReaders(); virtual void SetPreLoadThreadNum(int thread_num); + virtual void DynamicAdjustChannelNum(int channel_num); + virtual void DynamicAdjustReadersNum(int thread_num); + virtual void SetFleetSendSleepSeconds(int seconds); protected: virtual int ReceiveFromClient(int msg_type, int client_id, @@ -217,6 +225,8 @@ class DatasetImpl : public Dataset { std::vector merge_slots_list_; bool slots_shuffle_fea_eval_ = false; int preload_thread_num_; + std::mutex global_index_mutex_; + int64_t global_index_ = 0; }; // use std::vector or Record as data type diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index c4f13975b7..56a9ebc36e 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -148,11 +148,18 @@ void DistMultiTrainer::Finalize() { if (root_tensor->numel() != thread_tensor->numel()) { continue; } -#define MergeCallback(cpp_type, proto_type) \ - do { \ - if (root_tensor->type() == proto_type) { \ - MergeToRootScope(root_tensor, thread_tensor); \ - } \ +#define MergeCallback(cpp_type, proto_type) \ + do { \ + if (root_tensor->type() == proto_type) { \ + if (thread_tensor->type() != proto_type) { \ + VLOG(0) << "Error: thread id=" << j << ", need_merge_var_names_[" << i \ + << "] " << need_merge_var_names_[i] \ + << ", root tensor type=" << root_tensor->type() \ + << ", thread tensor type=" << thread_tensor->type(); \ + exit(-1); \ + } \ + MergeToRootScope(root_tensor, thread_tensor); \ + } \ } while (0) _ForEachDataType_(MergeCallback); } @@ -163,6 +170,10 @@ void DistMultiTrainer::Finalize() { } pull_dense_worker_->Stop(); root_scope_->DropKids(); + + // flush local client push queue + auto fleet_ptr_ = FleetWrapper::GetInstance(); + fleet_ptr_->ClientFlush(); } template diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 0e1102f341..22a9b79d7f 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -66,6 +66,14 @@ paddle::ps::Archive& operator>>(paddle::ps::Archive& ar, std::shared_ptr FleetWrapper::pslib_ptr_ = NULL; #endif +void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms, + int connect_timeout_ms, + int max_retry) { + client2client_request_timeout_ms_ = request_timeout_ms; + client2client_connect_timeout_ms_ = connect_timeout_ms; + client2client_max_retry_ = max_retry; +} + void FleetWrapper::InitServer(const std::string& dist_desc, int index) { #ifdef PADDLE_WITH_PSLIB if (!is_initialized_) { @@ -142,7 +150,9 @@ std::vector FleetWrapper::GetClientsInfo() { void FleetWrapper::CreateClient2ClientConnection() { #ifdef PADDLE_WITH_PSLIB VLOG(3) << "Going to create client2client connection"; - pslib_ptr_->create_client2client_connection(); + pslib_ptr_->create_client2client_connection(client2client_request_timeout_ms_, + client2client_connect_timeout_ms_, + client2client_max_retry_); #endif } @@ -344,7 +354,9 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( slot = boost::lexical_cast(sparse_key_names[i]); } Variable* g_var = scope.FindVar(sparse_grad_names[i]); - CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found"; + if (g_var == nullptr) { + continue; + } LoDTensor* g_tensor = g_var->GetMutable(); if (g_tensor == nullptr) { LOG(ERROR) << "tensor of var[" << sparse_key_names[i] << "] is null"; diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 4fdfd6fc66..4aa626340d 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -59,7 +59,17 @@ class FleetWrapper { scale_sparse_gradient_with_batch_size_ = true; // trainer sleep some time for pslib core dump sleep_seconds_before_fail_exit_ = 300; + // pslib request server timeout ms + client2client_request_timeout_ms_ = 500000; + // pslib connect server timeout_ms + client2client_connect_timeout_ms_ = 10000; + // pslib request max retry + client2client_max_retry_ = 3; } + + void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, + int max_retry); + // Pull sparse variables from server in Sync mode // Param: scope, table_id, var_names, fea_keys // Param: fea_values @@ -200,6 +210,9 @@ class FleetWrapper { static bool is_initialized_; bool scale_sparse_gradient_with_batch_size_; int32_t sleep_seconds_before_fail_exit_; + int client2client_request_timeout_ms_; + int client2client_connect_timeout_ms_; + int client2client_max_retry_; DISABLE_COPY_AND_ASSIGN(FleetWrapper); }; diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index a5c76db6fa..5dc83ac7b3 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -124,6 +124,9 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker { "'epsilon' should be between 0.0 and 0.001."); }); AddAttr("data_layout", "").SetDefault("NCHW"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddInput("X", "The input tensor"); AddInput("BatchSize", "BatchSize is a 1-dimensional tensor of size C " @@ -224,7 +227,6 @@ class DataNormGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("Scales"), ""); // check output - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), ""); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSize")), ""); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSum")), ""); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSquareSum")), @@ -237,7 +239,9 @@ class DataNormGradOp : public framework::OperatorWithKernel { (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } ctx->SetOutputDim(framework::GradVarName("BatchSize"), {C}); ctx->SetOutputDim(framework::GradVarName("BatchSum"), {C}); ctx->SetOutputDim(framework::GradVarName("BatchSquareSum"), {C}); @@ -304,7 +308,10 @@ class DataNormGradKernel : x_dims[x_dims.size() - 1]); // init output - auto *d_x = ctx.Output(framework::GradVarName("X")); + Tensor *d_x = nullptr; + if (ctx.HasOutput(framework::GradVarName("X"))) { + d_x = ctx.Output(framework::GradVarName("X")); + } auto *d_batch_size = ctx.Output(framework::GradVarName("BatchSize")); auto *d_batch_sum = ctx.Output(framework::GradVarName("BatchSum")); @@ -331,10 +338,12 @@ class DataNormGradKernel ConstEigenVectorArrayMap means_arr(means->data(), C); ConstEigenArrayMap x_arr(x->data(), C, N); ConstEigenArrayMap d_y_arr(d_y->data(), C, N); - EigenArrayMap d_x_arr(d_x->mutable_data(ctx.GetPlace()), C, N); - d_x_arr.setZero(); - for (int nc = 0; nc < N; ++nc) { - d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr; + if (d_x != nullptr) { + EigenArrayMap d_x_arr(d_x->mutable_data(ctx.GetPlace()), C, N); + d_x_arr.setZero(); + for (int nc = 0; nc < N; ++nc) { + d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr; + } } // calculate data sum and squre sum diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 0125465e6e..dd513d4b85 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -257,6 +257,15 @@ void BindDataset(py::module *m) { py::call_guard()) .def("destroy_preload_readers", &framework::Dataset::DestroyPreLoadReaders, + py::call_guard()) + .def("dynamic_adjust_channel_num", + &framework::Dataset::DynamicAdjustChannelNum, + py::call_guard()) + .def("dynamic_adjust_readers_num", + &framework::Dataset::DynamicAdjustReadersNum, + py::call_guard()) + .def("set_fleet_send_sleep_seconds", + &framework::Dataset::SetFleetSendSleepSeconds, py::call_guard()); py::class_(*m, "IterableDatasetWrapper") diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 90772b3546..e7c7750c27 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -65,7 +65,9 @@ void BindFleetWrapper(py::module* m) { .def("client_flush", &framework::FleetWrapper::ClientFlush) .def("load_from_paddle_model", &framework::FleetWrapper::LoadFromPaddleModel) - .def("load_model_one_table", &framework::FleetWrapper::LoadModelOneTable); + .def("load_model_one_table", &framework::FleetWrapper::LoadModelOneTable) + .def("set_client2client_config", + &framework::FleetWrapper::SetClient2ClientConfig); } // end FleetWrapper } // end namespace pybind } // end namespace paddle diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 486fe0995f..1ae2d056e8 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""This is defination of dataset class, which is high performance IO.""" from paddle.fluid.proto import data_feed_pb2 from google.protobuf import text_format @@ -70,7 +71,7 @@ class DatasetBase(object): self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc.pipe_command = "cat" self.dataset = core.Dataset("MultiSlotDataset") - self.thread_num = 0 + self.thread_num = 1 self.filelist = [] def set_pipe_command(self, pipe_command): @@ -265,6 +266,12 @@ class DatasetBase(object): """ return text_format.MessageToString(self.proto_desc) + def _dynamic_adjust_before_train(self, thread_num): + pass + + def _dynamic_adjust_after_train(self): + pass + class InMemoryDataset(DatasetBase): """ @@ -281,19 +288,19 @@ class InMemoryDataset(DatasetBase): super(InMemoryDataset, self).__init__() self.proto_desc.name = "MultiSlotInMemoryDataFeed" self.fleet_send_batch_size = None + self.is_user_set_queue_num = False self.queue_num = None self.parse_ins_id = False self.parse_content = False self.merge_by_lineid = False + self.fleet_send_sleep_seconds = None def _prepare_to_run(self): """ Set data_feed_desc before load or shuffle, user no need to call this function. """ - if self.thread_num > len(self.filelist): - self.thread_num = len(self.filelist) - if self.thread_num == 0: + if self.thread_num <= 0: self.thread_num = 1 self.dataset.set_thread_num(self.thread_num) if self.queue_num is None: @@ -305,6 +312,16 @@ class InMemoryDataset(DatasetBase): self.dataset.create_channel() self.dataset.create_readers() + def _dynamic_adjust_before_train(self, thread_num): + if not self.is_user_set_queue_num: + self.dataset.dynamic_adjust_channel_num(thread_num) + self.dataset.dynamic_adjust_readers_num(thread_num) + + def _dynamic_adjust_after_train(self): + if not self.is_user_set_queue_num: + self.dataset.dynamic_adjust_channel_num(self.thread_num) + self.dataset.dynamic_adjust_readers_num(self.thread_num) + def set_queue_num(self, queue_num): """ Set Dataset output queue num, training threads get data from queues @@ -320,6 +337,7 @@ class InMemoryDataset(DatasetBase): dataset.set_queue_num(12) """ + self.is_user_set_queue_num = True self.queue_num = queue_num def set_parse_ins_id(self, parse_ins_id): @@ -356,9 +374,9 @@ class InMemoryDataset(DatasetBase): """ self.parse_content = parse_content - def set_fleet_send_batch_size(self, fleet_send_batch_size): + def set_fleet_send_batch_size(self, fleet_send_batch_size=1024): """ - Set fleet send batch size, default is 80000 + Set fleet send batch size, default is 1024 Args: fleet_send_batch_size(int): fleet send batch size @@ -373,6 +391,23 @@ class InMemoryDataset(DatasetBase): """ self.fleet_send_batch_size = fleet_send_batch_size + def set_fleet_send_sleep_seconds(self, fleet_send_sleep_seconds=0): + """ + Set fleet send sleep time, default is 0 + + Args: + fleet_send_sleep_seconds(int): fleet send sleep time + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_fleet_send_sleep_seconds(2) + + """ + self.fleet_send_sleep_seconds = fleet_send_sleep_seconds + def set_merge_by_lineid(self, var_list, erase_duplicate_feas=True, @@ -480,7 +515,7 @@ class InMemoryDataset(DatasetBase): """ self.dataset.local_shuffle() - def global_shuffle(self, fleet=None): + def global_shuffle(self, fleet=None, thread_num=12): """ Global shuffle. Global shuffle can be used only in distributed mode. i.e. multiple @@ -500,6 +535,7 @@ class InMemoryDataset(DatasetBase): Args: fleet(Fleet): fleet singleton. Default None. + thread_num(int): shuffle thread num. Default is 12. """ trainer_num = 1 @@ -507,13 +543,16 @@ class InMemoryDataset(DatasetBase): fleet._role_maker._barrier_worker() trainer_num = fleet.worker_num() if self.fleet_send_batch_size is None: - self.fleet_send_batch_size = 800 * trainer_num + self.fleet_send_batch_size = 1024 + if self.fleet_send_sleep_seconds is None: + self.fleet_send_sleep_seconds = 0 self.dataset.register_client2client_msg_handler() self.dataset.set_trainer_num(trainer_num) self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size) + self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds) if fleet is not None: fleet._role_maker._barrier_worker() - self.dataset.global_shuffle() + self.dataset.global_shuffle(thread_num) if fleet is not None: fleet._role_maker._barrier_worker() if self.merge_by_lineid: @@ -666,6 +705,9 @@ class QueueDataset(DatasetBase): dataset = fluid.DatasetFactory().create_dataset("QueueDataset") dataset.local_shuffle() + Raises: + NotImplementedError: QueueDataset does not support local shuffle + """ raise NotImplementedError( "QueueDataset does not support local shuffle, " @@ -689,6 +731,9 @@ class QueueDataset(DatasetBase): dataset = fluid.DatasetFactory().create_dataset("QueueDataset") dataset.global_shuffle(fleet) + Raises: + NotImplementedError: QueueDataset does not support global shuffle + """ raise NotImplementedError( "QueueDataset does not support global shuffle, " @@ -708,14 +753,16 @@ class FileInstantDataset(DatasetBase): def __init__(self): """ - Init + Initialize FileInstantDataset + This class should be created by DatasetFactory """ super(FileInstantDataset, self).__init__() self.proto_desc.name = "MultiSlotFileInstantDataFeed" def local_shuffle(self): """ - Local shuffle, FileInstantDataset does not support local shuffle + Local shuffle + FileInstantDataset does not support local shuffle """ raise NotImplementedError( "FileInstantDataset does not support local shuffle, " @@ -724,6 +771,7 @@ class FileInstantDataset(DatasetBase): def global_shuffle(self, fleet=None): """ Global shuffle + FileInstantDataset does not support global shuffle """ raise NotImplementedError( "FileInstantDataset does not support global shuffle, " @@ -743,26 +791,30 @@ class BoxPSDataset(InMemoryDataset): def __init__(self): """ - Init + Initialize BoxPSDataset + This class should be created by DatasetFactory """ super(BoxPSDataset, self).__init__() self.boxps = core.BoxPS(self.dataset) def begin_pass(self): """ - Notify BoxPS to begin next pass + Begin Pass + Notify BoxPS to begin next pass """ self.boxps.begin_pass() def end_pass(self): """ - Notify BoxPS to end current pass + End Pass + Notify BoxPS to end current pass """ self.boxps.end_pass() def wait_preload_done(self): """ - Wait async proload done + Wait async proload done + Wait Until Feed Pass Done """ self.boxps.wait_feed_pass_done() diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 9060455f4d..ed0479be84 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -803,7 +803,6 @@ class Executor(object): program.program._fleet_opt) trainer._set_program(program.program) - # The following thread_num-determined logic will be deprecated if thread <= 0: if dataset.thread_num <= 0: raise RuntimeError( @@ -889,9 +888,11 @@ class Executor(object): trainer._set_infer(True) trainer._gen_trainer_desc() self._dump_debug_info(program=program, trainer=trainer) + dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num) self._default_executor.run_from_dataset(program.desc, scope, dataset.dataset, trainer._desc()) + dataset._dynamic_adjust_after_train() dataset._finish_to_run() return None @@ -973,8 +974,10 @@ class Executor(object): print_period=print_period) trainer._gen_trainer_desc() self._dump_debug_info(program=program, trainer=trainer) + dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num) self._default_executor.run_from_dataset(program.desc, scope, dataset.dataset, trainer._desc()) + dataset._dynamic_adjust_after_train() dataset._finish_to_run() return None diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 1e84365ada..e9668805e4 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -32,11 +32,20 @@ class PSLib(Fleet): self._fleet_ptr = None self._main_programs = [] self._scopes = [] + self._client2client_request_timeout_ms = 500000 + self._client2client_connect_timeout_ms = 10000 + self._client2client_max_retry = 3 def init(self, role_maker=None): super(PSLib, self).init(MPISymetricRoleMaker()) self._fleet_ptr = fluid.core.Fleet() + def _set_client_communication_config(self, request_timeout_ms, + connect_timeout_ms, max_retry): + self._client2client_request_timeout_ms = request_timeout_ms + self._client2client_connect_timeout_ms = connect_timeout_ms + self._client2client_max_retry = max_retry + def init_worker(self): """ init_worker(): will be called by user. When a user knows current process is_server(), he/she @@ -72,6 +81,10 @@ class PSLib(Fleet): info = self._fleet_ptr.get_clients_info() all_info = self._role_maker._worker_gather(info[0]) self._fleet_ptr.gather_clients(all_info) + self._fleet_ptr.set_client2client_config( + self._client2client_request_timeout_ms, + self._client2client_connect_timeout_ms, + self._client2client_max_retry) self._fleet_ptr.create_client2client_connection() # barrier for init model self._role_maker._barrier_worker() diff --git a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py index e3dc12f7cb..6813b76789 100644 --- a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py +++ b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py @@ -311,14 +311,23 @@ class FleetUtil(object): xbox_base_key, data_path, hadoop_fs_name, - monitor_data={}): + monitor_data={}, + mode="patch"): xbox_dict = collections.OrderedDict() - xbox_dict["id"] = str(int(time.time())) + if mode == "base": + xbox_dict["id"] = str(xbox_base_key) + elif mode == "patch": + xbox_dict["id"] = str(int(time.time())) + else: + print("warning: unknown mode %s, set it to patch" % mode) + mode = "patch" + xbox_dict["id"] = str(int(time.time())) xbox_dict["key"] = str(xbox_base_key) if model_path.startswith("hdfs:") or model_path.startswith("afs:"): model_path = model_path[model_path.find(":") + 1:] xbox_dict["input"] = hadoop_fs_name + model_path.rstrip("/") + "/000" xbox_dict["record_count"] = "111111" + xbox_dict["partition_type"] = "2" xbox_dict["job_name"] = "default_job_name" xbox_dict["ins_tag"] = "feasign" xbox_dict["ins_path"] = data_path @@ -477,13 +486,16 @@ class FleetUtil(object): day = str(day) pass_id = str(pass_id) xbox_base_key = int(xbox_base_key) + mode = None if pass_id != "-1": + mode = "patch" suffix_name = "/%s/delta-%s/" % (day, pass_id) model_path = output_path.rstrip("/") + suffix_name if donefile_name is None: donefile_name = "xbox_patch_done.txt" else: + mode = "base" suffix_name = "/%s/base/" % day model_path = output_path.rstrip("/") + suffix_name if donefile_name is None: @@ -495,7 +507,8 @@ class FleetUtil(object): if fleet.worker_index() == 0: donefile_path = output_path + "/" + donefile_name xbox_str = self._get_xbox_str(output_path, day, model_path, \ - xbox_base_key, data_path, hadoop_fs_name, monitor_data={}) + xbox_base_key, data_path, hadoop_fs_name, monitor_data={}, + mode=mode) configs = { "fs.default.name": hadoop_fs_name, "hadoop.job.ugi": hadoop_fs_ugi diff --git a/python/paddle/fluid/incubate/fleet/utils/hdfs.py b/python/paddle/fluid/incubate/fleet/utils/hdfs.py index 6b9efc4856..1d1714bf72 100644 --- a/python/paddle/fluid/incubate/fleet/utils/hdfs.py +++ b/python/paddle/fluid/incubate/fleet/utils/hdfs.py @@ -24,10 +24,20 @@ import copy import errno import logging -from paddle.fluid.log_helper import get_logger __all__ = ["HDFSClient"] + +def get_logger(name, level, fmt): + logger = logging.getLogger(name) + logger.setLevel(level) + handler = logging.FileHandler('hdfs.log', mode='w') + formatter = logging.Formatter(fmt=fmt) + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger + + _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') @@ -461,7 +471,7 @@ class HDFSClient(object): procs = [] for i in range(multi_processes): - process_datas = HDFSClient.split_flies(all_files, i, + process_datas = HDFSClient.split_files(all_files, i, multi_processes) p = multiprocessing.Process( target=__subprocess_download, @@ -551,7 +561,7 @@ class HDFSClient(object): procs = [] for i in range(multi_processes): - process_datas = HDFSClient.split_flies(all_files, i, + process_datas = HDFSClient.split_files(all_files, i, multi_processes) p = multiprocessing.Process( target=__subprocess_upload, args=( diff --git a/python/paddle/fluid/tests/unittests/test_data_norm_op.py b/python/paddle/fluid/tests/unittests/test_data_norm_op.py new file mode 100644 index 0000000000..0273664d5d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_data_norm_op.py @@ -0,0 +1,203 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This is unit test of Test data_norm Op.""" + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from op_test import OpTest +from paddle.fluid.framework import grad_var_name + + +def _reference_testing(x, batch_size, batch_sum, batch_square_sum): + x_shape = x.shape + means_arr = batch_sum / batch_size + scales_arr = np.sqrt(batch_size / batch_square_sum) + for i in range(x_shape[0]): + x[i] -= means_arr + x[i] *= scales_arr + y = np.array(x) + return y + + +def create_or_get_tensor(scope, var_name, var, place): + tensor = scope.var(var_name).get_tensor() + if var is not None: + assert isinstance(var, np.ndarray) + tensor.set_recursive_sequence_lengths([]) + tensor.set(var, place) + return tensor + + +class TestDataNormOpInference(unittest.TestCase): + """ + test class for data norm op + test forward + """ + + def setUp(self): + """ + init members of this class + """ + self.dtype = np.float32 + self.use_mkldnn = False + + def __assert_close(self, tensor, np_array, msg, atol=1e-4): + self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) + + def check_with_place(self, place, data_layout, dtype, shape): + """ + do forward and check + + Args: + place(Place): CPUPlace + data_layout(str): NCHW or NWHC + dtype(dtype): np.float32 + shape(list): input shape + + """ + epsilon = 0.00001 + if len(shape) == 2: + x_shape = shape + c = x_shape[1] + else: + ValueError("len(shape) should be equal to 2") + scale_shape = [c] + + x_val = np.random.random_sample(x_shape).astype(dtype) + x_val = x_val - 0.5 + batch_size = np.ones(scale_shape).astype(np.float32) + batch_size *= 1e4 + batch_sum = np.zeros(scale_shape).astype(np.float32) + batch_square_sum = np.ones(scale_shape).astype(np.float32) + batch_square_sum *= 1e4 + + y_out = _reference_testing(x_val, batch_size, batch_sum, + batch_square_sum).astype(dtype) + + scope = core.Scope() + + # create input + x_tensor = create_or_get_tensor(scope, "x_val", + OpTest.np_dtype_to_fluid_dtype(x_val), + place) + batch_size_tensor = create_or_get_tensor( + scope, "batch_size", + OpTest.np_dtype_to_fluid_dtype(batch_size), place) + batch_sum_tensor = create_or_get_tensor( + scope, "batch_sum", + OpTest.np_dtype_to_fluid_dtype(batch_sum), place) + batch_square_sum_tensor = create_or_get_tensor( + scope, "batch_square_sum", + OpTest.np_dtype_to_fluid_dtype(batch_square_sum), place) + + # create output + y_tensor = create_or_get_tensor(scope, "y_out", None, place) + mean_tensor = create_or_get_tensor(scope, "mean", None, place) + scales_tensor = create_or_get_tensor(scope, "scales", None, place) + + data_norm_op = Operator( + "data_norm", + # inputs + X="x_val", + BatchSize="batch_size", + BatchSum="batch_sum", + BatchSquareSum="batch_square_sum", + # outputs + Y="y_out", + Means="mean", + Scales="scales", + # attrs + epsilon=epsilon, + use_mkldnn=self.use_mkldnn) + + data_norm_op.run(scope, place) + + # check inference result + self.__assert_close( + y_tensor, + y_out, + "inference output are different at " + str(place) + ", " + + data_layout + ", " + str(np.dtype(dtype)) + + str(np.array(y_tensor)) + str(y_out), + atol=1e-3) + + def test_check_output(self): + """ + test check forward, check output + """ + places = [core.CPUPlace()] + for place in places: + for data_format in ["NCHW", "NHWC"]: + self.check_with_place(place, data_format, self.dtype, [2, 3]) + + +class TestDataNormOp(OpTest): + """ + test class for data norm op + test forward and backward + """ + + def setUp(self): + """ + init data norm op test env + """ + self.op_type = 'data_norm' + self.use_mkldnn = False + epsilon = 0.00001 + x_shape = [2, 3] + scale_shape = [3] + tp = np.float32 + + x_val = np.array([[-0.35702616, -0.42756206, -0.08306625], + [0.41199666, -0.21719968, -0.10180971]]).astype(tp) + batch_size = np.ones(scale_shape).astype(tp) + batch_size *= 1e4 + batch_sum = np.zeros(scale_shape).astype(tp) + batch_square_sum = np.ones(scale_shape).astype(tp) + batch_square_sum *= 1e4 + + y = np.array(x_val) + + mean = np.array([[0, 0, 0], [0, 0, 0]]).astype(tp) + scale = np.array([[1, 1, 1], [1, 1, 1]]).astype(tp) + + self.inputs = { + "X": x_val, + "BatchSize": batch_size, + "BatchSum": batch_sum, + "BatchSquareSum": batch_square_sum + } + self.outputs = {"Y": y, "Means": mean, "Scales": scale} + self.attrs = {"epsilon": epsilon, "use_mkldnn": self.use_mkldnn} + + def test_check_output(self): + """ + test check forward, check output + """ + self.check_output() + + def test_check_grad(self): + """ + test check backward, check grad + """ + self.check_grad(['X'], 'Y', no_grad_set=set([])) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index 27557897ba..8bfa88dc2c 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -237,6 +237,25 @@ class TestDataset(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda( ) else fluid.CUDAPlace(0)) exe.run(fluid.default_startup_program()) + + for i in range(2): + try: + exe.train_from_dataset(fluid.default_main_program(), dataset) + exe.train_from_dataset( + fluid.default_main_program(), dataset, thread=1) + exe.train_from_dataset( + fluid.default_main_program(), dataset, thread=2) + exe.train_from_dataset( + fluid.default_main_program(), dataset, thread=2) + exe.train_from_dataset( + fluid.default_main_program(), dataset, thread=3) + exe.train_from_dataset( + fluid.default_main_program(), dataset, thread=4) + except ImportError as e: + pass + except Exception as e: + self.assertTrue(False) + if self.use_data_loader: data_loader = fluid.io.DataLoader.from_dataset(dataset, fluid.cpu_places(), @@ -253,12 +272,14 @@ class TestDataset(unittest.TestCase): self.assertTrue(False) dataset.set_merge_by_lineid(slots_vars) + dataset.set_fleet_send_sleep_seconds(2) dataset.preload_into_memory() dataset.wait_preload_done() dataset.release_memory() dataset.preload_into_memory(1) dataset.wait_preload_done() fleet_ptr = fluid.core.Fleet() + fleet_ptr.set_client2client_config(1, 1, 1) os.remove("./test_in_memory_dataset_run_a.txt") os.remove("./test_in_memory_dataset_run_b.txt") @@ -311,6 +332,19 @@ class TestDataset(unittest.TestCase): except Exception as e: self.assertTrue(False) + dataset2 = fluid.DatasetFactory().create_dataset("QueueDataset") + dataset2.set_use_var(slots_vars) + dataset2.set_batch_size(32) + dataset2.set_thread(3) + dataset2.set_pipe_command("cat") + dataset.set_filelist([]) + try: + exe.train_from_dataset(fluid.default_main_program(), dataset2) + except ImportError as e: + print("warning: we skip trainer_desc_pb2 import problem in windows") + except Exception as e: + self.assertTrue(False) + os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_b.txt")