|
|
|
@ -308,27 +308,39 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": get empty Strategy.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t axis = -1;
|
|
|
|
|
std::vector<int64_t> axis_list;
|
|
|
|
|
auto iter = ops[iter_ops]->attrs().find(AXIS);
|
|
|
|
|
if (iter != ops[iter_ops]->attrs().end()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(iter->second);
|
|
|
|
|
if (iter->second->isa<Int64Imm>()) {
|
|
|
|
|
axis = iter->second->cast<Int64ImmPtr>()->value();
|
|
|
|
|
axis_list.push_back(iter->second->cast<Int64ImmPtr>()->value());
|
|
|
|
|
} else if (iter->second->isa<ValueTuple>()) {
|
|
|
|
|
ValueTuplePtr value_tuple = iter->second->cast<ValueTuplePtr>();
|
|
|
|
|
if (value_tuple == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value_tuple is nullptr.";
|
|
|
|
|
}
|
|
|
|
|
std::vector<ValuePtr> value_vector = value_tuple->value();
|
|
|
|
|
(void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_list),
|
|
|
|
|
[](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t.";
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t or tuple int64_t.";
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
axis_list.push_back(-1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (axis < 0) {
|
|
|
|
|
int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
|
|
|
|
|
axis = input_dim + axis;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (strategies[0][axis] != 1) {
|
|
|
|
|
strategies[0][axis] = 1;
|
|
|
|
|
MS_LOG(INFO) << ops[iter_ops]->name() << ": adjust strategy to 1 on axis " << axis;
|
|
|
|
|
for (auto &axis : axis_list) {
|
|
|
|
|
if (axis < 0) {
|
|
|
|
|
int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
|
|
|
|
|
axis = input_dim + axis;
|
|
|
|
|
}
|
|
|
|
|
if (axis >= SizeToLong(strategies[0].size()) || axis < 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": axis value is out of range.";
|
|
|
|
|
}
|
|
|
|
|
if (strategies[0][axis] != 1) {
|
|
|
|
|
strategies[0][axis] = 1;
|
|
|
|
|
MS_LOG(INFO) << ops[iter_ops]->name() << ": adjust strategy to 1 on axis " << axis;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return strategies;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|