!253 change adam output numbers adapter to tbe

Merge pull request !253 from zhaoting/add-YOLOv3-infer-scipt-and-change-dataset-to-MindRecord
pull/253/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 50d8992177

@ -88,7 +88,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float" "float16","float16","float16","float16","float","float","float", "float"
], ],
"format": [ "format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
], ],
"name": "beta1_power", "name": "beta1_power",
"need_compile": false, "need_compile": false,
@ -101,7 +102,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float","float" "float16","float16","float16","float16","float","float","float","float"
], ],
"format": [ "format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
], ],
"name": "beta2_power", "name": "beta2_power",
"need_compile": false, "need_compile": false,
@ -114,7 +116,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float" "float16","float16","float16","float16","float","float","float", "float"
], ],
"format": [ "format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
], ],
"name": "lr", "name": "lr",
"need_compile": false, "need_compile": false,
@ -127,7 +130,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float" "float16","float16","float16","float16","float","float","float", "float"
], ],
"format": [ "format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
], ],
"name": "beta1", "name": "beta1",
"need_compile": false, "need_compile": false,
@ -140,7 +144,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float" "float16","float16","float16","float16","float","float","float", "float"
], ],
"format": [ "format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
], ],
"name": "beta2", "name": "beta2",
"need_compile": false, "need_compile": false,
@ -153,7 +158,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float" "float16","float16","float16","float16","float","float","float", "float"
], ],
"format": [ "format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
], ],
"name": "epsilon", "name": "epsilon",
"need_compile": false, "need_compile": false,
@ -161,7 +167,7 @@ from mindspore.ops.op_info_register import op_info_register
"shape": "all" "shape": "all"
}, },
{ {
"index": 8, "index": 9,
"dtype": [ "dtype": [
"float16","float16","float16","float16","float","float","float", "float" "float16","float16","float16","float16","float","float","float", "float"
], ],
@ -187,6 +193,32 @@ from mindspore.ops.op_info_register import op_info_register
"need_compile": false, "need_compile": false,
"param_type": "required", "param_type": "required",
"shape": "all" "shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "m",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "v",
"need_compile": false,
"param_type": "required",
"shape": "all"
} }
] ]
}""") }""")

@ -2149,7 +2149,7 @@ class Adam(PrimitiveWithInfer):
validator.check_param_equal("var_shape", var_shape, "m_shape", m_shape) validator.check_param_equal("var_shape", var_shape, "m_shape", m_shape)
validator.check_param_equal("var_shape", var_shape, "v_shape", v_shape) validator.check_param_equal("var_shape", var_shape, "v_shape", v_shape)
validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape)
return var_shape return var_shape, m_shape, v_shape
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
@ -2159,7 +2159,7 @@ class Adam(PrimitiveWithInfer):
args = {"beta1_power_dtype": beta1_power_dtype, "beta2_power_dtype": beta2_power_dtype, 'lr_dtype': lr_dtype, args = {"beta1_power_dtype": beta1_power_dtype, "beta2_power_dtype": beta2_power_dtype, 'lr_dtype': lr_dtype,
"beta1_dtype": beta1_dtype, "beta2_dtype": beta2_dtype, "epsilon_dtype": epsilon_dtype} "beta1_dtype": beta1_dtype, "beta2_dtype": beta2_dtype, "epsilon_dtype": epsilon_dtype}
validator.check_type_same(args, [mstype.float16, mstype.float32]) validator.check_type_same(args, [mstype.float16, mstype.float32])
return var_dtype return var_dtype, m_dtype, v_dtype
class BinaryCrossEntropy(PrimitiveWithInfer): class BinaryCrossEntropy(PrimitiveWithInfer):

Loading…
Cancel
Save