parent
							
								
									7ffcf05809
								
							
						
					
					
						commit
						c000d0c915
					
				@ -0,0 +1,90 @@
 | 
				
			|||||||
 | 
					# Copyright 2021 Huawei Technologies Co., Ltd
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					# ============================================================================
 | 
				
			||||||
 | 
					"""Evaluation callback when training"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import stat
 | 
				
			||||||
 | 
					from mindspore import save_checkpoint
 | 
				
			||||||
 | 
					from mindspore import log as logger
 | 
				
			||||||
 | 
					from mindspore.train.callback import Callback
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EvalCallBack(Callback):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Evaluation callback when training.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        eval_function (function): evaluation function.
 | 
				
			||||||
 | 
					        eval_param_dict (dict): evaluation parameters' configure dict.
 | 
				
			||||||
 | 
					        interval (int): run evaluation interval, default is 1.
 | 
				
			||||||
 | 
					        eval_start_epoch (int): evaluation start epoch, default is 1.
 | 
				
			||||||
 | 
					        save_best_ckpt (bool): Whether to save best checkpoint, default is True.
 | 
				
			||||||
 | 
					        besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
 | 
				
			||||||
 | 
					        metrics_name (str): evaluation metrics name, default is `acc`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Examples:
 | 
				
			||||||
 | 
					        >>> EvalCallBack(eval_function, eval_param_dict)
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
 | 
				
			||||||
 | 
					                 ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
 | 
				
			||||||
 | 
					        super(EvalCallBack, self).__init__()
 | 
				
			||||||
 | 
					        self.eval_param_dict = eval_param_dict
 | 
				
			||||||
 | 
					        self.eval_function = eval_function
 | 
				
			||||||
 | 
					        self.eval_start_epoch = eval_start_epoch
 | 
				
			||||||
 | 
					        if interval < 1:
 | 
				
			||||||
 | 
					            raise ValueError("interval should >= 1.")
 | 
				
			||||||
 | 
					        self.interval = interval
 | 
				
			||||||
 | 
					        self.save_best_ckpt = save_best_ckpt
 | 
				
			||||||
 | 
					        self.best_res = 0
 | 
				
			||||||
 | 
					        self.best_epoch = 0
 | 
				
			||||||
 | 
					        if not os.path.isdir(ckpt_directory):
 | 
				
			||||||
 | 
					            os.makedirs(ckpt_directory)
 | 
				
			||||||
 | 
					        self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
 | 
				
			||||||
 | 
					        self.metrics_name = metrics_name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_ckpoint_file(self, file_name):
 | 
				
			||||||
 | 
					        """Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            os.chmod(file_name, stat.S_IWRITE)
 | 
				
			||||||
 | 
					            os.remove(file_name)
 | 
				
			||||||
 | 
					        except OSError:
 | 
				
			||||||
 | 
					            logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
 | 
				
			||||||
 | 
					        except ValueError:
 | 
				
			||||||
 | 
					            logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def epoch_end(self, run_context):
 | 
				
			||||||
 | 
					        """Callback when epoch end."""
 | 
				
			||||||
 | 
					        cb_params = run_context.original_args()
 | 
				
			||||||
 | 
					        cur_epoch = cb_params.cur_epoch_num
 | 
				
			||||||
 | 
					        if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
 | 
				
			||||||
 | 
					            res = self.eval_function(self.eval_param_dict)
 | 
				
			||||||
 | 
					            print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
 | 
				
			||||||
 | 
					            if res >= self.best_res:
 | 
				
			||||||
 | 
					                self.best_res = res
 | 
				
			||||||
 | 
					                self.best_epoch = cur_epoch
 | 
				
			||||||
 | 
					                print("update best result: {}".format(res), flush=True)
 | 
				
			||||||
 | 
					                if self.save_best_ckpt:
 | 
				
			||||||
 | 
					                    if os.path.exists(self.bast_ckpt_path):
 | 
				
			||||||
 | 
					                        self.remove_ckpoint_file(self.bast_ckpt_path)
 | 
				
			||||||
 | 
					                    save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
 | 
				
			||||||
 | 
					                    print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def end(self, run_context):
 | 
				
			||||||
 | 
					        print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
 | 
				
			||||||
 | 
					                                                                                     self.best_res,
 | 
				
			||||||
 | 
					                                                                                     self.best_epoch), flush=True)
 | 
				
			||||||
