|
|
|
@ -67,7 +67,7 @@ HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out,
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
x += offset_i * x_stride;
|
|
|
|
|
for (size_t j = 0; j < x_dim_i; ++j) {
|
|
|
|
|
for (size_t j = 0; j < out_dim_i; ++j) {
|
|
|
|
|
StridedMemcpy<T>(x, x_dims, out, out_dims, i + 1, rank, x_stride,
|
|
|
|
|
out_stride, offsets);
|
|
|
|
|
x += x_stride;
|
|
|
|
@ -86,8 +86,6 @@ struct RandomCropFunctor {
|
|
|
|
|
int rank_;
|
|
|
|
|
int64_t seed_;
|
|
|
|
|
|
|
|
|
|
size_t prod_x_dims_;
|
|
|
|
|
size_t prod_out_dims_;
|
|
|
|
|
size_t prod_batchsize_dims_;
|
|
|
|
|
size_t prod_x_ins_dims_;
|
|
|
|
|
size_t prod_out_ins_dims_;
|
|
|
|
@ -118,8 +116,6 @@ struct RandomCropFunctor {
|
|
|
|
|
prod_out_ins_dims_ *= out_dim_i;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
prod_x_dims_ = prod_batchsize_dims_ * prod_x_ins_dims_;
|
|
|
|
|
prod_out_dims_ = prod_batchsize_dims_ * prod_out_ins_dims_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(size_t ins_idx) {
|
|
|
|
@ -146,7 +142,17 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class RandomCropKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
virtual void Compute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
int64_t seed = *ctx.Input<framework::LoDTensor>("Seed")->data<int64_t>();
|
|
|
|
|
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
|
|
|
|
|
int64_t seed = 0;
|
|
|
|
|
if (platform::is_cpu_place(seed_tensor.place())) {
|
|
|
|
|
seed = *seed_tensor.data<int64_t>();
|
|
|
|
|
} else {
|
|
|
|
|
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
|
|
|
|
|
"your program";
|
|
|
|
|
framework::LoDTensor cpu_seed;
|
|
|
|
|
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
|
|
|
|
|
seed = *cpu_seed.data<int64_t>();
|
|
|
|
|
}
|
|
|
|
|
auto shape = ctx.Attr<std::vector<int>>("shape");
|
|
|
|
|
auto& x = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
|
|
|
|
|
auto& out = detail::Ref(ctx.Output<framework::LoDTensor>("Out"));
|
|
|
|
|