You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/parallel/_cost_model_context.py

535 lines
21 KiB

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Context of cost_model in auto_parallel"""
import threading
from mindspore._c_expression import CostModelContext
from mindspore._checkparam import args_type_check
class _CostModelContext:
"""
_CostModelContext is the environment in which operations are executed
Note:
Creating a context through instantiating Context object is not recommended.
Use cost_model_context() to get the context since Context is singleton.
"""
_instance = None
_instance_lock = threading.Lock()
def __init__(self):
self._context_handle = CostModelContext.get_instance()
def __new__(cls):
if cls._instance is None:
cls._instance_lock.acquire()
cls._instance = object.__new__(cls)
cls._instance_lock.release()
return cls._instance
def set_device_memory_capacity(self, dev_mem_cap):
"""
Set device memory capacity.
Args:
dev_mem_cap (float): The memory capacity for each device.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_device_memory_capacity(dev_mem_cap)
def get_device_memory_capacity(self):
"""
Get device memory capacity.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_device_memory_capacity()
def set_costmodel_alpha(self, alpha):
"""
Set costmodel alpha.
Args:
alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_alpha(alpha)
def get_costmodel_alpha(self):
"""
Get costmodel alpha.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_alpha()
def set_costmodel_beta(self, beta):
"""
Set costmodel beta.
Args:
beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_beta(beta)
def get_costmodel_beta(self):
"""
Get costmodel beta.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_beta()
def set_costmodel_gamma(self, gamma):
"""
Set costmodel gamma.
Args:
gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_gamma(gamma)
def get_costmodel_gamma(self):
"""
Get costmodel gamma.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_gamma()
def set_costmodel_communi_threshold(self, threshold):
"""
Set costmodel communication threshold.
Args:
threshold (float): A parameter used in adjusting communication calculation for practice.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_communi_threshold(threshold)
def get_costmodel_communi_threshold(self):
"""
Get costmodel communication threshold.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_communi_threshold()
def set_costmodel_communi_const(self, communi_const):
"""
Set costmodel communication const.
Args:
const (float): A parameter used in adjusting communication calculation for practice.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_communi_const(communi_const)
def get_costmodel_communi_const(self):
"""
Get costmodel communication const.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_communi_const()
def set_costmodel_communi_bias(self, communi_bias):
"""
Set costmodel communication bias.
Args:
bias (float): A parameter used in adjusting communication calculation for practice.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_communi_bias(communi_bias)
def get_costmodel_communi_bias(self):
"""
Get costmodel communication bias.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_communi_bias()
def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
"""
Set costmodel allreduce fusion algorithm.
Args:
algorithm (int): The AllReduce fusion algorithm of parameter gradients.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_allreduce_fusion_algorithm(algorithm)
def get_costmodel_allreduce_fusion_algorithm(self):
"""
Get costmodel allreduce fusion algorithm.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_allreduce_fusion_algorithm()
def set_costmodel_allreduce_fusion_times(self, allreduce_fusion_times):
"""
Set costmodel allreduce fusion times.
Args:
allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_allreduce_fusion_times(allreduce_fusion_times)
def get_costmodel_allreduce_fusion_times(self):
"""
Get costmodel allreduce fusion times.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_allreduce_fusion_times()
def set_costmodel_allreduce_fusion_tail_percent(self, tail_percent):
"""
Set costmodel allreduce fusion tail percent.
Args:
tail_percent (int): The percentage of backward computing time corresponding to the last parameter gradients
AllReduce in the whole backward computing time.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_allreduce_fusion_tail_percent(tail_percent)
def get_costmodel_allreduce_fusion_tail_percent(self):
"""
Get costmodel allreduce fusion tail percent.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_allreduce_fusion_tail_percent()
def set_costmodel_allreduce_fusion_tail_time(self, tail_time):
"""
Set costmodel allreduce fusion tail time.
Args:
tail_time (int): The tail time of the last parameter gradients AllReduce after the end of backward
computation.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_allreduce_fusion_tail_time(tail_time)
def get_costmodel_allreduce_fusion_tail_time(self):
"""
Get costmodel allreduce fusion tail time.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_allreduce_fusion_tail_time()
def set_costmodel_allreduce_fusion_allreduce_inherent_time(self, allreduce_inherent_time):
"""
Set costmodel allreduce fusion allreduce inherent time.
Args:
allreduce_inherent_time (int): The inherent cost time of AllReduce.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_allreduce_fusion_allreduce_inherent_time(allreduce_inherent_time)
def get_costmodel_allreduce_fusion_allreduce_inherent_time(self):
"""
Get costmodel allreduce fusion allreduce inherent time.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_allreduce_fusion_allreduce_inherent_time()
def set_costmodel_allreduce_fusion_allreduce_bandwidth(self, allreduce_bandwidth):
"""
Set costmodel allreduce fusion allreduce bandwidth.
Args:
allreduce_bandwidth (int): The bancwidth of AllReduce.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_allreduce_fusion_allreduce_bandwidth(allreduce_bandwidth)
def get_costmodel_allreduce_fusion_allreduce_bandwidth(self):
"""
Get costmodel allreduce fusion allreduce bandwidth.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_allreduce_fusion_allreduce_bandwidth()
def set_costmodel_allreduce_fusion_computation_time_parameter(self, computation_time_parameter):
"""
Set costmodel allreduce fusion computation time parameter.
Args:
computation_time_parameter (int): The parameter used to compute backward computation time.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_costmodel_allreduce_fusion_computation_time_parameter(computation_time_parameter)
def get_costmodel_allreduce_fusion_computation_time_parameter(self):
"""
Get costmodel allreduce fusion computation time parameter.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_allreduce_fusion_computation_time_parameter()
def reset_cost_model(self):
"""
Reset cost model settings.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.reset_cost_model()
_cost_model_context = None
def cost_model_context():
"""
Get the global _cost_model_context. If it is not created, create a new one.
Returns:
The global cost_model context.
"""
global _cost_model_context
if _cost_model_context is None:
_cost_model_context = _CostModelContext()
return _cost_model_context
set_cost_model_context_func_map = {
"device_memory_capacity": cost_model_context().set_device_memory_capacity,
"costmodel_alpha": cost_model_context().set_costmodel_alpha,
"costmodel_beta": cost_model_context().set_costmodel_beta,
"costmodel_gamma": cost_model_context().set_costmodel_gamma,
"costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold,
"costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
"costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
"costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
"costmodel_allreduce_fusion_tail_time": cost_model_context().set_costmodel_allreduce_fusion_tail_time,
"costmodel_allreduce_fusion_allreduce_inherent_time":
cost_model_context().set_costmodel_allreduce_fusion_allreduce_inherent_time,
"costmodel_allreduce_fusion_allreduce_bandwidth":
cost_model_context().set_costmodel_allreduce_fusion_allreduce_bandwidth,
"costmodel_allreduce_fusion_computation_time_parameter":
cost_model_context().set_costmodel_allreduce_fusion_computation_time_parameter}
get_cost_model_context_func_map = {
"device_memory_capacity": cost_model_context().get_device_memory_capacity,
"costmodel_alpha": cost_model_context().get_costmodel_alpha,
"costmodel_beta": cost_model_context().get_costmodel_beta,
"costmodel_gamma": cost_model_context().get_costmodel_gamma,
"costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
"costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
"costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
"costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
"costmodel_allreduce_fusion_tail_time": cost_model_context().get_costmodel_allreduce_fusion_tail_time,
"costmodel_allreduce_fusion_allreduce_inherent_time":
cost_model_context().get_costmodel_allreduce_fusion_allreduce_inherent_time,
"costmodel_allreduce_fusion_allreduce_bandwidth":
cost_model_context().get_costmodel_allreduce_fusion_allreduce_bandwidth,
"costmodel_allreduce_fusion_computation_time_parameter":
cost_model_context().get_costmodel_allreduce_fusion_computation_time_parameter}
@args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
costmodel_allreduce_fusion_allreduce_inherent_time=float,
costmodel_allreduce_fusion_allreduce_bandwidth=float,
costmodel_allreduce_fusion_computation_time_parameter=float)
def set_cost_model_context(**kwargs):
"""
Set cost model context.
Note:
Attribute name is needed.
Args:
device_memory_capacity (float): The memory capacity for each device.
costmodel_alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
costmodel_beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
costmodel_gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
0: bypass allreduce fusion;
1: only use backward computation time to group allreduce;
2: use backward computation time and parameter gradient allreduce time to group allreduce.
costmodel_allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
costmodel_allreduce_fusion_tail_percent (float): A parameter used in allreduce fusion algorithm. The percentage
of backward computing time corresponding to the last parameter gradients AllReduce in the whole backward
computing time.
costmodel_allreduce_fusion_tail_time (float): A parameter used in allreduce fusion algorithm. The tail time of
the last parameter gradients AllReduce after the end of backward computation.
costmodel_allreduce_fusion_allreduce_inherent_time (float): A parameter used in allreduce fusion algorithm. The
inherent cost time of AllReduce.
costmodel_allreduce_fusion_allreduce_bandwidth (float): A parameter used in allreduce fusion algorithm. The
bandwidth of AllReduce.
costmodel_allreduce_fusion_computation_time_parameter (float): A parameter used in allreduce fusion algorithm.
The parameter used to compute backward computation time.
Raises:
ValueError: If context keyword is not recognized.
"""
for key, value in kwargs.items():
if key not in set_cost_model_context_func_map:
raise ValueError("Set context keyword %s is not recognized!" % key)
set_func = set_cost_model_context_func_map[key]
set_func(value)
def get_cost_model_context(attr_key):
"""
Get cost model context attributes.
Note:
Return value according to the attribute value.
Args:
attr_key (str): The key of the attribute.
Raises:
ValueError: If context keyword is not recognized.
"""
if attr_key not in get_cost_model_context_func_map:
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
get_func = get_cost_model_context_func_map[attr_key]
return get_func()
def reset_cost_model_context():
"""Reset cost model context attributes."""
cost_model_context().reset_cost_model()