Add support of sparse_binary_vector as input for fm layer

release/0.11.0
wangmeng28 7 years ago
parent 5392a503a7
commit 6fed6f2079

@ -96,15 +96,20 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
/* Calculate the gradients of the latentVectors_ matrix */
if (latentVectors_->getWGrad()) {
MatrixPtr tmpInput = inputV->clone(0, 0, useGpu_);
if (dynamic_cast<CpuSparseMatrix*>(inputV.get())) {
Matrix::resizeOrCreateSparseMatrix(tmpInput_,
inputV->getHeight(),
inputV->getWidth(),
inputV->getElementCnt());
CpuSparseMatrix* sparseInputV =
dynamic_cast<CpuSparseMatrix*>(inputV.get());
CpuSparseMatrix* sparseInputSquare =
dynamic_cast<CpuSparseMatrix*>(inputSquare_.get());
CpuSparseMatrix* sparseTmpInput =
dynamic_cast<CpuSparseMatrix*>(tmpInput.get());
dynamic_cast<CpuSparseMatrix*>(tmpInput_.get());
sparseTmpInput->copyFrom(*sparseInputV);
sparseTmpInput->rowScale(0, *sparseInputV, *oGrad);
latentVectors_->getWGrad()->mul(
*sparseTmpInput->getTranspose(), *inputMulFactor_, 1, 1);
@ -115,12 +120,15 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
negOnes_->add(-1);
tmpSum_->mul(*negOnes_, *sparseTmpInput, 1, 0);
} else {
tmpInput->rowScale(0, *inputV, *oGrad);
Matrix::resizeOrCreate(
tmpInput_, inputV->getHeight(), inputV->getWidth(), false, useGpu_);
tmpInput_->rowScale(0, *inputV, *oGrad);
latentVectors_->getWGrad()->mul(
*tmpInput->getTranspose(), *inputMulFactor_, 1, 1);
tmpInput->rowScale(0, *inputSquare_, *oGrad);
*tmpInput_->getTranspose(), *inputMulFactor_, 1, 1);
tmpInput_->rowScale(0, *inputSquare_, *oGrad);
tmpSum_->sumCols(*tmpInput, -1, 0);
tmpSum_->sumCols(*tmpInput_, -1, 0);
}
latentVectors_->getWGrad()->addRowScale(

@ -61,6 +61,7 @@ private:
// Store temporary calculation result
MatrixPtr tmpOut_;
MatrixPtr tmpSum_;
MatrixPtr tmpInput_;
// Negative identity matrix
MatrixPtr negOnes_;

@ -266,13 +266,25 @@ void CpuSparseMatrix::rowScale(size_t cCol, CpuSparseMatrix& b, Matrix& c) {
CHECK_EQ(width_, b.getWidth());
real* A = getValue();
real* B = b.getValue();
for (size_t i = 0; i < height_; i++) {
size_t start = getRowStartIdx(i);
size_t end = getRowStartIdx(i + 1);
CHECK_EQ(start, b.getRowStartIdx(i));
CHECK_EQ(end, b.getRowStartIdx(i + 1));
for (size_t j = start; j < end; j++) {
A[j] = B[j] * c.getElement(i, cCol);
if (b.getValueType() == FLOAT_VALUE) {
for (size_t i = 0; i < height_; i++) {
size_t start = getRowStartIdx(i);
size_t end = getRowStartIdx(i + 1);
CHECK_EQ(start, b.getRowStartIdx(i));
CHECK_EQ(end, b.getRowStartIdx(i + 1));
for (size_t j = start; j < end; j++) {
A[j] = B[j] * c.getElement(i, cCol);
}
}
} else if (b.getValueType() == NO_VALUE) {
for (size_t i = 0; i < height_; i++) {
size_t start = getRowStartIdx(i);
size_t end = getRowStartIdx(i + 1);
CHECK_EQ(start, b.getRowStartIdx(i));
CHECK_EQ(end, b.getRowStartIdx(i + 1));
for (size_t j = start; j < end; j++) {
A[j] = c.getElement(i, cCol);
}
}
}
}

Loading…
Cancel
Save