add ternarybert to model zoo

pull/11157/head
w00517672 4 years ago
parent 21addb331d
commit 65c7eb2461

File diff suppressed because it is too large Load Diff

@ -0,0 +1,107 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval standalone script"""
import os
import re
import argparse
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import create_tinybert_dataset
from src.config import eval_cfg, student_net_cfg, task_cfg
from src.tinybert_model import BertModelCLS
DATA_NAME = 'eval.tf_record'
def parse_args():
"""
parse args
"""
parser = argparse.ArgumentParser(description='ternarybert evaluation')
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU'],
help='Device where the code will be implemented. (Default: GPU)')
parser.add_argument('--device_id', type=int, default=0, help='Device id. (Default: 0)')
parser.add_argument('--model_dir', type=str, default='', help='The checkpoint directory of model.')
parser.add_argument('--data_dir', type=str, default='', help='Data directory.')
parser.add_argument('--task_name', type=str, default='sts-b', choices=['sts-b', 'qnli', 'mnli'],
help='The name of the task to train. (Default: sts-b)')
parser.add_argument('--dataset_type', type=str, default='tfrecord', choices=['tfrecord', 'mindrecord'],
help='The name of the task to train. (Default: tfrecord)')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for evaluating')
return parser.parse_args()
def get_ckpt(ckpt_file):
lists = os.listdir(ckpt_file)
lists.sort(key=lambda fn: os.path.getmtime(ckpt_file + '/' + fn))
return os.path.join(ckpt_file, lists[-1])
def do_eval_standalone(args_opt):
"""
do eval standalone
"""
ckpt_file = os.path.join(args_opt.model_dir, args_opt.task_name)
ckpt_file = get_ckpt(ckpt_file)
print('ckpt file:', ckpt_file)
task = task_cfg[args_opt.task_name]
student_net_cfg.seq_length = task.seq_length
eval_cfg.batch_size = args_opt.batch_size
eval_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name, DATA_NAME)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args.device_id)
eval_dataset = create_tinybert_dataset(batch_size=eval_cfg.batch_size,
device_num=1,
rank=0,
do_shuffle='false',
data_dir=eval_data_dir,
data_type=args_opt.dataset_type,
seq_length=task.seq_length,
task_type=task.task_type,
drop_remainder=False)
print('eval dataset size:', eval_dataset.get_dataset_size())
print('eval dataset batch size:', eval_dataset.get_batch_size())
eval_model = BertModelCLS(student_net_cfg, False, task.num_labels, 0.0, phase_type='student')
param_dict = load_checkpoint(ckpt_file)
new_param_dict = {}
for key, value in param_dict.items():
new_key = re.sub('tinybert_', 'bert_', key)
new_key = re.sub('^bert.', '', new_key)
new_param_dict[new_key] = value
load_param_into_net(eval_model, new_param_dict)
eval_model.set_train(False)
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
callback = task.metrics()
for step, data in enumerate(eval_dataset.create_dict_iterator()):
input_data = []
for i in columns_list:
input_data.append(data[i])
input_ids, input_mask, token_type_id, label_ids = input_data
_, _, logits, _ = eval_model(input_ids, token_type_id, input_mask)
callback.update(logits, label_ids)
print('eval step: {}, {}: {}'.format(step, callback.name, callback.get_metrics()))
metrics = callback.get_metrics()
print('The best {}: {}'.format(callback.name, metrics))
if __name__ == '__main__':
args = parse_args()
do_eval_standalone(args)

@ -0,0 +1,57 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert hub interface for bert base"""
from src.tinybert_model import BertModel
from src.tinybert_model import BertConfig
import mindspore.common.dtype as mstype
tinybert_student_net_cfg = BertConfig(
seq_length=128,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=6,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
dtype=mstype.float32,
compute_type=mstype.float32,
do_quant=True,
embedding_bits=2,
weight_bits=2,
weight_clip_value=3.0,
cls_dropout_prob=0.1,
activation_init=2.5,
is_lgt_fit=False
)
def create_network(name, *args, **kwargs):
"""
Create tinybert network.
"""
if name == "ternarybert":
if "seq_length" in kwargs:
tinybert_student_net_cfg.seq_length = kwargs["seq_length"]
is_training = kwargs.get("is_training", False)
return BertModel(tinybert_student_net_cfg, is_training, *args)
raise NotImplementedError(f"{name} is not implemented in the repo")

