|
|
|
@ -133,9 +133,13 @@ Status ReshapeInfo::GetParameterInput() {
|
|
|
|
|
|
|
|
|
|
Status ReshapeInfo::ComputeReplaceOp() {
|
|
|
|
|
RankList dev_list = global_device_list();
|
|
|
|
|
TensorRedistribution tensor_redistribution(true, true);
|
|
|
|
|
TensorRedistribution tensor_redistribution(!is_generating_costs_, true);
|
|
|
|
|
if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed.";
|
|
|
|
|
if (is_generating_costs_) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": tensor_redistribution init failed.";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed.";
|
|
|
|
|
}
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString();
|
|
|
|
@ -143,7 +147,11 @@ Status ReshapeInfo::ComputeReplaceOp() {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size();
|
|
|
|
|
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
|
|
|
|
|
if (redistribution_oplist_ptr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed.";
|
|
|
|
|
if (is_generating_costs_) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed.";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed.";
|
|
|
|
|
}
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
replace_op_ = redistribution_oplist_ptr->first;
|
|
|
|
@ -444,6 +452,7 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) {
|
|
|
|
|
Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
|
|
|
|
|
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs,
|
|
|
|
|
int32_t out_index, int32_t in_index, bool is_prev_param) {
|
|
|
|
|
is_generating_costs_ = true;
|
|
|
|
|
for (auto pre_stra_cost : pre_stra_costs) {
|
|
|
|
|
std::vector<TensorInfo> pre_out_tensor_infos;
|
|
|
|
|
if (is_prev_param) {
|
|
|
|
@ -488,6 +497,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
|
|
|
|
|
SetCostForReshape(reshape_stra);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
is_generating_costs_ = false;
|
|
|
|
|
if (strategy_cost_.empty()) {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|