modify applr momentum output 2

pull/981/head
changzherui 5 years ago
parent e01692adda
commit 48735c25ca

@ -201,6 +201,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
if (AnfAlgo::GetCNodeName(kernel) == "ApplyMomentum") {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0);
AnfAlgo::SetOutputAddr(device_address, 0, kernel.get());
AnfAlgo::SetOutputAddr(device_address, 1, kernel.get());
return;
}

@ -29,23 +29,24 @@ apply_momentum_op_info = TBERegOp("ApplyMomentum") \
.input(2, "lr", False, "required", "all") \
.input(3, "grad", False, "required", "all") \
.input(4, "momentum", False, "required", "all") \
.output(0, "out", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD,
DataType.F16_Default, DataType.F16_5HD) \
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0,
DataType.F16_Default, DataType.F16_C1HWNCoC0) \
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
DataType.F16_Default, DataType.F16_FracZ) \
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
DataType.F32_Default, DataType.F32_5HD) \
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
DataType.F32_Default, DataType.F32_C1HWNCoC0) \
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
DataType.F32_Default, DataType.F32_FracZ) \
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
.get_op_info()

@ -1427,8 +1427,11 @@ class ApplyMomentum(PrimitiveWithInfer):
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
outputs=['output'])
self.is_tbe = context.get_context("device_target") == "Ascend"
def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
if self.is_tbe:
return v_shape, v_shape
return v_shape
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
@ -1439,6 +1442,8 @@ class ApplyMomentum(PrimitiveWithInfer):
validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name)
if self.is_tbe:
return g_dtype, g_dtype
return g_dtype

Loading…
Cancel
Save