|
|
|
@ -14,7 +14,8 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/utils/Util.h"
|
|
|
|
|
|
|
|
|
|
#include "Layer.h"
|
|
|
|
|
#include "CostLayer.h"
|
|
|
|
|
#include "ValidationLayer.h"
|
|
|
|
|
#include "paddle/math/SparseMatrix.h"
|
|
|
|
|
#include "paddle/utils/Error.h"
|
|
|
|
|
#include "paddle/utils/Logging.h"
|
|
|
|
@ -93,6 +94,20 @@ ClassRegistrar<Layer, LayerConfig> Layer::registrar_;
|
|
|
|
|
|
|
|
|
|
LayerPtr Layer::create(const LayerConfig& config) {
|
|
|
|
|
std::string type = config.type();
|
|
|
|
|
|
|
|
|
|
// NOTE: As following types have illegal character '-',
|
|
|
|
|
// they can not use REGISTER_LAYER to registrar.
|
|
|
|
|
// Besides, to fit with old training models,
|
|
|
|
|
// they can not use '_' instead.
|
|
|
|
|
if (type == "multi-class-cross-entropy")
|
|
|
|
|
return LayerPtr(new MultiClassCrossEntropy(config));
|
|
|
|
|
else if (type == "rank-cost")
|
|
|
|
|
return LayerPtr(new RankingCost(config));
|
|
|
|
|
else if (type == "auc-validation")
|
|
|
|
|
return LayerPtr(new AucValidation(config));
|
|
|
|
|
else if (type == "pnpair-validation")
|
|
|
|
|
return LayerPtr(new PnpairValidation(config));
|
|
|
|
|
|
|
|
|
|
return LayerPtr(registrar_.createByType(config.type(), config));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|