|
|
|
@ -37,6 +37,15 @@ public:
|
|
|
|
|
real torch_learningRate = optConfig_.learning_method() == "torch_momentum"
|
|
|
|
|
? 1.0 - paraConfig.momentum()
|
|
|
|
|
: 1.0;
|
|
|
|
|
#ifdef PADDLE_USE_MKLDNN
|
|
|
|
|
vecs[PARAMETER_VALUE]->sgdUpdateWithOMP(
|
|
|
|
|
*vecs[PARAMETER_GRADIENT],
|
|
|
|
|
*vecs[PARAMETER_MOMENTUM],
|
|
|
|
|
learningRate_ * paraConfig.learning_rate() *
|
|
|
|
|
(firstTime_ ? 1.0 : torch_learningRate),
|
|
|
|
|
paraConfig.momentum(),
|
|
|
|
|
applyDecay_ ? paraConfig.decay_rate() : 0);
|
|
|
|
|
#else
|
|
|
|
|
vecs[PARAMETER_VALUE]->sgdUpdate(
|
|
|
|
|
*vecs[PARAMETER_GRADIENT],
|
|
|
|
|
*vecs[PARAMETER_MOMENTUM],
|
|
|
|
@ -44,6 +53,7 @@ public:
|
|
|
|
|
(firstTime_ ? 1.0 : torch_learningRate),
|
|
|
|
|
paraConfig.momentum(),
|
|
|
|
|
applyDecay_ ? paraConfig.decay_rate() : 0);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
virtual void finishBatch() { firstTime_ = false; }
|
|
|
|
|
};
|
|
|
|
|