|
|
|
@ -104,15 +104,21 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
|
|
|
|
|
CpuSparseMatrix* tmpIn_s = dynamic_cast<CpuSparseMatrix*>(tmpIn.get());
|
|
|
|
|
tmpIn_s->copyFrom(*inputV_s);
|
|
|
|
|
tmpIn_s->rowScale(0, *inputV_s, *oGrad);
|
|
|
|
|
latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1);
|
|
|
|
|
latentVectors_->getWGrad()->mul(*tmpIn_s->getTranspose(), *tmpMul_, 1, 1);
|
|
|
|
|
tmpIn_s->rowScale(0, *x2_s, *oGrad);
|
|
|
|
|
|
|
|
|
|
MatrixPtr ones = Matrix::create(1, inputV->getHeight(), false, useGpu_);
|
|
|
|
|
ones->zeroMem();
|
|
|
|
|
ones->add(-1);
|
|
|
|
|
tmpSum->mul(*ones, *tmpIn_s, 1, 0);
|
|
|
|
|
} else {
|
|
|
|
|
tmpIn->rowScale(0, *inputV, *oGrad);
|
|
|
|
|
latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1);
|
|
|
|
|
tmpIn->rowScale(0, *x2_, *oGrad);
|
|
|
|
|
|
|
|
|
|
tmpSum->sumCols(*tmpIn, -1, 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tmpSum->sumCols(*tmpIn, -1, 0);
|
|
|
|
|
latentVectors_->getWGrad()->addRowScale(
|
|
|
|
|
0, *latentVectors_->getW(), *tmpSum_T);
|
|
|
|
|
|
|
|
|
|