You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
2197 lines
92 KiB
2197 lines
92 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import paddle.fluid.framework as framework
|
|
from paddle.fluid.optimizer import Optimizer
|
|
import paddle.fluid.core as core
|
|
import numpy as np
|
|
from paddle.distributed import fleet
|
|
from functools import reduce
|
|
|
|
registerd_op = {## forwards
|
|
"elementwise_add": "AddParser",
|
|
"matmul": "MatMulParser",
|
|
"mul": "MulParser",
|
|
"relu": "ReluParser",
|
|
"softmax_with_cross_entropy": "SoftmaxWithCrossEntropyParser",
|
|
"shape": "ShapeParser",
|
|
"fill_constant": "FillConstantParser",
|
|
"reduce_sum": "ReduceSumParser",
|
|
"elementwise_mul": "DotMulParser",
|
|
"elementwise_div": "DotDivParser",
|
|
"elementwise_pow": "DotPowParser",
|
|
"elementwise_max": "MaxParser",
|
|
"elementwise_min": "MinParser",
|
|
"elementwise_sub": "DotSubParser",
|
|
"pow": "PowParser",
|
|
"gelu": "GeluParser",
|
|
"sqrt": "SqrtParser",
|
|
"log": "LogParser",
|
|
"sum": "SumParser",
|
|
"logical_not": "LogicalNotParser",
|
|
"gather": "GatherParser",
|
|
"scatter": "ScatterParser",
|
|
"cast": "CastParser",
|
|
"tanh": "TanhParser",
|
|
"stack": "StackParser",
|
|
"square": "SquareParser",
|
|
"unsqueeze2": "UnSqueezeParser",
|
|
"assign": "AssignParser",
|
|
"softmax": "SoftMaxParser",
|
|
"reshape2": "ReshapeParser",
|
|
"transpose2": "TransposeParser",
|
|
"layer_norm": "LayerNormParser",
|
|
"less_than": "LessParser",
|
|
"mean": "MeanParser",
|
|
"scale": "ScaleParser",
|
|
"slice": "SliceParser",
|
|
"top_k": "TopkParser",
|
|
"accuracy": "AccuracyParser",
|
|
#"increment": "IncrementParser",
|
|
"lookup_table": "LookupTableParser",
|
|
"truncated_gaussian_random": "TruncatedNormalParser",
|
|
"c_allgather": "AllGatherParser",
|
|
"c_allreduce_sum": "AllReduceSumParser",
|
|
"c_allreduce_max": "AllReduceMaxParser",
|
|
"c_broadcast": "BroadcastParser",
|
|
"c_reduce_scatter": "ReduceScatterParser",
|
|
"c_send": "SendParser",
|
|
"c_receive": "ReceiveParser",
|
|
"uniform_random": "UniformRandomParser",
|
|
"range": "RangeParser",
|
|
"equal": "EqualParser",
|
|
"expand": "ExpandParser",
|
|
"squeeze2": "SqueezeParser",
|
|
|
|
|
|
## backwords
|
|
"matmul_grad": "MatMulGradParser",
|
|
"mul_grad": "MulGradParser",
|
|
"relu_grad": "ReluGradParser",
|
|
"reduce_sum_grad": "ReduceSumGradParser",
|
|
"softmax_with_cross_entropy_grad": "SoftmaxWithCrossEntropyGradParser",
|
|
"tanh_grad":"TanhGradParser",
|
|
"log_grad":"LogGradParser",
|
|
"pow_grad": "PowGradParser",
|
|
"sqrt_grad": "SqrtGradParser",
|
|
"gelu_grad": "GeluGradParser",
|
|
"mean_grad": "MeanGradParser",
|
|
'lookup_table_grad': "LookUpTableGradParser",
|
|
"elementwise_mul_grad": "DotMulGradParser",
|
|
"elementwise_add_grad": "DotAddGradParser",
|
|
"elementwise_div_grad": "DotDivGradParser",
|
|
"softmax_grad": "SoftmaxGradParser",
|
|
"slice_grad": "SliceGradParser",
|
|
"reshape2_grad": "ReshapeGradParser",
|
|
"gather_grad": "GatherGradParser",
|
|
"transpose2_grad": "TransposeGradParser",
|
|
"layer_norm_grad": "LayerNormGradParser",
|
|
|
|
## opt
|
|
"sgd": "SGDParser",
|
|
#"adam": "AdamParser",
|
|
}
|
|
global_cnt = -1
|
|
global_input_cnt = -1
|
|
|
|
|
|
class AscendHelper(object):
|
|
def __init__(self):
|
|
self.dtype2ge_map = {
|
|
0: core.GEDataType.DT_BOOL,
|
|
1: core.GEDataType.DT_INT16,
|
|
2: core.GEDataType.DT_INT32,
|
|
3: core.GEDataType.DT_INT64,
|
|
4: core.GEDataType.DT_FLOAT16,
|
|
5: core.GEDataType.DT_FLOAT,
|
|
6: core.GEDataType.DT_DOUBLE
|
|
}
|
|
self.dtype2np_map = {
|
|
0: "bool",
|
|
1: "int16",
|
|
2: "int32",
|
|
3: "int64",
|
|
4: "float16",
|
|
5: "float32",
|
|
6: "float64"
|
|
}
|
|
self.dtype2paddle_inv_map = {"VarType.FP32": 0, "VarType.FP16": 1}
|
|
|
|
def dtype2ge(self, dtype):
|
|
assert dtype in self.dtype2ge_map, "dtype[%d] is not supported %d" % (
|
|
dtype)
|
|
return self.dtype2ge_map[dtype]
|
|
|
|
def dtype2np(self, index):
|
|
assert index in self.dtype2np_map, "index[%d] is not supported %d" % (
|
|
dtype)
|
|
return self.dtype2np_map[index]
|
|
|
|
|
|
class AscendParserFactory(object):
|
|
def __init__(self, graph, var2geop):
|
|
self.graph = graph
|
|
self.var2geop = var2geop
|
|
|
|
def create_parse(self, parser_class):
|
|
try:
|
|
parser = globals()[parser_class](self.graph, self.var2geop)
|
|
return parser
|
|
except:
|
|
raise ValueError("parser class %s does not exist" % parser_class)
|
|
|
|
|
|
class AscendParserBase(object):
|
|
def __init__(self, graph, var2geop):
|
|
self.graph = graph
|
|
self.var2geop = var2geop
|
|
self.op = None
|
|
self.ascend_helper = AscendHelper()
|
|
|
|
def _get_ge_input(self, input_var_name):
|
|
assert input_var_name in self.var2geop, "var %s not created before" % (
|
|
input_var_name)
|
|
return self.var2geop[input_var_name]
|
|
|
|
def update_output(self, geop_list, index_list):
|
|
output_num = len(self.op.output_names)
|
|
assert output_num == len(
|
|
index_list
|
|
), "Parser[%s]'s output number[%d] is not equal to parameters number[%d]" % (
|
|
self.parser_name, len(index_list), output_num)
|
|
for output_id in range(output_num):
|
|
arguments = self.op.output(self.op.output_names[output_id])
|
|
#print("%d argument: %s" % (output_id, str(arguments)))
|
|
if len(arguments) > 0:
|
|
assert len(arguments) == len(
|
|
index_list[output_id]
|
|
), "Parser[%s]'s %dth argument number[%d] is not equal to paddle's number[%d]" % (
|
|
self.parser_name, output_id, len(index_list[output_id]),
|
|
len(arguments))
|
|
for i in range(len(arguments)):
|
|
#print("assgin index_list[%d][%d] to %s" %
|
|
# (output_id, i, arguments[i]))
|
|
self.var2geop[arguments[i]] = geop_list[index_list[
|
|
output_id][i]]
|
|
|
|
for geop in geop_list:
|
|
self.graph.add_op(geop)
|
|
|
|
def apply(self, op):
|
|
self.op = op
|
|
assert self.op.type == self.parser_name, "op [%s] != parser_name[%s]" % (
|
|
self.op.type, self.parser_name)
|
|
#print("begin to parse op %s" % (self.parser_name))
|
|
geop_list, index_list = self._apply()
|
|
self.update_output(geop_list, index_list)
|
|
|
|
def _mark_as_input(self, ge_tensor):
|
|
global global_input_cnt
|
|
global_input_cnt += 1
|
|
self.var2geop["geinput." + str(global_input_cnt)] = ge_tensor
|
|
|
|
def _accumulated_op_id(self):
|
|
global global_cnt
|
|
global_cnt += 1
|
|
name = "." + str(global_cnt)
|
|
return name
|
|
|
|
def _create_ge_tensor(self, shape, dtype, value):
|
|
tensor_desc = core.GETensorDesc(
|
|
core.GEShape(shape), core.GEFormat.FORMAT_ND,
|
|
self.ascend_helper.dtype2ge(dtype))
|
|
tensor = core.GETensor(tensor_desc)
|
|
|
|
data = (value * np.ones((
|
|
shape))).reshape(shape).astype(self.ascend_helper.dtype2np(dtype))
|
|
buf = data.tobytes()
|
|
data_8 = np.frombuffer(buf, dtype=np.uint8)
|
|
tensor.set_data(data_8)
|
|
return tensor
|
|
|
|
def _get_ge_tensor(self, shape, dtype, value_list):
|
|
tensor_desc = core.GETensorDesc(
|
|
core.GEShape(shape), core.GEFormat.FORMAT_ND,
|
|
self.ascend_helper.dtype2ge(dtype))
|
|
tensor = core.GETensor(tensor_desc)
|
|
|
|
data = np.array(value_list).reshape(shape).astype(
|
|
self.ascend_helper.dtype2np(dtype))
|
|
buf = data.tobytes()
|
|
data_8 = np.frombuffer(buf, dtype=np.uint8)
|
|
tensor.set_data(data_8)
|
|
|
|
tensor_const = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor)
|
|
|
|
return tensor_const
|
|
|
|
def _get_variable(self, shape, dtype, tensor):
|
|
if dtype == "int32":
|
|
type = core.GEDataType.DT_INT32
|
|
elif dtype == "float32":
|
|
type = core.GEDataType.DT_FLOAT
|
|
|
|
var = core.GEOperatorFactory.create_operator(
|
|
"variable" + self._accumulated_op_id(), "Variable")
|
|
var.update_output_desc("y",
|
|
core.GETensorDesc(
|
|
core.GEShape(shape), core.GEFormat.FORMAT_ND,
|
|
type))
|
|
assign = core.GEOperatorFactory.create_operator(
|
|
"assign" + self._accumulated_op_id(), "Assign").set_input(
|
|
"value", tensor).set_input("ref", var)
|
|
|
|
return assign
|
|
|
|
def _create_shape_tensor(self):
|
|
tensor_desc = core.GETensorDesc(
|
|
core.GEShape([2]), core.GEFormat.FORMAT_ND,
|
|
core.GEDataType.DT_INT32)
|
|
tensor = core.GETensor(tensor_desc)
|
|
|
|
data = np.ones((2)).astype("int32").reshape([2])
|
|
data[0] = 64
|
|
buf = data.tobytes()
|
|
data_8 = np.frombuffer(buf, dtype=np.uint8)
|
|
tensor.set_data(data_8)
|
|
return tensor
|
|
|
|
def _get_GEtensor_shape(self, tensor):
|
|
tensor_shape = core.GEOperatorFactory.create_operator(
|
|
"shape" + self._accumulated_op_id(), "Shape").set_input("x", tensor)
|
|
tensor_shape = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", tensor_shape).set_attr_int32("dst_type", 0)
|
|
return tensor_shape
|
|
|
|
|
|
### elementwise_op
|
|
class AddParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(AddParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "elementwise_add"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
y = self._get_ge_input(self.op.input_arg_names[1])
|
|
add = core.GEOperatorFactory.create_operator(
|
|
"add" + self._accumulated_op_id(),
|
|
"Add").set_input("x1", x).set_input("x2", y)
|
|
return [add], [[0]]
|
|
|
|
|
|
class DotSubParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(DotSubParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "elementwise_sub"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
y = self._get_ge_input(self.op.input_arg_names[1])
|
|
sub = core.GEOperatorFactory.create_operator(
|
|
"sub" + self._accumulated_op_id(),
|
|
"Sub").set_input("x1", x).set_input("x2", y)
|
|
return [sub], [[0]]
|
|
|
|
|
|
class DotMulParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(DotMulParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "elementwise_mul"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
y = self._get_ge_input(self.op.input_arg_names[1])
|
|
mul = core.GEOperatorFactory.create_operator(
|
|
"dotmul" + self._accumulated_op_id(),
|
|
"Mul").set_input("x1", x).set_input("x2", y)
|
|
return [mul], [[0]]
|
|
|
|
|
|
class DotDivParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(DotDivParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "elementwise_div"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
y = self._get_ge_input(self.op.input_arg_names[1])
|
|
div = core.GEOperatorFactory.create_operator(
|
|
"dotdiv" + self._accumulated_op_id(),
|
|
"Div").set_input("x1", x).set_input("x2", y)
|
|
return [div], [[0]]
|
|
|
|
|
|
class DotPowParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(DotPowParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "elementwise_pow"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
y = self._get_ge_input(self.op.input_arg_names[1])
|
|
pow = core.GEOperatorFactory.create_operator(
|
|
"dotpow" + self._accumulated_op_id(),
|
|
"Pow").set_input("x1", x1).set_input("x2", y)
|
|
return [pow], [[0]]
|
|
|
|
|
|
class LessParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(LessParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "less_than"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
y = self._get_ge_input(self.op.input_arg_names[1])
|
|
less_than = core.GEOperatorFactory.create_operator(
|
|
"less_than" + self._accumulated_op_id(),
|
|
"Less").set_input("x1", x).set_input("x2", y)
|
|
return [less_than], [[0]]
|
|
|
|
|
|
class MaxParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(MaxParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "elementwise_max"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
y = self._get_ge_input(self.op.input_arg_names[1])
|
|
max_out = core.GEOperatorFactory.create_operator(
|
|
"max" + self._accumulated_op_id(),
|
|
"Maximum").set_input("x1", x).set_input("x2", y)
|
|
return [max_out], [[0]]
|
|
|
|
|
|
class MinParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(MinParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "elementwise_min"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
y = self._get_ge_input(self.op.input_arg_names[1])
|
|
min_out = core.GEOperatorFactory.create_operator(
|
|
"min" + self._accumulated_op_id(),
|
|
"Minimum").set_input("x1", x).set_input("x2", y)
|
|
return [min_out], [[0]]
|
|
|
|
|
|
## cal
|
|
class LogParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(LogParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "log"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
log = core.GEOperatorFactory.create_operator(
|
|
"log" + self._accumulated_op_id(), "Log").set_input("x", x)
|
|
return [log], [[0]]
|
|
|
|
|
|
class SqrtParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SqrtParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "sqrt"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
sqrt = core.GEOperatorFactory.create_operator(
|
|
"sqrt" + self._accumulated_op_id(), "Sqrt").set_input("x", x)
|
|
return [sqrt], [[0]]
|
|
|
|
|
|
class PowParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(PowParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "pow"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
factor = self.op.attr("factor")
|
|
pow_value = core.GEOperatorFactory.create_operator(
|
|
"pow" + self._accumulated_op_id(),
|
|
"Power").set_input("x", x).set_attr_float(
|
|
"power", factor).set_attr_float("scale", 1.0).set_attr_float(
|
|
"shift", 0.0)
|
|
return [pow_value], [[0]]
|
|
|
|
|
|
class SquareParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SquareParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "square"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
square = core.GEOperatorFactory.create_operator(
|
|
"square" + self._accumulated_op_id(), "Square").set_input("x", x)
|
|
return [square], [[0]]
|
|
|
|
|
|
class SumParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SumParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "sum"
|
|
|
|
def _apply(self):
|
|
len_list = len(self.op.input_arg_names)
|
|
if len_list < 2:
|
|
assert False, "the size of input list must large or equal 2"
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
y = self._get_ge_input(self.op.input_arg_names[1])
|
|
sum = core.GEOperatorFactory.create_operator(
|
|
"sum" + self._accumulated_op_id(),
|
|
"Add").set_input("x1", x).set_input("x2", y)
|
|
for i in range(2, len_list):
|
|
y = self._get_ge_input(self.op.input_arg_names[i])
|
|
sum = core.GEOperatorFactory.create_operator(
|
|
"sum" + self._accumulated_op_id(),
|
|
"Add").set_input("x1", sum).set_input("x2", y)
|
|
return [sum], [[0]]
|
|
|
|
|
|
class LogicalNotParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(LogicalNotParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "logical_not"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
logical_not = core.GEOperatorFactory.create_operator(
|
|
"logical_not" + self._accumulated_op_id(),
|
|
"LogicalNot").set_input("x", x)
|
|
return [logical_not], [[0]]
|
|
|
|
|
|
class MeanParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(MeanParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "mean"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
mean = core.GEOperatorFactory.create_operator(
|
|
"mean" + self._accumulated_op_id(),
|
|
"ReduceMeanD").set_input("x", x).set_attr_bool(
|
|
"keep_dims", False).set_attr_vec_int32("axes", [])
|
|
return [mean], [[0]]
|
|
|
|
|
|
class ReduceSumParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(ReduceSumParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "reduce_sum"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
axes = self.op.attr("dim")
|
|
keep_dims = self.op.attr("keep_dim")
|
|
reduce_all = self.op.attr("reduce_all")
|
|
x_shape = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
if reduce_all:
|
|
axes = list(range(len(x_shape)))
|
|
reduce_sum = core.GEOperatorFactory.create_operator(
|
|
"reduce_sum" + self._accumulated_op_id(),
|
|
"ReduceSumD").set_input("x", x, 0).set_attr_vec_int32(
|
|
"axes", axes).set_attr_bool("keep_dims", keep_dims)
|
|
return [reduce_sum], [[0]]
|
|
|
|
|
|
#class IncrementParser(AscendParserBase):
|
|
# def __init__(self, graph, var2geop):
|
|
# super(IncrementParser, self).__init__(graph, var2geop)
|
|
# self.parser_name = "increment"
|
|
#
|
|
# def _apply(self):
|
|
# x = self._get_ge_input(self.op.input_arg_names[0])
|
|
# step = self.op.attr("step") #self._get_ge_input(self.op.input_arg_names[1])
|
|
# print("step: ", step)
|
|
#
|
|
# increment = core.GEOperatorFactory.create_operator("adds" + self._accumulated_op_id(), "Adds").set_input("x", x).set_attr_float("value", step) #set_input("x2", bias)
|
|
#
|
|
# return [increment]
|
|
|
|
|
|
## matrix cal
|
|
class MatMulParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(MatMulParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "matmul"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
y = self._get_ge_input(self.op.input_arg_names[1])
|
|
transpose_x = self.op.attr("transpose_X")
|
|
transpose_y = self.op.attr("transpose_Y")
|
|
|
|
x1_shape = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
x2_shape = self.op.block.var(self.op.input_arg_names[1]).shape
|
|
|
|
if len(x1_shape) > 2:
|
|
matmul = core.GEOperatorFactory.create_operator(
|
|
"matmul" + self._accumulated_op_id(), "BatchMatMul").set_input(
|
|
"x1", x).set_input("x2", y).set_attr_bool(
|
|
"adj_x1",
|
|
transpose_x).set_attr_bool("adj_x2", transpose_y)
|
|
elif len(x1_shape) == 2:
|
|
matmul = core.GEOperatorFactory.create_operator(
|
|
"matmul" + self._accumulated_op_id(),
|
|
"MatMul").set_input("x1", x).set_input("x2", y).set_attr_bool(
|
|
"transpose_x1", transpose_x).set_attr_bool("transpose_x2",
|
|
transpose_y)
|
|
else:
|
|
assert False, "not support"
|
|
return [matmul], [[0]]
|
|
|
|
|
|
class MulParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(MulParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "mul"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
y = self._get_ge_input(self.op.input_arg_names[1])
|
|
x_num_col_dims = self.op.attr("x_num_col_dims")
|
|
y_num_col_dims = self.op.attr("y_num_col_dims")
|
|
shape_x1 = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
shape_x2 = self.op.block.var(self.op.input_arg_names[1]).shape
|
|
|
|
if x_num_col_dims == 1 and y_num_col_dims == 1:
|
|
if len(shape_x1) == 2 and len(shape_x2) == 2:
|
|
matmul = core.GEOperatorFactory.create_operator(
|
|
"mul" + self._accumulated_op_id(),
|
|
"MatMul").set_input("x1", x).set_input("x2", y)
|
|
elif len(shape_x1) == 3 and len(shape_x2) == 2:
|
|
flatten_x1 = core.GEOperatorFactory.create_operator(
|
|
"flatten" + self._accumulated_op_id(),
|
|
"Flatten").set_input("x", x)
|
|
matmul = core.GEOperatorFactory.create_operator(
|
|
"mul" + self._accumulated_op_id(), "MatMul").set_input(
|
|
"x1", flatten_x1, 0).set_input("x2", y, 0)
|
|
else:
|
|
assert False, "not support"
|
|
else:
|
|
if len(shape_x1) == 3 and len(shape_x2) == 2:
|
|
assert x_num_col_dims == 2, "only support 2"
|
|
flatten_x1 = core.GEOperatorFactory.create_operator(
|
|
"flatten" + self._accumulated_op_id(),
|
|
"FlattenV2").set_input("x", x).set_attr_int32(
|
|
"axis", 0).set_attr_int32("end_axis", 1)
|
|
matmul_m = core.GEOperatorFactory.create_operator(
|
|
"mul" + self._accumulated_op_id(), "MatMul").set_input(
|
|
"x1", flatten_x1, 0).set_input("x2", y, 0)
|
|
matmul_transpose = core.GEOperatorFactory.create_operator(
|
|
"transpose" + self._accumulated_op_id(),
|
|
"TransposeD").set_input(
|
|
"x", matmul_m).set_attr_vec_int32("perm", [1, 0])
|
|
tensor = self._create_ge_tensor(
|
|
[3], 2, [shape_x2[1], shape_x1[0], shape_x1[1]])
|
|
const_shape = core.GEOperatorFactory.create_operator(
|
|
"shape" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor)
|
|
reshape_matmul = core.GEOperatorFactory.create_operator(
|
|
"reshape" + self._accumulated_op_id(), "Reshape").set_input(
|
|
"x", matmul_transpose).set_input(
|
|
"shape", const_shape).set_attr_int32("axis", 0)
|
|
matmul = core.GEOperatorFactory.create_operator(
|
|
"transpose" + self._accumulated_op_id(),
|
|
"TransposeD").set_input(
|
|
"x",
|
|
reshape_matmul).set_attr_vec_int32("perm", [1, 2, 0])
|
|
else:
|
|
assert False, "not support"
|
|
|
|
return [matmul], [[0]]
|
|
|
|
|
|
class LayerNormParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(LayerNormParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "layer_norm"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[2])
|
|
scale = self._get_ge_input(self.op.input_arg_names[1])
|
|
bias = self._get_ge_input(self.op.input_arg_names[0])
|
|
epsilon = self.op.attr("epsilon")
|
|
begin_norm_axis = self.op.attr("begin_norm_axis")
|
|
x_dtype = self.op.block.var(self.op.input_arg_names[2]).dtype
|
|
|
|
shape_tensor = core.GEOperatorFactory.create_operator(
|
|
"shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
|
|
scale_expand = core.GEOperatorFactory.create_operator(
|
|
"broadcast_to_d" + self._accumulated_op_id(),
|
|
"BroadcastTo").set_input("x",
|
|
scale).set_input("shape", shape_tensor)
|
|
bias_expand = core.GEOperatorFactory.create_operator(
|
|
"broadcast_to_d" + self._accumulated_op_id(),
|
|
"BroadcastTo").set_input("x", bias).set_input("shape", shape_tensor)
|
|
layer_norm = core.GEOperatorFactory.create_operator(
|
|
"layer_norm" + self._accumulated_op_id(),
|
|
"LayerNorm").set_input("x", x).set_input(
|
|
"gamma",
|
|
scale_expand).set_input("beta", bias_expand).set_attr_int32(
|
|
"begin_norm_axis", begin_norm_axis).set_attr_int32(
|
|
"begin_params_axis",
|
|
begin_norm_axis).set_attr_float("epsilon", epsilon)
|
|
|
|
cast_dtype = 0 if self.ascend_helper.dtype2paddle_inv_map[str(
|
|
x_dtype)] == 0 else 1
|
|
y = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", layer_norm, 0).set_attr_int32("dst_type", cast_dtype)
|
|
mean = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", layer_norm, 1).set_attr_int32("dst_type", cast_dtype)
|
|
variance = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", layer_norm, 2).set_attr_int32("dst_type", cast_dtype)
|
|
return [y, mean, variance], [[1], [2], [0]]
|
|
|
|
|
|
## activate function
|
|
class ReluParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(ReluParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "relu"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
relu = core.GEOperatorFactory.create_operator(
|
|
"relu" + self._accumulated_op_id(), "Relu").set_input("x", x)
|
|
return [relu], [[0]]
|
|
|
|
|
|
class GeluParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(GeluParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "gelu"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
gelu = core.GEOperatorFactory.create_operator(
|
|
"gelu" + self._accumulated_op_id(), "Gelu").set_input("x", x)
|
|
return [gelu], [[0]]
|
|
|
|
|
|
class TanhParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(TanhParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "tanh"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
tanh = core.GEOperatorFactory.create_operator(
|
|
"tanh" + self._accumulated_op_id(), "Tanh").set_input("x", x)
|
|
return [tanh], [[0]]
|
|
|
|
|
|
## loss function
|
|
class SoftmaxWithCrossEntropyParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SoftmaxWithCrossEntropyParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "softmax_with_cross_entropy"
|
|
|
|
def _apply(self):
|
|
label = self._get_ge_input(self.op.input_arg_names[0])
|
|
logits = self._get_ge_input(self.op.input_arg_names[1])
|
|
cls_num = self.op.block.var(self.op.input_arg_names[1]).shape[1]
|
|
|
|
softmax = core.GEOperatorFactory.create_operator(
|
|
"softmax" + self._accumulated_op_id(),
|
|
"SoftmaxV2").set_input("x", logits)
|
|
label = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", label).set_attr_int32("dst_type", 3)
|
|
|
|
tensoron = self._create_ge_tensor([1], 5, 1)
|
|
on = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensoron)
|
|
tensoroff = self._create_ge_tensor([1], 5, 0)
|
|
off = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensoroff)
|
|
self._mark_as_input(on)
|
|
self._mark_as_input(off)
|
|
onehot = core.GEOperatorFactory.create_operator(
|
|
"onehot" + self._accumulated_op_id(), "OneHotD").set_input(
|
|
"x", label).set_input("on_value", on).set_input(
|
|
"off_value", off).set_attr_int32("depth", cls_num)
|
|
squeeze = core.GEOperatorFactory.create_operator(
|
|
"mul" + self._accumulated_op_id(), "Squeeze").set_input("x", onehot)
|
|
|
|
loss_all = core.GEOperatorFactory.create_operator(
|
|
"loss" + self._accumulated_op_id(),
|
|
"SoftmaxCrossEntropyWithLogits").set_input(
|
|
"features", logits).set_input("labels", squeeze)
|
|
loss = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", loss_all, 0).set_attr_int32("dst_type", 0)
|
|
loss_expand = core.GEOperatorFactory.create_operator(
|
|
"unsqueeze" + self._accumulated_op_id(),
|
|
"Unsqueeze").set_input("x", loss).set_attr_vec_int32("axes", [1])
|
|
return [label, softmax, loss_expand], [[2], [1]]
|
|
|
|
|
|
class SoftMaxParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SoftMaxParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "softmax"
|
|
|
|
def _apply(self):
|
|
logits = self._get_ge_input(self.op.input_arg_names[0])
|
|
axes = self.op.attr("axis")
|
|
|
|
softmax = core.GEOperatorFactory.create_operator(
|
|
"softmax" + self._accumulated_op_id(), "SoftmaxV2").set_input(
|
|
"x", logits).set_attr_vec_int32("axes", [axes])
|
|
return [softmax], [[0]]
|
|
|
|
|
|
## general
|
|
class ShapeParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(ShapeParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "shape"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
shape = core.GEOperatorFactory.create_operator(
|
|
"shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
|
|
return [shape], [[0]]
|
|
|
|
|
|
class FillConstantParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(FillConstantParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "fill_constant"
|
|
|
|
def _apply(self):
|
|
shape = self.op.attr("shape")
|
|
dtype = self.op.attr("dtype")
|
|
value = self.op.attr("value")
|
|
|
|
tensor = self._create_ge_tensor(shape, dtype, value)
|
|
const = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor)
|
|
self._mark_as_input(const)
|
|
if self.op.block.var(self.op.output('Out')[0]).persistable:
|
|
#print("%s is Persistable in fill_constant" %
|
|
# (self.op.output('Out')[0]))
|
|
var = core.GEOperatorFactory.create_operator(
|
|
self.op.output('Out')[0], "Variable")
|
|
var.update_output_desc("y",
|
|
core.GETensorDesc(
|
|
core.GEShape(shape),
|
|
core.GEFormat.FORMAT_ND,
|
|
core.GEDataType.DT_FLOAT))
|
|
assign = core.GEOperatorFactory.create_operator(
|
|
"assign" + self._accumulated_op_id(), "Assign").set_input(
|
|
"value", const).set_input("ref", var)
|
|
return [const], [[0]]
|
|
#else:
|
|
# print(
|
|
# "self.op.output('Out')[0]: %s is not persistable in fill_constant"
|
|
# % (self.op.output('Out')[0]))
|
|
return [const], [[0]]
|
|
|
|
|
|
class TruncatedNormalParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(TruncatedNormalParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "truncated_gaussian_random"
|
|
|
|
def _apply(self):
|
|
shape = self.op.attr("shape")
|
|
dtype = self.op.attr("dtype")
|
|
mean = self.op.attr("mean")
|
|
std = self.op.attr("std")
|
|
seed = self.op.attr("seed")
|
|
|
|
tensor1 = self._create_ge_tensor([len(shape)], 2, shape)
|
|
shape_tensor = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor1)
|
|
tensor2 = self._create_ge_tensor([1], dtype, mean)
|
|
mean_tensor = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor2)
|
|
tensor3 = self._create_ge_tensor([1], dtype, std)
|
|
std_tensor = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor3)
|
|
tensor4 = self._create_ge_tensor([1], dtype, mean - 2 * std)
|
|
min_tensor = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor4)
|
|
tensor5 = self._create_ge_tensor([1], dtype, mean + 2 * std)
|
|
max_tensor = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor5)
|
|
|
|
self._mark_as_input(shape_tensor)
|
|
self._mark_as_input(mean_tensor)
|
|
self._mark_as_input(std_tensor)
|
|
self._mark_as_input(min_tensor)
|
|
self._mark_as_input(max_tensor)
|
|
|
|
truncated_normal = core.GEOperatorFactory.create_operator(
|
|
"truncated_normal" + self._accumulated_op_id(),
|
|
"ParameterizedTruncatedNormal").set_input(
|
|
"shape", shape_tensor).set_input(
|
|
"means", mean_tensor).set_input(
|
|
"stdevs", std_tensor).set_input(
|
|
"min", min_tensor).set_input(
|
|
"max", max_tensor).set_attr_int32("seed", 0)
|
|
|
|
## wirte the output of truncatedNormal from startup_program to main_program
|
|
if self.op.block.var(self.op.output('Out')[0]).persistable:
|
|
#print("%s is Persistable in truncated_normal" %
|
|
# (self.op.output('Out')[0]))
|
|
var = core.GEOperatorFactory.create_operator(
|
|
self.op.output('Out')[0], "Variable")
|
|
var.update_output_desc("y",
|
|
core.GETensorDesc(
|
|
core.GEShape(shape),
|
|
core.GEFormat.FORMAT_ND,
|
|
core.GEDataType.DT_FLOAT))
|
|
assign = core.GEOperatorFactory.create_operator(
|
|
"assign" + self._accumulated_op_id(), "Assign").set_input(
|
|
"value", truncated_normal).set_input("ref", var)
|
|
return [
|
|
shape_tensor, mean_tensor, std_tensor, min_tensor, max_tensor,
|
|
truncated_normal
|
|
], [[-1]]
|
|
#else:
|
|
# print(
|
|
# "self.op.output('Out')[0] is not persistable in truncated_noraml"
|
|
# )
|
|
return [truncated_normal], [[0]]
|
|
|
|
|
|
class GatherParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(GatherParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "gather"
|
|
|
|
def _apply(self):
|
|
index = self._get_ge_input(self.op.input_arg_names[0])
|
|
x = self._get_ge_input(self.op.input_arg_names[1])
|
|
clo = self.op.block.var(self.op.input_arg_names[1]).shape[-1]
|
|
|
|
gather = core.GEOperatorFactory.create_operator(
|
|
"gather" + self._accumulated_op_id(), "Gather").set_input(
|
|
"x", x).set_input("indices", index).set_attr_bool(
|
|
"validate_indices", True)
|
|
return [gather], [[0]]
|
|
|
|
|
|
class ScatterParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(ScatterParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "scatter"
|
|
|
|
def _apply(self):
|
|
index = self._get_ge_input(self.op.input_arg_names[0])
|
|
x = self._get_ge_input(self.op.input_arg_names[1])
|
|
updates = self._get_ge_input(self.op.input_arg_names[2])
|
|
overwrite = self.op.attr("overwrite")
|
|
index_shape = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
|
|
if len(index_shape) == 1:
|
|
index = core.GEOperatorFactory.create_operator(
|
|
"unsqueeze" + self.getid(), "Unsqueeze").set_input(
|
|
"x", index).set_attr_vec_int32("axes", [1])
|
|
if not overwrite:
|
|
scatter_value = core.GEOperatorFactory.create_operator(
|
|
"scatter" + self._accumulated_op_id(),
|
|
"TensorScatterAdd").set_input(
|
|
"x", x_var).set_input("indices", index_var).set_input(
|
|
"updates", updatesi_var)
|
|
else:
|
|
scatter_value = core.GEOperatorFactory.create_operator(
|
|
"scatter" + self._accumulated_op_id(),
|
|
"TensorScatterUpdate").set_input(
|
|
"x", x_var).set_input("indices", index_var).set_input(
|
|
"updates", updates_var)
|
|
return [x_var, index_var, updates_var, scatter_value], [[-1]]
|
|
|
|
|
|
class CastParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(CastParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "cast"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
dtype = self.op.attr("out_dtype")
|
|
cast = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", x).set_attr_int32("dst_type", dtype)
|
|
return [cast], [[0]]
|
|
|
|
|
|
class AssignParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(AssignParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "assign"
|
|
|
|
def _apply(self):
|
|
const = self._get_ge_input(self.op.input_arg_names[0])
|
|
var = self._get_ge_input(self.op.input_arg_names[1])
|
|
assign = core.GEOperatorFactory.create_operator(
|
|
"assign" + self._accumulated_op_id(), "Assign").set_input(
|
|
"value", const).set_input("ref", var)
|
|
return [assign], [[0]]
|
|
|
|
|
|
class ScaleParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(ScaleParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "scale"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
scale = self.op.attr("scale")
|
|
bias = self.op.attr("bias")
|
|
bias_after_scale = self.op.attr("bias_after_scale")
|
|
|
|
if bias_after_scale:
|
|
scale_value = core.GEOperatorFactory.create_operator(
|
|
"scale" + self._accumulated_op_id(), "Power").set_input(
|
|
"x", x).set_attr_float("power", 1.0).set_attr_float(
|
|
"scale", scale).set_attr_float("shift", bias)
|
|
else:
|
|
x_add_bias = core.GEOperatorFactory.create_operator(
|
|
"adds" + self._accumulated_op_id(), "Adds").set_input(
|
|
"x", x).set_attr_float("value", bias)
|
|
scale_value = core.GEOperatorFactory.create_operator(
|
|
"scale" + self._accumulated_op_id(), "Power").set_input(
|
|
"x",
|
|
x_add_bias).set_attr_float("power", 1.0).set_attr_float(
|
|
"scale", scale).set_attr_float("shift", 0.0)
|
|
return [scale_value], [[0]]
|
|
|
|
|
|
class SliceParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SliceParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "slice"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
axes = self.op.attr("axes")
|
|
starts = self.op.attr("starts")
|
|
ends = self.op.attr("ends")
|
|
|
|
x_shape = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
len_shape = len(x_shape)
|
|
axes_cor = list(range(len_shape))
|
|
starts_cor, ends_cor = [], []
|
|
cnt = 0
|
|
for i in range(len_shape):
|
|
starts_cor.append(starts[cnt] if i in axes else 0)
|
|
if i in axes and ends[cnt] <= x_shape[i]:
|
|
ends_cor.append(ends[cnt])
|
|
else:
|
|
ends_cor.append(x_shape[i])
|
|
if i in axes:
|
|
cnt += 1
|
|
size = [ends_cor[i] - starts_cor[i] for i in range(len(axes_cor))]
|
|
|
|
assert len(axes_cor) == len(starts_cor) == len(
|
|
ends_cor), "the three fields must have same size"
|
|
slice_value = core.GEOperatorFactory.create_operator(
|
|
"slice" + self._accumulated_op_id(), "SliceD").set_input(
|
|
"x", x).set_attr_vec_int32(
|
|
"offsets", starts_cor).set_attr_vec_int32("size", size)
|
|
|
|
return [slice_value], [[0]]
|
|
|
|
|
|
class ReshapeParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(ReshapeParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "reshape2"
|
|
|
|
def _apply(self):
|
|
org_shape = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
assert org_shape.count(-1) == 0, "do not allow the dim is -1"
|
|
shape = self.op.attr("shape")
|
|
for cnt in range(len(shape)):
|
|
if shape[cnt] == 0:
|
|
shape[cnt] = org_shape[cnt]
|
|
|
|
if -1 in shape:
|
|
assert shape.count(-1) == 1, "only allow one dim is -1"
|
|
mul_res_org = reduce(lambda x, y: x * y, org_shape)
|
|
mul_res_refine = reduce(lambda x, y: x * y, shape) * -1
|
|
idx = shape.index(-1)
|
|
shape[idx] = mul_res_org // mul_res_refine
|
|
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
tensor = self._create_ge_tensor([len(shape)], 2, shape)
|
|
const_shape = core.GEOperatorFactory.create_operator(
|
|
"shape" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor)
|
|
reshape = core.GEOperatorFactory.create_operator(
|
|
"reshape" + self._accumulated_op_id(), "Reshape").set_input(
|
|
"x",
|
|
x).set_input("shape", const_shape).set_attr_int32("axis", 0)
|
|
x_shape = core.GEOperatorFactory.create_operator(
|
|
"shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
|
|
|
|
return [x_shape, reshape], [[1], [0]]
|
|
|
|
|
|
class TransposeParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(TransposeParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "transpose2"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
perm = self.op.attr("axis")
|
|
transpose = core.GEOperatorFactory.create_operator(
|
|
"transpose" + self._accumulated_op_id(), "TransposeD").set_input(
|
|
"x", x).set_attr_vec_int32("perm", perm)
|
|
x_shape = core.GEOperatorFactory.create_operator(
|
|
"shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
|
|
|
|
return [x_shape, transpose], [[1], [0]]
|
|
|
|
|
|
class AccuracyParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(AccuracyParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "accuracy"
|
|
|
|
def _apply(self):
|
|
pred = self._get_ge_input(self.op.input_arg_names[0])
|
|
label = self._get_ge_input(self.op.input_arg_names[1])
|
|
logits = self._get_ge_input(self.op.input_arg_names[2])
|
|
|
|
pred = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", pred).set_attr_int32("dst_type", 3)
|
|
label = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", label).set_attr_int32("dst_type", 3)
|
|
equal = core.GEOperatorFactory.create_operator(
|
|
"equal" + self._accumulated_op_id(), "Equal").set_input(
|
|
"x1", pred).set_input("x2", label)
|
|
cast = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", equal).set_attr_int32("dst_type", 0)
|
|
acc = core.GEOperatorFactory.create_operator(
|
|
"mean" + self._accumulated_op_id(), "ReduceMeanD").set_input(
|
|
"x", cast).set_attr_bool("keep_dims", False).set_attr_vec_int32(
|
|
"axes", [])
|
|
correct = core.GEOperatorFactory.create_operator(
|
|
"sum" + self._accumulated_op_id(), "ReduceSumD").set_input(
|
|
"x", cast).set_attr_bool("keep_dims", False).set_attr_vec_int32(
|
|
"axes", [])
|
|
ones_tensor = core.GEOperatorFactory.create_operator(
|
|
"oneslike" + self._accumulated_op_id(),
|
|
"OnesLike").set_input("x", label)
|
|
ones_tensor = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", ones_tensor).set_attr_int32("dst_type", 0)
|
|
total = core.GEOperatorFactory.create_operator(
|
|
"sum" + self._accumulated_op_id(), "ReduceSumD").set_input(
|
|
"x", ones_tensor).set_attr_bool(
|
|
"keep_dims", False).set_attr_vec_int32("axes", [])
|
|
|
|
return [acc, correct, total], [[0], [1], [2]]
|
|
|
|
|
|
class TopkParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(TopkParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "top_k"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
k = self.op.attr("k")
|
|
|
|
tensor = self._create_ge_tensor([1], 2, k)
|
|
const_k = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor)
|
|
cast_x = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(),
|
|
"Cast").set_input("x", x).set_attr_int32("dst_type", 1)
|
|
topk = core.GEOperatorFactory.create_operator(
|
|
"topk" + self._accumulated_op_id(),
|
|
"TopK").set_input("x", cast_x).set_input("k", const_k)
|
|
value = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", topk, 0).set_attr_int32("dst_type", 0)
|
|
index = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", topk, 1).set_attr_int32("dst_type", 0)
|
|
return [value, index], [[1], [0]]
|
|
|
|
|
|
class LookupTableParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(LookupTableParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "lookup_table"
|
|
|
|
def _apply(self):
|
|
ids = self._get_ge_input(self.op.input_arg_names[0])
|
|
w = self._get_ge_input(self.op.input_arg_names[1])
|
|
|
|
ids_squeeze = core.GEOperatorFactory.create_operator(
|
|
"squeeze" + self._accumulated_op_id(), "Squeeze").set_input(
|
|
"x", ids).set_attr_vec_int32("axes", [-1])
|
|
out = core.GEOperatorFactory.create_operator(
|
|
"lookup" + self._accumulated_op_id(), "Gather").set_input(
|
|
"x", w).set_input("indices", ids_squeeze)
|
|
return [out], [[0]]
|
|
|
|
|
|
class StackParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(StackParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "stack"
|
|
|
|
def _apply(self):
|
|
tiles = len(self.op.input_arg_names)
|
|
data_x_lst = []
|
|
for index in range(tiles):
|
|
data_x_lst.append(
|
|
self._get_ge_input(self.op.input_arg_names[index]))
|
|
axis = self.op.attr("axis")
|
|
|
|
data_x = data_x_lst[0]
|
|
tensor = self._create_ge_tensor([1], 2, axis)
|
|
tensor_axis = core.GEOperatorFactory.create_operator(
|
|
"axis" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor)
|
|
expand = core.GEOperatorFactory.create_operator(
|
|
"expand" + self._accumulated_op_id(),
|
|
"ExpandDims").set_input("x", data_x).set_input("axis", tensor_axis)
|
|
|
|
stack = core.GEOperatorFactory.create_operator(
|
|
"stack" + self._accumulated_op_id(),
|
|
"TileWithAxis").set_input("x", expand).set_attr_int32(
|
|
"axis", axis).set_attr_int32("tiles", tiles)
|
|
|
|
return [stack], [[0]]
|
|
|
|
|
|
class UnSqueezeParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(UnSqueezeParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "unsqueeze2"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
axes = self.op.attr('axes')
|
|
|
|
output = core.GEOperatorFactory.create_operator(
|
|
"unsqueeze" + self._accumulated_op_id(),
|
|
"Unsqueeze").set_input("x", x).set_attr_vec_int32("axes", axes)
|
|
shape = core.GEOperatorFactory.create_operator(
|
|
"shape" + self._accumulated_op_id(), "Shape").set_input("x", output)
|
|
return [shape, output], [[1], [0]]
|
|
|
|
|
|
## parallel
|
|
class AllGatherParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(AllGatherParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "c_allgather"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
rank_size = self.op.attr("rank_size")
|
|
group = self.op.attr("group")
|
|
|
|
allgather = core.GEOperatorFactory.create_operator(
|
|
"allgather" + self._accumulated_op_id(), "HcomAllGather").set_input(
|
|
"x", x).set_attr_int32(
|
|
"rank_size", rank_size).set_attr_string("group", group)
|
|
return [allgather], [[0]]
|
|
|
|
|
|
class AllReduceParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop, reduction):
|
|
super(AllReduceParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "c_allreduce_" + reduction
|
|
self.reduction = reduction
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
reduction = self.reduction
|
|
ring_id = self.op.attr("ring_id")
|
|
group = "hcom_group_" + str(ring_id)
|
|
fusion = None #self.op.attr("fusion")
|
|
fusion_id = None #self.op.attr("fusion_id")
|
|
|
|
allreduce = core.GEOperatorFactory.create_operator(
|
|
"allreduce" + self._accumulated_op_id(), "HcomAllReduce").set_input(
|
|
"x", x).set_attr_string(
|
|
"reduction", reduction).set_attr_string("group", group)
|
|
if fusion is not None:
|
|
allreduce.set_attr_int32("fusion", fusion)
|
|
|
|
if fusion_id is not None:
|
|
allreduce.set_attr_int32("fusion_id", fusion_id)
|
|
return [allreduce], [[0]]
|
|
|
|
|
|
class AllReduceSumParser(AllReduceParser):
|
|
def __init__(self, graph, var2geop):
|
|
super(AllReduceSumParser, self).__init__(graph, var2geop, 'sum')
|
|
|
|
|
|
class AllReduceMaxParser(AllReduceParser):
|
|
def __init__(self, graph, var2geop):
|
|
super(AllReduceMaxParser, self).__init__(graph, var2geop, 'max')
|
|
|
|
|
|
class BroadcastParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(BroadcastParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "c_broadcast"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
root_rank = self.op.attr("root_rank")
|
|
group = self.op.attr("group")
|
|
|
|
broadcast = core.GEOperatorFactory.create_operator(
|
|
"broadcast" + self._accumulated_op_id(), "HcomBroadcast").set_input(
|
|
"x", x).set_attr_int32(
|
|
"root_rank", root_rank).set_attr_string("group", group)
|
|
return [broadcast], [[0]]
|
|
|
|
|
|
class ReduceScatterParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(ReduceScatterParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "c_reduce_scatter"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
reduction = self.op.attr("reduction")
|
|
group = self.op.attr("group")
|
|
rank_size = self.op.attr("rank_size")
|
|
|
|
reduce_scatter = core.GEOperatorFactory.create_operator(
|
|
"reducescatter" + self._accumulated_op_id(),
|
|
"HcomReduceScatter").set_input("x", x).set_attr_string(
|
|
"reduction", reduction).set_attr_string(
|
|
"group", group).set_attr_int32("rank_size", rank_size)
|
|
return [reduce_scatter], [[0]]
|
|
|
|
|
|
class SendParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SendParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "c_send"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
sr_tag = self.op.attr("sr_tag")
|
|
dest_rank = self.op.attr("dest_rank")
|
|
group = self.op.attr("group")
|
|
|
|
send = core.GEOperatorFactory.create_operator(
|
|
"send" + self._accumulated_op_id(), "HcomSend").set_input(
|
|
"x", x).set_attr_int32("sr_tag", sr_tag).set_attr_int32(
|
|
"dest_rank", dest_rank).set_attr_string("group", group)
|
|
return [send], [[0]]
|
|
|
|
|
|
class ReceiveParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(ReceiveParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "c_receive"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
sr_tag = self.op.attr("sr_tag")
|
|
src_rank = self.op.attr("src_rank")
|
|
group = self.op.attr("group")
|
|
shape = self.op.attr("shape")
|
|
dtype = self.op.attr("dtype")
|
|
|
|
receive = core.GEOperatorFactory.create_operator(
|
|
"receive" + self._accumulated_op_id(), "HcomReceive").set_input(
|
|
"x", x).set_attr_int32("sr_tag", sr_tag).set_attr_int32(
|
|
"src_rank", src_rank).set_attr_string(
|
|
"group", group).set_attr_vec_int32(
|
|
"shape", shape).set_attr_int32("dtype", dtype)
|
|
return [receive], [[0]]
|
|
|
|
|
|
class RangeParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(RangeParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "range"
|
|
|
|
def _apply(self):
|
|
# TODO not support range type yet
|
|
start = self._get_ge_input(self.op.input_arg_names[0])
|
|
end = self._get_ge_input(self.op.input_arg_names[1])
|
|
delta = self._get_ge_input(self.op.input_arg_names[2])
|
|
|
|
ge_range = core.GEOperatorFactory.create_operator(
|
|
"range" + self._accumulated_op_id(), "Range")\
|
|
.set_input("start", end)\
|
|
.set_input("limit", start) \
|
|
.set_input("delta", delta)
|
|
|
|
return [ge_range], [[0]]
|
|
|
|
|
|
class UniformRandomParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(UniformRandomParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "uniform_random"
|
|
|
|
def _apply(self):
|
|
shape = self.op.attr("shape")
|
|
|
|
min_v = self.op.attr("min")
|
|
max_v = self.op.attr("max")
|
|
seed = self.op.attr("seed")
|
|
dtype = self.op.attr("dtype")
|
|
assert max_v > min_v, "assert max_v > min_v, but recieved " + \
|
|
"as max_v={}, min_v={} ".format(max_v, min_v)
|
|
|
|
tensor1 = self._create_ge_tensor([len(shape)], 2, shape)
|
|
shape_tensor = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor1)
|
|
|
|
ge_ur = core.GEOperatorFactory.create_operator(
|
|
"uniform_random" + self._accumulated_op_id(), "RandomUniform")\
|
|
.set_input("shape", shape_tensor)\
|
|
.set_attr_dtype("dtype", self.ascend_helper.dtype2ge(dtype)) \
|
|
.set_attr_int32("seed", seed)\
|
|
.set_attr_int32("seed2", seed)
|
|
|
|
scale = max_v - min_v
|
|
|
|
scale_value = core.GEOperatorFactory.create_operator(
|
|
"scale" + self._accumulated_op_id(), "Power").set_input(
|
|
"x", ge_ur).set_attr_float("power", 1.0).set_attr_float(
|
|
"scale", scale).set_attr_float("shift", min_v)
|
|
|
|
return [scale_value], [[0]]
|
|
|
|
|
|
class EqualParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(EqualParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "equal"
|
|
|
|
def _apply(self):
|
|
data_x1 = self._get_ge_input(self.op.input_arg_names[0])
|
|
data_x2 = self._get_ge_input(self.op.input_arg_names[1])
|
|
equal = core.GEOperatorFactory.create_operator("equal" \
|
|
+ self._accumulated_op_id(), "Equal")\
|
|
.set_input("x1", data_x1)\
|
|
.set_input("x2", data_x2)
|
|
return [equal], [[0]]
|
|
|
|
|
|
class ExpandParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(ExpandParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "expand"
|
|
|
|
def _apply(self):
|
|
data_x1_shape = self._get_ge_input(self.op.input_arg_names[0])
|
|
expand_times = self.op.attr('expand_times')
|
|
|
|
tensor = self._create_ge_tensor([len(expand_times)], 2, expand_times)
|
|
expand_tensor = core.GEOperatorFactory.\
|
|
create_operator("const" + self._accumulated_op_id(), "Const")\
|
|
.set_attr_tensor("value", tensor)
|
|
|
|
assign = core.GEOperatorFactory\
|
|
.create_operator("tile" + self._accumulated_op_id(), "Tile")\
|
|
.set_input("x", data_x1_shape)\
|
|
.set_input("multiples", expand_tensor)
|
|
return [assign], [[0]]
|
|
|
|
|
|
class SqueezeParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SqueezeParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "squeeze2"
|
|
|
|
def _apply(self):
|
|
tensor = self._get_ge_input(self.op.input_arg_names[0])
|
|
axes = self.op.attr("axes")
|
|
|
|
data_squeezed = core.GEOperatorFactory\
|
|
.create_operator("squeeze" + self._accumulated_op_id(), "Squeeze")\
|
|
.set_input("x", tensor)\
|
|
.set_attr_vec_int32("axes", axes)
|
|
shape = core.GEOperatorFactory.create_operator(
|
|
"shape" + self._accumulated_op_id(),
|
|
"Shape").set_input("x", data_squeezed)
|
|
return [shape, data_squeezed], [[1], [0]]
|
|
|
|
|
|
#****************************************************************#
|
|
#*************************** *************************#
|
|
#*************************** *************************#
|
|
#*************************** GradParser *************************#
|
|
#*************************** *************************#
|
|
#*************************** *************************#
|
|
#****************************************************************#
|
|
## grad
|
|
class ReduceSumGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(ReduceSumGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "reduce_sum_grad"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
input = self._get_ge_input(self.op.input_arg_names[1])
|
|
|
|
shape_tensor = core.GEOperatorFactory.create_operator(
|
|
"shape" + self._accumulated_op_id(),
|
|
"Shape").set_input("x", input, 0)
|
|
tensoron = self._create_ge_tensor([1], 2, -1)
|
|
const = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensoron)
|
|
self._mark_as_input(const)
|
|
|
|
reduce_sum = core.GEOperatorFactory.create_operator(
|
|
"broadcast_to_d" + self._accumulated_op_id(),
|
|
"BroadcastTo").set_input("x", x).set_input("shape", shape_tensor)
|
|
#reduce_sum = core.GEOperatorFactory.create_operator("expand" + self._accumulated_op_id(), "ExpandDims").set_input("x", reduce_sum).set_input("axis", const)
|
|
|
|
return [reduce_sum], [[0]]
|
|
|
|
|
|
class MatMulGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(MatMulGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "matmul_grad"
|
|
|
|
def _apply(self):
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[0])
|
|
x = self._get_ge_input(self.op.input_arg_names[1])
|
|
y = self._get_ge_input(self.op.input_arg_names[2])
|
|
transpose_x = self.op.attr("transpose_X")
|
|
transpose_y = self.op.attr("transpose_Y")
|
|
|
|
out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
x_shape = self.op.block.var(self.op.input_arg_names[1]).shape
|
|
y_shape = self.op.block.var(self.op.input_arg_names[2]).shape
|
|
|
|
if len(x_shape) > 2:
|
|
if transpose_y:
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"BatchMatMul").set_input("x1", out_grad).set_input(
|
|
"x2", y).set_attr_bool(
|
|
"adj_x1", False).set_attr_bool("adj_x2", False)
|
|
y_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"BatchMatMul").set_input("x1", out_grad).set_input(
|
|
"x2", x).set_attr_bool(
|
|
"adj_x1", True).set_attr_bool("adj_x2", False)
|
|
else:
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"BatchMatMul").set_input("x1", out_grad).set_input(
|
|
"x2", y).set_attr_bool(
|
|
"adj_x1", False).set_attr_bool("adj_x2", True)
|
|
y_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"BatchMatMul").set_input("x1", x).set_input(
|
|
"x2", out_grad).set_attr_bool(
|
|
"adj_x1", True).set_attr_bool("adj_x2", False)
|
|
else:
|
|
if transpose_y:
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"MatMul").set_input("x1", out_grad).set_input(
|
|
"x2", y).set_attr_bool(
|
|
"transpose_x1", False).set_attr_bool("transpose_x2",
|
|
False)
|
|
y_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"MatMul").set_input("x1", out_grad).set_input(
|
|
"x2", x).set_attr_bool(
|
|
"transpose_x1", True).set_attr_bool("transpose_x2",
|
|
False)
|
|
else:
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"MatMul").set_input("x1", out_grad).set_input(
|
|
"x2", y).set_attr_bool(
|
|
"transpose_x1", False).set_attr_bool("transpose_x2",
|
|
True)
|
|
y_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"MatMul").set_input("x1", x).set_input(
|
|
"x2", out_grad).set_attr_bool(
|
|
"transpose_x1", True).set_attr_bool("transpose_x2",
|
|
False)
|
|
|
|
return [x_grad, y_grad], [[0], [1]]
|
|
|
|
|
|
class MulGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(MulGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "mul_grad"
|
|
|
|
def _apply(self):
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[0])
|
|
x = self._get_ge_input(self.op.input_arg_names[1])
|
|
y = self._get_ge_input(self.op.input_arg_names[2])
|
|
x_num_col_dims = self.op.attr("x_num_col_dims")
|
|
y_num_col_dims = self.op.attr("y_num_col_dims")
|
|
|
|
shape_out_grad = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
shape_x = self.op.block.var(self.op.input_arg_names[1]).shape
|
|
shape_y = self.op.block.var(self.op.input_arg_names[2]).shape
|
|
|
|
if x_num_col_dims == 1 and y_num_col_dims == 1:
|
|
if len(shape_x) == 2 and len(shape_y) == 2:
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"MatMul").set_input("x1", out_grad).set_input(
|
|
"x2", y).set_attr_bool(
|
|
"transpose_x1", False).set_attr_bool("transpose_x2",
|
|
True)
|
|
y_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"MatMul").set_input("x1", x).set_input(
|
|
"x2", out_grad).set_attr_bool(
|
|
"transpose_x1", True).set_attr_bool("transpose_x2",
|
|
False)
|
|
elif len(shape_x) == 3 and len(shape_y) == 2:
|
|
flatten_x = core.GEOperatorFactory.create_operator(
|
|
"flatten" + self._accumulated_op_id(),
|
|
"Flatten").set_input("x", x)
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"MatMul").set_input(
|
|
"x1", out_grad).set_input("x2", y).set_attr_bool(
|
|
"transpose_x1",
|
|
False).set_attr_bool("transpose_x2", True)
|
|
if len(shape_out_grad) == 2:
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
"unsqueeze" + self._accumulated_op_id(),
|
|
"Unsqueeze").set_input("x", x_grad).set_attr_vec_int32(
|
|
"axes", [1])
|
|
|
|
y_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"MatMul").set_input(
|
|
"x1",
|
|
flatten_x).set_input("x2", out_grad).set_attr_bool(
|
|
"transpose_x1",
|
|
True).set_attr_bool("transpose_x2", False)
|
|
else:
|
|
if len(shape_x) == 3 and len(shape_y) == 2:
|
|
assert x_num_col_dims == 2, "only support 2"
|
|
flatten_x = core.GEOperatorFactory.create_operator(
|
|
"flatten" + self._accumulated_op_id(),
|
|
"FlattenV2").set_input("x", x).set_attr_int32(
|
|
"axis", 0).set_attr_int32("end_axis", 1)
|
|
flatten_out_grad = core.GEOperatorFactory.create_operator(
|
|
"flatten" + self._accumulated_op_id(),
|
|
"FlattenV2").set_input("x", out_grad).set_attr_int32(
|
|
"axis", 0).set_attr_int32("end_axis", 1)
|
|
|
|
y_unsqueeze = core.GEOperatorFactory.create_operator(
|
|
"unsqueeze" + self._accumulated_op_id(),
|
|
"Unsqueeze").set_input("x",
|
|
y).set_attr_vec_int32("axes", [0])
|
|
y_stack = core.GEOperatorFactory.create_operator(
|
|
"stack" + self._accumulated_op_id(),
|
|
"TileWithAxis").set_input("x", y_unsqueeze).set_attr_int32(
|
|
"axis", 0).set_attr_int32("tiles", shape_out_grad[0])
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"BatchMatMul").set_input("x1", out_grad).set_input(
|
|
"x2", y_stack).set_attr_bool(
|
|
"adj_x1", False).set_attr_bool("adj_x2", True)
|
|
y_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"MatMul").set_input("x1", flatten_x).set_input(
|
|
"x2", flatten_out_grad).set_attr_bool(
|
|
"transpose_x1",
|
|
True).set_attr_bool("transpose_x2", False)
|
|
|
|
return [x_grad, y_grad], [[0], [1]]
|
|
|
|
|
|
class ReluGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(ReluGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "relu_grad"
|
|
|
|
def _apply(self):
|
|
out = self._get_ge_input(self.op.input_arg_names[0])
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[1])
|
|
relu_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(), "ReluGrad").set_input(
|
|
"gradients", out_grad).set_input("features", out)
|
|
return [relu_grad], [[0]]
|
|
|
|
|
|
class SoftmaxWithCrossEntropyGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SoftmaxWithCrossEntropyGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "softmax_with_cross_entropy_grad"
|
|
|
|
def _apply(self):
|
|
label = self._get_ge_input(self.op.input_arg_names[0])
|
|
loss_grad = self._get_ge_input(self.op.input_arg_names[1])
|
|
softmax = self._get_ge_input(self.op.input_arg_names[2])
|
|
cls_num = self.op.block.var(self.op.input_arg_names[2]).shape[1]
|
|
|
|
label_shape = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
loss_grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape
|
|
softmax_shape = self.op.block.var(self.op.input_arg_names[2]).shape
|
|
|
|
tensoron = self._create_ge_tensor([1], 5, 1)
|
|
on = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensoron)
|
|
tensoroff = self._create_ge_tensor([1], 5, 0)
|
|
off = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensoroff)
|
|
self._mark_as_input(on)
|
|
self._mark_as_input(off)
|
|
|
|
label = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", label).set_attr_int32("dst_type", 3)
|
|
onehot = core.GEOperatorFactory.create_operator(
|
|
"onehot" + self._accumulated_op_id(), "OneHotD").set_input(
|
|
"x", label).set_input("on_value", on).set_input(
|
|
"off_value", off).set_attr_int32("depth", cls_num)
|
|
squeeze = core.GEOperatorFactory.create_operator(
|
|
"suqeeze" + self._accumulated_op_id(),
|
|
"Squeeze").set_input("x", onehot)
|
|
sub = core.GEOperatorFactory.create_operator(
|
|
"sub" + self._accumulated_op_id(), "Sub").set_input(
|
|
"x1", softmax).set_input("x2", squeeze)
|
|
grad = core.GEOperatorFactory.create_operator(
|
|
"mul" + self._accumulated_op_id(),
|
|
"Mul").set_input("x1", loss_grad).set_input("x2", sub)
|
|
|
|
return [on, off, label, onehot, grad], [[-1]]
|
|
|
|
|
|
class DotMulGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(DotMulGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "elementwise_mul_grad"
|
|
|
|
def _apply(self):
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[0])
|
|
out_1 = self._get_ge_input(self.op.input_arg_names[1])
|
|
out_2 = self._get_ge_input(self.op.input_arg_names[2])
|
|
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"Mul").set_input("x1", out_grad).set_input("x2", out_2)
|
|
y_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"Mul").set_input("x1", out_1).set_input("x2", out_grad)
|
|
|
|
return [x_grad, y_grad], [[0], [1]]
|
|
|
|
|
|
class DotAddGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(DotAddGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "elementwise_add_grad"
|
|
|
|
def _apply(self):
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[0])
|
|
out_1 = self._get_ge_input(self.op.input_arg_names[1])
|
|
out_2 = self._get_ge_input(self.op.input_arg_names[2])
|
|
out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
out_1_shape = self.op.block.var(self.op.input_arg_names[1]).shape
|
|
out_2_shape = self.op.block.var(self.op.input_arg_names[2]).shape
|
|
|
|
x_grad = out_grad
|
|
cur_time_x = len(out_grad_shape) - len(out_1_shape)
|
|
for i in range(cur_time_x):
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"ReduceSumD").set_input("x", x_grad).set_attr_vec_int32(
|
|
"axes", [0]).set_attr_bool("keep_dims", False)
|
|
for axis, size in enumerate(out_1_shape):
|
|
if size == 1:
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"ReduceSumD").set_input("x", x_grad).set_attr_vec_int32(
|
|
"axes", [axis]).set_attr_bool("keep_dims", True)
|
|
|
|
y_grad = out_grad
|
|
cur_time_y = len(out_grad_shape) - len(out_2_shape)
|
|
for i in range(cur_time_y):
|
|
y_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"ReduceSumD").set_input("x", y_grad).set_attr_vec_int32(
|
|
"axes", [0]).set_attr_bool("keep_dims", False)
|
|
for axis, size in enumerate(out_2_shape):
|
|
if size == 1:
|
|
y_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"ReduceSumD").set_input("x", y_grad).set_attr_vec_int32(
|
|
"axes", [axis]).set_attr_bool("keep_dims", True)
|
|
|
|
return [x_grad, y_grad], [[0], [1]]
|
|
|
|
|
|
class DotDivGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(DotDivGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "elementwise_div_grad"
|
|
|
|
def _apply(self):
|
|
out = self._get_ge_input(self.op.input_arg_names[0])
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[1])
|
|
x = self._get_ge_input(self.op.input_arg_names[2])
|
|
y = self._get_ge_input(self.op.input_arg_names[3])
|
|
|
|
y_power = core.GEOperatorFactory.create_operator(
|
|
"power" + self._accumulated_op_id(), "Power").set_input(
|
|
"x", y).set_attr_float("power", -1)
|
|
|
|
tensor_zeros = core.GEOperatorFactory.create_operator(
|
|
"zeroslike" + self._accumulated_op_id(),
|
|
"ZerosLike").set_input("x", x)
|
|
x_zero = core.GEOperatorFactory.create_operator(
|
|
"equal" + self._accumulated_op_id(), "Equal").set_input(
|
|
"x1", x).set_input("x2", tensor_zeros)
|
|
x_nozero = core.GEOperatorFactory.create_operator(
|
|
"logical_not" + self._accumulated_op_id(),
|
|
"LogicalNot").set_input("x", x_zero)
|
|
x_nozero_f = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", x_nozero).set_attr_int32("dst_type", 0)
|
|
x_grad_w = core.GEOperatorFactory.create_operator(
|
|
"mul" + self._accumulated_op_id(), "Mul").set_input(
|
|
"x1", x_nozero_f).set_input("x2", y_power)
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"Mul").set_input("x1", x_grad_w).set_input("x2", out_grad)
|
|
|
|
y_grad_w = core.GEOperatorFactory.create_operator(
|
|
"mul" + self._accumulated_op_id(), "Mul").set_input(
|
|
"x1", out).set_input("x2", y_power)
|
|
y_grad = core.GEOperatorFactory.create_operator(
|
|
"mul" + self._accumulated_op_id(), "Mul").set_input(
|
|
"x1", y_grad_w).set_input("x2", out_grad)
|
|
|
|
return [x_grad, y_grad], [[0], [1]]
|
|
|
|
|
|
class SoftmaxGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SoftmaxGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "softmax_grad"
|
|
|
|
def _apply(self):
|
|
out = self._get_ge_input(self.op.input_arg_names[0])
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[1])
|
|
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"SoftmaxGrad").set_input("softmax", out).set_input("grad_softmax",
|
|
out_grad)
|
|
return [x_grad], [[0]]
|
|
|
|
|
|
class ReshapeGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(ReshapeGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "reshape2_grad"
|
|
|
|
def _apply(self):
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[0])
|
|
x_shape = self._get_ge_input(self.op.input_arg_names[1])
|
|
x_shape_list = self.op.block.var(self.op.input_arg_names[1]).shape
|
|
|
|
if x_shape_list[0] == 0:
|
|
x_shape_delzero = x_shape_list[1:]
|
|
tensor = self._create_ge_tensor([len(x_shape_delzero)], 2,
|
|
x_shape_delzero)
|
|
const_shape = core.GEOperatorFactory.create_operator(
|
|
"shape" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", tensor)
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
"reshape" + self._accumulated_op_id(), "Reshape").set_input(
|
|
"x", out_grad).set_input("shape", const_shape)
|
|
|
|
return [x_grad], [[0]]
|
|
|
|
|
|
class GatherGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(GatherGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "gather_grad"
|
|
|
|
def _apply(self):
|
|
index = self._get_ge_input(self.op.input_arg_names[0])
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[1])
|
|
x = self._get_ge_input(self.op.input_arg_names[2])
|
|
|
|
index_shape = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
out_grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape
|
|
x_shape = self.op.block.var(self.op.input_arg_names[2]).shape
|
|
|
|
if len(index_shape) == 1:
|
|
index = core.GEOperatorFactory.create_operator(
|
|
"unsqueeze" + self._accumulated_op_id(), "Unsqueeze").set_input(
|
|
"x", index).set_attr_vec_int32("axes", [1])
|
|
|
|
tensor_zeros = core.GEOperatorFactory.create_operator(
|
|
"zeroslike" + self._accumulated_op_id(),
|
|
"ZerosLike").set_input("x", x)
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
"scatter" + self._accumulated_op_id(),
|
|
"TensorScatterUpdate").set_input("x", tensor_zeros).set_input(
|
|
"indices", index).set_input("updates", out_grad)
|
|
|
|
return [tensor_zeros, x_grad], [[-1]]
|
|
|
|
|
|
class TransposeGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(TransposeGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "transpose2_grad"
|
|
|
|
def _apply(self):
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[0])
|
|
x = self._get_ge_input(self.op.input_arg_names[1])
|
|
perm = self.op.attr("axis")
|
|
|
|
x_shape = self.op.block.var(self.op.input_arg_names[1]).shape[1:]
|
|
out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
assert list(map(lambda x: out_grad_shape[x], perm)) == list(x_shape)
|
|
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
"transpose" + self._accumulated_op_id(), "TransposeD").set_input(
|
|
"x", out_grad).set_attr_vec_int32("perm", perm)
|
|
|
|
return [x_grad], [[0]]
|
|
|
|
|
|
class LayerNormGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(LayerNormGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "layer_norm_grad"
|
|
|
|
def _apply(self):
|
|
bias = self._get_ge_input(self.op.input_arg_names[0])
|
|
mean = self._get_ge_input(self.op.input_arg_names[1])
|
|
scale = self._get_ge_input(self.op.input_arg_names[2])
|
|
variance = self._get_ge_input(self.op.input_arg_names[3])
|
|
x = self._get_ge_input(self.op.input_arg_names[4])
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[5])
|
|
x_dtype = self.op.block.var(self.op.input_arg_names[4]).dtype
|
|
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
self.parser_name + self._accumulated_op_id(),
|
|
"LayerNormGrad").set_input("dy", out_grad).set_input(
|
|
"x", x).set_input("variance", variance).set_input(
|
|
"mean", mean).set_input("gamma", scale)
|
|
|
|
cast_dtype = 0 if self.ascend_helper.dtype2paddle_inv_map[str(
|
|
x_dtype)] == 0 else 1
|
|
out_x_grad = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", x_grad, 0).set_attr_int32("dst_type", cast_dtype)
|
|
out_scale_grad = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", x_grad, 1).set_attr_int32("dst_type", cast_dtype)
|
|
out_bias_grad = core.GEOperatorFactory.create_operator(
|
|
"cast" + self._accumulated_op_id(), "Cast").set_input(
|
|
"x", x_grad, 2).set_attr_int32("dst_type", cast_dtype)
|
|
|
|
return [out_x_grad, out_scale_grad, out_bias_grad], [[2], [1], [0]]
|
|
|
|
|
|
class TanhGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(TanhGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = 'tanh_grad'
|
|
|
|
def _apply(self):
|
|
y = self._get_ge_input(self.op.input_arg_names[0])
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[1])
|
|
tanh_grad = core.GEOperatorFactory.create_operator(
|
|
"tanh_grad" + self._accumulated_op_id(),
|
|
"TanhGrad").set_input("y", y).set_input("dy", out_grad)
|
|
|
|
return [tanh_grad], [[0]]
|
|
|
|
|
|
class LogGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(LogGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = 'log_grad'
|
|
|
|
def _apply(self):
|
|
grad = self._get_ge_input(self.op.input_arg_names[0])
|
|
input = self._get_ge_input(self.op.input_arg_names[1])
|
|
log_grad = core.GEOperatorFactory.create_operator(
|
|
"log_grad" + self._accumulated_op_id(),
|
|
"DivNoNan").set_input("x1", grad).set_input("x2", input)
|
|
return [log_grad], [[0]]
|
|
|
|
|
|
class SqrtGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SqrtGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "sqrt_grad"
|
|
|
|
def _apply(self):
|
|
y = self._get_ge_input(self.op.input_arg_names[0])
|
|
out_grad = self._get_ge_input(self.op.input_arg_names[1])
|
|
sqrt_grad = core.GEOperatorFactory.create_operator(
|
|
"sqrt_grad" + self._accumulated_op_id(),
|
|
"SqrtGrad").set_input("y", y).set_input("dy", out_grad)
|
|
return [sqrt_grad]
|
|
|
|
|
|
class PowGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(PowGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "pow_grad"
|
|
|
|
def _apply(self):
|
|
grad = self._get_ge_input(self.op.input_arg_names[0])
|
|
x = self._get_ge_input(self.op.input_arg_names[1])
|
|
factor = self.op.attr("factor")
|
|
|
|
shape_tensor = self._create_shape_tensor()
|
|
shape_tensor = core.GEOperatorFactory.create_operator(
|
|
"shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
|
|
factor_scale = self._create_ge_tensor([1], 5, factor)
|
|
factor_scale = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(),
|
|
"Const").set_attr_tensor("value", factor_scale)
|
|
factor_tensor = core.GEOperatorFactory.create_operator(
|
|
"broadcast_to_d" + self._accumulated_op_id(),
|
|
"BroadcastTo").set_input(
|
|
"x", factor_scale).set_input("shape", shape_tensor)
|
|
|
|
x_power = core.GEOperatorFactory.create_operator(
|
|
"x_power" + self._accumulated_op_id(), "Power").set_input(
|
|
"x", x).set_attr_float("power", factor - 1)
|
|
x_power_mul_factor = core.GEOperatorFactory.create_operator(
|
|
"x_power_mul_factor" + self._accumulated_op_id(), "Mul").set_input(
|
|
"x1", x).set_input("x2", factor_tensor)
|
|
x_power_mul_factor_grad = core.GEOperatorFactory.create_operator(
|
|
"x_power_mul_factor_grad" + self._accumulated_op_id(),
|
|
"Mul").set_input("x1", x_power_mul_factor).set_input("x2", grad)
|
|
|
|
return [x_power_mul_factor_grad], [[0]]
|
|
|
|
|
|
class GeluGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(GeluGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "gelu_grad"
|
|
|
|
def _apply(self):
|
|
grad = self._get_ge_input(self.op.input_arg_names[0])
|
|
x = self._get_ge_input(self.op.input_arg_names[1])
|
|
|
|
y = core.GEOperatorFactory.create_operator(
|
|
"gelu" + self._accumulated_op_id(), "Gelu").set_input("x", x)
|
|
gelu_grad = core.GEOperatorFactory.create_operator(
|
|
"gelu_grad" + self._accumulated_op_id(), "GeluGrad").set_input(
|
|
"x", x).set_input("dy", grad).set_input("y", y)
|
|
|
|
return [gelu_grad], [[0]]
|
|
|
|
|
|
class MeanGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(MeanGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "mean_grad"
|
|
|
|
def _apply(self):
|
|
grad = self._get_ge_input(self.op.input_arg_names[0])
|
|
x = self._get_ge_input(self.op.input_arg_names[1])
|
|
|
|
ones_tensor = core.GEOperatorFactory.create_operator(
|
|
"one_tensor" + self._accumulated_op_id(),
|
|
"OnesLike").set_input("x", x)
|
|
sum = core.GEOperatorFactory.create_operator(
|
|
"mean" + self._accumulated_op_id(), "ReduceSumD").set_input(
|
|
"x", ones_tensor).set_attr_bool(
|
|
"keep_dims", False).set_attr_vec_int32("axes", [])
|
|
mean = core.GEOperatorFactory.create_operator(
|
|
"x_power" + self._accumulated_op_id(), "Power").set_input(
|
|
"x", sum).set_attr_float("power", -1)
|
|
|
|
mean_grad = core.GEOperatorFactory.create_operator(
|
|
"mean_grad" + self._accumulated_op_id(),
|
|
"Mul").set_input("x1", mean).set_input("x2", grad)
|
|
|
|
return [mean_grad], [[0]]
|
|
|
|
|
|
class SliceGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SliceGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "slice_grad"
|
|
|
|
def _apply(self):
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
grad = self._get_ge_input(self.op.input_arg_names[1])
|
|
axes = self.op.attr("axes")
|
|
starts = self.op.attr("starts")
|
|
ends = self.op.attr("ends")
|
|
|
|
x_shape = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape
|
|
|
|
len_shape = len(x_shape)
|
|
axes_cor = list(range(len_shape))
|
|
starts_cor, ends_cor = [], []
|
|
cnt = 0
|
|
for i in range(len_shape):
|
|
starts_cor.append(starts[cnt] if i in axes else 0)
|
|
if i in axes and ends[cnt] <= x_shape[i]:
|
|
ends_cor.append(x_shape[i] - ends[cnt])
|
|
else:
|
|
ends_cor.append(0)
|
|
if i in axes:
|
|
cnt += 1
|
|
|
|
starts_cor[0] = 0
|
|
ends_cor[0] = 0
|
|
paddings = [[s, e] for (s, e) in zip(starts_cor, ends_cor)]
|
|
slice_value = core.GEOperatorFactory.create_operator(
|
|
"slice_grad" + self._accumulated_op_id(), "PadD").set_input(
|
|
"x", grad).set_attr_vec_vec_int64("paddings", paddings)
|
|
|
|
return [slice_value], [[0]]
|
|
|
|
|
|
class LookUpTableGradParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(LookUpTableGradParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "lookup_table_grad"
|
|
|
|
def _apply(self):
|
|
ids = self._get_ge_input(self.op.input_arg_names[0])
|
|
grad = self._get_ge_input(self.op.input_arg_names[1])
|
|
embedding = self._get_ge_input(self.op.input_arg_names[2])
|
|
|
|
shape_ids = self.op.block.var(self.op.input_arg_names[0]).shape
|
|
shape_grad = self.op.block.var(self.op.input_arg_names[1]).shape
|
|
shape_embedding = self.op.block.var(self.op.input_arg_names[2]).shape
|
|
|
|
ids_flatten = core.GEOperatorFactory.create_operator(
|
|
"flatten" + self._accumulated_op_id(), "FlattenV2").set_input(
|
|
"x",
|
|
ids).set_attr_int32("axis", 0).set_attr_int32("end_axis", 1)
|
|
grad_flatten = core.GEOperatorFactory.create_operator(
|
|
"flatten" + self._accumulated_op_id(), "FlattenV2").set_input(
|
|
"x",
|
|
grad).set_attr_int32("axis", 0).set_attr_int32("end_axis", 1)
|
|
|
|
tensor_zeros = core.GEOperatorFactory.create_operator(
|
|
"zeroslike" + self._accumulated_op_id(),
|
|
"ZerosLike").set_input("x", embedding)
|
|
embedding_grad = core.GEOperatorFactory.create_operator(
|
|
"scatteradd" + self._accumulated_op_id(),
|
|
"TensorScatterAdd").set_input(
|
|
"x", tensor_zeros).set_input("indices", ids_flatten).set_input(
|
|
"updates", grad_flatten)
|
|
|
|
return [embedding_grad], [[0]]
|
|
|
|
|
|
class SGDParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(SGDParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "sgd"
|
|
|
|
def _apply(self):
|
|
grad = self._get_ge_input(self.op.input_arg_names[0])
|
|
lr = self._get_ge_input(self.op.input_arg_names[1])
|
|
param = self._get_ge_input(self.op.input_arg_names[2])
|
|
sgd = core.GEOperatorFactory.create_operator(
|
|
"momentum" + self._accumulated_op_id(),
|
|
"ApplyGradientDescent").set_input("var", param).set_input(
|
|
"alpha", lr).set_input("delta", grad)
|
|
return [sgd], [[0]]
|
|
|
|
|
|
class AdamParser(AscendParserBase):
|
|
def __init__(self, graph, var2geop):
|
|
super(AdamParser, self).__init__(graph, var2geop)
|
|
self.parser_name = "adam"
|
|
|
|
def _apply(self):
|
|
beta1_power = self._get_ge_input(self.op.input_arg_names[0])
|
|
beta2_power = self._get_ge_input(self.op.input_arg_names[1])
|
|
grad = self._get_ge_input(self.op.input_arg_names[2])
|
|
lr = self._get_ge_input(self.op.input_arg_names[3])
|
|
moment1 = self._get_ge_input(self.op.input_arg_names[4])
|
|
moment2 = self._get_ge_input(self.op.input_arg_names[5])
|
|
param = self._get_ge_input(self.op.input_arg_names[6])
|
|
beta1 = self.op.attr('beta1')
|
|
beta2 = self.op.attr('beta2')
|
|
epsilon = self.op.attr('epsilon')
|
|
|
|
beta1 = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(), "Const").set_attr_tensor(
|
|
"value", self._create_ge_tensor([1], 5, beta1))
|
|
beta2 = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(), "Const").set_attr_tensor(
|
|
"value", self._create_ge_tensor([1], 5, beta2))
|
|
epsilon = core.GEOperatorFactory.create_operator(
|
|
"const" + self._accumulated_op_id(), "Const").set_attr_tensor(
|
|
"value", self._create_ge_tensor([1], 5, epsilon))
|
|
|
|
adam = core.GEOperatorFactory.create_operator(
|
|
"adam" + self._accumulated_op_id(),
|
|
"ApplyAdam").set_input("var", param).set_input(
|
|
"m", moment1).set_input("v", moment2).set_input(
|
|
"beta1_power", beta1_power).set_input(
|
|
"beta2_power", beta2_power).set_input(
|
|
"lr", lr).set_input("beta1", beta1).set_input(
|
|
"beta2", beta2).set_input(
|
|
"epsilon", epsilon).set_input("grad", grad)
|
|
|
|
return [adam], [[0]]
|