|
|
|
@ -73,15 +73,23 @@ elementwise_inner_add(const framework::ExecutionContext& ctx,
|
|
|
|
|
template <typename T, typename IndexT = int>
|
|
|
|
|
void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
|
|
|
|
|
const Tensor& index, Tensor* output) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
platform::is_cpu_place(ctx.GetPlace()), true,
|
|
|
|
|
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
|
|
|
|
|
// check index of shape 1-D
|
|
|
|
|
if (index.dims().size() == 2) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(index.dims()[1], 1,
|
|
|
|
|
"index.dims()[1] should be 1 when index.dims().size() == "
|
|
|
|
|
"2 in scatter_op.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"index.dims()[1] should be 1 when "
|
|
|
|
|
"index.dims().size() =2 in scatter_op."
|
|
|
|
|
"But received value is [%d]",
|
|
|
|
|
index.dims()[1]));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(index.dims().size(), 1,
|
|
|
|
|
"index.dims().size() should be 1 or 2 in scatter_op.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"index.dims().size() should be 1 or 2 in scatter_op."
|
|
|
|
|
"But received value is [%d]",
|
|
|
|
|
index.dims().size()));
|
|
|
|
|
}
|
|
|
|
|
int index_size = index.dims()[0];
|
|
|
|
|
|
|
|
|
@ -94,7 +102,9 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
|
|
|
|
|
|
|
|
|
|
// check src shape and dst shape should match
|
|
|
|
|
for (int i = 1; i < src_dims.size(); i++)
|
|
|
|
|
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"src shape and dst shape should match"));
|
|
|
|
|
|
|
|
|
|
// slice size
|
|
|
|
|
size_t slice_size = 1;
|
|
|
|
@ -111,12 +121,14 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
|
|
|
|
|
template <typename T, typename IndexT = int>
|
|
|
|
|
void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
|
|
|
|
|
const Tensor& index, Tensor* output) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.device_context().GetPlace()),
|
|
|
|
|
true);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
platform::is_cpu_place(ctx.device_context().GetPlace()), true,
|
|
|
|
|
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
|
|
|
|
|
// check index of shape 1-D
|
|
|
|
|
PADDLE_ENFORCE(index.dims().size() == 1 ||
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
index.dims().size() == 1 ||
|
|
|
|
|
(index.dims().size() == 2 && index.dims()[1] == 1),
|
|
|
|
|
"");
|
|
|
|
|
true, platform::errors::InvalidArgument("index's shape is error."));
|
|
|
|
|
int index_size = index.dims()[0];
|
|
|
|
|
|
|
|
|
|
auto src_dims = src.dims();
|
|
|
|
@ -130,7 +142,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
|
|
|
|
|
|
|
|
|
|
// check src shape and dst shape should match
|
|
|
|
|
for (int i = 1; i < src_dims.size(); i++)
|
|
|
|
|
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"src shape and dst shape should match"));
|
|
|
|
|
|
|
|
|
|
// slice size
|
|
|
|
|
size_t slice_size = 1;
|
|
|
|
@ -156,8 +170,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
|
|
|
|
|
template <typename T, typename IndexT = int>
|
|
|
|
|
void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
|
|
|
|
|
const Tensor& index, Tensor* output) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.device_context().GetPlace()),
|
|
|
|
|
true, "It should be running on the CPU");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
platform::is_cpu_place(ctx.device_context().GetPlace()), true,
|
|
|
|
|
platform::errors::PreconditionNotMet("It should be running on the CPU"));
|
|
|
|
|
|
|
|
|
|
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
|
|
|
|
|
auto index_dims = index.dims();
|
|
|
|
|