|
|
|
@ -217,6 +217,68 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
|
|
|
|
|
std::vector<int> *slice_axes,
|
|
|
|
|
std::vector<int> *slice_starts,
|
|
|
|
|
std::vector<int> *slice_ends,
|
|
|
|
|
std::vector<int> *slice_strides,
|
|
|
|
|
std::vector<int> *decrease_axis,
|
|
|
|
|
std::vector<int> *infer_flags) {
|
|
|
|
|
// We allow indexing by Integers, Slices, and tuples of those
|
|
|
|
|
// types.
|
|
|
|
|
// Ellipsis and None are not supported yet.
|
|
|
|
|
// wrap to tuple
|
|
|
|
|
PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
tensor->IsInitialized(), true,
|
|
|
|
|
platform::errors::InvalidArgument("tensor has not been initialized"));
|
|
|
|
|
const auto &shape = tensor->dims();
|
|
|
|
|
const int rank = shape.size();
|
|
|
|
|
const int size = PyTuple_GET_SIZE(index);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
size <= rank, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"too many indices (%d) for tensor of dimension %d", size, rank));
|
|
|
|
|
for (int dim = 0; dim < size; ++dim) {
|
|
|
|
|
PyObject *slice_item = PyTuple_GetItem(index, dim);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
PyNumber_Check(slice_item) || PySlice_Check(slice_item), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"We allow indexing by Integers, Slices, and tuples of "
|
|
|
|
|
"these types, but received %s in %dth slice item",
|
|
|
|
|
std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
|
|
|
|
|
infer_flags->push_back(1);
|
|
|
|
|
int dim_len = shape[dim];
|
|
|
|
|
if (PyNumber_Check(slice_item)) {
|
|
|
|
|
// integer
|
|
|
|
|
int start = static_cast<int>(PyLong_AsLong(slice_item));
|
|
|
|
|
start = start < 0 ? start + dim_len : start;
|
|
|
|
|
slice_axes->push_back(dim);
|
|
|
|
|
slice_starts->push_back(start);
|
|
|
|
|
slice_ends->push_back(start + 1);
|
|
|
|
|
slice_strides->push_back(1);
|
|
|
|
|
decrease_axis->push_back(dim);
|
|
|
|
|
} else {
|
|
|
|
|
// slice
|
|
|
|
|
Py_ssize_t start, end, step;
|
|
|
|
|
// The parameter type for the slice parameter was PySliceObject* before 3.2
|
|
|
|
|
#if PY_VERSION_HEX >= 0x03020000
|
|
|
|
|
PySlice_GetIndices(slice_item, dim_len, &start, &end, &step);
|
|
|
|
|
#else
|
|
|
|
|
PySlice_GetIndices(reinterpret_cast<PySliceObject *>(slice_item), dim_len,
|
|
|
|
|
&start, &end, &step);
|
|
|
|
|
#endif
|
|
|
|
|
// :: or : or 0:dim_len:1
|
|
|
|
|
if (start == 0 && end == dim_len && step == 1) continue;
|
|
|
|
|
slice_axes->push_back(dim);
|
|
|
|
|
slice_starts->push_back(start);
|
|
|
|
|
slice_ends->push_back(end);
|
|
|
|
|
slice_strides->push_back(step);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!PyTuple_Check(_index)) Py_DecRef(index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Bind Methods
|
|
|
|
|
void BindImperative(py::module *m_ptr) {
|
|
|
|
|
auto &m = *m_ptr;
|
|
|
|
@ -396,77 +458,22 @@ void BindImperative(py::module *m_ptr) {
|
|
|
|
|
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
|
|
|
|
|
.def("__init__", &InitVarBaseFromNumpyWithKwargs)
|
|
|
|
|
.def("__getitem__",
|
|
|
|
|
[](imperative::VarBase &self, py::handle _index) {
|
|
|
|
|
// We allow indexing by Integers, Slices, and tuples of those
|
|
|
|
|
// types.
|
|
|
|
|
// Ellipsis and None are not supported yet.
|
|
|
|
|
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
|
|
|
|
|
std::vector<int> slice_axes, slice_starts, slice_ends,
|
|
|
|
|
slice_strides, decrease_axis;
|
|
|
|
|
// wrap to tuple
|
|
|
|
|
PyObject *index = !PyTuple_Check(_index.ptr())
|
|
|
|
|
? PyTuple_Pack(1, _index.ptr())
|
|
|
|
|
: _index.ptr();
|
|
|
|
|
const auto &tensor = self.Var().Get<framework::LoDTensor>();
|
|
|
|
|
PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"%s has not been initialized", self.Name()));
|
|
|
|
|
const auto &shape = tensor.dims();
|
|
|
|
|
const int rank = shape.size();
|
|
|
|
|
const int size = PyTuple_GET_SIZE(index);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
size <= rank, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"too many indices (%d) for tensor of dimension %d", size,
|
|
|
|
|
rank));
|
|
|
|
|
for (int dim = 0; dim < size; ++dim) {
|
|
|
|
|
PyObject *slice_item = PyTuple_GetItem(index, dim);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
PyNumber_Check(slice_item) || PySlice_Check(slice_item),
|
|
|
|
|
true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"We allow indexing by Integers, Slices, and tuples of "
|
|
|
|
|
"these types, but received %s in %dth slice item",
|
|
|
|
|
std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
|
|
|
|
|
int dim_len = shape[dim];
|
|
|
|
|
if (PyNumber_Check(slice_item)) {
|
|
|
|
|
// integer
|
|
|
|
|
int start = static_cast<int>(PyLong_AsLong(slice_item));
|
|
|
|
|
start = start < 0 ? start + dim_len : start;
|
|
|
|
|
slice_axes.push_back(dim);
|
|
|
|
|
slice_starts.push_back(start);
|
|
|
|
|
slice_ends.push_back(start + 1);
|
|
|
|
|
slice_strides.push_back(1);
|
|
|
|
|
decrease_axis.push_back(dim);
|
|
|
|
|
} else {
|
|
|
|
|
// slice
|
|
|
|
|
Py_ssize_t start, end, step;
|
|
|
|
|
// The parameter type for the slice parameter was PySliceObject* before 3.2
|
|
|
|
|
#if PY_VERSION_HEX >= 0x03020000
|
|
|
|
|
PySlice_GetIndices(slice_item, dim_len, &start, &end, &step);
|
|
|
|
|
#else
|
|
|
|
|
PySlice_GetIndices(
|
|
|
|
|
reinterpret_cast<PySliceObject *>(slice_item), dim_len,
|
|
|
|
|
&start, &end, &step);
|
|
|
|
|
#endif
|
|
|
|
|
// :: or : or 0:dim_len:1
|
|
|
|
|
if (start == 0 && end == dim_len && step == 1) continue;
|
|
|
|
|
slice_axes.push_back(dim);
|
|
|
|
|
slice_starts.push_back(start);
|
|
|
|
|
slice_ends.push_back(end);
|
|
|
|
|
slice_strides.push_back(step);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!PyTuple_Check(_index.ptr())) Py_DecRef(index);
|
|
|
|
|
slice_strides, decrease_axis, infer_flags;
|
|
|
|
|
auto tensor =
|
|
|
|
|
self->MutableVar()->GetMutable<framework::LoDTensor>();
|
|
|
|
|
ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
|
|
|
|
|
&slice_starts, &slice_ends, &slice_strides,
|
|
|
|
|
&decrease_axis, &infer_flags);
|
|
|
|
|
|
|
|
|
|
// release gil and do tracing
|
|
|
|
|
py::gil_scoped_release release;
|
|
|
|
|
const auto &tracer = imperative::GetCurrentTracer();
|
|
|
|
|
auto _self = self.NewVarBase(tensor.place(), false);
|
|
|
|
|
if (slice_axes.empty()) {
|
|
|
|
|
return _self;
|
|
|
|
|
return self;
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<int> infer_flags(size, 1);
|
|
|
|
|
imperative::NameVarBaseMap ins = {{"Input", {_self}}};
|
|
|
|
|
imperative::NameVarBaseMap ins = {{"Input", {self}}};
|
|
|
|
|
framework::AttributeMap attrs = {
|
|
|
|
|
{"axes", slice_axes},
|
|
|
|
|
{"starts", slice_starts},
|
|
|
|
|