|
|
|
|
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <queue>
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_PSLIB
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
@ -182,53 +184,105 @@ void HeterComm<KeyType, ValType, GradType>::create_storage(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KeyType, typename ValType, typename GradType>
|
|
|
|
|
void HeterComm<KeyType, ValType, GradType>::walk_to_dest(int start_index,
|
|
|
|
|
int end_index,
|
|
|
|
|
char* src_key,
|
|
|
|
|
char* src_val) {
|
|
|
|
|
void HeterComm<KeyType, ValType, GradType>::walk_to_dest(
|
|
|
|
|
int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key,
|
|
|
|
|
GradType* src_val) {
|
|
|
|
|
int need_copy_val = 0;
|
|
|
|
|
if (src_val) {
|
|
|
|
|
need_copy_val = 1;
|
|
|
|
|
}
|
|
|
|
|
auto& nodes = path_[start_index][end_index].nodes_;
|
|
|
|
|
for (size_t i = 0; i < nodes.size(); ++i) {
|
|
|
|
|
cudaMemcpyAsync(nodes[i].key_storage, src_key, nodes[i].key_bytes_len,
|
|
|
|
|
cudaMemcpyDefault, nodes[i].in_stream);
|
|
|
|
|
std::queue<CopyTask> que;
|
|
|
|
|
for (int i = 0; i < gpu_num; i++) {
|
|
|
|
|
if (h_left[i] == -1 || h_right[i] == -1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
int size = path_[start_index][i].nodes_.size();
|
|
|
|
|
auto& node = path_[start_index][i].nodes_[0];
|
|
|
|
|
CopyTask t(&path_[start_index][i], 0);
|
|
|
|
|
que.push(t);
|
|
|
|
|
cudaMemcpyAsync(node.key_storage,
|
|
|
|
|
reinterpret_cast<char*>(src_key + h_left[i]),
|
|
|
|
|
node.key_bytes_len, cudaMemcpyDefault, node.in_stream);
|
|
|
|
|
if (need_copy_val) {
|
|
|
|
|
cudaMemcpyAsync(nodes[i].val_storage, src_val, nodes[i].val_bytes_len,
|
|
|
|
|
cudaMemcpyDefault, nodes[i].in_stream);
|
|
|
|
|
cudaMemcpyAsync(node.val_storage,
|
|
|
|
|
reinterpret_cast<char*>(src_val + h_left[i]),
|
|
|
|
|
node.val_bytes_len, cudaMemcpyDefault, node.in_stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
while (!que.empty()) {
|
|
|
|
|
CopyTask& cur_task = que.front();
|
|
|
|
|
que.pop();
|
|
|
|
|
if (cur_task.path->nodes_[cur_task.step].sync) {
|
|
|
|
|
cudaStreamSynchronize(cur_task.path->nodes_[cur_task.step].in_stream);
|
|
|
|
|
}
|
|
|
|
|
if (cur_task.step != cur_task.path->nodes_.size() - 1) {
|
|
|
|
|
int cur_step = cur_task.step;
|
|
|
|
|
CopyTask c(cur_task.path, cur_step + 1);
|
|
|
|
|
que.push(c);
|
|
|
|
|
cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].key_storage,
|
|
|
|
|
cur_task.path->nodes_[cur_step].key_storage,
|
|
|
|
|
cur_task.path->nodes_[cur_step + 1].key_bytes_len,
|
|
|
|
|
cudaMemcpyDefault,
|
|
|
|
|
cur_task.path->nodes_[cur_step + 1].in_stream);
|
|
|
|
|
if (need_copy_val) {
|
|
|
|
|
cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].val_storage,
|
|
|
|
|
cur_task.path->nodes_[cur_step].val_storage,
|
|
|
|
|
cur_task.path->nodes_[cur_step + 1].val_bytes_len,
|
|
|
|
|
cudaMemcpyDefault,
|
|
|
|
|
cur_task.path->nodes_[cur_step + 1].in_stream);
|
|
|
|
|
}
|
|
|
|
|
if (nodes[i].sync) {
|
|
|
|
|
cudaStreamSynchronize(nodes[i].in_stream);
|
|
|
|
|
}
|
|
|
|
|
// cudaStreamSynchronize(nodes[i].in_stream);
|
|
|
|
|
src_key = nodes[i].key_storage;
|
|
|
|
|
src_val = nodes[i].val_storage;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KeyType, typename ValType, typename GradType>
|
|
|
|
|
void HeterComm<KeyType, ValType, GradType>::walk_to_src(int start_index,
|
|
|
|
|
int end_index,
|
|
|
|
|
char* src_val) {
|
|
|
|
|
auto& nodes = path_[start_index][end_index].nodes_;
|
|
|
|
|
int len = nodes.size();
|
|
|
|
|
char* start = NULL;
|
|
|
|
|
for (int i = len - 1; i >= 0; --i) {
|
|
|
|
|
if (start == NULL) {
|
|
|
|
|
start = nodes[i].val_storage;
|
|
|
|
|
void HeterComm<KeyType, ValType, GradType>::walk_to_src(
|
|
|
|
|
int start_index, int gpu_num, int* h_left, int* h_right, ValType* src_val) {
|
|
|
|
|
std::queue<CopyTask> que;
|
|
|
|
|
for (int i = 0; i < gpu_num; i++) {
|
|
|
|
|
if (h_left[i] == -1 || h_right[i] == -1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
cudaMemcpyAsync(nodes[i].val_storage, start, nodes[i].val_bytes_len,
|
|
|
|
|
cudaMemcpyDefault, nodes[i].out_stream);
|
|
|
|
|
if (nodes[i].sync) {
|
|
|
|
|
cudaStreamSynchronize(nodes[i].out_stream);
|
|
|
|
|
int cur_step = path_[start_index][i].nodes_.size() - 1;
|
|
|
|
|
auto& node = path_[start_index][i].nodes_[cur_step];
|
|
|
|
|
if (cur_step == 0) {
|
|
|
|
|
cudaMemcpyAsync(reinterpret_cast<char*>(src_val + h_left[i]),
|
|
|
|
|
node.val_storage, node.val_bytes_len, cudaMemcpyDefault,
|
|
|
|
|
node.out_stream);
|
|
|
|
|
} else {
|
|
|
|
|
CopyTask t(&path_[start_index][i], cur_step - 1);
|
|
|
|
|
que.push(t);
|
|
|
|
|
cudaMemcpyAsync(path_[start_index][i].nodes_[cur_step - 1].val_storage,
|
|
|
|
|
node.val_storage,
|
|
|
|
|
path_[start_index][i].nodes_[cur_step - 1].val_bytes_len,
|
|
|
|
|
cudaMemcpyDefault,
|
|
|
|
|
path_[start_index][i].nodes_[cur_step - 1].out_stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
while (!que.empty()) {
|
|
|
|
|
CopyTask& cur_task = que.front();
|
|
|
|
|
que.pop();
|
|
|
|
|
int cur_step = cur_task.step;
|
|
|
|
|
if (cur_task.path->nodes_[cur_step].sync) {
|
|
|
|
|
cudaStreamSynchronize(cur_task.path->nodes_[cur_step].out_stream);
|
|
|
|
|
}
|
|
|
|
|
if (cur_step > 0) {
|
|
|
|
|
CopyTask c(cur_task.path, cur_step - 1);
|
|
|
|
|
que.push(c);
|
|
|
|
|
cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage,
|
|
|
|
|
cur_task.path->nodes_[cur_step].val_storage,
|
|
|
|
|
cur_task.path->nodes_[cur_step - 1].val_bytes_len,
|
|
|
|
|
cudaMemcpyDefault,
|
|
|
|
|
cur_task.path->nodes_[cur_step - 1].out_stream);
|
|
|
|
|
} else if (cur_step == 0) {
|
|
|
|
|
int end_index = cur_task.path->nodes_.back().gpu_num;
|
|
|
|
|
cudaMemcpyAsync(reinterpret_cast<char*>(src_val + h_left[end_index]),
|
|
|
|
|
cur_task.path->nodes_[cur_step].val_storage,
|
|
|
|
|
cur_task.path->nodes_[cur_step].val_bytes_len,
|
|
|
|
|
cudaMemcpyDefault,
|
|
|
|
|
cur_task.path->nodes_[cur_step].out_stream);
|
|
|
|
|
}
|
|
|
|
|
start = nodes[i].val_storage;
|
|
|
|
|
}
|
|
|
|
|
cudaMemcpyAsync(src_val, nodes[0].val_storage, nodes[0].val_bytes_len,
|
|
|
|
|
cudaMemcpyDefault, nodes[0].out_stream);
|
|
|
|
|
// cudaStreamSynchronize(nodes[0].out_stream);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KeyType, typename ValType, typename GradType>
|
|
|
|
|
@ -462,14 +516,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
|
|
|
|
|
shard_len * sizeof(ValType), local_storage);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < total_gpu; ++i) {
|
|
|
|
|
int shard_len = h_right[i] - h_left[i] + 1;
|
|
|
|
|
if (h_left[i] == -1 || h_right[i] == -1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
walk_to_dest(num, i, reinterpret_cast<char*>(d_shard_keys_ptr + h_left[i]),
|
|
|
|
|
NULL);
|
|
|
|
|
}
|
|
|
|
|
walk_to_dest(num, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < total_gpu; ++i) {
|
|
|
|
|
if (h_left[i] == -1) {
|
|
|
|
|
@ -486,14 +533,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
|
|
|
|
|
cudaStreamSynchronize(resource_->remote_stream(i));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < total_gpu; ++i) {
|
|
|
|
|
int shard_len = h_right[i] - h_left[i] + 1;
|
|
|
|
|
if (h_left[i] == -1 || h_right[i] == -1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
platform::CUDADeviceGuard guard(resource_->dev_id(i));
|
|
|
|
|
walk_to_src(num, i, reinterpret_cast<char*>(d_shard_vals_ptr + h_left[i]));
|
|
|
|
|
}
|
|
|
|
|
walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < total_gpu; ++i) {
|
|
|
|
|
auto& node = path_[num][i].nodes_.front();
|
|
|
|
|
@ -561,7 +601,6 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
|
|
|
|
|
cudaMemcpyDeviceToHost);
|
|
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<memory::Allocation>> local_storage;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < total_gpu; ++i) {
|
|
|
|
|
int shard_len = h_right[i] - h_left[i] + 1;
|
|
|
|
|
if (h_left[i] == -1 || h_right[i] == -1) {
|
|
|
|
|
@ -571,15 +610,8 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
|
|
|
|
|
shard_len * sizeof(GradType), local_storage);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < total_gpu; ++i) {
|
|
|
|
|
int shard_len = h_right[i] - h_left[i] + 1;
|
|
|
|
|
if (h_left[i] == -1 || h_right[i] == -1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
walk_to_dest(gpu_num, i,
|
|
|
|
|
reinterpret_cast<char*>(d_shard_keys_ptr + h_left[i]),
|
|
|
|
|
reinterpret_cast<char*>(d_shard_grads_ptr + h_left[i]));
|
|
|
|
|
}
|
|
|
|
|
walk_to_dest(gpu_num, total_gpu, h_left, h_right, d_shard_keys_ptr,
|
|
|
|
|
d_shard_grads_ptr);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < total_gpu; ++i) {
|
|
|
|
|
if (h_left[i] == -1 || h_right[i] == -1) {
|
|
|
|
|
|