!14528 add ssd, resnet,unet evaluation while training process
	
		
	
				
					
				
			From: @zhao_ting_v Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34pull/14528/MERGE
						commit
						c9c8d5fe44
					
				| @ -0,0 +1,90 @@ | |||||||
|  | # Copyright 2021 Huawei Technologies Co., Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================ | ||||||
|  | """Evaluation callback when training""" | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | import stat | ||||||
|  | from mindspore import save_checkpoint | ||||||
|  | from mindspore import log as logger | ||||||
|  | from mindspore.train.callback import Callback | ||||||
|  | 
 | ||||||
|  | class EvalCallBack(Callback): | ||||||
|  |     """ | ||||||
|  |     Evaluation callback when training. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         eval_function (function): evaluation function. | ||||||
|  |         eval_param_dict (dict): evaluation parameters' configure dict. | ||||||
|  |         interval (int): run evaluation interval, default is 1. | ||||||
|  |         eval_start_epoch (int): evaluation start epoch, default is 1. | ||||||
|  |         save_best_ckpt (bool): Whether to save best checkpoint, default is True. | ||||||
|  |         besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | ||||||
|  |         metrics_name (str): evaluation metrics name, default is `acc`. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         None | ||||||
|  | 
 | ||||||
|  |     Examples: | ||||||
|  |         >>> EvalCallBack(eval_function, eval_param_dict) | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | ||||||
|  |                  ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | ||||||
|  |         super(EvalCallBack, self).__init__() | ||||||
|  |         self.eval_param_dict = eval_param_dict | ||||||
|  |         self.eval_function = eval_function | ||||||
|  |         self.eval_start_epoch = eval_start_epoch | ||||||
|  |         if interval < 1: | ||||||
|  |             raise ValueError("interval should >= 1.") | ||||||
|  |         self.interval = interval | ||||||
|  |         self.save_best_ckpt = save_best_ckpt | ||||||
|  |         self.best_res = 0 | ||||||
|  |         self.best_epoch = 0 | ||||||
|  |         if not os.path.isdir(ckpt_directory): | ||||||
|  |             os.makedirs(ckpt_directory) | ||||||
|  |         self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | ||||||
|  |         self.metrics_name = metrics_name | ||||||
|  | 
 | ||||||
|  |     def remove_ckpoint_file(self, file_name): | ||||||
|  |         """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | ||||||
|  |         try: | ||||||
|  |             os.chmod(file_name, stat.S_IWRITE) | ||||||
|  |             os.remove(file_name) | ||||||
|  |         except OSError: | ||||||
|  |             logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | ||||||
|  |         except ValueError: | ||||||
|  |             logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | ||||||
|  | 
 | ||||||
|  |     def epoch_end(self, run_context): | ||||||
|  |         """Callback when epoch end.""" | ||||||
|  |         cb_params = run_context.original_args() | ||||||
|  |         cur_epoch = cb_params.cur_epoch_num | ||||||
|  |         if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | ||||||
|  |             res = self.eval_function(self.eval_param_dict) | ||||||
|  |             print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | ||||||
|  |             if res >= self.best_res: | ||||||
|  |                 self.best_res = res | ||||||
|  |                 self.best_epoch = cur_epoch | ||||||
|  |                 print("update best result: {}".format(res), flush=True) | ||||||
|  |                 if self.save_best_ckpt: | ||||||
|  |                     if os.path.exists(self.bast_ckpt_path): | ||||||
|  |                         self.remove_ckpoint_file(self.bast_ckpt_path) | ||||||
|  |                     save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | ||||||
|  |                     print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | ||||||
|  | 
 | ||||||
|  |     def end(self, run_context): | ||||||
|  |         print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | ||||||
|  |                                                                                      self.best_res, | ||||||
|  |                                                                                      self.best_epoch), flush=True) | ||||||
| @ -0,0 +1,132 @@ | |||||||
|  | # Copyright 2021 Huawei Technologies Co., Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================ | ||||||
|  | """evaluation metric.""" | ||||||
|  | 
 | ||||||
