|
|
|
@ -999,18 +999,6 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra) {
|
|
|
|
|
Shape shape = pre_out_tensor_info.shape();
|
|
|
|
|
Shape slice_shape = pre_out_tensor_info.slice_shape();
|
|
|
|
|
for (size_t i = 0; i < shape.size(); ++i) {
|
|
|
|
|
if ((slice_shape[i] == 0) || (shape[i] % slice_shape[i] != 0)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "slice_shape is wrong in reshape operator";
|
|
|
|
|
}
|
|
|
|
|
int32_t dim = (int32_t)(shape[i] / slice_shape[i]);
|
|
|
|
|
(*stra).push_back(dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
for (auto node : all_nodes) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
@ -1054,46 +1042,10 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
reshape_info->set_next_operator_name(next_operator_info->name());
|
|
|
|
|
reshape_info->set_next_operator_index(in_index);
|
|
|
|
|
}
|
|
|
|
|
for (auto pre_stra_cost : pre_stra_costs) {
|
|
|
|
|
std::vector<TensorInfo> pre_out_tensor_infos;
|
|
|
|
|
if (pre_node->isa<Parameter>()) {
|
|
|
|
|
pre_out_tensor_infos = pre_stra_cost->inputs_ptr;
|
|
|
|
|
} else {
|
|
|
|
|
pre_out_tensor_infos = pre_stra_cost->outputs_ptr;
|
|
|
|
|
}
|
|
|
|
|
if (pre_out_tensor_infos.size() <= IntToSize(out_index)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "out_index is out of range of the tensor_infos in setting reshape's input_layout";
|
|
|
|
|
}
|
|
|
|
|
TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index];
|
|
|
|
|
TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout();
|
|
|
|
|
reshape_info->SetInputLayout(pre_out_tensor_layout);
|
|
|
|
|
// infer pre_node output strategy from output_layout.
|
|
|
|
|
Dimensions stra;
|
|
|
|
|
InferStraByTensorInfo(pre_out_tensor_info, &stra);
|
|
|
|
|
std::vector<Dimensions> stra_inputs = {stra};
|
|
|
|
|
StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
|
|
|
|
|
if (next_stra_costs.empty()) {
|
|
|
|
|
if (reshape_info->Init(nullptr) == FAILED) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:operator reshape init failed";
|
|
|
|
|
}
|
|
|
|
|
// set cost for each input_layout and output_layout pairs.
|
|
|
|
|
reshape_info->SetCostForReshape(reshape_stra);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
for (auto next_stra_cost : next_stra_costs) {
|
|
|
|
|
std::vector<TensorInfo> next_in_tensor_infos = next_stra_cost->inputs_ptr;
|
|
|
|
|
if (next_in_tensor_infos.size() <= IntToSize(in_index)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "in_index is out of range of the tensor_infos in setting reshape's output_layout";
|
|
|
|
|
}
|
|
|
|
|
TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index];
|
|
|
|
|
TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout();
|
|
|
|
|
reshape_info->SetOutputLayout(next_in_tensor_layout);
|
|
|
|
|
if (reshape_info->Init(nullptr) == FAILED) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:operator reshape init failed";
|
|
|
|
|
}
|
|
|
|
|
// set cost for each input_layout and output_layout pairs.
|
|
|
|
|
reshape_info->SetCostForReshape(reshape_stra);
|
|
|
|
|
}
|
|
|
|
|
bool is_prev_param = pre_node->isa<Parameter>();
|
|
|
|
|
if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) !=
|
|
|
|
|
SUCCESS) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|