parent
d113e7a694
commit
9ebf8e2362
@ -0,0 +1,110 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Defined callback for DeepFM.
|
||||
"""
|
||||
import time
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
def add_write(file_path, out_str):
|
||||
with open(file_path, 'a+', encoding='utf-8') as file_out:
|
||||
file_out.write(out_str + '\n')
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note
|
||||
If per_print_times is 0 do not print loss.
|
||||
"""
|
||||
def __init__(self, model, eval_dataset, auc_metric, eval_file_path):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.model = model
|
||||
self.eval_dataset = eval_dataset
|
||||
self.aucMetric = auc_metric
|
||||
self.aucMetric.clear()
|
||||
self.eval_file_path = eval_file_path
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
start_time = time.time()
|
||||
out = self.model.eval(self.eval_dataset)
|
||||
eval_time = int(time.time() - start_time)
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
out_str = "{} EvalCallBack metric{}; eval_time{}s".format(
|
||||
time_str, out.values(), eval_time)
|
||||
print(out_str)
|
||||
add_write(self.eval_file_path, out_str)
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note
|
||||
If per_print_times is 0 do not print loss.
|
||||
Args
|
||||
loss_file_path (str) The file absolute path, to save as loss_file;
|
||||
per_print_times (int) Print loss every times. Default 1.
|
||||
"""
|
||||
def __init__(self, loss_file_path, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self.loss_file_path = loss_file_path
|
||||
self._per_print_times = per_print_times
|
||||
self.loss = 0
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs.asnumpy()
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
cur_num = cb_params.cur_step_num
|
||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
|
||||
with open(self.loss_file_path, "a+") as loss_file:
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
loss_file.write("{} epoch: {} step: {}, loss is {}\n".format(
|
||||
time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss))
|
||||
print("epoch: {} step: {}, loss is {}\n".format(
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch, loss))
|
||||
self.loss = loss
|
||||
|
||||
class TimeMonitor(Callback):
|
||||
"""
|
||||
Time monitor for calculating cost of each epoch.
|
||||
Args
|
||||
data_size (int) step size of an epoch.
|
||||
"""
|
||||
def __init__(self, data_size):
|
||||
super(TimeMonitor, self).__init__()
|
||||
self.data_size = data_size
|
||||
self.per_step_time = 0
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / self.data_size
|
||||
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
|
||||
self.per_step_time = per_step_mseconds
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
print(f"step time {step_mseconds}", flush=True)
|
||||
@ -0,0 +1,62 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
|
||||
|
||||
class DataConfig:
|
||||
"""
|
||||
Define parameters of dataset.
|
||||
"""
|
||||
data_vocab_size = 184965
|
||||
train_num_of_parts = 21
|
||||
test_num_of_parts = 3
|
||||
batch_size = 1000
|
||||
data_field_size = 39
|
||||
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5
|
||||
data_format = 3
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""
|
||||
Define parameters of model.
|
||||
"""
|
||||
batch_size = DataConfig.batch_size
|
||||
data_field_size = DataConfig.data_field_size
|
||||
data_vocab_size = DataConfig.data_vocab_size
|
||||
data_emb_dim = 80
|
||||
deep_layer_args = [[400, 400, 512], "relu"]
|
||||
init_args = [-0.01, 0.01]
|
||||
weight_bias_init = ['normal', 'normal']
|
||||
keep_prob = 0.9
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
"""
|
||||
Define parameters of training.
|
||||
"""
|
||||
batch_size = DataConfig.batch_size
|
||||
l2_coef = 1e-6
|
||||
learning_rate = 1e-5
|
||||
epsilon = 1e-8
|
||||
loss_scale = 1024.0
|
||||
train_epochs = 3
|
||||
save_checkpoint = True
|
||||
ckpt_file_name_prefix = "deepfm"
|
||||
save_checkpoint_steps = 1
|
||||
keep_checkpoint_max = 15
|
||||
eval_callback = True
|
||||
loss_callback = True
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,80 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_criteo."""
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.deepfm import ModelBuilder, AUCMetric
|
||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||
from src.dataset import create_dataset, DataType
|
||||
from src.callback import EvalCallBack, LossCallBack, TimeMonitor
|
||||
|
||||
set_seed(1)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_deepfm():
|
||||
data_config = DataConfig()
|
||||
train_config = TrainConfig()
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
rank_size = None
|
||||
rank_id = None
|
||||
|
||||
dataset_path = "/home/workspace/mindspore_dataset/criteo_data/criteo_h5/"
|
||||
print("dataset_path:", dataset_path)
|
||||
ds_train = create_dataset(dataset_path,
|
||||
train_mode=True,
|
||||
epochs=1,
|
||||
batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format),
|
||||
rank_size=rank_size,
|
||||
rank_id=rank_id)
|
||||
|
||||
model_builder = ModelBuilder(ModelConfig, TrainConfig)
|
||||
train_net, eval_net = model_builder.get_train_eval_net()
|
||||
auc_metric = AUCMetric()
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
|
||||
loss_file_name = './loss.log'
|
||||
time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
loss_callback = LossCallBack(loss_file_path=loss_file_name)
|
||||
callback_list = [time_callback, loss_callback]
|
||||
|
||||
eval_file_name = './auc.log'
|
||||
ds_eval = create_dataset(dataset_path, train_mode=False,
|
||||
epochs=1,
|
||||
batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format))
|
||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric,
|
||||
eval_file_path=eval_file_name)
|
||||
callback_list.append(eval_callback)
|
||||
|
||||
print("train_config.train_epochs:", train_config.train_epochs)
|
||||
model.train(train_config.train_epochs, ds_train, callbacks=callback_list)
|
||||
|
||||
export_loss_value = 0.51
|
||||
print("loss_callback.loss:", loss_callback.loss)
|
||||
assert loss_callback.loss < export_loss_value
|
||||
export_per_step_time = 10.4
|
||||
print("time_callback:", time_callback.per_step_time)
|
||||
assert time_callback.per_step_time < export_per_step_time
|
||||
print("*******test case pass!********")
|
||||
Loading…
Reference in new issue