revert-12469-sum_op_dim_fix
tangwei12 7 years ago
parent 60dda7bf9f
commit 470fb7c5c3

@ -36,15 +36,15 @@ class SamplingIdKernel : public framework::OpKernel<T> {
std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector);
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max")));
static_cast<T>(context.Attr<float>("min")),
static_cast<T>(context.Attr<float>("max")));
std::vector<T> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) {

@ -39,7 +39,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class SamplingIdKernel : public framework::OpKernel<T> {
class SamplingIdGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("X");
@ -83,5 +83,6 @@ class SamplingIdKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel<float>,
paddle::operators::SamplingIdKernel<double>);
REGISTER_OP_CUDA_KERNEL(sampling_id,
paddle::operators::SamplingIdGPUKernel<float>,
paddle::operators::SamplingIdGPUKernel<double>);

Loading…
Cancel
Save