|
|
@ -644,7 +644,7 @@ class Cell(Cell_):
|
|
|
|
param.set_cast_dtype(mstype.float32)
|
|
|
|
param.set_cast_dtype(mstype.float32)
|
|
|
|
elif self._mindspore_flags.get('fp16'):
|
|
|
|
elif self._mindspore_flags.get('fp16'):
|
|
|
|
param.set_cast_dtype(mstype.float16)
|
|
|
|
param.set_cast_dtype(mstype.float16)
|
|
|
|
else:
|
|
|
|
elif hasattr(param, "set_cast_dtype"):
|
|
|
|
# retest dtype
|
|
|
|
# retest dtype
|
|
|
|
param.set_cast_dtype()
|
|
|
|
param.set_cast_dtype()
|
|
|
|
return param
|
|
|
|
return param
|
|
|
|