You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
201 lines
8.0 KiB
201 lines
8.0 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Parameter Server utils"""
|
|
|
|
import numpy as np
|
|
import os
|
|
import paddle
|
|
|
|
|
|
class DistributedInfer:
|
|
"""
|
|
Utility class for distributed infer of PaddlePaddle.
|
|
"""
|
|
|
|
def __init__(self, main_program=None, startup_program=None):
|
|
if main_program:
|
|
self.origin_main_program = main_program.clone()
|
|
else:
|
|
self.origin_main_program = paddle.static.default_main_program(
|
|
).clone()
|
|
|
|
if startup_program:
|
|
self.origin_startup_program = startup_program
|
|
else:
|
|
self.origin_startup_program = paddle.static.default_startup_program(
|
|
)
|
|
self.sparse_table_maps = None
|
|
|
|
def init_distributed_infer_env(self,
|
|
exe,
|
|
loss,
|
|
role_maker=None,
|
|
dirname=None):
|
|
import paddle.distributed.fleet as fleet
|
|
|
|
if fleet.fleet._runtime_handle is None:
|
|
fleet.init(role_maker=role_maker)
|
|
|
|
fake_optimizer = paddle.optimizer.SGD()
|
|
strategy = fleet.DistributedStrategy()
|
|
strategy.a_sync = True
|
|
optimizer = fleet.distributed_optimizer(
|
|
fake_optimizer, strategy=strategy)
|
|
optimizer.minimize(
|
|
loss, startup_program=self.origin_startup_program)
|
|
|
|
if fleet.is_server():
|
|
fleet.init_server(dirname=dirname)
|
|
fleet.run_server()
|
|
else:
|
|
exe.run(paddle.static.default_startup_program())
|
|
fleet.init_worker()
|
|
self._init_dense_params(exe, dirname)
|
|
global_startup_program = paddle.static.default_startup_program()
|
|
global_startup_program = self.origin_startup_program
|
|
global_main_program = paddle.static.default_main_program()
|
|
global_main_program = self.origin_main_program
|
|
|
|
def _get_sparse_table_map(self):
|
|
import paddle.distributed.fleet as fleet
|
|
|
|
if self.sparse_table_maps is None:
|
|
self.sparse_table_maps = {}
|
|
send_ctx = fleet.fleet._runtime_handle._communicator.send_ctx_
|
|
for gradname, ctx in send_ctx.items():
|
|
if ctx.is_sparse:
|
|
param = gradname.strip("@GRAD")
|
|
self.sparse_table_maps[param] = ctx.table_id()
|
|
else:
|
|
continue
|
|
return self.sparse_table_maps
|
|
|
|
def _init_dense_params(self, exe=None, dirname=None):
|
|
import paddle.distributed.fleet as fleet
|
|
|
|
sparse_table_maps = self._get_sparse_table_map()
|
|
|
|
if dirname is not None and exe is not None:
|
|
all_persist_vars = [
|
|
v for v in self.origin_main_program.list_vars()
|
|
if paddle.static.io.is_persistable(v)
|
|
]
|
|
dense_persist_vars = [(v.name, v) for v in all_persist_vars
|
|
if v.name not in sparse_table_maps]
|
|
need_load_vars = [
|
|
v[1] for v in dense_persist_vars
|
|
if os.path.isfile(os.path.join(dirname, v[0]))
|
|
]
|
|
paddle.static.load_vars(
|
|
exe,
|
|
dirname,
|
|
main_program=self.origin_main_program,
|
|
vars=need_load_vars)
|
|
|
|
def get_dist_infer_program(self):
|
|
import paddle.distributed.fleet as fleet
|
|
|
|
varname2tables = self._get_sparse_table_map()
|
|
convert_program = self._convert_program(self.origin_main_program,
|
|
varname2tables)
|
|
return convert_program
|
|
|
|
def _convert_program(self, main_program, varname2tables):
|
|
def distributed_ops_pass(program):
|
|
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
|
|
|
|
def _get_pull_sparse_ops(_program):
|
|
pull_sparse_ops = {}
|
|
for op in _program.global_block().ops:
|
|
if op.type in SPARSE_OP_TYPE_DICT.keys() \
|
|
and op.attr('remote_prefetch') is True:
|
|
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
|
|
ops = pull_sparse_ops.get(param_name, [])
|
|
ops.append(op)
|
|
pull_sparse_ops[param_name] = ops
|
|
return pull_sparse_ops
|
|
|
|
def _pull_sparse_fuse(_program, pull_sparse_ops):
|
|
for param, ops in pull_sparse_ops.items():
|
|
all_ops = program.global_block().ops
|
|
op_idxs = [all_ops.index(op) for op in ops]
|
|
|
|
inputs = [
|
|
program.global_block().vars[op.input("Ids")[0]]
|
|
for op in ops
|
|
]
|
|
|
|
w = program.global_block().vars[ops[0].input("W")[0]]
|
|
|
|
if w.name not in varname2tables.keys():
|
|
raise ValueError(
|
|
"can not find variable {}, please check your configuration".
|
|
format(w.name))
|
|
|
|
table_id = varname2tables[w.name]
|
|
|
|
padding_idx = ops[0].attr("padding_idx")
|
|
is_distributed = ops[0].attr("is_distributed")
|
|
op_type = ops[0].type
|
|
|
|
outputs = [
|
|
program.global_block().vars[op.output("Out")[0]]
|
|
for op in ops
|
|
]
|
|
|
|
for idx in op_idxs[::-1]:
|
|
program.global_block()._remove_op(idx)
|
|
|
|
inputs_idxs = [-1] * len(inputs)
|
|
outputs_idxs = [-1] * len(outputs)
|
|
|
|
for idx, op in enumerate(program.global_block().ops):
|
|
for i in range(0, len(op.output_names)):
|
|
outs = op.output(op.output_names[i])
|
|
for in_id, in_var in enumerate(inputs):
|
|
if in_var.name in outs:
|
|
inputs_idxs[in_id] = idx
|
|
for i in range(0, len(op.input_names)):
|
|
ins = op.input(op.input_names[i])
|
|
for out_id, out_var in enumerate(outputs):
|
|
if out_var.name in ins:
|
|
outputs_idxs[out_id] = idx
|
|
|
|
if min(outputs_idxs) - max(inputs_idxs) >= 1:
|
|
distributed_idx = max(inputs_idxs) + 1
|
|
|
|
program.global_block()._insert_op(
|
|
index=distributed_idx,
|
|
type="distributed_lookup_table",
|
|
inputs={"Ids": inputs,
|
|
'W': w},
|
|
outputs={"Outputs": outputs},
|
|
attrs={
|
|
"is_distributed": is_distributed,
|
|
"padding_idx": padding_idx,
|
|
"table_id": table_id,
|
|
"lookup_table_version": op_type
|
|
})
|
|
else:
|
|
raise ValueError(
|
|
"something wrong with Fleet, submit a issue is recommended"
|
|
)
|
|
|
|
pull_sparse_ops = _get_pull_sparse_ops(program)
|
|
_pull_sparse_fuse(program, pull_sparse_ops)
|
|
return program
|
|
|
|
covert_program = distributed_ops_pass(main_program)
|
|
return covert_program
|