You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
535 lines
21 KiB
535 lines
21 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Context of cost_model in auto_parallel"""
|
|
import threading
|
|
from mindspore._c_expression import CostModelContext
|
|
from mindspore._checkparam import args_type_check
|
|
|
|
|
|
class _CostModelContext:
|
|
"""
|
|
_CostModelContext is the environment in which operations are executed
|
|
|
|
Note:
|
|
Creating a context through instantiating Context object is not recommended.
|
|
Use cost_model_context() to get the context since Context is singleton.
|
|
"""
|
|
_instance = None
|
|
_instance_lock = threading.Lock()
|
|
|
|
def __init__(self):
|
|
self._context_handle = CostModelContext.get_instance()
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance_lock.acquire()
|
|
cls._instance = object.__new__(cls)
|
|
cls._instance_lock.release()
|
|
return cls._instance
|
|
|
|
def set_device_memory_capacity(self, dev_mem_cap):
|
|
"""
|
|
Set device memory capacity.
|
|
|
|
Args:
|
|
dev_mem_cap (float): The memory capacity for each device.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_device_memory_capacity(dev_mem_cap)
|
|
|
|
def get_device_memory_capacity(self):
|
|
"""
|
|
Get device memory capacity.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_device_memory_capacity()
|
|
|
|
def set_costmodel_alpha(self, alpha):
|
|
"""
|
|
Set costmodel alpha.
|
|
|
|
Args:
|
|
alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_alpha(alpha)
|
|
|
|
def get_costmodel_alpha(self):
|
|
"""
|
|
Get costmodel alpha.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_alpha()
|
|
|
|
def set_costmodel_beta(self, beta):
|
|
"""
|
|
Set costmodel beta.
|
|
|
|
Args:
|
|
beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_beta(beta)
|
|
|
|
def get_costmodel_beta(self):
|
|
"""
|
|
Get costmodel beta.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_beta()
|
|
|
|
def set_costmodel_gamma(self, gamma):
|
|
"""
|
|
Set costmodel gamma.
|
|
|
|
Args:
|
|
gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_gamma(gamma)
|
|
|
|
def get_costmodel_gamma(self):
|
|
"""
|
|
Get costmodel gamma.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_gamma()
|
|
|
|
def set_costmodel_communi_threshold(self, threshold):
|
|
"""
|
|
Set costmodel communication threshold.
|
|
|
|
Args:
|
|
threshold (float): A parameter used in adjusting communication calculation for practice.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_communi_threshold(threshold)
|
|
|
|
def get_costmodel_communi_threshold(self):
|
|
"""
|
|
Get costmodel communication threshold.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_communi_threshold()
|
|
|
|
def set_costmodel_communi_const(self, communi_const):
|
|
"""
|
|
Set costmodel communication const.
|
|
|
|
Args:
|
|
const (float): A parameter used in adjusting communication calculation for practice.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_communi_const(communi_const)
|
|
|
|
def get_costmodel_communi_const(self):
|
|
"""
|
|
Get costmodel communication const.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_communi_const()
|
|
|
|
def set_costmodel_communi_bias(self, communi_bias):
|
|
"""
|
|
Set costmodel communication bias.
|
|
|
|
Args:
|
|
bias (float): A parameter used in adjusting communication calculation for practice.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_communi_bias(communi_bias)
|
|
|
|
def get_costmodel_communi_bias(self):
|
|
"""
|
|
Get costmodel communication bias.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_communi_bias()
|
|
|
|
def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
|
|
"""
|
|
Set costmodel allreduce fusion algorithm.
|
|
|
|
Args:
|
|
algorithm (int): The AllReduce fusion algorithm of parameter gradients.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_allreduce_fusion_algorithm(algorithm)
|
|
|
|
def get_costmodel_allreduce_fusion_algorithm(self):
|
|
"""
|
|
Get costmodel allreduce fusion algorithm.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_allreduce_fusion_algorithm()
|
|
|
|
def set_costmodel_allreduce_fusion_times(self, allreduce_fusion_times):
|
|
"""
|
|
Set costmodel allreduce fusion times.
|
|
|
|
Args:
|
|
allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_allreduce_fusion_times(allreduce_fusion_times)
|
|
|
|
def get_costmodel_allreduce_fusion_times(self):
|
|
"""
|
|
Get costmodel allreduce fusion times.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_allreduce_fusion_times()
|
|
|
|
def set_costmodel_allreduce_fusion_tail_percent(self, tail_percent):
|
|
"""
|
|
Set costmodel allreduce fusion tail percent.
|
|
|
|
Args:
|
|
tail_percent (int): The percentage of backward computing time corresponding to the last parameter gradients
|
|
AllReduce in the whole backward computing time.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_allreduce_fusion_tail_percent(tail_percent)
|
|
|
|
def get_costmodel_allreduce_fusion_tail_percent(self):
|
|
"""
|
|
Get costmodel allreduce fusion tail percent.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_allreduce_fusion_tail_percent()
|
|
|
|
def set_costmodel_allreduce_fusion_tail_time(self, tail_time):
|
|
"""
|
|
Set costmodel allreduce fusion tail time.
|
|
|
|
Args:
|
|
tail_time (int): The tail time of the last parameter gradients AllReduce after the end of backward
|
|
computation.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_allreduce_fusion_tail_time(tail_time)
|
|
|
|
def get_costmodel_allreduce_fusion_tail_time(self):
|
|
"""
|
|
Get costmodel allreduce fusion tail time.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_allreduce_fusion_tail_time()
|
|
|
|
def set_costmodel_allreduce_fusion_allreduce_inherent_time(self, allreduce_inherent_time):
|
|
"""
|
|
Set costmodel allreduce fusion allreduce inherent time.
|
|
|
|
Args:
|
|
allreduce_inherent_time (int): The inherent cost time of AllReduce.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_allreduce_fusion_allreduce_inherent_time(allreduce_inherent_time)
|
|
|
|
def get_costmodel_allreduce_fusion_allreduce_inherent_time(self):
|
|
"""
|
|
Get costmodel allreduce fusion allreduce inherent time.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_allreduce_fusion_allreduce_inherent_time()
|
|
|
|
def set_costmodel_allreduce_fusion_allreduce_bandwidth(self, allreduce_bandwidth):
|
|
"""
|
|
Set costmodel allreduce fusion allreduce bandwidth.
|
|
|
|
Args:
|
|
allreduce_bandwidth (int): The bancwidth of AllReduce.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_allreduce_fusion_allreduce_bandwidth(allreduce_bandwidth)
|
|
|
|
def get_costmodel_allreduce_fusion_allreduce_bandwidth(self):
|
|
"""
|
|
Get costmodel allreduce fusion allreduce bandwidth.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_allreduce_fusion_allreduce_bandwidth()
|
|
|
|
def set_costmodel_allreduce_fusion_computation_time_parameter(self, computation_time_parameter):
|
|
"""
|
|
Set costmodel allreduce fusion computation time parameter.
|
|
|
|
Args:
|
|
computation_time_parameter (int): The parameter used to compute backward computation time.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.set_costmodel_allreduce_fusion_computation_time_parameter(computation_time_parameter)
|
|
|
|
def get_costmodel_allreduce_fusion_computation_time_parameter(self):
|
|
"""
|
|
Get costmodel allreduce fusion computation time parameter.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
return self._context_handle.get_costmodel_allreduce_fusion_computation_time_parameter()
|
|
|
|
def reset_cost_model(self):
|
|
"""
|
|
Reset cost model settings.
|
|
|
|
Raises:
|
|
ValueError: If context handle is none.
|
|
"""
|
|
if self._context_handle is None:
|
|
raise ValueError("Context handle is none in context!!!")
|
|
self._context_handle.reset_cost_model()
|
|
|
|
|
|
_cost_model_context = None
|
|
|
|
|
|
def cost_model_context():
|
|
"""
|
|
Get the global _cost_model_context. If it is not created, create a new one.
|
|
|
|
Returns:
|
|
The global cost_model context.
|
|
"""
|
|
global _cost_model_context
|
|
if _cost_model_context is None:
|
|
_cost_model_context = _CostModelContext()
|
|
return _cost_model_context
|
|
|
|
|
|
set_cost_model_context_func_map = {
|
|
"device_memory_capacity": cost_model_context().set_device_memory_capacity,
|
|
"costmodel_alpha": cost_model_context().set_costmodel_alpha,
|
|
"costmodel_beta": cost_model_context().set_costmodel_beta,
|
|
"costmodel_gamma": cost_model_context().set_costmodel_gamma,
|
|
"costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold,
|
|
"costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
|
|
"costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
|
|
"costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
|
|
"costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
|
|
"costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
|
|
"costmodel_allreduce_fusion_tail_time": cost_model_context().set_costmodel_allreduce_fusion_tail_time,
|
|
"costmodel_allreduce_fusion_allreduce_inherent_time":
|
|
cost_model_context().set_costmodel_allreduce_fusion_allreduce_inherent_time,
|
|
"costmodel_allreduce_fusion_allreduce_bandwidth":
|
|
cost_model_context().set_costmodel_allreduce_fusion_allreduce_bandwidth,
|
|
"costmodel_allreduce_fusion_computation_time_parameter":
|
|
cost_model_context().set_costmodel_allreduce_fusion_computation_time_parameter}
|
|
|
|
|
|
get_cost_model_context_func_map = {
|
|
"device_memory_capacity": cost_model_context().get_device_memory_capacity,
|
|
"costmodel_alpha": cost_model_context().get_costmodel_alpha,
|
|
"costmodel_beta": cost_model_context().get_costmodel_beta,
|
|
"costmodel_gamma": cost_model_context().get_costmodel_gamma,
|
|
"costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
|
|
"costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
|
|
"costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
|
|
"costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
|
|
"costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
|
|
"costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
|
|
"costmodel_allreduce_fusion_tail_time": cost_model_context().get_costmodel_allreduce_fusion_tail_time,
|
|
"costmodel_allreduce_fusion_allreduce_inherent_time":
|
|
cost_model_context().get_costmodel_allreduce_fusion_allreduce_inherent_time,
|
|
"costmodel_allreduce_fusion_allreduce_bandwidth":
|
|
cost_model_context().get_costmodel_allreduce_fusion_allreduce_bandwidth,
|
|
"costmodel_allreduce_fusion_computation_time_parameter":
|
|
cost_model_context().get_costmodel_allreduce_fusion_computation_time_parameter}
|
|
|
|
|
|
@args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
|
|
costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
|
|
costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
|
|
costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
|
|
costmodel_allreduce_fusion_allreduce_inherent_time=float,
|
|
costmodel_allreduce_fusion_allreduce_bandwidth=float,
|
|
costmodel_allreduce_fusion_computation_time_parameter=float)
|
|
def set_cost_model_context(**kwargs):
|
|
"""
|
|
Set cost model context.
|
|
|
|
Note:
|
|
Attribute name is needed.
|
|
|
|
Args:
|
|
device_memory_capacity (float): The memory capacity for each device.
|
|
costmodel_alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
|
|
costmodel_beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
|
|
costmodel_gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
|
|
costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
|
|
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
|
|
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
|
|
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
|
|
0: bypass allreduce fusion;
|
|
1: only use backward computation time to group allreduce;
|
|
2: use backward computation time and parameter gradient allreduce time to group allreduce.
|
|
costmodel_allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
|
|
costmodel_allreduce_fusion_tail_percent (float): A parameter used in allreduce fusion algorithm. The percentage
|
|
of backward computing time corresponding to the last parameter gradients AllReduce in the whole backward
|
|
computing time.
|
|
costmodel_allreduce_fusion_tail_time (float): A parameter used in allreduce fusion algorithm. The tail time of
|
|
the last parameter gradients AllReduce after the end of backward computation.
|
|
costmodel_allreduce_fusion_allreduce_inherent_time (float): A parameter used in allreduce fusion algorithm. The
|
|
inherent cost time of AllReduce.
|
|
costmodel_allreduce_fusion_allreduce_bandwidth (float): A parameter used in allreduce fusion algorithm. The
|
|
bandwidth of AllReduce.
|
|
costmodel_allreduce_fusion_computation_time_parameter (float): A parameter used in allreduce fusion algorithm.
|
|
The parameter used to compute backward computation time.
|
|
|
|
|
|
|
|
Raises:
|
|
ValueError: If context keyword is not recognized.
|
|
"""
|
|
for key, value in kwargs.items():
|
|
if key not in set_cost_model_context_func_map:
|
|
raise ValueError("Set context keyword %s is not recognized!" % key)
|
|
set_func = set_cost_model_context_func_map[key]
|
|
set_func(value)
|
|
|
|
|
|
def get_cost_model_context(attr_key):
|
|
"""
|
|
Get cost model context attributes.
|
|
|
|
Note:
|
|
Return value according to the attribute value.
|
|
|
|
Args:
|
|
attr_key (str): The key of the attribute.
|
|
|
|
Raises:
|
|
ValueError: If context keyword is not recognized.
|
|
"""
|
|
if attr_key not in get_cost_model_context_func_map:
|
|
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
|
|
get_func = get_cost_model_context_func_map[attr_key]
|
|
return get_func()
|
|
|
|
|
|
def reset_cost_model_context():
|
|
"""Reset cost model context attributes."""
|
|
cost_model_context().reset_cost_model()
|