Fix reshape on GE graph. (#31084)

Fix reshape on GE graph
revert-31562-mean
gongweibao 4 years ago committed by GitHub
parent a6edbc478b
commit c687edecd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -214,7 +214,8 @@ class AscendOptimizer(Optimizer):
parameter_list=None, parameter_list=None,
no_grad_set=None, no_grad_set=None,
auto_dp=False, auto_dp=False,
rank_table_file=None): rank_table_file=None,
precision_mode="must_keep_origin_dtype"):
minimized = None minimized = None
if self.inner_opt: if self.inner_opt:
minimized = self.inner_opt.minimize( minimized = self.inner_opt.minimize(
@ -234,7 +235,7 @@ class AscendOptimizer(Optimizer):
config = { config = {
"ge.exec.deviceId": str(fleet.local_device_ids()), "ge.exec.deviceId": str(fleet.local_device_ids()),
"ge.graphRunMode": "1", "ge.graphRunMode": "1",
"ge.exec.precision_mode": "must_keep_origin_dtype", "ge.exec.precision_mode": precision_mode,
} }
# if multi trainers # if multi trainers
if rank_table_file and fleet.world_size() > 1: if rank_table_file and fleet.world_size() > 1:

@ -203,7 +203,8 @@ class AscendParserBase(object):
def _accumulated_op_id(self): def _accumulated_op_id(self):
global global_cnt global global_cnt
global_cnt += 1 global_cnt += 1
return "." + str(global_cnt) name = "." + str(global_cnt)
return name
def _create_ge_tensor(self, shape, dtype, value): def _create_ge_tensor(self, shape, dtype, value):
tensor_desc = core.GETensorDesc( tensor_desc = core.GETensorDesc(
@ -1630,10 +1631,14 @@ class MulGradParser(AscendParserBase):
"unsqueeze" + self._accumulated_op_id(), "unsqueeze" + self._accumulated_op_id(),
"Unsqueeze").set_input("x", "Unsqueeze").set_input("x",
y).set_attr_vec_int32("axes", [0]) y).set_attr_vec_int32("axes", [0])
y_stack = core.GEOperatorFactory.create_operator(
"stack" + self._accumulated_op_id(),
"TileWithAxis").set_input("x", y_unsqueeze).set_attr_int32(
"axis", 0).set_attr_int32("tiles", shape_out_grad[0])
x_grad = core.GEOperatorFactory.create_operator( x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(), self.parser_name + self._accumulated_op_id(),
"BatchMatMul").set_input("x1", out_grad).set_input( "BatchMatMul").set_input("x1", out_grad).set_input(
"x2", y_unsqueeze).set_attr_bool( "x2", y_stack).set_attr_bool(
"adj_x1", False).set_attr_bool("adj_x2", True) "adj_x1", False).set_attr_bool("adj_x2", True)
y_grad = core.GEOperatorFactory.create_operator( y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(), self.parser_name + self._accumulated_op_id(),

Loading…
Cancel
Save