|
|
@ -15,15 +15,27 @@ limitations under the License. */
|
|
|
|
#include "PaddleAPI.h"
|
|
|
|
#include "PaddleAPI.h"
|
|
|
|
|
|
|
|
|
|
|
|
#include "PaddleAPIPrivate.h"
|
|
|
|
#include "PaddleAPIPrivate.h"
|
|
|
|
|
|
|
|
#include "paddle/trainer/RemoteParameterUpdater.h"
|
|
|
|
#include "paddle/trainer/ThreadParameterUpdater.h"
|
|
|
|
#include "paddle/trainer/ThreadParameterUpdater.h"
|
|
|
|
|
|
|
|
|
|
|
|
ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {}
|
|
|
|
ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {}
|
|
|
|
|
|
|
|
|
|
|
|
ParameterUpdater *ParameterUpdater::createLocalUpdater(
|
|
|
|
ParameterUpdater *ParameterUpdater::createLocalUpdater(
|
|
|
|
OptimizationConfig *config) {
|
|
|
|
OptimizationConfig *config) {
|
|
|
|
auto param = new ParameterUpdater();
|
|
|
|
auto updater = new ParameterUpdater();
|
|
|
|
param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig()));
|
|
|
|
updater->m->updater.reset(
|
|
|
|
return param;
|
|
|
|
new paddle::SgdThreadUpdater(config->m->getConfig()));
|
|
|
|
|
|
|
|
return updater;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ParameterUpdater *ParameterUpdater::createRemoteUpdater(
|
|
|
|
|
|
|
|
OptimizationConfig *config, int passCount) {
|
|
|
|
|
|
|
|
auto updater = new ParameterUpdater();
|
|
|
|
|
|
|
|
std::unique_ptr<paddle::ParameterUpdater> localUpdater;
|
|
|
|
|
|
|
|
localUpdater.reset(new paddle::SgdThreadUpdater(config->m->getConfig()));
|
|
|
|
|
|
|
|
updater->m->updater.reset(new paddle::ConcurrentRemoteParameterUpdater(
|
|
|
|
|
|
|
|
config->m->getConfig(), passCount, std::move(localUpdater)));
|
|
|
|
|
|
|
|
return updater;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ParameterUpdater::~ParameterUpdater() { delete m; }
|
|
|
|
ParameterUpdater::~ParameterUpdater() { delete m; }
|
|
|
|