|
|
|
@ -222,6 +222,71 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool PyCheckInteger(PyObject *obj) {
|
|
|
|
|
#if PY_VERSION_HEX < 0x03000000
|
|
|
|
|
return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj);
|
|
|
|
|
#else
|
|
|
|
|
return PyLong_Check(obj) && !PyBool_Check(obj);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NOTE(zhiqiu): Revised version of PySlice_GetIndices. From:
|
|
|
|
|
// https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Objects/sliceobject.c#L103
|
|
|
|
|
// Original PySlice_GetIndices return wrong result when
|
|
|
|
|
// slice_item contains long int, such as arr[:180L].
|
|
|
|
|
// NOT sure why this happens !!!
|
|
|
|
|
// Besides, PySlice_GetIndices cannot raise error when float in slice item.
|
|
|
|
|
// So, I make a revised version of PySlice_GetIndices, named to
|
|
|
|
|
// _PySlice_GetIndices. Try to use _PySlice_Unpack which is more robust than
|
|
|
|
|
// PySlice_GetIndices in the future.
|
|
|
|
|
static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
|
|
|
|
|
Py_ssize_t *start, Py_ssize_t *stop,
|
|
|
|
|
Py_ssize_t *step) {
|
|
|
|
|
/* XXX support long ints */
|
|
|
|
|
if (r->step == Py_None) {
|
|
|
|
|
*step = 1;
|
|
|
|
|
} else {
|
|
|
|
|
if (PyCheckInteger(r->step)) {
|
|
|
|
|
*step = PyLong_AsLong(r->step);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"Currently, VarBase.__getitem__() only allows None or integers in "
|
|
|
|
|
"slice item, but received %s.",
|
|
|
|
|
std::string(Py_TYPE(r->step)->tp_name)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (r->start == Py_None) {
|
|
|
|
|
*start = *step < 0 ? length - 1 : 0;
|
|
|
|
|
} else {
|
|
|
|
|
if (PyCheckInteger(r->start)) {
|
|
|
|
|
*start = PyLong_AsLong(r->start);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"Currently, VarBase.__getitem__() only allows None or integers in "
|
|
|
|
|
"slice item, but received %s.",
|
|
|
|
|
std::string(Py_TYPE(r->start)->tp_name)));
|
|
|
|
|
}
|
|
|
|
|
if (*start < 0) *start += length;
|
|
|
|
|
}
|
|
|
|
|
if (r->stop == Py_None) {
|
|
|
|
|
*stop = *step < 0 ? -1 : length;
|
|
|
|
|
} else {
|
|
|
|
|
if (PyCheckInteger(r->stop)) {
|
|
|
|
|
*stop = PyLong_AsLong(r->stop);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"Currently, VarBase.__getitem__() only allows None or integers in "
|
|
|
|
|
"slice item, but received %s.",
|
|
|
|
|
std::string(Py_TYPE(r->stop)->tp_name)));
|
|
|
|
|
}
|
|
|
|
|
if (*stop < 0) *stop += length;
|
|
|
|
|
}
|
|
|
|
|
if (*stop > length) return -1;
|
|
|
|
|
if (*start >= length) return -1;
|
|
|
|
|
if (*step == 0) return -1;
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
|
|
|
|
|
std::vector<int> *slice_axes,
|
|
|
|
|
std::vector<int> *slice_starts,
|
|
|
|
@ -246,16 +311,17 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
|
|
|
|
|
"too many indices (%d) for tensor of dimension %d", size, rank));
|
|
|
|
|
for (int dim = 0; dim < size; ++dim) {
|
|
|
|
|
PyObject *slice_item = PyTuple_GetItem(index, dim);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
PyNumber_Check(slice_item) || PySlice_Check(slice_item), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"We allow indexing by Integers, Slices, and tuples of "
|
|
|
|
|
"these types, but received %s in %dth slice item",
|
|
|
|
|
std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
|
|
|
|
|
PADDLE_ENFORCE_EQ(PyCheckInteger(slice_item) || PySlice_Check(slice_item),
|
|
|
|
|
true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Currently, VarBase.__getitem__() only allows "
|
|
|
|
|
"indexing by Integers, Slices, and tuples of "
|
|
|
|
|
"these types, but received %s in %dth slice item",
|
|
|
|
|
std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
|
|
|
|
|
infer_flags->push_back(1);
|
|
|
|
|
int dim_len = shape[dim];
|
|
|
|
|
if (PyNumber_Check(slice_item)) {
|
|
|
|
|
// integer
|
|
|
|
|
if (PyCheckInteger(slice_item)) {
|
|
|
|
|
// integer, PyLong_AsLong supports both int and long
|
|
|
|
|
int start = static_cast<int>(PyLong_AsLong(slice_item));
|
|
|
|
|
auto s_t = start;
|
|
|
|
|
start = start < 0 ? start + dim_len : start;
|
|
|
|
@ -275,17 +341,15 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
|
|
|
|
|
slice_strides->push_back(1);
|
|
|
|
|
decrease_axis->push_back(dim);
|
|
|
|
|
} else {
|
|
|
|
|
// slice
|
|
|
|
|
// slice item
|
|
|
|
|
Py_ssize_t start, end, step;
|
|
|
|
|
// The parameter type for the slice parameter was PySliceObject* before 3.2
|
|
|
|
|
#if PY_VERSION_HEX >= 0x03020000
|
|
|
|
|
PySlice_GetIndices(slice_item, dim_len, &start, &end, &step);
|
|
|
|
|
#else
|
|
|
|
|
PySlice_GetIndices(reinterpret_cast<PySliceObject *>(slice_item), dim_len,
|
|
|
|
|
&start, &end, &step);
|
|
|
|
|
#endif
|
|
|
|
|
PySliceObject *p = reinterpret_cast<PySliceObject *>(slice_item);
|
|
|
|
|
_PySlice_GetIndices(p, dim_len, &start, &end, &step);
|
|
|
|
|
|
|
|
|
|
// :: or : or 0:dim_len:1
|
|
|
|
|
if (start == 0 && end == dim_len && step == 1) continue;
|
|
|
|
|
if (start == 0 && end == dim_len && step == 1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
slice_axes->push_back(dim);
|
|
|
|
|
slice_starts->push_back(start);
|
|
|
|
|
slice_ends->push_back(end);
|
|
|
|
@ -493,7 +557,6 @@ void BindImperative(py::module *m_ptr) {
|
|
|
|
|
ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
|
|
|
|
|
&slice_starts, &slice_ends, &slice_strides,
|
|
|
|
|
&decrease_axis, &infer_flags);
|
|
|
|
|
|
|
|
|
|
// release gil and do tracing
|
|
|
|
|
py::gil_scoped_release release;
|
|
|
|
|
const auto &tracer = imperative::GetCurrentTracer();
|
|
|
|
@ -633,8 +696,8 @@ void BindImperative(py::module *m_ptr) {
|
|
|
|
|
[](imperative::VarBase &self,
|
|
|
|
|
const imperative::detail::BackwardStrategy &bckst,
|
|
|
|
|
const imperative::Tracer &tracer) {
|
|
|
|
|
// TODO(jiabin): when we impl more backward execution we can select
|
|
|
|
|
// them
|
|
|
|
|
// TODO(jiabin): when we impl more backward execution we can
|
|
|
|
|
// select them
|
|
|
|
|
auto *engine = tracer.GetEngine();
|
|
|
|
|
engine->Init(&self, bckst);
|
|
|
|
|
VLOG(3) << "Start backward";
|
|
|
|
|