parent
21addb331d
commit
65c7eb2461
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,107 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""eval standalone script"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.dataset import create_tinybert_dataset
|
||||
from src.config import eval_cfg, student_net_cfg, task_cfg
|
||||
from src.tinybert_model import BertModelCLS
|
||||
|
||||
|
||||
DATA_NAME = 'eval.tf_record'
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='ternarybert evaluation')
|
||||
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU'],
|
||||
help='Device where the code will be implemented. (Default: GPU)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='Device id. (Default: 0)')
|
||||
parser.add_argument('--model_dir', type=str, default='', help='The checkpoint directory of model.')
|
||||
parser.add_argument('--data_dir', type=str, default='', help='Data directory.')
|
||||
parser.add_argument('--task_name', type=str, default='sts-b', choices=['sts-b', 'qnli', 'mnli'],
|
||||
help='The name of the task to train. (Default: sts-b)')
|
||||
parser.add_argument('--dataset_type', type=str, default='tfrecord', choices=['tfrecord', 'mindrecord'],
|
||||
help='The name of the task to train. (Default: tfrecord)')
|
||||
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for evaluating')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_ckpt(ckpt_file):
|
||||
lists = os.listdir(ckpt_file)
|
||||
lists.sort(key=lambda fn: os.path.getmtime(ckpt_file + '/' + fn))
|
||||
return os.path.join(ckpt_file, lists[-1])
|
||||
|
||||
|
||||
def do_eval_standalone(args_opt):
|
||||
"""
|
||||
do eval standalone
|
||||
"""
|
||||
ckpt_file = os.path.join(args_opt.model_dir, args_opt.task_name)
|
||||
ckpt_file = get_ckpt(ckpt_file)
|
||||
print('ckpt file:', ckpt_file)
|
||||
task = task_cfg[args_opt.task_name]
|
||||
student_net_cfg.seq_length = task.seq_length
|
||||
eval_cfg.batch_size = args_opt.batch_size
|
||||
eval_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name, DATA_NAME)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args.device_id)
|
||||
|
||||
eval_dataset = create_tinybert_dataset(batch_size=eval_cfg.batch_size,
|
||||
device_num=1,
|
||||
rank=0,
|
||||
do_shuffle='false',
|
||||
data_dir=eval_data_dir,
|
||||
data_type=args_opt.dataset_type,
|
||||
seq_length=task.seq_length,
|
||||
task_type=task.task_type,
|
||||
drop_remainder=False)
|
||||
print('eval dataset size:', eval_dataset.get_dataset_size())
|
||||
print('eval dataset batch size:', eval_dataset.get_batch_size())
|
||||
|
||||
eval_model = BertModelCLS(student_net_cfg, False, task.num_labels, 0.0, phase_type='student')
|
||||
param_dict = load_checkpoint(ckpt_file)
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
new_key = re.sub('tinybert_', 'bert_', key)
|
||||
new_key = re.sub('^bert.', '', new_key)
|
||||
new_param_dict[new_key] = value
|
||||
load_param_into_net(eval_model, new_param_dict)
|
||||
eval_model.set_train(False)
|
||||
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
callback = task.metrics()
|
||||
for step, data in enumerate(eval_dataset.create_dict_iterator()):
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
_, _, logits, _ = eval_model(input_ids, token_type_id, input_mask)
|
||||
callback.update(logits, label_ids)
|
||||
print('eval step: {}, {}: {}'.format(step, callback.name, callback.get_metrics()))
|
||||
metrics = callback.get_metrics()
|
||||
print('The best {}: {}'.format(callback.name, metrics))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
do_eval_standalone(args)
|
@ -0,0 +1,57 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Bert hub interface for bert base"""
|
||||
|
||||
from src.tinybert_model import BertModel
|
||||
from src.tinybert_model import BertConfig
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
tinybert_student_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32,
|
||||
do_quant=True,
|
||||
embedding_bits=2,
|
||||
weight_bits=2,
|
||||
weight_clip_value=3.0,
|
||||
cls_dropout_prob=0.1,
|
||||
activation_init=2.5,
|
||||
is_lgt_fit=False
|
||||
)
|
||||
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""
|
||||
Create tinybert network.
|
||||
"""
|
||||
if name == "ternarybert":
|
||||
if "seq_length" in kwargs:
|
||||
tinybert_student_net_cfg.seq_length = kwargs["seq_length"]
|
||||
is_training = kwargs.get("is_training", False)
|
||||
return BertModel(tinybert_student_net_cfg, is_training, *args)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../eval.py \
|
||||
--task_name=sts-b \
|
||||
--device_id=0 \
|
||||
--model_dir="" \
|
||||
--data_dir="" > log.txt
|
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../train.py \
|
||||
--task_name=sts-b \
|
||||
--device_id=0 \
|
||||
--teacher_model_dir="" \
|
||||
--student_model_dir="" \
|
||||
--data_dir="" > log.txt
|
@ -0,0 +1,115 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""assessment methods"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Accuracy:
|
||||
"""Accuracy"""
|
||||
def __init__(self):
|
||||
self.acc_num = 0
|
||||
self.total_num = 0
|
||||
self.name = 'Accuracy'
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logit_id = np.argmax(logits, axis=-1)
|
||||
self.acc_num += np.sum(labels == logit_id)
|
||||
self.total_num += len(labels)
|
||||
|
||||
def get_metrics(self):
|
||||
return self.acc_num / self.total_num * 100.0
|
||||
|
||||
|
||||
class F1:
|
||||
"""F1"""
|
||||
def __init__(self):
|
||||
self.logits_array = np.array([])
|
||||
self.labels_array = np.array([])
|
||||
self.name = 'F1'
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logits = np.argmax(logits, axis=1)
|
||||
self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool)
|
||||
self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool)
|
||||
|
||||
def get_metrics(self):
|
||||
if len(self.labels_array) < 2:
|
||||
return 0.0
|
||||
tp = np.sum(self.labels_array & self.logits_array)
|
||||
fp = np.sum(self.labels_array & (~self.logits_array))
|
||||
fn = np.sum((~self.labels_array) & self.logits_array)
|
||||
p = tp / (tp + fp)
|
||||
r = tp / (tp + fn)
|
||||
return 2.0 * p * r / (p + r) * 100.0
|
||||
|
||||
|
||||
class Pearsonr:
|
||||
"""Pearsonr"""
|
||||
def __init__(self):
|
||||
self.logits_array = np.array([])
|
||||
self.labels_array = np.array([])
|
||||
self.name = 'Pearsonr'
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logits = np.reshape(logits, -1)
|
||||
self.labels_array = np.concatenate([self.labels_array, labels])
|
||||
self.logits_array = np.concatenate([self.logits_array, logits])
|
||||
|
||||
def get_metrics(self):
|
||||
if len(self.labels_array) < 2:
|
||||
return 0.0
|
||||
x_mean = self.logits_array.mean()
|
||||
y_mean = self.labels_array.mean()
|
||||
xm = self.logits_array - x_mean
|
||||
ym = self.labels_array - y_mean
|
||||
norm_xm = np.linalg.norm(xm)
|
||||
norm_ym = np.linalg.norm(ym)
|
||||
return np.dot(xm / norm_xm, ym / norm_ym) * 100.0
|
||||
|
||||
|
||||
class Matthews:
|
||||
"""Matthews"""
|
||||
def __init__(self):
|
||||
self.logits_array = np.array([])
|
||||
self.labels_array = np.array([])
|
||||
self.name = 'Matthews'
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logits = np.argmax(logits, axis=1)
|
||||
self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool)
|
||||
self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool)
|
||||
|
||||
def get_metrics(self):
|
||||
if len(self.labels_array) < 2:
|
||||
return 0.0
|
||||
tp = np.sum(self.labels_array & self.logits_array)
|
||||
fp = np.sum(self.labels_array & (~self.logits_array))
|
||||
fn = np.sum((~self.labels_array) & self.logits_array)
|
||||
tn = np.sum((~self.labels_array) & (~self.logits_array))
|
||||
return (tp * tn - fp * fn) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) * 100.0
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,103 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""config script"""
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from easydict import EasyDict as edict
|
||||
from .tinybert_model import BertConfig
|
||||
from .assessment_method import Accuracy, F1, Pearsonr, Matthews
|
||||
|
||||
|
||||
gradient_cfg = edict({
|
||||
'clip_type': 1,
|
||||
'clip_value': 1.0
|
||||
})
|
||||
|
||||
task_cfg = edict({
|
||||
"sst-2": edict({"num_labels": 2, "seq_length": 64, "task_type": "classification", "metrics": Accuracy}),
|
||||
"qnli": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": Accuracy}),
|
||||
"mnli": edict({"num_labels": 3, "seq_length": 128, "task_type": "classification", "metrics": Accuracy}),
|
||||
"cola": edict({"num_labels": 2, "seq_length": 64, "task_type": "classification", "metrics": Matthews}),
|
||||
"mrpc": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": F1}),
|
||||
"sts-b": edict({"num_labels": 1, "seq_length": 128, "task_type": "regression", "metrics": Pearsonr}),
|
||||
"qqp": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": F1}),
|
||||
"rte": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": Accuracy})
|
||||
})
|
||||
|
||||
train_cfg = edict({
|
||||
'batch_size': 16,
|
||||
'loss_scale_value': 2 ** 16,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 50,
|
||||
'optimizer_cfg': edict({
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 5e-5,
|
||||
'end_learning_rate': 1e-14,
|
||||
'power': 1.0,
|
||||
'weight_decay': 1e-4,
|
||||
'eps': 1e-6,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
'warmup_ratio': 0.1
|
||||
}),
|
||||
}),
|
||||
})
|
||||
|
||||
eval_cfg = edict({
|
||||
'batch_size': 32,
|
||||
})
|
||||
|
||||
teacher_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32,
|
||||
do_quant=False
|
||||
)
|
||||
student_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32,
|
||||
do_quant=True,
|
||||
embedding_bits=2,
|
||||
weight_bits=2,
|
||||
weight_clip_value=3.0,
|
||||
cls_dropout_prob=0.1,
|
||||
activation_init=2.5,
|
||||
is_lgt_fit=False
|
||||
)
|
@ -0,0 +1,62 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create tinybert dataset"""
|
||||
|
||||
from enum import Enum
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
"""Enumerate supported dataset format"""
|
||||
TFRECORD = 1
|
||||
MINDRECORD = 2
|
||||
|
||||
|
||||
def create_tinybert_dataset(batch_size=32, device_num=1, rank=0, do_shuffle="true", data_dir=None,
|
||||
data_type='tfrecord', seq_length=128, task_type=mstype.int32, drop_remainder=True):
|
||||
"""create tinybert dataset"""
|
||||
if isinstance(data_dir, list):
|
||||
data_files = data_dir
|
||||
else:
|
||||
data_files = [data_dir]
|
||||
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
|
||||
shuffle = (do_shuffle == "true")
|
||||
|
||||
if data_type == 'mindrecord':
|
||||
ds = de.MindDataset(data_files, columns_list=columns_list, shuffle=shuffle, num_shards=device_num,
|
||||
shard_id=rank)
|
||||
else:
|
||||
ds = de.TFRecordDataset(data_files, columns_list=columns_list, shuffle=shuffle, num_shards=device_num,
|
||||
shard_id=rank, shard_equal_rows=(device_num == 1))
|
||||
|
||||
if device_num == 1 and shuffle is True:
|
||||
ds = ds.shuffle(10000)
|
||||
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
slice_op = C.Slice(slice(0, seq_length, 1))
|
||||
label_type = mstype.int32 if task_type == 'classification' else mstype.float32
|
||||
ds = ds.map(operations=[type_cast_op, slice_op], input_columns=["segment_ids"])
|
||||
ds = ds.map(operations=[type_cast_op, slice_op], input_columns=["input_mask"])
|
||||
ds = ds.map(operations=[type_cast_op, slice_op], input_columns=["input_ids"])
|
||||
ds = ds.map(operations=[C.TypeCast(label_type), slice_op], input_columns=["label_ids"])
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||
|
||||
return ds
|
@ -0,0 +1,171 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Quantization function."""
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import nn
|
||||
|
||||
|
||||
class QuantizeWeightCell(nn.Cell):
|
||||
"""
|
||||
The ternary fake quant op for weight.
|
||||
|
||||
Args:
|
||||
num_bits (int): The bit number of quantization, supporting 2 to 8 bits. Default: 2.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in QuantizeWeightCell. Default: mstype.float32.
|
||||
clip_value (float): Clips weight to be in [-clip_value, clip_value].
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **weight** (Parameter) - Parameter of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Parameter of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
"""
|
||||
|
||||
def __init__(self, num_bits=8, compute_type=mstype.float32, clip_value=1.0, per_channel=False):
|
||||
super(QuantizeWeightCell, self).__init__()
|
||||
self.num_bits = num_bits
|
||||
self.compute_type = compute_type
|
||||
self.clip_value = clip_value
|
||||
self.per_channel = per_channel
|
||||
|
||||
self.clamp = C.clip_by_value
|
||||
self.abs = P.Abs()
|
||||
self.sum = P.ReduceSum()
|
||||
self.nelement = F.size
|
||||
self.div = P.Div()
|
||||
self.cast = P.Cast()
|
||||
self.max = P.ReduceMax()
|
||||
self.min = P.ReduceMin()
|
||||
self.round = P.Round()
|
||||
|
||||
def construct(self, weight):
|
||||
"""quantize weight cell"""
|
||||
tensor = self.clamp(weight, -self.clip_value, self.clip_value)
|
||||
if self.num_bits == 2:
|
||||
if self.per_channel:
|
||||
n = self.nelement(tensor[0])
|
||||
m = self.div(self.sum(self.abs(tensor), 1), n)
|
||||
thres = 0.7 * m
|
||||
pos = self.cast(tensor[:] > thres[0], self.compute_type)
|
||||
neg = self.cast(tensor[:] < -thres[0], self.compute_type)
|
||||
mask = self.cast(self.abs(tensor)[:] > thres[0], self.compute_type)
|
||||
alpha = self.reshape(self.sum(self.abs(mask * tensor), 1) / self.sum(mask, 1), (-1, 1))
|
||||
output = alpha * pos - alpha * neg
|
||||
else:
|
||||
n = self.nelement(tensor)
|
||||
m = self.div(self.sum(self.abs(tensor)), n)
|
||||
thres = 0.7 * m
|
||||
pos = self.cast(tensor > thres, self.compute_type)
|
||||
neg = self.cast(tensor < -thres, self.compute_type)
|
||||
mask = self.cast(self.abs(tensor) > thres, self.compute_type)
|
||||
alpha = self.sum(self.abs(mask * self.cast(tensor, self.compute_type))) / self.sum(mask)
|
||||
output = alpha * pos - alpha * neg
|
||||
else:
|
||||
tensor_max = self.cast(self.max(tensor), self.compute_type)
|
||||
tensor_min = self.cast(self.min(tensor), self.compute_type)
|
||||
s = (tensor_max - tensor_min) / (2 ** self.cast(self.num_bits, self.compute_type) - 1)
|
||||
output = self.round(self.div(tensor - tensor_min, s)) * s + tensor_min
|
||||
return output
|
||||
|
||||
|
||||
class QuantizeWeight:
|
||||
"""
|
||||
Quantize weight into specified bit.
|
||||
|
||||
Args:
|
||||
num_bits (int): The bit number of quantization, supporting 2 to 8 bits. Default: 2.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in QuantizeWeightCell. Default: mstype.float32.
|
||||
clip_value (float): Clips weight to be in [-clip_value, clip_value].
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **weight** (Parameter) - Parameter of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Parameter of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
"""
|
||||
|
||||
def __init__(self, num_bits=2, compute_type=mstype.float32, clip_value=1.0, per_channel=False):
|
||||
self.num_bits = num_bits
|
||||
self.compute_type = compute_type
|
||||
self.clip_value = clip_value
|
||||
self.per_channel = per_channel
|
||||
|
||||
self.clamp = C.clip_by_value
|
||||
self.abs = P.Abs()
|
||||
self.sum = P.ReduceSum()
|
||||
self.nelement = F.size
|
||||
self.div = P.Div()
|
||||
self.cast = P.Cast()
|
||||
self.max = P.ReduceMax()
|
||||
self.min = P.ReduceMin()
|
||||
self.floor = P.Floor()
|
||||
|
||||
def construct(self, weight):
|
||||
"""quantize weight"""
|
||||
tensor = self.clamp(weight, -self.clip_value, self.clip_value)
|
||||
if self.num_bits == 2:
|
||||
if self.per_channel:
|
||||
n = self.nelement(tensor[0])
|
||||
m = self.div(self.sum(self.abs(tensor), 1), n)
|
||||
thres = 0.7 * m
|
||||
pos = self.cast(tensor[:] > thres[0], self.compute_type)
|
||||
neg = self.cast(tensor[:] < -thres[0], self.compute_type)
|
||||
mask = self.cast(self.abs(tensor)[:] > thres[0], self.compute_type)
|
||||
alpha = self.reshape(self.sum(self.abs(mask * tensor), 1) / self.sum(mask, 1), (-1, 1))
|
||||
output = alpha * pos - alpha * neg
|
||||
else:
|
||||
n = self.nelement(tensor)
|
||||
m = self.div(self.sum(self.abs(tensor)), n)
|
||||
thres = 0.7 * m
|
||||
pos = self.cast(tensor > thres, self.compute_type)
|
||||
neg = self.cast(tensor < -thres, self.compute_type)
|
||||
mask = self.cast(self.abs(tensor) > thres, self.compute_type)
|
||||
alpha = self.sum(self.abs(mask * tensor)) / self.sum(mask)
|
||||
output = alpha * pos - alpha * neg
|
||||
else:
|
||||
tensor_max = self.max(tensor)
|
||||
tensor_min = self.min(tensor)
|
||||
s = (tensor_max - tensor_min) / (2 ** self.num_bits - 1)
|
||||
output = self.floor(self.div((tensor - tensor_min), s) + 0.5) * s + tensor_min
|
||||
return output
|
||||
|
||||
|
||||
def convert_network(network, embedding_bits=2, weight_bits=2, clip_value=1.0):
|
||||
quantize_embedding = QuantizeWeight(num_bits=embedding_bits, clip_value=clip_value)
|
||||
quantize_weight = QuantizeWeight(num_bits=weight_bits, clip_value=clip_value)
|
||||
for name, param in network.parameters_and_names():
|
||||
if 'bert_embedding_lookup' in name and 'min' not in name and 'max' not in name:
|
||||
quantized_param = quantize_embedding.construct(param)
|
||||
param.set_data(quantized_param)
|
||||
elif 'weight' in name and 'dense_1' not in name:
|
||||
quantized_param = quantize_weight.construct(param)
|
||||
param.set_data(quantized_param)
|
||||
|
||||
|
||||
def save_params(network):
|
||||
return {name: Parameter(param, 'saved_params') for name, param in network.parameters_and_names()}
|
||||
|
||||
|
||||
def restore_params(network, params_dict):
|
||||
for name, param in network.parameters_and_names():
|
||||
param.set_data(params_dict[name])
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,187 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ternarybert utils"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
from .quant import convert_network, save_params, restore_params
|
||||
|
||||
|
||||
class ModelSaveCkpt(Callback):
|
||||
"""
|
||||
Saves checkpoint.
|
||||
If the loss in NAN or INF terminating training.
|
||||
Args:
|
||||
network (Network): The train network for training.
|
||||
save_ckpt_step (int): The step to save checkpoint.
|
||||
max_ckpt_num (int): The max checkpoint number.
|
||||
"""
|
||||
def __init__(self, network, save_ckpt_step, max_ckpt_num, output_dir, embedding_bits=2, weight_bits=2,
|
||||
clip_value=1.0):
|
||||
super(ModelSaveCkpt, self).__init__()
|
||||
self.count = 0
|
||||
self.network = network
|
||||
self.save_ckpt_step = save_ckpt_step
|
||||
self.max_ckpt_num = max_ckpt_num
|
||||
self.output_dir = output_dir
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
self.embedding_bits = embedding_bits
|
||||
self.weight_bits = weight_bits
|
||||
self.clip_value = clip_value
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end and save ckpt"""
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.cur_step_num % self.save_ckpt_step == 0:
|
||||
saved_ckpt_num = cb_params.cur_step_num / self.save_ckpt_step
|
||||
if saved_ckpt_num > self.max_ckpt_num:
|
||||
oldest_ckpt_index = saved_ckpt_num - self.max_ckpt_num
|
||||
path = os.path.join(self.output_dir, "ternary_bert_{}_{}.ckpt".format(int(oldest_ckpt_index),
|
||||
self.save_ckpt_step))
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
params_dict = save_params(self.network)
|
||||
convert_network(self.network, self.embedding_bits, self.weight_bits, self.clip_value)
|
||||
save_checkpoint(self.network, os.path.join(self.output_dir,
|
||||
"ternary_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
|
||||
self.save_ckpt_step)))
|
||||
restore_params(self.network, params_dict)
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
"""
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0")
|
||||
self._per_print_times = per_print_times
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end and print loss"""
|
||||
cb_params = run_context.original_args()
|
||||
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
|
||||
|
||||
class StepCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss in NAN or INF terminating training.
|
||||
Note:
|
||||
if per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(StepCallBack, self).__init__()
|
||||
self.start_time = 0.0
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.start_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
time_cost = time.time() - self.start_time
|
||||
cb_params = run_context.original_args()
|
||||
print("step: {}, second_per_step: {}".format(cb_params.cur_step_num, time_cost))
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""Evaluation callback"""
|
||||
def __init__(self, network, dataset, eval_ckpt_step, save_ckpt_dir, embedding_bits=2, weight_bits=2,
|
||||
clip_value=1.0, metrics=None):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.network = network
|
||||
self.global_metrics = 0.0
|
||||
self.dataset = dataset
|
||||
self.eval_ckpt_step = eval_ckpt_step
|
||||
self.save_ckpt_dir = save_ckpt_dir
|
||||
self.embedding_bits = embedding_bits
|
||||
self.weight_bits = weight_bits
|
||||
self.clip_value = clip_value
|
||||
self.metrics = metrics
|
||||
if not os.path.exists(save_ckpt_dir):
|
||||
os.makedirs(save_ckpt_dir)
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end and do evaluation"""
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.cur_step_num % self.eval_ckpt_step == 0:
|
||||
params_dict = save_params(self.network)
|
||||
convert_network(self.network, self.embedding_bits, self.weight_bits, self.clip_value)
|
||||
self.network.set_train(False)
|
||||
callback = self.metrics()
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
for data in self.dataset:
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
_, _, logits, _ = self.network(input_ids, token_type_id, input_mask)
|
||||
callback.update(logits, label_ids)
|
||||
metrics = callback.get_metrics()
|
||||
|
||||
if metrics > self.global_metrics:
|
||||
self.global_metrics = metrics
|
||||
eval_model_ckpt_file = os.path.join(self.save_ckpt_dir, 'eval_model.ckpt')
|
||||
if os.path.exists(eval_model_ckpt_file):
|
||||
os.remove(eval_model_ckpt_file)
|
||||
save_checkpoint(self.network, eval_model_ckpt_file)
|
||||
print('step {}, {} {}, best_{} {}'.format(cb_params.cur_step_num,
|
||||
callback.name,
|
||||
metrics,
|
||||
callback.name,
|
||||
self.global_metrics))
|
||||
restore_params(self.network, params_dict)
|
||||
self.network.set_train(True)
|
||||
|
||||
|
||||
class BertLearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Warmup-decay learning rate for Bert network.
|
||||
"""
|
||||
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
|
||||
super(BertLearningRate, self).__init__()
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
self.warmup_flag = True
|
||||
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
|
||||
self.greater = P.Greater()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, global_step):
|
||||
decay_lr = self.decay_lr(global_step)
|
||||
if self.warmup_flag:
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||
warmup_lr = self.warmup_lr(global_step)
|
||||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
else:
|
||||
lr = decay_lr
|
||||
return lr
|
@ -0,0 +1,165 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""task distill script"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore import set_seed
|
||||
from src.dataset import create_tinybert_dataset
|
||||
from src.utils import StepCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
|
||||
from src.config import train_cfg, eval_cfg, teacher_net_cfg, student_net_cfg, task_cfg
|
||||
from src.cell_wrapper import BertNetworkWithLoss, BertTrainCell
|
||||
|
||||
WEIGHTS_NAME = 'eval_model.ckpt'
|
||||
EVAL_DATA_NAME = 'eval.tf_record'
|
||||
TRAIN_DATA_NAME = 'train.tf_record'
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='ternarybert task distill')
|
||||
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU'],
|
||||
help='Device where the code will be implemented. (Default: GPU)')
|
||||
parser.add_argument('--do_eval', type=str, default='true', choices=['true', 'false'],
|
||||
help='Do eval task during training or not. (Default: true)')
|
||||
parser.add_argument('--epoch_size', type=int, default=3, help='Epoch size for train phase. (Default: 3)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='Device id. (Default: 0)')
|
||||
parser.add_argument('--do_shuffle', type=str, default='true', choices=['true', 'false'],
|
||||
help='Enable shuffle for train dataset. (Default: true)')
|
||||
parser.add_argument('--enable_data_sink', type=str, default='true', choices=['true', 'false'],
|
||||
help='Enable data sink. (Default: true)')
|
||||
parser.add_argument('--save_ckpt_step', type=int, default=50,
|
||||
help='If do_eval is false, the checkpoint will be saved every save_ckpt_step. (Default: 50)')
|
||||
parser.add_argument('--eval_ckpt_step', type=int, default=50,
|
||||
help='If do_eval is true, the evaluation will be ran every eval_ckpt_step. (Default: 50)')
|
||||
parser.add_argument('--max_ckpt_num', type=int, default=10,
|
||||
help='The number of checkpoints will not be larger than max_ckpt_num. (Default: 10)')
|
||||
parser.add_argument('--data_sink_steps', type=int, default=1, help='Sink steps for each epoch. (Default: 1)')
|
||||
parser.add_argument('--teacher_model_dir', type=str, default='', help='The checkpoint directory of teacher model.')
|
||||
parser.add_argument('--student_model_dir', type=str, default='', help='The checkpoint directory of student model.')
|
||||
parser.add_argument('--data_dir', type=str, default='', help='Data directory.')
|
||||
parser.add_argument('--output_dir', type=str, default='./', help='The output checkpoint directory.')
|
||||
parser.add_argument('--task_name', type=str, default='sts-b', choices=['sts-b', 'qnli', 'mnli'],
|
||||
help='The name of the task to train. (Default: sts-b)')
|
||||
parser.add_argument('--dataset_type', type=str, default='tfrecord', choices=['tfrecord', 'mindrecord'],
|
||||
help='The name of the task to train. (Default: tfrecord)')
|
||||
parser.add_argument('--seed', type=int, default=1, help='The random seed')
|
||||
parser.add_argument('--train_batch_size', type=int, default=16, help='Batch size for training')
|
||||
parser.add_argument('--eval_batch_size', type=int, default=32, help='Eval Batch size in callback')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def run_task_distill(args_opt):
|
||||
"""
|
||||
run task distill
|
||||
"""
|
||||
task = task_cfg[args_opt.task_name]
|
||||
teacher_net_cfg.seq_length = task.seq_length
|
||||
student_net_cfg.seq_length = task.seq_length
|
||||
train_cfg.batch_size = args_opt.train_batch_size
|
||||
eval_cfg.batch_size = args_opt.eval_batch_size
|
||||
teacher_ckpt = os.path.join(args_opt.teacher_model_dir, args_opt.task_name, WEIGHTS_NAME)
|
||||
student_ckpt = os.path.join(args_opt.student_model_dir, args_opt.task_name, WEIGHTS_NAME)
|
||||
train_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name, TRAIN_DATA_NAME)
|
||||
eval_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name, EVAL_DATA_NAME)
|
||||
save_ckpt_dir = os.path.join(args_opt.output_dir, args_opt.task_name)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args.device_id)
|
||||
|
||||
rank = 0
|
||||
device_num = 1
|
||||
train_dataset = create_tinybert_dataset(batch_size=train_cfg.batch_size,
|
||||
device_num=device_num,
|
||||
rank=rank,
|
||||
do_shuffle=args_opt.do_shuffle,
|
||||
data_dir=train_data_dir,
|
||||
data_type=args_opt.dataset_type,
|
||||
seq_length=task.seq_length,
|
||||
task_type=task.task_type,
|
||||
drop_remainder=True)
|
||||
dataset_size = train_dataset.get_dataset_size()
|
||||
print('train dataset size:', dataset_size)
|
||||
eval_dataset = create_tinybert_dataset(batch_size=eval_cfg.batch_size,
|
||||
device_num=device_num,
|
||||
rank=rank,
|
||||
do_shuffle=args_opt.do_shuffle,
|
||||
data_dir=eval_data_dir,
|
||||
data_type=args_opt.dataset_type,
|
||||
seq_length=task.seq_length,
|
||||
task_type=task.task_type,
|
||||
drop_remainder=False)
|
||||
print('eval dataset size:', eval_dataset.get_dataset_size())
|
||||
|
||||
if args_opt.enable_data_sink == 'true':
|
||||
repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
|
||||
else:
|
||||
repeat_count = args_opt.epoch_size
|
||||
|
||||
netwithloss = BertNetworkWithLoss(teacher_config=teacher_net_cfg, teacher_ckpt=teacher_ckpt,
|
||||
student_config=student_net_cfg, student_ckpt=student_ckpt,
|
||||
is_training=True, task_type=task.task_type, num_labels=task.num_labels)
|
||||
params = netwithloss.trainable_params()
|
||||
optimizer_cfg = train_cfg.optimizer_cfg
|
||||
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=int(dataset_size * args_opt.epoch_size *
|
||||
optimizer_cfg.AdamWeightDecay.warmup_ratio),
|
||||
decay_steps=int(dataset_size * args_opt.epoch_size),
|
||||
power=optimizer_cfg.AdamWeightDecay.power)
|
||||
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
||||
|
||||
netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)
|
||||
|
||||
if args_opt.do_eval == 'true':
|
||||
eval_dataset = list(eval_dataset.create_dict_iterator())
|
||||
callback = [EvalCallBack(network=netwithloss.bert,
|
||||
dataset=eval_dataset,
|
||||
eval_ckpt_step=args_opt.eval_ckpt_step,
|
||||
save_ckpt_dir=save_ckpt_dir,
|
||||
embedding_bits=student_net_cfg.embedding_bits,
|
||||
weight_bits=student_net_cfg.weight_bits,
|
||||
clip_value=student_net_cfg.weight_clip_value,
|
||||
metrics=task.metrics)]
|
||||
else:
|
||||
callback = [StepCallBack(),
|
||||
ModelSaveCkpt(network=netwithloss.bert,
|
||||
save_ckpt_step=args_opt.save_ckpt_step,
|
||||
max_ckpt_num=args_opt.max_ckpt_num,
|
||||
output_dir=save_ckpt_dir,
|
||||
embedding_bits=student_net_cfg.embedding_bits,
|
||||
weight_bits=student_net_cfg.weight_bits,
|
||||
clip_value=student_net_cfg.weight_clip_value)]
|
||||
model = Model(netwithgrads)
|
||||
model.train(repeat_count, train_dataset, callbacks=callback,
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
set_seed(args.seed)
|
||||
run_task_distill(args)
|
Loading…
Reference in new issue