Fix reshape on GE graph. (#31084)

Fix reshape on GE graph
revert-31562-mean
gongweibao 4 years ago committed by GitHub
parent a6edbc478b
commit c687edecd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -214,7 +214,8 @@ class AscendOptimizer(Optimizer):
parameter_list=None,
no_grad_set=None,
auto_dp=False,
rank_table_file=None):
rank_table_file=None,
precision_mode="must_keep_origin_dtype"):
minimized = None
if self.inner_opt:
minimized = self.inner_opt.minimize(
@ -234,7 +235,7 @@ class AscendOptimizer(Optimizer):
config = {
"ge.exec.deviceId": str(fleet.local_device_ids()),
"ge.graphRunMode": "1",
"ge.exec.precision_mode": "must_keep_origin_dtype",
"ge.exec.precision_mode": precision_mode,
}
# if multi trainers
if rank_table_file and fleet.world_size() > 1:

@ -203,7 +203,8 @@ class AscendParserBase(object):
def _accumulated_op_id(self):
global global_cnt
global_cnt += 1
return "." + str(global_cnt)
name = "." + str(global_cnt)
return name
def _create_ge_tensor(self, shape, dtype, value):
tensor_desc = core.GETensorDesc(
@ -1630,10 +1631,14 @@ class MulGradParser(AscendParserBase):
"unsqueeze" + self._accumulated_op_id(),
"Unsqueeze").set_input("x",
y).set_attr_vec_int32("axes", [0])
y_stack = core.GEOperatorFactory.create_operator(
"stack" + self._accumulated_op_id(),
"TileWithAxis").set_input("x", y_unsqueeze).set_attr_int32(
"axis", 0).set_attr_int32("tiles", shape_out_grad[0])
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"BatchMatMul").set_input("x1", out_grad).set_input(
"x2", y_unsqueeze).set_attr_bool(
"x2", y_stack).set_attr_bool(
"adj_x1", False).set_attr_bool("adj_x2", True)
y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),

Loading…
Cancel
Save