You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
133 lines
4.6 KiB
133 lines
4.6 KiB
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "paddle/fluid/distributed/table/tensor_table.h"
|
|
#include <chrono> // NOLINT
|
|
#include <map>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <vector>
|
|
#include "paddle/fluid/distributed/common/utils.h"
|
|
DECLARE_double(eager_delete_tensor_gb);
|
|
namespace paddle {
|
|
namespace distributed {
|
|
|
|
int32_t TensorTable::set_program_env(
|
|
framework::Scope *scope, platform::Place place,
|
|
const std::vector<framework::ProgramDesc> *sub_program) {
|
|
scope_ = scope;
|
|
place_ = place;
|
|
executor_ = new framework::Executor(place_);
|
|
sub_program_ = sub_program;
|
|
return 0;
|
|
}
|
|
|
|
int32_t GlobalStepTable::initialize() {
|
|
auto _program_config = _config.tensor();
|
|
auto trainers_ = _config.common().trainer_num();
|
|
FLAGS_eager_delete_tensor_gb = -1;
|
|
// Get Config
|
|
if (_program_config.has_startup_program_id()) {
|
|
startup_program_id_ = _program_config.startup_program_id();
|
|
}
|
|
if (_program_config.has_main_program_id()) {
|
|
main_program_id_ = _program_config.main_program_id();
|
|
}
|
|
if (_program_config.has_feed_var_name()) {
|
|
feed_var_name_ = _program_config.feed_var_name();
|
|
}
|
|
if (_program_config.has_fetch_var_name()) {
|
|
fetch_var_name_ = _program_config.fetch_var_name();
|
|
}
|
|
|
|
// Run startup program
|
|
if (startup_program_id_ != -1) {
|
|
std::map<std::string, const framework::LoDTensor *> fake_feed;
|
|
std::map<std::string, framework::FetchType *> fake_fetch;
|
|
auto startup_program_desc = sub_program_->at(startup_program_id_);
|
|
auto ctx = executor_->Prepare(startup_program_desc, 0);
|
|
executor_->RunPreparedContext(ctx.get(), scope_, false);
|
|
}
|
|
|
|
if (main_program_id_ != -1) {
|
|
// Run main porgram, if program is used for learning decay
|
|
auto main_program_desc = sub_program_->at(main_program_id_);
|
|
auto main_ctx = executor_->Prepare(main_program_desc, 0);
|
|
exec_context_ = std::move(main_ctx);
|
|
executor_->RunPreparedContext(exec_context_.get(), scope_, false);
|
|
// init decay_counters
|
|
decay_counters_.reserve(trainers_);
|
|
for (int32_t i = 0; i < trainers_; ++i) {
|
|
decay_counters_[i] = 0;
|
|
}
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
int32_t GlobalStepTable::set_table_map(
|
|
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) {
|
|
auto *lr_var = scope_->FindVar(fetch_var_name_);
|
|
auto *lr_tensor = lr_var->GetMutable<framework::LoDTensor>();
|
|
auto *lr_value = lr_tensor->mutable_data<float>(platform::CPUPlace());
|
|
VLOG(3) << "GlobalStepTable::set_table_map set global lr: " << *lr_value;
|
|
|
|
for (auto iter = table_map->begin(); iter != table_map->end(); iter++) {
|
|
auto table_id = iter->first;
|
|
if (table_id == _config.table_id()) {
|
|
continue;
|
|
}
|
|
iter->second->set_global_lr(lr_value);
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int32_t GlobalStepTable::push_dense(const int64_t *values,
|
|
const int32_t trainer_id) {
|
|
return _run_program(values, trainer_id);
|
|
}
|
|
|
|
int32_t GlobalStepTable::_run_program(const int64_t *values,
|
|
const uint32_t trainer_id) {
|
|
FLAGS_eager_delete_tensor_gb = -1;
|
|
auto counter = decay_counters_.at(trainer_id);
|
|
counter += int(values[0]);
|
|
decay_counters_.at(trainer_id) = counter;
|
|
|
|
auto *global_step_var = scope_->FindVar(feed_var_name_);
|
|
auto *tensor = global_step_var->GetMutable<framework::LoDTensor>();
|
|
auto *value = tensor->mutable_data<int64_t>(platform::CPUPlace());
|
|
|
|
auto global_counter = 0;
|
|
for (auto &trainer_counter : decay_counters_) {
|
|
global_counter += trainer_counter.second;
|
|
}
|
|
|
|
// Todo: hard code for increment op
|
|
value[0] = global_counter - 1;
|
|
VLOG(3) << "GlobalStepTable::_run_program global_counter " << value[0];
|
|
|
|
executor_->RunPreparedContext(exec_context_.get(), scope_, false, false);
|
|
auto *lr_var = scope_->FindVar(fetch_var_name_);
|
|
auto *lr_tensor = lr_var->GetMutable<framework::LoDTensor>();
|
|
auto *lr_value = lr_tensor->mutable_data<float>(platform::CPUPlace());
|
|
VLOG(3) << "GlobalStepTable::LR value: " << lr_value[0];
|
|
return 0;
|
|
}
|
|
|
|
} // namespace distributed
|
|
} // namespace paddle
|