|
|
@ -1086,7 +1086,9 @@ void OperatorWithKernel::TransferInplaceVarsBack(
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "The var[%s] should not be nullptr.",
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "The var[%s] should not be nullptr.",
|
|
|
|
var_name);
|
|
|
|
var_name);
|
|
|
|
auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
|
|
|
|
auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
|
|
|
|
|
|
|
|
auto original_dims = original_tensor->dims();
|
|
|
|
original_tensor->ShareDataWith(*transformed_tensor);
|
|
|
|
original_tensor->ShareDataWith(*transformed_tensor);
|
|
|
|
|
|
|
|
original_tensor->Resize(original_dims);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|