parent
184768e070
commit
88a8eedda1
@ -0,0 +1,84 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "paddle/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
using platform::Place;
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
template <typename T>
|
||||
__global__ void GatherCUDAKernel(const T* params, const int* indices, T* output,
|
||||
size_t index_size, size_t slice_size) {
|
||||
CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
|
||||
int indices_i = i / slice_size;
|
||||
int slice_i = i - indices_i * slice_size; // offset inside the slice
|
||||
int gather_i = indices[indices_i];
|
||||
int params_i = gather_i * slice_size + slice_i;
|
||||
*(output + i) = *(params + params_i);
|
||||
}
|
||||
}
|
||||
|
||||
// Implementation of GPU copy:
|
||||
template <typename T>
|
||||
struct GPUGather {
|
||||
void operator()(const T* src, const int* index, const int slice_size,
|
||||
const int index_size, T* output) {
|
||||
int block = 512;
|
||||
int n = slice_size * index_size;
|
||||
int grid = (n + block - 1) / block;
|
||||
GatherCUDAKernel<T><<<grid, block>>>(src, index, output, index_size,
|
||||
slice_size);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* A thin wrapper on gpu tensor
|
||||
* Return a new tensor from source tensor, gathered according to index
|
||||
* input[src]: type-T source Tensor
|
||||
* input[index]: type-int index Tensor (1-D)
|
||||
* return: output tensor
|
||||
*/
|
||||
template <typename T>
|
||||
void GPUTGather(const Place& place, const Tensor* src, const Tensor* index,
|
||||
Tensor* output) {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(place));
|
||||
// check index of shape 1-D
|
||||
PADDLE_ENFORCE(index->dims().size() == 1);
|
||||
int index_size = index->dims()[0];
|
||||
|
||||
auto src_dims = src->dims();
|
||||
framework::DDim output_dims(src_dims);
|
||||
output_dims[0] = index_size;
|
||||
|
||||
// slice size
|
||||
int slice_size = 1;
|
||||
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
|
||||
|
||||
// Gathering
|
||||
GPUGather<T> gather_functor;
|
||||
gather_functor(src->data<T>(), index->data<int>(), slice_size, index_size,
|
||||
output->data<T>());
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,70 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "gather.cu.h"
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/operators/gather_op.h"
|
||||
#include "scatter.cu.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// template <typename T>
|
||||
__global__ void print_arr(const float *params, const int N) {
|
||||
CUDA_1D_KERNEL_LOOP(i, N) { printf("device: %d, %f\n", i, params[i]); }
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class GatherOpCUDAKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
auto *x = ctx.Input<Tensor>("X");
|
||||
auto *index = ctx.Input<Tensor>("Index");
|
||||
auto *output = ctx.Output<Tensor>("Out");
|
||||
|
||||
output->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
GPUTGather<T>(ctx.GetPlace(), x, index, output);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class GatherGradOpCUDAKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
LOG(INFO) << "Gather grad here";
|
||||
auto *Index = ctx.Input<Tensor>("Index");
|
||||
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto *x = ctx.Input<Tensor>("X");
|
||||
|
||||
dX->mutable_data<T>(ctx.GetPlace());
|
||||
auto dxt = framework::EigenVector<T>::Flatten(*dX);
|
||||
auto place = ctx.GetEigenDevice<platform::GPUPlace>();
|
||||
dxt.device(place) = dxt.constant(static_cast<T>(0));
|
||||
|
||||
GPUTScatter<T>(ctx.GetPlace(), dO, Index, dX);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(gather, ops::GatherOpCUDAKernel<float>);
|
||||
REGISTER_OP_GPU_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>);
|
@ -0,0 +1,86 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "paddle/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
template <typename T>
|
||||
__global__ void ScatterCUDAKernel(const T* params, const int* indices,
|
||||
T* output, size_t index_size,
|
||||
size_t slice_size) {
|
||||
CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
|
||||
int indices_i = i / slice_size;
|
||||
int slice_i = i - indices_i * slice_size; // offset inside the slice
|
||||
int scatter_i = indices[indices_i];
|
||||
int out_i = scatter_i * slice_size + slice_i;
|
||||
*(output + out_i) = *(params + i);
|
||||
}
|
||||
}
|
||||
|
||||
// Implementation of GPU copy:
|
||||
template <typename T>
|
||||
struct GPUScatterAssign {
|
||||
void operator()(const T* src, const int* index, const int slice_size,
|
||||
const int index_size, T* output) {
|
||||
int block = 512;
|
||||
int n = slice_size * index_size;
|
||||
int grid = (n + block - 1) / block;
|
||||
// printf("grid, block: %d %d\n", grid, block);
|
||||
ScatterCUDAKernel<T><<<grid, block>>>(src, index, output, index_size,
|
||||
slice_size);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* A thin wrapper on gpu tensor
|
||||
* Return a new updated tensor from source tensor, scatter-assigned according to
|
||||
* index
|
||||
* input[src]: type-T source Tensor
|
||||
* input[index]: type-int index Tensor (1-D)
|
||||
* return: output tensor
|
||||
*/
|
||||
template <typename T>
|
||||
void GPUTScatter(const platform::Place& place,
|
||||
const paddle::framework::Tensor* src,
|
||||
const paddle::framework::Tensor* index,
|
||||
paddle::framework::Tensor* output) {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(place));
|
||||
// check index of shape 1-D
|
||||
PADDLE_ENFORCE(index->dims().size() == 1);
|
||||
int index_size = index->dims()[0];
|
||||
|
||||
auto src_dims = src->dims();
|
||||
framework::DDim output_dims(src_dims);
|
||||
output_dims[0] = index_size;
|
||||
|
||||
// slice size
|
||||
int slice_size = 1;
|
||||
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
|
||||
|
||||
// Scatter Assign
|
||||
GPUScatterAssign<T> scatter_functor;
|
||||
scatter_functor(src->data<T>(), index->data<int>(), slice_size, index_size,
|
||||
output->data<T>());
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,63 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "gather.cu.h"
|
||||
#include "paddle/operators/gather_op.h"
|
||||
#include "scatter.cu.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class ScatterOpCUDAKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
auto *Ref = ctx.Input<Tensor>("Ref");
|
||||
auto *Index = ctx.Input<Tensor>("Index");
|
||||
auto *Updates = ctx.Input<Tensor>("Updates");
|
||||
auto *Out = ctx.Output<Tensor>("Out");
|
||||
|
||||
Out->ShareDataWith<T>(*Ref);
|
||||
|
||||
GPUTScatter<T>(ctx.GetPlace(), Updates, Index, Out);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ScatterGradOpCUDAKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
|
||||
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
|
||||
auto *Index = ctx.Input<Tensor>("Index");
|
||||
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
|
||||
// In place gradient: dRef = dO
|
||||
dRef->ShareDataWith<T>(*dOut);
|
||||
dUpdates->mutable_data<T>(ctx.GetPlace());
|
||||
// Gradient by Gather: dUpdates = dO[Index]
|
||||
GPUTGather<T>(ctx.GetPlace(), dOut, Index, dUpdates);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(scatter, ops::ScatterOpCUDAKernel<float>);
|
||||
REGISTER_OP_GPU_KERNEL(scatter_grad, ops::ScatterGradOpCUDAKernel<float>);
|
Loading…
Reference in new issue