|
|
|
@ -896,6 +896,16 @@ py::object ForwardExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::obj
|
|
|
|
|
// Update input for cast struct
|
|
|
|
|
auto cast_struct = cast_struct_pair->second;
|
|
|
|
|
cast_struct->op_inputs[0] = obj;
|
|
|
|
|
auto grad = this->grad();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(grad);
|
|
|
|
|
if (grad->grad_flag()) {
|
|
|
|
|
// Get forward op index
|
|
|
|
|
if (!grad->cell_op_info_stack().empty()) {
|
|
|
|
|
std::string &cell_op_info = grad->cell_op_info_stack().top();
|
|
|
|
|
cell_op_info += cast_struct->op_index;
|
|
|
|
|
}
|
|
|
|
|
grad->op_index_map()[cast_struct->op_name]++;
|
|
|
|
|
}
|
|
|
|
|
py::object ret = py::none();
|
|
|
|
|
RunOpInner(&ret, cast_struct);
|
|
|
|
|
return ret;
|
|
|
|
|