enabling approximation in DP algorithms

pull/8016/head
Xiaoda Zhang 4 years ago
parent 3070e9c78b
commit aa84484049

@ -28,6 +28,10 @@ namespace mindspore {
namespace parallel {
Status Edge::InitEdgeCost() {
bool has_available_cost = false;
pre_op_output_.clear();
next_op_input_.clear();
cost_map_.clear();
for (auto &swc : prev_op_->GetStrategyCost()) {
MS_EXCEPTION_IF_NULL(swc);
pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr));
@ -332,5 +336,8 @@ void Edge::SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &cost_map)
next_op_input_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.second, {}));
}
}
// Return true if there are available strategies in this edge.
bool Edge::CheckStrategyCostPossibility() { return !cost_map_.empty(); }
} // namespace parallel
} // namespace mindspore

@ -140,6 +140,8 @@ class Edge {
// In the inference phase,
Status CalculateMemoryCostForInference();
void mark_output_critical() { is_output_critical_ = 1; }
// Whether there exists any available strategy in 'cost_map_'
bool CheckStrategyCostPossibility();
private:
std::string edge_name_;

@ -41,6 +41,8 @@ bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
bool TRIANGLE_STAR_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE;
bool DP_ALGO_ENABLE_APPROX = DEFAULT_DP_ALGO_ENABLE_APPROX;
double DP_ALGO_APPROX_EPSILON = DEFAULT_DP_ALGO_APPROX_EPSILON;
void CostGraph::SetDeviceMemoryAndCostParameter() {
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
@ -170,6 +172,21 @@ void CostGraph::SetDeviceMemoryAndCostParameter() {
}
RUN_PHASE = phase;
MS_LOG(INFO) << "run_phase: " << RUN_PHASE << ".";
auto enable_approx = CostModelContext::GetInstance()->dp_algo_enable_approxi();
DP_ALGO_ENABLE_APPROX = enable_approx;
if (enable_approx) {
MS_LOG(INFO) << "dp_algo_enable_approx: true.";
} else {
MS_LOG(INFO) << "dp_algo_enable_approx: false.";
}
auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon();
if (epsilon <= 0 || epsilon > 1) {
MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]";
}
DP_ALGO_APPROX_EPSILON = epsilon;
MS_LOG(INFO) << "epsilon: " << epsilon << ".";
}
void CostGraph::RemoveOperator(const OperatorInfoPtr &op) {
@ -1901,5 +1918,31 @@ Status CostGraph::CalculateMemoryCost() {
}
return SUCCESS;
}
void CostGraph::CheckApproximateCostGraphEdges() {
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (!approximation) {
return;
}
for (auto &s_edge : edges_) {
auto &edges_vector = s_edge.second;
for (auto &edge_ptr : edges_vector) {
MS_EXCEPTION_IF_NULL(edge_ptr);
if (edge_ptr->CheckStrategyCostPossibility()) {
continue;
}
MS_LOG(INFO) << "Checking StrategyCost for edge: " << edge_ptr->edge_name()
<< " impossible, re-initing the operators and edges";
auto prev_op = edge_ptr->prev_operator();
MS_EXCEPTION_IF_NULL(prev_op);
auto next_op = edge_ptr->next_operator();
MS_EXCEPTION_IF_NULL(next_op);
// Check the 'prev_op'
prev_op->ExactStrategiesAndRelatedEdges();
// Check the 'next_op'
next_op->ExactStrategiesAndRelatedEdges();
}
}
}
} // namespace parallel
} // namespace mindspore

