fix all bugs

revert-3824-remove_grad_op_type
zchen0211 8 years ago
parent 03d0040c59
commit 9430bc3207

@ -75,12 +75,12 @@ void ScatterUpdate(const platform::Place& place,
auto dst_dims = output->dims();
// check src shape and dst shape should match
for (size_t i = 1; i < src_dims.size(); i++)
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]);
// slice size
size_t slice_size = 1;
for (size_t i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
if (platform::is_cpu_place(place)) {
CPUScatterUpdate<T>(src, index->data<int>(), index_size, output);

Loading…
Cancel
Save