|
|
|
@ -46,6 +46,8 @@ using mindspore::abstract::AbstractBase;
|
|
|
|
|
using mindspore::abstract::AbstractClass;
|
|
|
|
|
using mindspore::abstract::AbstractDictionary;
|
|
|
|
|
using mindspore::abstract::AbstractDictionaryPtr;
|
|
|
|
|
using mindspore::abstract::AbstractEllipsis;
|
|
|
|
|
using mindspore::abstract::AbstractEllipsisPtr;
|
|
|
|
|
using mindspore::abstract::AbstractFunction;
|
|
|
|
|
using mindspore::abstract::AbstractFunctionPtr;
|
|
|
|
|
using mindspore::abstract::AbstractList;
|
|
|
|
@ -1081,6 +1083,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
|
|
|
|
|
|
|
|
|
|
std::vector<unsigned int> shrink;
|
|
|
|
|
auto slice_tuple_eles = slice_tuple->elements();
|
|
|
|
|
size_t ellipsis_num = 0;
|
|
|
|
|
for (size_t index = 0; index < slice_tuple_size; index++) {
|
|
|
|
|
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
|
|
|
|
|
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
|
|
|
|
@ -1098,7 +1101,20 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice or int number, but got "
|
|
|
|
|
if (slice_tuple_eles[index]->isa<AbstractEllipsis>()) {
|
|
|
|
|
ellipsis_num++;
|
|
|
|
|
if (ellipsis_num > 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis";
|
|
|
|
|
}
|
|
|
|
|
size_t ellipsis_len = shape_size - (slice_tuple_size - 1);
|
|
|
|
|
begin->insert(begin->end(), ellipsis_len, 0);
|
|
|
|
|
end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len);
|
|
|
|
|
strides->insert(strides->end(), ellipsis_len, 1);
|
|
|
|
|
shrink.insert(shrink.end(), ellipsis_len, 0);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got "
|
|
|
|
|
<< slice_tuple_eles[index]->ToString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1160,6 +1176,11 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
|
|
|
|
abstract::CheckArgsSize(op_name, args_spec_list, 2);
|
|
|
|
|
AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
|
|
|
|
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
|
|
|
|
AnfNodePtr tensor_node = ret_graph->add_parameter();
|
|
|
|
|
(void)ret_graph->add_parameter();
|
|
|
|
|
|
|
|
|
|
auto shape = tensorPtr->shape()->shape();
|
|
|
|
|
std::vector<int> begin;
|
|
|
|
|
std::vector<int> end;
|
|
|
|
@ -1174,23 +1195,28 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
|
|
|
|
shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides);
|
|
|
|
|
} else if (args_spec_list[1]->isa<AbstractScalar>()) {
|
|
|
|
|
AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
|
|
|
|
|
if (scalar_ptr->BuildValue()->isa<BoolImm>()) {
|
|
|
|
|
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
|
|
|
|
|
return ExpandADim(ret_graph, tensor_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
|
|
|
|
|
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
|
|
|
|
|
ret_graph->set_output(tensor_node);
|
|
|
|
|
return ret_graph;
|
|
|
|
|
} else if (args_spec_list[1]->isa<AbstractNone>()) {
|
|
|
|
|
return ExpandADim(ret_graph, tensor_node);
|
|
|
|
|
} else {
|
|
|
|
|
std::ostringstream args_info;
|
|
|
|
|
for (const auto &arg : args_spec_list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(arg);
|
|
|
|
|
args_info << arg->ToString() << "\n";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "TensorSlice requires to input a tensor and a slice or slice tuple, but got "
|
|
|
|
|
<< args_info.str();
|
|
|
|
|
MS_LOG(EXCEPTION)
|
|
|
|
|
<< "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got "
|
|
|
|
|
<< args_info.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
|
|
|
|
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
|
|
|
|
|
|
|
|
|
AnfNodePtr tensor_node = ret_graph->add_parameter();
|
|
|
|
|
(void)ret_graph->add_parameter();
|
|
|
|
|
|
|
|
|
|
auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations");
|
|
|
|
|
auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0),
|
|
|
|
|
NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)});
|
|
|
|
@ -1199,6 +1225,12 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
|
|
|
|
return ret_graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const {
|
|
|
|
|
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
|
|
|
|
|
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
|
|
|
|
|
return ret_graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
|
|
|
|
|
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
|
|
|
|
|
.def(py::init<std::string &>());
|
|
|
|
|