support RemoteSparseUpdater

feature/design_of_v2_layer_converter
qiaolongfei 8 years ago
parent 6802b65cd2
commit bad503ff08

@ -21,6 +21,7 @@ limitations under the License. */
#include <vector>
#include "paddle/utils/Common.h"
#include "paddle/utils/GlobalConstants.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h"
/// Import PaddlePaddle's enumeration into global namespace.
using namespace paddle::enumeration_wrapper; // NOLINT
@ -468,9 +469,9 @@ private:
};
enum GradientMatchineCreateMode {
CREATE_MODE_NORMAL = 0,
CREATE_MODE_SGD_SPARSE_CPU_TRAINING = 3,
CREATE_MODE_TESTING = 4
CREATE_MODE_NORMAL = paddle::GradientMachine::kNormal,
CREATE_MODE_SGD_SPARSE_CPU_TRAINING = paddle::GradientMachine::kSgdSparseCpuTraining,
CREATE_MODE_TESTING = paddle::GradientMachine::kTesting
};
struct ParameterConfigPrivate;
@ -818,7 +819,7 @@ private:
public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
int passCount);
int passCount, bool userSparseUpdater);
~ParameterUpdater();
/**

@ -29,10 +29,19 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
}
ParameterUpdater *ParameterUpdater::createRemoteUpdater(
OptimizationConfig *config, int passCount) {
OptimizationConfig *config, int passCount, bool userSparseUpdater) {
auto updater = new ParameterUpdater();
updater->m->updater.reset(new paddle::RemoteParameterUpdater(
config->m->getConfig(), passCount, nullptr));
auto remoteUpdater = new paddle::RemoteParameterUpdater(
config->m->getConfig(), passCount, nullptr);
if (userSparseUpdater) {
std::unique_ptr<paddle::ParameterUpdater> remoteUpdaterPtr;
remoteUpdaterPtr.reset(remoteUpdater);
auto sparseRemoteUpdater = new paddle::SparseRemoteParameterUpdaterComposite(
config->m->getConfig(), passCount, false, std::move(remoteUpdaterPtr));
updater->m->updater.reset(sparseRemoteUpdater);
} else {
updater->m->updater.reset(remoteUpdater);
}
return updater;
}

@ -41,9 +41,9 @@ class Optimizer(object):
def create_local_updater(self):
return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__)
def create_remote_updater(self, pass_num):
def create_remote_updater(self, pass_num, use_sparse_updater):
return swig_api.ParameterUpdater.createRemoteUpdater(self.__opt_conf__,
pass_num)
pass_num, use_sparse_updater)
class Momentum(Optimizer):

@ -97,7 +97,8 @@ class SGD(object):
if self.__is_local__:
updater = self.__optimizer__.create_local_updater()
else:
updater = self.__optimizer__.create_remote_updater(num_passes)
updater = self.__optimizer__.create_remote_updater(num_passes,
self.__use_sparse_updater__)
updater.init(self.__gradient_machine__)
self.__gradient_machine__.start()

Loading…
Cancel
Save