@ -45,6 +45,8 @@ extern size_t TENSOR_SLICE_ALIGNMENT_SIZE;
extern bool FULLY_USE_DEVICES;
extern bool ELEMENTWISE_OP_STRA_FOLLOW;
extern bool MULTI_SUBGRAPHS;
extern bool DP_ALGO_ENABLE_APPROX;
extern double DP_ALGO_APPROX_EPSILON;
extern int32_t RUN_PHASE;
extern bool TRIANGLE_STAR_STRATEGY_OVERWRITE;
@ -193,6 +195,9 @@ class CostGraph {
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
// once (instead of multiple times), this method is used to correct this.
Status CorrectOpsMemoryCost();
// When APPROXIMATION is enabled in the DP algorithm, some edges may have no valid strategies.
// This method is to re-init those edge involved operators.
void CheckApproximateCostGraphEdges();
// Needed by rec_parser
void add_inputs_tensor_name(const std::vector<std::string> &inputs_tensor_name) {
inputs_tensor_name_list_.push_back(inputs_tensor_name);

@ -65,6 +65,8 @@ void CostModelContext::ResetAlgoParameters() {
fully_use_device_ = DEFAULT_FULLY_USE_DEVICES;
elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
triangle_star_strategy_overwrite_ = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE;
dp_algo_enable_approxi_ = DEFAULT_DP_ALGO_ENABLE_APPROX;
dp_algo_approxi_epsilon_ = DEFAULT_DP_ALGO_APPROX_EPSILON;
}
void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) {
@ -73,6 +75,10 @@ void CostModelContext::set_costmodel_context_for_device(const std::string &devic
}
}
void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { dp_algo_approxi_epsilon_ = epsilon; }
void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { dp_algo_enable_approxi_ = approxi; }
void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; }
void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; }

