|
|
|
@ -75,12 +75,12 @@ void ScatterUpdate(const platform::Place& place,
|
|
|
|
|
auto dst_dims = output->dims();
|
|
|
|
|
|
|
|
|
|
// check src shape and dst shape should match
|
|
|
|
|
for (size_t i = 1; i < src_dims.size(); i++)
|
|
|
|
|
for (int i = 1; i < src_dims.size(); i++)
|
|
|
|
|
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]);
|
|
|
|
|
|
|
|
|
|
// slice size
|
|
|
|
|
size_t slice_size = 1;
|
|
|
|
|
for (size_t i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
|
|
|
|
|
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
|
|
|
|
|
|
|
|
|
|
if (platform::is_cpu_place(place)) {
|
|
|
|
|
CPUScatterUpdate<T>(src, index->data<int>(), index_size, output);
|
|
|
|
|