Correct CPU gradients of the argsort op (#22739)

* Correct CPU gradients of the argsort op, form a network to test its forward and backward process, test=develop

* fix dynamic threshold error in test_argsort_op, test=develop
revert-22710-feature/integrated_ps_api
FlyingQianMM 5 years ago committed by GitHub
parent 2b80e9a719
commit 79d712346f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -81,13 +81,13 @@ static void FullAssign(Type input_height, Type input_width, int input_dim,
auto e_input = EigenVector<T>::Flatten(*input);
auto e_indices = EigenVector<Type>::Flatten(*indices);
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + e_indices(j)] = e_input(e_indices(j));
t_out[i * input_width + e_indices(j)] = e_input(j);
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
auto e_indices = EigenMatrix<Type>::Reshape(*indices, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + e_indices(i, j)] = e_input(i, e_indices(i, j));
t_out[i * input_width + e_indices(i, j)] = e_input(i, j);
}
}
}

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save