|
|
|
@ -102,7 +102,7 @@ class FleetDistRunnerBase(object):
|
|
|
|
|
def run_pserver(self, args):
|
|
|
|
|
fleet.init(self.build_role(args))
|
|
|
|
|
strategy = self.build_strategy(args)
|
|
|
|
|
avg_cost = self.net()
|
|
|
|
|
avg_cost = self.net(args)
|
|
|
|
|
self.build_optimizer(avg_cost, strategy)
|
|
|
|
|
|
|
|
|
|
fleet.init_server()
|
|
|
|
@ -111,24 +111,18 @@ class FleetDistRunnerBase(object):
|
|
|
|
|
def run_dataset_trainer(self, args):
|
|
|
|
|
fleet.init(self.build_role(args))
|
|
|
|
|
strategy = self.build_strategy(args)
|
|
|
|
|
avg_cost = self.net()
|
|
|
|
|
avg_cost = self.net(args)
|
|
|
|
|
self.build_optimizer(avg_cost, strategy)
|
|
|
|
|
out = self.do_dataset_training(fleet)
|
|
|
|
|
|
|
|
|
|
def run_pyreader_trainer(self, args):
|
|
|
|
|
fleet.init(self.build_role(args))
|
|
|
|
|
strategy = self.build_strategy(args)
|
|
|
|
|
avg_cost = self.net()
|
|
|
|
|
self.reader = fluid.io.PyReader(
|
|
|
|
|
feed_list=self.feeds,
|
|
|
|
|
capacity=64,
|
|
|
|
|
iterable=False,
|
|
|
|
|
use_double_buffer=False)
|
|
|
|
|
|
|
|
|
|
avg_cost = self.net(args)
|
|
|
|
|
self.build_optimizer(avg_cost, strategy)
|
|
|
|
|
out = self.do_pyreader_training(fleet)
|
|
|
|
|
|
|
|
|
|
def net(self, batch_size=4, lr=0.01):
|
|
|
|
|
def net(self, args, batch_size=4, lr=0.01):
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"get_model should be implemented by child classes.")
|
|
|
|
|
|
|
|
|
|