@ -0,0 +1,26 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../eval.py \
--task_name=sts-b \
--device_id=0 \
--model_dir="" \
--data_dir="" > log.txt

@ -0,0 +1,27 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../train.py \
--task_name=sts-b \
--device_id=0 \
--teacher_model_dir="" \
--student_model_dir="" \
--data_dir="" > log.txt

@ -0,0 +1,115 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""assessment methods"""
import numpy as np
class Accuracy:
"""Accuracy"""
def __init__(self):
self.acc_num = 0
self.total_num = 0
self.name = 'Accuracy'
def update(self, logits, labels):
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
logit_id = np.argmax(logits, axis=-1)
self.acc_num += np.sum(labels == logit_id)
self.total_num += len(labels)
def get_metrics(self):
return self.acc_num / self.total_num * 100.0
class F1:
"""F1"""
def __init__(self):
self.logits_array = np.array([])
self.labels_array = np.array([])
self.name = 'F1'
def update(self, logits, labels):
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
logits = np.argmax(logits, axis=1)
self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool)
self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool)
def get_metrics(self):
if len(self.labels_array) < 2:
return 0.0
tp = np.sum(self.labels_array & self.logits_array)
fp = np.sum(self.labels_array & (~self.logits_array))
fn = np.sum((~self.labels_array) & self.logits_array)
p = tp / (tp + fp)
r = tp / (tp + fn)
return 2.0 * p * r / (p + r) * 100.0
class Pearsonr:
"""Pearsonr"""
def __init__(self):
self.logits_array = np.array([])
self.labels_array = np.array([])
self.name = 'Pearsonr'
def update(self, logits, labels):
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
logits = np.reshape(logits, -1)
self.labels_array = np.concatenate([self.labels_array, labels])
self.logits_array = np.concatenate([self.logits_array, logits])
def get_metrics(self):
if len(self.labels_array) < 2:
return 0.0
x_mean = self.logits_array.mean()
y_mean = self.labels_array.mean()
xm = self.logits_array - x_mean
ym = self.labels_array - y_mean
norm_xm = np.linalg.norm(xm)
norm_ym = np.linalg.norm(ym)
return np.dot(xm / norm_xm, ym / norm_ym) * 100.0
class Matthews:
"""Matthews"""
def __init__(self):
self.logits_array = np.array([])
self.labels_array = np.array([])
self.name = 'Matthews'
def update(self, logits, labels):
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
logits = np.argmax(logits, axis=1)
self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool)
self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool)
def get_metrics(self):
if len(self.labels_array) < 2:
return 0.0
tp = np.sum(self.labels_array & self.logits_array)
fp = np.sum(self.labels_array & (~self.logits_array))
fn = np.sum((~self.labels_array) & self.logits_array)
tn = np.sum((~self.labels_array) & (~self.logits_array))
return (tp * tn - fp * fn) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) * 100.0

File diff suppressed because it is too large Load Diff

@ -0,0 +1,103 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""config script"""
import mindspore.common.dtype as mstype
from easydict import EasyDict as edict
from .tinybert_model import BertConfig
from .assessment_method import Accuracy, F1, Pearsonr, Matthews
gradient_cfg = edict({
'clip_type': 1,
'clip_value': 1.0
})
task_cfg = edict({
"sst-2": edict({"num_labels": 2, "seq_length": 64, "task_type": "classification", "metrics": Accuracy}),
"qnli": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": Accuracy}),
"mnli": edict({"num_labels": 3, "seq_length": 128, "task_type": "classification", "metrics": Accuracy}),
"cola": edict({"num_labels": 2, "seq_length": 64, "task_type": "classification", "metrics": Matthews}),
"mrpc": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": F1}),
"sts-b": edict({"num_labels": 1, "seq_length": 128, "task_type": "regression", "metrics": Pearsonr}),
"qqp": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": F1}),
"rte": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": Accuracy})
})
train_cfg = edict({
'batch_size': 16,
'loss_scale_value': 2 ** 16,
'scale_factor': 2,
'scale_window': 50,
'optimizer_cfg': edict({
'AdamWeightDecay': edict({
'learning_rate': 5e-5,
'end_learning_rate': 1e-14,
'power': 1.0,
'weight_decay': 1e-4,
'eps': 1e-6,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'warmup_ratio': 0.1
}),
}),
})
eval_cfg = edict({
'batch_size': 32,
})
teacher_net_cfg = BertConfig(
seq_length=128,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=6,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
dtype=mstype.float32,
compute_type=mstype.float32,
do_quant=False
)
student_net_cfg = BertConfig(
seq_length=128,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=6,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
dtype=mstype.float32,
compute_type=mstype.float32,
do_quant=True,
embedding_bits=2,
weight_bits=2,
weight_clip_value=3.0,
cls_dropout_prob=0.1,
activation_init=2.5,
is_lgt_fit=False
)

