|
|
|
@ -579,8 +579,40 @@ void WorkerForReduceSparseGradient(WorkerParamsForReduceSparseGradient param) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RunMultiThreadReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad,
|
|
|
|
|
size_t outer_dim, std::vector<std::pair<int, size_t>> *sorted_indices,
|
|
|
|
|
std::vector<size_t> *slice_positions) {
|
|
|
|
|
MS_LOG(DEBUG) << "Start";
|
|
|
|
|
size_t thread_num = 24;
|
|
|
|
|
if (slice_positions->size() < thread_num) {
|
|
|
|
|
thread_num = slice_positions->size();
|
|
|
|
|
}
|
|
|
|
|
size_t stride = (slice_positions->size() + thread_num - 1) / thread_num;
|
|
|
|
|
thread_num = (slice_positions->size() + stride - 1) / stride;
|
|
|
|
|
std::vector<std::thread> threads;
|
|
|
|
|
size_t max_length = sorted_indices->size() * outer_dim;
|
|
|
|
|
for (size_t i = 0; i < thread_num; ++i) {
|
|
|
|
|
size_t slice_start = i * stride;
|
|
|
|
|
size_t slice_end = 0;
|
|
|
|
|
if (i == thread_num - 1) {
|
|
|
|
|
slice_end = slice_positions->size();
|
|
|
|
|
} else {
|
|
|
|
|
slice_end = slice_start + stride;
|
|
|
|
|
}
|
|
|
|
|
WorkerParamsForReduceSparseGradient params{
|
|
|
|
|
slice_start, slice_end, max_length, outer_dim, sorted_indices, slice_positions, origin_sparse_grad.value_,
|
|
|
|
|
unique_grad};
|
|
|
|
|
threads.emplace_back(std::thread(WorkerForReduceSparseGradient, params));
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < thread_num; ++i) {
|
|
|
|
|
threads[i].join();
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "End";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim,
|
|
|
|
|
size_t outer_dim) {
|
|
|
|
|
size_t outer_dim, bool use_multi_threads) {
|
|
|
|
|
MS_LOG(DEBUG) << "Start";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(unique_grad);
|
|
|
|
@ -599,42 +631,35 @@ void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradie
|
|
|
|
|
[](const std::pair<int, size_t> &left, const std::pair<int, size_t> &right) { return left.first < right.first; });
|
|
|
|
|
int last_index = 0;
|
|
|
|
|
std::vector<size_t> slice_positions;
|
|
|
|
|
slice_positions.reserve(sorted_indices.size());
|
|
|
|
|
for (size_t i = 0; i < sorted_indices.size(); ++i) {
|
|
|
|
|
if (i == 0 || last_index != sorted_indices[i].first) {
|
|
|
|
|
slice_positions.emplace_back(i);
|
|
|
|
|
}
|
|
|
|
|
last_index = sorted_indices[i].first;
|
|
|
|
|
}
|
|
|
|
|
size_t thread_num = 8;
|
|
|
|
|
if (slice_positions.size() < thread_num) {
|
|
|
|
|
thread_num = slice_positions.size();
|
|
|
|
|
}
|
|
|
|
|
size_t stride = (slice_positions.size() + thread_num - 1) / thread_num;
|
|
|
|
|
thread_num = (slice_positions.size() + stride - 1) / stride;
|
|
|
|
|
std::vector<std::thread> threads;
|
|
|
|
|
size_t max_length = sorted_indices.size() * outer_dim;
|
|
|
|
|
for (size_t i = 0; i < thread_num; ++i) {
|
|
|
|
|
size_t slice_start = i * stride;
|
|
|
|
|
size_t slice_end = 0;
|
|
|
|
|
if (i == thread_num - 1) {
|
|
|
|
|
slice_end = slice_positions.size();
|
|
|
|
|
} else {
|
|
|
|
|
slice_end = slice_start + stride;
|
|
|
|
|
}
|
|
|
|
|
WorkerParamsForReduceSparseGradient params{
|
|
|
|
|
slice_start, slice_end, max_length, outer_dim, &sorted_indices, &slice_positions, origin_sparse_grad.value_,
|
|
|
|
|
unique_grad};
|
|
|
|
|
threads.emplace_back(std::thread(WorkerForReduceSparseGradient, params));
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < thread_num; ++i) {
|
|
|
|
|
threads[i].join();
|
|
|
|
|
if (use_multi_threads) {
|
|
|
|
|
RunMultiThreadReduceSparseGradient(origin_sparse_grad, unique_grad, outer_dim, &sorted_indices, &slice_positions);
|
|
|
|
|
} else {
|
|
|
|
|
size_t max_length = sorted_indices.size() * outer_dim;
|
|
|
|
|
WorkerParamsForReduceSparseGradient params{0,
|
|
|
|
|
slice_positions.size(),
|
|
|
|
|
max_length,
|
|
|
|
|
outer_dim,
|
|
|
|
|
&sorted_indices,
|
|
|
|
|
&slice_positions,
|
|
|
|
|
origin_sparse_grad.value_,
|
|
|
|
|
unique_grad};
|
|
|
|
|
WorkerForReduceSparseGradient(params);
|
|
|
|
|
}
|
|
|
|
|
unique_grad->indices_size_ = slice_positions.size();
|
|
|
|
|
MS_LOG(DEBUG) << "End";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ReduceMultiSparseGradient(const std::vector<std::shared_ptr<SparseGradient>> &unique_slice_grads,
|
|
|
|
|
SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim,
|
|
|
|
|
size_t outer_dim) {
|
|
|
|
|
MS_LOG(DEBUG) << "Start";
|
|
|
|
|
if (unique_slice_grads.empty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -658,10 +683,12 @@ void ReduceMultiSparseGradient(const std::vector<std::shared_ptr<SparseGradient>
|
|
|
|
|
}
|
|
|
|
|
tmp_grad->indices_size_ = unique_indices_size;
|
|
|
|
|
ReduceSparseGradient(*tmp_grad, unique_grad, first_dim, outer_dim);
|
|
|
|
|
MS_LOG(DEBUG) << "End";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *tmp_grad,
|
|
|
|
|
SparseGradient *unique_grad, size_t first_dim, size_t outer_dim) {
|
|
|
|
|
MS_LOG(DEBUG) << "Start";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(unique_grad);
|
|
|
|
@ -693,12 +720,13 @@ void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, Spar
|
|
|
|
|
unique_slice_grads[i]->indices_ = unique_grad->indices_ + indices_offset;
|
|
|
|
|
unique_slice_grads[i]->indices_size_ = indices_size;
|
|
|
|
|
threads.emplace_back(
|
|
|
|
|
std::thread(ReduceSparseGradient, slice_grad, unique_slice_grads[i].get(), first_dim, outer_dim));
|
|
|
|
|
std::thread(ReduceSparseGradient, slice_grad, unique_slice_grads[i].get(), first_dim, outer_dim, false));
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < thread_num; ++i) {
|
|
|
|
|
threads[i].join();
|
|
|
|
|
}
|
|
|
|
|
ReduceMultiSparseGradient(unique_slice_grads, tmp_grad, unique_grad, first_dim, outer_dim);
|
|
|
|
|
MS_LOG(DEBUG) << "End";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
|
|
|
|
|