|
|
|
@ -932,206 +932,6 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int ConvertBinaryToDecimal(const std::vector<unsigned int> &number_bin) {
|
|
|
|
|
unsigned int number_dec = 0;
|
|
|
|
|
for (size_t index = 0; index < number_bin.size(); index++) {
|
|
|
|
|
number_dec |= number_bin[index] << index;
|
|
|
|
|
}
|
|
|
|
|
return static_cast<int>(number_dec);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParseSlice(const AbstractSlicePtr &slice, std::vector<int> *begin, std::vector<int> *end,
|
|
|
|
|
std::vector<int> *strides, int length) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(slice);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(begin);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(end);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(strides);
|
|
|
|
|
if (length <= 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Could not slice a dim when it's length less than 1";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int start_default = 0;
|
|
|
|
|
int stop_default = length;
|
|
|
|
|
int step_default = 1;
|
|
|
|
|
int step_value = CheckSliceMember(slice->step(), step_default, "step");
|
|
|
|
|
if (step_value < 0) {
|
|
|
|
|
start_default = -1;
|
|
|
|
|
stop_default = -(length + 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
begin->push_back(CheckSliceMember(slice->start(), start_default, "begin"));
|
|
|
|
|
end->push_back(CheckSliceMember(slice->stop(), stop_default, "stop"));
|
|
|
|
|
strides->push_back(step_value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector<int> &shape,
|
|
|
|
|
std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(slice_tuple);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(begin);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(end);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(strides);
|
|
|
|
|
|
|
|
|
|
size_t slice_tuple_size = slice_tuple->size();
|
|
|
|
|
size_t shape_size = shape.size();
|
|
|
|
|
if (slice_tuple_size > shape_size) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The number of slice data to slice tensor should be less than the rank of tensor,"
|
|
|
|
|
"when the rank of tensor is "
|
|
|
|
|
<< shape_size << ", the number of slice is " << slice_tuple_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<unsigned int> shrink;
|
|
|
|
|
auto slice_tuple_eles = slice_tuple->elements();
|
|
|
|
|
size_t ellipsis_num = 0;
|
|
|
|
|
|
|
|
|
|
for (size_t index = 0; index < slice_tuple_size; index++) {
|
|
|
|
|
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
|
|
|
|
|
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
|
|
|
|
|
ParseSlice(slice, begin, end, strides, shape[index]);
|
|
|
|
|
shrink.push_back(0);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (slice_tuple_eles[index]->isa<AbstractScalar>()) {
|
|
|
|
|
int ele_index = GetArgScalarValue(dyn_cast<AbstractScalar>(slice_tuple_eles[index]), "slice_tuple");
|
|
|
|
|
begin->push_back(ele_index);
|
|
|
|
|
end->push_back(ele_index + 1);
|
|
|
|
|
strides->push_back(1);
|
|
|
|
|
shrink.push_back(1);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (slice_tuple_eles[index]->isa<AbstractEllipsis>()) {
|
|
|
|
|
ellipsis_num++;
|
|
|
|
|
if (ellipsis_num > 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis";
|
|
|
|
|
}
|
|
|
|
|
size_t ellipsis_len = shape_size - (slice_tuple_size - 1);
|
|
|
|
|
begin->insert(begin->end(), ellipsis_len, 0);
|
|
|
|
|
end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len);
|
|
|
|
|
strides->insert(strides->end(), ellipsis_len, 1);
|
|
|
|
|
shrink.insert(shrink.end(), ellipsis_len, 0);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got "
|
|
|
|
|
<< slice_tuple_eles[index]->ToString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ellipsis_num == 0) {
|
|
|
|
|
for (size_t index = slice_tuple_size; index < shape_size; index++) {
|
|
|
|
|
begin->push_back(0);
|
|
|
|
|
end->push_back(shape[index]);
|
|
|
|
|
strides->push_back(1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return ConvertBinaryToDecimal(shrink);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector<int> &shape,
|
|
|
|
|
std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(begin);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(end);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(strides);
|
|
|
|
|
size_t shape_size = shape.size();
|
|
|
|
|
if (shape_size == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Could slice a scalar tensor";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ParseSlice(slice, begin, end, strides, shape[0]);
|
|
|
|
|
|
|
|
|
|
for (size_t index = 1; index < shape_size; index++) {
|
|
|
|
|
begin->push_back(0);
|
|
|
|
|
end->push_back(shape[index]);
|
|
|
|
|
strides->push_back(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector<int> &shape,
|
|
|
|
|
std::vector<int> *begin, std::vector<int> *end,
|
|
|
|
|
std::vector<int> *strides) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(begin);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(end);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(strides);
|
|
|
|
|
int ele_index = GetArgScalarValue(scalar, "slice_tuple");
|
|
|
|
|
|
|
|
|
|
begin->push_back(ele_index);
|
|
|
|
|
end->push_back(ele_index + 1);
|
|
|
|
|
strides->push_back(1);
|
|
|
|
|
|
|
|
|
|
for (size_t index = 1; index < shape.size(); index++) {
|
|
|
|
|
begin->push_back(0);
|
|
|
|
|
end->push_back(shape[index]);
|
|
|
|
|
strides->push_back(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) {
|
|
|
|
|
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
|
|
|
|
|
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
|
|
|
|
|
return ret_graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
// slice a tensor
|
|
|
|
|
// args: tensor, slice or slice tuple
|
|
|
|
|
const std::string op_name = std::string("TensorSlice");
|
|
|
|
|
abstract::CheckArgsSize(op_name, args_spec_list, 2);
|
|
|
|
|
AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
|
|
|
|
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
|
|
|
|
AnfNodePtr tensor_node = ret_graph->add_parameter();
|
|
|
|
|
(void)ret_graph->add_parameter();
|
|
|
|
|
|
|
|
|
|
auto shape = tensorPtr->shape()->shape();
|
|
|
|
|
std::vector<int> begin;
|
|
|
|
|
std::vector<int> end;
|
|
|
|
|
std::vector<int> strides;
|
|
|
|
|
int shrink_axis_mask;
|
|
|
|
|
|
|
|
|
|
if (args_spec_list[1]->isa<AbstractTuple>()) {
|
|
|
|
|
AbstractTuplePtr tuple_ptr = dyn_cast<AbstractTuple>(args_spec_list[1]);
|
|
|
|
|
shrink_axis_mask = GenerateStridedSliceParametersFromTuple(tuple_ptr, shape, &begin, &end, &strides);
|
|
|
|
|
} else if (args_spec_list[1]->isa<AbstractSlice>()) {
|
|
|
|
|
AbstractSlicePtr slice_ptr = dyn_cast<AbstractSlice>(args_spec_list[1]);
|
|
|
|
|
shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides);
|
|
|
|
|
} else if (args_spec_list[1]->isa<AbstractScalar>()) {
|
|
|
|
|
AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
|
|
|
|
|
if (scalar_ptr->BuildValue()->isa<BoolImm>()) {
|
|
|
|
|
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
|
|
|
|
|
return ExpandADim(ret_graph, tensor_node);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "TensorSlice not support the index is False.";
|
|
|
|
|
}
|
|
|
|
|
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
|
|
|
|
|
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
|
|
|
|
|
ret_graph->set_output(tensor_node);
|
|
|
|
|
return ret_graph;
|
|
|
|
|
} else if (args_spec_list[1]->isa<AbstractNone>()) {
|
|
|
|
|
return ExpandADim(ret_graph, tensor_node);
|
|
|
|
|
} else {
|
|
|
|
|
std::ostringstream args_info;
|
|
|
|
|
for (const auto &arg : args_spec_list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(arg);
|
|
|
|
|
args_info << arg->ToString() << "\n";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION)
|
|
|
|
|
<< "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got "
|
|
|
|
|
<< args_info.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations");
|
|
|
|
|
auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0),
|
|
|
|
|
NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)});
|
|
|
|
|
ret_graph->set_output(ret_graph->NewCNode(
|
|
|
|
|
{PrimStridedSlice, tensor_node, NewValueNode(begin), NewValueNode(end), NewValueNode(strides)}));
|
|
|
|
|
return ret_graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
// select indexed item
|
|
|
|
|
// args: tuple of items, index
|
|
|
|
@ -1162,11 +962,6 @@ REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
|
|
|
|
|
.def(py::init<std::string &>());
|
|
|
|
|
}));
|
|
|
|
|
|
|
|
|
|
REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) {
|
|
|
|
|
(void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
|
|
|
|
|
.def(py::init<std::string &>());
|
|
|
|
|
}));
|
|
|
|
|
|
|
|
|
|
REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
|
|
|
|
|
(void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
|
|
|
|
|
*m, "TupleGetItemTensor_")
|
|
|
|
|