|
|
@ -58,16 +58,22 @@ void FactorizationMachineLayer::forward(PassType passType) {
|
|
|
|
inputMulFactor_, batchSize, factorSize_, false, useGpu_);
|
|
|
|
inputMulFactor_, batchSize, factorSize_, false, useGpu_);
|
|
|
|
Matrix::resizeOrCreate(tmpOut_, batchSize, factorSize_, false, useGpu_);
|
|
|
|
Matrix::resizeOrCreate(tmpOut_, batchSize, factorSize_, false, useGpu_);
|
|
|
|
|
|
|
|
|
|
|
|
REGISTER_TIMER_INFO("InputMulFactorTimer", getName().c_str());
|
|
|
|
REGISTER_TIMER_INFO("FmInputMulFactorTimer", getName().c_str());
|
|
|
|
inputMulFactor_->mul(*inputV, *latentVectors_->getW());
|
|
|
|
inputMulFactor_->mul(*inputV, *latentVectors_->getW());
|
|
|
|
inputMulFactor_->square2(*tmpOut_);
|
|
|
|
inputMulFactor_->square2(*tmpOut_);
|
|
|
|
outV->sumRows(*tmpOut_, 0.5, 0);
|
|
|
|
outV->sumRows(*tmpOut_, 0.5, 0);
|
|
|
|
|
|
|
|
|
|
|
|
inputSquare_ = inputV->clone(0, 0, useGpu_);
|
|
|
|
if (dynamic_cast<CpuSparseMatrix*>(inputV.get())) {
|
|
|
|
if (dynamic_cast<CpuSparseMatrix*>(inputSquare_.get())) {
|
|
|
|
Matrix::resizeOrCreateSparseMatrix(inputSquare_,
|
|
|
|
|
|
|
|
inputV->getHeight(),
|
|
|
|
|
|
|
|
inputV->getWidth(),
|
|
|
|
|
|
|
|
inputV->getElementCnt(),
|
|
|
|
|
|
|
|
inputV->getValueType());
|
|
|
|
inputSquare_->copyFrom(*inputV);
|
|
|
|
inputSquare_->copyFrom(*inputV);
|
|
|
|
(dynamic_cast<CpuSparseMatrix*>(inputSquare_.get()))->square2();
|
|
|
|
(dynamic_cast<CpuSparseMatrix*>(inputSquare_.get()))->square2();
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
|
|
|
|
Matrix::resizeOrCreate(
|
|
|
|
|
|
|
|
inputSquare_, inputV->getHeight(), inputV->getWidth(), false, useGpu_);
|
|
|
|
inputV->square2(*inputSquare_);
|
|
|
|
inputV->square2(*inputSquare_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
latentVectors_->getW()->square2(*latentVectorsSquare_);
|
|
|
|
latentVectors_->getW()->square2(*latentVectorsSquare_);
|
|
|
@ -75,7 +81,7 @@ void FactorizationMachineLayer::forward(PassType passType) {
|
|
|
|
outV->sumRows(*tmpOut_, -0.5, 1.0);
|
|
|
|
outV->sumRows(*tmpOut_, -0.5, 1.0);
|
|
|
|
|
|
|
|
|
|
|
|
/* activation */ {
|
|
|
|
/* activation */ {
|
|
|
|
REGISTER_TIMER_INFO("FmAtvTimer", getName().c_str());
|
|
|
|
REGISTER_TIMER_INFO("FmFwAtvTimer", getName().c_str());
|
|
|
|
forwardActivation();
|
|
|
|
forwardActivation();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|