|
|
|
@ -203,7 +203,8 @@ class AscendParserBase(object):
|
|
|
|
|
def _accumulated_op_id(self):
|
|
|
|
|
global global_cnt
|
|
|
|
|
global_cnt += 1
|
|
|
|
|
return "." + str(global_cnt)
|
|
|
|
|
name = "." + str(global_cnt)
|
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
def _create_ge_tensor(self, shape, dtype, value):
|
|
|
|
|
tensor_desc = core.GETensorDesc(
|
|
|
|
@ -1630,10 +1631,14 @@ class MulGradParser(AscendParserBase):
|
|
|
|
|
"unsqueeze" + self._accumulated_op_id(),
|
|
|
|
|
"Unsqueeze").set_input("x",
|
|
|
|
|
y).set_attr_vec_int32("axes", [0])
|
|
|
|
|
y_stack = core.GEOperatorFactory.create_operator(
|
|
|
|
|
"stack" + self._accumulated_op_id(),
|
|
|
|
|
"TileWithAxis").set_input("x", y_unsqueeze).set_attr_int32(
|
|
|
|
|
"axis", 0).set_attr_int32("tiles", shape_out_grad[0])
|
|
|
|
|
x_grad = core.GEOperatorFactory.create_operator(
|
|
|
|
|
self.parser_name + self._accumulated_op_id(),
|
|
|
|
|
"BatchMatMul").set_input("x1", out_grad).set_input(
|
|
|
|
|
"x2", y_unsqueeze).set_attr_bool(
|
|
|
|
|
"x2", y_stack).set_attr_bool(
|
|
|
|
|
"adj_x1", False).set_attr_bool("adj_x2", True)
|
|
|
|
|
y_grad = core.GEOperatorFactory.create_operator(
|
|
|
|
|
self.parser_name + self._accumulated_op_id(),
|
|
|
|
|