diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index d7784457d1..2df475c39d 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -624,8 +624,8 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { if (_hook_grad.find(cell_id) != _hook_grad.end()) { py::tuple hook_args = py::tuple(3); hook_args[0] = cell_id; - hook_args[1] = _hook_grad[cell_id]; - hook_args[2] = py_args[2]; + hook_args[1] = py::make_tuple(_hook_grad[cell_id]); + hook_args[2] = py::make_tuple(py_args[2]); py::function fn_hook = prim->hook(); obj = fn_hook(*hook_args); if (py::isinstance(obj)) { @@ -638,7 +638,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { } } else { py::function fn_hook = prim->hook(); - obj = fn_hook(py_args[2]); + obj = fn_hook(py::make_tuple(py_args[2])); if (py::isinstance(obj)) { obj = py_args[2]; } diff --git a/tests/ut/python/pynative_mode/test_hook.py b/tests/ut/python/pynative_mode/test_hook.py index 63d7128762..f25fe97cb2 100644 --- a/tests/ut/python/pynative_mode/test_hook.py +++ b/tests/ut/python/pynative_mode/test_hook.py @@ -30,13 +30,13 @@ def weight_variable(): def cell_hook_function(cell_id, grad_input, grad_output): print(cell_id) - assert(grad_output.asnumpy().shape == (32, 6, 14, 14)) - assert(grad_input.asnumpy().shape == (32, 16, 10, 10)) + assert(grad_output[0].asnumpy().shape == (32, 6, 14, 14)) + assert(grad_input[0].asnumpy().shape == (32, 16, 10, 10)) def var_hook_function(grad_out): print("grad:", grad_out) - assert(grad_out.asnumpy().shape == (32, 120)) + assert(grad_out[0].asnumpy().shape == (32, 120)) class LeNet5(nn.Cell):