|
|
|
|
@ -1645,8 +1645,36 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) {
|
|
|
|
|
FuncGraphManagerPtr manager = node->func_graph()->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
AnfNodeIndexSet node_set = manager->node_users()[node];
|
|
|
|
|
for (auto &node_pair : node_set) {
|
|
|
|
|
CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
|
|
|
|
|
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
|
|
|
|
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_prim);
|
|
|
|
|
if ((node_prim->name() == DEPEND && node_pair.second != 1) || node_prim->name() == RESHAPE) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
|
|
|
|
|
auto layout = GetInputLayoutFromCNode(node_pair);
|
|
|
|
|
return std::make_shared<TensorLayout>(layout);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
|
|
|
|
|
// Create DataParallel tensor layout for parameter(support WideDeep).
|
|
|
|
|
auto next_layout = FindParameterNextLayout(node);
|
|
|
|
|
if (next_layout != nullptr) {
|
|
|
|
|
return next_layout;
|
|
|
|
|
}
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
|
int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size());
|
|
|
|
|
TensorLayout input_tensor_layout;
|
|
|
|
|
|