|
|
|
@ -20,6 +20,19 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void SGDKernel(const T* g, const T* p, const T* learning_rate,
|
|
|
|
|
const int num, T* p_out) {
|
|
|
|
|
T lr = learning_rate[0];
|
|
|
|
|
int grid_size = blockDim.x * gridDim.x;
|
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += grid_size) {
|
|
|
|
|
T g_data = g[i];
|
|
|
|
|
T p_data = p[i];
|
|
|
|
|
p_out[i] = p_data - lr * g_data;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int block_size>
|
|
|
|
|
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
|
|
|
|
|
const int64_t* rows,
|
|
|
|
@ -41,40 +54,65 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows,
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SparseSGDFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input,
|
|
|
|
|
const framework::Tensor& learning_rate,
|
|
|
|
|
framework::Tensor* output) {
|
|
|
|
|
auto in_height = input.height();
|
|
|
|
|
auto out_dims = output->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
|
|
|
|
|
|
|
|
|
|
auto& in_value = input.value();
|
|
|
|
|
auto& in_rows = input.rows();
|
|
|
|
|
|
|
|
|
|
int64_t in_row_numel = in_value.numel() / in_rows.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
|
|
|
|
|
|
|
|
|
|
auto* in_data = in_value.data<T>();
|
|
|
|
|
auto* out_data = output->data<T>();
|
|
|
|
|
|
|
|
|
|
const int block_size = 256;
|
|
|
|
|
dim3 threads(block_size, 1);
|
|
|
|
|
dim3 grid(1, in_rows.size());
|
|
|
|
|
SparseSGDFunctorKernel<T, 256><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
|
in_data, in_rows.data(), learning_rate.data<T>(), out_data,
|
|
|
|
|
in_row_numel);
|
|
|
|
|
class SGDOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* param = ctx.Input<framework::Tensor>("Param");
|
|
|
|
|
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
|
|
|
|
|
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
|
|
|
|
|
|
|
|
|
|
auto* grad_var = ctx.InputVar("Grad");
|
|
|
|
|
// Actually, all tensors are LoDTensor except SelectedRows.
|
|
|
|
|
if (grad_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
param_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* grad = ctx.Input<framework::Tensor>("Grad");
|
|
|
|
|
auto* grad_data = grad->data<T>();
|
|
|
|
|
auto* param_data = param->data<T>();
|
|
|
|
|
auto* param_out_data = param_out->data<T>();
|
|
|
|
|
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (param->numel() + block - 1) / block;
|
|
|
|
|
|
|
|
|
|
SGDKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
grad_data, param_data, learning_rate->data<T>(), param->numel(),
|
|
|
|
|
param_out_data);
|
|
|
|
|
|
|
|
|
|
} else if (grad_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
|
|
|
|
|
// This manual optimization brings difficulty to track data dependency.
|
|
|
|
|
// It's better to find a more elegant solution.
|
|
|
|
|
PADDLE_ENFORCE_EQ(param, param_out);
|
|
|
|
|
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
|
|
|
|
|
|
|
|
|
|
auto in_height = grad->height();
|
|
|
|
|
auto out_dims = param_out->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
|
|
|
|
|
|
|
|
|
|
auto& in_value = grad->value();
|
|
|
|
|
auto& in_rows = grad->rows();
|
|
|
|
|
|
|
|
|
|
int64_t in_row_numel = in_value.numel() / in_rows.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);
|
|
|
|
|
|
|
|
|
|
auto* in_data = in_value.data<T>();
|
|
|
|
|
auto* out_data = param_out->data<T>();
|
|
|
|
|
|
|
|
|
|
const int block_size = 256;
|
|
|
|
|
dim3 threads(block_size, 1);
|
|
|
|
|
dim3 grid(1, in_rows.size());
|
|
|
|
|
SparseSGDFunctorKernel<
|
|
|
|
|
T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
in_data, in_rows.data(), learning_rate->data<T>(), out_data,
|
|
|
|
|
in_row_numel);
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Unsupported Variable Type of Grad");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template struct SparseSGDFunctor<platform::CUDADeviceContext, float>;
|
|
|
|
|
template struct SparseSGDFunctor<platform::CUDADeviceContext, double>;
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
sgd, ops::SGDOpKernel<paddle::platform::CUDADeviceContext, float>,
|
|
|
|
|
ops::SGDOpKernel<paddle::platform::CUDADeviceContext, double>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(sgd, ops::SGDOpCUDAKernel<float>,
|
|
|
|
|
ops::SGDOpCUDAKernel<double>);
|
|
|
|
|