From 99ebe71b39e0ee035811f09f2eefa479df941eeb Mon Sep 17 00:00:00 2001 From: chenyijie6 Date: Thu, 18 Mar 2021 20:25:50 +0800 Subject: [PATCH] Fix Pynative AMP Backprop Bugs --- mindspore/ccsrc/pipeline/pynative/pynative_execute.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 404fcf2a76..28e0fefcc8 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -896,6 +896,16 @@ py::object ForwardExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::obj // Update input for cast struct auto cast_struct = cast_struct_pair->second; cast_struct->op_inputs[0] = obj; + auto grad = this->grad(); + MS_EXCEPTION_IF_NULL(grad); + if (grad->grad_flag()) { + // Get forward op index + if (!grad->cell_op_info_stack().empty()) { + std::string &cell_op_info = grad->cell_op_info_stack().top(); + cell_op_info += cast_struct->op_index; + } + grad->op_index_map()[cast_struct->op_name]++; + } py::object ret = py::none(); RunOpInner(&ret, cast_struct); return ret;