!14528 add ssd, resnet,unet evaluation while training process
	
		
	
				
					
				
			From: @zhao_ting_v Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34pull/14528/MERGE
						commit
						c9c8d5fe44
					
				| @ -0,0 +1,90 @@ | ||||
| # Copyright 2021 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================ | ||||
| """Evaluation callback when training""" | ||||
| 
 | ||||
| import os | ||||
| import stat | ||||
| from mindspore import save_checkpoint | ||||
| from mindspore import log as logger | ||||
| from mindspore.train.callback import Callback | ||||
| 
 | ||||
| class EvalCallBack(Callback): | ||||
|     """ | ||||
|     Evaluation callback when training. | ||||
| 
 | ||||
|     Args: | ||||
|         eval_function (function): evaluation function. | ||||
|         eval_param_dict (dict): evaluation parameters' configure dict. | ||||
|         interval (int): run evaluation interval, default is 1. | ||||
|         eval_start_epoch (int): evaluation start epoch, default is 1. | ||||
|         save_best_ckpt (bool): Whether to save best checkpoint, default is True. | ||||
|         besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | ||||
|         metrics_name (str): evaluation metrics name, default is `acc`. | ||||
| 
 | ||||
|     Returns: | ||||
|         None | ||||
| 
 | ||||
|     Examples: | ||||
|         >>> EvalCallBack(eval_function, eval_param_dict) | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | ||||
|                  ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | ||||
|         super(EvalCallBack, self).__init__() | ||||
|         self.eval_param_dict = eval_param_dict | ||||
|         self.eval_function = eval_function | ||||
|         self.eval_start_epoch = eval_start_epoch | ||||
|         if interval < 1: | ||||
|             raise ValueError("interval should >= 1.") | ||||
|         self.interval = interval | ||||
|         self.save_best_ckpt = save_best_ckpt | ||||
|         self.best_res = 0 | ||||
|         self.best_epoch = 0 | ||||
|         if not os.path.isdir(ckpt_directory): | ||||
|             os.makedirs(ckpt_directory) | ||||
|         self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | ||||
|         self.metrics_name = metrics_name | ||||
| 
 | ||||
|     def remove_ckpoint_file(self, file_name): | ||||
|         """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | ||||
|         try: | ||||
|             os.chmod(file_name, stat.S_IWRITE) | ||||
|             os.remove(file_name) | ||||
|         except OSError: | ||||
|             logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | ||||
|         except ValueError: | ||||
|             logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | ||||
| 
 | ||||
|     def epoch_end(self, run_context): | ||||
|         """Callback when epoch end.""" | ||||
|         cb_params = run_context.original_args() | ||||
|         cur_epoch = cb_params.cur_epoch_num | ||||
|         if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | ||||
|             res = self.eval_function(self.eval_param_dict) | ||||
|             print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | ||||
|             if res >= self.best_res: | ||||
|                 self.best_res = res | ||||
|                 self.best_epoch = cur_epoch | ||||
|                 print("update best result: {}".format(res), flush=True) | ||||
|                 if self.save_best_ckpt: | ||||
|                     if os.path.exists(self.bast_ckpt_path): | ||||
|                         self.remove_ckpoint_file(self.bast_ckpt_path) | ||||
|                     save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | ||||
|                     print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | ||||
| 
 | ||||
|     def end(self, run_context): | ||||
|         print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | ||||
|                                                                                      self.best_res, | ||||
|                                                                                      self.best_epoch), flush=True) | ||||
| @ -0,0 +1,132 @@ | ||||
| # Copyright 2021 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================ | ||||
| """evaluation metric.""" | ||||
| 
 | ||||
| from mindspore.communication.management import GlobalComm | ||||
| from mindspore.ops import operations as P | ||||
| import mindspore.nn as nn | ||||
| import mindspore.common.dtype as mstype | ||||
| 
 | ||||
| 
 | ||||
| class ClassifyCorrectCell(nn.Cell): | ||||
|     r""" | ||||
|     Cell that returns correct count of the prediction in classification network. | ||||
|     This Cell accepts a network as arguments. | ||||
|     It returns orrect count of the prediction to calculate the metrics. | ||||
| 
 | ||||
|     Args: | ||||
|         network (Cell): The network Cell. | ||||
| 
 | ||||
|     Inputs: | ||||
|         - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | ||||
|         - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | ||||
| 
 | ||||
|     Outputs: | ||||
|         Tuple, containing a scalar correct count of the prediction | ||||
| 
 | ||||
|     Examples: | ||||
|         >>> # For a defined network Net without loss function | ||||
|         >>> net = Net() | ||||
|         >>> eval_net = nn.ClassifyCorrectCell(net) | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, network): | ||||
|         super(ClassifyCorrectCell, self).__init__(auto_prefix=False) | ||||
|         self._network = network | ||||
|         self.argmax = P.Argmax() | ||||
|         self.equal = P.Equal() | ||||
|         self.cast = P.Cast() | ||||
|         self.reduce_sum = P.ReduceSum() | ||||
|         self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) | ||||
| 
 | ||||
|     def construct(self, data, label): | ||||
|         outputs = self._network(data) | ||||
|         y_pred = self.argmax(outputs) | ||||
|         y_pred = self.cast(y_pred, mstype.int32) | ||||
|         y_correct = self.equal(y_pred, label) | ||||
|         y_correct = self.cast(y_correct, mstype.float32) | ||||
|         y_correct = self.reduce_sum(y_correct) | ||||
|         total_correct = self.allreduce(y_correct) | ||||
|         return (total_correct,) | ||||
| 
 | ||||
| 
 | ||||
| class DistAccuracy(nn.Metric): | ||||
|     r""" | ||||
|     Calculates the accuracy for classification data in distributed mode. | ||||
|     The accuracy class creates two local variables, correct number and total number that are used to compute the | ||||
|     frequency with which predictions matches labels. This frequency is ultimately returned as the accuracy: an | ||||
|     idempotent operation that simply divides correct number by total number. | ||||
| 
 | ||||
