You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py

2197 lines
92 KiB

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import Optimizer
import paddle.fluid.core as core
import numpy as np
from paddle.distributed import fleet
from functools import reduce
registerd_op = {## forwards
"elementwise_add": "AddParser",
"matmul": "MatMulParser",
"mul": "MulParser",
"relu": "ReluParser",
"softmax_with_cross_entropy": "SoftmaxWithCrossEntropyParser",
"shape": "ShapeParser",
"fill_constant": "FillConstantParser",
"reduce_sum": "ReduceSumParser",
"elementwise_mul": "DotMulParser",
"elementwise_div": "DotDivParser",
"elementwise_pow": "DotPowParser",
"elementwise_max": "MaxParser",
"elementwise_min": "MinParser",
"elementwise_sub": "DotSubParser",
"pow": "PowParser",
"gelu": "GeluParser",
"sqrt": "SqrtParser",
"log": "LogParser",
"sum": "SumParser",
"logical_not": "LogicalNotParser",
"gather": "GatherParser",
"scatter": "ScatterParser",
"cast": "CastParser",
"tanh": "TanhParser",
"stack": "StackParser",
"square": "SquareParser",
"unsqueeze2": "UnSqueezeParser",
"assign": "AssignParser",
"softmax": "SoftMaxParser",
"reshape2": "ReshapeParser",
"transpose2": "TransposeParser",
"layer_norm": "LayerNormParser",
"less_than": "LessParser",
"mean": "MeanParser",
"scale": "ScaleParser",
"slice": "SliceParser",
"top_k": "TopkParser",
"accuracy": "AccuracyParser",
#"increment": "IncrementParser",
"lookup_table": "LookupTableParser",
"truncated_gaussian_random": "TruncatedNormalParser",
"c_allgather": "AllGatherParser",
"c_allreduce_sum": "AllReduceSumParser",
"c_allreduce_max": "AllReduceMaxParser",
"c_broadcast": "BroadcastParser",
"c_reduce_scatter": "ReduceScatterParser",
"c_send": "SendParser",
"c_receive": "ReceiveParser",
"uniform_random": "UniformRandomParser",
"range": "RangeParser",
"equal": "EqualParser",
"expand": "ExpandParser",
"squeeze2": "SqueezeParser",
## backwords
"matmul_grad": "MatMulGradParser",
"mul_grad": "MulGradParser",
"relu_grad": "ReluGradParser",
"reduce_sum_grad": "ReduceSumGradParser",
"softmax_with_cross_entropy_grad": "SoftmaxWithCrossEntropyGradParser",
"tanh_grad":"TanhGradParser",
"log_grad":"LogGradParser",
"pow_grad": "PowGradParser",
"sqrt_grad": "SqrtGradParser",
"gelu_grad": "GeluGradParser",
"mean_grad": "MeanGradParser",
'lookup_table_grad': "LookUpTableGradParser",
"elementwise_mul_grad": "DotMulGradParser",
"elementwise_add_grad": "DotAddGradParser",
"elementwise_div_grad": "DotDivGradParser",
"softmax_grad": "SoftmaxGradParser",
"slice_grad": "SliceGradParser",
"reshape2_grad": "ReshapeGradParser",
"gather_grad": "GatherGradParser",
"transpose2_grad": "TransposeGradParser",
"layer_norm_grad": "LayerNormGradParser",
## opt
"sgd": "SGDParser",
#"adam": "AdamParser",
}
global_cnt = -1
global_input_cnt = -1
class AscendHelper(object):
def __init__(self):
self.dtype2ge_map = {
0: core.GEDataType.DT_BOOL,
1: core.GEDataType.DT_INT16,
2: core.GEDataType.DT_INT32,
3: core.GEDataType.DT_INT64,
4: core.GEDataType.DT_FLOAT16,
5: core.GEDataType.DT_FLOAT,
6: core.GEDataType.DT_DOUBLE
}
self.dtype2np_map = {
0: "bool",
1: "int16",
2: "int32",
3: "int64",
4: "float16",
5: "float32",
6: "float64"
}
self.dtype2paddle_inv_map = {"VarType.FP32": 0, "VarType.FP16": 1}
def dtype2ge(self, dtype):
assert dtype in self.dtype2ge_map, "dtype[%d] is not supported %d" % (
dtype)
return self.dtype2ge_map[dtype]
def dtype2np(self, index):
assert index in self.dtype2np_map, "index[%d] is not supported %d" % (
dtype)
return self.dtype2np_map[index]
class AscendParserFactory(object):
def __init__(self, graph, var2geop):
self.graph = graph
self.var2geop = var2geop
def create_parse(self, parser_class):
try:
parser = globals()[parser_class](self.graph, self.var2geop)
return parser
except:
raise ValueError("parser class %s does not exist" % parser_class)
class AscendParserBase(object):
def __init__(self, graph, var2geop):
self.graph = graph
self.var2geop = var2geop
self.op = None
self.ascend_helper = AscendHelper()
def _get_ge_input(self, input_var_name):
assert input_var_name in self.var2geop, "var %s not created before" % (
input_var_name)
return self.var2geop[input_var_name]
def update_output(self, geop_list, index_list):
output_num = len(self.op.output_names)
assert output_num == len(
index_list
), "Parser[%s]'s output number[%d] is not equal to parameters number[%d]" % (
self.parser_name, len(index_list), output_num)
for output_id in range(output_num):
arguments = self.op.output(self.op.output_names[output_id])
#print("%d argument: %s" % (output_id, str(arguments)))
if len(arguments) > 0:
assert len(arguments) == len(
index_list[output_id]
), "Parser[%s]'s %dth argument number[%d] is not equal to paddle's number[%d]" % (
self.parser_name, output_id, len(index_list[output_id]),
len(arguments))
for i in range(len(arguments)):
#print("assgin index_list[%d][%d] to %s" %
# (output_id, i, arguments[i]))
self.var2geop[arguments[i]] = geop_list[index_list[
output_id][i]]
for geop in geop_list:
self.graph.add_op(geop)
def apply(self, op):
self.op = op
assert self.op.type == self.parser_name, "op [%s] != parser_name[%s]" % (
self.op.type, self.parser_name)
#print("begin to parse op %s" % (self.parser_name))
geop_list, index_list = self._apply()
self.update_output(geop_list, index_list)
def _mark_as_input(self, ge_tensor):
global global_input_cnt
global_input_cnt += 1
self.var2geop["geinput." + str(global_input_cnt)] = ge_tensor
def _accumulated_op_id(self):
global global_cnt
global_cnt += 1
name = "." + str(global_cnt)
return name
def _create_ge_tensor(self, shape, dtype, value):
tensor_desc = core.GETensorDesc(
core.GEShape(shape), core.GEFormat.FORMAT_ND,
self.ascend_helper.dtype2ge(dtype))
tensor = core.GETensor(tensor_desc)
data = (value * np.ones((
shape))).reshape(shape).astype(self.ascend_helper.dtype2np(dtype))
buf = data.tobytes()
data_8 = np.frombuffer(buf, dtype=np.uint8)
tensor.set_data(data_8)
return tensor
def _get_ge_tensor(self, shape, dtype, value_list):
tensor_desc = core.GETensorDesc(
core.GEShape(shape), core.GEFormat.FORMAT_ND,
self.ascend_helper.dtype2ge(dtype))
tensor = core.GETensor(tensor_desc)
data = np.array(value_list).reshape(shape).astype(
self.ascend_helper.dtype2np(dtype))
buf = data.tobytes()
data_8 = np.frombuffer(buf, dtype=np.uint8)
tensor.set_data(data_8)
tensor_const = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor)
return tensor_const
def _get_variable(self, shape, dtype, tensor):
if dtype == "int32":
type = core.GEDataType.DT_INT32
elif dtype == "float32":
type = core.GEDataType.DT_FLOAT
var = core.GEOperatorFactory.create_operator(
"variable" + self._accumulated_op_id(), "Variable")
var.update_output_desc("y",
core.GETensorDesc(
core.GEShape(shape), core.GEFormat.FORMAT_ND,
type))
assign = core.GEOperatorFactory.create_operator(
"assign" + self._accumulated_op_id(), "Assign").set_input(
"value", tensor).set_input("ref", var)
return assign
def _create_shape_tensor(self):
tensor_desc = core.GETensorDesc(
core.GEShape([2]), core.GEFormat.FORMAT_ND,
core.GEDataType.DT_INT32)
tensor = core.GETensor(tensor_desc)
data = np.ones((2)).astype("int32").reshape([2])
data[0] = 64
buf = data.tobytes()
data_8 = np.frombuffer(buf, dtype=np.uint8)
tensor.set_data(data_8)
return tensor
def _get_GEtensor_shape(self, tensor):
tensor_shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(), "Shape").set_input("x", tensor)
tensor_shape = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", tensor_shape).set_attr_int32("dst_type", 0)
return tensor_shape
### elementwise_op
class AddParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(AddParser, self).__init__(graph, var2geop)
self.parser_name = "elementwise_add"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
y = self._get_ge_input(self.op.input_arg_names[1])
add = core.GEOperatorFactory.create_operator(
"add" + self._accumulated_op_id(),
"Add").set_input("x1", x).set_input("x2", y)
return [add], [[0]]
class DotSubParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(DotSubParser, self).__init__(graph, var2geop)
self.parser_name = "elementwise_sub"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
y = self._get_ge_input(self.op.input_arg_names[1])
sub = core.GEOperatorFactory.create_operator(
"sub" + self._accumulated_op_id(),
"Sub").set_input("x1", x).set_input("x2", y)
return [sub], [[0]]
class DotMulParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(DotMulParser, self).__init__(graph, var2geop)
self.parser_name = "elementwise_mul"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
y = self._get_ge_input(self.op.input_arg_names[1])
mul = core.GEOperatorFactory.create_operator(
"dotmul" + self._accumulated_op_id(),
"Mul").set_input("x1", x).set_input("x2", y)
return [mul], [[0]]
class DotDivParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(DotDivParser, self).__init__(graph, var2geop)
self.parser_name = "elementwise_div"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
y = self._get_ge_input(self.op.input_arg_names[1])
div = core.GEOperatorFactory.create_operator(
"dotdiv" + self._accumulated_op_id(),
"Div").set_input("x1", x).set_input("x2", y)
return [div], [[0]]
class DotPowParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(DotPowParser, self).__init__(graph, var2geop)
self.parser_name = "elementwise_pow"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
y = self._get_ge_input(self.op.input_arg_names[1])
pow = core.GEOperatorFactory.create_operator(
"dotpow" + self._accumulated_op_id(),
"Pow").set_input("x1", x1).set_input("x2", y)
return [pow], [[0]]
class LessParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(LessParser, self).__init__(graph, var2geop)
self.parser_name = "less_than"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
y = self._get_ge_input(self.op.input_arg_names[1])
less_than = core.GEOperatorFactory.create_operator(
"less_than" + self._accumulated_op_id(),
"Less").set_input("x1", x).set_input("x2", y)
return [less_than], [[0]]
class MaxParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(MaxParser, self).__init__(graph, var2geop)
self.parser_name = "elementwise_max"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
y = self._get_ge_input(self.op.input_arg_names[1])
max_out = core.GEOperatorFactory.create_operator(
"max" + self._accumulated_op_id(),
"Maximum").set_input("x1", x).set_input("x2", y)
return [max_out], [[0]]
class MinParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(MinParser, self).__init__(graph, var2geop)
self.parser_name = "elementwise_min"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
y = self._get_ge_input(self.op.input_arg_names[1])
min_out = core.GEOperatorFactory.create_operator(
"min" + self._accumulated_op_id(),
"Minimum").set_input("x1", x).set_input("x2", y)
return [min_out], [[0]]
## cal
class LogParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(LogParser, self).__init__(graph, var2geop)
self.parser_name = "log"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
log = core.GEOperatorFactory.create_operator(
"log" + self._accumulated_op_id(), "Log").set_input("x", x)
return [log], [[0]]
class SqrtParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SqrtParser, self).__init__(graph, var2geop)
self.parser_name = "sqrt"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
sqrt = core.GEOperatorFactory.create_operator(
"sqrt" + self._accumulated_op_id(), "Sqrt").set_input("x", x)
return [sqrt], [[0]]
class PowParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(PowParser, self).__init__(graph, var2geop)
self.parser_name = "pow"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
factor = self.op.attr("factor")
pow_value = core.GEOperatorFactory.create_operator(
"pow" + self._accumulated_op_id(),
"Power").set_input("x", x).set_attr_float(
"power", factor).set_attr_float("scale", 1.0).set_attr_float(
"shift", 0.0)
return [pow_value], [[0]]
class SquareParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SquareParser, self).__init__(graph, var2geop)
self.parser_name = "square"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
square = core.GEOperatorFactory.create_operator(
"square" + self._accumulated_op_id(), "Square").set_input("x", x)
return [square], [[0]]
class SumParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SumParser, self).__init__(graph, var2geop)
self.parser_name = "sum"
def _apply(self):
len_list = len(self.op.input_arg_names)
if len_list < 2:
assert False, "the size of input list must large or equal 2"
x = self._get_ge_input(self.op.input_arg_names[0])
y = self._get_ge_input(self.op.input_arg_names[1])
sum = core.GEOperatorFactory.create_operator(
"sum" + self._accumulated_op_id(),
"Add").set_input("x1", x).set_input("x2", y)
for i in range(2, len_list):
y = self._get_ge_input(self.op.input_arg_names[i])
sum = core.GEOperatorFactory.create_operator(
"sum" + self._accumulated_op_id(),
"Add").set_input("x1", sum).set_input("x2", y)
return [sum], [[0]]
class LogicalNotParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(LogicalNotParser, self).__init__(graph, var2geop)
self.parser_name = "logical_not"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
logical_not = core.GEOperatorFactory.create_operator(
"logical_not" + self._accumulated_op_id(),
"LogicalNot").set_input("x", x)
return [logical_not], [[0]]
class MeanParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(MeanParser, self).__init__(graph, var2geop)
self.parser_name = "mean"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
mean = core.GEOperatorFactory.create_operator(
"mean" + self._accumulated_op_id(),
"ReduceMeanD").set_input("x", x).set_attr_bool(
"keep_dims", False).set_attr_vec_int32("axes", [])
return [mean], [[0]]
class ReduceSumParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ReduceSumParser, self).__init__(graph, var2geop)
self.parser_name = "reduce_sum"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
axes = self.op.attr("dim")
keep_dims = self.op.attr("keep_dim")
reduce_all = self.op.attr("reduce_all")
x_shape = self.op.block.var(self.op.input_arg_names[0]).shape
if reduce_all:
axes = list(range(len(x_shape)))
reduce_sum = core.GEOperatorFactory.create_operator(
"reduce_sum" + self._accumulated_op_id(),
"ReduceSumD").set_input("x", x, 0).set_attr_vec_int32(
"axes", axes).set_attr_bool("keep_dims", keep_dims)
return [reduce_sum], [[0]]
#class IncrementParser(AscendParserBase):
# def __init__(self, graph, var2geop):
# super(IncrementParser, self).__init__(graph, var2geop)
# self.parser_name = "increment"
#
# def _apply(self):
# x = self._get_ge_input(self.op.input_arg_names[0])
# step = self.op.attr("step") #self._get_ge_input(self.op.input_arg_names[1])
# print("step: ", step)
#
# increment = core.GEOperatorFactory.create_operator("adds" + self._accumulated_op_id(), "Adds").set_input("x", x).set_attr_float("value", step) #set_input("x2", bias)
#
# return [increment]
## matrix cal
class MatMulParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(MatMulParser, self).__init__(graph, var2geop)
self.parser_name = "matmul"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
y = self._get_ge_input(self.op.input_arg_names[1])
transpose_x = self.op.attr("transpose_X")
transpose_y = self.op.attr("transpose_Y")
x1_shape = self.op.block.var(self.op.input_arg_names[0]).shape
x2_shape = self.op.block.var(self.op.input_arg_names[1]).shape
if len(x1_shape) > 2:
matmul = core.GEOperatorFactory.create_operator(
"matmul" + self._accumulated_op_id(), "BatchMatMul").set_input(
"x1", x).set_input("x2", y).set_attr_bool(
"adj_x1",
transpose_x).set_attr_bool("adj_x2", transpose_y)
elif len(x1_shape) == 2:
matmul = core.GEOperatorFactory.create_operator(
"matmul" + self._accumulated_op_id(),
"MatMul").set_input("x1", x).set_input("x2", y).set_attr_bool(
"transpose_x1", transpose_x).set_attr_bool("transpose_x2",
transpose_y)
else:
assert False, "not support"
return [matmul], [[0]]
class MulParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(MulParser, self).__init__(graph, var2geop)
self.parser_name = "mul"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
y = self._get_ge_input(self.op.input_arg_names[1])
x_num_col_dims = self.op.attr("x_num_col_dims")
y_num_col_dims = self.op.attr("y_num_col_dims")
shape_x1 = self.op.block.var(self.op.input_arg_names[0]).shape
shape_x2 = self.op.block.var(self.op.input_arg_names[1]).shape
if x_num_col_dims == 1 and y_num_col_dims == 1:
if len(shape_x1) == 2 and len(shape_x2) == 2:
matmul = core.GEOperatorFactory.create_operator(
"mul" + self._accumulated_op_id(),
"MatMul").set_input("x1", x).set_input("x2", y)
elif len(shape_x1) == 3 and len(shape_x2) == 2:
flatten_x1 = core.GEOperatorFactory.create_operator(
"flatten" + self._accumulated_op_id(),
"Flatten").set_input("x", x)
matmul = core.GEOperatorFactory.create_operator(
"mul" + self._accumulated_op_id(), "MatMul").set_input(
"x1", flatten_x1, 0).set_input("x2", y, 0)
else:
assert False, "not support"
else:
if len(shape_x1) == 3 and len(shape_x2) == 2:
assert x_num_col_dims == 2, "only support 2"
flatten_x1 = core.GEOperatorFactory.create_operator(
"flatten" + self._accumulated_op_id(),
"FlattenV2").set_input("x", x).set_attr_int32(
"axis", 0).set_attr_int32("end_axis", 1)
matmul_m = core.GEOperatorFactory.create_operator(
"mul" + self._accumulated_op_id(), "MatMul").set_input(
"x1", flatten_x1, 0).set_input("x2", y, 0)
matmul_transpose = core.GEOperatorFactory.create_operator(
"transpose" + self._accumulated_op_id(),
"TransposeD").set_input(
"x", matmul_m).set_attr_vec_int32("perm", [1, 0])
tensor = self._create_ge_tensor(
[3], 2, [shape_x2[1], shape_x1[0], shape_x1[1]])
const_shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor)
reshape_matmul = core.GEOperatorFactory.create_operator(
"reshape" + self._accumulated_op_id(), "Reshape").set_input(
"x", matmul_transpose).set_input(
"shape", const_shape).set_attr_int32("axis", 0)
matmul = core.GEOperatorFactory.create_operator(
"transpose" + self._accumulated_op_id(),
"TransposeD").set_input(
"x",
reshape_matmul).set_attr_vec_int32("perm", [1, 2, 0])
else:
assert False, "not support"
return [matmul], [[0]]
class LayerNormParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(LayerNormParser, self).__init__(graph, var2geop)
self.parser_name = "layer_norm"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[2])
scale = self._get_ge_input(self.op.input_arg_names[1])
bias = self._get_ge_input(self.op.input_arg_names[0])
epsilon = self.op.attr("epsilon")
begin_norm_axis = self.op.attr("begin_norm_axis")
x_dtype = self.op.block.var(self.op.input_arg_names[2]).dtype
shape_tensor = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
scale_expand = core.GEOperatorFactory.create_operator(
"broadcast_to_d" + self._accumulated_op_id(),
"BroadcastTo").set_input("x",
scale).set_input("shape", shape_tensor)
bias_expand = core.GEOperatorFactory.create_operator(
"broadcast_to_d" + self._accumulated_op_id(),
"BroadcastTo").set_input("x", bias).set_input("shape", shape_tensor)
layer_norm = core.GEOperatorFactory.create_operator(
"layer_norm" + self._accumulated_op_id(),
"LayerNorm").set_input("x", x).set_input(
"gamma",
scale_expand).set_input("beta", bias_expand).set_attr_int32(
"begin_norm_axis", begin_norm_axis).set_attr_int32(
"begin_params_axis",
begin_norm_axis).set_attr_float("epsilon", epsilon)
cast_dtype = 0 if self.ascend_helper.dtype2paddle_inv_map[str(
x_dtype)] == 0 else 1
y = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", layer_norm, 0).set_attr_int32("dst_type", cast_dtype)
mean = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", layer_norm, 1).set_attr_int32("dst_type", cast_dtype)
variance = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", layer_norm, 2).set_attr_int32("dst_type", cast_dtype)
return [y, mean, variance], [[1], [2], [0]]
## activate function
class ReluParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ReluParser, self).__init__(graph, var2geop)
self.parser_name = "relu"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
relu = core.GEOperatorFactory.create_operator(
"relu" + self._accumulated_op_id(), "Relu").set_input("x", x)
return [relu], [[0]]
class GeluParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(GeluParser, self).__init__(graph, var2geop)
self.parser_name = "gelu"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
gelu = core.GEOperatorFactory.create_operator(
"gelu" + self._accumulated_op_id(), "Gelu").set_input("x", x)
return [gelu], [[0]]
class TanhParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(TanhParser, self).__init__(graph, var2geop)
self.parser_name = "tanh"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
tanh = core.GEOperatorFactory.create_operator(
"tanh" + self._accumulated_op_id(), "Tanh").set_input("x", x)
return [tanh], [[0]]
## loss function
class SoftmaxWithCrossEntropyParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SoftmaxWithCrossEntropyParser, self).__init__(graph, var2geop)
self.parser_name = "softmax_with_cross_entropy"
def _apply(self):
label = self._get_ge_input(self.op.input_arg_names[0])
logits = self._get_ge_input(self.op.input_arg_names[1])
cls_num = self.op.block.var(self.op.input_arg_names[1]).shape[1]
softmax = core.GEOperatorFactory.create_operator(
"softmax" + self._accumulated_op_id(),
"SoftmaxV2").set_input("x", logits)
label = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", label).set_attr_int32("dst_type", 3)
tensoron = self._create_ge_tensor([1], 5, 1)
on = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensoron)
tensoroff = self._create_ge_tensor([1], 5, 0)
off = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensoroff)
self._mark_as_input(on)
self._mark_as_input(off)
onehot = core.GEOperatorFactory.create_operator(
"onehot" + self._accumulated_op_id(), "OneHotD").set_input(
"x", label).set_input("on_value", on).set_input(
"off_value", off).set_attr_int32("depth", cls_num)
squeeze = core.GEOperatorFactory.create_operator(
"mul" + self._accumulated_op_id(), "Squeeze").set_input("x", onehot)
loss_all = core.GEOperatorFactory.create_operator(
"loss" + self._accumulated_op_id(),
"SoftmaxCrossEntropyWithLogits").set_input(
"features", logits).set_input("labels", squeeze)
loss = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", loss_all, 0).set_attr_int32("dst_type", 0)
loss_expand = core.GEOperatorFactory.create_operator(
"unsqueeze" + self._accumulated_op_id(),
"Unsqueeze").set_input("x", loss).set_attr_vec_int32("axes", [1])
return [label, softmax, loss_expand], [[2], [1]]
class SoftMaxParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SoftMaxParser, self).__init__(graph, var2geop)
self.parser_name = "softmax"
def _apply(self):
logits = self._get_ge_input(self.op.input_arg_names[0])
axes = self.op.attr("axis")
softmax = core.GEOperatorFactory.create_operator(
"softmax" + self._accumulated_op_id(), "SoftmaxV2").set_input(
"x", logits).set_attr_vec_int32("axes", [axes])
return [softmax], [[0]]
## general
class ShapeParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ShapeParser, self).__init__(graph, var2geop)
self.parser_name = "shape"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
return [shape], [[0]]
class FillConstantParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(FillConstantParser, self).__init__(graph, var2geop)
self.parser_name = "fill_constant"
def _apply(self):
shape = self.op.attr("shape")
dtype = self.op.attr("dtype")
value = self.op.attr("value")
tensor = self._create_ge_tensor(shape, dtype, value)
const = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor)
self._mark_as_input(const)
if self.op.block.var(self.op.output('Out')[0]).persistable:
#print("%s is Persistable in fill_constant" %
# (self.op.output('Out')[0]))
var = core.GEOperatorFactory.create_operator(
self.op.output('Out')[0], "Variable")
var.update_output_desc("y",
core.GETensorDesc(
core.GEShape(shape),
core.GEFormat.FORMAT_ND,
core.GEDataType.DT_FLOAT))
assign = core.GEOperatorFactory.create_operator(
"assign" + self._accumulated_op_id(), "Assign").set_input(
"value", const).set_input("ref", var)
return [const], [[0]]
#else:
# print(
# "self.op.output('Out')[0]: %s is not persistable in fill_constant"
# % (self.op.output('Out')[0]))
return [const], [[0]]
class TruncatedNormalParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(TruncatedNormalParser, self).__init__(graph, var2geop)
self.parser_name = "truncated_gaussian_random"
def _apply(self):
shape = self.op.attr("shape")
dtype = self.op.attr("dtype")
mean = self.op.attr("mean")
std = self.op.attr("std")
seed = self.op.attr("seed")
tensor1 = self._create_ge_tensor([len(shape)], 2, shape)
shape_tensor = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor1)
tensor2 = self._create_ge_tensor([1], dtype, mean)
mean_tensor = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor2)
tensor3 = self._create_ge_tensor([1], dtype, std)
std_tensor = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor3)
tensor4 = self._create_ge_tensor([1], dtype, mean - 2 * std)
min_tensor = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor4)
tensor5 = self._create_ge_tensor([1], dtype, mean + 2 * std)
max_tensor = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor5)
self._mark_as_input(shape_tensor)
self._mark_as_input(mean_tensor)
self._mark_as_input(std_tensor)
self._mark_as_input(min_tensor)
self._mark_as_input(max_tensor)
truncated_normal = core.GEOperatorFactory.create_operator(
"truncated_normal" + self._accumulated_op_id(),
"ParameterizedTruncatedNormal").set_input(
"shape", shape_tensor).set_input(
"means", mean_tensor).set_input(
"stdevs", std_tensor).set_input(
"min", min_tensor).set_input(
"max", max_tensor).set_attr_int32("seed", 0)
## wirte the output of truncatedNormal from startup_program to main_program
if self.op.block.var(self.op.output('Out')[0]).persistable:
#print("%s is Persistable in truncated_normal" %
# (self.op.output('Out')[0]))
var = core.GEOperatorFactory.create_operator(
self.op.output('Out')[0], "Variable")
var.update_output_desc("y",
core.GETensorDesc(
core.GEShape(shape),
core.GEFormat.FORMAT_ND,
core.GEDataType.DT_FLOAT))
assign = core.GEOperatorFactory.create_operator(
"assign" + self._accumulated_op_id(), "Assign").set_input(
"value", truncated_normal).set_input("ref", var)
return [
shape_tensor, mean_tensor, std_tensor, min_tensor, max_tensor,
truncated_normal
], [[-1]]
#else:
# print(
# "self.op.output('Out')[0] is not persistable in truncated_noraml"
# )
return [truncated_normal], [[0]]
class GatherParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(GatherParser, self).__init__(graph, var2geop)
self.parser_name = "gather"
def _apply(self):
index = self._get_ge_input(self.op.input_arg_names[0])
x = self._get_ge_input(self.op.input_arg_names[1])
clo = self.op.block.var(self.op.input_arg_names[1]).shape[-1]
gather = core.GEOperatorFactory.create_operator(
"gather" + self._accumulated_op_id(), "Gather").set_input(
"x", x).set_input("indices", index).set_attr_bool(
"validate_indices", True)
return [gather], [[0]]
class ScatterParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ScatterParser, self).__init__(graph, var2geop)
self.parser_name = "scatter"
def _apply(self):
index = self._get_ge_input(self.op.input_arg_names[0])
x = self._get_ge_input(self.op.input_arg_names[1])
updates = self._get_ge_input(self.op.input_arg_names[2])
overwrite = self.op.attr("overwrite")
index_shape = self.op.block.var(self.op.input_arg_names[0]).shape
if len(index_shape) == 1:
index = core.GEOperatorFactory.create_operator(
"unsqueeze" + self.getid(), "Unsqueeze").set_input(
"x", index).set_attr_vec_int32("axes", [1])
if not overwrite:
scatter_value = core.GEOperatorFactory.create_operator(
"scatter" + self._accumulated_op_id(),
"TensorScatterAdd").set_input(
"x", x_var).set_input("indices", index_var).set_input(
"updates", updatesi_var)
else:
scatter_value = core.GEOperatorFactory.create_operator(
"scatter" + self._accumulated_op_id(),
"TensorScatterUpdate").set_input(
"x", x_var).set_input("indices", index_var).set_input(
"updates", updates_var)
return [x_var, index_var, updates_var, scatter_value], [[-1]]
class CastParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(CastParser, self).__init__(graph, var2geop)
self.parser_name = "cast"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
dtype = self.op.attr("out_dtype")
cast = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", x).set_attr_int32("dst_type", dtype)
return [cast], [[0]]
class AssignParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(AssignParser, self).__init__(graph, var2geop)
self.parser_name = "assign"
def _apply(self):
const = self._get_ge_input(self.op.input_arg_names[0])
var = self._get_ge_input(self.op.input_arg_names[1])
assign = core.GEOperatorFactory.create_operator(
"assign" + self._accumulated_op_id(), "Assign").set_input(
"value", const).set_input("ref", var)
return [assign], [[0]]
class ScaleParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ScaleParser, self).__init__(graph, var2geop)
self.parser_name = "scale"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
scale = self.op.attr("scale")
bias = self.op.attr("bias")
bias_after_scale = self.op.attr("bias_after_scale")
if bias_after_scale:
scale_value = core.GEOperatorFactory.create_operator(
"scale" + self._accumulated_op_id(), "Power").set_input(
"x", x).set_attr_float("power", 1.0).set_attr_float(
"scale", scale).set_attr_float("shift", bias)
else:
x_add_bias = core.GEOperatorFactory.create_operator(
"adds" + self._accumulated_op_id(), "Adds").set_input(
"x", x).set_attr_float("value", bias)
scale_value = core.GEOperatorFactory.create_operator(
"scale" + self._accumulated_op_id(), "Power").set_input(
"x",
x_add_bias).set_attr_float("power", 1.0).set_attr_float(
"scale", scale).set_attr_float("shift", 0.0)
return [scale_value], [[0]]
class SliceParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SliceParser, self).__init__(graph, var2geop)
self.parser_name = "slice"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
axes = self.op.attr("axes")
starts = self.op.attr("starts")
ends = self.op.attr("ends")
x_shape = self.op.block.var(self.op.input_arg_names[0]).shape
len_shape = len(x_shape)
axes_cor = list(range(len_shape))
starts_cor, ends_cor = [], []
cnt = 0
for i in range(len_shape):
starts_cor.append(starts[cnt] if i in axes else 0)
if i in axes and ends[cnt] <= x_shape[i]:
ends_cor.append(ends[cnt])
else:
ends_cor.append(x_shape[i])
if i in axes:
cnt += 1
size = [ends_cor[i] - starts_cor[i] for i in range(len(axes_cor))]
assert len(axes_cor) == len(starts_cor) == len(
ends_cor), "the three fields must have same size"
slice_value = core.GEOperatorFactory.create_operator(
"slice" + self._accumulated_op_id(), "SliceD").set_input(
"x", x).set_attr_vec_int32(
"offsets", starts_cor).set_attr_vec_int32("size", size)
return [slice_value], [[0]]
class ReshapeParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ReshapeParser, self).__init__(graph, var2geop)
self.parser_name = "reshape2"
def _apply(self):
org_shape = self.op.block.var(self.op.input_arg_names[0]).shape
assert org_shape.count(-1) == 0, "do not allow the dim is -1"
shape = self.op.attr("shape")
for cnt in range(len(shape)):
if shape[cnt] == 0:
shape[cnt] = org_shape[cnt]
if -1 in shape:
assert shape.count(-1) == 1, "only allow one dim is -1"
mul_res_org = reduce(lambda x, y: x * y, org_shape)
mul_res_refine = reduce(lambda x, y: x * y, shape) * -1
idx = shape.index(-1)
shape[idx] = mul_res_org // mul_res_refine
x = self._get_ge_input(self.op.input_arg_names[0])
tensor = self._create_ge_tensor([len(shape)], 2, shape)
const_shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor)
reshape = core.GEOperatorFactory.create_operator(
"reshape" + self._accumulated_op_id(), "Reshape").set_input(
"x",
x).set_input("shape", const_shape).set_attr_int32("axis", 0)
x_shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
return [x_shape, reshape], [[1], [0]]
class TransposeParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(TransposeParser, self).__init__(graph, var2geop)
self.parser_name = "transpose2"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
perm = self.op.attr("axis")
transpose = core.GEOperatorFactory.create_operator(
"transpose" + self._accumulated_op_id(), "TransposeD").set_input(
"x", x).set_attr_vec_int32("perm", perm)
x_shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
return [x_shape, transpose], [[1], [0]]
class AccuracyParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(AccuracyParser, self).__init__(graph, var2geop)
self.parser_name = "accuracy"
def _apply(self):
pred = self._get_ge_input(self.op.input_arg_names[0])
label = self._get_ge_input(self.op.input_arg_names[1])
logits = self._get_ge_input(self.op.input_arg_names[2])
pred = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", pred).set_attr_int32("dst_type", 3)
label = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", label).set_attr_int32("dst_type", 3)
equal = core.GEOperatorFactory.create_operator(
"equal" + self._accumulated_op_id(), "Equal").set_input(
"x1", pred).set_input("x2", label)
cast = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", equal).set_attr_int32("dst_type", 0)
acc = core.GEOperatorFactory.create_operator(
"mean" + self._accumulated_op_id(), "ReduceMeanD").set_input(
"x", cast).set_attr_bool("keep_dims", False).set_attr_vec_int32(
"axes", [])
correct = core.GEOperatorFactory.create_operator(
"sum" + self._accumulated_op_id(), "ReduceSumD").set_input(
"x", cast).set_attr_bool("keep_dims", False).set_attr_vec_int32(
"axes", [])
ones_tensor = core.GEOperatorFactory.create_operator(
"oneslike" + self._accumulated_op_id(),
"OnesLike").set_input("x", label)
ones_tensor = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", ones_tensor).set_attr_int32("dst_type", 0)
total = core.GEOperatorFactory.create_operator(
"sum" + self._accumulated_op_id(), "ReduceSumD").set_input(
"x", ones_tensor).set_attr_bool(
"keep_dims", False).set_attr_vec_int32("axes", [])
return [acc, correct, total], [[0], [1], [2]]
class TopkParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(TopkParser, self).__init__(graph, var2geop)
self.parser_name = "top_k"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
k = self.op.attr("k")
tensor = self._create_ge_tensor([1], 2, k)
const_k = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor)
cast_x = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(),
"Cast").set_input("x", x).set_attr_int32("dst_type", 1)
topk = core.GEOperatorFactory.create_operator(
"topk" + self._accumulated_op_id(),
"TopK").set_input("x", cast_x).set_input("k", const_k)
value = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", topk, 0).set_attr_int32("dst_type", 0)
index = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", topk, 1).set_attr_int32("dst_type", 0)
return [value, index], [[1], [0]]
class LookupTableParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(LookupTableParser, self).__init__(graph, var2geop)
self.parser_name = "lookup_table"
def _apply(self):
ids = self._get_ge_input(self.op.input_arg_names[0])
w = self._get_ge_input(self.op.input_arg_names[1])
ids_squeeze = core.GEOperatorFactory.create_operator(
"squeeze" + self._accumulated_op_id(), "Squeeze").set_input(
"x", ids).set_attr_vec_int32("axes", [-1])
out = core.GEOperatorFactory.create_operator(
"lookup" + self._accumulated_op_id(), "Gather").set_input(
"x", w).set_input("indices", ids_squeeze)
return [out], [[0]]
class StackParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(StackParser, self).__init__(graph, var2geop)
self.parser_name = "stack"
def _apply(self):
tiles = len(self.op.input_arg_names)
data_x_lst = []
for index in range(tiles):
data_x_lst.append(
self._get_ge_input(self.op.input_arg_names[index]))
axis = self.op.attr("axis")
data_x = data_x_lst[0]
tensor = self._create_ge_tensor([1], 2, axis)
tensor_axis = core.GEOperatorFactory.create_operator(
"axis" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor)
expand = core.GEOperatorFactory.create_operator(
"expand" + self._accumulated_op_id(),
"ExpandDims").set_input("x", data_x).set_input("axis", tensor_axis)
stack = core.GEOperatorFactory.create_operator(
"stack" + self._accumulated_op_id(),
"TileWithAxis").set_input("x", expand).set_attr_int32(
"axis", axis).set_attr_int32("tiles", tiles)
return [stack], [[0]]
class UnSqueezeParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(UnSqueezeParser, self).__init__(graph, var2geop)
self.parser_name = "unsqueeze2"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
axes = self.op.attr('axes')
output = core.GEOperatorFactory.create_operator(
"unsqueeze" + self._accumulated_op_id(),
"Unsqueeze").set_input("x", x).set_attr_vec_int32("axes", axes)
shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(), "Shape").set_input("x", output)
return [shape, output], [[1], [0]]
## parallel
class AllGatherParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(AllGatherParser, self).__init__(graph, var2geop)
self.parser_name = "c_allgather"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
rank_size = self.op.attr("rank_size")
group = self.op.attr("group")
allgather = core.GEOperatorFactory.create_operator(
"allgather" + self._accumulated_op_id(), "HcomAllGather").set_input(
"x", x).set_attr_int32(
"rank_size", rank_size).set_attr_string("group", group)
return [allgather], [[0]]
class AllReduceParser(AscendParserBase):
def __init__(self, graph, var2geop, reduction):
super(AllReduceParser, self).__init__(graph, var2geop)
self.parser_name = "c_allreduce_" + reduction
self.reduction = reduction
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
reduction = self.reduction
ring_id = self.op.attr("ring_id")
group = "hcom_group_" + str(ring_id)
fusion = None #self.op.attr("fusion")
fusion_id = None #self.op.attr("fusion_id")
allreduce = core.GEOperatorFactory.create_operator(
"allreduce" + self._accumulated_op_id(), "HcomAllReduce").set_input(
"x", x).set_attr_string(
"reduction", reduction).set_attr_string("group", group)
if fusion is not None:
allreduce.set_attr_int32("fusion", fusion)
if fusion_id is not None:
allreduce.set_attr_int32("fusion_id", fusion_id)
return [allreduce], [[0]]
class AllReduceSumParser(AllReduceParser):
def __init__(self, graph, var2geop):
super(AllReduceSumParser, self).__init__(graph, var2geop, 'sum')
class AllReduceMaxParser(AllReduceParser):
def __init__(self, graph, var2geop):
super(AllReduceMaxParser, self).__init__(graph, var2geop, 'max')
class BroadcastParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(BroadcastParser, self).__init__(graph, var2geop)
self.parser_name = "c_broadcast"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
root_rank = self.op.attr("root_rank")
group = self.op.attr("group")
broadcast = core.GEOperatorFactory.create_operator(
"broadcast" + self._accumulated_op_id(), "HcomBroadcast").set_input(
"x", x).set_attr_int32(
"root_rank", root_rank).set_attr_string("group", group)
return [broadcast], [[0]]
class ReduceScatterParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ReduceScatterParser, self).__init__(graph, var2geop)
self.parser_name = "c_reduce_scatter"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
reduction = self.op.attr("reduction")
group = self.op.attr("group")
rank_size = self.op.attr("rank_size")
reduce_scatter = core.GEOperatorFactory.create_operator(
"reducescatter" + self._accumulated_op_id(),
"HcomReduceScatter").set_input("x", x).set_attr_string(
"reduction", reduction).set_attr_string(
"group", group).set_attr_int32("rank_size", rank_size)
return [reduce_scatter], [[0]]
class SendParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SendParser, self).__init__(graph, var2geop)
self.parser_name = "c_send"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
sr_tag = self.op.attr("sr_tag")
dest_rank = self.op.attr("dest_rank")
group = self.op.attr("group")
send = core.GEOperatorFactory.create_operator(
"send" + self._accumulated_op_id(), "HcomSend").set_input(
"x", x).set_attr_int32("sr_tag", sr_tag).set_attr_int32(
"dest_rank", dest_rank).set_attr_string("group", group)
return [send], [[0]]
class ReceiveParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ReceiveParser, self).__init__(graph, var2geop)
self.parser_name = "c_receive"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
sr_tag = self.op.attr("sr_tag")
src_rank = self.op.attr("src_rank")
group = self.op.attr("group")
shape = self.op.attr("shape")
dtype = self.op.attr("dtype")
receive = core.GEOperatorFactory.create_operator(
"receive" + self._accumulated_op_id(), "HcomReceive").set_input(
"x", x).set_attr_int32("sr_tag", sr_tag).set_attr_int32(
"src_rank", src_rank).set_attr_string(
"group", group).set_attr_vec_int32(
"shape", shape).set_attr_int32("dtype", dtype)
return [receive], [[0]]
class RangeParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(RangeParser, self).__init__(graph, var2geop)
self.parser_name = "range"
def _apply(self):
# TODO not support range type yet
start = self._get_ge_input(self.op.input_arg_names[0])
end = self._get_ge_input(self.op.input_arg_names[1])
delta = self._get_ge_input(self.op.input_arg_names[2])
ge_range = core.GEOperatorFactory.create_operator(
"range" + self._accumulated_op_id(), "Range")\
.set_input("start", end)\
.set_input("limit", start) \
.set_input("delta", delta)
return [ge_range], [[0]]
class UniformRandomParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(UniformRandomParser, self).__init__(graph, var2geop)
self.parser_name = "uniform_random"
def _apply(self):
shape = self.op.attr("shape")
min_v = self.op.attr("min")
max_v = self.op.attr("max")
seed = self.op.attr("seed")
dtype = self.op.attr("dtype")
assert max_v > min_v, "assert max_v > min_v, but recieved " + \
"as max_v={}, min_v={} ".format(max_v, min_v)
tensor1 = self._create_ge_tensor([len(shape)], 2, shape)
shape_tensor = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor1)
ge_ur = core.GEOperatorFactory.create_operator(
"uniform_random" + self._accumulated_op_id(), "RandomUniform")\
.set_input("shape", shape_tensor)\
.set_attr_dtype("dtype", self.ascend_helper.dtype2ge(dtype)) \
.set_attr_int32("seed", seed)\
.set_attr_int32("seed2", seed)
scale = max_v - min_v
scale_value = core.GEOperatorFactory.create_operator(
"scale" + self._accumulated_op_id(), "Power").set_input(
"x", ge_ur).set_attr_float("power", 1.0).set_attr_float(
"scale", scale).set_attr_float("shift", min_v)
return [scale_value], [[0]]
class EqualParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(EqualParser, self).__init__(graph, var2geop)
self.parser_name = "equal"
def _apply(self):
data_x1 = self._get_ge_input(self.op.input_arg_names[0])
data_x2 = self._get_ge_input(self.op.input_arg_names[1])
equal = core.GEOperatorFactory.create_operator("equal" \
+ self._accumulated_op_id(), "Equal")\
.set_input("x1", data_x1)\
.set_input("x2", data_x2)
return [equal], [[0]]
class ExpandParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ExpandParser, self).__init__(graph, var2geop)
self.parser_name = "expand"
def _apply(self):
data_x1_shape = self._get_ge_input(self.op.input_arg_names[0])
expand_times = self.op.attr('expand_times')
tensor = self._create_ge_tensor([len(expand_times)], 2, expand_times)
expand_tensor = core.GEOperatorFactory.\
create_operator("const" + self._accumulated_op_id(), "Const")\
.set_attr_tensor("value", tensor)
assign = core.GEOperatorFactory\
.create_operator("tile" + self._accumulated_op_id(), "Tile")\
.set_input("x", data_x1_shape)\
.set_input("multiples", expand_tensor)
return [assign], [[0]]
class SqueezeParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SqueezeParser, self).__init__(graph, var2geop)
self.parser_name = "squeeze2"
def _apply(self):
tensor = self._get_ge_input(self.op.input_arg_names[0])
axes = self.op.attr("axes")
data_squeezed = core.GEOperatorFactory\
.create_operator("squeeze" + self._accumulated_op_id(), "Squeeze")\
.set_input("x", tensor)\
.set_attr_vec_int32("axes", axes)
shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(),
"Shape").set_input("x", data_squeezed)
return [shape, data_squeezed], [[1], [0]]
#****************************************************************#
#*************************** *************************#
#*************************** *************************#
#*************************** GradParser *************************#
#*************************** *************************#
#*************************** *************************#
#****************************************************************#
## grad
class ReduceSumGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ReduceSumGradParser, self).__init__(graph, var2geop)
self.parser_name = "reduce_sum_grad"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
input = self._get_ge_input(self.op.input_arg_names[1])
shape_tensor = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(),
"Shape").set_input("x", input, 0)
tensoron = self._create_ge_tensor([1], 2, -1)
const = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensoron)
self._mark_as_input(const)
reduce_sum = core.GEOperatorFactory.create_operator(
"broadcast_to_d" + self._accumulated_op_id(),
"BroadcastTo").set_input("x", x).set_input("shape", shape_tensor)
#reduce_sum = core.GEOperatorFactory.create_operator("expand" + self._accumulated_op_id(), "ExpandDims").set_input("x", reduce_sum).set_input("axis", const)
return [reduce_sum], [[0]]
class MatMulGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(MatMulGradParser, self).__init__(graph, var2geop)
self.parser_name = "matmul_grad"
def _apply(self):
out_grad = self._get_ge_input(self.op.input_arg_names[0])
x = self._get_ge_input(self.op.input_arg_names[1])
y = self._get_ge_input(self.op.input_arg_names[2])
transpose_x = self.op.attr("transpose_X")
transpose_y = self.op.attr("transpose_Y")
out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape
x_shape = self.op.block.var(self.op.input_arg_names[1]).shape
y_shape = self.op.block.var(self.op.input_arg_names[2]).shape
if len(x_shape) > 2:
if transpose_y:
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"BatchMatMul").set_input("x1", out_grad).set_input(
"x2", y).set_attr_bool(
"adj_x1", False).set_attr_bool("adj_x2", False)
y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"BatchMatMul").set_input("x1", out_grad).set_input(
"x2", x).set_attr_bool(
"adj_x1", True).set_attr_bool("adj_x2", False)
else:
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"BatchMatMul").set_input("x1", out_grad).set_input(
"x2", y).set_attr_bool(
"adj_x1", False).set_attr_bool("adj_x2", True)
y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"BatchMatMul").set_input("x1", x).set_input(
"x2", out_grad).set_attr_bool(
"adj_x1", True).set_attr_bool("adj_x2", False)
else:
if transpose_y:
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"MatMul").set_input("x1", out_grad).set_input(
"x2", y).set_attr_bool(
"transpose_x1", False).set_attr_bool("transpose_x2",
False)
y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"MatMul").set_input("x1", out_grad).set_input(
"x2", x).set_attr_bool(
"transpose_x1", True).set_attr_bool("transpose_x2",
False)
else:
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"MatMul").set_input("x1", out_grad).set_input(
"x2", y).set_attr_bool(
"transpose_x1", False).set_attr_bool("transpose_x2",
True)
y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"MatMul").set_input("x1", x).set_input(
"x2", out_grad).set_attr_bool(
"transpose_x1", True).set_attr_bool("transpose_x2",
False)
return [x_grad, y_grad], [[0], [1]]
class MulGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(MulGradParser, self).__init__(graph, var2geop)
self.parser_name = "mul_grad"
def _apply(self):
out_grad = self._get_ge_input(self.op.input_arg_names[0])
x = self._get_ge_input(self.op.input_arg_names[1])
y = self._get_ge_input(self.op.input_arg_names[2])
x_num_col_dims = self.op.attr("x_num_col_dims")
y_num_col_dims = self.op.attr("y_num_col_dims")
shape_out_grad = self.op.block.var(self.op.input_arg_names[0]).shape
shape_x = self.op.block.var(self.op.input_arg_names[1]).shape
shape_y = self.op.block.var(self.op.input_arg_names[2]).shape
if x_num_col_dims == 1 and y_num_col_dims == 1:
if len(shape_x) == 2 and len(shape_y) == 2:
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"MatMul").set_input("x1", out_grad).set_input(
"x2", y).set_attr_bool(
"transpose_x1", False).set_attr_bool("transpose_x2",
True)
y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"MatMul").set_input("x1", x).set_input(
"x2", out_grad).set_attr_bool(
"transpose_x1", True).set_attr_bool("transpose_x2",
False)
elif len(shape_x) == 3 and len(shape_y) == 2:
flatten_x = core.GEOperatorFactory.create_operator(
"flatten" + self._accumulated_op_id(),
"Flatten").set_input("x", x)
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"MatMul").set_input(
"x1", out_grad).set_input("x2", y).set_attr_bool(
"transpose_x1",
False).set_attr_bool("transpose_x2", True)
if len(shape_out_grad) == 2:
x_grad = core.GEOperatorFactory.create_operator(
"unsqueeze" + self._accumulated_op_id(),
"Unsqueeze").set_input("x", x_grad).set_attr_vec_int32(
"axes", [1])
y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"MatMul").set_input(
"x1",
flatten_x).set_input("x2", out_grad).set_attr_bool(
"transpose_x1",
True).set_attr_bool("transpose_x2", False)
else:
if len(shape_x) == 3 and len(shape_y) == 2:
assert x_num_col_dims == 2, "only support 2"
flatten_x = core.GEOperatorFactory.create_operator(
"flatten" + self._accumulated_op_id(),
"FlattenV2").set_input("x", x).set_attr_int32(
"axis", 0).set_attr_int32("end_axis", 1)
flatten_out_grad = core.GEOperatorFactory.create_operator(
"flatten" + self._accumulated_op_id(),
"FlattenV2").set_input("x", out_grad).set_attr_int32(
"axis", 0).set_attr_int32("end_axis", 1)
y_unsqueeze = core.GEOperatorFactory.create_operator(
"unsqueeze" + self._accumulated_op_id(),
"Unsqueeze").set_input("x",
y).set_attr_vec_int32("axes", [0])
y_stack = core.GEOperatorFactory.create_operator(
"stack" + self._accumulated_op_id(),
"TileWithAxis").set_input("x", y_unsqueeze).set_attr_int32(
"axis", 0).set_attr_int32("tiles", shape_out_grad[0])
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"BatchMatMul").set_input("x1", out_grad).set_input(
"x2", y_stack).set_attr_bool(
"adj_x1", False).set_attr_bool("adj_x2", True)
y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"MatMul").set_input("x1", flatten_x).set_input(
"x2", flatten_out_grad).set_attr_bool(
"transpose_x1",
True).set_attr_bool("transpose_x2", False)
return [x_grad, y_grad], [[0], [1]]
class ReluGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ReluGradParser, self).__init__(graph, var2geop)
self.parser_name = "relu_grad"
def _apply(self):
out = self._get_ge_input(self.op.input_arg_names[0])
out_grad = self._get_ge_input(self.op.input_arg_names[1])
relu_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(), "ReluGrad").set_input(
"gradients", out_grad).set_input("features", out)
return [relu_grad], [[0]]
class SoftmaxWithCrossEntropyGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SoftmaxWithCrossEntropyGradParser, self).__init__(graph, var2geop)
self.parser_name = "softmax_with_cross_entropy_grad"
def _apply(self):
label = self._get_ge_input(self.op.input_arg_names[0])
loss_grad = self._get_ge_input(self.op.input_arg_names[1])
softmax = self._get_ge_input(self.op.input_arg_names[2])
cls_num = self.op.block.var(self.op.input_arg_names[2]).shape[1]
label_shape = self.op.block.var(self.op.input_arg_names[0]).shape
loss_grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape
softmax_shape = self.op.block.var(self.op.input_arg_names[2]).shape
tensoron = self._create_ge_tensor([1], 5, 1)
on = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensoron)
tensoroff = self._create_ge_tensor([1], 5, 0)
off = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensoroff)
self._mark_as_input(on)
self._mark_as_input(off)
label = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", label).set_attr_int32("dst_type", 3)
onehot = core.GEOperatorFactory.create_operator(
"onehot" + self._accumulated_op_id(), "OneHotD").set_input(
"x", label).set_input("on_value", on).set_input(
"off_value", off).set_attr_int32("depth", cls_num)
squeeze = core.GEOperatorFactory.create_operator(
"suqeeze" + self._accumulated_op_id(),
"Squeeze").set_input("x", onehot)
sub = core.GEOperatorFactory.create_operator(
"sub" + self._accumulated_op_id(), "Sub").set_input(
"x1", softmax).set_input("x2", squeeze)
grad = core.GEOperatorFactory.create_operator(
"mul" + self._accumulated_op_id(),
"Mul").set_input("x1", loss_grad).set_input("x2", sub)
return [on, off, label, onehot, grad], [[-1]]
class DotMulGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(DotMulGradParser, self).__init__(graph, var2geop)
self.parser_name = "elementwise_mul_grad"
def _apply(self):
out_grad = self._get_ge_input(self.op.input_arg_names[0])
out_1 = self._get_ge_input(self.op.input_arg_names[1])
out_2 = self._get_ge_input(self.op.input_arg_names[2])
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"Mul").set_input("x1", out_grad).set_input("x2", out_2)
y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"Mul").set_input("x1", out_1).set_input("x2", out_grad)
return [x_grad, y_grad], [[0], [1]]
class DotAddGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(DotAddGradParser, self).__init__(graph, var2geop)
self.parser_name = "elementwise_add_grad"
def _apply(self):
out_grad = self._get_ge_input(self.op.input_arg_names[0])
out_1 = self._get_ge_input(self.op.input_arg_names[1])
out_2 = self._get_ge_input(self.op.input_arg_names[2])
out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape
out_1_shape = self.op.block.var(self.op.input_arg_names[1]).shape
out_2_shape = self.op.block.var(self.op.input_arg_names[2]).shape
x_grad = out_grad
cur_time_x = len(out_grad_shape) - len(out_1_shape)
for i in range(cur_time_x):
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"ReduceSumD").set_input("x", x_grad).set_attr_vec_int32(
"axes", [0]).set_attr_bool("keep_dims", False)
for axis, size in enumerate(out_1_shape):
if size == 1:
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"ReduceSumD").set_input("x", x_grad).set_attr_vec_int32(
"axes", [axis]).set_attr_bool("keep_dims", True)
y_grad = out_grad
cur_time_y = len(out_grad_shape) - len(out_2_shape)
for i in range(cur_time_y):
y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"ReduceSumD").set_input("x", y_grad).set_attr_vec_int32(
"axes", [0]).set_attr_bool("keep_dims", False)
for axis, size in enumerate(out_2_shape):
if size == 1:
y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"ReduceSumD").set_input("x", y_grad).set_attr_vec_int32(
"axes", [axis]).set_attr_bool("keep_dims", True)
return [x_grad, y_grad], [[0], [1]]
class DotDivGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(DotDivGradParser, self).__init__(graph, var2geop)
self.parser_name = "elementwise_div_grad"
def _apply(self):
out = self._get_ge_input(self.op.input_arg_names[0])
out_grad = self._get_ge_input(self.op.input_arg_names[1])
x = self._get_ge_input(self.op.input_arg_names[2])
y = self._get_ge_input(self.op.input_arg_names[3])
y_power = core.GEOperatorFactory.create_operator(
"power" + self._accumulated_op_id(), "Power").set_input(
"x", y).set_attr_float("power", -1)
tensor_zeros = core.GEOperatorFactory.create_operator(
"zeroslike" + self._accumulated_op_id(),
"ZerosLike").set_input("x", x)
x_zero = core.GEOperatorFactory.create_operator(
"equal" + self._accumulated_op_id(), "Equal").set_input(
"x1", x).set_input("x2", tensor_zeros)
x_nozero = core.GEOperatorFactory.create_operator(
"logical_not" + self._accumulated_op_id(),
"LogicalNot").set_input("x", x_zero)
x_nozero_f = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", x_nozero).set_attr_int32("dst_type", 0)
x_grad_w = core.GEOperatorFactory.create_operator(
"mul" + self._accumulated_op_id(), "Mul").set_input(
"x1", x_nozero_f).set_input("x2", y_power)
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"Mul").set_input("x1", x_grad_w).set_input("x2", out_grad)
y_grad_w = core.GEOperatorFactory.create_operator(
"mul" + self._accumulated_op_id(), "Mul").set_input(
"x1", out).set_input("x2", y_power)
y_grad = core.GEOperatorFactory.create_operator(
"mul" + self._accumulated_op_id(), "Mul").set_input(
"x1", y_grad_w).set_input("x2", out_grad)
return [x_grad, y_grad], [[0], [1]]
class SoftmaxGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SoftmaxGradParser, self).__init__(graph, var2geop)
self.parser_name = "softmax_grad"
def _apply(self):
out = self._get_ge_input(self.op.input_arg_names[0])
out_grad = self._get_ge_input(self.op.input_arg_names[1])
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"SoftmaxGrad").set_input("softmax", out).set_input("grad_softmax",
out_grad)
return [x_grad], [[0]]
class ReshapeGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ReshapeGradParser, self).__init__(graph, var2geop)
self.parser_name = "reshape2_grad"
def _apply(self):
out_grad = self._get_ge_input(self.op.input_arg_names[0])
x_shape = self._get_ge_input(self.op.input_arg_names[1])
x_shape_list = self.op.block.var(self.op.input_arg_names[1]).shape
if x_shape_list[0] == 0:
x_shape_delzero = x_shape_list[1:]
tensor = self._create_ge_tensor([len(x_shape_delzero)], 2,
x_shape_delzero)
const_shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor)
x_grad = core.GEOperatorFactory.create_operator(
"reshape" + self._accumulated_op_id(), "Reshape").set_input(
"x", out_grad).set_input("shape", const_shape)
return [x_grad], [[0]]
class GatherGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(GatherGradParser, self).__init__(graph, var2geop)
self.parser_name = "gather_grad"
def _apply(self):
index = self._get_ge_input(self.op.input_arg_names[0])
out_grad = self._get_ge_input(self.op.input_arg_names[1])
x = self._get_ge_input(self.op.input_arg_names[2])
index_shape = self.op.block.var(self.op.input_arg_names[0]).shape
out_grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape
x_shape = self.op.block.var(self.op.input_arg_names[2]).shape
if len(index_shape) == 1:
index = core.GEOperatorFactory.create_operator(
"unsqueeze" + self._accumulated_op_id(), "Unsqueeze").set_input(
"x", index).set_attr_vec_int32("axes", [1])
tensor_zeros = core.GEOperatorFactory.create_operator(
"zeroslike" + self._accumulated_op_id(),
"ZerosLike").set_input("x", x)
x_grad = core.GEOperatorFactory.create_operator(
"scatter" + self._accumulated_op_id(),
"TensorScatterUpdate").set_input("x", tensor_zeros).set_input(
"indices", index).set_input("updates", out_grad)
return [tensor_zeros, x_grad], [[-1]]
class TransposeGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(TransposeGradParser, self).__init__(graph, var2geop)
self.parser_name = "transpose2_grad"
def _apply(self):
out_grad = self._get_ge_input(self.op.input_arg_names[0])
x = self._get_ge_input(self.op.input_arg_names[1])
perm = self.op.attr("axis")
x_shape = self.op.block.var(self.op.input_arg_names[1]).shape[1:]
out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape
assert list(map(lambda x: out_grad_shape[x], perm)) == list(x_shape)
x_grad = core.GEOperatorFactory.create_operator(
"transpose" + self._accumulated_op_id(), "TransposeD").set_input(
"x", out_grad).set_attr_vec_int32("perm", perm)
return [x_grad], [[0]]
class LayerNormGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(LayerNormGradParser, self).__init__(graph, var2geop)
self.parser_name = "layer_norm_grad"
def _apply(self):
bias = self._get_ge_input(self.op.input_arg_names[0])
mean = self._get_ge_input(self.op.input_arg_names[1])
scale = self._get_ge_input(self.op.input_arg_names[2])
variance = self._get_ge_input(self.op.input_arg_names[3])
x = self._get_ge_input(self.op.input_arg_names[4])
out_grad = self._get_ge_input(self.op.input_arg_names[5])
x_dtype = self.op.block.var(self.op.input_arg_names[4]).dtype
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"LayerNormGrad").set_input("dy", out_grad).set_input(
"x", x).set_input("variance", variance).set_input(
"mean", mean).set_input("gamma", scale)
cast_dtype = 0 if self.ascend_helper.dtype2paddle_inv_map[str(
x_dtype)] == 0 else 1
out_x_grad = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", x_grad, 0).set_attr_int32("dst_type", cast_dtype)
out_scale_grad = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", x_grad, 1).set_attr_int32("dst_type", cast_dtype)
out_bias_grad = core.GEOperatorFactory.create_operator(
"cast" + self._accumulated_op_id(), "Cast").set_input(
"x", x_grad, 2).set_attr_int32("dst_type", cast_dtype)
return [out_x_grad, out_scale_grad, out_bias_grad], [[2], [1], [0]]
class TanhGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(TanhGradParser, self).__init__(graph, var2geop)
self.parser_name = 'tanh_grad'
def _apply(self):
y = self._get_ge_input(self.op.input_arg_names[0])
out_grad = self._get_ge_input(self.op.input_arg_names[1])
tanh_grad = core.GEOperatorFactory.create_operator(
"tanh_grad" + self._accumulated_op_id(),
"TanhGrad").set_input("y", y).set_input("dy", out_grad)
return [tanh_grad], [[0]]
class LogGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(LogGradParser, self).__init__(graph, var2geop)
self.parser_name = 'log_grad'
def _apply(self):
grad = self._get_ge_input(self.op.input_arg_names[0])
input = self._get_ge_input(self.op.input_arg_names[1])
log_grad = core.GEOperatorFactory.create_operator(
"log_grad" + self._accumulated_op_id(),
"DivNoNan").set_input("x1", grad).set_input("x2", input)
return [log_grad], [[0]]
class SqrtGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SqrtGradParser, self).__init__(graph, var2geop)
self.parser_name = "sqrt_grad"
def _apply(self):
y = self._get_ge_input(self.op.input_arg_names[0])
out_grad = self._get_ge_input(self.op.input_arg_names[1])
sqrt_grad = core.GEOperatorFactory.create_operator(
"sqrt_grad" + self._accumulated_op_id(),
"SqrtGrad").set_input("y", y).set_input("dy", out_grad)
return [sqrt_grad]
class PowGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(PowGradParser, self).__init__(graph, var2geop)
self.parser_name = "pow_grad"
def _apply(self):
grad = self._get_ge_input(self.op.input_arg_names[0])
x = self._get_ge_input(self.op.input_arg_names[1])
factor = self.op.attr("factor")
shape_tensor = self._create_shape_tensor()
shape_tensor = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
factor_scale = self._create_ge_tensor([1], 5, factor)
factor_scale = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", factor_scale)
factor_tensor = core.GEOperatorFactory.create_operator(
"broadcast_to_d" + self._accumulated_op_id(),
"BroadcastTo").set_input(
"x", factor_scale).set_input("shape", shape_tensor)
x_power = core.GEOperatorFactory.create_operator(
"x_power" + self._accumulated_op_id(), "Power").set_input(
"x", x).set_attr_float("power", factor - 1)
x_power_mul_factor = core.GEOperatorFactory.create_operator(
"x_power_mul_factor" + self._accumulated_op_id(), "Mul").set_input(
"x1", x).set_input("x2", factor_tensor)
x_power_mul_factor_grad = core.GEOperatorFactory.create_operator(
"x_power_mul_factor_grad" + self._accumulated_op_id(),
"Mul").set_input("x1", x_power_mul_factor).set_input("x2", grad)
return [x_power_mul_factor_grad], [[0]]
class GeluGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(GeluGradParser, self).__init__(graph, var2geop)
self.parser_name = "gelu_grad"
def _apply(self):
grad = self._get_ge_input(self.op.input_arg_names[0])
x = self._get_ge_input(self.op.input_arg_names[1])
y = core.GEOperatorFactory.create_operator(
"gelu" + self._accumulated_op_id(), "Gelu").set_input("x", x)
gelu_grad = core.GEOperatorFactory.create_operator(
"gelu_grad" + self._accumulated_op_id(), "GeluGrad").set_input(
"x", x).set_input("dy", grad).set_input("y", y)
return [gelu_grad], [[0]]
class MeanGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(MeanGradParser, self).__init__(graph, var2geop)
self.parser_name = "mean_grad"
def _apply(self):
grad = self._get_ge_input(self.op.input_arg_names[0])
x = self._get_ge_input(self.op.input_arg_names[1])
ones_tensor = core.GEOperatorFactory.create_operator(
"one_tensor" + self._accumulated_op_id(),
"OnesLike").set_input("x", x)
sum = core.GEOperatorFactory.create_operator(
"mean" + self._accumulated_op_id(), "ReduceSumD").set_input(
"x", ones_tensor).set_attr_bool(
"keep_dims", False).set_attr_vec_int32("axes", [])
mean = core.GEOperatorFactory.create_operator(
"x_power" + self._accumulated_op_id(), "Power").set_input(
"x", sum).set_attr_float("power", -1)
mean_grad = core.GEOperatorFactory.create_operator(
"mean_grad" + self._accumulated_op_id(),
"Mul").set_input("x1", mean).set_input("x2", grad)
return [mean_grad], [[0]]
class SliceGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SliceGradParser, self).__init__(graph, var2geop)
self.parser_name = "slice_grad"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
grad = self._get_ge_input(self.op.input_arg_names[1])
axes = self.op.attr("axes")
starts = self.op.attr("starts")
ends = self.op.attr("ends")
x_shape = self.op.block.var(self.op.input_arg_names[0]).shape
grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape
len_shape = len(x_shape)
axes_cor = list(range(len_shape))
starts_cor, ends_cor = [], []
cnt = 0
for i in range(len_shape):
starts_cor.append(starts[cnt] if i in axes else 0)
if i in axes and ends[cnt] <= x_shape[i]:
ends_cor.append(x_shape[i] - ends[cnt])
else:
ends_cor.append(0)
if i in axes:
cnt += 1
starts_cor[0] = 0
ends_cor[0] = 0
paddings = [[s, e] for (s, e) in zip(starts_cor, ends_cor)]
slice_value = core.GEOperatorFactory.create_operator(
"slice_grad" + self._accumulated_op_id(), "PadD").set_input(
"x", grad).set_attr_vec_vec_int64("paddings", paddings)
return [slice_value], [[0]]
class LookUpTableGradParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(LookUpTableGradParser, self).__init__(graph, var2geop)
self.parser_name = "lookup_table_grad"
def _apply(self):
ids = self._get_ge_input(self.op.input_arg_names[0])
grad = self._get_ge_input(self.op.input_arg_names[1])
embedding = self._get_ge_input(self.op.input_arg_names[2])
shape_ids = self.op.block.var(self.op.input_arg_names[0]).shape
shape_grad = self.op.block.var(self.op.input_arg_names[1]).shape
shape_embedding = self.op.block.var(self.op.input_arg_names[2]).shape
ids_flatten = core.GEOperatorFactory.create_operator(
"flatten" + self._accumulated_op_id(), "FlattenV2").set_input(
"x",
ids).set_attr_int32("axis", 0).set_attr_int32("end_axis", 1)
grad_flatten = core.GEOperatorFactory.create_operator(
"flatten" + self._accumulated_op_id(), "FlattenV2").set_input(
"x",
grad).set_attr_int32("axis", 0).set_attr_int32("end_axis", 1)
tensor_zeros = core.GEOperatorFactory.create_operator(
"zeroslike" + self._accumulated_op_id(),
"ZerosLike").set_input("x", embedding)
embedding_grad = core.GEOperatorFactory.create_operator(
"scatteradd" + self._accumulated_op_id(),
"TensorScatterAdd").set_input(
"x", tensor_zeros).set_input("indices", ids_flatten).set_input(
"updates", grad_flatten)
return [embedding_grad], [[0]]
class SGDParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SGDParser, self).__init__(graph, var2geop)
self.parser_name = "sgd"
def _apply(self):
grad = self._get_ge_input(self.op.input_arg_names[0])
lr = self._get_ge_input(self.op.input_arg_names[1])
param = self._get_ge_input(self.op.input_arg_names[2])
sgd = core.GEOperatorFactory.create_operator(
"momentum" + self._accumulated_op_id(),
"ApplyGradientDescent").set_input("var", param).set_input(
"alpha", lr).set_input("delta", grad)
return [sgd], [[0]]
class AdamParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(AdamParser, self).__init__(graph, var2geop)
self.parser_name = "adam"
def _apply(self):
beta1_power = self._get_ge_input(self.op.input_arg_names[0])
beta2_power = self._get_ge_input(self.op.input_arg_names[1])
grad = self._get_ge_input(self.op.input_arg_names[2])
lr = self._get_ge_input(self.op.input_arg_names[3])
moment1 = self._get_ge_input(self.op.input_arg_names[4])
moment2 = self._get_ge_input(self.op.input_arg_names[5])
param = self._get_ge_input(self.op.input_arg_names[6])
beta1 = self.op.attr('beta1')
beta2 = self.op.attr('beta2')
epsilon = self.op.attr('epsilon')
beta1 = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(), "Const").set_attr_tensor(
"value", self._create_ge_tensor([1], 5, beta1))
beta2 = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(), "Const").set_attr_tensor(
"value", self._create_ge_tensor([1], 5, beta2))
epsilon = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(), "Const").set_attr_tensor(
"value", self._create_ge_tensor([1], 5, epsilon))
adam = core.GEOperatorFactory.create_operator(
"adam" + self._accumulated_op_id(),
"ApplyAdam").set_input("var", param).set_input(
"m", moment1).set_input("v", moment2).set_input(
"beta1_power", beta1_power).set_input(
"beta2_power", beta2_power).set_input(
"lr", lr).set_input("beta1", beta1).set_input(
"beta2", beta2).set_input(
"epsilon", epsilon).set_input("grad", grad)
return [adam], [[0]]