You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py

161 lines
6.4 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op, OP_ROLE_KEY, OpRole
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from paddle.fluid import core
class FP16Utils(object):
def __init__(self):
pass
@staticmethod
def is_fp16_cast_op(block, op, params):
if op.type != "cast":
return False
if is_optimizer_op(op):
return False
assert (len(op.desc.input_arg_names()) == 1)
assert (len(op.desc.output_arg_names()) == 1)
input_name, output_name = op.desc.input_arg_names()[
0], op.desc.output_arg_names()[0]
if input_name not in params:
return False
input_var = block.var(input_name)
output_var = block.var(output_name)
if input_var.dtype != core.VarDesc.VarType.FP32 or \
output_var.dtype != core.VarDesc.VarType.FP16:
return False
return True
@staticmethod
def is_fp32_cast_op(block, op):
if op.type != "cast":
return False
if not is_optimizer_op(op):
return False
assert (len(op.desc.input_arg_names()) == 1)
assert (len(op.desc.output_arg_names()) == 1)
input_name, output_name = op.desc.input_arg_names()[
0], op.desc.output_arg_names()[0]
input_var = block.var(input_name)
output_var = block.var(output_name)
if input_var.dtype != core.VarDesc.VarType.FP16 or \
output_var.dtype != core.VarDesc.VarType.FP32:
return False
return True
@staticmethod
def remove_cast_op(block, params, segment, offset):
inserted_op_num = 0
for op_idx in reversed(
range(offset + segment._start_idx, offset + segment._end_idx)):
op = block.ops[op_idx]
if FP16Utils.is_fp16_cast_op(block, op, params):
block._remove_op(op_idx, sync=False)
inserted_op_num -= 1
block._sync_with_cpp()
return inserted_op_num
@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"""
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
"""
# remove cast
for idx, op in reversed(list(enumerate(block.ops))):
if not FP16Utils.is_fp32_cast_op(block, op):
continue
output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format(
output_name))
if output_name in reduced_grads_to_param:
continue
if shard.has_param(param_name):
continue
block._remove_op(idx, sync=False)
block._remove_var(output_name, sync=False)
block._sync_with_cpp()
update_loss_scaling_op_idx = -1
inf_var_name = ''
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = []
for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError(
"Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(input_name))
if shard.has_param(param_name):
reversed_x.append(input_name)
op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x)
if update_loss_scaling_op_idx == -1:
return
inf_var = block.var(inf_var_name)
inf_var_fp32 = block.create_var(
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_sharding = block.create_var(
name=inf_var_name + "@sharding",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_fp32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_fp32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
[inf_var_fp32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
type='c_allreduce_max',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize})
comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num,
type='cast',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_sharding},
attrs={
"in_dtype": inf_var_fp32.dtype,
"out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._sync_with_cpp()