|
|
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#define EIGEN_USE_GPU
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include "paddle/fluid/operators/sgd_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/cuda_primitives.h"
|
|
|
|
|
|
|
|
|
@ -33,22 +33,21 @@ __global__ void SGDKernel(const T* g, const T* p, const T* learning_rate,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int block_size>
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
|
|
|
|
|
const int64_t* rows,
|
|
|
|
|
const T* learning_rate, T* tensor_out,
|
|
|
|
|
int64_t row_numel) {
|
|
|
|
|
const int ty = blockIdx.y;
|
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
|
|
|
|
|
|
selected_rows += ty * row_numel;
|
|
|
|
|
tensor_out += rows[ty] * row_numel;
|
|
|
|
|
|
|
|
|
|
for (int index = tid; index < row_numel; index += block_size) {
|
|
|
|
|
// Since index in rows of SelectedRows can be duplicate, we have to use
|
|
|
|
|
// Atomic Operation to avoid concurrent write error.
|
|
|
|
|
paddle::platform::CudaAtomicAdd(
|
|
|
|
|
tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]);
|
|
|
|
|
int64_t row_numel, int64_t limit) {
|
|
|
|
|
for (int64_t i = blockIdx.x; i < limit; i += gridDim.x) {
|
|
|
|
|
const T* selected_rows_ptr = selected_rows + i * row_numel;
|
|
|
|
|
T* tensor_out_ptr = tensor_out + rows[i] * row_numel;
|
|
|
|
|
for (int64_t index = threadIdx.x; index < row_numel; index += blockDim.x) {
|
|
|
|
|
// Since index in rows of SelectedRows can be duplicate, we have to use
|
|
|
|
|
// Atomic Operation to avoid concurrent write error.
|
|
|
|
|
paddle::platform::CudaAtomicAdd(
|
|
|
|
|
tensor_out_ptr + index,
|
|
|
|
|
-1.0 * learning_rate[0] * selected_rows_ptr[index]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
@ -97,13 +96,15 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* in_data = in_value.data<T>();
|
|
|
|
|
auto* out_data = param_out->data<T>();
|
|
|
|
|
|
|
|
|
|
const int block_size = 256;
|
|
|
|
|
dim3 threads(block_size, 1);
|
|
|
|
|
dim3 grid(1, in_rows.size());
|
|
|
|
|
SparseSGDFunctorKernel<
|
|
|
|
|
T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
const int kThreadsPerBlock = 256;
|
|
|
|
|
int thread_x = kThreadsPerBlock;
|
|
|
|
|
int max_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount();
|
|
|
|
|
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
|
|
|
|
|
|
|
|
|
|
SparseSGDFunctorKernel<<<max_blocks, thread_x, 0,
|
|
|
|
|
ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
in_data, in_rows.CUDAData(ctx.GetPlace()), learning_rate->data<T>(),
|
|
|
|
|
out_data, in_row_numel);
|
|
|
|
|
out_data, in_row_numel, in_rows.size());
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Unsupported Variable Type of Grad");
|
|
|
|
|