extract embedding table from unified interface

pull/10318/head
shibeiji 4 years ago
parent 0212e19bc9
commit 812b4b0eab

@ -166,11 +166,10 @@ class EmbeddingPostprocessor(nn.Cell):
self.token_type_vocab_size = token_type_vocab_size self.token_type_vocab_size = token_type_vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings self.use_one_hot_embeddings = use_one_hot_embeddings
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.embedding_table = Parameter(initializer self.token_type_embedding = nn.Embedding(
(TruncatedNormal(initializer_range), vocab_size=token_type_vocab_size,
[token_type_vocab_size, embedding_size=embedding_size,
embedding_size])) use_one_hot=use_one_hot_embeddings)
self.shape_flat = (-1,) self.shape_flat = (-1,)
self.one_hot = P.OneHot() self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32) self.on_value = Tensor(1.0, mstype.float32)
@ -178,35 +177,28 @@ class EmbeddingPostprocessor(nn.Cell):
self.array_mul = P.MatMul() self.array_mul = P.MatMul()
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.shape = tuple(embedding_shape) self.shape = tuple(embedding_shape)
self.layernorm = nn.LayerNorm((embedding_size,))
self.dropout = nn.Dropout(1 - dropout_prob) self.dropout = nn.Dropout(1 - dropout_prob)
self.gather = P.GatherV2() self.gather = P.GatherV2()
self.use_relative_positions = use_relative_positions self.use_relative_positions = use_relative_positions
self.slice = P.StridedSlice() self.slice = P.StridedSlice()
self.full_position_embeddings = Parameter(initializer _, seq, _ = self.shape
(TruncatedNormal(initializer_range), self.full_position_embedding = nn.Embedding(
[max_position_embeddings, vocab_size=max_position_embeddings,
embedding_size])) embedding_size=embedding_size,
use_one_hot=False)
self.layernorm = nn.LayerNorm((embedding_size,))
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
self.add = P.TensorAdd()
def construct(self, token_type_ids, word_embeddings): def construct(self, token_type_ids, word_embeddings):
"""Postprocessors apply positional and token type embeddings to word embeddings.""" """Postprocessors apply positional and token type embeddings to word embeddings."""
output = word_embeddings output = word_embeddings
if self.use_token_type: if self.use_token_type:
flat_ids = self.reshape(token_type_ids, self.shape_flat) token_type_embeddings = self.token_type_embedding(token_type_ids)
if self.use_one_hot_embeddings: output = self.add(output, token_type_embeddings)
one_hot_ids = self.one_hot(flat_ids,
self.token_type_vocab_size, self.on_value, self.off_value)
token_type_embeddings = self.array_mul(one_hot_ids,
self.embedding_table)
else:
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
output += token_type_embeddings
if not self.use_relative_positions: if not self.use_relative_positions:
_, seq, width = self.shape position_embeddings = self.full_position_embedding(self.position_ids)
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) output = self.add(output, position_embeddings)
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
output += position_embeddings
output = self.layernorm(output) output = self.layernorm(output)
output = self.dropout(output) output = self.dropout(output)
return output return output
@ -771,6 +763,7 @@ class CreateAttentionMaskFromInputMask(nn.Cell):
def __init__(self, config): def __init__(self, config):
super(CreateAttentionMaskFromInputMask, self).__init__() super(CreateAttentionMaskFromInputMask, self).__init__()
self.input_mask = None self.input_mask = None
self.cast = P.Cast() self.cast = P.Cast()
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.shape = (-1, 1, config.seq_length) self.shape = (-1, 1, config.seq_length)
@ -808,12 +801,11 @@ class BertModel(nn.Cell):
self.last_idx = self.num_hidden_layers - 1 self.last_idx = self.num_hidden_layers - 1
output_embedding_shape = [-1, self.seq_length, self.embedding_size] output_embedding_shape = [-1, self.seq_length, self.embedding_size]
self.bert_embedding_lookup = EmbeddingLookup( self.bert_embedding_lookup = nn.Embedding(
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
embedding_size=self.embedding_size, embedding_size=self.embedding_size,
embedding_shape=output_embedding_shape, use_one_hot=use_one_hot_embeddings)
use_one_hot_embeddings=use_one_hot_embeddings, self.embedding_tables = self.bert_embedding_lookup.embedding_table
initializer_range=config.initializer_range)
self.bert_embedding_postprocessor = EmbeddingPostprocessor( self.bert_embedding_postprocessor = EmbeddingPostprocessor(
embedding_size=self.embedding_size, embedding_size=self.embedding_size,
@ -855,7 +847,8 @@ class BertModel(nn.Cell):
def construct(self, input_ids, token_type_ids, input_mask): def construct(self, input_ids, token_type_ids, input_mask):
"""Bidirectional Encoder Representations from Transformers.""" """Bidirectional Encoder Representations from Transformers."""
# embedding # embedding
word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids) embedding_tables = self.embedding_tables
word_embeddings = self.bert_embedding_lookup(input_ids)
embedding_output = self.bert_embedding_postprocessor(token_type_ids, embedding_output = self.bert_embedding_postprocessor(token_type_ids,
word_embeddings) word_embeddings)

@ -38,11 +38,11 @@ cfg = edict({
'Lamb': edict({ 'Lamb': edict({
'learning_rate': 3e-5, 'learning_rate': 3e-5,
'end_learning_rate': 0.0, 'end_learning_rate': 0.0,
'power': 10.0, 'power': 5.0,
'warmup_steps': 10000, 'warmup_steps': 10000,
'weight_decay': 0.01, 'weight_decay': 0.01,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6, 'eps': 1e-8,
}), }),
'Momentum': edict({ 'Momentum': edict({
'learning_rate': 2e-5, 'learning_rate': 2e-5,

@ -124,16 +124,16 @@ Note: 1.the first run of training will generate the mindrecord file, which will
```shell ```shell
# create dataset in mindrecord format # create dataset in mindrecord format
bash scripts/convert_dataset_to_mindrecord.sh bash scripts/convert_dataset_to_mindrecord.sh [COCO_DATASET_DIR] [MINDRECORD_DATASET_DIR]
# standalone training on Ascend # standalone training on Ascend
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH] bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional)
# standalone training on CPU # standalone training on CPU
bash scripts/run_standalone_train_cpu.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH] bash scripts/run_standalone_train_cpu.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional)
# distributed training on Ascend # distributed training on Ascend
bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH] [RANK_TABLE_FILE] bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [RANK_TABLE_FILE] [LOAD_CHECKPOINT_PATH](optional)
# eval on Ascend # eval on Ascend
bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH] bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH]
@ -354,7 +354,7 @@ Parameters for optimizer and learning rate:
Before your first training, convert coco type dataset to mindrecord files is needed to improve performance on host. Before your first training, convert coco type dataset to mindrecord files is needed to improve performance on host.
```bash ```bash
bash scripts/convert_dataset_to_mindrecord.sh bash scripts/convert_dataset_to_mindrecord.sh /path/coco_dataset_dir /path/mindrecord_dataset_dir
``` ```
The command above will run in the background, after converting mindrecord files will be located in path specified by yourself. The command above will run in the background, after converting mindrecord files will be located in path specified by yourself.
@ -364,7 +364,7 @@ The command above will run in the background, after converting mindrecord files
#### Running on Ascend #### Running on Ascend
```bash ```bash
bash scripts/run_standalone_train_ascend.sh device_id /path/mindrecord_dataset /path/load_ckpt bash scripts/run_standalone_train_ascend.sh device_id /path/mindrecord_dataset /path/load_ckpt(optional)
``` ```
The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows: The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows:
@ -380,7 +380,7 @@ epoch: 349.0, current epoch percent: 1.00, step: 87500, outputs are (Tensor(shap
#### Running on CPU #### Running on CPU
```bash ```bash
bash scripts/run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt bash scripts/run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt(optional)
``` ```
The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows (rusume from pretrained checkpoint and batch_size was set to be 8): The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows (rusume from pretrained checkpoint and batch_size was set to be 8):
@ -401,7 +401,7 @@ epoch: 0.0, current epoch percent: 0.00, step: 5, time of per steps: 45.213 s, o
#### Running on Ascend #### Running on Ascend
```bash ```bash
bash scripts/run_distributed_pretrain_ascend.sh /path/mindrecord_dataset /path/load_ckpt /path/hccl.json bash scripts/run_distributed_pretrain_ascend.sh /path/mindrecord_dataset /path/hccl.json /path/load_ckpt(optional)
``` ```
The command above will run in the background, you can view training logs in LOG*/training_log.txt and LOG*/ms_log/. After training finished, you will get some checkpoint files under the LOG*/ckpt_0 folder by default. The loss value will be displayed as follows: The command above will run in the background, you can view training logs in LOG*/training_log.txt and LOG*/ms_log/. After training finished, you will get some checkpoint files under the LOG*/ckpt_0 folder by default. The loss value will be displayed as follows:

@ -16,13 +16,16 @@
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the scipt as: " echo "Please run the scipt as: "
echo "bash convert_dataset_to_mindrecord.sh" echo "bash convert_dataset_to_mindrecord.sh /path/coco_dataset_dir /path/mindrecord_dataset_dir"
echo "==============================================================================================================" echo "=============================================================================================================="
COCO_DIR=$1
MINDRECORD_DIR=$2
export GLOG_v=1 export GLOG_v=1
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
python ${PROJECT_DIR}/../src/dataset.py \ python ${PROJECT_DIR}/../src/dataset.py \
--coco_data_dir="" \ --coco_data_dir=$COCO_DIR \
--mindrecord_dir="" \ --mindrecord_dir=$MINDRECORD_DIR \
--mindrecord_prefix="coco_hp.train.mind" > create_dataset.log 2>&1 & --mindrecord_prefix="coco_hp.train.mind" > create_dataset.log 2>&1 &

@ -16,17 +16,23 @@
echo "================================================================================================================" echo "================================================================================================================"
echo "Please run the script as: " echo "Please run the script as: "
echo "bash run_distributed_train_ascend.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH RANK_TABLE_FILE" echo "bash run_distributed_train_ascend.sh MINDRECORD_DIR RANK_TABLE_FILE LOAD_CHECKPOINT_PATH"
echo "for example: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/load_ckpt /path/hccl.json" echo "for example: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/hccl.json /path/load_ckpt"
echo "if no ckpt, just run: bash run_distributed_train_ascend.sh /path/mindrecord_dataset \"\" /path/hccl.json" echo "if no ckpt, just run: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/hccl.json"
echo "It is better to use the absolute path." echo "It is better to use the absolute path."
echo "For hyper parameter, please note that you should customize the scripts: echo "For hyper parameter, please note that you should customize the scripts:
'{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' " '{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' "
echo "================================================================================================================" echo "================================================================================================================"
CUR_DIR=`pwd` CUR_DIR=`pwd`
MINDRECORD_DIR=$1 MINDRECORD_DIR=$1
LOAD_CHECKPOINT_PATH=$2 HCCL_RANK_FILE=$2
HCCL_RANK_FILE=$3 if [ $# == 3 ];
then
LOAD_CHECKPOINT_PATH=$3
else
LOAD_CHECKPOINT_PATH=""
fi
python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_train_cmd.py \ python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_train_cmd.py \
--run_script_dir=${CUR_DIR}/train.py \ --run_script_dir=${CUR_DIR}/train.py \

@ -18,12 +18,17 @@ echo "==========================================================================
echo "Please run the scipt as: " echo "Please run the scipt as: "
echo "bash run_standalone_train_ascend.sh DEVICE_ID MINDRECORD_DIR LOAD_CHECKPOINT_PATH" echo "bash run_standalone_train_ascend.sh DEVICE_ID MINDRECORD_DIR LOAD_CHECKPOINT_PATH"
echo "for example: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset /path/load_ckpt" echo "for example: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset /path/load_ckpt"
echo "if no ckpt, just run: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset \"\" " echo "if no ckpt, just run: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset"
echo "==============================================================================================================" echo "=============================================================================================================="
DEVICE_ID=$1 DEVICE_ID=$1
MINDRECORD_DIR=$2 MINDRECORD_DIR=$2
LOAD_CHECKPOINT_PATH=$3 if [ $# == 3 ];
then
LOAD_CHECKPOINT_PATH=$3
else
LOAD_CHECKPOINT_PATH=""
fi
mkdir -p ms_log mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)

@ -18,11 +18,17 @@ echo "==========================================================================
echo "Please run the scipt as: " echo "Please run the scipt as: "
echo "bash run_standalone_train_cpu.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH" echo "bash run_standalone_train_cpu.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH"
echo "for example: bash run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt" echo "for example: bash run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt"
echo "if no ckpt, just run: bash run_standalone_train_cpu.sh /path/mindrecord_dataset \"\" " echo "if no ckpt, just run: bash run_standalone_train_cpu.sh /path/mindrecord_dataset"
echo "==============================================================================================================" echo "=============================================================================================================="
MINDRECORD_DIR=$1 MINDRECORD_DIR=$1
LOAD_CHECKPOINT_PATH=$2 if [ $# == 2 ];
then
LOAD_CHECKPOINT_PATH=$2
echo
else
LOAD_CHECKPOINT_PATH=""
fi
mkdir -p ms_log mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)

Loading…
Cancel
Save