From 1be752f0982360ccfabfa14ec998980f037b0e01 Mon Sep 17 00:00:00 2001 From: luopengting Date: Fri, 23 Oct 2020 18:17:41 +0800 Subject: [PATCH] collect custom lineage data in optimizer auto --- .../train/callback/_summary_collector.py | 58 ++++++++++++++++--- 1 file changed, 49 insertions(+), 9 deletions(-) diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 3ce324bb66..9c87fca475 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -17,6 +17,7 @@ import os import re import json +from json.decoder import JSONDecodeError from importlib import import_module @@ -34,6 +35,9 @@ from mindspore.nn.optim.optimizer import Optimizer from mindspore.nn.loss.loss import _Loss from mindspore.train._utils import check_value_type +HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG" +HYPER_CONFIG_LEN_LIMIT = 100000 + class LineageMetadata: """Initialize parameters used in model lineage management.""" @@ -188,8 +192,7 @@ class SummaryCollector(Callback): msg = f"For 'collect_specified_data' the value after processing is: {self._collect_specified_data}." logger.info(msg) - self._check_custom_lineage_data(custom_lineage_data) - self._custom_lineage_data = custom_lineage_data + self._custom_lineage_data = self._process_custom_lineage_data(custom_lineage_data) self._temp_optimizer = None self._has_saved_graph = False @@ -232,8 +235,7 @@ class SummaryCollector(Callback): if value <= 0: raise ValueError(f'For `{name}` the value should be greater than 0, but got `{value}`.') - @staticmethod - def _check_custom_lineage_data(custom_lineage_data): + def _process_custom_lineage_data(self, custom_lineage_data): """ Check user custom lineage data. @@ -244,12 +246,50 @@ class SummaryCollector(Callback): TypeError: If the type of parameters is invalid. """ if custom_lineage_data is None: - return + custom_lineage_data = {} + self._check_custom_lineage_type('custom_lineage_data', custom_lineage_data) + + auto_custom_lineage_data = self._collect_optimizer_custom_lineage_data() + self._check_custom_lineage_type('auto_custom_lineage_data', auto_custom_lineage_data) + # the priority of user defined info is higher than auto collected info + auto_custom_lineage_data.update(custom_lineage_data) + custom_lineage_data = auto_custom_lineage_data + + return custom_lineage_data + + def _check_custom_lineage_type(self, param_name, custom_lineage): + """Check custom lineage type.""" + check_value_type(param_name, custom_lineage, [dict, type(None)]) + for key, value in custom_lineage.items(): + check_value_type(f'{param_name} -> {key}', key, str) + check_value_type(f'the value of {param_name} -> {key}', value, (int, str, float)) + + def _collect_optimizer_custom_lineage_data(self): + """Collect custom lineage data if mindoptimizer has set the hyper config""" + auto_custom_lineage_data = {} + + hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME) + if hyper_config is None: + logger.debug("Hyper config is not in system environment.") + return auto_custom_lineage_data + if len(hyper_config) > HYPER_CONFIG_LEN_LIMIT: + logger.warning("Hyper config is too long. The length limit is %s, the length of " + "hyper_config is %s." % (HYPER_CONFIG_LEN_LIMIT, len(hyper_config))) + return auto_custom_lineage_data - check_value_type('custom_lineage_data', custom_lineage_data, [dict, type(None)]) - for key, value in custom_lineage_data.items(): - check_value_type(f'custom_lineage_data -> {key}', key, str) - check_value_type(f'the value of custom_lineage_data -> {key}', value, (int, str, float)) + try: + hyper_config = json.loads(hyper_config) + except (TypeError, JSONDecodeError) as exc: + logger.warning("Hyper config decode error. Detail: %s." % str(exc)) + return auto_custom_lineage_data + + custom_lineage_data = hyper_config.get("custom_lineage_data") + if custom_lineage_data is None: + logger.info("No custom lineage data in hyper config. Please check the custom lineage data " + "if custom parameters exist in the configuration file.") + auto_custom_lineage_data = custom_lineage_data if custom_lineage_data is not None else {} + + return auto_custom_lineage_data @staticmethod def _check_action(action):