|
|
|
@ -104,10 +104,11 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|
|
|
|
size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float);
|
|
|
|
|
size_t reduce_scatter_out_lens = one_split_lens / 8;
|
|
|
|
|
const std::vector<int> &group = {0, 1, 2, 3, 4, 5, 6, 7};
|
|
|
|
|
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mpi_instance);
|
|
|
|
|
for (int i = 0; i < split_num_; i++) {
|
|
|
|
|
device::cpu::MPIAdapter::Instance()->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
|
|
|
|
|
output_addr + i * reduce_scatter_out_lens, group,
|
|
|
|
|
one_split_lens / 8, "sum");
|
|
|
|
|
mpi_instance->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
|
|
|
|
|
output_addr + i * reduce_scatter_out_lens, group, one_split_lens / 8, "sum");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|