|
|
|
@ -36,9 +36,9 @@
|
|
|
|
|
#include "ir/anf.h"
|
|
|
|
|
#include "ir/func_graph_cloner.h"
|
|
|
|
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
|
|
|
|
#include "frontend/parallel/ps/worker.h"
|
|
|
|
|
#include "frontend/parallel/ps/common.h"
|
|
|
|
|
#include "frontend/parallel/ps/util.h"
|
|
|
|
|
#include "ps/worker.h"
|
|
|
|
|
#include "ps/common.h"
|
|
|
|
|
#include "ps/util.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
@ -1380,7 +1380,7 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tens
|
|
|
|
|
|
|
|
|
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
|
|
|
|
void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
|
|
|
|
|
if (!parallel::ps::Util::IsRoleOfWorker()) {
|
|
|
|
|
if (!ps::Util::IsRoleOfWorker()) {
|
|
|
|
|
MS_LOG(INFO) << "Not parameter server mode.";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -1393,7 +1393,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
|
|
|
|
|
if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
|
|
|
|
|
size_t embedding_table_idx = 0;
|
|
|
|
|
auto embedding_table = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), embedding_table_idx);
|
|
|
|
|
size_t key = parallel::ps::worker.SetParamKey(embedding_table->fullname_with_scope());
|
|
|
|
|
size_t key = ps::worker.SetParamKey(embedding_table->fullname_with_scope());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
|
|
|
|
|
} else if (AnfAlgo::GetCNodeName(node) == kPushOpName) {
|
|
|
|
|
auto pull_node = FindPullNode(node, node_list);
|
|
|
|
@ -1404,12 +1404,12 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
|
|
|
|
|
// Second input of Pull node is the trainable parameter.
|
|
|
|
|
size_t parameter_index = 1;
|
|
|
|
|
auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast<CNodePtr>(), parameter_index);
|
|
|
|
|
size_t key = parallel::ps::worker.SetParamKey(parameter_node->fullname_with_scope());
|
|
|
|
|
size_t key = ps::worker.SetParamKey(parameter_node->fullname_with_scope());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node);
|
|
|
|
|
|
|
|
|
|
std::string optimizer_name = AnfAlgo::GetNodeAttr<std::string>(node, kAttrOptimizerType);
|
|
|
|
|
parallel::ps::worker.SetKeyOptimId(key, optimizer_name);
|
|
|
|
|
ps::worker.SetKeyOptimId(key, optimizer_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -1417,7 +1417,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
|
|
|
|
|
|
|
|
|
|
void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
|
|
|
|
|
const std::vector<tensor::TensorPtr> &inputs_const) {
|
|
|
|
|
if (!parallel::ps::Util::IsRoleOfWorker()) {
|
|
|
|
|
if (!ps::Util::IsRoleOfWorker()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
std::vector<tensor::TensorPtr> inputs(inputs_const);
|
|
|
|
@ -1440,7 +1440,7 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
|
|
|
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
|
|
|
|
|
auto pk_node = input_node->cast<ParameterPtr>();
|
|
|
|
|
parallel::ps::worker.InitPSParamAndOptim(pk_node->fullname_with_scope(), tensor);
|
|
|
|
|
ps::worker.InitPSParamAndOptim(pk_node->fullname_with_scope(), tensor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|