|
|
|
@ -2365,14 +2365,15 @@ void testFactorizationMachineLayer(InputType type, bool useGpu) {
|
|
|
|
|
config.layerConfig.set_factor_size(FACTOR_SIZE);
|
|
|
|
|
config.layerConfig.set_size(1);
|
|
|
|
|
config.biasSize = 0;
|
|
|
|
|
config.inputDefs.push_back({type, "layer_0", 1024, 10240});
|
|
|
|
|
config.inputDefs.push_back({type, "layer_0", 128, 1280});
|
|
|
|
|
config.layerConfig.add_inputs();
|
|
|
|
|
testLayerGrad(config, "factorization_machine", 16, false, useGpu, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Layer, FactorizationMachineLayer) {
|
|
|
|
|
testFactorizationMachineLayer(INPUT_DATA, false);
|
|
|
|
|
testFactorizationMachineLayer(INPUT_DATA, true);
|
|
|
|
|
for (auto useGpu : {false, true}) {
|
|
|
|
|
testFactorizationMachineLayer(INPUT_DATA, useGpu);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int main(int argc, char** argv) {
|
|
|
|
|