@ -0,0 +1,62 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create tinybert dataset"""
from enum import Enum
import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C
class DataType(Enum):
"""Enumerate supported dataset format"""
TFRECORD = 1
MINDRECORD = 2
def create_tinybert_dataset(batch_size=32, device_num=1, rank=0, do_shuffle="true", data_dir=None,
data_type='tfrecord', seq_length=128, task_type=mstype.int32, drop_remainder=True):
"""create tinybert dataset"""
if isinstance(data_dir, list):
data_files = data_dir
else:
data_files = [data_dir]
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
shuffle = (do_shuffle == "true")
if data_type == 'mindrecord':
ds = de.MindDataset(data_files, columns_list=columns_list, shuffle=shuffle, num_shards=device_num,
shard_id=rank)
else:
ds = de.TFRecordDataset(data_files, columns_list=columns_list, shuffle=shuffle, num_shards=device_num,
shard_id=rank, shard_equal_rows=(device_num == 1))
if device_num == 1 and shuffle is True:
ds = ds.shuffle(10000)
type_cast_op = C.TypeCast(mstype.int32)
slice_op = C.Slice(slice(0, seq_length, 1))
label_type = mstype.int32 if task_type == 'classification' else mstype.float32
ds = ds.map(operations=[type_cast_op, slice_op], input_columns=["segment_ids"])
ds = ds.map(operations=[type_cast_op, slice_op], input_columns=["input_mask"])
ds = ds.map(operations=[type_cast_op, slice_op], input_columns=["input_ids"])
ds = ds.map(operations=[C.TypeCast(label_type), slice_op], input_columns=["label_ids"])
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
return ds

@ -0,0 +1,171 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Quantization function."""
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import nn
class QuantizeWeightCell(nn.Cell):
"""
The ternary fake quant op for weight.
Args:
num_bits (int): The bit number of quantization, supporting 2 to 8 bits. Default: 2.
compute_type (:class:`mindspore.dtype`): Compute type in QuantizeWeightCell. Default: mstype.float32.
clip_value (float): Clips weight to be in [-clip_value, clip_value].
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
Inputs:
- **weight** (Parameter) - Parameter of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Parameter of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
"""
def __init__(self, num_bits=8, compute_type=mstype.float32, clip_value=1.0, per_channel=False):
super(QuantizeWeightCell, self).__init__()
self.num_bits = num_bits
self.compute_type = compute_type
self.clip_value = clip_value
self.per_channel = per_channel
self.clamp = C.clip_by_value
self.abs = P.Abs()
self.sum = P.ReduceSum()
self.nelement = F.size
self.div = P.Div()
self.cast = P.Cast()
self.max = P.ReduceMax()
self.min = P.ReduceMin()
self.round = P.Round()
def construct(self, weight):
"""quantize weight cell"""
tensor = self.clamp(weight, -self.clip_value, self.clip_value)
if self.num_bits == 2:
if self.per_channel:
n = self.nelement(tensor[0])
m = self.div(self.sum(self.abs(tensor), 1), n)
thres = 0.7 * m
pos = self.cast(tensor[:] > thres[0], self.compute_type)
neg = self.cast(tensor[:] < -thres[0], self.compute_type)
mask = self.cast(self.abs(tensor)[:] > thres[0], self.compute_type)
alpha = self.reshape(self.sum(self.abs(mask * tensor), 1) / self.sum(mask, 1), (-1, 1))
output = alpha * pos - alpha * neg
else:
n = self.nelement(tensor)
m = self.div(self.sum(self.abs(tensor)), n)
thres = 0.7 * m
pos = self.cast(tensor > thres, self.compute_type)
neg = self.cast(tensor < -thres, self.compute_type)
mask = self.cast(self.abs(tensor) > thres, self.compute_type)
alpha = self.sum(self.abs(mask * self.cast(tensor, self.compute_type))) / self.sum(mask)
output = alpha * pos - alpha * neg
else:
tensor_max = self.cast(self.max(tensor), self.compute_type)
tensor_min = self.cast(self.min(tensor), self.compute_type)
s = (tensor_max - tensor_min) / (2 ** self.cast(self.num_bits, self.compute_type) - 1)
output = self.round(self.div(tensor - tensor_min, s)) * s + tensor_min
return output
class QuantizeWeight:
"""
Quantize weight into specified bit.
Args:
num_bits (int): The bit number of quantization, supporting 2 to 8 bits. Default: 2.
compute_type (:class:`mindspore.dtype`): Compute type in QuantizeWeightCell. Default: mstype.float32.
clip_value (float): Clips weight to be in [-clip_value, clip_value].
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
Inputs:
- **weight** (Parameter) - Parameter of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Parameter of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
"""
def __init__(self, num_bits=2, compute_type=mstype.float32, clip_value=1.0, per_channel=False):
self.num_bits = num_bits
self.compute_type = compute_type
self.clip_value = clip_value
self.per_channel = per_channel
self.clamp = C.clip_by_value
self.abs = P.Abs()
self.sum = P.ReduceSum()
self.nelement = F.size
self.div = P.Div()
self.cast = P.Cast()
self.max = P.ReduceMax()
self.min = P.ReduceMin()
self.floor = P.Floor()
def construct(self, weight):
"""quantize weight"""
tensor = self.clamp(weight, -self.clip_value, self.clip_value)
if self.num_bits == 2:
if self.per_channel:
n = self.nelement(tensor[0])
m = self.div(self.sum(self.abs(tensor), 1), n)
thres = 0.7 * m
pos = self.cast(tensor[:] > thres[0], self.compute_type)
neg = self.cast(tensor[:] < -thres[0], self.compute_type)
mask = self.cast(self.abs(tensor)[:] > thres[0], self.compute_type)
alpha = self.reshape(self.sum(self.abs(mask * tensor), 1) / self.sum(mask, 1), (-1, 1))
output = alpha * pos - alpha * neg
else:
n = self.nelement(tensor)
m = self.div(self.sum(self.abs(tensor)), n)
thres = 0.7 * m
pos = self.cast(tensor > thres, self.compute_type)
neg = self.cast(tensor < -thres, self.compute_type)
mask = self.cast(self.abs(tensor) > thres, self.compute_type)
alpha = self.sum(self.abs(mask * tensor)) / self.sum(mask)
output = alpha * pos - alpha * neg
else:
tensor_max = self.max(tensor)
tensor_min = self.min(tensor)
s = (tensor_max - tensor_min) / (2 ** self.num_bits - 1)
output = self.floor(self.div((tensor - tensor_min), s) + 0.5) * s + tensor_min
return output
def convert_network(network, embedding_bits=2, weight_bits=2, clip_value=1.0):
quantize_embedding = QuantizeWeight(num_bits=embedding_bits, clip_value=clip_value)
quantize_weight = QuantizeWeight(num_bits=weight_bits, clip_value=clip_value)
for name, param in network.parameters_and_names():
if 'bert_embedding_lookup' in name and 'min' not in name and 'max' not in name:
quantized_param = quantize_embedding.construct(param)
param.set_data(quantized_param)
elif 'weight' in name and 'dense_1' not in name:
quantized_param = quantize_weight.construct(param)
param.set_data(quantized_param)
def save_params(network):
return {name: Parameter(param, 'saved_params') for name, param in network.parameters_and_names()}
def restore_params(network, params_dict):
for name, param in network.parameters_and_names():
param.set_data(params_dict[name])

