fix reshape strategy search bug in auto parallel

pull/7380/head
yao_yf 4 years ago
parent 022005b94a
commit 4c1d4924cb

@ -423,6 +423,7 @@ void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &stra
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_); std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
swc->cost_list.push_back(result); swc->cost_list.push_back(result);
strategy_cost_.emplace_back(swc); strategy_cost_.emplace_back(swc);
ResetQueueMember();
} }
Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { Status ReshapeInfo::GenerateStrategies(int32_t stage_id) {

@ -223,7 +223,8 @@ Status TensorRedistribution::ComputeCost() {
} else { } else {
prev_shape = from_.tensor_shape().array(); prev_shape = from_.tensor_shape().array();
} }
double prev_prod = std::accumulate(prev_shape.begin(), prev_shape.end(), 1, std::multiplies<int>()); double prev_prod =
std::accumulate(prev_shape.begin(), prev_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
computation_cost_ += 2.0 * prev_prod; computation_cost_ += 2.0 * prev_prod;
memory_cost_ += 2.0 * prev_prod; memory_cost_ += 2.0 * prev_prod;
} }

Loading…
Cancel
Save