|
|
|
@ -21,6 +21,7 @@ limitations under the License. */
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/utils/Common.h"
|
|
|
|
|
#include "paddle/utils/GlobalConstants.h"
|
|
|
|
|
#include "paddle/gserver/gradientmachines/GradientMachine.h"
|
|
|
|
|
|
|
|
|
|
/// Import PaddlePaddle's enumeration into global namespace.
|
|
|
|
|
using namespace paddle::enumeration_wrapper; // NOLINT
|
|
|
|
@ -468,9 +469,9 @@ private:
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
enum GradientMatchineCreateMode {
|
|
|
|
|
CREATE_MODE_NORMAL = 0,
|
|
|
|
|
CREATE_MODE_SGD_SPARSE_CPU_TRAINING = 3,
|
|
|
|
|
CREATE_MODE_TESTING = 4
|
|
|
|
|
CREATE_MODE_NORMAL = paddle::GradientMachine::kNormal,
|
|
|
|
|
CREATE_MODE_SGD_SPARSE_CPU_TRAINING = paddle::GradientMachine::kSgdSparseCpuTraining,
|
|
|
|
|
CREATE_MODE_TESTING = paddle::GradientMachine::kTesting
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ParameterConfigPrivate;
|
|
|
|
@ -818,7 +819,7 @@ private:
|
|
|
|
|
public:
|
|
|
|
|
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
|
|
|
|
|
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
|
|
|
|
|
int passCount);
|
|
|
|
|
int passCount, bool userSparseUpdater);
|
|
|
|
|
~ParameterUpdater();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|