|  | from mindspore.communication.management import GlobalComm | ||||||
|  | from mindspore.ops import operations as P | ||||||
|  | import mindspore.nn as nn | ||||||
|  | import mindspore.common.dtype as mstype | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ClassifyCorrectCell(nn.Cell): | ||||||
|  |     r""" | ||||||
|  |     Cell that returns correct count of the prediction in classification network. | ||||||
|  |     This Cell accepts a network as arguments. | ||||||
|  |     It returns orrect count of the prediction to calculate the metrics. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         network (Cell): The network Cell. | ||||||
|  | 
 | ||||||
|  |     Inputs: | ||||||
|  |         - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | ||||||
|  |         - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | ||||||
|  | 
 | ||||||
|  |     Outputs: | ||||||
|  |         Tuple, containing a scalar correct count of the prediction | ||||||
|  | 
 | ||||||
|  |     Examples: | ||||||
|  |         >>> # For a defined network Net without loss function | ||||||
|  |         >>> net = Net() | ||||||
|  |         >>> eval_net = nn.ClassifyCorrectCell(net) | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, network): | ||||||
|  |         super(ClassifyCorrectCell, self).__init__(auto_prefix=False) | ||||||
|  |         self._network = network | ||||||
|  |         self.argmax = P.Argmax() | ||||||
|  |         self.equal = P.Equal() | ||||||
|  |         self.cast = P.Cast() | ||||||
|  |         self.reduce_sum = P.ReduceSum() | ||||||
|  |         self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) | ||||||
|  | 
 | ||||||
|  |     def construct(self, data, label): | ||||||
|  |         outputs = self._network(data) | ||||||
|  |         y_pred = self.argmax(outputs) | ||||||
|  |         y_pred = self.cast(y_pred, mstype.int32) | ||||||
|  |         y_correct = self.equal(y_pred, label) | ||||||
|  |         y_correct = self.cast(y_correct, mstype.float32) | ||||||
|  |         y_correct = self.reduce_sum(y_correct) | ||||||
|  |         total_correct = self.allreduce(y_correct) | ||||||
|  |         return (total_correct,) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DistAccuracy(nn.Metric): | ||||||
|  |     r""" | ||||||
|  |     Calculates the accuracy for classification data in distributed mode. | ||||||
|  |     The accuracy class creates two local variables, correct number and total number that are used to compute the | ||||||
|  |     frequency with which predictions matches labels. This frequency is ultimately returned as the accuracy: an | ||||||
|  |     idempotent operation that simply divides correct number by total number. | ||||||
|  | 
 | ||||||
|  |     .. math:: | ||||||
|  | 
 | ||||||
|  |         \text{accuracy} =\frac{\text{true_positive} + \text{true_negative}} | ||||||
|  | 
 | ||||||
|  |         {\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}} | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         eval_type (str): Metric to calculate the accuracy over a dataset, for classification (single-label). | ||||||
|  | 
 | ||||||
|  |     Examples: | ||||||
|  |         >>> y_correct = Tensor(np.array([20])) | ||||||
|  |         >>> metric = nn.DistAccuracy(batch_size=3, device_num=8) | ||||||
|  |         >>> metric.clear() | ||||||
|  |         >>> metric.update(y_correct) | ||||||
|  |         >>> accuracy = metric.eval() | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, batch_size, device_num): | ||||||
|  |         super(DistAccuracy, self).__init__() | ||||||
|  |         self.clear() | ||||||
|  |         self.batch_size = batch_size | ||||||
|  |         self.device_num = device_num | ||||||
|  | 
 | ||||||
|  |     def clear(self): | ||||||
|  |         """Clears the internal evaluation result.""" | ||||||
|  |         self._correct_num = 0 | ||||||
|  |         self._total_num = 0 | ||||||
|  | 
 | ||||||
