You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
205 lines
7.2 KiB
205 lines
7.2 KiB
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import sys
|
|
import signal
|
|
import subprocess
|
|
import argparse
|
|
import time
|
|
import math
|
|
import random
|
|
from multiprocessing import Process
|
|
from functools import reduce
|
|
|
|
import numpy as np
|
|
import pickle
|
|
import unittest
|
|
import six
|
|
|
|
import paddle
|
|
import paddle.fluid as fluid
|
|
from paddle.fluid import core
|
|
from paddle.fluid import io
|
|
|
|
from test_dist_base import TestDistRunnerBase, runtime_main, RUN_STEP
|
|
from dist_simnet_bow import TestDistSimnetBow2x2, DATA_URL, DATA_MD5
|
|
|
|
|
|
class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
|
|
def _load_persistable_vars(self, executor, dirname, program):
|
|
def _is_checkpoint_var(var):
|
|
"""
|
|
the checkpoint will not save or load all the variables.
|
|
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
|
|
|
|
: param var(Variable)
|
|
"""
|
|
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
|
|
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
|
|
var.desc.type() == core.VarDesc.VarType.RAW:
|
|
return False
|
|
# @GRAD are named for gradient variables, checkpoint will not save it.
|
|
if "@GRAD" in var.name:
|
|
return False
|
|
# .trainer_ are named for distribute train variables, checkpoint will not save it.
|
|
if ".trainer_" in var.name:
|
|
return False
|
|
|
|
# .block is named for distribute train variables, checkpoint will not save it.
|
|
if ".block" in var.name:
|
|
return False
|
|
|
|
if "tmp_" in var.name:
|
|
return False
|
|
|
|
return var.persistable
|
|
|
|
io.load_vars(
|
|
executor,
|
|
dirname=dirname,
|
|
main_program=program,
|
|
predicate=_is_checkpoint_var,
|
|
filename=None)
|
|
|
|
def run_pserver(self, args):
|
|
self.get_model(batch_size=2)
|
|
# NOTE: pserver should not call memory optimize
|
|
t = self.get_transpiler(args.trainer_id,
|
|
fluid.default_main_program(), args.endpoints,
|
|
args.trainers, args.sync_mode, False,
|
|
args.current_endpoint)
|
|
pserver_prog = t.get_pserver_program(args.current_endpoint)
|
|
startup_prog = t.get_startup_program(args.current_endpoint,
|
|
pserver_prog)
|
|
|
|
need_load = bool(int(os.getenv("LOAD", "0")))
|
|
model_dir = os.getenv("MODEL_DIR", "")
|
|
|
|
place = fluid.CPUPlace()
|
|
exe = fluid.Executor(place)
|
|
exe.run(startup_prog)
|
|
|
|
if need_load and model_dir:
|
|
fluid.io.load_persistables(exe, model_dir, pserver_prog)
|
|
|
|
exe.run(pserver_prog)
|
|
|
|
def run_trainer(self, args):
|
|
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
|
|
self.get_model(batch_size=2)
|
|
|
|
if args.update_method == "pserver":
|
|
t = self.get_transpiler(args.trainer_id,
|
|
fluid.default_main_program(),
|
|
args.endpoints, args.trainers,
|
|
args.sync_mode)
|
|
|
|
trainer_prog = t.get_trainer_program()
|
|
else:
|
|
trainer_prog = fluid.default_main_program()
|
|
|
|
if args.use_cuda:
|
|
place = fluid.CUDAPlace(0)
|
|
else:
|
|
place = fluid.CPUPlace()
|
|
|
|
startup_exe = fluid.Executor(place)
|
|
startup_exe.run(fluid.default_startup_program())
|
|
|
|
strategy = fluid.ExecutionStrategy()
|
|
strategy.num_threads = 1
|
|
|
|
build_stra = fluid.BuildStrategy()
|
|
|
|
if args.use_reduce:
|
|
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
|
|
else:
|
|
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
|
|
|
|
exe = fluid.ParallelExecutor(
|
|
args.use_cuda,
|
|
loss_name=avg_cost.name,
|
|
exec_strategy=strategy,
|
|
build_strategy=build_stra)
|
|
|
|
feed_var_list = [
|
|
var for var in trainer_prog.global_block().vars.values()
|
|
if var.is_data
|
|
]
|
|
|
|
feeder = fluid.DataFeeder(feed_var_list, place)
|
|
reader_generator = train_reader()
|
|
|
|
def get_data():
|
|
origin_batch = next(reader_generator)
|
|
if args.update_method == "pserver" and args.use_reader_alloc:
|
|
new_batch = []
|
|
for offset, item in enumerate(origin_batch):
|
|
if offset % 2 == args.trainer_id:
|
|
new_batch.append(item)
|
|
return new_batch
|
|
else:
|
|
return origin_batch
|
|
|
|
need_save = bool(int(os.getenv("SAVE", "0")))
|
|
model_dir = os.getenv("MODEL_DIR", "")
|
|
save_mode = os.getenv("SAVE_MODE", "")
|
|
|
|
if save_mode == "LOCAL":
|
|
if need_save:
|
|
for _ in six.moves.xrange(RUN_STEP):
|
|
loss, = exe.run(fetch_list=[avg_cost.name],
|
|
feed=feeder.feed(get_data()))
|
|
if need_save and model_dir:
|
|
io.save_persistables(startup_exe, model_dir, trainer_prog)
|
|
|
|
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor(
|
|
))
|
|
if six.PY2:
|
|
print(pickle.dumps(np.ravel(var).tolist()))
|
|
else:
|
|
sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist()))
|
|
|
|
elif save_mode == "DIST":
|
|
skip_steps = int(os.getenv("SKIP_STEPS"))
|
|
loss = None
|
|
if need_save:
|
|
for idx in six.moves.xrange(8):
|
|
loss, = exe.run(fetch_list=[avg_cost.name],
|
|
feed=feeder.feed(get_data()))
|
|
if need_save and model_dir and idx == skip_steps and args.trainer_id == 0:
|
|
io.save_persistables(startup_exe, model_dir,
|
|
trainer_prog)
|
|
else:
|
|
for idx in six.moves.xrange(8):
|
|
data = get_data()
|
|
if idx <= skip_steps:
|
|
continue
|
|
loss, = exe.run(fetch_list=[avg_cost.name],
|
|
feed=feeder.feed(data))
|
|
if six.PY2:
|
|
print(pickle.dumps(loss.tolist()))
|
|
else:
|
|
sys.stdout.buffer.write(pickle.dumps(loss.tolist()))
|
|
else:
|
|
raise Exception("save_mode must be LOCAL or DIST")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
paddle.dataset.common.download(DATA_URL, 'simnet', DATA_MD5, "train")
|
|
runtime_main(TestDistSaveLoad2x2)
|