|
|
|
@ -112,7 +112,6 @@ BEGIN_DEFINE_ACTIVATION(softmax)
|
|
|
|
|
private:
|
|
|
|
|
MatrixPtr sftMaxSum_;
|
|
|
|
|
MatrixPtr sftMaxDot_;
|
|
|
|
|
MatrixPtr one_;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
Error __must_check forward(Argument& act) {
|
|
|
|
@ -138,14 +137,6 @@ Error __must_check backward(Argument& act) {
|
|
|
|
|
1,
|
|
|
|
|
/* trans */ false,
|
|
|
|
|
useGpu(act.deviceId));
|
|
|
|
|
if (!one_ || one_->getWidth() != outputG->getWidth()) {
|
|
|
|
|
Matrix::resizeOrCreate(one_,
|
|
|
|
|
1,
|
|
|
|
|
outputG->getWidth(),
|
|
|
|
|
/* trans */ false,
|
|
|
|
|
useGpu(act.deviceId));
|
|
|
|
|
one_->one();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sftMaxDot_->dotMul(*outputG, *outputV);
|
|
|
|
|
sftMaxSum_->colMerge(*sftMaxDot_);
|
|
|
|
|