|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <sstream>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
@ -21,6 +22,7 @@
|
|
|
|
|
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
|
|
|
|
|
#include "paddle/fluid/distributed/table/table.h"
|
|
|
|
|
#include "paddle/fluid/framework/archive.h"
|
|
|
|
|
#include "paddle/fluid/string/string_helper.h"
|
|
|
|
|
|
|
|
|
|
const static int max_port = 65535;
|
|
|
|
|
|
|
|
|
@ -55,6 +57,16 @@ DEFINE_int32(pserver_connect_timeout_ms, 10000,
|
|
|
|
|
|
|
|
|
|
DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num");
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
class Scope;
|
|
|
|
|
class Variable;
|
|
|
|
|
} // namespace framework
|
|
|
|
|
namespace platform {
|
|
|
|
|
class DeviceContext;
|
|
|
|
|
} // namespace platform
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace distributed {
|
|
|
|
|
|
|
|
|
@ -903,5 +915,72 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
|
|
|
|
|
return fut;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
|
|
|
|
|
const std::string &path) {
|
|
|
|
|
// get var information
|
|
|
|
|
std::string var_name = "";
|
|
|
|
|
int64_t var_num = 0;
|
|
|
|
|
int64_t var_shape = 0;
|
|
|
|
|
const auto &worker_param = _config.worker_param().downpour_worker_param();
|
|
|
|
|
for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) {
|
|
|
|
|
if (worker_param.downpour_table_param(i).table_id() == table_id) {
|
|
|
|
|
var_name = worker_param.downpour_table_param(i).common().table_name();
|
|
|
|
|
var_num = worker_param.downpour_table_param(i).accessor().fea_dim();
|
|
|
|
|
var_shape = worker_param.downpour_table_param(i).accessor().embedx_dim();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
var_name, "",
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Cannot find table id %d to save variables.", table_id));
|
|
|
|
|
|
|
|
|
|
std::string var_store = string::Sprintf("%s", path);
|
|
|
|
|
MkDirRecursively(var_store.c_str());
|
|
|
|
|
|
|
|
|
|
// pull sparse from server
|
|
|
|
|
std::vector<float> save_huge_vec(var_num * var_shape);
|
|
|
|
|
std::vector<uint64_t> save_key(var_num);
|
|
|
|
|
std::vector<float *> save_vec;
|
|
|
|
|
for (size_t i = 0; i < save_key.size(); ++i) {
|
|
|
|
|
save_key[i] = i;
|
|
|
|
|
save_vec.push_back(save_huge_vec.data() + i * var_shape);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto status = pull_sparse((float **)save_vec.data(), table_id,
|
|
|
|
|
save_key.data(), save_key.size());
|
|
|
|
|
status.wait();
|
|
|
|
|
|
|
|
|
|
// create lod tensor
|
|
|
|
|
std::shared_ptr<framework::Scope> scope;
|
|
|
|
|
scope.reset(new framework::Scope());
|
|
|
|
|
auto place = platform::CPUPlace();
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(place);
|
|
|
|
|
|
|
|
|
|
framework::Variable *var = scope->Var(var_name);
|
|
|
|
|
framework::LoDTensor *var_tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> vec_dim = {var_num, var_shape};
|
|
|
|
|
var_tensor->Resize(framework::make_ddim(vec_dim));
|
|
|
|
|
|
|
|
|
|
// copy and save
|
|
|
|
|
float *tensor_data = var_tensor->mutable_data<float>(place);
|
|
|
|
|
memcpy(tensor_data, save_huge_vec.data(),
|
|
|
|
|
var_num * var_shape * sizeof(float));
|
|
|
|
|
|
|
|
|
|
std::string file_name = string::Sprintf("%s/%s", var_store, var_name);
|
|
|
|
|
std::ofstream fout(file_name, std::ios::binary);
|
|
|
|
|
PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
|
|
|
|
|
platform::errors::Unavailable(
|
|
|
|
|
"Cannot open %s to save variables.", file_name));
|
|
|
|
|
|
|
|
|
|
framework::SerializeToStream(fout, *var_tensor, dev_ctx);
|
|
|
|
|
fout.close();
|
|
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace distributed
|
|
|
|
|
} // namespace paddle
|
|
|
|
|