|
|
@ -106,10 +106,10 @@ inline void CheckAndUpdateSlice(const framework::DDim in_dims,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
|
|
|
|
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
|
|
|
|
const std::vector<int64_t> axes,
|
|
|
|
const std::vector<int64_t>& axes,
|
|
|
|
const std::vector<int64_t> starts,
|
|
|
|
const std::vector<int64_t>& starts,
|
|
|
|
const std::vector<int64_t> ends,
|
|
|
|
const std::vector<int64_t>& ends,
|
|
|
|
const std::vector<int64_t> steps) {
|
|
|
|
const std::vector<int64_t>& steps) {
|
|
|
|
framework::DDim slice_dims(in_dims);
|
|
|
|
framework::DDim slice_dims(in_dims);
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < axes.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < axes.size(); ++i) {
|
|
|
@ -127,6 +127,38 @@ inline framework::DDim GetSliceDims(const framework::DDim in_dims,
|
|
|
|
return slice_dims;
|
|
|
|
return slice_dims;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline framework::DDim GetDecreasedDims(
|
|
|
|
|
|
|
|
const framework::DDim slice_dims,
|
|
|
|
|
|
|
|
const std::vector<int64_t>& decrease_axes) {
|
|
|
|
|
|
|
|
// Get dims after decreasing axes.
|
|
|
|
|
|
|
|
framework::DDim decreased_dims(slice_dims);
|
|
|
|
|
|
|
|
if (decrease_axes.size() > 0) {
|
|
|
|
|
|
|
|
for (size_t i = 0; i < decrease_axes.size(); ++i) {
|
|
|
|
|
|
|
|
int64_t axis = decrease_axes[i];
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
decreased_dims[axis], 1,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument("decrease dim should be 1"));
|
|
|
|
|
|
|
|
decreased_dims[axis] = 0;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> new_shape;
|
|
|
|
|
|
|
|
for (int i = 0; i < decreased_dims.size(); ++i) {
|
|
|
|
|
|
|
|
if (decreased_dims[i] != 0) {
|
|
|
|
|
|
|
|
new_shape.push_back(decreased_dims[i]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and
|
|
|
|
|
|
|
|
// uses [1] instead.
|
|
|
|
|
|
|
|
if (new_shape.size() == 0) {
|
|
|
|
|
|
|
|
new_shape.push_back(1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
decreased_dims = framework::make_ddim(new_shape);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return decreased_dims;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
class SetValueKernel : public framework::OpKernel<T> {
|
|
|
|
class SetValueKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
@ -179,6 +211,7 @@ class SetValueKernel : public framework::OpKernel<T> {
|
|
|
|
auto ends = ctx.Attr<std::vector<int64_t>>("ends");
|
|
|
|
auto ends = ctx.Attr<std::vector<int64_t>>("ends");
|
|
|
|
auto steps = ctx.Attr<std::vector<int64_t>>("steps");
|
|
|
|
auto steps = ctx.Attr<std::vector<int64_t>>("steps");
|
|
|
|
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
|
|
|
|
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
|
|
|
|
|
|
|
|
auto decrease_axes = ctx.Attr<std::vector<int64_t>>("decrease_axes");
|
|
|
|
|
|
|
|
|
|
|
|
auto dtype = in->type();
|
|
|
|
auto dtype = in->type();
|
|
|
|
if (!starts_tensor_list.empty()) {
|
|
|
|
if (!starts_tensor_list.empty()) {
|
|
|
@ -194,6 +227,7 @@ class SetValueKernel : public framework::OpKernel<T> {
|
|
|
|
auto in_dims = in->dims();
|
|
|
|
auto in_dims = in->dims();
|
|
|
|
CheckAndUpdateSlice(in_dims, axes, &starts, &ends, &steps);
|
|
|
|
CheckAndUpdateSlice(in_dims, axes, &starts, &ends, &steps);
|
|
|
|
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, steps);
|
|
|
|
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, steps);
|
|
|
|
|
|
|
|
auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes);
|
|
|
|
|
|
|
|
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
auto& eigen_place =
|
|
|
|
auto& eigen_place =
|
|
|
@ -212,13 +246,13 @@ class SetValueKernel : public framework::OpKernel<T> {
|
|
|
|
// set_value is what we want.
|
|
|
|
// set_value is what we want.
|
|
|
|
TensorCopy(*in, place, out);
|
|
|
|
TensorCopy(*in, place, out);
|
|
|
|
|
|
|
|
|
|
|
|
Tensor slice_t(dtype), pad_t(dtype);
|
|
|
|
Tensor slice_tensor(dtype), pad_tensor(dtype);
|
|
|
|
slice_t.mutable_data<T>(slice_dims, place);
|
|
|
|
slice_tensor.mutable_data<T>(slice_dims, place);
|
|
|
|
pad_t.mutable_data<T>(in_dims, place);
|
|
|
|
pad_tensor.mutable_data<T>(in_dims, place);
|
|
|
|
|
|
|
|
|
|
|
|
auto pad_e = framework::EigenTensor<T, D>::From(pad_t, in_dims);
|
|
|
|
auto pad_e = framework::EigenTensor<T, D>::From(pad_tensor, in_dims);
|
|
|
|
auto out_e = framework::EigenTensor<T, D>::From(*out);
|
|
|
|
auto out_e = framework::EigenTensor<T, D>::From(*out);
|
|
|
|
auto slice_e = framework::EigenTensor<T, D>::From(slice_t, slice_dims);
|
|
|
|
auto slice_e = framework::EigenTensor<T, D>::From(slice_tensor, slice_dims);
|
|
|
|
|
|
|
|
|
|
|
|
// Step 1: Set the value of out at `_index` to zero
|
|
|
|
// Step 1: Set the value of out at `_index` to zero
|
|
|
|
slice_e.device(eigen_place) = slice_e.constant(T(0));
|
|
|
|
slice_e.device(eigen_place) = slice_e.constant(T(0));
|
|
|
@ -244,11 +278,26 @@ class SetValueKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
|
|
// Step 2: Set a tensor with the same shape as out tensor. And its data at
|
|
|
|
// Step 2: Set a tensor with the same shape as out tensor. And its data at
|
|
|
|
// '_index' is the same as value_tensor, and data out of '_index' to zero
|
|
|
|
// '_index' is the same as value_tensor, and data out of '_index' to zero
|
|
|
|
|
|
|
|
|
|
|
|
// - Step 2.1 Set slice tensor with value
|
|
|
|
// - Step 2.1 Set slice tensor with value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// NOTE(liym27): [ Why resize slice_tensor here? ]
|
|
|
|
|
|
|
|
// A: When do broadcasting on slice_tensor and value_tensor, the shape of
|
|
|
|
|
|
|
|
// slice_tensor should be decreased dims.
|
|
|
|
|
|
|
|
// e.g.
|
|
|
|
|
|
|
|
// x[:,0] = value_tensor
|
|
|
|
|
|
|
|
// x's shape = [3, 4], value_tensor's shape = [3]
|
|
|
|
|
|
|
|
// We get slice_dims = [3, 1], decrease_slice_dims = [3]
|
|
|
|
|
|
|
|
// If do broadcasting on Tensor with shape [3, 1] and [3], the result's
|
|
|
|
|
|
|
|
// shape is [3, 3], which cross the border;
|
|
|
|
|
|
|
|
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
|
|
|
|
|
|
|
|
// is [3], which is right.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
slice_tensor.Resize(decrease_slice_dims);
|
|
|
|
if (value_tensor != nullptr) {
|
|
|
|
if (value_tensor != nullptr) {
|
|
|
|
// ElementwiseComputeEx can do broadcasting
|
|
|
|
// ElementwiseComputeEx can do broadcasting
|
|
|
|
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
|
|
|
|
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
|
|
|
|
ctx, &slice_t, value_tensor, -1, SubFunctor<T>(), &slice_t);
|
|
|
|
ctx, &slice_tensor, value_tensor, -1, SubFunctor<T>(), &slice_tensor);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
Tensor value_t(dtype);
|
|
|
|
Tensor value_t(dtype);
|
|
|
|
auto value_dims = framework::make_ddim(shape);
|
|
|
|
auto value_dims = framework::make_ddim(shape);
|
|
|
@ -257,8 +306,9 @@ class SetValueKernel : public framework::OpKernel<T> {
|
|
|
|
CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx);
|
|
|
|
CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx);
|
|
|
|
value_t.Resize(value_dims);
|
|
|
|
value_t.Resize(value_dims);
|
|
|
|
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
|
|
|
|
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
|
|
|
|
ctx, &slice_t, &value_t, -1, SubFunctor<T>(), &slice_t);
|
|
|
|
ctx, &slice_tensor, &value_t, -1, SubFunctor<T>(), &slice_tensor);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
slice_tensor.Resize(slice_dims);
|
|
|
|
|
|
|
|
|
|
|
|
// - Step 2.2 Pad slice tensor with 0
|
|
|
|
// - Step 2.2 Pad slice tensor with 0
|
|
|
|
pad_e.device(eigen_place) = pad_e.constant(T(0));
|
|
|
|
pad_e.device(eigen_place) = pad_e.constant(T(0));
|
|
|
|