|
|
|
@ -674,12 +674,6 @@ test_case_nn_ops = [
|
|
|
|
|
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]],
|
|
|
|
|
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
|
|
|
|
|
'skip': ['backward']}),
|
|
|
|
|
('ApplyMomentum', {
|
|
|
|
|
'block': P.ApplyMomentum(),
|
|
|
|
|
'desc_inputs': [[128, 32, 32, 64], [128, 32, 32, 64],
|
|
|
|
|
[32, 32, 64], [32, 32, 64], [32, 32, 64]],
|
|
|
|
|
'desc_bprop': [[128, 32, 32, 64]],
|
|
|
|
|
'skip': ['backward']}),
|
|
|
|
|
('TopK', {
|
|
|
|
|
'block': P.TopK(),
|
|
|
|
|
'desc_const': [5],
|
|
|
|
@ -1113,12 +1107,6 @@ test_case_other_ops = [
|
|
|
|
|
'desc_inputs': (Tensor(np.ones((1, 3, 6, 6), np.float32)),
|
|
|
|
|
Tensor(np.ones((2, 4), np.int32))),
|
|
|
|
|
'desc_bprop': [[2]]}),
|
|
|
|
|
('ScatterNdUpdate', {
|
|
|
|
|
'block': P.ScatterNdUpdate(),
|
|
|
|
|
'desc_inputs': (Tensor(np.ones((2, 3), np.float32)),
|
|
|
|
|
Tensor(np.ones((2, 2), np.int32)),
|
|
|
|
|
Tensor(np.ones((2,), np.float32))),
|
|
|
|
|
'desc_bprop': [[2, 3]]}),
|
|
|
|
|
('ScatterNd', {
|
|
|
|
|
'block': P.ScatterNd(),
|
|
|
|
|
'desc_const': [(3, 3)],
|
|
|
|
@ -1178,7 +1166,7 @@ import mindspore.context as context
|
|
|
|
|
@non_graph_engine
|
|
|
|
|
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
|
|
|
|
|
def test_exec():
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
|
|
|
|
return test_exec_case
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1207,6 +1195,12 @@ raise_set = [
|
|
|
|
|
'block': (NetForFlatten0D(), {'exception': ValueError}),
|
|
|
|
|
'desc_inputs': [Tensor(np.array(0).astype(np.int32))],
|
|
|
|
|
'desc_bprop': [Tensor(np.array(0).astype(np.int32))]}),
|
|
|
|
|
('ScatterNdUpdate', {
|
|
|
|
|
'block': (P.ScatterNdUpdate(), {'exception': TypeError}),
|
|
|
|
|
'desc_inputs': (Tensor(np.ones((2, 3), np.float32)),
|
|
|
|
|
Tensor(np.ones((2, 2), np.int32)),
|
|
|
|
|
Tensor(np.ones((2,), np.float32))),
|
|
|
|
|
'desc_bprop': [[2, 3]]}),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|