|
|
|
@ -48,7 +48,7 @@ class KLDivLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
if ("none" == reduction) {
|
|
|
|
|
ctx->SetOutputDim("Loss", dim_x);
|
|
|
|
|
} else {
|
|
|
|
|
ctx->SetOutputDim("Loss", framework::make_ddim({1}));
|
|
|
|
|
ctx->SetOutputDim("Loss", {1});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -81,7 +81,7 @@ class KLDivLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"The reduction type to apply to the output, available types "
|
|
|
|
|
"are 'none' | 'batchmean' | 'mean' | 'sum', 'none' for no "
|
|
|
|
|
"reduction, 'batchmean' for the sum of output divided by "
|
|
|
|
|
"batchmean size, 'mean' for the average valud of all output, "
|
|
|
|
|
"batch size, 'mean' for the average valud of all output, "
|
|
|
|
|
"'sum' for the sum of the output.")
|
|
|
|
|
.SetDefault("mean");
|
|
|
|
|
|
|
|
|
|