|  |     def update(self, *inputs): | ||||||
|  |         """ | ||||||
|  |         Updates the internal evaluation result :math:`y_{pred}` and :math:`y`. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             inputs: Input `y_correct`. `y_correct` is a `scalar Tensor`. | ||||||
|  |                 `y_correct` is the right prediction count that gathered from all devices | ||||||
|  |                 it's a scalar in float type | ||||||
|  | 
 | ||||||
|  |         Raises: | ||||||
|  |             ValueError: If the number of the input is not 1. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         if len(inputs) != 1: | ||||||
|  |             raise ValueError('Distribute accuracy needs 1 input (y_correct), but got {}'.format(len(inputs))) | ||||||
|  |         y_correct = self._convert_data(inputs[0]) | ||||||
|  |         self._correct_num += y_correct | ||||||
|  |         self._total_num += self.batch_size * self.device_num | ||||||
|  | 
 | ||||||
|  |     def eval(self): | ||||||
|  |         """ | ||||||
|  |         Computes the accuracy. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Float, the computed result. | ||||||
|  | 
 | ||||||
|  |         Raises: | ||||||
|  |             RuntimeError: If the sample size is 0. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         if self._total_num == 0: | ||||||
|  |             raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.') | ||||||
|  |         return self._correct_num / self._total_num | ||||||
| @ -0,0 +1,90 @@ | |||||||
|  | # Copyright 2021 Huawei Technologies Co., Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================ | ||||||
|  | """Evaluation callback when training""" | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | import stat | ||||||
|  | from mindspore import save_checkpoint | ||||||
|  | from mindspore import log as logger | ||||||
|  | from mindspore.train.callback import Callback | ||||||
|  | 
 | ||||||
|  | class EvalCallBack(Callback): | ||||||
|  |     """ | ||||||
|  |     Evaluation callback when training. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         eval_function (function): evaluation function. | ||||||
|  |         eval_param_dict (dict): evaluation parameters' configure dict. | ||||||
|  |         interval (int): run evaluation interval, default is 1. | ||||||
|  |         eval_start_epoch (int): evaluation start epoch, default is 1. | ||||||
|  |         save_best_ckpt (bool): Whether to save best checkpoint, default is True. | ||||||
|  |         besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | ||||||
|  |         metrics_name (str): evaluation metrics name, default is `acc`. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         None | ||||||
|  | 
 | ||||||
|  |     Examples: | ||||||
|  |         >>> EvalCallBack(eval_function, eval_param_dict) | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | ||||||
|  |                  ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | ||||||
|  |         super(EvalCallBack, self).__init__() | ||||||
|  |         self.eval_param_dict = eval_param_dict | ||||||
|  |         self.eval_function = eval_function | ||||||
|  |         self.eval_start_epoch = eval_start_epoch | ||||||
|  |         if interval < 1: | ||||||
|  |             raise ValueError("interval should >= 1.") | ||||||
|  |         self.interval = interval | ||||||
|  |         self.save_best_ckpt = save_best_ckpt | ||||||
|  |         self.best_res = 0 | ||||||
|  |         self.best_epoch = 0 | ||||||
|  |         if not os.path.isdir(ckpt_directory): | ||||||
|  |             os.makedirs(ckpt_directory) | ||||||
|  |         self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | ||||||
|  |         self.metrics_name = metrics_name | ||||||
|  | 
 | ||||||
|  |     def remove_ckpoint_file(self, file_name): | ||||||
|  |         """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | ||||||
|  |         try: | ||||||
|  |             os.chmod(file_name, stat.S_IWRITE) | ||||||
|  |             os.remove(file_name) | ||||||
|  |         except OSError: | ||||||
|  |             logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | ||||||
|  |         except ValueError: | ||||||
|  |             logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | ||||||
|  | 
 | ||||||
|  |     def epoch_end(self, run_context): | ||||||
|  |         """Callback when epoch end.""" | ||||||
|  |         cb_params = run_context.original_args() | ||||||
|  |         cur_epoch = cb_params.cur_epoch_num | ||||||
|  |         if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | ||||||
|  |             res = self.eval_function(self.eval_param_dict) | ||||||
|  |             print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | ||||||
|  |             if res >= self.best_res: | ||||||
|  |                 self.best_res = res | ||||||
|  |                 self.best_epoch = cur_epoch | ||||||
|  |                 print("update best result: {}".format(res), flush=True) | ||||||
|  |                 if self.save_best_ckpt: | ||||||
|  |                     if os.path.exists(self.bast_ckpt_path): | ||||||
|  |                         self.remove_ckpoint_file(self.bast_ckpt_path) | ||||||
|  |                     save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | ||||||
|  |                     print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | ||||||
|  | 
 | ||||||
