add enforce

revert-12469-sum_op_dim_fix
tangwei12 7 years ago
parent baa6273c54
commit 9f09d68678

@ -33,6 +33,10 @@ class SamplingIdKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]);
const int width = static_cast<int>(input->dims()[1]);
PADDLE_ENFORCE_GE(batch_size, 0,
"batch_size(dims[0]) must be nonnegative.");
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector);

@ -46,6 +46,10 @@ class SamplingIdGPUKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]);
const int width = static_cast<int>(input->dims()[1]);
PADDLE_ENFORCE_GE(batch_size, 0,
"batch_size(dims[0]) must be nonnegative.");
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector);
@ -56,10 +60,11 @@ class SamplingIdGPUKernel : public framework::OpKernel<T> {
}
T min = static_cast<T>(context.Attr<float>("min"));
T max = static_cast<T>(context.Attr<float>("max"));
UniformGenerator<T> gen = UniformGenerator<T>(min, max, seed);
std::vector<T> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
T r = UniformGenerator<T>(min, max, seed);
T r = gen(0);
int idx = width - 1;
for (int j = 0; j < width; ++j) {
if ((r -= ins_vector[i * width + j]) < 0) {

Loading…
Cancel
Save