callback module in encapsulated

pull/2236/head
Li Hongzhang 5 years ago
parent fc74606211
commit ecc459158e

@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode
from model.dataset_helper import DatasetHelper

@ -26,9 +26,9 @@
namespace mindspore {
namespace callbacks {
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback";
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op";
const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op";
const char kSummary[] = "Summary";
const char kCheckPoint[] = "Save";
const int ONE_SHAPE = 1;

@ -25,9 +25,9 @@
namespace mindspore {
namespace callbacks {
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback";
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op";
const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op";
const char kSummary[] = "Summary";
const char kCheckPoint[] = "Save";
const int ONE_SHAPE = 1;

@ -14,7 +14,15 @@
# ============================================================================
"""Callback related classes and functions."""
from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryStep, CheckpointConfig, RunContext
from ._callback import Callback
from ._callback import CallbackManager as _CallbackManager
from ._callback import InternalCallbackParam as _InternalCallbackParam
from ._callback import RunContext
from ._checkpoint import CheckpointConfig
from ._checkpoint import CheckpointManager as _CheckpointManager
from ._checkpoint import ModelCheckpoint
from ._loss_monitor import LossMonitor
from ._summary_step import SummaryStep
from ._time_monitor import TimeMonitor
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
"SummaryStep", "CheckpointConfig", "RunContext"]
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"]

File diff suppressed because it is too large Load Diff

@ -0,0 +1,62 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LossMonitor Callback class."""
import numpy as np
from mindspore.common.tensor import Tensor
from ._callback import Callback
class LossMonitor(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF, it will terminate training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
Raises:
ValueError: If print_step is not int or less than zero.
"""
def __init__(self, per_print_times=1):
super(LossMonitor, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cb_params.cur_epoch_num, cur_step_in_epoch))
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SummaryStep Callback class."""
from ._callback import Callback
class SummaryStep(Callback):
"""
The summary callback class.
Args:
summary (Object): Summary recode object.
flush_step (int): Number of interval steps to execute. Default: 10.
"""
def __init__(self, summary, flush_step=10):
super(SummaryStep, self).__init__()
if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0:
raise ValueError("`flush_step` should be int and greater than 0")
self._summary = summary
self._flush_step = flush_step
def __enter__(self):
self._summary.__enter__()
return self
def __exit__(self, *err):
return self._summary.__exit__(*err)
def step_end(self, run_context):
"""
Save summary.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
if cb_params.cur_step_num % self._flush_step == 0:
self._summary.record(cb_params.cur_step_num, cb_params.train_network)
@property
def summary_file_name(self):
return self._summary.full_file_name

@ -0,0 +1,35 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TimeMonitor Callback class."""
import time
from ._callback import Callback
class TimeMonitor(Callback):
"""Time Monitor."""
def __init__(self, data_size):
super(TimeMonitor, self).__init__()
self.data_size = data_size
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / self.data_size
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)

@ -19,7 +19,7 @@ from mindspore import log as logger
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
from .callback import _InternalCallbackParam, RunContext, _CallbackManager
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check

@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode
from .dataset_helper import DatasetHelper

@ -26,10 +26,10 @@ from mindspore.common.api import ms_function
from mindspore.common.tensor import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \
_checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
_CallbackManager, Callback, CheckpointConfig, _set_cur_net
from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
_CallbackManager, Callback, CheckpointConfig
from mindspore.train.callback._callback import set_cur_net, checkpoint_cb_for_save_op
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist
class Net(nn.Cell):
"""Net definition."""
@ -187,7 +187,7 @@ def test_checkpoint_cb_for_save_op():
one_param['name'] = "conv1.weight"
one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
parameter_list.append(one_param)
_checkpoint_cb_for_save_op(parameter_list)
checkpoint_cb_for_save_op(parameter_list)
def test_checkpoint_cb_for_save_op_update_net():
@ -198,8 +198,8 @@ def test_checkpoint_cb_for_save_op_update_net():
one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32)
parameter_list.append(one_param)
net = Net()
_set_cur_net(net)
_checkpoint_cb_for_save_op(parameter_list)
set_cur_net(net)
checkpoint_cb_for_save_op(parameter_list)
assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1

@ -28,7 +28,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.train.callback.callback import _CheckpointManager
from mindspore.train.callback import _CheckpointManager
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
_exec_save_checkpoint, export, _save_graph
from ..ut_filter import non_graph_engine

Loading…
Cancel
Save