|
|
@ -15,6 +15,7 @@ limitations under the License. */
|
|
|
|
#pragma once
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
|
|
#include "ParameterOptimizer.h"
|
|
|
|
#include "ParameterOptimizer.h"
|
|
|
|
|
|
|
|
#include "ParameterUpdateFunctions.h"
|
|
|
|
#include "Regularizer.h"
|
|
|
|
#include "Regularizer.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
@ -37,6 +38,15 @@ public:
|
|
|
|
real torch_learningRate = optConfig_.learning_method() == "torch_momentum"
|
|
|
|
real torch_learningRate = optConfig_.learning_method() == "torch_momentum"
|
|
|
|
? 1.0 - paraConfig.momentum()
|
|
|
|
? 1.0 - paraConfig.momentum()
|
|
|
|
: 1.0;
|
|
|
|
: 1.0;
|
|
|
|
|
|
|
|
#ifdef PADDLE_USE_MKLDNN
|
|
|
|
|
|
|
|
sgdUpdate(learningRate_ * paraConfig.learning_rate() *
|
|
|
|
|
|
|
|
(firstTime_ ? 1.0 : torch_learningRate),
|
|
|
|
|
|
|
|
paraConfig.momentum(),
|
|
|
|
|
|
|
|
applyDecay_ ? paraConfig.decay_rate() : 0,
|
|
|
|
|
|
|
|
vecs[PARAMETER_VALUE].get(),
|
|
|
|
|
|
|
|
vecs[PARAMETER_GRADIENT].get(),
|
|
|
|
|
|
|
|
vecs[PARAMETER_MOMENTUM].get());
|
|
|
|
|
|
|
|
#else
|
|
|
|
vecs[PARAMETER_VALUE]->sgdUpdate(
|
|
|
|
vecs[PARAMETER_VALUE]->sgdUpdate(
|
|
|
|
*vecs[PARAMETER_GRADIENT],
|
|
|
|
*vecs[PARAMETER_GRADIENT],
|
|
|
|
*vecs[PARAMETER_MOMENTUM],
|
|
|
|
*vecs[PARAMETER_MOMENTUM],
|
|
|
@ -44,6 +54,7 @@ public:
|
|
|
|
(firstTime_ ? 1.0 : torch_learningRate),
|
|
|
|
(firstTime_ ? 1.0 : torch_learningRate),
|
|
|
|
paraConfig.momentum(),
|
|
|
|
paraConfig.momentum(),
|
|
|
|
applyDecay_ ? paraConfig.decay_rate() : 0);
|
|
|
|
applyDecay_ ? paraConfig.decay_rate() : 0);
|
|
|
|
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
}
|
|
|
|
virtual void finishBatch() { firstTime_ = false; }
|
|
|
|
virtual void finishBatch() { firstTime_ = false; }
|
|
|
|
};
|
|
|
|
};
|
|
|
|