You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/transpiler/geo_sgd_transpiler.py

349 lines
14 KiB

# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
"""
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. create delta variable in global scope which used to send
3. add send op to send sparse ids to communicator
Steps to transpile pserver:
1. create new program for parameter server.
2. create params variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append sum ops that should run on current server instance.
5. add listen_and_serv op
"""
import sys
import collections
import six
import numpy as np
from .ps_dispatcher import RoundRobin, PSDispatcher
from .. import core, framework
from ..framework import Program, default_main_program, \
default_startup_program, Block, Parameter
from .details import wait_server_ready, VarsDistributed
from .details import delete_ops
from ..distribute_lookup_table import find_distributed_lookup_table
from .distribute_transpiler import DistributeTranspiler, DistributeTranspilerConfig, slice_variable, same_or_split_var
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
class GeoSgdTranspiler(DistributeTranspiler):
def __init__(self, config=None):
if config is not None:
self.config = config
else:
self.config = DistributeTranspilerConfig()
if self.config.split_method is None:
self.config.split_method = RoundRobin
assert (self.config.min_block_size >= 8192)
assert (self.config.split_method.__bases__[0] == PSDispatcher)
def transpile(self,
trainer_id,
program=None,
pservers="127.0.0.1:6174",
trainers=1,
sync_mode=False,
startup_program=None,
current_endpoint="127.0.0.1:6174"):
if program is None:
program = default_main_program()
if startup_program is None:
startup_program = default_startup_program()
self.origin_program = program
self.startup_program = startup_program
self.origin_startup_program = self.startup_program.clone()
self.trainer_num = trainers
# geo-sgd only supply async-mode
self.sync_mode = False
self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
self.vars_overview = VarsDistributed()
self.optimize_ops, self.params_grads = self._get_optimize_pass()
ps_dispatcher = self.config.split_method(self.pserver_endpoints)
self.param_name_to_grad_name = dict()
self.grad_name_to_param_name = dict()
for param_var, grad_var in self.params_grads:
self.param_name_to_grad_name[param_var.name] = grad_var.name
self.grad_name_to_param_name[grad_var.name] = param_var.name
# distribute lookup table
self.table_name = find_distributed_lookup_table(self.origin_program)
self.has_distributed_lookup_table = self.table_name != None
self.origin_program._distributed_lookup_table = self.table_name if self.table_name else None
# add distributed attrs to program
self.origin_program._is_distributed = True
self.origin_program._endpoints = self.pserver_endpoints
self.origin_program._ps_endpoint = current_endpoint
self.origin_program._is_chief = self.trainer_id == 0
# program info send to geo-sgd communicator
self.vars_info = collections.OrderedDict()
self.split_to_origin_mapping = collections.OrderedDict()
self.delta_vars_list = []
self.sparse_var_list = []
self.sparse_var_splited_list = []
# split and create vars, then put splited vars in dicts for later use.
# step 1. split and create vars, then put splited vars in dicts for later use.
self._init_splited_vars()
# step 3. create send recv var (param after optimize)
send_vars = []
ps_dispatcher.reset()
param_var_mapping_items = list(six.iteritems(self.param_var_mapping))
# send_vars is the parameter which splited by communicator and send to pserver,not the origin parameter
for _, splited_vars in param_var_mapping_items:
for _, var in enumerate(splited_vars):
send_vars.append(var)
recv_vars = send_vars
ps_dispatcher.reset()
eplist = ps_dispatcher.dispatch(recv_vars)
for i, ep in enumerate(eplist):
self.param_opt_ep_mapping[ep]["params"].append(recv_vars[i])
distributed_var = self.vars_overview.get_distributed_var_by_slice(
recv_vars[i].name)
distributed_var.endpoint = ep
origin_name = self.split_to_origin_mapping[recv_vars[i].name]
self.vars_info[origin_name]["epmap"].append(ep)
self.origin_program._parameters_on_pservers = self.vars_overview
# send sparse id to communicator
self.sparse_var = []
self.sparse_tables = []
for op in self.origin_program.global_block().ops:
if op.type == "lookup_table":
op._set_attr('remote_prefetch', False)
for input_var_name, sparse_var_name in zip(
op.input("Ids"), op.input("W")):
if sparse_var_name in self.sparse_var_list:
input_var = program.global_block().var(input_var_name)
self.sparse_var.append(input_var)
self.sparse_tables.append(sparse_var_name)
# batch training loop end flag
dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
program.global_block().append_op(
type="send",
inputs={"X": self.sparse_var},
outputs={"Out": dummy_output},
attrs={"send_varnames": self.sparse_tables})
# add param_init flag in trainer startup program
self.trainer_startup_program = self._get_trainer_startup_program(
recv_vars=recv_vars, eplist=eplist)
for delta_var in self.delta_vars_list:
self.trainer_startup_program.global_block().create_var(
name=delta_var.name,
persistable=delta_var.persistable,
dtype=delta_var.dtype,
type=delta_var.type,
shape=delta_var.shape)
dummy_output = self.trainer_startup_program.global_block().create_var(
name=framework.generate_control_dev_var_name())
param_init = self.trainer_startup_program.global_block().create_var(
name="param_init")
self.trainer_startup_program.global_block().append_op(
type="send",
inputs={"X": [param_init]},
outputs={"Out": dummy_output},
attrs={"send_varnames": [param_init.name]})
def _get_vars_info(self):
return self.vars_info
def get_trainer_program(self, wait_port=True):
if wait_port:
wait_server_ready(self.pserver_endpoints)
return self.origin_program
def get_pserver_programs(self, endpoint):
pserver_prog = self.get_pserver_program(endpoint)
self.param_grad_ep_mapping = self.param_opt_ep_mapping
pserver_startup = self.get_startup_program(
endpoint, pserver_program=pserver_prog)
return pserver_prog, pserver_startup
def get_pserver_program(self, endpoint):
# step1
pserver_program = Program()
pserver_program.random_seed = self.origin_program.random_seed
pserver_program._copy_dist_param_info_from(self.origin_program)
# step2: Create vars to receive vars at parameter servers.
recv_inputs = []
for v in self.param_opt_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v)
optimize_block = []
param_to_block_id = []
sparse_grad_to_param = []
# append op to the current block
pre_block_idx = pserver_program.num_blocks - 1
for var in self.param_opt_ep_mapping[endpoint]["params"]:
per_opt_block = pserver_program._create_block(pre_block_idx)
optimize_block.append(per_opt_block)
var_name = var.name
pserver_block = per_opt_block.program.global_block()
param = pserver_block.vars[var_name]
delta_var_name = "%s.delta" % (param.name)
if var.name in self.sparse_var_splited_list:
delta_type = core.VarDesc.VarType.SELECTED_ROWS
sparse_grad_to_param.append(":".join(
[delta_var_name, param.name]))
else:
delta_type = param.type
delta_var = pserver_block.create_var(
name=delta_var_name,
persistable=False,
type=delta_type,
dtype=param.dtype,
shape=param.shape)
per_opt_block.append_op(
type="sum",
inputs={"X": [param, delta_var]},
outputs={"Out": param})
param_to_block_id.append(delta_var_name + ":" + str(
per_opt_block.idx))
attrs = {
"optimize_blocks": optimize_block,
"endpoint": endpoint,
"Fanin": self.trainer_num,
"sync_mode": self.sync_mode,
"grad_to_block_id": param_to_block_id,
"sparse_grad_to_param": sparse_grad_to_param
}
# step5 append the listen_and_serv op
pserver_program.global_block().append_op(
type="listen_and_serv",
inputs={'X': recv_inputs},
outputs={},
attrs=attrs)
pserver_program._sync_with_cpp()
# save pserver program to generate pserver side startup relatively.
self.pserver_program = pserver_program
return pserver_program
def _init_splited_vars(self):
param_list = []
grad_list = []
param_grad_set = set()
# step 1. create param_list
for p, g in self.params_grads:
if type(p) == Parameter and p.trainable == False:
continue
if p.name not in param_grad_set:
param_list.append(p)
param_grad_set.add(p.name)
if g.name not in param_grad_set:
grad_list.append(g)
param_grad_set.add(g.name)
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
self.sparse_var_list.append(p.name)
# step 2. Slice vars into numbers of piece with block_size
# when we slice var up into blocks, we will slice the var according to
# pserver services' count. A pserver may have two or more listening ports.
param_blocks = slice_variable(param_list,
len(self.pserver_endpoints),
self.config.min_block_size)
# step 3. Create splited param from split blocks
# origin_param_name -> [splited_param_vars]
# Todo: update _create_vars_from_blocklist
self.param_var_mapping = self._create_vars_from_blocklist(
self.origin_program, param_blocks)
# step 4. Create mapping of endpoint -> split var to create pserver side program
self.param_opt_ep_mapping = collections.OrderedDict()
[
self.param_opt_ep_mapping.update({
ep: {
"params": [],
}
}) for ep in self.pserver_endpoints
]
# step 5. Create delta var of Geo-Sgd & record vars infomation
for origin_name, splited_vars in self.param_var_mapping.items():
origin_var = self.origin_program.global_block().var(origin_name)
self.vars_info[origin_name] = collections.OrderedDict()
self.vars_info[origin_name]["var_names"] = []
vars_section = self._get_splited_var_sections(splited_vars)
self.vars_info[origin_name]["sections"] = [
str(i) for i in vars_section
]
self.vars_info[origin_name]["epmap"] = []
self.vars_info[origin_name]["is_sparse"] = []
# todo: add var shape(may be no need,because recv scope have)
if origin_name in self.sparse_var_list:
delta_type = core.VarDesc.VarType.SELECTED_ROWS
self.vars_info[origin_name]["is_sparse"].append("True")
else:
delta_type = origin_var.type
self.vars_info[origin_name]["is_sparse"].append("False")
delta_var = self.origin_program.global_block().create_var(
name=".".join([origin_name, "delta"]),
persistable=False,
dtype=origin_var.dtype,
type=delta_type,
shape=origin_var.shape)
self.delta_vars_list.append(delta_var)
for splited_var in splited_vars:
is_slice, block_id, offset = self._get_slice_var_info(
splited_var)
self.vars_overview.add_distributed_var(
origin_var=origin_var,
slice_var=splited_var,
block_id=block_id,
offset=offset,
is_slice=is_slice,
vtype="Param")
self.split_to_origin_mapping[splited_var.name] = origin_name
if origin_name in self.sparse_var_list:
self.sparse_var_splited_list.append(splited_var.name)
self.vars_info[origin_name]["var_names"].append(
splited_var.name)
if len(splited_vars) != 1:
self.origin_program.global_block().create_var(
name=".".join([splited_var.name, "delta"]),
persistable=False,
dtype=splited_var.dtype,
type=delta_type,
shape=splited_var.shape)