From f17534af08f4ad001fb1abbb55de970dfc41567b Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Fri, 15 Jan 2021 10:57:54 +0800 Subject: [PATCH] ps cache support sparse --- .../pass/const_input_to_attr_registry.cc | 4 +- .../parallel/graph_util/generate_graph.cc | 19 ++- .../parallel/ops_info/gather_v2_p_info.cc | 15 +- .../frontend/parallel/ops_info/ops_utils.h | 1 + .../frontend/parallel/ops_info/unique_info.cc | 62 +++++++++ .../frontend/parallel/ops_info/unique_info.h | 5 + mindspore/ccsrc/pipeline/jit/init.cc | 3 +- mindspore/ccsrc/ps/parameter_server.h | 10 +- .../ccsrc/ps/ps_cache/ps_cache_manager.cc | 88 +++++++----- .../ccsrc/ps/ps_cache/ps_cache_manager.h | 20 ++- mindspore/ccsrc/ps/ps_context.cc | 6 + mindspore/ccsrc/ps/ps_context.h | 1 + .../device/ascend/ascend_kernel_runtime.cc | 2 +- .../runtime/device/gpu/gpu_kernel_runtime.cc | 7 +- .../ccsrc/runtime/device/kernel_runtime.cc | 48 +++++-- .../ccsrc/runtime/device/kernel_runtime.h | 1 + mindspore/ccsrc/utils/utils.h | 2 +- mindspore/core/utils/convert_utils_base.h | 7 + mindspore/nn/layer/embedding.py | 32 +++-- mindspore/nn/optim/lazyadam.py | 21 +-- mindspore/parallel/_ps_context.py | 3 + mindspore/parallel/_utils.py | 5 +- .../run_parameter_server_train_cluster.sh | 12 +- .../run_parameter_server_train_distribute.sh | 11 +- .../run_parameter_server_train_standalone.sh | 9 +- .../recommend/wide_and_deep/src/callbacks.py | 5 +- .../wide_and_deep/src/wide_and_deep.py | 6 +- ...in_and_eval_parameter_server_distribute.py | 8 +- ...in_and_eval_parameter_server_standalone.py | 9 +- .../python_file_for_ci/callbacks.py | 128 ++++++++++++++++++ .../run_wide_and_deep_auto_parallel.sh | 3 +- 31 files changed, 446 insertions(+), 107 deletions(-) create mode 100644 tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/callbacks.py diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc index 52fc6e507b..d624f45a13 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc @@ -53,7 +53,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(prim::kPrimReduceAny->name(), {1}); Register(prim::kPrimUnsortedSegmentMin->name(), {2}); Register(prim::kPrimUnsortedSegmentMax->name(), {2}); - Register(kSparseGatherV2, {2}); + Register(kSparseGatherV2OpName, {2}); Register(kUnsortedSegmentProdOpName, {2}); Register(kSimpleMeanGradOpName, {1}); Register(kMeanGradOpName, {1}); @@ -109,7 +109,7 @@ bool ConstInputToAttrInfoRegistry::GetRegisterByOpName(const std::string &op_nam ConstInputToAttrInfoRegister *reg) const { if (op_input_to_attr_map_.find(op_name) != op_input_to_attr_map_.end()) { *reg = op_input_to_attr_map_.at(op_name); - MS_LOG(DEBUG) << op_name << " const2attr find in registery."; + MS_LOG(DEBUG) << op_name << " const2attr find in registry."; return true; } return false; diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc index 7b5caf7c40..b3e7704ebf 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc @@ -31,15 +31,22 @@ std::string GetOpPythonPath(const OperatorName &op_name) { // almost all ops are defined in two main paths const std::string ops_module = OP_PATH; const std::string inner_ops_module = INNER_OP_PATH; + const std::string functional_op_module = FUNCTIONAL_OP_PATH; py::module mod = py::module::import(common::SafeCStr(ops_module)); py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module)); - if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) { - if (!py::hasattr(mod, common::SafeCStr(op_name))) { - MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name; - } + py::module functional_mod = py::module::import(common::SafeCStr(functional_op_module)); + + if (py::hasattr(inner_mod, common::SafeCStr(op_name))) { + return inner_ops_module; + } + if (py::hasattr(mod, common::SafeCStr(op_name))) { return ops_module; } - return inner_ops_module; + if (!py::hasattr(functional_mod, common::SafeCStr(op_name))) { + MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << functional_op_module + << " don't have op:" << op_name; + } + return functional_op_module; } ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) { @@ -141,7 +148,7 @@ Status GenerateGraph::Init(const CNodePtr &cnode) { } AnfNodePtr GenerateGraph::PushBack(const std::vector &inputs) { - CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode + CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to create anfnode MS_EXCEPTION_IF_NULL(cnode); cnode->set_scope(scope_); if (inputs.size() < 2) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index cf2ff01ca5..dde660ed87 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -24,8 +24,10 @@ #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/graph_util/generate_graph.h" +#include "frontend/parallel/context.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) -#include "ps/ps_cache/ps_data/ps_data_prefetch.h" +#include "ps/ps_cache/ps_cache_manager.h" +#include "utils/ms_context.h" #endif namespace mindspore { @@ -158,6 +160,15 @@ Status GatherV2PInfo::GetAttrs() { if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) { dynamic_shape_indices_ = true; } +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); + MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); + bool enable_sparse = MsContext::GetInstance()->get_param(MS_CTX_ENABLE_SPARSE); + if (ps::PsDataPrefetch::GetInstance().cache_enable() && enable_sparse) { + dynamic_shape_indices_ = true; + } +#endif return SUCCESS; } @@ -531,7 +542,7 @@ Status GatherV2PInfo::InferBias() { } #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) if (ps::PsDataPrefetch::GetInstance().cache_enable()) { - bias_ = 0; + bias_ = static_cast(ps::PsCacheManager::GetInstance().cache_indices_lower_bound()); return SUCCESS; } #endif diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 0a1b2563e8..9cff810462 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -68,6 +68,7 @@ constexpr char REDUCE_OP_MAX[] = "max"; constexpr char REDUCE_OP_MIN[] = "min"; constexpr char OP_PATH[] = "mindspore.ops.operations"; constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops"; +constexpr char FUNCTIONAL_OP_PATH[] = "mindspore.ops.functional"; constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils"; constexpr char GET_OP_FUNCTION[] = "_get_python_op"; constexpr char KEEP_DIMS[] = "keep_dims"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc index 8b70389d45..271c4c1bb5 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc @@ -23,9 +23,13 @@ #include "ir/value.h" #include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/strategy.h" #include "frontend/parallel/context.h" #include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#include "ps/ps_cache/ps_cache_manager.h" +#endif namespace mindspore { namespace parallel { @@ -186,5 +190,63 @@ Status UniqueInfo::GenerateStrategies(int64_t stage_id) { } return SUCCESS; } + +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + GenerateGraph gen_g = GenerateGraph(); + if (gen_g.Init(cnode) != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed"; + return FAILED; + } + auto bias = static_cast(ps::PsCacheManager::GetInstance().cache_indices_lower_bound()); + auto slice_size = SizeToLong(ps::PsCacheManager::GetInstance().vocab_cache_size()); + + auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias)}); + auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); + auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size - 1)}); + auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); + auto unique = gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node()}); + auto tuple_getitem_0 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), unique, CreatInt64Imm(0)}); + auto tuple_getitem_1 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), unique, CreatInt64Imm(1)}); + auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), tuple_getitem_1}); + auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); + auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), tuple_getitem_1, cast}); + + Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); + OperatorAttrs attrs = {attr_op}; + AnfNodePtr reduce_op; + reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul}); + auto make_tuple = gen_g.PushBack({gen_g.NewOpInst(MAKE_TUPLE), tuple_getitem_0, reduce_op}); + std::vector> input_nodes = {std::make_pair(sub, 1), std::make_pair(unique, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, make_tuple)); + return SUCCESS; +} +#endif + +ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + if (ps::PsDataPrefetch::GetInstance().cache_enable()) { + auto inputs = cnode->inputs(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "Invalid inputs"; + } + const auto &primitive = GetValueNode(inputs[0]); + const auto &attr = primitive->GetAttr("cache_enable"); + if (attr == nullptr) { + return nullptr; + } + auto need_mask = GetValue(attr); + if (!need_mask) { + return nullptr; + } + if (ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; + } + return replace_graph_; + } +#endif + return nullptr; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h index 94323bc523..b2037617f7 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h @@ -39,6 +39,7 @@ class UniqueInfo : public OperatorInfo { Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int64_t stage_id) override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; protected: Status CheckStrategy(const StrategyPtr &strategy) override; @@ -50,8 +51,12 @@ class UniqueInfo : public OperatorInfo { Status InferMirrorOps() override; Status InferForwardCommunication() override { return SUCCESS; } Status InferAsLossDivisor() override { return SUCCESS; } +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + Status ComputeReplaceGraph(const CNodePtr &cnode); +#endif private: + std::string replace_op_name_ = UNIQUE; int64_t dev_num_ = 1; }; } // namespace parallel diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index ac0c8d01bb..bc09d18947 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -321,7 +321,8 @@ PYBIND11_MODULE(_c_expression, m) { .def("insert_weight_init_info", &PSContext::InsertWeightInitInfo, "Insert embedding table initialization seed.") .def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.") .def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.") - .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not."); + .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.") + .def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode."); (void)py::class_>(m, "OpInfoLoaderPy") .def(py::init()) diff --git a/mindspore/ccsrc/ps/parameter_server.h b/mindspore/ccsrc/ps/parameter_server.h index 7d7fb8c8c5..2a790f7332 100644 --- a/mindspore/ccsrc/ps/parameter_server.h +++ b/mindspore/ccsrc/ps/parameter_server.h @@ -773,12 +773,14 @@ void ParameterServer::GetEmbeddingTableParamPtr() { for (auto cnode : cnodes) { MS_EXCEPTION_IF_NULL(cnode); std::string cnode_name = AnfAlgo::GetCNodeName(cnode); - if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName) { + if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName || cnode_name == kSparseGatherV2OpName) { auto embedding_table = AnfAlgo::GetInputNode(cnode, 0); MS_EXCEPTION_IF_NULL(embedding_table); - MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count; - embedding_tables_.insert(std::make_pair(count, embedding_table->cast())); - count++; + if (embedding_table->isa()) { + MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count; + embedding_tables_.insert(std::make_pair(count, embedding_table->cast())); + count++; + } } } } diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index 7f7c09368c..d82d0c3ce3 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -35,11 +35,11 @@ void PsCacheManager::InsertHashTableSize(const std::string ¶m_name, size_t c if (vocab_size_ == 0) { vocab_size_ = vocab_size; } - if (cache_vocab_size_ == 0) { - cache_vocab_size_ = cache_vocab_size; + if (vocab_cache_size_ == 0) { + vocab_cache_size_ = cache_vocab_size; } - if (host_cache_vocab_size_ == 0) { - host_cache_vocab_size_ = cache_vocab_size * kHostCacheScaleFactor; + if (host_vocab_cache_size_ == 0) { + host_vocab_cache_size_ = cache_vocab_size * kHostCacheScaleFactor; } } @@ -148,8 +148,8 @@ void PsCacheManager::Initialize() { Util::SetInternalEnvVar(); worker.Run(); } - embedding_device_cache_ = std::make_shared(batch_elements_, cache_vocab_size_); - embedding_host_cache_ = std::make_shared(batch_elements_, host_cache_vocab_size_); + embedding_device_cache_ = std::make_shared(batch_elements_, vocab_cache_size_); + embedding_host_cache_ = std::make_shared(batch_elements_, host_vocab_cache_size_); AddEmbeddingTable(); AllocMemForHashTable(); SetLocalIdRank(); @@ -220,13 +220,13 @@ void PsCacheManager::AllocMemForHashTable() { for (auto &item : hash_tables_) { size_t embedding_size = item.second.embedding_size; auto &device_address = item.second.device_address; - device_address.size = cache_vocab_size_ * embedding_size * sizeof(float); + device_address.size = vocab_cache_size_ * embedding_size * sizeof(float); auto addr = embedding_device_cache_->cache_->MallocMemory(device_address.size); MS_EXCEPTION_IF_NULL(addr); device_address.addr = addr; auto &host_address = item.second.host_address; - auto host_address_ptr = new float[host_cache_vocab_size_ * embedding_size]; + auto host_address_ptr = new float[host_vocab_cache_size_ * embedding_size]; MS_EXCEPTION_IF_NULL(host_address_ptr); host_address = std::shared_ptr(host_address_ptr, std::default_delete()); MS_EXCEPTION_IF_NULL(host_address); @@ -239,21 +239,28 @@ void PsCacheManager::AllocMemForHashTable() { embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast( embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float))); MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_); - if (!(embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_))) { + if (!(embedding_device_cache_->cache_->MallocConstantMemory(vocab_cache_size_))) { MS_LOG(EXCEPTION) << "MallocConstantMemory failed."; } } void PsCacheManager::SetLocalIdRank() { auto worker_num = ::ps::NumWorkers(); - auto worker_id = ::ps::MyRank(); - auto local_shard_size = FloatToSize(std::ceil(SizeToFloat(vocab_size_) / worker_num)); - range_bound_.first = local_shard_size * worker_id; - range_bound_.second = std::min(range_bound_.first + local_shard_size, vocab_size_); - MS_LOG(INFO) << "Worker num:" << worker_num << ", worker id:" << worker_id << ", rank id begin:" << range_bound_.first - << ", rank id end:" << range_bound_.second; + auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num)); + vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_); + emb_table_slice_bounds_.first = local_shard_size * rank_id_; + emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_)); + cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_; + cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_); + MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_ + << ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second + << ", cache indices begin: " << cache_indices_bounds_.first + << ", cache indices end: " << cache_indices_bounds_.second + << ", vocab_cache_size_diff: " << vocab_cache_size_diff_; } +int PsCacheManager::cache_indices_lower_bound() const { return cache_indices_bounds_.first; } + std::string PsCacheManager::channel_name() { std::lock_guard locker(channel_mutex_); return channel_name_; @@ -398,8 +405,8 @@ bool PsCacheManager::ProcessData() { return true; } -bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, - bool *in_device, size_t *hash_hit_count) { +bool PsCacheManager::CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, + bool *in_device, bool *out_range, size_t *hash_hit_count) { MS_ERROR_IF_NULL(batch_ids); MS_ERROR_IF_NULL(hash_index); MS_ERROR_IF_NULL(in_device); @@ -410,9 +417,19 @@ bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batc const auto &hash_id_to_index = device_hash_map->hash_id_to_index(); for (size_t i = 0; i < batch_ids_len; ++i) { + if (batch_ids[i] < emb_table_slice_bounds_.first) { + hash_index[i] = batch_ids[i] - vocab_cache_size_diff_; + out_range[i] = true; + continue; + } + if (batch_ids[i] >= emb_table_slice_bounds_.second) { + hash_index[i] = batch_ids[i] + cache_indices_bounds_.second; + out_range[i] = true; + continue; + } auto iter = hash_id_to_index.find(batch_ids[i]); if (iter != hash_id_to_index.end()) { - hash_index[i] = iter->second; + hash_index[i] = iter->second + cache_indices_bounds_.first; if (device_hash_map->hash_step(iter->second) != data_step_) { ++(*hash_hit_count); device_hash_map->set_hash_step(iter->second, data_step_); @@ -423,11 +440,12 @@ bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batc return true; } -bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, - bool *in_device) { +bool PsCacheManager::CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index, + bool *in_device, bool *out_range) { MS_ERROR_IF_NULL(batch_ids); MS_ERROR_IF_NULL(hash_index); MS_ERROR_IF_NULL(in_device); + MS_ERROR_IF_NULL(out_range); size_t thread_num = batch_ids_len / kMinIdsPerThread + 1; thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; @@ -441,8 +459,9 @@ bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_id break; } size_t task_proc_lens = batch_ids_len / thread_num + (i < (batch_ids_len % thread_num) ? 1 : 0); - threads[i] = std::thread(&PsCacheManager::CheckIDInDeviceTask, this, batch_ids + task_offset, task_proc_lens, - hash_index + task_offset, in_device + task_offset, hash_hit_count + i); + threads[i] = + std::thread(&PsCacheManager::CheckCacheHitOrOutRangeTask, this, batch_ids + task_offset, task_proc_lens, + hash_index + task_offset, in_device + task_offset, out_range + task_offset, hash_hit_count + i); task_offset += task_proc_lens; } if (task_offset != batch_ids_len) { @@ -477,27 +496,26 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, MS_ERROR_IF_NULL(hash_index); statistics_info_.batch_id_count_ = batch_ids_len; std::unique_ptr in_device(new bool[batch_ids_len]); + std::unique_ptr out_range(new bool[batch_ids_len]); if (memset_s(in_device.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) { - MS_LOG(EXCEPTION) << "Data in device memset failed."; + MS_LOG(EXCEPTION) << "Initialize in_device array failed."; + } + if (memset_s(out_range.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) { + MS_LOG(EXCEPTION) << "Initialize out_range array failed."; } - CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get()); + RETURN_IF_FALSE(CheckCacheHitOrOutRange(batch_ids, batch_ids_len, hash_index, in_device.get(), out_range.get())); RETURN_IF_FALSE(ResetEmbeddingHashMap()); for (size_t i = 0; i < batch_ids_len; i++) { - if (in_device[i]) { + if (in_device[i] || out_range[i]) { continue; } bool need_swap_host_to_device = true; bool need_swap_device_to_host = true; - auto id = batch_ids[i]; - if ((id < SizeToInt(range_bound_.first)) || (id >= SizeToInt(range_bound_.second))) { - hash_index[i] = -1; - continue; - } int index = INVALID_INDEX_VALUE; - RETURN_IF_FALSE(ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device, &index)); - hash_index[i] = index; + RETURN_IF_FALSE(ParseDeviceData(batch_ids[i], &need_swap_device_to_host, &need_swap_host_to_device, &index)); + hash_index[i] = index + cache_indices_bounds_.first; if (need_swap_host_to_device) { - RETURN_IF_FALSE(ParseHostDataHostToDevice(id)); + RETURN_IF_FALSE(ParseHostDataHostToDevice(batch_ids[i])); } if (need_swap_device_to_host) { RETURN_IF_FALSE(ParseHostDataDeviceToHost()); @@ -667,7 +685,7 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, const int *indices_addr, float *output_addr) { - size_t first_dim_size = host_cache_vocab_size_; + size_t first_dim_size = host_vocab_cache_size_; size_t outer_dim_size = embedding_size; size_t thread_num = indices_lens / 10000 + 1; @@ -697,7 +715,7 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data, float *hash_table_addr) { - size_t first_dim_size = host_cache_vocab_size_; + size_t first_dim_size = host_vocab_cache_size_; size_t thread_num = insert_indices_size / 10000 + 1; thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; std::thread threads[kMaxThreadNum]; diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index f67bcc8610..c78efafb95 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -125,7 +125,10 @@ class PsCacheManager { const size_t &QueryHashTableSize(const std::string ¶m_name) const; bool IsHashTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; } void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; } + void set_rank_id(int rank_id) { rank_id_ = rank_id; } bool initialized_ps_cache() const { return initialized_ps_cache_; } + size_t vocab_cache_size() const { return vocab_cache_size_; } + int cache_indices_lower_bound() const; void DoProcessData(uint32_t device_id, void *context); void IncreaseGraphStep(const std::string &channel_name); void SyncEmbeddingTable(); @@ -170,10 +173,12 @@ class PsCacheManager { void DumpStatisticsInfo(size_t each_print_step = 1000); bool SyncHostEmbeddingTable(); bool SyncDeviceEmbeddingTable(); - bool CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, - size_t *hash_hit_count); - bool CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device); + bool CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, + bool *out_range, size_t *hash_hit_count); + bool CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, + bool *out_range); bool ResetEmbeddingHashMap(); + bool initialized_ps_cache_{false}; std::string channel_name_; std::mutex channel_mutex_; @@ -190,11 +195,14 @@ class PsCacheManager { std::shared_ptr embedding_host_cache_; size_t vocab_size_{0}; - size_t cache_vocab_size_{0}; - size_t host_cache_vocab_size_{0}; + size_t vocab_cache_size_{0}; + size_t host_vocab_cache_size_{0}; size_t batch_elements_{0}; PsCacheStatisticsInfo statistics_info_; - std::pair range_bound_; + std::pair emb_table_slice_bounds_; + std::pair cache_indices_bounds_; + int vocab_cache_size_diff_{0}; + int rank_id_{0}; std::atomic_bool finish_insert_init_info_{false}; std::atomic_bool finish_init_parameter_server_{false}; std::atomic_bool running_{false}; diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 5d3942f96c..2ce0e6e472 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -129,5 +129,11 @@ void PSContext::set_cache_enable(bool cache_enable) const { PsDataPrefetch::GetInstance().set_cache_enable(cache_enable); #endif } + +void PSContext::set_rank_id(int rank_id) const { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + ps_cache_instance.set_rank_id(rank_id); +#endif +} } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 070a4df464..bf14121382 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -52,6 +52,7 @@ class PSContext { void InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const; void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const; void set_cache_enable(bool cache_enable) const; + void set_rank_id(int rank_id) const; private: PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {} diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index fd116d1ec7..3e7516d686 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -391,7 +391,7 @@ bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) { bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { InnerSetContext(); if (graph->is_dynamic_shape()) { - if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { + if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE && (ConfigManager::GetInstance().iter_num() > 1)) { MS_LOG(EXCEPTION) << "Dynamic shape is not supported with sink mode."; } if (DumpJsonParser::GetInstance().async_dump_enabled()) { diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 20a13b733e..e47d89d8c2 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -851,7 +851,7 @@ void GPUKernelRuntime::UpdateHostSwapInQueue(const DeviceAddressPtr device_addre MS_LOG(WARNING) << "Unexpected device address status: " << status; break; default: - MS_LOG(EXCEPTION) << "Invaild device address status: " << status; + MS_LOG(EXCEPTION) << "Invalid device address status: " << status; } } @@ -1092,6 +1092,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) MS_EXCEPTION_IF_NULL(mem_reuse_util_); auto cnode = kernel->cast(); MS_EXCEPTION_IF_NULL(cnode); + // Can not free the input addr of communication op when enable multi stream if (AnfAlgo::IsCommunicationOp(kernel)) { return; } @@ -1106,7 +1107,9 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) } auto kernel_with_index = GetPrevNodeOutput(kernel, i); - if (AnfAlgo::IsCommunicationOp(kernel_with_index.first)) { + // Maintain output addr of fused communication op to improve training performance + if (AnfAlgo::IsCommunicationOp(kernel_with_index.first) && + AnfAlgo::GetInputTensorNum(kernel_with_index.first) > 1) { continue; } diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 609a8f316f..2f1373c7ec 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -1049,7 +1049,8 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, MS_EXCEPTION_IF_NULL(graph); for (const auto &kernel : graph->execution_order()) { MS_EXCEPTION_IF_NULL(kernel); - if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { + auto kernel_name = AnfAlgo::GetCNodeName(kernel); + if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) { continue; } auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true); @@ -1061,13 +1062,15 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, continue; } auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); - while (input_index.first->isa() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) { - input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true); + while (input_index.first->isa() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) { + input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true); MS_EXCEPTION_IF_NULL(input_index.first); } - if (input_index.first->isa() && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) { + auto input_index_node_name = AnfAlgo::GetCNodeName(input_index.first); + if (input_index.first->isa() && (input_index_node_name != kGetNextOpName)) { bool full_batch = parallel::ParallelContext::GetInstance()->full_batch(); - if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) { + if ((!full_batch && (input_index_node_name != kUniqueOpName)) || + (full_batch && (input_index_node_name != kMinimumOpName))) { MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from " << input_index.first->fullname_with_scope(); MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in " @@ -1082,6 +1085,28 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, } } +void KernelRuntime::CheckSparsePSEmbeddingCache(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto pre_node = AnfAlgo::GetPrevNodeOutput(node, 1, true); + while (pre_node.first->isa() && (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) { + pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true); + MS_EXCEPTION_IF_NULL(pre_node.first); + } + if (!(pre_node.first->isa()) || (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) { + MS_LOG(EXCEPTION) << "The input_indices of kernel[SparseGatherV2] must be unique in parameter server cache mode"; + } + + pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true); + while (pre_node.first->isa() && (AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName)) { + pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true); + MS_EXCEPTION_IF_NULL(pre_node.first); + } + if (!(pre_node.first->isa()) || (AnfAlgo::GetCNodeName(pre_node.first) != kGetNextOpName)) { + MS_LOG(EXCEPTION) << "The input indices of kernel[Unique] must be produced from dataset directly and the indices " + "value can not be changed before delivering to kernel[Unique] in parameter server cache mode."; + } +} + void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); AnfNodePtr first_cache_input_index = nullptr; @@ -1090,16 +1115,23 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g MS_EXCEPTION_IF_NULL(first_cache_input_index); for (const auto &kernel : graph->execution_order()) { MS_EXCEPTION_IF_NULL(kernel); - if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { + auto kernel_name = AnfAlgo::GetCNodeName(kernel); + if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) { continue; } auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true); auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true); MS_EXCEPTION_IF_NULL(input_param.first); MS_EXCEPTION_IF_NULL(input_index.first); + if (!input_param.first->isa()) { + continue; + } auto param_name = input_param.first->fullname_with_scope(); - while (input_index.first->isa() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) { - input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true); + if (ps::ps_cache_instance.IsHashTable(param_name) && (kernel_name == kSparseGatherV2OpName)) { + CheckSparsePSEmbeddingCache(kernel); + } + while (input_index.first->isa() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) { + input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true); MS_EXCEPTION_IF_NULL(input_index.first); } if (input_index.first == first_cache_input_index) { diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index f759db9e88..54f7fb2052 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -138,6 +138,7 @@ class KernelRuntime { void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index, size_t *first_cache_size); void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph); + void CheckSparsePSEmbeddingCache(const CNodePtr &node); #endif protected: diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 30e6dc13e2..ada68a5c3c 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -83,7 +83,7 @@ constexpr auto kScatterNdOpName = "ScatterNd"; constexpr auto kStridedSliceAssignOpName = "StridedSliceAssign"; constexpr auto kStridedSliceOpName = "StridedSlice"; constexpr auto kStridedSliceGradOpName = "StridedSliceGrad"; -constexpr auto kSparseGatherV2 = "SparseGatherV2"; +constexpr auto kSparseGatherV2OpName = "SparseGatherV2"; constexpr auto kUnsortedSegmentProdOpName = "UnsortedSegmentProd"; constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin"; constexpr auto kFlattenGradOpName = "FlattenGrad"; diff --git a/mindspore/core/utils/convert_utils_base.h b/mindspore/core/utils/convert_utils_base.h index b5673f5ac0..bcdfb299a7 100644 --- a/mindspore/core/utils/convert_utils_base.h +++ b/mindspore/core/utils/convert_utils_base.h @@ -73,6 +73,13 @@ inline size_t FloatToSize(float u) { } inline float IntToFloat(int32_t v) { return static_cast(v); } +inline int FloatToInt(float u) { + if (u > static_cast((std::numeric_limits::max)())) { + MS_LOG(EXCEPTION) << "The float value(" << u << ") exceeds the maximum value of int."; + } + return static_cast(u); +} + inline float SizeToFloat(size_t v) { return static_cast(v); } inline double LongToDouble(int64_t v) { return static_cast(v); } diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index 5618223ef5..fb056a64ce 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -20,10 +20,12 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer -from mindspore.communication.management import get_group_size -from mindspore.context import ParallelMode, get_context +from mindspore.communication.management import get_group_size, get_rank +from mindspore.context import ParallelMode from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch -from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context +from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context +from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _set_rank_id +from mindspore import context from mindspore._checkparam import Rel from mindspore._checkparam import Validator as validator from mindspore.ops.primitive import constexpr @@ -227,8 +229,6 @@ class EmbeddingLookup(Cell): self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size') self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), name='embedding_table') - if self.cache_enable and enable_ps: - self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size) parallel_mode = _get_parallel_mode() is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) self.gather_revert = P.GatherV2() @@ -238,6 +238,10 @@ class EmbeddingLookup(Cell): self.shape = P.Shape() if is_auto_parallel: self.unique = P.Unique().shard(((1,),)) + if self.cache_enable and enable_ps: + self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size) + if is_auto_parallel: + self.unique.add_prim_attr('cache_enable', True) indices_shape_size = 2 if slice_mode == "field_slice" and is_auto_parallel: if not manual_shapes: @@ -252,7 +256,7 @@ class EmbeddingLookup(Cell): self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) elif slice_mode == "table_row_slice" and is_auto_parallel: full_batch = _get_full_batch() - if target == 'DEVICE' and not full_batch: + if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse): indices_shape_size = 1 self.gather_revert.shard(((1, 1), (get_group_size(),))) self.forward_unique = True @@ -293,7 +297,7 @@ class EmbeddingLookup(Cell): raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.") if not self.sparse: raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.") - if get_context("device_target") != 'Ascend': + if context.get_context("device_target") != 'Ascend': raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.") logger.info("EmbeddingLookup cache enable takes effect.") @@ -320,21 +324,29 @@ class EmbeddingLookup(Cell): parallel_mode = _get_parallel_mode() is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) if is_auto_parallel: - device_num = get_group_size() + rank_size = get_group_size() + rank_id = get_rank() full_batch = _get_full_batch() - if device_num > 1 and not (full_batch and slice_mode == "table_row_slice"): + if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"): raise ValueError("The embeddingLookup cache of parameter server parallel only be used " "in 'full_batch' and 'table_row_slice' parallel strategy.") - self.vocab_cache_size = self.vocab_cache_size * device_num + self.vocab_cache_size = self.vocab_cache_size * rank_size + _set_rank_id(rank_id) self.cache_enable = True if _is_role_worker(): self.vocab_size = self.vocab_cache_size + if context.get_context("enable_sparse") != self.sparse: + raise ValueError("The value of parameter 'sparse' must be same for all EmbeddingLookup " + "kernels and equal the value of 'enable_sparse' in context setting in " + "parameter server cache mode") def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size): """PS embeddingLookup cache enable set.""" self.embedding_table.cache_enable = True self.embedding_table.is_param_ps = True _set_cache_enable(True) + if self.sparse: + self.forward_unique = True if _is_role_worker(): _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) diff --git a/mindspore/nn/optim/lazyadam.py b/mindspore/nn/optim/lazyadam.py index 7af8836643..f4e9ee5d62 100644 --- a/mindspore/nn/optim/lazyadam.py +++ b/mindspore/nn/optim/lazyadam.py @@ -28,14 +28,15 @@ _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") @_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") + "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", + "Bool") def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power, - beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter): + beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter, cache_enable): """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" success = True indices = gradient.indices values = gradient.values - if ps_parameter: + if ps_parameter and not cache_enable: op_shape = P.Shape() shapes = (op_shape(params), op_shape(m), op_shape(v), op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), @@ -75,12 +76,12 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, @_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") -def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, - beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter): + "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") +def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power, + beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter, cache_enable): """Apply lazy adam optimizer to the weight parameter using Tensor.""" success = True - if ps_parameter: + if ps_parameter and not cache_enable: op_shape = P.Shape() success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), (op_shape(params), op_shape(moment1), op_shape(moment2))), params)) @@ -245,12 +246,14 @@ class LazyAdam(Optimizer): success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, self.use_locking, self.use_nesterov, self._is_device, self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps), - lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters) + lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters, + self.cache_enable) else: success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, self.use_locking, self.use_nesterov, self._is_device, self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr), - gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters) + gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters, + self.cache_enable) return success @Optimizer.target.setter diff --git a/mindspore/parallel/_ps_context.py b/mindspore/parallel/_ps_context.py index b14c8c8f2d..70414eedef 100644 --- a/mindspore/parallel/_ps_context.py +++ b/mindspore/parallel/_ps_context.py @@ -142,3 +142,6 @@ def _set_cache_enable(cache_enable): os.environ['GOTO_NUM_THREADS'] = '2' os.environ['OMP_NUM_THREADS'] = '2' ps_context().set_cache_enable(cache_enable) + +def _set_rank_id(rank_id): + ps_context().set_rank_id(rank_id) diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index 8dc97fc8f2..8b1ae18a8f 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -190,7 +190,10 @@ def _get_python_op(op_name, op_path, instance_name, arglist): """Get python operator.""" module = __import__(op_path, fromlist=["None"]) cls = getattr(module, op_name) - op = cls(*arglist) + if op_path != "mindspore.ops.functional": + op = cls(*arglist) + else: + op = cls op.set_prim_instance_name(instance_name) return op diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh index de312f6f98..64a1f1ed76 100644 --- a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh +++ b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh @@ -17,7 +17,8 @@ #bash run_parameter_server_train_cluster.sh RANK_SIZE EPOCHS DEVICE_TARGET DATASET # LOCAL_WORKER_NUM LOCAL_SERVER_NUM SERVER_NUM -# SCHED_HOST SCHED_PORT ROLE RANK_TABLE_FILE VOCAB_CACHE_SIZE +# SCHED_HOST SCHED_PORT ROLE RANK_TABLE_FILE +# VOCAB_CACHE_SIZE SPARSE execute_path=$(pwd) script_self=$(readlink -f "$0") self_path=$(dirname "${script_self}") @@ -37,11 +38,16 @@ export MS_SCHED_PORT=$9 export MS_ROLE=${10} export RANK_TABLE_FILE=${11} export VOCAB_CACHE_SIZE=${12} +export SPARSE=${13} if [[ ! -n "${12}" ]]; then export VOCAB_CACHE_SIZE=0 fi +if [[ ! -n "${13}" ]]; then + export SPARSE=0 +fi + echo "=====Role is $MS_ROLE======" if [[ "$MS_ROLE" == "MS_SCHED" ]]; then @@ -73,7 +79,7 @@ if [[ "$MS_ROLE" == "MS_WORKER" ]]; then mpirun --allow-run-as-root -n $LOCAL_WORKER_NUM --output-filename log_output --merge-stderr-to-stdout \ python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \ --device_target=$DEVICE --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ - --vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker.log 2>&1 & + --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker.log 2>&1 & else for((i=0;i<$LOCAL_WORKER_NUM;i++)); do @@ -84,7 +90,7 @@ if [[ "$MS_ROLE" == "MS_WORKER" ]]; then export DEVICE_ID=$i python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \ --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ - --vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker_$i.log 2>&1 & + --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker_$i.log 2>&1 & done fi fi diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh index b2cf337753..73db06fc69 100644 --- a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh +++ b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh @@ -17,7 +17,7 @@ #bash run_parameter_server_train_distribute.sh RANK_SIZE EPOCHS DEVICE_TARGET DATASET # SERVER_NUM SCHED_HOST SCHED_PORT RANK_TABLE_FILE -# VOCAB_CACHE_SIZE +# VOCAB_CACHE_SIZE SPARSE execute_path=$(pwd) script_self=$(readlink -f "$0") self_path=$(dirname "${script_self}") @@ -33,11 +33,16 @@ export MS_SCHED_HOST=$6 export MS_SCHED_PORT=$7 export RANK_TABLE_FILE=$8 export VOCAB_CACHE_SIZE=$9 +export SPARSE=${10} if [[ ! -n "$9" ]]; then export VOCAB_CACHE_SIZE=0 fi +if [[ ! -n "${10}" ]]; then + export SPARSE=0 +fi + export MS_ROLE=MS_SCHED rm -rf ${execute_path}/sched/ mkdir ${execute_path}/sched/ @@ -65,7 +70,7 @@ if [[ "X$DEVICE_TARGET" == "XGPU" ]]; then mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \ python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \ --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ - --vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker.log 2>&1 & + --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker.log 2>&1 & else for((i=0;i<$MS_WORKER_NUM;i++)); do @@ -76,7 +81,7 @@ else export DEVICE_ID=$i python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \ --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ - --vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker_$i.log 2>&1 & + --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker_$i.log 2>&1 & done fi diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_standalone.sh b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_standalone.sh index 78b4c64a03..8c7bb208c8 100644 --- a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_standalone.sh +++ b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_standalone.sh @@ -16,7 +16,7 @@ #bash run_parameter_server_train_standalone.sh EPOCHS DEVICE_TARGET DATASET SERVER_NUM SCHED_HOST -# SCHED_PORT DEVICE_ID VOCAB_CACHE_SIZE +# SCHED_PORT DEVICE_ID VOCAB_CACHE_SIZE SPARSE execute_path=$(pwd) script_self=$(readlink -f "$0") self_path=$(dirname "${script_self}") @@ -31,11 +31,16 @@ export MS_SCHED_HOST=$5 export MS_SCHED_PORT=$6 DEVICE_ID=$7 export VOCAB_CACHE_SIZE=$8 +export SPARSE=$9 if [[ ! -n "$8" ]]; then export VOCAB_CACHE_SIZE=0 fi +if [[ ! -n "$9" ]]; then + export SPARSE=0 +fi + # Set device id if [[ "X$DEVICE_TARGET" == "XGPU" ]]; then if [[ ! -n "$DEVICE_ID" ]]; then @@ -76,4 +81,4 @@ mkdir ${execute_path}/worker/ cd ${execute_path}/worker/ || exit python -s ${self_path}/../train_and_eval_parameter_server_standalone.py --device_target=$DEVICE_TARGET \ --epochs=$EPOCH_SIZE --data_path=$DATASET --parameter_server=1 \ - --vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker.log 2>&1 & + --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker.log 2>&1 & diff --git a/model_zoo/official/recommend/wide_and_deep/src/callbacks.py b/model_zoo/official/recommend/wide_and_deep/src/callbacks.py index bebefc9ca0..c10e221ad6 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/callbacks.py +++ b/model_zoo/official/recommend/wide_and_deep/src/callbacks.py @@ -115,8 +115,11 @@ class EvalCallBack(Callback): if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL, ParallelMode.DATA_PARALLEL): rank_id = get_rank() + enable_data_sink = not self.sparse + if bool(self.config.parameter_server): + enable_data_sink = True start_time = time.time() - out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.sparse)) + out = self.model.eval(self.eval_dataset, dataset_sink_mode=enable_data_sink) end_time = time.time() eval_time = int(end_time - start_time) diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index 3bbf47e13e..f167617f07 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -202,7 +202,7 @@ class WideDeepModel(nn.Cell): self.unique = P.Unique().shard(((1,),)) self.wide_gatherv2 = P.GatherV2() self.deep_gatherv2 = P.GatherV2() - if is_auto_parallel and sparse and not is_field_slice: + if is_auto_parallel and sparse and not is_field_slice and not parameter_server: target = 'DEVICE' if host_device_mix: target = 'CPU' @@ -376,12 +376,12 @@ class TrainStepWrap(nn.Cell): self.weights_w = ParameterTuple(weights_w) self.weights_d = ParameterTuple(weights_d) - if (sparse and is_auto_parallel) or (parameter_server and not cache_enable): + if (sparse and is_auto_parallel) or (sparse and parameter_server): self.optimizer_d = LazyAdam( self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) - if host_device_mix or parameter_server: + if host_device_mix or (parameter_server and not cache_enable): self.optimizer_w.target = "CPU" self.optimizer_d.target = "CPU" else: diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py index 03840efaac..19fd171d30 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py @@ -43,7 +43,7 @@ def get_wide_deep_net(config): if cache_enable: loss_net = VirtualDatasetCellTriple(loss_net) train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server), - cache_enable=(config.vocab_cache_size > 0)) + sparse=config.sparse, cache_enable=(config.vocab_cache_size > 0)) eval_net = PredictWithSigmoid(wide_deep_net) if cache_enable: eval_net = VirtualDatasetCellTriple(eval_net) @@ -138,7 +138,7 @@ def train_and_eval(config): callback_list.append(ckpoint_cb) model.train(epochs, ds_train, callbacks=callback_list, - dataset_sink_mode=bool(parameter_server and cache_enable)) + dataset_sink_mode=(parameter_server and cache_enable)) if __name__ == "__main__": @@ -148,7 +148,6 @@ if __name__ == "__main__": cache_enable = wide_deep_config.vocab_cache_size > 0 if cache_enable and wide_deep_config.device_target != "GPU": context.set_context(variable_memory_max_size="24GB") - context.set_context(enable_sparse=True) context.set_ps_context(enable_ps=True) init() context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank())) @@ -159,5 +158,8 @@ if __name__ == "__main__": else: context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=get_group_size()) + wide_deep_config.sparse = True + if wide_deep_config.sparse: + context.set_context(enable_sparse=True) train_and_eval(wide_deep_config) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py index a5a868e6a1..f051fb055a 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py @@ -29,7 +29,6 @@ from src.metrics import AUCMetric from src.config import WideDeepConfig sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -context.set_context(enable_sparse=True) def get_wide_deep_net(config): @@ -39,7 +38,7 @@ def get_wide_deep_net(config): wide_deep_net = WideDeepModel(config) loss_net = NetWithLossClass(wide_deep_net, config) train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server), - cache_enable=(config.vocab_cache_size > 0)) + sparse=config.sparse, cache_enable=(config.vocab_cache_size > 0)) eval_net = PredictWithSigmoid(wide_deep_net) return train_net, eval_net @@ -81,7 +80,6 @@ def train_and_eval(config): else: dataset_type = DataType.H5 parameter_server = bool(config.parameter_server) - cache_enable = config.vocab_cache_size > 0 print("epochs is {}".format(epochs)) ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, data_type=dataset_type) @@ -121,6 +119,11 @@ if __name__ == "__main__": wide_deep_config.argparse_init() context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True) + cache_enable = wide_deep_config.vocab_cache_size > 0 + if not cache_enable: + wide_deep_config.sparse = True + if wide_deep_config.sparse: + context.set_context(enable_sparse=True) context.set_ps_context(enable_ps=True) train_and_eval(wide_deep_config) diff --git a/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/callbacks.py b/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/callbacks.py new file mode 100644 index 0000000000..54f6de5c06 --- /dev/null +++ b/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/callbacks.py @@ -0,0 +1,128 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +callbacks +""" +import time +from mindspore.train.callback import Callback +from mindspore import context +from mindspore.context import ParallelMode +from mindspore.communication.management import get_rank + +def add_write(file_path, out_str): + """ + add lines to the file + """ + with open(file_path, 'a+', encoding="utf-8") as file_out: + file_out.write(out_str + "\n") + + +class LossCallBack(Callback): + """ + Monitor the loss in training. + + If the loss is NAN or INF, terminate the training. + + Note: + If per_print_times is 0, do NOT print loss. + If this process is MS_PSERVER role, do not run callbacks. + + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + def __init__(self, config=None, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("per_print_times must be in and >= 0.") + self._per_print_times = per_print_times + self.config = config + + def step_end(self, run_context): + """Monitor the loss in training.""" + cb_params = run_context.original_args() + if cb_params.net_outputs is None: + return + wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy() + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + cur_num = cb_params.cur_step_num + rank_id = 0 + parallel_mode = context.get_auto_parallel_context("parallel_mode") + if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL, + ParallelMode.DATA_PARALLEL): + rank_id = get_rank() + + print("===loss===", rank_id, cb_params.cur_epoch_num, cur_step_in_epoch, + wide_loss, deep_loss, flush=True) + + # raise ValueError + if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None: + loss_file = open(self.config.loss_file_name, "a+") + loss_file.write("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % + (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) + loss_file.write("\n") + loss_file.close() + print("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % + (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) + + +class EvalCallBack(Callback): + """ + Monitor the loss in evaluating. + + If the loss is NAN or INF, terminate evaluating. + + Note: + If per_print_times is 0, do NOT print loss. + + Args: + print_per_step (int): Print loss every times. Default: 1. + """ + def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1): + super(EvalCallBack, self).__init__() + if not isinstance(print_per_step, int) or print_per_step < 0: + raise ValueError("print_per_step must be int and >= 0.") + self.print_per_step = print_per_step + self.model = model + self.eval_dataset = eval_dataset + self.aucMetric = auc_metric + self.aucMetric.clear() + self.eval_file_name = config.eval_file_name + self.eval_values = [] + self.sparse = config.sparse + self.config = config + + def epoch_end(self, run_context): + """ + epoch end + """ + self.aucMetric.clear() + parallel_mode = context.get_auto_parallel_context("parallel_mode") + if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + context.set_auto_parallel_context(strategy_ckpt_save_file="", + strategy_ckpt_load_file=self.config.stra_ckpt) + rank_id = 0 + if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL, + ParallelMode.DATA_PARALLEL): + rank_id = get_rank() + start_time = time.time() + out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.sparse)) + end_time = time.time() + eval_time = int(end_time - start_time) + + time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime()) + out_str = "{} == Rank: {} == EvalCallBack model.eval(): {}; eval_time: {}s".\ + format(time_str, rank_id, out.values(), eval_time) + print(out_str) + self.eval_values = out.values() + add_write(self.eval_file_name, out_str) diff --git a/tests/st/model_zoo_tests/wide_and_deep/run_wide_and_deep_auto_parallel.sh b/tests/st/model_zoo_tests/wide_and_deep/run_wide_and_deep_auto_parallel.sh index 0cdd327212..48ba33ef8e 100644 --- a/tests/st/model_zoo_tests/wide_and_deep/run_wide_and_deep_auto_parallel.sh +++ b/tests/st/model_zoo_tests/wide_and_deep/run_wide_and_deep_auto_parallel.sh @@ -34,6 +34,7 @@ cp -r ${CODE_DIR} ${BASE_PATH}/wide_and_deep cp -f ${BASE_PATH}/python_file_for_ci/train_and_test_multinpu_ci.py ${BASE_PATH}/wide_and_deep/train_and_test_multinpu_ci.py cp -f ${BASE_PATH}/python_file_for_ci/__init__.py ${BASE_PATH}/wide_and_deep/__init__.py cp -f ${BASE_PATH}/python_file_for_ci/config.py ${BASE_PATH}/wide_and_deep/src/config.py +cp -f ${BASE_PATH}/python_file_for_ci/callbacks.py ${BASE_PATH}/wide_and_deep/src/callbacks.py cp -f ${BASE_PATH}/python_file_for_ci/datasets.py ${BASE_PATH}/wide_and_deep/src/datasets.py cp -f ${BASE_PATH}/python_file_for_ci/wide_and_deep.py ${BASE_PATH}/wide_and_deep/src/wide_and_deep.py source ${BASE_PATH}/env.sh @@ -55,7 +56,7 @@ for((i=0; i<${DEVICE_NUM}; i++)); do wait ${process_pid[i]} status=`echo $?` if [ "${status}" != "0" ]; then - echo "[ERROR] test wide_and_deep semi auto parallel failed. status: ${status}" + echo "[ERROR] test wide_and_deep semi auto parallel failed. status: ${status}" exit 1 else echo "[INFO] test wide_and_deep semi auto parallel success."