@ -0,0 +1,132 @@
 | 
				
			|||||||
 | 
					# Copyright 2021 Huawei Technologies Co., Ltd
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					# ============================================================================
 | 
				
			||||||
 | 
					"""evaluation metric."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mindspore.communication.management import GlobalComm
 | 
				
			||||||
 | 
					from mindspore.ops import operations as P
 | 
				
			||||||
 | 
					import mindspore.nn as nn
 | 
				
			||||||
 | 
					import mindspore.common.dtype as mstype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ClassifyCorrectCell(nn.Cell):
 | 
				
			||||||
 | 
					    r"""
 | 
				
			||||||
 | 
					    Cell that returns correct count of the prediction in classification network.
 | 
				
			||||||
 | 
					    This Cell accepts a network as arguments.
 | 
				
			||||||
 | 
					    It returns orrect count of the prediction to calculate the metrics.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        network (Cell): The network Cell.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Inputs:
 | 
				
			||||||
 | 
					        - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
 | 
				
			||||||
 | 
					        - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Outputs:
 | 
				
			||||||
 | 
					        Tuple, containing a scalar correct count of the prediction
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Examples:
 | 
				
			||||||
 | 
					        >>> # For a defined network Net without loss function
 | 
				
			||||||
 | 
					        >>> net = Net()
 | 
				
			||||||
 | 
					        >>> eval_net = nn.ClassifyCorrectCell(net)
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, network):
 | 
				
			||||||
 | 
					        super(ClassifyCorrectCell, self).__init__(auto_prefix=False)
 | 
				
			||||||
 | 
					        self._network = network
 | 
				
			||||||
 | 
					        self.argmax = P.Argmax()
 | 
				
			||||||
 | 
					        self.equal = P.Equal()
 | 
				
			||||||
 | 
					        self.cast = P.Cast()
 | 
				
			||||||
 | 
					        self.reduce_sum = P.ReduceSum()
 | 
				
			||||||
 | 
					        self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def construct(self, data, label):
 | 
				
			||||||
 | 
					        outputs = self._network(data)
 | 
				
			||||||
 | 
					        y_pred = self.argmax(outputs)
 | 
				
			||||||
 | 
					        y_pred = self.cast(y_pred, mstype.int32)
 | 
				
			||||||
 | 
					        y_correct = self.equal(y_pred, label)
 | 
				
			||||||
 | 
					        y_correct = self.cast(y_correct, mstype.float32)
 | 
				
			||||||
 | 
					        y_correct = self.reduce_sum(y_correct)
 | 
				
			||||||
 | 
					        total_correct = self.allreduce(y_correct)
 | 
				
			||||||
 | 
					        return (total_correct,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DistAccuracy(nn.Metric):
 | 
				
			||||||
 | 
					    r"""
 | 
				
			||||||
 | 
					    Calculates the accuracy for classification data in distributed mode.
 | 
				
			||||||
 | 
					    The accuracy class creates two local variables, correct number and total number that are used to compute the
 | 
				
			||||||
 | 
					    frequency with which predictions matches labels. This frequency is ultimately returned as the accuracy: an
 | 
				
			||||||
 | 
					    idempotent operation that simply divides correct number by total number.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    .. math::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        \text{accuracy} =\frac{\text{true_positive} + \text{true_negative}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        {\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        eval_type (str): Metric to calculate the accuracy over a dataset, for classification (single-label).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Examples:
 | 
				
			||||||
 | 
					        >>> y_correct = Tensor(np.array([20]))
 | 
				
			||||||
 | 
					        >>> metric = nn.DistAccuracy(batch_size=3, device_num=8)
 | 
				
			||||||
 | 
					        >>> metric.clear()
 | 
				
			||||||
 | 
					        >>> metric.update(y_correct)
 | 
				
			||||||
 | 
					        >>> accuracy = metric.eval()
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, batch_size, device_num):
 | 
				
			||||||
 | 
					        super(DistAccuracy, self).__init__()
 | 
				
			||||||
 | 
					        self.clear()
 | 
				
			||||||
 | 
					        self.batch_size = batch_size
 | 
				
			||||||
 | 
					        self.device_num = device_num
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def clear(self):
 | 
				
			||||||
 | 
					        """Clears the internal evaluation result."""
 | 
				
			||||||
 | 
					        self._correct_num = 0
 | 
				
			||||||
 | 
					        self._total_num = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update(self, *inputs):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Updates the internal evaluation result :math:`y_{pred}` and :math:`y`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            inputs: Input `y_correct`. `y_correct` is a `scalar Tensor`.
 | 
				
			||||||
 | 
					                `y_correct` is the right prediction count that gathered from all devices
 | 
				
			||||||
 | 
					                it's a scalar in float type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Raises:
 | 
				
			||||||
 | 
					            ValueError: If the number of the input is not 1.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(inputs) != 1:
 | 
				
			||||||
 | 
					            raise ValueError('Distribute accuracy needs 1 input (y_correct), but got {}'.format(len(inputs)))
 | 
				
			||||||
 | 
					        y_correct = self._convert_data(inputs[0])
 | 
				
			||||||
 | 
					        self._correct_num += y_correct
 | 
				
			||||||
 | 
					        self._total_num += self.batch_size * self.device_num
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def eval(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Computes the accuracy.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            Float, the computed result.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Raises:
 | 
				
			||||||
 | 
					            RuntimeError: If the sample size is 0.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self._total_num == 0:
 | 
				
			||||||
 | 
					            raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.')
 | 
				
			||||||
 | 
					        return self._correct_num / self._total_num
 | 
				
			||||||
@ -0,0 +1,90 @@
 | 
				
			|||||||
 | 
					# Copyright 2021 Huawei Technologies Co., Ltd
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					# ============================================================================
 | 
				
			||||||
 | 
					"""Evaluation callback when training"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import stat
 | 
				
			||||||
 | 
					from mindspore import save_checkpoint
 | 
				
			||||||
 | 
					from mindspore import log as logger
 | 
				
			||||||
 | 
					from mindspore.train.callback import Callback
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EvalCallBack(Callback):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Evaluation callback when training.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        eval_function (function): evaluation function.
 | 
				
			||||||
 | 
					        eval_param_dict (dict): evaluation parameters' configure dict.
 | 
				
			||||||
 | 
					        interval (int): run evaluation interval, default is 1.
 | 
				
			||||||
 | 
					        eval_start_epoch (int): evaluation start epoch, default is 1.
 | 
				
			||||||
 | 
					        save_best_ckpt (bool): Whether to save best checkpoint, default is True.
 | 
				
			||||||
 | 
					        besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
 | 
				
			||||||
 | 
					        metrics_name (str): evaluation metrics name, default is `acc`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Examples:
 | 
				
			||||||
 | 
					        >>> EvalCallBack(eval_function, eval_param_dict)
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
 | 
				
			||||||
 | 
					                 ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
 | 
				
			||||||
 | 
					        super(EvalCallBack, self).__init__()
 | 
				
			||||||
 | 
					        self.eval_param_dict = eval_param_dict
 | 
				
			||||||
 | 
					        self.eval_function = eval_function
 | 
				
			||||||
 | 
					        self.eval_start_epoch = eval_start_epoch
 | 
				
			||||||
 | 
					        if interval < 1:
 | 
				
			||||||
 | 
					            raise ValueError("interval should >= 1.")
 | 
				
			||||||
 | 
					        self.interval = interval
 | 
				
			||||||
 | 
					        self.save_best_ckpt = save_best_ckpt
 | 
				
			||||||
 | 
					        self.best_res = 0
 | 
				
			||||||
 | 
					        self.best_epoch = 0
 | 
				
			||||||
 | 
					        if not os.path.isdir(ckpt_directory):
 | 
				
			||||||
 | 
					            os.makedirs(ckpt_directory)
 | 
				
			||||||
 | 
					        self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
 | 
				
			||||||
 | 
					        self.metrics_name = metrics_name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_ckpoint_file(self, file_name):
 | 
				
			||||||
 | 
					        """Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            os.chmod(file_name, stat.S_IWRITE)
 | 
				
			||||||
 | 
					            os.remove(file_name)
 | 
				
			||||||
 | 
					        except OSError:
 | 
				
			||||||
 | 
					            logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
 | 
				
			||||||
 | 
					        except ValueError:
 | 
				
			||||||
 | 
					            logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def epoch_end(self, run_context):
 | 
				
			||||||
 | 
					        """Callback when epoch end."""
 | 
				
			||||||
 | 
					        cb_params = run_context.original_args()
 | 
				
			||||||
 | 
					        cur_epoch = cb_params.cur_epoch_num
 | 
				
			||||||
 | 
					        if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
 | 
				
			||||||
 | 
					            res = self.eval_function(self.eval_param_dict)
 | 
				
			||||||
 | 
					            print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
 | 
				
			||||||
 | 
					            if res >= self.best_res:
 | 
				
			||||||
 | 
					                self.best_res = res
 | 
				
			||||||
 | 
					                self.best_epoch = cur_epoch
 | 
				
			||||||
 | 
					                print("update best result: {}".format(res), flush=True)
 | 
				
			||||||
 | 
					                if self.save_best_ckpt:
 | 
				
			||||||
 | 
					                    if os.path.exists(self.bast_ckpt_path):
 | 
				
			||||||
 | 
					                        self.remove_ckpoint_file(self.bast_ckpt_path)
 | 
				
			||||||
 | 
					                    save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
 | 
				
			||||||
 | 
					                    print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def end(self, run_context):
 | 
				
			||||||
 | 
					        print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
 | 
				
			||||||
 | 
					                                                                                     self.best_res,
 | 
				
			||||||
 | 
					                                                                                     self.best_epoch), flush=True)
 | 
				
			||||||
@ -0,0 +1,90 @@
 | 
				
			|||||||
 | 
					# Copyright 2021 Huawei Technologies Co., Ltd
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					# ============================================================================
 | 
				
			||||||
 | 
					"""Evaluation callback when training"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import stat
 | 
				
			||||||
 | 
					from mindspore import save_checkpoint
 | 
				
			||||||
 | 
					from mindspore import log as logger
 | 
				
			||||||
 | 
					from mindspore.train.callback import Callback
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EvalCallBack(Callback):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Evaluation callback when training.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        eval_function (function): evaluation function.
 | 
				
			||||||
 | 
					        eval_param_dict (dict): evaluation parameters' configure dict.
 | 
				
			||||||
 | 
					        interval (int): run evaluation interval, default is 1.
 | 
				
			||||||
 | 
					        eval_start_epoch (int): evaluation start epoch, default is 1.
 | 
				
			||||||
 | 
					        save_best_ckpt (bool): Whether to save best checkpoint, default is True.
 | 
				
			||||||
 | 
					        besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
 | 
				
			||||||
 | 
					        metrics_name (str): evaluation metrics name, default is `acc`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Examples:
 | 
				
			||||||
 | 
					        >>> EvalCallBack(eval_function, eval_param_dict)
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
 | 
				
			||||||
 | 
					                 ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
 | 
				
			||||||
 | 
					        super(EvalCallBack, self).__init__()
 | 
				
			||||||
 | 
					        self.eval_param_dict = eval_param_dict
 | 
				
			||||||
 | 
					        self.eval_function = eval_function
 | 
				
			||||||
 | 
					        self.eval_start_epoch = eval_start_epoch
 | 
				
			||||||
 | 
					        if interval < 1:
 | 
				
			||||||
 | 
					            raise ValueError("interval should >= 1.")
 | 
				
			||||||
 | 
					        self.interval = interval
 | 
				
			||||||
 | 
					        self.save_best_ckpt = save_best_ckpt
 | 
				
			||||||
 | 
					        self.best_res = 0
 | 
				
			||||||
 | 
					        self.best_epoch = 0
 | 
				
			||||||
 | 
					        if not os.path.isdir(ckpt_directory):
 | 
				
			||||||
 | 
					            os.makedirs(ckpt_directory)
 | 
				
			||||||
 | 
					        self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
 | 
				
			||||||
 | 
					        self.metrics_name = metrics_name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_ckpoint_file(self, file_name):
 | 
				
			||||||
 | 
					        """Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            os.chmod(file_name, stat.S_IWRITE)
 | 
				
			||||||
 | 
					            os.remove(file_name)
 | 
				
			||||||
 | 
					        except OSError:
 | 
				
			||||||
 | 
					            logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
 | 
				
			||||||
 | 
					        except ValueError:
 | 
				
			||||||
 | 
					            logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def epoch_end(self, run_context):
 | 
				
			||||||
 | 
					        """Callback when epoch end."""
 | 
				
			||||||
 | 
					        cb_params = run_context.original_args()
 | 
				
			||||||
 | 
					        cur_epoch = cb_params.cur_epoch_num
 | 
				
			||||||
 | 
					        if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
 | 
				
			||||||
 | 
					            res = self.eval_function(self.eval_param_dict)
 | 
				
			||||||
 | 
					            print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
 | 
				
			||||||
 | 
					            if res >= self.best_res:
 | 
				
			||||||
 | 
					                self.best_res = res
 | 
				
			||||||
 | 
					                self.best_epoch = cur_epoch
 | 
				
			||||||
 | 
					                print("update best result: {}".format(res), flush=True)
 | 
				
			||||||
 | 
					                if self.save_best_ckpt:
 | 
				
			||||||
 | 
					                    if os.path.exists(self.bast_ckpt_path):
 | 
				
			||||||
 | 
					                        self.remove_ckpoint_file(self.bast_ckpt_path)
 | 
				
			||||||
 | 
					                    save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
 | 
				
			||||||
 | 
					                    print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def end(self, run_context):
 | 
				
			||||||
 | 
					        print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
 | 
				
			||||||
 | 
					                                                                                     self.best_res,
 | 
				
			||||||
 | 
					                                                                                     self.best_epoch), flush=True)
 | 
				
			||||||
					Loading…
					
					
				
		Reference in new issue