|
|
|
@ -24,123 +24,128 @@ namespace operators {
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void sequence_expand_kernel(const T* x_data, T* out_data,
|
|
|
|
|
const size_t* lod,
|
|
|
|
|
const size_t* out_offset,
|
|
|
|
|
size_t lod_size, size_t element_len,
|
|
|
|
|
size_t x_size) {
|
|
|
|
|
int bid_x = blockIdx.x;
|
|
|
|
|
if (bid_x > lod_size) return;
|
|
|
|
|
int repeats = lod[bid_x];
|
|
|
|
|
int offset = out_offset[bid_x];
|
|
|
|
|
for (int tid_y = threadIdx.y; tid_y < repeats; tid_y += blockDim.y) {
|
|
|
|
|
for (int tid_x = threadIdx.x; tid_x < element_len; tid_x += blockDim.x) {
|
|
|
|
|
out_data[(offset + tid_y) * element_len + tid_x] =
|
|
|
|
|
x_data[bid_x * element_len + tid_x];
|
|
|
|
|
__global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod,
|
|
|
|
|
const size_t* ref_lod,
|
|
|
|
|
const size_t lod_size,
|
|
|
|
|
/* default=1,
|
|
|
|
|
the instance length*/
|
|
|
|
|
const int x_item_length, T* out_data) {
|
|
|
|
|
constexpr int N = 1024;
|
|
|
|
|
__shared__ int mem[N];
|
|
|
|
|
int offset = 0;
|
|
|
|
|
for (int i = 0; i < lod_size; ++i) {
|
|
|
|
|
mem[i] = offset;
|
|
|
|
|
if (i < lod_size - 1) {
|
|
|
|
|
offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void sequence_expand_grad_kernel(const T* dout_data, T* dx_data,
|
|
|
|
|
const size_t* lod,
|
|
|
|
|
const size_t* out_offset,
|
|
|
|
|
size_t lod_size, size_t element_len,
|
|
|
|
|
size_t dout_size, size_t dx_size) {
|
|
|
|
|
// reduce visit memory time.
|
|
|
|
|
// dout_shm = [0 - dout_size-1], dx_shm = [dout_size-1, dout_size + dx_size-1]
|
|
|
|
|
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 &&
|
|
|
|
|
threadIdx.y == 0) {
|
|
|
|
|
printf("lod_size=%ld, element_size=%ld, dout_size=%ld, dx_size=%ld\n",
|
|
|
|
|
lod_size, element_len, dout_size, dx_size);
|
|
|
|
|
}
|
|
|
|
|
extern __shared__ T shm[];
|
|
|
|
|
T* dout_shm = shm;
|
|
|
|
|
T* dx_shm = &shm[dout_size];
|
|
|
|
|
|
|
|
|
|
// int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
|
|
|
|
for (int idx = 0; idx < dout_size; ++idx) {
|
|
|
|
|
if (idx < dx_size) {
|
|
|
|
|
dx_shm[idx] = 0.0;
|
|
|
|
|
}
|
|
|
|
|
if (idx < dout_size) {
|
|
|
|
|
dout_shm[idx] = dout_data[idx];
|
|
|
|
|
int bid = blockIdx.x;
|
|
|
|
|
if (bid >= lod_size - 1) return;
|
|
|
|
|
|
|
|
|
|
int x_item_count = x_lod[bid + 1] - x_lod[bid];
|
|
|
|
|
int repeats = ref_lod[bid + 1] - ref_lod[bid];
|
|
|
|
|
int out_offset = mem[bid];
|
|
|
|
|
int x_offset = x_lod[bid];
|
|
|
|
|
for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) {
|
|
|
|
|
for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) {
|
|
|
|
|
for (int tid_x = threadIdx.x; tid_x < x_item_length;
|
|
|
|
|
tid_x += blockDim.x) {
|
|
|
|
|
out_data[(out_offset + tid_z * x_item_count + tid_y) * x_item_length +
|
|
|
|
|
tid_x] = x_data[(x_offset + tid_y) * x_item_length + tid_x];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int bid_x = blockIdx.x;
|
|
|
|
|
if (bid_x > lod_size) return;
|
|
|
|
|
int repeats = lod[bid_x];
|
|
|
|
|
int offset = out_offset[bid_x];
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
printf("repeats=%d, offset=%ld\n", repeats, offset);
|
|
|
|
|
}
|
|
|
|
|
for (int tid_y = threadIdx.y; tid_y < repeats; tid_y += blockDim.y) {
|
|
|
|
|
for (int tid_x = threadIdx.x; tid_x < element_len; tid_x += blockDim.x) {
|
|
|
|
|
T val = dout_shm[(offset + tid_y) * element_len + tid_x];
|
|
|
|
|
platform::CudaAtomicAdd(&dx_shm[bid_x * element_len + tid_x], val);
|
|
|
|
|
int dx_idx = bid_x * element_len + tid_x;
|
|
|
|
|
int dout_idx = (offset + tid_y) * element_len + tid_x;
|
|
|
|
|
printf("dx_idx=%d, dout_idx=%d, dx_data=%f, dout_data=%f, val=%f \n",
|
|
|
|
|
dx_idx, dout_idx, dx_shm[dx_idx], dout_shm[dout_idx], val);
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void sequence_expand_grad_kernel(const T* dout_data,
|
|
|
|
|
const size_t* ref_lod,
|
|
|
|
|
const size_t* dx_lod,
|
|
|
|
|
const size_t lod_size,
|
|
|
|
|
/* default=1,
|
|
|
|
|
the instance length*/
|
|
|
|
|
const int x_item_length,
|
|
|
|
|
T* dx_data) {
|
|
|
|
|
// TODO(dzhwinter) : too many atomicAdd
|
|
|
|
|
// use shared memory to reduce memory visits
|
|
|
|
|
constexpr int N = 1024;
|
|
|
|
|
__shared__ int mem[N];
|
|
|
|
|
int offset = 0;
|
|
|
|
|
for (int i = 0; i < lod_size; ++i) {
|
|
|
|
|
mem[i] = offset;
|
|
|
|
|
if (i < lod_size - 1) {
|
|
|
|
|
offset += (ref_lod[i + 1] - ref_lod[i]) * (dx_lod[i + 1] - dx_lod[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
// copy shared memory back to dx
|
|
|
|
|
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < dx_size;
|
|
|
|
|
idx += blockDim.x * gridDim.x) {
|
|
|
|
|
dx_data[idx] = dx_shm[idx];
|
|
|
|
|
|
|
|
|
|
int bid = blockIdx.x;
|
|
|
|
|
if (bid >= lod_size - 1) return;
|
|
|
|
|
int x_item_count = dx_lod[bid + 1] - dx_lod[bid];
|
|
|
|
|
int repeats = ref_lod[bid + 1] - ref_lod[bid];
|
|
|
|
|
int out_offset = mem[bid];
|
|
|
|
|
int x_offset = dx_lod[bid];
|
|
|
|
|
|
|
|
|
|
for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) {
|
|
|
|
|
for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) {
|
|
|
|
|
for (int tid_x = threadIdx.x; tid_x < x_item_length;
|
|
|
|
|
tid_x += blockDim.x) {
|
|
|
|
|
platform::CudaAtomicAdd(
|
|
|
|
|
&dx_data[(x_offset + tid_y) * x_item_length + tid_x],
|
|
|
|
|
dout_data[(out_offset + tid_z * x_item_count + tid_y) *
|
|
|
|
|
x_item_length +
|
|
|
|
|
tid_x]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
|
const LoDTensor& x, LoDTensor* out) {
|
|
|
|
|
auto x_dims = x.dims();
|
|
|
|
|
size_t element_len = framework::product(x_dims) / x_dims[0];
|
|
|
|
|
auto lod = out->lod().back();
|
|
|
|
|
framework::Vector<size_t> out_lod;
|
|
|
|
|
for (size_t i = 0; i < lod.size() - 1; ++i) {
|
|
|
|
|
out_lod.push_back(lod[i + 1] - lod[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int thread_x = std::max(static_cast<int>(element_len), 32);
|
|
|
|
|
int block_x = static_cast<int>(out_lod.size());
|
|
|
|
|
dim3 block_size(thread_x, 1024 / thread_x);
|
|
|
|
|
void operator()(
|
|
|
|
|
const platform::CUDADeviceContext& context, const LoDTensor& x,
|
|
|
|
|
const framework::Vector<size_t>& x_lod, /*expand source lod*/
|
|
|
|
|
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
|
|
|
|
|
LoDTensor* out) {
|
|
|
|
|
int x_item_length = 1;
|
|
|
|
|
x_item_length = x.numel() / x.dims()[0];
|
|
|
|
|
VLOG(0) << "x_item_length" << x_item_length;
|
|
|
|
|
int thread_x = std::max(static_cast<int>(ref_lod.size()), 32);
|
|
|
|
|
int thread_y = std::max(1024 / thread_x, 16);
|
|
|
|
|
int thread_z = std::min(1024 / thread_x / thread_y, 16);
|
|
|
|
|
int block_x = static_cast<int>(ref_lod.size());
|
|
|
|
|
dim3 block_size(thread_x, thread_y, thread_z);
|
|
|
|
|
dim3 grid_size(block_x, 1);
|
|
|
|
|
|
|
|
|
|
sequence_expand_kernel<<<grid_size, block_size, 0, context.stream()>>>(
|
|
|
|
|
x.data<T>(), out->mutable_data<T>(context.GetPlace()),
|
|
|
|
|
out_lod.CUDAData(context.GetPlace()), lod.CUDAData(context.GetPlace()),
|
|
|
|
|
out_lod.size(), element_len, framework::product(x_dims));
|
|
|
|
|
x.data<T>(), x_lod.CUDAData(context.GetPlace()),
|
|
|
|
|
ref_lod.CUDAData(context.GetPlace()), x_lod.size(), x_item_length,
|
|
|
|
|
out->mutable_data<T>(context.GetPlace()));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
|
const LoDTensor& x, const LoDTensor& out,
|
|
|
|
|
const LoDTensor& dout, LoDTensor* dx) {
|
|
|
|
|
auto x_dims = x.dims();
|
|
|
|
|
size_t element_len = framework::product(x_dims) / x_dims[0];
|
|
|
|
|
auto lod = out.lod().back();
|
|
|
|
|
framework::Vector<size_t> out_lod;
|
|
|
|
|
for (size_t i = 0; i < lod.size() - 1; ++i) {
|
|
|
|
|
out_lod.push_back(lod[i + 1] - lod[i]);
|
|
|
|
|
}
|
|
|
|
|
size_t dout_size = framework::product(dout.dims());
|
|
|
|
|
size_t dx_size = framework::product(dx->dims());
|
|
|
|
|
|
|
|
|
|
int thread_x = std::max(static_cast<int>(element_len), 32);
|
|
|
|
|
dim3 block_size(thread_x, 1024 / thread_x);
|
|
|
|
|
int block_x = static_cast<int>(out_lod.size());
|
|
|
|
|
const LoDTensor& dout,
|
|
|
|
|
const framework::Vector<size_t>& x_lod, /*expand source lod*/
|
|
|
|
|
const framework::Vector<size_t>& ref_lod, /*expand based lod*/
|
|
|
|
|
LoDTensor* dx) {
|
|
|
|
|
int x_item_length = 1;
|
|
|
|
|
x_item_length = framework::product(dx->dims()) / dx->dims()[0];
|
|
|
|
|
|
|
|
|
|
int thread_x = std::max(static_cast<int>(ref_lod.size()), 32);
|
|
|
|
|
int thread_y = std::max(1024 / thread_x, 16);
|
|
|
|
|
int thread_z = std::min(1024 / thread_x / thread_y, 16);
|
|
|
|
|
int block_x = static_cast<int>(ref_lod.size());
|
|
|
|
|
dim3 block_size(thread_x, thread_y, thread_z);
|
|
|
|
|
dim3 grid_size(block_x, 1);
|
|
|
|
|
sequence_expand_grad_kernel<<<grid_size, block_size,
|
|
|
|
|
(dout_size + dx_size) * sizeof(T),
|
|
|
|
|
context.stream()>>>(
|
|
|
|
|
dout.data<T>(), dx->mutable_data<T>(context.GetPlace()),
|
|
|
|
|
out_lod.CUDAData(context.GetPlace()), lod.CUDAData(context.GetPlace()),
|
|
|
|
|
out_lod.size(), element_len, dout_size, dx_size);
|
|
|
|
|
sequence_expand_grad_kernel<<<grid_size, block_size, 0, context.stream()>>>(
|
|
|
|
|
dout.data<T>(), ref_lod.CUDAData(context.GetPlace()),
|
|
|
|
|
x_lod.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length,
|
|
|
|
|
dx->mutable_data<T>(context.GetPlace()));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|