|
|
|
@ -57,12 +57,12 @@ void FactorizationMachineLayer::forward(PassType passType) {
|
|
|
|
|
|
|
|
|
|
REGISTER_TIMER_INFO("FwMulTimer", getName().c_str());
|
|
|
|
|
tmpMul_->mul(*inputV, *latentVectors_->getW());
|
|
|
|
|
tmpOut_->pow2(*tmpMul_, 2);
|
|
|
|
|
tmpMul_->square2(*tmpOut_);
|
|
|
|
|
outV->sumRows(*tmpOut_, 0.5, 0);
|
|
|
|
|
|
|
|
|
|
x2_ = inputV->clone(0, 0, useGpu_);
|
|
|
|
|
x2_->pow2(*inputV, 2);
|
|
|
|
|
v2_->pow2(*latentVectors_->getW(), 2);
|
|
|
|
|
inputV->square2(*x2_);
|
|
|
|
|
latentVectors_->getW()->square2(*v2_);
|
|
|
|
|
tmpOut_->mul(*x2_, *v2_);
|
|
|
|
|
outV->sumRows(*tmpOut_, -0.5, 1.0);
|
|
|
|
|
|
|
|
|
|