|     .. math:: | ||||
| 
 | ||||
|         \text{accuracy} =\frac{\text{true_positive} + \text{true_negative}} | ||||
| 
 | ||||
|         {\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}} | ||||
| 
 | ||||
|     Args: | ||||
|         eval_type (str): Metric to calculate the accuracy over a dataset, for classification (single-label). | ||||
| 
 | ||||
|     Examples: | ||||
|         >>> y_correct = Tensor(np.array([20])) | ||||
|         >>> metric = nn.DistAccuracy(batch_size=3, device_num=8) | ||||
|         >>> metric.clear() | ||||
|         >>> metric.update(y_correct) | ||||
|         >>> accuracy = metric.eval() | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, batch_size, device_num): | ||||
|         super(DistAccuracy, self).__init__() | ||||
|         self.clear() | ||||
|         self.batch_size = batch_size | ||||
|         self.device_num = device_num | ||||
| 
 | ||||
|     def clear(self): | ||||
|         """Clears the internal evaluation result.""" | ||||
|         self._correct_num = 0 | ||||
|         self._total_num = 0 | ||||
| 
 | ||||
|     def update(self, *inputs): | ||||
|         """ | ||||
|         Updates the internal evaluation result :math:`y_{pred}` and :math:`y`. | ||||
| 
 | ||||
|         Args: | ||||
|             inputs: Input `y_correct`. `y_correct` is a `scalar Tensor`. | ||||
|                 `y_correct` is the right prediction count that gathered from all devices | ||||
|                 it's a scalar in float type | ||||
| 
 | ||||
|         Raises: | ||||
|             ValueError: If the number of the input is not 1. | ||||
|         """ | ||||
| 
 | ||||
|         if len(inputs) != 1: | ||||
|             raise ValueError('Distribute accuracy needs 1 input (y_correct), but got {}'.format(len(inputs))) | ||||
|         y_correct = self._convert_data(inputs[0]) | ||||
|         self._correct_num += y_correct | ||||
|         self._total_num += self.batch_size * self.device_num | ||||
| 
 | ||||
|     def eval(self): | ||||
|         """ | ||||
|         Computes the accuracy. | ||||
| 
 | ||||
|         Returns: | ||||
|             Float, the computed result. | ||||
| 
 | ||||
|         Raises: | ||||
|             RuntimeError: If the sample size is 0. | ||||
|         """ | ||||
| 
 | ||||
|         if self._total_num == 0: | ||||
|             raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.') | ||||
|         return self._correct_num / self._total_num | ||||
| @ -0,0 +1,90 @@ | ||||
| # Copyright 2021 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================ | ||||
| """Evaluation callback when training""" | ||||
| 
 | ||||
| import os | ||||
| import stat | ||||
| from mindspore import save_checkpoint | ||||
| from mindspore import log as logger | ||||
| from mindspore.train.callback import Callback | ||||
| 
 | ||||
| class EvalCallBack(Callback): | ||||
|     """ | ||||
|     Evaluation callback when training. | ||||
| 
 | ||||
|     Args: | ||||
|         eval_function (function): evaluation function. | ||||
|         eval_param_dict (dict): evaluation parameters' configure dict. | ||||
|         interval (int): run evaluation interval, default is 1. | ||||
|         eval_start_epoch (int): evaluation start epoch, default is 1. | ||||
|         save_best_ckpt (bool): Whether to save best checkpoint, default is True. | ||||
|         besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | ||||
|         metrics_name (str): evaluation metrics name, default is `acc`. | ||||
| 
 | ||||
|     Returns: | ||||
|         None | ||||
| 
 | ||||
|     Examples: | ||||
|         >>> EvalCallBack(eval_function, eval_param_dict) | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | ||||
|                  ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | ||||
|         super(EvalCallBack, self).__init__() | ||||
|         self.eval_param_dict = eval_param_dict | ||||
|         self.eval_function = eval_function | ||||
|         self.eval_start_epoch = eval_start_epoch | ||||
|         if interval < 1: | ||||
|             raise ValueError("interval should >= 1.") | ||||
|         self.interval = interval | ||||
|         self.save_best_ckpt = save_best_ckpt | ||||
|         self.best_res = 0 | ||||
|         self.best_epoch = 0 | ||||
|         if not os.path.isdir(ckpt_directory): | ||||
|             os.makedirs(ckpt_directory) | ||||
|         self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | ||||
|         self.metrics_name = metrics_name | ||||
| 
 | ||||
|     def remove_ckpoint_file(self, file_name): | ||||
|         """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | ||||
|         try: | ||||
|             os.chmod(file_name, stat.S_IWRITE) | ||||
|             os.remove(file_name) | ||||
|         except OSError: | ||||
|             logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | ||||
|         except ValueError: | ||||
|             logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | ||||
| 
 | ||||