|  |     def end(self, run_context): | ||||||
|  |         print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | ||||||
|  |                                                                                      self.best_res, | ||||||
|  |                                                                                      self.best_epoch), flush=True) | ||||||
| @ -0,0 +1,90 @@ | |||||||
|  | # Copyright 2021 Huawei Technologies Co., Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================ | ||||||
|  | """Evaluation callback when training""" | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | import stat | ||||||
|  | from mindspore import save_checkpoint | ||||||
|  | from mindspore import log as logger | ||||||
|  | from mindspore.train.callback import Callback | ||||||
|  | 
 | ||||||
|  | class EvalCallBack(Callback): | ||||||
|  |     """ | ||||||
|  |     Evaluation callback when training. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         eval_function (function): evaluation function. | ||||||
|  |         eval_param_dict (dict): evaluation parameters' configure dict. | ||||||
|  |         interval (int): run evaluation interval, default is 1. | ||||||
|  |         eval_start_epoch (int): evaluation start epoch, default is 1. | ||||||
|  |         save_best_ckpt (bool): Whether to save best checkpoint, default is True. | ||||||
|  |         besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | ||||||
|  |         metrics_name (str): evaluation metrics name, default is `acc`. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         None | ||||||
|  | 
 | ||||||
|  |     Examples: | ||||||
|  |         >>> EvalCallBack(eval_function, eval_param_dict) | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | ||||||
|  |                  ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | ||||||
|  |         super(EvalCallBack, self).__init__() | ||||||
|  |         self.eval_param_dict = eval_param_dict | ||||||
|  |         self.eval_function = eval_function | ||||||
|  |         self.eval_start_epoch = eval_start_epoch | ||||||
|  |         if interval < 1: | ||||||
|  |             raise ValueError("interval should >= 1.") | ||||||
|  |         self.interval = interval | ||||||
|  |         self.save_best_ckpt = save_best_ckpt | ||||||
|  |         self.best_res = 0 | ||||||
|  |         self.best_epoch = 0 | ||||||
|  |         if not os.path.isdir(ckpt_directory): | ||||||
|  |             os.makedirs(ckpt_directory) | ||||||
|  |         self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | ||||||
|  |         self.metrics_name = metrics_name | ||||||
|  | 
 | ||||||
|  |     def remove_ckpoint_file(self, file_name): | ||||||
|  |         """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | ||||||
|  |         try: | ||||||
|  |             os.chmod(file_name, stat.S_IWRITE) | ||||||
|  |             os.remove(file_name) | ||||||
|  |         except OSError: | ||||||
|  |             logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | ||||||
|  |         except ValueError: | ||||||
|  |             logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | ||||||
|  | 
 | ||||||
|  |     def epoch_end(self, run_context): | ||||||
|  |         """Callback when epoch end.""" | ||||||
|  |         cb_params = run_context.original_args() | ||||||
|  |         cur_epoch = cb_params.cur_epoch_num | ||||||
|  |         if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | ||||||
|  |             res = self.eval_function(self.eval_param_dict) | ||||||
|  |             print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | ||||||
|  |             if res >= self.best_res: | ||||||
|  |                 self.best_res = res | ||||||
|  |                 self.best_epoch = cur_epoch | ||||||
|  |                 print("update best result: {}".format(res), flush=True) | ||||||
|  |                 if self.save_best_ckpt: | ||||||
|  |                     if os.path.exists(self.bast_ckpt_path): | ||||||
|  |                         self.remove_ckpoint_file(self.bast_ckpt_path) | ||||||
|  |                     save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | ||||||
|  |                     print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | ||||||
|  | 
 | ||||||
|  |     def end(self, run_context): | ||||||
|  |         print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | ||||||
|  |                                                                                      self.best_res, | ||||||
|  |                                                                                      self.best_epoch), flush=True) | ||||||
					Loading…
					
					
				
		Reference in new issue