You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
127 lines
3.8 KiB
127 lines
3.8 KiB
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
|
|
#pragma once
|
|
|
|
#include "paddle/utils/Util.h"
|
|
|
|
#include <stdio.h>
|
|
|
|
#include "hl_gpu.h"
|
|
#include "paddle/gserver/dataproviders/DataProvider.h"
|
|
#include "paddle/gserver/gradientmachines/GradientMachine.h"
|
|
|
|
#include "TrainerConfig.pb.h"
|
|
#include "TrainerConfigHelper.h"
|
|
#include "ParameterUpdater.h"
|
|
#include <fstream>
|
|
#include <stdlib.h>
|
|
|
|
namespace paddle {
|
|
|
|
/**
|
|
* Configuration for parameter utils.
|
|
*/
|
|
struct ParameterUtilConfig {
|
|
DISABLE_COPY(ParameterUtilConfig);
|
|
|
|
ParameterUtilConfig(bool save_only_one, int saving_period,
|
|
bool load_save_parameters_in_pserver,
|
|
std::string config):
|
|
save_only_one_(save_only_one),
|
|
saving_period_(saving_period),
|
|
load_save_param_pserver_(load_save_parameters_in_pserver),
|
|
config_(config) {
|
|
}
|
|
|
|
bool save_only_one_;
|
|
int saving_period_;
|
|
bool load_save_param_pserver_;
|
|
std::string config_;
|
|
};
|
|
|
|
|
|
/**
|
|
* ParameterUtil
|
|
* Utility class for loading and saving parameters
|
|
*/
|
|
class ParameterUtil {
|
|
public:
|
|
/**
|
|
* Ctor.
|
|
*
|
|
* @param config
|
|
* @param intconfig
|
|
* @param gradientMachine
|
|
* @param parameterUpdater
|
|
* @return
|
|
*/
|
|
ParameterUtil(const std::shared_ptr<TrainerConfigHelper> &config,
|
|
std::unique_ptr<ParameterUtilConfig> &&intconfig,
|
|
const GradientMachinePtr &gradientMachine,
|
|
const std::shared_ptr<ParameterUpdater> ¶meterUpdater);
|
|
|
|
/// Load parameter from the saved parameter file as pass passId
|
|
/// if loadsave_parameters_in_pserver is set, some parameters MUST
|
|
/// load in pserver, which is "remote".
|
|
/// loadParameters can choose to load local/remote parameter, or both.
|
|
bool loadParameters(int passId, bool local = true, bool remote = false);
|
|
|
|
/// load parameters given path info
|
|
void loadParametersWithPath(const std::string& dir, bool local = true,
|
|
bool remote = false);
|
|
|
|
/// Save parameter to dist for pass passId
|
|
/// passInnerId means saving times in one pass, some users want to
|
|
/// save parameters when have processed some batches in one pass
|
|
/// passInnerId = 0 means do not need to save in one inner pass
|
|
void saveParameters(int passId, int passInnerId = 0);
|
|
|
|
/// save parameters for one pass, when passInnerId > 0 means saving
|
|
/// the passInnerId times in one pass
|
|
void saveParametersOnePass(int passId, int passInnerId = 0);
|
|
|
|
/// delete parameter from disk via passId
|
|
void deleteParameters(int passId, int passInnerId = 0);
|
|
|
|
/// save config given path info
|
|
void saveConfigWithPath(const std::string& path);
|
|
|
|
/**
|
|
* Try to load parameter from config.
|
|
* @return true if can load from trainer config.
|
|
*/
|
|
inline bool tryLoadParametersFromConfig() {
|
|
auto& c = config_->getConfig();
|
|
if (!c.init_model_path().empty()) {
|
|
loadParametersWithPath(c.init_model_path());
|
|
return true;
|
|
} else if (c.start_pass() > 0) {
|
|
CHECK(loadParameters(c.start_pass() - 1));
|
|
return true;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<TrainerConfigHelper> config_;
|
|
std::unique_ptr<ParameterUtilConfig> intConfig_;
|
|
GradientMachinePtr gserver_;
|
|
std::shared_ptr<ParameterUpdater> pUpdater_;
|
|
};
|
|
|
|
} // namespace paddle
|