deepfm network use mindrecord dataset

pull/10592/head
wsq3 4 years ago
parent 1a9f4625ed
commit 435a524f06

File diff suppressed because it is too large Load Diff

@ -27,7 +27,7 @@ class DataConfig:
batch_size = 16000
data_field_size = 39
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5
data_format = 3
data_format = 1
class ModelConfig:

@ -14,7 +14,7 @@
# ============================================================================
"""train_criteo."""
import os
# import pytest
import pytest
from mindspore import context
from mindspore.train.model import Model
@ -27,10 +27,10 @@ from src.callback import EvalCallBack, LossCallBack, TimeMonitor
set_seed(1)
# @pytest.mark.level0
# @pytest.mark.platform_arm_ascend_training
# @pytest.mark.platform_x86_ascend_training
# @pytest.mark.env_onecard
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_deepfm():
data_config = DataConfig()
train_config = TrainConfig()
@ -39,7 +39,7 @@ def test_deepfm():
rank_size = None
rank_id = None
dataset_path = "/home/workspace/mindspore_dataset/criteo_data/criteo_h5/"
dataset_path = "/home/workspace/mindspore_dataset/criteo_data/mindrecord/"
print("dataset_path:", dataset_path)
ds_train = create_dataset(dataset_path,
train_mode=True,
@ -71,10 +71,10 @@ def test_deepfm():
print("train_config.train_epochs:", train_config.train_epochs)
model.train(train_config.train_epochs, ds_train, callbacks=callback_list)
export_loss_value = 0.51
export_loss_value = 0.52
print("loss_callback.loss:", loss_callback.loss)
assert loss_callback.loss < export_loss_value
export_per_step_time = 40.0
export_per_step_time = 30.0
print("time_callback:", time_callback.per_step_time)
assert time_callback.per_step_time < export_per_step_time
print("*******test case pass!********")

Loading…
Cancel
Save