|
|
|
@ -627,6 +627,183 @@ class Variable(object):
|
|
|
|
|
"""
|
|
|
|
|
self.error_clip = error_clip
|
|
|
|
|
|
|
|
|
|
def _slice_indices(self, slice, length):
|
|
|
|
|
"""
|
|
|
|
|
Reference implementation for the slice.indices method.
|
|
|
|
|
"""
|
|
|
|
|
# Compute step and length as integers.
|
|
|
|
|
step = 1 if slice.step is None else slice.step
|
|
|
|
|
|
|
|
|
|
# Raise ValueError for negative length or zero step.
|
|
|
|
|
if length < 0:
|
|
|
|
|
raise ValueError("length should not be negative")
|
|
|
|
|
if step == 0:
|
|
|
|
|
raise ValueError("slice step cannot be zero")
|
|
|
|
|
|
|
|
|
|
# Find lower and upper bounds for start and stop.
|
|
|
|
|
lower = -1 if step < 0 else 0
|
|
|
|
|
upper = length - 1 if step < 0 else length
|
|
|
|
|
|
|
|
|
|
# Compute start.
|
|
|
|
|
if slice.start is None:
|
|
|
|
|
start = upper if step < 0 else lower
|
|
|
|
|
else:
|
|
|
|
|
start = slice.start
|
|
|
|
|
start = max(start + length, lower) if start < 0 else min(start,
|
|
|
|
|
upper)
|
|
|
|
|
|
|
|
|
|
# Compute stop.
|
|
|
|
|
if slice.stop is None:
|
|
|
|
|
stop = lower if step < 0 else upper
|
|
|
|
|
else:
|
|
|
|
|
stop = slice.stop
|
|
|
|
|
stop = max(stop + length, lower) if stop < 0 else min(stop, upper)
|
|
|
|
|
|
|
|
|
|
return start, stop, step
|
|
|
|
|
|
|
|
|
|
def _detectEllipsis(self, item):
|
|
|
|
|
has_ellipsis = False
|
|
|
|
|
start = 0
|
|
|
|
|
end = len(self.shape)
|
|
|
|
|
for index, o in enumerate(item):
|
|
|
|
|
if o is Ellipsis:
|
|
|
|
|
if has_ellipsis:
|
|
|
|
|
raise ValueError("Index can have one ellipsis only.")
|
|
|
|
|
has_ellipsis = True
|
|
|
|
|
start = index
|
|
|
|
|
else:
|
|
|
|
|
if has_ellipsis:
|
|
|
|
|
end = index
|
|
|
|
|
return has_ellipsis, start, end
|
|
|
|
|
|
|
|
|
|
def _reconstructSliceinfo(self, item):
|
|
|
|
|
has_ellipsis, start, end = self._detectEllipsis(item)
|
|
|
|
|
if has_ellipsis:
|
|
|
|
|
newitem = []
|
|
|
|
|
for i in range(start):
|
|
|
|
|
newitem.append(item[i])
|
|
|
|
|
for i in range(start, end):
|
|
|
|
|
newitem.append(slice(None, None, None))
|
|
|
|
|
for i in range(end, len(item)):
|
|
|
|
|
newitem.append(item[i])
|
|
|
|
|
return newitem
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _detectContinuesSlice(self, item):
|
|
|
|
|
starts = []
|
|
|
|
|
ends = []
|
|
|
|
|
for index, o in enumerate(item):
|
|
|
|
|
if isinstance(o, int):
|
|
|
|
|
start = int(o)
|
|
|
|
|
if (index > 0 and index >= self.shape[index]) \
|
|
|
|
|
or (index < 0 and (index + self.shape[index]) < 0):
|
|
|
|
|
raise IndexError("invalid index")
|
|
|
|
|
start = max(start + self.shape[index], 0) if start < 0 else min(
|
|
|
|
|
start, self.shape[index])
|
|
|
|
|
starts.append(start)
|
|
|
|
|
ends.append(start + 1)
|
|
|
|
|
elif isinstance(o, slice):
|
|
|
|
|
start, stop, step = self._slice_indices(o, self.shape[index])
|
|
|
|
|
if step == 1 or step == -1:
|
|
|
|
|
starts.append(start)
|
|
|
|
|
ends.append(stop)
|
|
|
|
|
else:
|
|
|
|
|
return False, None
|
|
|
|
|
else:
|
|
|
|
|
raise IndexError("Valid index accept int or slice or ellipsis")
|
|
|
|
|
return True, [starts, ends]
|
|
|
|
|
|
|
|
|
|
def _cloneVar(self, copy=False):
|
|
|
|
|
if not copy:
|
|
|
|
|
return self.block.create_var(
|
|
|
|
|
name=unique_name.generate(".".join(self.name)),
|
|
|
|
|
dtype=self.dtype,
|
|
|
|
|
persistable=self.persistable,
|
|
|
|
|
stop_gradient=self._stop_gradient, )
|
|
|
|
|
else:
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def _sliceVar(self, axes, starts, ends):
|
|
|
|
|
new_var = self._cloneVar()
|
|
|
|
|
self.block.append_op(
|
|
|
|
|
type="slice",
|
|
|
|
|
inputs={'Input': [self]},
|
|
|
|
|
outputs={'Out': [new_var]},
|
|
|
|
|
attrs={'axes': axes,
|
|
|
|
|
'starts': starts,
|
|
|
|
|
'ends': ends})
|
|
|
|
|
return new_var
|
|
|
|
|
|
|
|
|
|
def _concatVar(self, inputs, axis):
|
|
|
|
|
new_var = self._cloneVar()
|
|
|
|
|
self.block.append_op(
|
|
|
|
|
type="concat",
|
|
|
|
|
inputs={'X': inputs},
|
|
|
|
|
outputs={'Out': [new_var]},
|
|
|
|
|
attrs={'axis': axis, })
|
|
|
|
|
return new_var
|
|
|
|
|
|
|
|
|
|
def _sliceAndConcatVar(self, item, axis):
|
|
|
|
|
if isinstance(item, slice):
|
|
|
|
|
if self.shape[axis] < 0:
|
|
|
|
|
return self._cloneVar(True)
|
|
|
|
|
start, stop, step = self._slice_indices(item, self.shape[axis])
|
|
|
|
|
if step == 1:
|
|
|
|
|
return self._sliceVar([axis], [start], [stop])
|
|
|
|
|
else:
|
|
|
|
|
vars = []
|
|
|
|
|
if step > 0:
|
|
|
|
|
while start < stop:
|
|
|
|
|
vars.append(
|
|
|
|
|
self._sliceVar([axis], [start], [start + 1]))
|
|
|
|
|
start += step
|
|
|
|
|
else:
|
|
|
|
|
while start > stop:
|
|
|
|
|
vars.append(
|
|
|
|
|
self._sliceVar([axis], [start], [start + 1]))
|
|
|
|
|
start += step
|
|
|
|
|
return self._concatVar(vars, axis)
|
|
|
|
|
elif isinstance(item, int):
|
|
|
|
|
if self.shape[axis] < 0:
|
|
|
|
|
return self._cloneVar(True)
|
|
|
|
|
index = int(item)
|
|
|
|
|
if (index > 0 and index >= self.shape[axis])\
|
|
|
|
|
or (index < 0 and (index + self.shape[axis]) < 0):
|
|
|
|
|
raise IndexError("invalid index")
|
|
|
|
|
return self._sliceVar([axis], [index], [index + 1])
|
|
|
|
|
else:
|
|
|
|
|
raise IndexError("Valid index accept int or slice or tuple")
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, item):
|
|
|
|
|
"""
|
|
|
|
|
Slice the variable.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
item(int/slice/tuple) : the index.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Sliced variable
|
|
|
|
|
"""
|
|
|
|
|
new_var = None
|
|
|
|
|
if isinstance(item, tuple):
|
|
|
|
|
if len(item) > len(self.shape):
|
|
|
|
|
raise IndexError("Too many indexes")
|
|
|
|
|
newitem = self._reconstructSliceinfo(item) or item
|
|
|
|
|
check, info = self._detectContinuesSlice(newitem)
|
|
|
|
|
if check:
|
|
|
|
|
starts = info[0]
|
|
|
|
|
ends = info[1]
|
|
|
|
|
axes = [i for i in range(len(starts))]
|
|
|
|
|
return self._sliceVar(axes, starts, ends)
|
|
|
|
|
else:
|
|
|
|
|
new_var = self
|
|
|
|
|
for index, o in enumerate(newitem):
|
|
|
|
|
new_var = new_var._sliceAndConcatVar(o, index)
|
|
|
|
|
else:
|
|
|
|
|
new_var = self._sliceAndConcatVar(item, 0)
|
|
|
|
|
return new_var
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_all_op_protos():
|
|
|
|
|
"""
|
|
|
|
|