!12870 add mix precision for cache

From: @fangzehua
Reviewed-by: 
Signed-off-by:
pull/12870/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 659b912f6d

@ -446,6 +446,9 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param
for (size_t i = 0; i < cnodes_size; ++i) {
if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) {
auto load_node = cnodes[i]->input(1);
if (IsPrimitiveCNode(load_node, prim::kPrimCast)) {
load_node = load_node->cast<CNodePtr>()->input(1);
}
if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
auto param_node = load_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>();
if (param_set.find(param_node) != param_set.end()) {

@ -261,6 +261,7 @@ class ModelCheckpoint(Callback):
self._manager = CheckpointManager()
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
self._graph_saved = False
self._need_flush_from_cache = True
def step_end(self, run_context):
"""
@ -326,7 +327,8 @@ class ModelCheckpoint(Callback):
return
# if param is cache enable, flush data from cache to host before save_ckpt
self._flush_from_cache(cb_params)
if self._need_flush_from_cache:
self._flush_from_cache(cb_params)
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
@ -365,10 +367,14 @@ class ModelCheckpoint(Callback):
def _flush_from_cache(self, cb_params):
"""Flush cache data to host if tensor is cache enable."""
has_cache_params = False
params = cb_params.train_network.get_parameters()
for param in params:
if param.cache_enable:
has_cache_params = True
Tensor(param).flush_from_cache()
if not has_cache_params:
self._need_flush_from_cache = False
@property
def latest_ckpt_file_name(self):

Loading…
Cancel
Save