|
|
|
@ -799,22 +799,61 @@ public:
|
|
|
|
|
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
|
|
|
|
|
~ParameterUpdater();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief initialize Parameter Updater by GradientMachine.
|
|
|
|
|
* @param gm
|
|
|
|
|
*/
|
|
|
|
|
void init(const GradientMachine& gm);
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief begin of a training/testing of one pass.
|
|
|
|
|
*/
|
|
|
|
|
void startPass();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief end of a traning/testing of one pass.
|
|
|
|
|
*/
|
|
|
|
|
void finishPass();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief begin of a training/testing of one batch.
|
|
|
|
|
* @param data batch's size
|
|
|
|
|
* @return PassType, mostly will be training.
|
|
|
|
|
*/
|
|
|
|
|
PassType startBatch(size_t batchSize);
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief end of a traning/testing of one batch
|
|
|
|
|
* @param cost current batch cost.
|
|
|
|
|
*/
|
|
|
|
|
void finishBatch(float cost);
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief update a parameter (by local optimizer or by cluster pserver)
|
|
|
|
|
* @param param
|
|
|
|
|
*/
|
|
|
|
|
void update(Parameter* param);
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief restore the average parameter.
|
|
|
|
|
* @note It is only used in AverageOptimizer. Restore will get the current
|
|
|
|
|
* PARAMETER_VALUE back.
|
|
|
|
|
*/
|
|
|
|
|
void restore();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief apply. Store the average parameter.
|
|
|
|
|
* @note It is only used in AverageOptimizer. Apply will store the current
|
|
|
|
|
* PARAMETER_VALUE to buffer, calcaualte current Average Parameter, and save
|
|
|
|
|
* it to PARAMETER_VALUE.
|
|
|
|
|
*/
|
|
|
|
|
void apply();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief catchUpWith The Regularization will be delayed in many situations(
|
|
|
|
|
* pserver, local sparse). Catch Up means catch the regularization up, apply
|
|
|
|
|
* regularization to all params.
|
|
|
|
|
*/
|
|
|
|
|
void catchUpWith();
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -830,10 +869,21 @@ private:
|
|
|
|
|
public:
|
|
|
|
|
~Evaluator();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief begin an evaluate stage.
|
|
|
|
|
*/
|
|
|
|
|
void start();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief end an evaluate stage.
|
|
|
|
|
*/
|
|
|
|
|
void finish();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief toString will get a evaluate result.
|
|
|
|
|
*
|
|
|
|
|
* __repr__ method in python
|
|
|
|
|
*/
|
|
|
|
|
std::string toString();
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|