|
|
|
@ -19,9 +19,9 @@ limitations under the License. */
|
|
|
|
|
#include <stdexcept>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/gserver/gradientmachines/GradientMachine.h"
|
|
|
|
|
#include "paddle/utils/Common.h"
|
|
|
|
|
#include "paddle/utils/GlobalConstants.h"
|
|
|
|
|
#include "paddle/gserver/gradientmachines/GradientMachine.h"
|
|
|
|
|
|
|
|
|
|
/// Import PaddlePaddle's enumeration into global namespace.
|
|
|
|
|
using namespace paddle::enumeration_wrapper; // NOLINT
|
|
|
|
@ -470,7 +470,8 @@ private:
|
|
|
|
|
|
|
|
|
|
enum GradientMatchineCreateMode {
|
|
|
|
|
CREATE_MODE_NORMAL = paddle::GradientMachine::kNormal,
|
|
|
|
|
CREATE_MODE_SGD_SPARSE_CPU_TRAINING = paddle::GradientMachine::kSgdSparseCpuTraining,
|
|
|
|
|
CREATE_MODE_SGD_SPARSE_CPU_TRAINING =
|
|
|
|
|
paddle::GradientMachine::kSgdSparseCpuTraining,
|
|
|
|
|
CREATE_MODE_TESTING = paddle::GradientMachine::kTesting
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -819,7 +820,8 @@ private:
|
|
|
|
|
public:
|
|
|
|
|
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
|
|
|
|
|
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
|
|
|
|
|
int passCount, bool userSparseUpdater);
|
|
|
|
|
int passCount,
|
|
|
|
|
bool userSparseUpdater);
|
|
|
|
|
~ParameterUpdater();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|