|
|
|
@ -180,7 +180,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
|
|
|
|
|
const std::vector<int>& height_sections,
|
|
|
|
|
const framework::ExecutionContext& context,
|
|
|
|
|
const framework::Scope& scope) {
|
|
|
|
|
auto& local_scope = scope.NewScope();
|
|
|
|
|
framework::Scope* local_scope = scope.NewTmpScope();
|
|
|
|
|
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto& cpu_ctx = *pool.Get(platform::CPUPlace());
|
|
|
|
@ -224,22 +224,22 @@ void prefetch(const std::string& id_name, const std::string& out_name,
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto splited_ids = SplitIds(ids_vector, height_sections, &local_scope);
|
|
|
|
|
auto splited_ids = SplitIds(ids_vector, height_sections, local_scope);
|
|
|
|
|
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
|
|
|
|
|
&local_scope);
|
|
|
|
|
local_scope);
|
|
|
|
|
|
|
|
|
|
// create output var in local scope
|
|
|
|
|
for (auto& name : out_var_names) {
|
|
|
|
|
local_scope.Var(name)->GetMutable<framework::LoDTensor>();
|
|
|
|
|
local_scope->Var(name)->GetMutable<framework::LoDTensor>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<distributed::VarHandlePtr> rets;
|
|
|
|
|
for (size_t i = 0; i < in_var_names.size(); i++) {
|
|
|
|
|
if (NeedSend(local_scope, in_var_names[i])) {
|
|
|
|
|
if (NeedSend(*local_scope, in_var_names[i])) {
|
|
|
|
|
VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i]
|
|
|
|
|
<< " to get " << out_var_names[i] << " back";
|
|
|
|
|
rets.push_back(rpc_client->AsyncPrefetchVar(
|
|
|
|
|
epmap[i], cpu_ctx, local_scope, in_var_names[i], out_var_names[i],
|
|
|
|
|
epmap[i], cpu_ctx, *local_scope, in_var_names[i], out_var_names[i],
|
|
|
|
|
table_names[i]));
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
|
|
|
|
@ -252,8 +252,8 @@ void prefetch(const std::string& id_name, const std::string& out_name,
|
|
|
|
|
|
|
|
|
|
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
|
|
|
|
|
out_var_names, height_sections, splited_ids,
|
|
|
|
|
context, &local_scope, &actual_ctx);
|
|
|
|
|
scope.DeleteScope(&local_scope);
|
|
|
|
|
context, local_scope, &actual_ctx);
|
|
|
|
|
delete local_scope;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}; // namespace distributed
|
|
|
|
|