|
|
|
|
@ -462,29 +462,49 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap,
|
|
|
|
|
|
|
|
|
|
void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label,
|
|
|
|
|
Matrix& target) {
|
|
|
|
|
label.idsToSparseMatrix(output.getWidth(), useGpu_);
|
|
|
|
|
MatrixPtr value = nullptr;
|
|
|
|
|
if (label.ids) {
|
|
|
|
|
CHECK(!label.value);
|
|
|
|
|
value = Matrix::createSparseMatrix(
|
|
|
|
|
label.ids->getSize(), output.getWidth(), label.ids->getSize(),
|
|
|
|
|
NO_VALUE, SPARSE_CSR, false, useGpu_);
|
|
|
|
|
label.idsToSparseMatrix(value);
|
|
|
|
|
} else {
|
|
|
|
|
CHECK(label.value);
|
|
|
|
|
value = label.value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dynamic_cast<CpuSparseMatrix*>(label.value.get()) ||
|
|
|
|
|
dynamic_cast<GpuSparseMatrix*>(label.value.get())) {
|
|
|
|
|
target.multiBinaryLabelCrossEntropy(output, *label.value);
|
|
|
|
|
if (dynamic_cast<CpuSparseMatrix*>(value.get()) ||
|
|
|
|
|
dynamic_cast<GpuSparseMatrix*>(value.get())) {
|
|
|
|
|
target.multiBinaryLabelCrossEntropy(output, *value);
|
|
|
|
|
} else {
|
|
|
|
|
Matrix::resizeOrCreate(targetPerDim_, output.getHeight(), output.getWidth(),
|
|
|
|
|
false, useGpu_);
|
|
|
|
|
|
|
|
|
|
targetPerDim_->binaryLabelCrossEntropy(output, *label.value);
|
|
|
|
|
targetPerDim_->binaryLabelCrossEntropy(output, *value);
|
|
|
|
|
targetPerDim_->rowSum(target);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiBinaryLabelCrossEntropy::backwardImp(
|
|
|
|
|
Matrix& output, Argument& label, Matrix& outputG) {
|
|
|
|
|
label.idsToSparseMatrix(output.getWidth(), useGpu_);
|
|
|
|
|
MatrixPtr value = nullptr;
|
|
|
|
|
if (label.ids) {
|
|
|
|
|
CHECK(!value);
|
|
|
|
|
value = Matrix::createSparseMatrix(
|
|
|
|
|
label.ids->getSize(), output.getWidth(), label.ids->getSize(),
|
|
|
|
|
NO_VALUE, SPARSE_CSR, false, useGpu_);
|
|
|
|
|
label.idsToSparseMatrix(value);
|
|
|
|
|
} else {
|
|
|
|
|
CHECK(label.value);
|
|
|
|
|
value = label.value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dynamic_cast<CpuSparseMatrix*>(label.value.get()) ||
|
|
|
|
|
dynamic_cast<GpuSparseMatrix*>(label.value.get())) {
|
|
|
|
|
outputG.multiBinaryLabelCrossEntropyBp(output, *label.value);
|
|
|
|
|
if (dynamic_cast<CpuSparseMatrix*>(value.get()) ||
|
|
|
|
|
dynamic_cast<GpuSparseMatrix*>(value.get())) {
|
|
|
|
|
outputG.multiBinaryLabelCrossEntropyBp(output, *value);
|
|
|
|
|
} else {
|
|
|
|
|
outputG.binaryLabelCrossEntropyBp(output, *label.value);
|
|
|
|
|
outputG.binaryLabelCrossEntropyBp(output, *value);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|