|
|
|
@ -102,8 +102,7 @@ static void MergeMultipleVarsIntoOneBySection(
|
|
|
|
|
const std::string& out_name, const std::vector<std::string>& out_var_names,
|
|
|
|
|
const std::vector<int>& height_section,
|
|
|
|
|
const std::vector<std::vector<int64_t>>& splited_ids,
|
|
|
|
|
const framework::ExecutionContext& context,
|
|
|
|
|
const framework::Scope& actual_scope, framework::Scope* scope,
|
|
|
|
|
const framework::ExecutionContext& context, framework::Scope* scope,
|
|
|
|
|
platform::DeviceContext* actual_ctx) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), "");
|
|
|
|
|
|
|
|
|
@ -115,9 +114,9 @@ static void MergeMultipleVarsIntoOneBySection(
|
|
|
|
|
id_to_offset[ids_vector[i]].push_back(i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto& id_tensor = actual_scope.FindVar(id_name)->Get<framework::LoDTensor>();
|
|
|
|
|
auto& id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>();
|
|
|
|
|
auto* out_tensor =
|
|
|
|
|
actual_scope.FindVar(out_name)->GetMutable<framework::LoDTensor>();
|
|
|
|
|
scope.FindVar(out_name)->GetMutable<framework::LoDTensor>();
|
|
|
|
|
auto* out_tensor_data = out_tensor->mutable_data<float>(id_tensor.place());
|
|
|
|
|
|
|
|
|
|
bool is_on_cpu_place = true;
|
|
|
|
@ -175,7 +174,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
|
|
|
|
|
const std::vector<int>& height_sections,
|
|
|
|
|
const framework::ExecutionContext& context,
|
|
|
|
|
const framework::Scope& scope) {
|
|
|
|
|
auto& local_scope = context.scope().NewScope();
|
|
|
|
|
auto& local_scope = scope.NewScope();
|
|
|
|
|
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto& cpu_ctx = *pool.Get(platform::CPUPlace());
|
|
|
|
@ -247,8 +246,8 @@ void prefetch(const std::string& id_name, const std::string& out_name,
|
|
|
|
|
|
|
|
|
|
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
|
|
|
|
|
out_var_names, height_sections, splited_ids,
|
|
|
|
|
context, scope, &local_scope, &actual_ctx);
|
|
|
|
|
context.scope().DeleteScope(&local_scope);
|
|
|
|
|
context, &local_scope, &actual_ctx);
|
|
|
|
|
scope.DeleteScope(&local_scope);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}; // namespace distributed
|
|
|
|
|