You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
914 lines
34 KiB
914 lines
34 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import print_function
|
|
from functools import reduce
|
|
|
|
import collections
|
|
import math
|
|
import os
|
|
import warnings
|
|
|
|
import six
|
|
import paddle.fluid as fluid
|
|
from paddle.fluid import core
|
|
from paddle.fluid.core import CommContext
|
|
import paddle.fluid.framework as framework
|
|
from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode
|
|
from paddle.fluid.incubate.fleet.parameter_server.ir import vars_metatools
|
|
from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import RoundRobin, PSDispatcher
|
|
from paddle.fluid.transpiler.details.program_utils import delete_ops
|
|
|
|
OP_NAME_SCOPE = "op_namescope"
|
|
CLIP_OP_NAME_SCOPE = "@CLIP"
|
|
STEP_COUNTER = "@PS_STEP_COUNTER@"
|
|
LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@"
|
|
|
|
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
|
|
RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName()
|
|
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
|
|
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
|
|
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
|
|
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
|
|
|
|
SPARSE_OP_LIST = ["lookup_table", "lookup_table_v2"]
|
|
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
|
|
|
|
|
|
def _get_lr_ops(program):
|
|
lr_ops = []
|
|
for index, op in enumerate(program.global_block().ops):
|
|
role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME))
|
|
if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \
|
|
role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \
|
|
int(OPT_OP_ROLE_ATTR_VALUE):
|
|
lr_ops.append(op)
|
|
return lr_ops
|
|
|
|
|
|
def _has_global_step(lr_ops):
|
|
if len(lr_ops) > 0:
|
|
for idx, op in enumerate(lr_ops):
|
|
if op.type != 'increment':
|
|
continue
|
|
counter = op.input("X")[0]
|
|
if counter == LEARNING_RATE_DECAY_COUNTER:
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_sparse_op(op):
|
|
if op.type in SPARSE_OP_LIST and op.attr('is_sparse') is True and op.attr(
|
|
'is_distributed') is False:
|
|
return True
|
|
|
|
if op.type == "distributed_lookup_table" and op.attr(
|
|
'is_distributed') is False:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def is_distributed_sparse_op(op):
|
|
if op.type in SPARSE_OP_LIST and op.attr('is_distributed') is True:
|
|
return True
|
|
|
|
if op.type == "distributed_lookup_table" and op.attr(
|
|
'is_distributed') is True:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def get_sparse_tablename(op):
|
|
return op.input("W")[0]
|
|
|
|
|
|
def get_sparse_tablenames(program, is_distributed):
|
|
tablenames = set()
|
|
if is_distributed:
|
|
for op in program.global_block().ops:
|
|
if is_distributed_sparse_op(op):
|
|
tablenames.add(get_sparse_tablename(op))
|
|
else:
|
|
for op in program.global_block().ops:
|
|
if is_sparse_op(op):
|
|
tablenames.add(get_sparse_tablename(op))
|
|
return list(tablenames)
|
|
|
|
|
|
class MergedVariable:
|
|
def __init__(self, merged, ordered, offsets):
|
|
self.merged_var = merged
|
|
self.ordered_vars = ordered
|
|
self.offsets = offsets
|
|
|
|
|
|
def Singleton(cls):
|
|
_instance = {}
|
|
|
|
def _singleton(*args, **kargs):
|
|
if cls not in _instance:
|
|
_instance[cls] = cls(*args, **kargs)
|
|
return _instance[cls]
|
|
|
|
return _singleton
|
|
|
|
|
|
@Singleton
|
|
class CompileTimeStrategy(object):
|
|
def __init__(self, main_program, startup_program, strategy, role_maker):
|
|
self.min_block_size = 8192
|
|
|
|
self.origin_main_program = main_program
|
|
self.origin_startup_program = startup_program
|
|
|
|
self.strategy = strategy
|
|
self.role_maker = role_maker
|
|
|
|
self.origin_sparse_pairs = []
|
|
self.origin_dense_pairs = []
|
|
|
|
self.merged_variables_pairs = []
|
|
self.merged_dense_pairs = []
|
|
self.merged_sparse_pairs = []
|
|
|
|
self.merged_variable_map = {}
|
|
self.param_name_to_grad_name = {}
|
|
self.grad_name_to_param_name = {}
|
|
|
|
self.param_grad_ep_mapping = collections.OrderedDict()
|
|
self.grad_param_mapping = collections.OrderedDict()
|
|
|
|
self._build_var_distributed()
|
|
|
|
def get_distributed_mode(self):
|
|
trainer = self.strategy.get_trainer_runtime_config()
|
|
return trainer.mode
|
|
|
|
def is_sync_mode(self):
|
|
trainer = self.strategy.get_trainer_runtime_config()
|
|
return trainer.mode == DistributedMode.SYNC
|
|
|
|
def is_geo_mode(self):
|
|
trainer = self.strategy.get_trainer_runtime_config()
|
|
return trainer.mode == DistributedMode.GEO
|
|
|
|
def is_async_mode(self):
|
|
trainer = self.strategy.get_trainer_runtime_config()
|
|
return trainer.mode == DistributedMode.ASYNC
|
|
|
|
def get_role_id(self):
|
|
try:
|
|
return self.role_maker._role_id()
|
|
except Exception:
|
|
return self.role_maker.role_id()
|
|
|
|
def get_trainers(self):
|
|
try:
|
|
return self.role_maker._worker_num()
|
|
except Exception:
|
|
return self.role_maker.worker_num()
|
|
|
|
def get_ps_endpoint(self):
|
|
try:
|
|
return self.role_maker._get_pserver_endpoints()[self.get_role_id()]
|
|
except Exception:
|
|
return self.role_maker.get_pserver_endpoints()[self.get_role_id()]
|
|
|
|
def get_ps_endpoints(self):
|
|
try:
|
|
return self.role_maker._get_pserver_endpoints()
|
|
except Exception:
|
|
return self.role_maker.get_pserver_endpoints()
|
|
|
|
def get_heter_worker_endpoints(self):
|
|
try:
|
|
return self.role_maker._get_heter_worker_endpoints()
|
|
except Exception:
|
|
return self.role_maker.get_heter_worker_endpoints()
|
|
|
|
def get_heter_worker_endpoint(self):
|
|
try:
|
|
return self.role_maker._get_heter_worker_endpoint()
|
|
except Exception:
|
|
return self.role_maker.get_heter_worker_endpoint()
|
|
|
|
def get_origin_programs(self):
|
|
return self.origin_main_program, self.origin_startup_program
|
|
|
|
def get_origin_main_program(self):
|
|
return self.origin_main_program
|
|
|
|
def get_origin_startup_program(self):
|
|
return self.origin_startup_program
|
|
|
|
def get_sparse_varname_on_ps(self, is_distributed, endpoint=None):
|
|
if not endpoint:
|
|
endpoint = self.get_ps_endpoint()
|
|
|
|
varnames = get_sparse_tablenames(self.get_origin_main_program(),
|
|
is_distributed)
|
|
ps_sparse_varnames = []
|
|
for varname in varnames:
|
|
tables = self.get_var_distributed(varname, True)
|
|
for i in range(len(tables)):
|
|
table, ep, _ = tables[i]
|
|
if ep == endpoint:
|
|
ps_sparse_varnames.append(table)
|
|
return ps_sparse_varnames
|
|
|
|
def build_ctx(self,
|
|
vars,
|
|
mapping,
|
|
is_grad,
|
|
is_sparse,
|
|
is_send,
|
|
is_distributed=False):
|
|
def get_grad_var_ep(slices):
|
|
names = []
|
|
eps = []
|
|
sections = []
|
|
|
|
for slice in slices:
|
|
if self.is_geo_mode():
|
|
if is_send:
|
|
names.append("{}.delta".format(slice.name))
|
|
else:
|
|
names.append(slice.name)
|
|
elif is_grad and self.is_sync_mode() and self.get_trainers(
|
|
) > 1:
|
|
names.append("{}.trainer_{}".format(slice.name,
|
|
self.get_role_id()))
|
|
else:
|
|
names.append(slice.name)
|
|
|
|
sections.append(slice.shape[0])
|
|
|
|
for ep, pairs in self.param_grad_ep_mapping.items():
|
|
params, grads = pairs["params"], pairs["grads"]
|
|
|
|
for var in params + grads:
|
|
if slice.name == var.name:
|
|
eps.append(ep)
|
|
break
|
|
return names, eps, sections
|
|
|
|
if isinstance(vars, MergedVariable):
|
|
name = vars.merged_var.name
|
|
slices = mapping[name]
|
|
names, eps, sections = get_grad_var_ep(slices)
|
|
origin_varnames = [var.name for var in vars.ordered_vars]
|
|
else:
|
|
name = vars.name
|
|
slices = mapping[name]
|
|
names, eps, sections = get_grad_var_ep(slices)
|
|
origin_varnames = [vars.name]
|
|
|
|
trainer_id = self.get_role_id()
|
|
aggregate = True
|
|
ctx = CommContext(name, names, eps, sections, origin_varnames,
|
|
trainer_id, aggregate, is_sparse, is_distributed)
|
|
return ctx
|
|
|
|
def get_trainer_send_context(self):
|
|
send_ctx = {}
|
|
distibuted_varnames = get_sparse_tablenames(self.origin_main_program,
|
|
True)
|
|
|
|
if not self.is_geo_mode():
|
|
for merged in self.merged_dense_pairs:
|
|
grad = merged[1]
|
|
ctx = self.build_ctx(grad, self.grad_var_mapping, True, False,
|
|
True)
|
|
send_ctx[ctx.var_name()] = ctx
|
|
|
|
for merged in self.merged_sparse_pairs:
|
|
param = merged[0]
|
|
grad = merged[1]
|
|
|
|
param_name = param.merged_var.name
|
|
|
|
is_distributed = True if param_name in distibuted_varnames else False
|
|
|
|
ctx = self.build_ctx(grad, self.grad_var_mapping, True, True,
|
|
True, is_distributed)
|
|
send_ctx[ctx.var_name()] = ctx
|
|
|
|
if self.is_async_mode():
|
|
name, ctx = self._step_ctx()
|
|
send_ctx[name] = ctx
|
|
else:
|
|
for pairs in self.origin_sparse_pairs:
|
|
param, grad = pairs
|
|
param_name = param.name
|
|
is_distributed = True if param_name in distibuted_varnames else False
|
|
|
|
param_ctx = self.build_ctx(param, self.param_var_mapping, False,
|
|
True, True, is_distributed)
|
|
grad_ctx = self.build_ctx(grad, self.grad_var_mapping, True,
|
|
True, True, is_distributed)
|
|
|
|
ctx = CommContext(param_ctx.var_name(),
|
|
param_ctx.split_varnames(),
|
|
param_ctx.split_endpoints(),
|
|
param_ctx.sections(),
|
|
grad_ctx.origin_varnames(),
|
|
param_ctx.trainer_id(),
|
|
param_ctx.aggregate(),
|
|
param_ctx.is_sparse(),
|
|
param_ctx.is_distributed())
|
|
|
|
send_ctx[ctx.var_name()] = ctx
|
|
name, ctx = self._step_ctx()
|
|
send_ctx[name] = ctx
|
|
return send_ctx
|
|
|
|
def get_communicator_send_context(self):
|
|
send_ctx = {}
|
|
distibuted_varnames = get_sparse_tablenames(self.origin_main_program,
|
|
True)
|
|
|
|
if self.is_geo_mode():
|
|
for pairs in self.merged_dense_pairs:
|
|
param = pairs[0]
|
|
ctx = self.build_ctx(param, self.param_var_mapping, False,
|
|
False, True)
|
|
send_ctx[ctx.var_name()] = ctx
|
|
|
|
for pairs in self.merged_sparse_pairs:
|
|
param = pairs[0]
|
|
param_name = param.merged_var.name
|
|
is_distributed = True if param_name in distibuted_varnames else False
|
|
|
|
ctx = self.build_ctx(param, self.param_var_mapping, False, True,
|
|
True, is_distributed)
|
|
send_ctx[ctx.var_name()] = ctx
|
|
name, ctx = self._step_ctx()
|
|
send_ctx[name] = ctx
|
|
else:
|
|
for merged in self.merged_dense_pairs:
|
|
grad = merged[1]
|
|
ctx = self.build_ctx(grad, self.grad_var_mapping, True, False,
|
|
True)
|
|
send_ctx[ctx.var_name()] = ctx
|
|
|
|
for merged in self.merged_sparse_pairs:
|
|
param, grad = merged
|
|
param_name = param.merged_var.name
|
|
|
|
is_distributed = True if param_name in distibuted_varnames else False
|
|
|
|
ctx = self.build_ctx(grad, self.grad_var_mapping, True, True,
|
|
True, is_distributed)
|
|
send_ctx[ctx.var_name()] = ctx
|
|
|
|
name, ctx = self._step_ctx()
|
|
send_ctx[name] = ctx
|
|
return send_ctx
|
|
|
|
def get_communicator_recv_context(self, recv_type=1):
|
|
# recv_type
|
|
# 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL
|
|
distibuted_varnames = get_sparse_tablenames(self.origin_main_program,
|
|
True)
|
|
sparse_varnames = []
|
|
for pairs in self.origin_sparse_pairs:
|
|
param, grad = pairs
|
|
sparse_varnames.append(param.name)
|
|
|
|
dense_recv_ctx = {}
|
|
sparse_recv_ctx = {}
|
|
distributed_recv_ctx = {}
|
|
|
|
for merged in self.merged_variables_pairs:
|
|
params = merged[0]
|
|
if params.merged_var.name in sparse_varnames:
|
|
continue
|
|
|
|
ctx = self.build_ctx(params, self.param_var_mapping, False, False,
|
|
False)
|
|
dense_recv_ctx[ctx.var_name()] = ctx
|
|
|
|
for pairs in self.origin_sparse_pairs:
|
|
param, grad = pairs
|
|
|
|
if param.name in distibuted_varnames:
|
|
ctx = self.build_ctx(param, self.param_var_mapping, False, True,
|
|
False, True)
|
|
distributed_recv_ctx[ctx.var_name()] = ctx
|
|
else:
|
|
ctx = self.build_ctx(param, self.param_var_mapping, False, True,
|
|
False, False)
|
|
sparse_recv_ctx[ctx.var_name()] = ctx
|
|
|
|
if recv_type == 1:
|
|
return dense_recv_ctx
|
|
if recv_type == 2:
|
|
return sparse_recv_ctx
|
|
if recv_type == 3:
|
|
return distributed_recv_ctx
|
|
if recv_type == 4:
|
|
dense_recv_ctx.update(sparse_recv_ctx)
|
|
dense_recv_ctx.update(distributed_recv_ctx)
|
|
return dense_recv_ctx
|
|
assert ValueError(
|
|
"recv_type can only be 1/2/3/4, 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL"
|
|
)
|
|
|
|
def get_server_runtime_config(self):
|
|
return self.strategy.get_server_runtime_config()
|
|
|
|
def get_var_distributed(self, varname, is_param):
|
|
var_distributed = []
|
|
offset = 0
|
|
if is_param:
|
|
params = self.param_var_mapping[varname]
|
|
param_varnames = [var.name for var in params]
|
|
for ep, pairs in self.param_grad_ep_mapping.items():
|
|
for p in pairs["params"]:
|
|
if p.name in param_varnames:
|
|
offset += p.shape[0]
|
|
var_distributed.append((p.name, ep, p.shape[0]))
|
|
else:
|
|
grads = self.grad_var_mapping[varname]
|
|
grad_varnames = [var.name for var in grads]
|
|
for ep, pairs in self.param_grad_ep_mapping.items():
|
|
for g in pairs["grads"]:
|
|
if g.name in grad_varnames:
|
|
var_distributed.append((g.name, ep, g.shape[0]))
|
|
return var_distributed
|
|
|
|
def _step_ctx(self):
|
|
name = STEP_COUNTER
|
|
trainer_id = self.get_role_id()
|
|
endpoints = self.get_ps_endpoints()
|
|
sections = [1] * len(endpoints)
|
|
names = [name] * len(endpoints)
|
|
ctx = CommContext(name, names, endpoints, sections, [name], trainer_id,
|
|
True, False, False)
|
|
return name, ctx
|
|
|
|
def _create_vars_from_blocklist(self, block_list):
|
|
"""
|
|
Create vars for each split.
|
|
NOTE: only grads need to be named for different trainers, use
|
|
add_trainer_suffix to rename the grad vars.
|
|
Args:
|
|
block_list (list[(varname, block_id, block_size)]): List of gradient blocks.
|
|
add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True.
|
|
Returns:
|
|
var_mapping (collections.OrderedDict(varname->[new_varname_variable])):A dict mapping
|
|
from original var name to each var split.
|
|
"""
|
|
|
|
# varname->[(block_id, current_block_size)]
|
|
block_map = collections.OrderedDict()
|
|
var_mapping = collections.OrderedDict()
|
|
|
|
for block_str in block_list:
|
|
varname, offset, size = block_str.split(":")
|
|
if varname not in block_map:
|
|
block_map[varname] = []
|
|
block_map[varname].append((int(offset), int(size)))
|
|
|
|
for varname, split in six.iteritems(block_map):
|
|
orig_var = self.merged_variable_map[varname]
|
|
|
|
if len(split) == 1:
|
|
var_mapping[varname] = [orig_var]
|
|
self.var_distributed.add_distributed_var(
|
|
origin_var=orig_var,
|
|
slice_var=orig_var,
|
|
block_id=0,
|
|
offset=0,
|
|
is_slice=False,
|
|
vtype="Param")
|
|
else:
|
|
var_mapping[varname] = []
|
|
orig_shape = orig_var.shape
|
|
orig_dim1_flatten = 1
|
|
|
|
if len(orig_shape) >= 2:
|
|
orig_dim1_flatten = reduce(lambda x, y: x * y,
|
|
orig_shape[1:])
|
|
|
|
for i, block in enumerate(split):
|
|
size = block[1]
|
|
rows = size // orig_dim1_flatten
|
|
splited_shape = [rows]
|
|
if len(orig_shape) >= 2:
|
|
splited_shape.extend(orig_shape[1:])
|
|
|
|
new_var_name = "%s.block%d" % (varname, i)
|
|
slice_var = vars_metatools.VarStruct(
|
|
name=new_var_name,
|
|
shape=splited_shape,
|
|
dtype=orig_var.dtype,
|
|
type=orig_var.type,
|
|
lod_level=orig_var.lod_level,
|
|
persistable=False)
|
|
var_mapping[varname].append(slice_var)
|
|
|
|
self.var_distributed.add_distributed_var(
|
|
origin_var=orig_var,
|
|
slice_var=slice_var,
|
|
block_id=i,
|
|
offset=-1,
|
|
is_slice=False,
|
|
vtype="Param")
|
|
|
|
return var_mapping
|
|
|
|
def _dispatcher(self):
|
|
ps_dispatcher = RoundRobin(self.get_ps_endpoints())
|
|
ps_dispatcher.reset()
|
|
grad_var_mapping_items = list(six.iteritems(self.grad_var_mapping))
|
|
|
|
sparse_gradnames = [grad.name for _, grad in self.origin_sparse_pairs]
|
|
|
|
for grad_varname, splited_vars in grad_var_mapping_items:
|
|
if grad_varname in sparse_gradnames:
|
|
continue
|
|
|
|
send_vars = []
|
|
for _, var in enumerate(splited_vars):
|
|
send_vars.append(var)
|
|
|
|
recv_vars = []
|
|
for _, var in enumerate(send_vars):
|
|
recv_vars.append(self.grad_param_mapping[var])
|
|
|
|
eps = ps_dispatcher.dispatch(recv_vars)
|
|
|
|
for i, ep in enumerate(eps):
|
|
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
|
|
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
|
|
|
|
for grad_varname, splited_vars in grad_var_mapping_items:
|
|
if grad_varname not in sparse_gradnames:
|
|
continue
|
|
|
|
ps_dispatcher.reset()
|
|
|
|
send_vars = []
|
|
for _, var in enumerate(splited_vars):
|
|
send_vars.append(var)
|
|
|
|
recv_vars = []
|
|
for _, var in enumerate(send_vars):
|
|
recv_vars.append(self.grad_param_mapping[var])
|
|
|
|
eps = ps_dispatcher.dispatch(recv_vars)
|
|
|
|
for i, ep in enumerate(eps):
|
|
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
|
|
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
|
|
|
|
def _slice_variable(self,
|
|
var_list,
|
|
slice_count,
|
|
min_block_size,
|
|
uniform=False):
|
|
"""
|
|
We may need to split dense tensor to one or more blocks and put
|
|
them equally onto parameter server. One block is a sub-tensor
|
|
aligned by dim[0] of the tensor.
|
|
|
|
We need to have a minimal block size so that the calculations in
|
|
the parameter server side can gain better performance. By default
|
|
minimum block size 8K elements (maybe 16bit or 32bit or 64bit).
|
|
|
|
Args:
|
|
var_list (list): List of variables.
|
|
slice_count (int): Numel of count that variables will be sliced, which
|
|
could be the pserver services' count.
|
|
min_block_size (int): Minimum split block size.
|
|
Returns:
|
|
blocks (list[(varname, block_id, current_block_size)]): A list
|
|
of VarBlocks. Each VarBlock specifies a shard of the var.
|
|
"""
|
|
blocks = []
|
|
for var in var_list:
|
|
if not uniform:
|
|
var_numel = reduce(lambda x, y: x * y, var.shape)
|
|
|
|
split_count = 1
|
|
|
|
# if min_block_size == -1:
|
|
# split_count = 1
|
|
# else:
|
|
# split_count = slice_count
|
|
# max_pserver_count = int(
|
|
# math.floor(var_numel / float(min_block_size)))
|
|
# if max_pserver_count == 0:
|
|
# max_pserver_count = 1
|
|
# if max_pserver_count < slice_count:
|
|
# split_count = max_pserver_count
|
|
block_size = int(math.ceil(var_numel / float(split_count)))
|
|
|
|
if len(var.shape) >= 2:
|
|
# align by dim1(width)
|
|
dim1 = reduce(lambda x, y: x * y, var.shape[1:])
|
|
remains = block_size % dim1
|
|
if remains != 0:
|
|
block_size += dim1 - remains
|
|
# update split_count after aligning
|
|
split_count = int(math.ceil(var_numel / float(block_size)))
|
|
for block_id in range(split_count):
|
|
curr_block_size = min(block_size, var_numel - (
|
|
(block_id) * block_size))
|
|
block = vars_metatools.VarBlock(var.name, block_id,
|
|
curr_block_size)
|
|
blocks.append(str(block))
|
|
else:
|
|
block_size = var.shape[0] / slice_count
|
|
remainder = var.shape[0] % slice_count
|
|
|
|
if block_size == 0:
|
|
dim0s = [block_size] * remainder
|
|
else:
|
|
dim0s = [block_size] * slice_count
|
|
for i in range(remainder):
|
|
dim0s[i] = dim0s[i] + 1
|
|
|
|
dim1 = reduce(lambda x, y: x * y, var.shape[1:])
|
|
|
|
for block_id in range(len(dim0s)):
|
|
numel = dim0s[block_id] * dim1
|
|
block = vars_metatools.VarBlock(var.name, block_id, numel)
|
|
blocks.append(str(block))
|
|
return blocks
|
|
|
|
def _get_param_grad_blocks(self, pairs, min_block_size, uniform=False):
|
|
param_list = []
|
|
grad_list = []
|
|
param_grad_set = set()
|
|
for p, g in pairs:
|
|
# todo(tangwei12) skip parameter marked not trainable
|
|
# if type(p) == Parameter and p.trainable == False:
|
|
# continue
|
|
p = p.merged_var
|
|
g = g.merged_var
|
|
|
|
if p.name not in param_grad_set:
|
|
param_list.append(p)
|
|
param_grad_set.add(p.name)
|
|
if g.name not in param_grad_set:
|
|
grad_list.append(g)
|
|
param_grad_set.add(g.name)
|
|
|
|
# when we slice var up into blocks, we will slice the var according to
|
|
# pserver services' count. A pserver may have two or more listening ports.
|
|
grad_blocks = self._slice_variable(grad_list,
|
|
len(self.get_ps_endpoints()),
|
|
min_block_size, uniform)
|
|
|
|
param_blocks = self._slice_variable(param_list,
|
|
len(self.get_ps_endpoints()),
|
|
min_block_size, uniform)
|
|
return param_blocks, grad_blocks
|
|
|
|
def _var_slice_and_distribute(self):
|
|
# update these mappings for further transpile:
|
|
# 1. param_var_mapping : param var name->[split params vars]
|
|
# 2. grad_var_mapping : grad var name->[split grads vars]
|
|
# 3. grad_param_mapping : grad.blockx->param.blockx
|
|
# 4. param_grad_ep_mapping : ep->{"params" : [], "grads" : [] }
|
|
|
|
dps, dgs = self._get_param_grad_blocks(self.merged_dense_pairs, -1,
|
|
False)
|
|
sps, sgs = self._get_param_grad_blocks(self.merged_sparse_pairs,
|
|
self.min_block_size, True)
|
|
|
|
param_blocks = dps + sps
|
|
grad_blocks = dgs + sgs
|
|
|
|
assert (len(grad_blocks) == len(param_blocks))
|
|
|
|
# origin_param_name->[splited_param_vars]
|
|
self.param_var_mapping = self._create_vars_from_blocklist(param_blocks)
|
|
self.grad_var_mapping = self._create_vars_from_blocklist(grad_blocks)
|
|
|
|
# dict(grad_splited_var->param_splited_var)
|
|
self.grad_param_mapping = collections.OrderedDict()
|
|
for g, p in zip(grad_blocks, param_blocks):
|
|
g_name, g_bid, _ = g.split(":")
|
|
p_name, p_bid, _ = p.split(":")
|
|
self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
|
|
self.param_var_mapping[p_name][int(p_bid)]
|
|
|
|
print_maps = {}
|
|
for k, v in self.grad_param_mapping.items():
|
|
print_maps[str(k)] = str(v)
|
|
|
|
# create mapping of endpoint->split var to create pserver side program
|
|
self.param_grad_ep_mapping = collections.OrderedDict()
|
|
[
|
|
self.param_grad_ep_mapping.update({
|
|
ep: {
|
|
"params": [],
|
|
"grads": []
|
|
}
|
|
}) for ep in self.get_ps_endpoints()
|
|
]
|
|
|
|
def _build_var_distributed(self):
|
|
self.var_distributed = vars_metatools.VarsDistributed()
|
|
|
|
sparse_pairs, dense_pairs = self.get_param_grads()
|
|
origin_for_sparse = []
|
|
origin_for_dense = []
|
|
param_name_grad_name = dict()
|
|
grad_name_to_param_name = dict()
|
|
|
|
for param, grad in sparse_pairs:
|
|
param = vars_metatools.create_var_struct(param)
|
|
grad = vars_metatools.create_var_struct(grad)
|
|
origin_for_sparse.append((param, grad))
|
|
|
|
for param, grad in dense_pairs:
|
|
param = vars_metatools.create_var_struct(param)
|
|
grad = vars_metatools.create_var_struct(grad)
|
|
origin_for_dense.append((param, grad))
|
|
|
|
for dense_pair in origin_for_dense:
|
|
param, grad = dense_pair
|
|
|
|
m_param = MergedVariable(param, [param], [0])
|
|
m_grad = MergedVariable(grad, [grad], [0])
|
|
self.merged_variables_pairs.append((m_param, m_grad))
|
|
self.merged_dense_pairs.append((m_param, m_grad))
|
|
|
|
for sparse_pair in origin_for_sparse:
|
|
param, grad = sparse_pair
|
|
|
|
m_param = MergedVariable(param, [param], [0])
|
|
m_grad = MergedVariable(grad, [grad], [0])
|
|
self.merged_variables_pairs.append((m_param, m_grad))
|
|
self.merged_sparse_pairs.append((m_param, m_grad))
|
|
|
|
for merged in self.merged_variables_pairs:
|
|
m_param, m_grad = merged
|
|
self.merged_variable_map[
|
|
m_param.merged_var.name] = m_param.merged_var
|
|
self.merged_variable_map[m_grad.merged_var.name] = m_grad.merged_var
|
|
|
|
param_merges = []
|
|
param_merges.extend(origin_for_sparse)
|
|
param_merges.extend(origin_for_dense)
|
|
|
|
for param, grad in param_merges:
|
|
param_name_grad_name[param.name] = grad.name
|
|
grad_name_to_param_name[grad.name] = param.name
|
|
|
|
self.origin_sparse_pairs = origin_for_sparse
|
|
self.origin_dense_pairs = origin_for_dense
|
|
self.param_name_to_grad_name = param_name_grad_name
|
|
self.grad_name_to_param_name = grad_name_to_param_name
|
|
|
|
sparse_pair_map = collections.OrderedDict()
|
|
|
|
for pair in self.origin_sparse_pairs + self.origin_dense_pairs:
|
|
param, grad = pair
|
|
sparse_pair_map[param.name] = str(param)
|
|
sparse_pair_map[grad.name] = str(grad)
|
|
|
|
self._var_slice_and_distribute()
|
|
self._dispatcher()
|
|
|
|
def get_param_grads(self):
|
|
origin_program = self.origin_main_program
|
|
|
|
def _get_params_grads(sparse_varnames):
|
|
block = origin_program.global_block()
|
|
|
|
dense_param_grads = []
|
|
sparse_param_grads = []
|
|
|
|
optimize_params = set()
|
|
origin_var_dict = origin_program.global_block().vars
|
|
role_id = int(core.op_proto_and_checker_maker.OpRole.Backward)
|
|
for op in block.ops:
|
|
if _is_opt_role_op(op):
|
|
# delete clip op from opt_ops when run in Parameter Server mode
|
|
if OP_NAME_SCOPE in op.all_attrs() \
|
|
and CLIP_OP_NAME_SCOPE in op.attr(OP_NAME_SCOPE):
|
|
op._set_attr("op_role", role_id)
|
|
continue
|
|
if op.attr(OP_ROLE_VAR_ATTR_NAME):
|
|
param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
|
|
grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
|
|
if param_name not in optimize_params:
|
|
optimize_params.add(param_name)
|
|
param_grad = (origin_var_dict[param_name],
|
|
origin_var_dict[grad_name])
|
|
|
|
if param_name in sparse_varnames:
|
|
sparse_param_grads.append(param_grad)
|
|
else:
|
|
dense_param_grads.append(param_grad)
|
|
return sparse_param_grads, dense_param_grads
|
|
|
|
def _get_sparse_varnames():
|
|
varnames = []
|
|
for op in origin_program.global_block().ops:
|
|
if op.type in SPARSE_OP_TYPE_DICT.keys() \
|
|
and op.attr('remote_prefetch') is True:
|
|
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
|
|
varnames.append(param_name)
|
|
|
|
return list(set(varnames))
|
|
|
|
sparse_varnames = _get_sparse_varnames()
|
|
sparse_param_grads, dense_param_grads = _get_params_grads(
|
|
sparse_varnames)
|
|
|
|
return sparse_param_grads, dense_param_grads
|
|
|
|
def remove_var_pair_by_grad(self, var_name):
|
|
|
|
for index, pair in enumerate(self.merged_variables_pairs):
|
|
var = pair[0]
|
|
var_grad = pair[1]
|
|
if var_grad.merged_var.name == var_name:
|
|
del self.merged_variables_pairs[index]
|
|
|
|
for index, pair in enumerate(self.merged_dense_pairs):
|
|
var = pair[0]
|
|
var_grad = pair[1]
|
|
if var_grad.merged_var.name == var_name:
|
|
del self.merged_dense_pairs[index]
|
|
return
|
|
|
|
for index, pair in enumerate(self.merged_sparse_pairs):
|
|
var = pair[0]
|
|
var_grad = pair[1]
|
|
if var_grad.merged_var.name == var_name:
|
|
del self.merged_sparse_pairs[index]
|
|
return
|
|
|
|
print("Not find {} in self.merge_pairs".format(var_name))
|
|
|
|
|
|
def _is_opt_role_op(op):
|
|
# NOTE : depend on oprole to find out whether this op is for
|
|
# optimize
|
|
op_maker = core.op_proto_and_checker_maker
|
|
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
|
|
if op_maker.kOpRoleAttrName() in op.attr_names and \
|
|
int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _get_optimize_ops(_program):
|
|
block = _program.global_block()
|
|
opt_ops = []
|
|
for op in block.ops:
|
|
if _is_opt_role_op(op):
|
|
# delete clip op from opt_ops when run in Parameter Server mode
|
|
if OP_NAME_SCOPE in op.all_attrs() \
|
|
and CLIP_OP_NAME_SCOPE in op.attr(OP_NAME_SCOPE):
|
|
op._set_attr(
|
|
"op_role",
|
|
int(core.op_proto_and_checker_maker.OpRole.Backward))
|
|
continue
|
|
opt_ops.append(op)
|
|
return opt_ops
|
|
|
|
|
|
def _get_varname_parts(varname):
|
|
# returns origin, blockid, trainerid
|
|
orig_var_name = ""
|
|
trainer_part = ""
|
|
block_part = ""
|
|
trainer_idx = varname.find(".trainer_")
|
|
if trainer_idx >= 0:
|
|
trainer_part = varname[trainer_idx + 1:]
|
|
else:
|
|
trainer_idx = len(varname)
|
|
block_index = varname.find(".block")
|
|
if block_index >= 0:
|
|
block_part = varname[block_index + 1:trainer_idx]
|
|
else:
|
|
block_index = len(varname)
|
|
orig_var_name = varname[0:min(block_index, trainer_idx)]
|
|
return orig_var_name, block_part, trainer_part
|
|
|
|
|
|
def _orig_varname(varname):
|
|
orig, _, _ = _get_varname_parts(varname)
|
|
return orig
|