|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""train_criteo."""
|
|
|
|
|
import os
|
|
|
|
|
# import pytest
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore.train.model import Model
|
|
|
|
@ -27,10 +27,10 @@ from src.callback import EvalCallBack, LossCallBack, TimeMonitor
|
|
|
|
|
|
|
|
|
|
set_seed(1)
|
|
|
|
|
|
|
|
|
|
# @pytest.mark.level0
|
|
|
|
|
# @pytest.mark.platform_arm_ascend_training
|
|
|
|
|
# @pytest.mark.platform_x86_ascend_training
|
|
|
|
|
# @pytest.mark.env_onecard
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_arm_ascend_training
|
|
|
|
|
@pytest.mark.platform_x86_ascend_training
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_deepfm():
|
|
|
|
|
data_config = DataConfig()
|
|
|
|
|
train_config = TrainConfig()
|
|
|
|
@ -39,7 +39,7 @@ def test_deepfm():
|
|
|
|
|
rank_size = None
|
|
|
|
|
rank_id = None
|
|
|
|
|
|
|
|
|
|
dataset_path = "/home/workspace/mindspore_dataset/criteo_data/criteo_h5/"
|
|
|
|
|
dataset_path = "/home/workspace/mindspore_dataset/criteo_data/mindrecord/"
|
|
|
|
|
print("dataset_path:", dataset_path)
|
|
|
|
|
ds_train = create_dataset(dataset_path,
|
|
|
|
|
train_mode=True,
|
|
|
|
@ -71,10 +71,10 @@ def test_deepfm():
|
|
|
|
|
print("train_config.train_epochs:", train_config.train_epochs)
|
|
|
|
|
model.train(train_config.train_epochs, ds_train, callbacks=callback_list)
|
|
|
|
|
|
|
|
|
|
export_loss_value = 0.51
|
|
|
|
|
export_loss_value = 0.52
|
|
|
|
|
print("loss_callback.loss:", loss_callback.loss)
|
|
|
|
|
assert loss_callback.loss < export_loss_value
|
|
|
|
|
export_per_step_time = 40.0
|
|
|
|
|
export_per_step_time = 30.0
|
|
|
|
|
print("time_callback:", time_callback.per_step_time)
|
|
|
|
|
assert time_callback.per_step_time < export_per_step_time
|
|
|
|
|
print("*******test case pass!********")
|
|
|
|
|