From 069d0004dc1334987e40d151b6e8521ebc80f661 Mon Sep 17 00:00:00 2001 From: Haonan Date: Fri, 4 Nov 2016 11:46:07 -0700 Subject: [PATCH 1/3] multi_binary_cross_entropy when ids vector is provided --- paddle/cuda/include/hl_matrix.h | 30 +++++++++ paddle/cuda/include/stub/hl_matrix_stub.h | 12 ++++ paddle/cuda/src/hl_cuda_matrix.cu | 78 +++++++++++++++++++++++ paddle/gserver/layers/CostLayer.cpp | 4 ++ paddle/gserver/tests/test_LayerGrad.cpp | 7 +- paddle/math/Matrix.cpp | 36 +++++++++++ paddle/math/Matrix.h | 4 ++ paddle/math/tests/test_matrixCompare.cpp | 66 ++++++++++++++++++- paddle/parameter/Argument.cpp | 22 +++++++ paddle/parameter/Argument.h | 8 +++ 10 files changed, 263 insertions(+), 4 deletions(-) diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index 71e8f8e3a6..6195e30b99 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -126,6 +126,36 @@ extern void hl_matrix_cross_entropy_bp(real* grad_d, int dimM, int dimN); +/** + * @brief Matrix multi-binary label cross entropy + * + * @param[in] output input matrix (M x N). + * @param[out] entropy output matrix (M x 1). + * @param[in] mat input sparse matrix. + * @param[in] dimM matrix height. + * @param[in] dimN matrix width. + */ +extern void hl_matrix_multi_binary_cross_entropy(real* output, + real* entropy, + hl_sparse_matrix_s mat, + int dimM, + int dimN); + +/** + * @brief Matrix multi-binary label cross entropy backprop + * + * @param[in] output input matrix (M x N). + * @param[out] grad output matrix (M x N). + * @param[in] mat input sparse matrix. + * @param[in] dimM matrix height. + * @param[in] dimN matrix width. + */ +extern void hl_matrix_multi_binary_cross_entropy_bp(real* output, + real* grad, + hl_sparse_matrix_s mat, + int dimM, + int dimN); + /** * @brief Matrix zero memory. * diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index e37b127543..76cac2e577 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -57,6 +57,18 @@ inline void hl_matrix_cross_entropy_bp(real* grad_d, int dimM, int dimN) {} +inline void hl_matrix_multi_binary_cross_entropy(real* output, + real* entropy, + hl_sparse_matrix_s mat, + int dimM, + int dimN) {} + +inline void hl_matrix_multi_binary_cross_entropy_bp(real* output, + real* grad, + hl_sparse_matrix_s mat, + int dimM, + int dimN) {} + inline void hl_matrix_zero_mem(real* data, int num) {} inline void hl_param_relu_forward(real* output, diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 3df9f63f9e..001b62a6b9 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include "hl_matrix_ops.cuh" #include "hl_matrix_apply.cuh" #include "hl_sequence.h" +#include "hl_sparse.ph" #include "paddle/utils/Logging.h" #include "hl_device_functions.cuh" #include "hl_gpu_matrix_kernel.cuh" @@ -317,6 +318,83 @@ void hl_matrix_classification_error(real* A_d, CHECK_SYNC("hl_matrix_classification_error"); } +__global__ void KeMatrixMultiBinaryCrossEntropy(real* output, + real* entropy, + int* row, + int* col, + int dimM, + int dimN) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < dimM) { + for (int i = 0; i < dimN; i ++) { + entropy[index] -= log(1 - output[index * dimN + i]); + } + int *row_col = col + row[index]; + int col_num = row[index + 1] - row[index]; + for (int i = 0; i < col_num; i ++) { + real o = output[index * dimN + row_col[i]]; + entropy[index] -= log(o / (1 - o)); + } + } +} + +void hl_matrix_multi_binary_cross_entropy(real* output, + real* entropy, + hl_sparse_matrix_s csr_mat, + int dimM, + int dimN) { + CHECK_NOTNULL(output); + CHECK_NOTNULL(entropy); + CHECK_NOTNULL(csr_mat); + int n_threads = 1024; + int blocks = (dimM + n_threads - 1) / n_threads; + dim3 threads(n_threads); + dim3 grid(blocks); + hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix); + KeMatrixMultiBinaryCrossEntropy<<< grid, threads, 0, STREAM_DEFAULT >>> + (output, entropy, mat->csr_row, mat->csr_col, dimM, dimN); + CHECK_SYNC("hl_matrix_multi_binary_cross_entropy failed"); +} + +__global__ void KeMatrixMultiBinaryCrossEntropyBp(real* output, + real* grad, + int* row, + int* col, + int dimM, + int dimN) { + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (row_idx < dimM) { + for (int i = 0; i < dimN; i ++) { + int index = row_idx * dimN + i; + grad[index] += 1.0 / (1 - output[index]); + } + int col_num = row[row_idx + 1] - row[row_idx]; + int *row_col = col + row[row_idx]; + for (int i = 0; i < col_num; i ++) { + int index = row_idx * dimN + row_col[i]; + grad[index] -= 1.0 / (output[index] * (1 - output[index])); + } + } +} + +void hl_matrix_multi_binary_cross_entropy_bp(real* output, + real* grad, + hl_sparse_matrix_s csr_mat, + int dimM, + int dimN) { + CHECK_NOTNULL(output); + CHECK_NOTNULL(grad); + CHECK_NOTNULL(csr_mat); + int n_threads = 1024; + int blocks = (dimM + n_threads - 1) / n_threads; + dim3 threads(n_threads); + dim3 grid(blocks); + hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix); + KeMatrixMultiBinaryCrossEntropyBp<<< grid, threads, 0, STREAM_DEFAULT >>> + (output, grad, mat->csr_row, mat->csr_col, dimM, dimN); + CHECK_SYNC("hl_matrix_multi_binary_cross_entropy_bp failed"); +} + __global__ void KeMatrixCrossEntropy(real* O, real* E, int* label, diff --git a/paddle/gserver/layers/CostLayer.cpp b/paddle/gserver/layers/CostLayer.cpp index 949788be49..c86e562d0e 100644 --- a/paddle/gserver/layers/CostLayer.cpp +++ b/paddle/gserver/layers/CostLayer.cpp @@ -462,6 +462,8 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap, void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label, Matrix& target) { + label.idsToSparseMatrix(output.getWidth(), useGpu_); + if (dynamic_cast(label.value.get()) || dynamic_cast(label.value.get())) { target.multiBinaryLabelCrossEntropy(output, *label.value); @@ -476,6 +478,8 @@ void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label, void MultiBinaryLabelCrossEntropy::backwardImp( Matrix& output, Argument& label, Matrix& outputG) { + label.idsToSparseMatrix(output.getWidth(), useGpu_); + if (dynamic_cast(label.value.get()) || dynamic_cast(label.value.get())) { outputG.multiBinaryLabelCrossEntropyBp(output, *label.value); diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index e7e07e9e69..f19c14f569 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -538,9 +538,10 @@ TEST(Layer, multi_binary_label) { config.layerConfig.add_inputs(); config.layerConfig.add_inputs(); - // Not support GPU now - testLayerGrad(config, "multi_binary_label_cross_entropy", 100, - /* trans */ false, /* useGpu */ false); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "multi_binary_label_cross_entropy", 100, + /* trans */ false, useGpu); + } } TEST(Layer, multi_cross_with_selfnorm) { diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 950c3bb6cc..9acc600553 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1268,6 +1268,42 @@ void GpuMatrix::bilinearBackward(const Matrix& out, } } +void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) { + GpuMatrix* output_ptr = dynamic_cast(&output); + auto label_ptr = dynamic_cast(&label); + + CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; + CHECK(label_ptr->format_ == SPARSE_CSR) << "Matrix format not supported"; + CHECK(height_ == output_ptr->height_ && width_ == 1 + && output_ptr->width_ == label_ptr->getWidth() + && output_ptr->height_ == label_ptr->getHeight()) + << "Matrix dimensions are not equal"; + + real* output_d = output_ptr->data_; + real* entropy_d = data_; + hl_sparse_matrix_s mat_d = label_ptr->sMatrix_.get(); + hl_matrix_multi_binary_cross_entropy( + output_d, entropy_d, mat_d, height_, output_ptr->width_); +} + +void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix &output, Matrix &label) { + GpuMatrix* output_ptr = dynamic_cast(&output); + auto label_ptr = dynamic_cast(&label); + + CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; + CHECK(label_ptr->format_ == SPARSE_CSR) << "Matrix format not supported"; + CHECK(height_ == output_ptr->height_ && width_ == output_ptr->width_ + && output_ptr->width_ == label_ptr->getWidth() + && output_ptr->height_ == label_ptr->getHeight()) + << "Matrix dimensions are not equal"; + + real* output_d = output_ptr->data_; + real* grad_d = data_; + hl_sparse_matrix_s mat_d = label_ptr->sMatrix_.get(); + hl_matrix_multi_binary_cross_entropy_bp( + output_d, grad_d, mat_d, height_, width_); +} + /** * CpuMatrix */ diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 700be75902..6c3c4804d2 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -1303,6 +1303,10 @@ public: const size_t numChannels, const real ratioH, const real ratioW); + + void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label); + + void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label); }; class CpuMatrix : public Matrix { diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index b3ee4bc349..a41e21903f 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -2208,7 +2208,6 @@ void testCollectSharedBias(int numSamples, int dim, int channel) { MatrixCheckErr(*cpuBias, *check); } - TEST(Matrix, sharedBias) { for (auto numSamples : {1, 100, 520}) { for (auto dim : {100 * 16, 100 * 32}) { @@ -2222,6 +2221,71 @@ TEST(Matrix, sharedBias) { } } +void testMultiBinaryLabelCrossEntropy(int numSamples, int dim) { + MatrixPtr output = std::make_shared(numSamples, dim); + MatrixPtr cpuOutput = std::make_shared(numSamples, dim); + MatrixPtr gpuOutput = std::make_shared(numSamples, dim); + + MatrixPtr cpuEntropy = std::make_shared(numSamples, 1); + MatrixPtr gpuEntropy = std::make_shared(numSamples, 1); + + MatrixPtr cpuGrad = std::make_shared(numSamples, dim); + MatrixPtr gpuGrad = std::make_shared(numSamples, dim); + + auto cpuRows = IVector::create(numSamples + 1, false); + auto cpuCols = IVector::create(numSamples, false); + auto gpuRows = IVector::create(numSamples + 1, true); + auto gpuCols = IVector::create(numSamples, true); + cpuRows->setElement(0, 0); + gpuRows->setElement(0, 0); + for (int i = 0; i < numSamples; i ++) { + int id = rand() % dim; // NOLINT + cpuRows->setElement(i + 1, i + 1); + gpuRows->setElement(i + 1, i + 1); + cpuCols->setElement(i, id); + gpuCols->setElement(i, id); + } + + MatrixPtr cpuLabel = std::make_shared + (nullptr, cpuRows->getData(), cpuCols->getData(), + numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); + MatrixPtr gpuLabel = std::make_shared + (nullptr, gpuRows->getData(), gpuCols->getData(), + numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); + + output->randomizeUniform(); + cpuOutput->zeroMem(); + output->softmax(*cpuOutput); + gpuOutput->copyFrom(*cpuOutput); + + cpuEntropy->zeroMem(); + gpuEntropy->zeroMem(); + cpuEntropy->multiBinaryLabelCrossEntropy(*cpuOutput, *cpuLabel); + gpuEntropy->multiBinaryLabelCrossEntropy(*gpuOutput, *gpuLabel); + + MatrixPtr check1 = std::make_shared(numSamples, 1); + check1->copyFrom(*gpuEntropy); + MatrixCheckErr(*cpuEntropy, *check1); + + cpuGrad->zeroMem(); + gpuGrad->zeroMem(); + cpuGrad->multiBinaryLabelCrossEntropyBp(*cpuOutput, *cpuLabel); + gpuGrad->multiBinaryLabelCrossEntropyBp(*gpuOutput, *gpuLabel); + + MatrixPtr check2 = std::make_shared(numSamples, dim); + check2->copyFrom(*gpuGrad); + MatrixCheckErr(*cpuGrad, *check2); +} + +TEST(Matrix, multiBinaryCrossEntropy) { + for (auto numSamples : {1, 100, 500}) { + for (auto dim : {1000, 10000, 100000}) { + VLOG(3) << " numSamples=" << numSamples << " dim=" << dim; + testMultiBinaryLabelCrossEntropy(numSamples, dim); + } + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index 42c74661d2..a5a96742e4 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -572,4 +572,26 @@ void Argument::subArgFrom(const Argument& input, size_t offset, size_t height, } } +void Argument::idsToSparseMatrix(int width, bool useGpu) { + if (ids) { + CHECK(!value); + int height = ids->getSize(); + int nnz = height; + auto rows = IVector::create(height + 1, useGpu); + auto cols = IVector::create(nnz, useGpu); + rows->setElement(0, 0); + for (int i = 0; i < height; i ++) { + int id = ids->getElement(i); + CHECK_LT(id, width); + rows->setElement(i + 1, i + 1); + cols->setElement(i, id); + } + value = Matrix::createSparseMatrix( + nullptr, rows->getData(), cols->getData(), + height, width, nnz, NO_VALUE, SPARSE_CSR, false, useGpu); + } else { + CHECK(value); + } +} + } // namespace paddle diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h index 81ff9029bc..48e1551258 100644 --- a/paddle/parameter/Argument.h +++ b/paddle/parameter/Argument.h @@ -286,6 +286,14 @@ struct Argument { sequence has sub-sequence degrades to a sequence. */ void degradeSequence(const Argument& input, bool useGpu); + + /* + @brief convert the ids vector to value as a sparse matrix + the ids vector keeps valid + @param the matrix width (id range) + @useGpu + */ + void idsToSparseMatrix(int width, bool useGpu); }; } // namespace paddle From 728defbec90a162ee5d7f8521106ded7797e72fa Mon Sep 17 00:00:00 2001 From: Haonan Date: Fri, 11 Nov 2016 16:31:56 -0800 Subject: [PATCH 2/3] copy the data when createSparseMatrix --- paddle/gserver/layers/CostLayer.cpp | 40 ++++++++++++++++------ paddle/parameter/Argument.cpp | 52 +++++++++++++++++++---------- paddle/parameter/Argument.h | 6 ++-- 3 files changed, 66 insertions(+), 32 deletions(-) diff --git a/paddle/gserver/layers/CostLayer.cpp b/paddle/gserver/layers/CostLayer.cpp index c86e562d0e..900981d1e7 100644 --- a/paddle/gserver/layers/CostLayer.cpp +++ b/paddle/gserver/layers/CostLayer.cpp @@ -462,29 +462,49 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap, void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label, Matrix& target) { - label.idsToSparseMatrix(output.getWidth(), useGpu_); + MatrixPtr value = nullptr; + if (label.ids) { + CHECK(!label.value); + value = Matrix::createSparseMatrix( + label.ids->getSize(), output.getWidth(), label.ids->getSize(), + NO_VALUE, SPARSE_CSR, false, useGpu_); + label.idsToSparseMatrix(value); + } else { + CHECK(label.value); + value = label.value; + } - if (dynamic_cast(label.value.get()) || - dynamic_cast(label.value.get())) { - target.multiBinaryLabelCrossEntropy(output, *label.value); + if (dynamic_cast(value.get()) || + dynamic_cast(value.get())) { + target.multiBinaryLabelCrossEntropy(output, *value); } else { Matrix::resizeOrCreate(targetPerDim_, output.getHeight(), output.getWidth(), false, useGpu_); - targetPerDim_->binaryLabelCrossEntropy(output, *label.value); + targetPerDim_->binaryLabelCrossEntropy(output, *value); targetPerDim_->rowSum(target); } } void MultiBinaryLabelCrossEntropy::backwardImp( Matrix& output, Argument& label, Matrix& outputG) { - label.idsToSparseMatrix(output.getWidth(), useGpu_); + MatrixPtr value = nullptr; + if (label.ids) { + CHECK(!value); + value = Matrix::createSparseMatrix( + label.ids->getSize(), output.getWidth(), label.ids->getSize(), + NO_VALUE, SPARSE_CSR, false, useGpu_); + label.idsToSparseMatrix(value); + } else { + CHECK(label.value); + value = label.value; + } - if (dynamic_cast(label.value.get()) || - dynamic_cast(label.value.get())) { - outputG.multiBinaryLabelCrossEntropyBp(output, *label.value); + if (dynamic_cast(value.get()) || + dynamic_cast(value.get())) { + outputG.multiBinaryLabelCrossEntropyBp(output, *value); } else { - outputG.binaryLabelCrossEntropyBp(output, *label.value); + outputG.binaryLabelCrossEntropyBp(output, *value); } } diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index a5a96742e4..354d0ead07 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -572,25 +572,41 @@ void Argument::subArgFrom(const Argument& input, size_t offset, size_t height, } } -void Argument::idsToSparseMatrix(int width, bool useGpu) { - if (ids) { - CHECK(!value); - int height = ids->getSize(); - int nnz = height; - auto rows = IVector::create(height + 1, useGpu); - auto cols = IVector::create(nnz, useGpu); - rows->setElement(0, 0); - for (int i = 0; i < height; i ++) { - int id = ids->getElement(i); - CHECK_LT(id, width); - rows->setElement(i + 1, i + 1); - cols->setElement(i, id); - } - value = Matrix::createSparseMatrix( - nullptr, rows->getData(), cols->getData(), - height, width, nnz, NO_VALUE, SPARSE_CSR, false, useGpu); +void Argument::idsToSparseMatrix(MatrixPtr sparse_mat) { + int height = ids->getSize(); + int width = sparse_mat->getWidth(); + + CpuIVector cpu_ids(height); + cpu_ids.copyFrom(*ids); + int *id_data = cpu_ids.getData(); + + int *rows = nullptr; + int *cols = nullptr; + if (sparse_mat->useGpu()) { + auto gpu_sparse_mat = + dynamic_cast(sparse_mat.get()); + rows = gpu_sparse_mat->rows_; + cols = gpu_sparse_mat->cols_; } else { - CHECK(value); + rows = sparse_mat->getRows(); + cols = sparse_mat->getCols(); + } + + rows[0] = 0; + for (int i = 0; i < height; i ++) { + int id = id_data[i]; + CHECK_LT(id, width); + rows[i + 1] = i + 1; + cols[i] = id; + } + + if (sparse_mat->useGpu()) { + auto gpu_sparse_mat = + dynamic_cast(sparse_mat.get()); + hl_memcpy_csr_matrix(gpu_sparse_mat->sMatrix_.get(), + nullptr, rows, cols, + HPPL_STREAM_DEFAULT); + hl_stream_synchronize(HPPL_STREAM_DEFAULT); } } diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h index 48e1551258..695033138b 100644 --- a/paddle/parameter/Argument.h +++ b/paddle/parameter/Argument.h @@ -289,11 +289,9 @@ struct Argument { /* @brief convert the ids vector to value as a sparse matrix - the ids vector keeps valid - @param the matrix width (id range) - @useGpu + @param[out] the output sparse_mat (already allocated) */ - void idsToSparseMatrix(int width, bool useGpu); + void idsToSparseMatrix(MatrixPtr sparse_mat); }; } // namespace paddle From 5591292b7a0c9fa9224fcea033870b656e0a08bb Mon Sep 17 00:00:00 2001 From: Haonan Date: Fri, 11 Nov 2016 21:31:32 -0800 Subject: [PATCH 3/3] modifications according to comments --- paddle/cuda/src/hl_cuda_matrix.cu | 6 ++-- paddle/gserver/layers/CostLayer.cpp | 10 ++---- paddle/gserver/tests/test_LayerGrad.cpp | 18 +++++++++- paddle/math/CpuSparseMatrix.cpp | 3 -- paddle/math/Matrix.cpp | 42 ++++++++++++------------ paddle/math/Vector.cpp | 26 +++++++++++++++ paddle/math/Vector.h | 8 +++++ paddle/math/tests/test_matrixCompare.cpp | 29 +++++----------- paddle/parameter/Argument.cpp | 38 --------------------- paddle/parameter/Argument.h | 6 ---- 10 files changed, 87 insertions(+), 99 deletions(-) diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 001b62a6b9..0b7cd33756 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -346,6 +346,7 @@ void hl_matrix_multi_binary_cross_entropy(real* output, CHECK_NOTNULL(output); CHECK_NOTNULL(entropy); CHECK_NOTNULL(csr_mat); + CHECK_EQ(csr_mat->format, HL_SPARSE_CSR); int n_threads = 1024; int blocks = (dimM + n_threads - 1) / n_threads; dim3 threads(n_threads); @@ -385,6 +386,7 @@ void hl_matrix_multi_binary_cross_entropy_bp(real* output, CHECK_NOTNULL(output); CHECK_NOTNULL(grad); CHECK_NOTNULL(csr_mat); + CHECK_EQ(csr_mat->format, HL_SPARSE_CSR); int n_threads = 1024; int blocks = (dimM + n_threads - 1) / n_threads; dim3 threads(n_threads); @@ -763,7 +765,7 @@ __global__ void KeMatrixAddSharedBias(real* A, int dim = N / channel; if (index < M * N) { int i = index % N; - i = i / dim; + i = i / dim; A[index] += scale * B[i]; } } @@ -791,7 +793,7 @@ __global__ void KeMatrixCollectSharedBias(real *B, const int dim, const int limit, real scale) { - if (dim < limit) { + if (dim < limit) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < channel) { real sum = 0.0; diff --git a/paddle/gserver/layers/CostLayer.cpp b/paddle/gserver/layers/CostLayer.cpp index 900981d1e7..2f85dd3c3b 100644 --- a/paddle/gserver/layers/CostLayer.cpp +++ b/paddle/gserver/layers/CostLayer.cpp @@ -465,10 +465,7 @@ void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label, MatrixPtr value = nullptr; if (label.ids) { CHECK(!label.value); - value = Matrix::createSparseMatrix( - label.ids->getSize(), output.getWidth(), label.ids->getSize(), - NO_VALUE, SPARSE_CSR, false, useGpu_); - label.idsToSparseMatrix(value); + value = label.ids->toOneHotSparseMatrix(output.getWidth(), useGpu_); } else { CHECK(label.value); value = label.value; @@ -491,10 +488,7 @@ void MultiBinaryLabelCrossEntropy::backwardImp( MatrixPtr value = nullptr; if (label.ids) { CHECK(!value); - value = Matrix::createSparseMatrix( - label.ids->getSize(), output.getWidth(), label.ids->getSize(), - NO_VALUE, SPARSE_CSR, false, useGpu_); - label.idsToSparseMatrix(value); + value = label.ids->toOneHotSparseMatrix(output.getWidth(), useGpu_); } else { CHECK(label.value); value = label.value; diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index f19c14f569..f3cd2b4faf 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -528,7 +528,7 @@ TEST(Layer, multi_cross) { } } -TEST(Layer, multi_binary_label) { +TEST(Layer, multi_binary_label_sparse_mat) { TestConfig config; config.layerConfig.set_type("multi_binary_label_cross_entropy"); config.biasSize = 0; @@ -544,6 +544,22 @@ TEST(Layer, multi_binary_label) { } } +TEST(layer, multi_binary_label_id) { + TestConfig config; + config.layerConfig.set_type("multi_binary_label_cross_entropy"); + config.biasSize = 0; + + config.inputDefs.push_back({INPUT_DATA, "layer_0", 50, 0}); + config.inputDefs.push_back({INPUT_LABEL, "layer_1", 10, 0}); + config.layerConfig.add_inputs(); + config.layerConfig.add_inputs(); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "multi_binary_label_cross_entropy", 100, + /* trans */ false, useGpu); + } +} + TEST(Layer, multi_cross_with_selfnorm) { TestConfig config; config.layerConfig.set_type("multi_class_cross_entropy_with_selfnorm"); diff --git a/paddle/math/CpuSparseMatrix.cpp b/paddle/math/CpuSparseMatrix.cpp index 842efdbe3d..64ee124a56 100644 --- a/paddle/math/CpuSparseMatrix.cpp +++ b/paddle/math/CpuSparseMatrix.cpp @@ -409,9 +409,6 @@ void CpuSparseMatrix::setRow(size_t row, size_t colNum, if (format_ == SPARSE_CSR) { CHECK_LT(row, height_); CHECK(NULL != cols); - for (size_t i = row; i < height_; i++) { - CHECK_EQ(rows_[i + 1], rows_[i]); - } if (0 == row) { rows_[row] = 0; } diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 9acc600553..5ee8fbebfc 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1269,37 +1269,37 @@ void GpuMatrix::bilinearBackward(const Matrix& out, } void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) { - GpuMatrix* output_ptr = dynamic_cast(&output); - auto label_ptr = dynamic_cast(&label); - - CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; - CHECK(label_ptr->format_ == SPARSE_CSR) << "Matrix format not supported"; - CHECK(height_ == output_ptr->height_ && width_ == 1 - && output_ptr->width_ == label_ptr->getWidth() - && output_ptr->height_ == label_ptr->getHeight()) + GpuMatrix* outputPtr = dynamic_cast(&output); + auto labelPtr = dynamic_cast(&label); + + CHECK(outputPtr && labelPtr) << "Invalid argument pointer"; + CHECK(labelPtr->format_ == SPARSE_CSR) << "Matrix format not supported"; + CHECK(height_ == outputPtr->height_ && width_ == 1 + && outputPtr->width_ == labelPtr->getWidth() + && outputPtr->height_ == labelPtr->getHeight()) << "Matrix dimensions are not equal"; - real* output_d = output_ptr->data_; + real* output_d = outputPtr->data_; real* entropy_d = data_; - hl_sparse_matrix_s mat_d = label_ptr->sMatrix_.get(); + hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get(); hl_matrix_multi_binary_cross_entropy( - output_d, entropy_d, mat_d, height_, output_ptr->width_); + output_d, entropy_d, mat_d, height_, outputPtr->width_); } void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix &output, Matrix &label) { - GpuMatrix* output_ptr = dynamic_cast(&output); - auto label_ptr = dynamic_cast(&label); - - CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; - CHECK(label_ptr->format_ == SPARSE_CSR) << "Matrix format not supported"; - CHECK(height_ == output_ptr->height_ && width_ == output_ptr->width_ - && output_ptr->width_ == label_ptr->getWidth() - && output_ptr->height_ == label_ptr->getHeight()) + GpuMatrix* outputPtr = dynamic_cast(&output); + auto labelPtr = dynamic_cast(&label); + + CHECK(outputPtr && labelPtr) << "Invalid argument pointer"; + CHECK(labelPtr->format_ == SPARSE_CSR) << "Matrix format not supported"; + CHECK(height_ == outputPtr->height_ && width_ == outputPtr->width_ + && outputPtr->width_ == labelPtr->getWidth() + && outputPtr->height_ == labelPtr->getHeight()) << "Matrix dimensions are not equal"; - real* output_d = output_ptr->data_; + real* output_d = outputPtr->data_; real* grad_d = data_; - hl_sparse_matrix_s mat_d = label_ptr->sMatrix_.get(); + hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get(); hl_matrix_multi_binary_cross_entropy_bp( output_d, grad_d, mat_d, height_, width_); } diff --git a/paddle/math/Vector.cpp b/paddle/math/Vector.cpp index 7553ea25e0..23c9caccea 100644 --- a/paddle/math/Vector.cpp +++ b/paddle/math/Vector.cpp @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/utils/ThreadLocal.h" #include "paddle/utils/Thread.h" #include "paddle/utils/Flags.h" +#include "Matrix.h" #include "hl_gpu.h" #include "hl_table_apply.h" @@ -73,6 +74,31 @@ std::shared_ptr> VectorT::create(size_t size, } } +template <> +MatrixPtr VectorT::toOneHotSparseMatrix(size_t idRange, bool useGpu) { + LOG(FATAL) << "Wrong for real vector"; + return nullptr; +} + +template <> +MatrixPtr VectorT::toOneHotSparseMatrix(size_t idRange, bool useGpu) { + int height = getSize(); + int width = idRange; + MatrixPtr mat = Matrix::createSparseMatrix( + height, idRange, height, NO_VALUE, SPARSE_CSR, false, useGpu); + + CpuIVector cpuIds(height); + cpuIds.copyFrom(*this); + int *idData = cpuIds.getData(); + + for (int i = 0; i < height; i ++) { + const unsigned int id = idData[i]; + CHECK_LT(id, width); + mat->setRow(i, 1, &id, nullptr); + } + return mat; +} + template GpuVectorT::GpuVectorT(size_t size) : VectorT(size, std::make_shared(sizeof(T) * size), diff --git a/paddle/math/Vector.h b/paddle/math/Vector.h index ee0a83bf03..faf8186b6d 100644 --- a/paddle/math/Vector.h +++ b/paddle/math/Vector.h @@ -37,6 +37,8 @@ class BaseVector; class SyncThreadPool; +class Matrix; + template class BaseVector : public BaseMatrixT { public: @@ -155,6 +157,12 @@ public: subVecFrom(src, interval.first, interval.second - interval.first); } + /** + * convert the vector to a sparse one_hot matrix of width idRange + * only applies to IVector + */ + std::shared_ptr toOneHotSparseMatrix(size_t idRange, bool useGpu); + /** * This function will crash if the size of src and dest is different. */ diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index a41e21903f..9c03695ba5 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -2232,26 +2232,15 @@ void testMultiBinaryLabelCrossEntropy(int numSamples, int dim) { MatrixPtr cpuGrad = std::make_shared(numSamples, dim); MatrixPtr gpuGrad = std::make_shared(numSamples, dim); - auto cpuRows = IVector::create(numSamples + 1, false); - auto cpuCols = IVector::create(numSamples, false); - auto gpuRows = IVector::create(numSamples + 1, true); - auto gpuCols = IVector::create(numSamples, true); - cpuRows->setElement(0, 0); - gpuRows->setElement(0, 0); - for (int i = 0; i < numSamples; i ++) { - int id = rand() % dim; // NOLINT - cpuRows->setElement(i + 1, i + 1); - gpuRows->setElement(i + 1, i + 1); - cpuCols->setElement(i, id); - gpuCols->setElement(i, id); - } - MatrixPtr cpuLabel = std::make_shared - (nullptr, cpuRows->getData(), cpuCols->getData(), - numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); + (numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); MatrixPtr gpuLabel = std::make_shared - (nullptr, gpuRows->getData(), gpuCols->getData(), - numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); + (numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); + for (int i = 0; i < numSamples; i ++) { + const unsigned int id = rand() % dim; // NOLINT + cpuLabel->setRow(i, 1, &id, nullptr); + gpuLabel->setRow(i, 1, &id, nullptr); + } output->randomizeUniform(); cpuOutput->zeroMem(); @@ -2278,8 +2267,8 @@ void testMultiBinaryLabelCrossEntropy(int numSamples, int dim) { } TEST(Matrix, multiBinaryCrossEntropy) { - for (auto numSamples : {1, 100, 500}) { - for (auto dim : {1000, 10000, 100000}) { + for (auto numSamples : {100, 1000, 10000}) { + for (auto dim : {100, 1000, 10000}) { VLOG(3) << " numSamples=" << numSamples << " dim=" << dim; testMultiBinaryLabelCrossEntropy(numSamples, dim); } diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index 354d0ead07..42c74661d2 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -572,42 +572,4 @@ void Argument::subArgFrom(const Argument& input, size_t offset, size_t height, } } -void Argument::idsToSparseMatrix(MatrixPtr sparse_mat) { - int height = ids->getSize(); - int width = sparse_mat->getWidth(); - - CpuIVector cpu_ids(height); - cpu_ids.copyFrom(*ids); - int *id_data = cpu_ids.getData(); - - int *rows = nullptr; - int *cols = nullptr; - if (sparse_mat->useGpu()) { - auto gpu_sparse_mat = - dynamic_cast(sparse_mat.get()); - rows = gpu_sparse_mat->rows_; - cols = gpu_sparse_mat->cols_; - } else { - rows = sparse_mat->getRows(); - cols = sparse_mat->getCols(); - } - - rows[0] = 0; - for (int i = 0; i < height; i ++) { - int id = id_data[i]; - CHECK_LT(id, width); - rows[i + 1] = i + 1; - cols[i] = id; - } - - if (sparse_mat->useGpu()) { - auto gpu_sparse_mat = - dynamic_cast(sparse_mat.get()); - hl_memcpy_csr_matrix(gpu_sparse_mat->sMatrix_.get(), - nullptr, rows, cols, - HPPL_STREAM_DEFAULT); - hl_stream_synchronize(HPPL_STREAM_DEFAULT); - } -} - } // namespace paddle diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h index 695033138b..81ff9029bc 100644 --- a/paddle/parameter/Argument.h +++ b/paddle/parameter/Argument.h @@ -286,12 +286,6 @@ struct Argument { sequence has sub-sequence degrades to a sequence. */ void degradeSequence(const Argument& input, bool useGpu); - - /* - @brief convert the ids vector to value as a sparse matrix - @param[out] the output sparse_mat (already allocated) - */ - void idsToSparseMatrix(MatrixPtr sparse_mat); }; } // namespace paddle