File diff suppressed because it is too large Load Diff

@ -0,0 +1,187 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ternarybert utils"""
import os
import time
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.train.callback import Callback
from mindspore.train.serialization import save_checkpoint
from mindspore.ops import operations as P
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
from .quant import convert_network, save_params, restore_params
class ModelSaveCkpt(Callback):
"""
Saves checkpoint.
If the loss in NAN or INF terminating training.
Args:
network (Network): The train network for training.
save_ckpt_step (int): The step to save checkpoint.
max_ckpt_num (int): The max checkpoint number.
"""
def __init__(self, network, save_ckpt_step, max_ckpt_num, output_dir, embedding_bits=2, weight_bits=2,
clip_value=1.0):
super(ModelSaveCkpt, self).__init__()
self.count = 0
self.network = network
self.save_ckpt_step = save_ckpt_step
self.max_ckpt_num = max_ckpt_num
self.output_dir = output_dir
if not os.path.exists(output_dir):
os.makedirs(output_dir)
self.embedding_bits = embedding_bits
self.weight_bits = weight_bits
self.clip_value = clip_value
def step_end(self, run_context):
"""step end and save ckpt"""
cb_params = run_context.original_args()
if cb_params.cur_step_num % self.save_ckpt_step == 0:
saved_ckpt_num = cb_params.cur_step_num / self.save_ckpt_step
if saved_ckpt_num > self.max_ckpt_num:
oldest_ckpt_index = saved_ckpt_num - self.max_ckpt_num
path = os.path.join(self.output_dir, "ternary_bert_{}_{}.ckpt".format(int(oldest_ckpt_index),
self.save_ckpt_step))
if os.path.exists(path):
os.remove(path)
params_dict = save_params(self.network)
convert_network(self.network, self.embedding_bits, self.weight_bits, self.clip_value)
save_checkpoint(self.network, os.path.join(self.output_dir,
"ternary_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
self.save_ckpt_step)))
restore_params(self.network, params_dict)
class LossCallBack(Callback):
"""
Monitor the loss in training.
"""
def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0")
self._per_print_times = per_print_times
def step_end(self, run_context):
"""step end and print loss"""
cb_params = run_context.original_args()
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num,
cb_params.cur_step_num,
str(cb_params.net_outputs)))
class StepCallBack(Callback):
"""
Monitor the loss in training.
If the loss in NAN or INF terminating training.
Note:
if per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self):
super(StepCallBack, self).__init__()
self.start_time = 0.0
def step_begin(self, run_context):
self.start_time = time.time()
def step_end(self, run_context):
time_cost = time.time() - self.start_time
cb_params = run_context.original_args()
print("step: {}, second_per_step: {}".format(cb_params.cur_step_num, time_cost))
class EvalCallBack(Callback):
"""Evaluation callback"""
def __init__(self, network, dataset, eval_ckpt_step, save_ckpt_dir, embedding_bits=2, weight_bits=2,
clip_value=1.0, metrics=None):
super(EvalCallBack, self).__init__()
self.network = network
self.global_metrics = 0.0
self.dataset = dataset
self.eval_ckpt_step = eval_ckpt_step
self.save_ckpt_dir = save_ckpt_dir
self.embedding_bits = embedding_bits
self.weight_bits = weight_bits
self.clip_value = clip_value
self.metrics = metrics
if not os.path.exists(save_ckpt_dir):
os.makedirs(save_ckpt_dir)
def step_end(self, run_context):
"""step end and do evaluation"""
cb_params = run_context.original_args()
if cb_params.cur_step_num % self.eval_ckpt_step == 0:
params_dict = save_params(self.network)
convert_network(self.network, self.embedding_bits, self.weight_bits, self.clip_value)
self.network.set_train(False)
callback = self.metrics()
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
for data in self.dataset:
input_data = []
for i in columns_list:
input_data.append(data[i])
input_ids, input_mask, token_type_id, label_ids = input_data
_, _, logits, _ = self.network(input_ids, token_type_id, input_mask)
callback.update(logits, label_ids)
metrics = callback.get_metrics()
if metrics > self.global_metrics:
self.global_metrics = metrics
eval_model_ckpt_file = os.path.join(self.save_ckpt_dir, 'eval_model.ckpt')
if os.path.exists(eval_model_ckpt_file):
os.remove(eval_model_ckpt_file)
save_checkpoint(self.network, eval_model_ckpt_file)
print('step {}, {} {}, best_{} {}'.format(cb_params.cur_step_num,
callback.name,
metrics,
callback.name,
self.global_metrics))
restore_params(self.network, params_dict)
self.network.set_train(True)
class BertLearningRate(LearningRateSchedule):
"""
Warmup-decay learning rate for Bert network.
"""
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__()
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
decay_lr = self.decay_lr(global_step)
if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr

