|
|
@ -36,7 +36,9 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
|
|
|
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
|
|
|
axis_ = 4 - input_shape_.size();
|
|
|
|
axis_ = 4 - input_shape_.size();
|
|
|
|
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reduce_scatter_flag");
|
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrReduceScatterFlag, kernel_node)) {
|
|
|
|
|
|
|
|
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrReduceScatterFlag);
|
|
|
|
|
|
|
|
}
|
|
|
|
#ifdef ENABLE_MPI
|
|
|
|
#ifdef ENABLE_MPI
|
|
|
|
if (reduce_scatter_flag_) {
|
|
|
|
if (reduce_scatter_flag_) {
|
|
|
|
size_t gatherv2_out_lens = 1;
|
|
|
|
size_t gatherv2_out_lens = 1;
|
|
|
@ -65,7 +67,9 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true";
|
|
|
|
MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "offset");
|
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) {
|
|
|
|
|
|
|
|
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, kAttrOffset);
|
|
|
|
|
|
|
|
}
|
|
|
|
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
|
|
|
|
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
|
|
|
|
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
|
|
|
|
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|