|
|
|
@ -179,8 +179,8 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num,
|
|
|
|
|
const std::string &op_type, float *output) {
|
|
|
|
|
bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t input_data_num,
|
|
|
|
|
size_t output_size, const std::string &op_type, float *output) {
|
|
|
|
|
int scatter_index = GetScatterIndex(rank_id_, ranks_group);
|
|
|
|
|
auto group = AddGroup(ranks_group);
|
|
|
|
|
if (group == MPI_GROUP_NULL) {
|
|
|
|
@ -193,7 +193,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MPI_Win window;
|
|
|
|
|
auto ret = MPI_Win_create(input, data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window);
|
|
|
|
|
auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window);
|
|
|
|
|
if (ret != MPI_SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "mpi window create fail! ret = " << ret;
|
|
|
|
|
return false;
|
|
|
|
@ -205,18 +205,21 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto op = GetMpiOp(op_type);
|
|
|
|
|
ret = MPI_Accumulate(input + i * data_num, data_num, MPI_FLOAT, remote_rank, i * data_num, data_num, MPI_FLOAT, op,
|
|
|
|
|
window);
|
|
|
|
|
ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num,
|
|
|
|
|
input_data_num, MPI_FLOAT, op, window);
|
|
|
|
|
if (ret != MPI_SUCCESS) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MPI_Win_fence(0, window);
|
|
|
|
|
if (output != nullptr) {
|
|
|
|
|
auto data_size = data_num * sizeof(float);
|
|
|
|
|
auto copy_ret = memcpy_s(output, data_size, input + scatter_index * data_num, data_size);
|
|
|
|
|
auto data_size = input_data_num * sizeof(float);
|
|
|
|
|
if (output_size < data_size) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "output buffer size " << output_size << " < input size " << data_size;
|
|
|
|
|
}
|
|
|
|
|
auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size);
|
|
|
|
|
if (copy_ret != 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "copy output memory fail!";
|
|
|
|
|
MS_LOG(EXCEPTION) << "copy output memory fail!ret = " << copy_ret;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MPI_Win_free(&window);
|
|
|
|
@ -224,7 +227,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MPIAdapter::AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
|
|
|
|
|
bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
|
|
|
|
|
if (ranks_group.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "input rank group is empty!";
|
|
|
|
|
return false;
|
|
|
|
|