|
|
@ -50,9 +50,11 @@ bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector<kernel::AddressP
|
|
|
|
const std::vector<int> &rank_group = {0, 1, 2, 3, 4, 5, 6, 7};
|
|
|
|
const std::vector<int> &rank_group = {0, 1, 2, 3, 4, 5, 6, 7};
|
|
|
|
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
|
|
|
|
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
|
|
|
|
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
|
|
|
|
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
|
|
|
|
|
|
|
|
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mpi_instance);
|
|
|
|
for (int i = 0; i < split_num_; i++) {
|
|
|
|
for (int i = 0; i < split_num_; i++) {
|
|
|
|
device::cpu::MPIAdapter::Instance()->AllGather(input_addr + i * input_split_lens,
|
|
|
|
mpi_instance->AllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group,
|
|
|
|
output_addr + i * output_split_lens, rank_group, input_split_lens);
|
|
|
|
input_split_lens);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#if defined(_WIN32) || defined(_WIN64)
|
|
|
|
#if defined(_WIN32) || defined(_WIN64)
|
|
|
|
auto end_time = std::chrono::steady_clock::now();
|
|
|
|
auto end_time = std::chrono::steady_clock::now();
|
|
|
|