|
|
|
@ -22,6 +22,7 @@
|
|
|
|
|
#include "mindspore/ccsrc/mindrecord/include/shard_error.h"
|
|
|
|
|
#include "dataset/engine/gnn/local_edge.h"
|
|
|
|
|
#include "dataset/engine/gnn/local_node.h"
|
|
|
|
|
#include "dataset/util/task_manager.h"
|
|
|
|
|
|
|
|
|
|
using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
|
|
|
|
|
|
|
|
|
@ -80,7 +81,7 @@ Status GraphLoader::InitAndLoad() {
|
|
|
|
|
n_feature_maps_.resize(num_workers_);
|
|
|
|
|
e_feature_maps_.resize(num_workers_);
|
|
|
|
|
default_feature_maps_.resize(num_workers_);
|
|
|
|
|
std::vector<std::future<Status>> r_codes(num_workers_);
|
|
|
|
|
TaskGroup vg;
|
|
|
|
|
|
|
|
|
|
shard_reader_ = std::make_unique<ShardReader>();
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS,
|
|
|
|
@ -97,12 +98,11 @@ Status GraphLoader::InitAndLoad() {
|
|
|
|
|
|
|
|
|
|
// launching worker threads
|
|
|
|
|
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
|
|
|
|
|
r_codes[wkr_id] = std::async(std::launch::async, &GraphLoader::WorkerEntry, this, wkr_id);
|
|
|
|
|
RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id)));
|
|
|
|
|
}
|
|
|
|
|
// wait for threads to finish and check its return code
|
|
|
|
|
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
|
|
|
|
|
RETURN_IF_NOT_OK(r_codes[wkr_id].get());
|
|
|
|
|
}
|
|
|
|
|
vg.join_all(Task::WaitFlag::kBlocking);
|
|
|
|
|
RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny());
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -201,8 +201,11 @@ Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<u
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GraphLoader::WorkerEntry(int32_t worker_id) {
|
|
|
|
|
// Handshake
|
|
|
|
|
TaskManager::FindMe()->Post();
|
|
|
|
|
ShardTuple rows = shard_reader_->GetNextById(row_id_++, worker_id);
|
|
|
|
|
while (rows.empty() == false) {
|
|
|
|
|
RETURN_IF_INTERRUPTED();
|
|
|
|
|
for (const auto &tupled_row : rows) {
|
|
|
|
|
std::vector<uint8_t> col_blob = std::get<0>(tupled_row);
|
|
|
|
|
mindrecord::json col_jsn = std::get<1>(tupled_row);
|
|
|
|
|