|     def epoch_end(self, run_context): | ||||
|         """Callback when epoch end.""" | ||||
|         cb_params = run_context.original_args() | ||||
|         cur_epoch = cb_params.cur_epoch_num | ||||
|         if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | ||||
|             res = self.eval_function(self.eval_param_dict) | ||||
|             print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | ||||
|             if res >= self.best_res: | ||||
|                 self.best_res = res | ||||
|                 self.best_epoch = cur_epoch | ||||
|                 print("update best result: {}".format(res), flush=True) | ||||
|                 if self.save_best_ckpt: | ||||
|                     if os.path.exists(self.bast_ckpt_path): | ||||
|                         self.remove_ckpoint_file(self.bast_ckpt_path) | ||||
|                     save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | ||||
|                     print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | ||||
| 
 | ||||
|     def end(self, run_context): | ||||
|         print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | ||||
|                                                                                      self.best_res, | ||||
|                                                                                      self.best_epoch), flush=True) | ||||
| @ -0,0 +1,90 @@ | ||||
| # Copyright 2021 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================ | ||||
| """Evaluation callback when training""" | ||||
| 
 | ||||
| import os | ||||
| import stat | ||||
| from mindspore import save_checkpoint | ||||
| from mindspore import log as logger | ||||
| from mindspore.train.callback import Callback | ||||
| 
 | ||||
| class EvalCallBack(Callback): | ||||
|     """ | ||||
|     Evaluation callback when training. | ||||
| 
 | ||||
|     Args: | ||||
|         eval_function (function): evaluation function. | ||||
|         eval_param_dict (dict): evaluation parameters' configure dict. | ||||
|         interval (int): run evaluation interval, default is 1. | ||||
|         eval_start_epoch (int): evaluation start epoch, default is 1. | ||||
|         save_best_ckpt (bool): Whether to save best checkpoint, default is True. | ||||
|         besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | ||||
|         metrics_name (str): evaluation metrics name, default is `acc`. | ||||
| 
 | ||||
|     Returns: | ||||
|         None | ||||
| 
 | ||||
|     Examples: | ||||
|         >>> EvalCallBack(eval_function, eval_param_dict) | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | ||||
|                  ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | ||||
|         super(EvalCallBack, self).__init__() | ||||
|         self.eval_param_dict = eval_param_dict | ||||
|         self.eval_function = eval_function | ||||
|         self.eval_start_epoch = eval_start_epoch | ||||
|         if interval < 1: | ||||
|             raise ValueError("interval should >= 1.") | ||||
|         self.interval = interval | ||||
|         self.save_best_ckpt = save_best_ckpt | ||||
|         self.best_res = 0 | ||||
|         self.best_epoch = 0 | ||||
|         if not os.path.isdir(ckpt_directory): | ||||
|             os.makedirs(ckpt_directory) | ||||
|         self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | ||||
|         self.metrics_name = metrics_name | ||||
| 
 | ||||
|     def remove_ckpoint_file(self, file_name): | ||||
|         """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | ||||
|         try: | ||||
|             os.chmod(file_name, stat.S_IWRITE) | ||||
|             os.remove(file_name) | ||||
|         except OSError: | ||||
|             logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | ||||
|         except ValueError: | ||||
|             logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | ||||
| 
 | ||||
|     def epoch_end(self, run_context): | ||||
|         """Callback when epoch end.""" | ||||
|         cb_params = run_context.original_args() | ||||
|         cur_epoch = cb_params.cur_epoch_num | ||||
|         if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | ||||
|             res = self.eval_function(self.eval_param_dict) | ||||
|             print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | ||||
|             if res >= self.best_res: | ||||
|                 self.best_res = res | ||||
|                 self.best_epoch = cur_epoch | ||||
|                 print("update best result: {}".format(res), flush=True) | ||||
|                 if self.save_best_ckpt: | ||||
|                     if os.path.exists(self.bast_ckpt_path): | ||||
|                         self.remove_ckpoint_file(self.bast_ckpt_path) | ||||
|                     save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | ||||
|                     print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | ||||
| 
 | ||||
|     def end(self, run_context): | ||||
|         print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | ||||
|                                                                                      self.best_res, | ||||
|                                                                                      self.best_epoch), flush=True) | ||||
					Loading…
					
					
				
		Reference in new issue