@ -45,6 +45,8 @@ namespace parallel {
#define TRAINING_PHASE 0
#define INFERENCE_PHASE 1
#define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true;
#define DEFAULT_DP_ALGO_ENABLE_APPROX false
#define DEFAULT_DP_ALGO_APPROX_EPSILON 0.1
class CostModelContext {
public:
@ -141,6 +143,12 @@ class CostModelContext {
void set_run_phase(int32_t);
int32_t run_phase() const { return run_phase_; }
void set_dp_algo_approxi_epsilon(double);
double dp_algo_approxi_epsilon() const { return dp_algo_approxi_epsilon_; }
void set_dp_algo_enable_approxi(bool);
bool dp_algo_enable_approxi() const { return dp_algo_enable_approxi_; }
private:
CostModelContext();
static std::shared_ptr<CostModelContext> cm_context_inst_;
@ -176,6 +184,12 @@ class CostModelContext {
// whether overwrite the right-node strategy
bool triangle_star_strategy_overwrite_;
// Whether to enable APPROXIMATION in the DP algorithm.
bool dp_algo_enable_approxi_;
// When APPROXIMATION is enabled in the DP algorithm, the 'epsilon' value used in the APPROXIMATION.
double dp_algo_approxi_epsilon_;
int32_t run_phase_; // 0: 'training', 1: 'inference'
int32_t costmodel_allreduce_fusion_algorithm_;

@ -1130,6 +1130,67 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
return SUCCESS;
}
// Keep at most (1.0 / epsilon) number of available strategies for each operator.
void OperatorInfo::ApproximateStrategies() {
auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (!enable_approxi) {
return;
}
MS_LOG(INFO) << "Approximating strategy-cost for: " << name_;
auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon();
auto target_num = static_cast<size_t>(std::ceil(1.0 / epsilon));
if (strategy_cost_.size() <= target_num) {
MS_LOG(INFO) << name_ << "'s strategy number is: " << strategy_cost_.size()
<< ", no greater than target-num: " << target_num;
return;
}
std::vector<std::shared_ptr<StrategyWithCost>> ret;
auto &origin_stra_cost = strategy_cost_;
auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
auto beta = CostModelContext::GetInstance()->costmodel_beta();
// sort
std::sort(
origin_stra_cost.begin(), origin_stra_cost.end(),
[&alpha, &beta](const std::shared_ptr<StrategyWithCost> &s1, const std::shared_ptr<StrategyWithCost> &s2) {
if (alpha * s1->cost_list[0]->computation_cost_ + beta * s1->cost_list[0]->communication_with_partial_para_ <
alpha * s2->cost_list[0]->computation_cost_ + beta * s2->cost_list[0]->communication_with_partial_para_) {
return true;
}
return false;
});
size_t step_length = origin_stra_cost.size() / target_num;
for (size_t i = 0; ret.size() < target_num && static_cast<size_t>(i * step_length) < origin_stra_cost.size(); ++i) {
ret.push_back(origin_stra_cost[static_cast<size_t>(i * step_length)]);
}
strategy_cost_ = ret;
is_strategy_cost_exact_ = false;
}
void OperatorInfo::ExactStrategiesAndRelatedEdges() {
if (is_strategy_cost_exact()) {
return;
}
ClearStrategyCost();
if (GenerateStrategies(0) != SUCCESS) {
MS_LOG(EXCEPTION) << "Strategy search for Operator " << name() << " failed.";
return;
}
SetIsStrategyCostExactTrue();
// re-init the previous edges
for (auto &prev_edge : prev_edges()) {
if (prev_edge->InitEdgeCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Edge: " << prev_edge->edge_name() << " cost init failed.";
}
}
// re-init the successive edges
for (auto &next_edge : succ_edges()) {
if (next_edge->InitEdgeCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Edge: " << next_edge->edge_name() << " cost init failed.";
}
}
}
int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() {
if (is_output_parameter_involve_ != -1) {
return is_output_parameter_involve_;

@ -138,6 +138,13 @@ class OperatorInfo {
}
StrategyPtr selected_strategy() const { return selected_strategy_; }
CostPtr selected_cost() const { return selected_cost_; }
// Approximate the list of available strategies
void ApproximateStrategies();
// Make the list of available strategies exact and re-init the related edges incident to this operator
void ExactStrategiesAndRelatedEdges();
bool is_strategy_cost_exact() { return is_strategy_cost_exact_; }
void SetIsStrategyCostExactTrue() { is_strategy_cost_exact_ = true; }
void ClearStrategyCost() { strategy_cost_.clear(); }
void CheckSelectedStrategy(const StrategyPtr &);
Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); }
void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; }
@ -263,6 +270,8 @@ class OperatorInfo {
int32_t used_devices_ = -1;
// the repeated_calc_num_ will be inserted to the last dimension of dev matrix in default
bool repeated_num_in_dev_matrix_right_ = true;
// Whether the list of available strategies is exact or approximate
bool is_strategy_cost_exact_ = true;
private:
OperatorCostPtr operator_cost_;

@ -408,6 +408,12 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
return nullptr;
}
// If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (approximation) {
operator_info->ApproximateStrategies();
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
}
} else {
// In this case, the configured strategy should be extracted to help setting cost
StrategyPtr strategyPtr;
@ -695,6 +701,11 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
}
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
}
// If 'approximation' is enabled, the edges need to be checked have effective costs.
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (approximation) {
entire_costgraph->CheckApproximateCostGraphEdges();
}
MS_LOG(INFO) << "Constructing edges for cost graph ends.";
}
@ -800,6 +811,11 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
}
std::shared_ptr<Edge> edge_ptr =
std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true);
// If 'approximation' is enabled, the edges need to be checked have effective costs.
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (approximation) {
target_op_info->ExactStrategiesAndRelatedEdges();
}
if (edge_ptr->InitEdgeCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Edge cost initialization failed";

@ -246,6 +246,14 @@ PYBIND11_MODULE(_c_expression, m) {
"Set the parameter elementwise_op_strategy_follow in the DP algorithm.")
.def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow,
"Get the parameter elementwise_op_strategy_follow in the DP algorithm.")
.def("set_dp_algo_enable_approxi", &CostModelContext::set_dp_algo_enable_approxi,
"Set the flag whether enabling approximation in the DP algorithm.")
.def("get_dp_algo_enable_approxi", &CostModelContext::dp_algo_enable_approxi,
"Get the flag whether enabling approximation in the DP algorithm.")
.def("set_dp_algo_approxi_epsilon", &CostModelContext::set_dp_algo_approxi_epsilon,
"Set the epsilon which is used in the approximation of DP algorithm.")
.def("get_dp_algo_approxi_epsilon", &CostModelContext::dp_algo_approxi_epsilon,
"Get the epsilon which is used in the approximation of DP algorithm.")
.def("reset_cost_model", &CostModelContext::ResetCostModel, "Reset the CostModelContext.")
.def("reset_algo_parameters", &CostModelContext::ResetAlgoParameters, "Reset the AlgoParameters.");

