|
|
|
@ -73,10 +73,16 @@ elementwise_inner_add(const framework::ExecutionContext& ctx,
|
|
|
|
|
template <typename T, typename IndexT = int>
|
|
|
|
|
void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
|
|
|
|
|
const Tensor& index, Tensor* output) {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true);
|
|
|
|
|
// check index of shape 1-D
|
|
|
|
|
PADDLE_ENFORCE(index.dims().size() == 1 ||
|
|
|
|
|
(index.dims().size() == 2 && index.dims()[1] == 1));
|
|
|
|
|
if (index.dims().size() == 2) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(index.dims()[1], 1,
|
|
|
|
|
"index.dims()[1] should be 1 when index.dims().size() == "
|
|
|
|
|
"2 in scatter_op.");
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(index.dims().size(), 1,
|
|
|
|
|
"index.dims().size() should be 1 or 2 in scatter_op.");
|
|
|
|
|
}
|
|
|
|
|
int index_size = index.dims()[0];
|
|
|
|
|
|
|
|
|
|
auto src_dims = src.dims();
|
|
|
|
@ -88,7 +94,7 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
|
|
|
|
|
|
|
|
|
|
// check src shape and dst shape should match
|
|
|
|
|
for (int i = 1; i < src_dims.size(); i++)
|
|
|
|
|
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i]);
|
|
|
|
|
|
|
|
|
|
// slice size
|
|
|
|
|
size_t slice_size = 1;
|
|
|
|
@ -105,10 +111,12 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
|
|
|
|
|
template <typename T, typename IndexT = int>
|
|
|
|
|
void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
|
|
|
|
|
const Tensor& index, Tensor* output) {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_cpu_place(ctx.device_context().GetPlace()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.device_context().GetPlace()),
|
|
|
|
|
true);
|
|
|
|
|
// check index of shape 1-D
|
|
|
|
|
PADDLE_ENFORCE(index.dims().size() == 1 ||
|
|
|
|
|
(index.dims().size() == 2 && index.dims()[1] == 1));
|
|
|
|
|
(index.dims().size() == 2 && index.dims()[1] == 1),
|
|
|
|
|
"");
|
|
|
|
|
int index_size = index.dims()[0];
|
|
|
|
|
|
|
|
|
|
auto src_dims = src.dims();
|
|
|
|
@ -122,7 +130,7 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
|
|
|
|
|
|
|
|
|
|
// check src shape and dst shape should match
|
|
|
|
|
for (int i = 1; i < src_dims.size(); i++)
|
|
|
|
|
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i]);
|
|
|
|
|
|
|
|
|
|
// slice size
|
|
|
|
|
size_t slice_size = 1;
|
|
|
|
|