@ -0,0 +1,165 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""task distill script"""
import os
import argparse
from mindspore import context
from mindspore.train.model import Model
from mindspore.nn.optim import AdamWeightDecay
from mindspore import set_seed
from src.dataset import create_tinybert_dataset
from src.utils import StepCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
from src.config import train_cfg, eval_cfg, teacher_net_cfg, student_net_cfg, task_cfg
from src.cell_wrapper import BertNetworkWithLoss, BertTrainCell
WEIGHTS_NAME = 'eval_model.ckpt'
EVAL_DATA_NAME = 'eval.tf_record'
TRAIN_DATA_NAME = 'train.tf_record'
def parse_args():
"""
parse args
"""
parser = argparse.ArgumentParser(description='ternarybert task distill')
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU'],
help='Device where the code will be implemented. (Default: GPU)')
parser.add_argument('--do_eval', type=str, default='true', choices=['true', 'false'],
help='Do eval task during training or not. (Default: true)')
parser.add_argument('--epoch_size', type=int, default=3, help='Epoch size for train phase. (Default: 3)')
parser.add_argument('--device_id', type=int, default=0, help='Device id. (Default: 0)')
parser.add_argument('--do_shuffle', type=str, default='true', choices=['true', 'false'],
help='Enable shuffle for train dataset. (Default: true)')
parser.add_argument('--enable_data_sink', type=str, default='true', choices=['true', 'false'],
help='Enable data sink. (Default: true)')
parser.add_argument('--save_ckpt_step', type=int, default=50,
help='If do_eval is false, the checkpoint will be saved every save_ckpt_step. (Default: 50)')
parser.add_argument('--eval_ckpt_step', type=int, default=50,
help='If do_eval is true, the evaluation will be ran every eval_ckpt_step. (Default: 50)')
parser.add_argument('--max_ckpt_num', type=int, default=10,
help='The number of checkpoints will not be larger than max_ckpt_num. (Default: 10)')
parser.add_argument('--data_sink_steps', type=int, default=1, help='Sink steps for each epoch. (Default: 1)')
parser.add_argument('--teacher_model_dir', type=str, default='', help='The checkpoint directory of teacher model.')
parser.add_argument('--student_model_dir', type=str, default='', help='The checkpoint directory of student model.')
parser.add_argument('--data_dir', type=str, default='', help='Data directory.')
parser.add_argument('--output_dir', type=str, default='./', help='The output checkpoint directory.')
parser.add_argument('--task_name', type=str, default='sts-b', choices=['sts-b', 'qnli', 'mnli'],
help='The name of the task to train. (Default: sts-b)')
parser.add_argument('--dataset_type', type=str, default='tfrecord', choices=['tfrecord', 'mindrecord'],
help='The name of the task to train. (Default: tfrecord)')
parser.add_argument('--seed', type=int, default=1, help='The random seed')
parser.add_argument('--train_batch_size', type=int, default=16, help='Batch size for training')
parser.add_argument('--eval_batch_size', type=int, default=32, help='Eval Batch size in callback')
return parser.parse_args()
def run_task_distill(args_opt):
"""
run task distill
"""
task = task_cfg[args_opt.task_name]
teacher_net_cfg.seq_length = task.seq_length
student_net_cfg.seq_length = task.seq_length
train_cfg.batch_size = args_opt.train_batch_size
eval_cfg.batch_size = args_opt.eval_batch_size
teacher_ckpt = os.path.join(args_opt.teacher_model_dir, args_opt.task_name, WEIGHTS_NAME)
student_ckpt = os.path.join(args_opt.student_model_dir, args_opt.task_name, WEIGHTS_NAME)
train_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name, TRAIN_DATA_NAME)
eval_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name, EVAL_DATA_NAME)
save_ckpt_dir = os.path.join(args_opt.output_dir, args_opt.task_name)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args.device_id)
rank = 0
device_num = 1
train_dataset = create_tinybert_dataset(batch_size=train_cfg.batch_size,
device_num=device_num,
rank=rank,
do_shuffle=args_opt.do_shuffle,
data_dir=train_data_dir,
data_type=args_opt.dataset_type,
seq_length=task.seq_length,
task_type=task.task_type,
drop_remainder=True)
dataset_size = train_dataset.get_dataset_size()
print('train dataset size:', dataset_size)
eval_dataset = create_tinybert_dataset(batch_size=eval_cfg.batch_size,
device_num=device_num,
rank=rank,
do_shuffle=args_opt.do_shuffle,
data_dir=eval_data_dir,
data_type=args_opt.dataset_type,
seq_length=task.seq_length,
task_type=task.task_type,
drop_remainder=False)
print('eval dataset size:', eval_dataset.get_dataset_size())
if args_opt.enable_data_sink == 'true':
repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
else:
repeat_count = args_opt.epoch_size
netwithloss = BertNetworkWithLoss(teacher_config=teacher_net_cfg, teacher_ckpt=teacher_ckpt,
student_config=student_net_cfg, student_ckpt=student_ckpt,
is_training=True, task_type=task.task_type, num_labels=task.num_labels)
params = netwithloss.trainable_params()
optimizer_cfg = train_cfg.optimizer_cfg
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=int(dataset_size * args_opt.epoch_size *
optimizer_cfg.AdamWeightDecay.warmup_ratio),
decay_steps=int(dataset_size * args_opt.epoch_size),
power=optimizer_cfg.AdamWeightDecay.power)
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)
if args_opt.do_eval == 'true':
eval_dataset = list(eval_dataset.create_dict_iterator())
callback = [EvalCallBack(network=netwithloss.bert,
dataset=eval_dataset,
eval_ckpt_step=args_opt.eval_ckpt_step,
save_ckpt_dir=save_ckpt_dir,
embedding_bits=student_net_cfg.embedding_bits,
weight_bits=student_net_cfg.weight_bits,
clip_value=student_net_cfg.weight_clip_value,
metrics=task.metrics)]
else:
callback = [StepCallBack(),
ModelSaveCkpt(network=netwithloss.bert,
save_ckpt_step=args_opt.save_ckpt_step,
max_ckpt_num=args_opt.max_ckpt_num,
output_dir=save_ckpt_dir,
embedding_bits=student_net_cfg.embedding_bits,
weight_bits=student_net_cfg.weight_bits,
clip_value=student_net_cfg.weight_clip_value)]
model = Model(netwithgrads)
model.train(repeat_count, train_dataset, callbacks=callback,
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
sink_size=args_opt.data_sink_steps)
if __name__ == '__main__':
args = parse_args()
set_seed(args.seed)
run_task_distill(args)
Loading…
Cancel
Save