@ -88,6 +88,22 @@ class _AlgoParameterConfig():
self.check_config_handle()
return self._config_handle.get_tensor_slice_align_size()
def set_dp_algo_enable_approxi(self, enable_flag):
self.check_config_handle()
self._config_handle.set_dp_algo_enable_approxi(enable_flag)
def get_dp_algo_enable_approxi(self):
self.check_config_handle()
return self._config_handle.get_dp_algo_enable_approxi()
def set_dp_algo_approxi_epsilon(self, epsilon):
self.check_config_handle()
self._config_handle.set_dp_algo_approxi_epsilon(epsilon)
def get_dp_algo_approxi_epsilon(self):
self.check_config_handle()
return self._config_handle.get_dp_algo_approxi_epsilon()
def reset_algo_parameters(self):
self.check_config_handle()
self._config_handle.reset_algo_parameters()
@ -113,18 +129,23 @@ set_algo_parameters_config_func_map = {
"fully_use_devices": _algo_parameter_config().set_fully_use_devices,
"elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow,
"tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable,
"tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size}
"tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size,
"enable_algo_approxi": _algo_parameter_config().set_dp_algo_enable_approxi,
"algo_approxi_epsilon": _algo_parameter_config().set_dp_algo_approxi_epsilon}
get_algo_parameters_config_func_map = {
"fully_use_devices": _algo_parameter_config().get_fully_use_devices,
"elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow,
"tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable,
"tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size}
"tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size,
"enable_algo_approxi": _algo_parameter_config().get_dp_algo_enable_approxi,
"algo_approxi_epsilon": _algo_parameter_config().get_dp_algo_approxi_epsilon}
@args_type_check(tensor_slice_align_enable=bool, tensor_slice_align_size=int,
fully_use_devices=bool, elementwise_op_strategy_follow=bool)
fully_use_devices=bool, elementwise_op_strategy_follow=bool,
enable_algo_approxi=bool, algo_approxi_epsilon=float)
def set_algo_parameters(**kwargs):
"""
Set algo parameter config.
@ -139,6 +160,8 @@ def set_algo_parameters(**kwargs):
fully_use_devices (bool): Whether ONLY generating strategies that fully use all available devices. Default: True
elementwise_op_strategy_follow (bool): Whether the elementwise operator has the same strategies as its
subsequent operators. Default: False
enable_algo_approxi (bool): Whether to enable the approximation in the DP algorithms.
algo_approxi_epsilon (float): The epsilon value used int the approximation DP algorithm.
Raises:
ValueError: If context keyword is not recognized.

@ -686,6 +686,33 @@ def test_train_8k_8p_gpu(batch_size=32, num_classes=8192):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
set_algo_parameters(elementwise_op_strategy_follow=True)
#set_algo_parameters(enable_algo_approxi=True)
resset_op_id()
np.random.seed(6)
input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
label_np = np.zeros([batch_size]).astype(np.int32)
for i in range(0, batch_size):
label_np[i] = i % num_classes
dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
net = resnet50(num_classes)
loss = SoftmaxCrossEntropyExpand(sparse=True)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
model = Model(net, loss_fn=loss, optimizer=opt)
model.train(5, dataset, dataset_sink_mode=False)
strategies = _executor._get_shard_strategy(model._train_network)
for (k, v) in strategies.items():
if re.search('Conv2D-op', k) is not None:
assert v[0][0] == dev_num
elif re.search('MatMul-op', k) is not None:
assert v == [[1, 1], [dev_num, 1]]
elif re.search('ReduceSum-op', k) is not None:
assert v == [[1, dev_num]]
def test_train_8k_8p_gpu_approxi(batch_size=32, num_classes=8192):
dev_num = 8
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
set_algo_parameters(enable_algo_approxi=True)
resset_op_id()
np.random.seed(6)
input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)

@ -105,7 +105,8 @@ def test_two_matmul():
assert costmodel_communi_bias == 1024.0
set_algo_parameters(tensor_slice_align_enable=False, tensor_slice_align_size=32,
fully_use_devices=False, elementwise_op_strategy_follow=False)
fully_use_devices=False, elementwise_op_strategy_follow=False,
enable_algo_approxi=True, algo_approxi_epsilon=0.001)
para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
assert not para_slice_align_enable
para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
@ -114,6 +115,10 @@ def test_two_matmul():
assert not fully_use_devices
elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow")
assert not elementwise_op_strategy_follow
enable_approxi = get_algo_parameters("enable_algo_approxi")
assert enable_approxi
algo_epsilon = get_algo_parameters("algo_approxi_epsilon")
assert algo_epsilon == 0.001
reset_algo_parameters()
para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
@ -124,6 +129,10 @@ def test_two_matmul():
assert fully_use_devices
elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow")
assert not elementwise_op_strategy_follow
enable_approxi = get_algo_parameters("enable_algo_approxi")
assert not enable_approxi
algo_epsilon = get_algo_parameters("algo_approxi_epsilon")
assert algo_epsilon == 0.1
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)

Loading…
Cancel
Save