!2676 Mass text summarization update.

Merge pull request !2676 from linqingke/mass
pull/2676/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit c9ba849969

@ -18,7 +18,7 @@ export DEVICE_ID=0
export RANK_ID=0 export RANK_ID=0
export RANK_SIZE=1 export RANK_SIZE=1
options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab -- "$@"` options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab: -- "$@"`
eval set -- "$options" eval set -- "$options"
echo $options echo $options
@ -129,6 +129,7 @@ do
esac esac
done done
file_path=$(cd "$(dirname $0)" || exit; pwd)
for((i=0; i < $RANK_SIZE; i++)) for((i=0; i < $RANK_SIZE; i++))
do do
if [ $RANK_SIZE -gt 1 ] if [ $RANK_SIZE -gt 1 ]
@ -139,7 +140,6 @@ do
fi fi
echo "Working on device $i" echo "Working on device $i"
file_path=$(cd "$(dirname $0)" || exit; pwd)
cd $file_path || exit cd $file_path || exit
cd ../ || exit cd ../ || exit

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Dataset loader to feed into model.""" """Dataset loader to feed into model."""
import os
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as deC import mindspore.dataset.transforms.c_transforms as deC
@ -40,12 +39,6 @@ def _load_dataset(input_files, batch_size, epoch_count=1,
if not input_files: if not input_files:
raise FileNotFoundError("Require at least one dataset.") raise FileNotFoundError("Require at least one dataset.")
if not (schema_file and
os.path.exists(schema_file)
and os.path.isfile(schema_file)
and os.path.basename(schema_file).endswith(".json")):
raise FileNotFoundError("`dataset_schema` must be a existed json file.")
if not isinstance(sink_mode, bool): if not isinstance(sink_mode, bool):
raise ValueError("`sink` must be type of bool.") raise ValueError("`sink` must be type of bool.")

@ -47,7 +47,7 @@ def rouge(hypothesis: List[str], target: List[str]):
edited_ref.append(r + "\n") edited_ref.append(r + "\n")
_rouge = Rouge() _rouge = Rouge()
scores = _rouge.get_scores(edited_hyp, target, avg=True) scores = _rouge.get_scores(edited_hyp, edited_ref, avg=True)
print(" | ROUGE Score:") print(" | ROUGE Score:")
print(f" | RG-1(F): {scores['rouge-1']['f'] * 100:8.2f}") print(f" | RG-1(F): {scores['rouge-1']['f'] * 100:8.2f}")
print(f" | RG-2(F): {scores['rouge-2']['f'] * 100:8.2f}") print(f" | RG-2(F): {scores['rouge-2']['f'] * 100:8.2f}")

@ -120,6 +120,7 @@ def _build_training_pipeline(config: TransformerConfig,
test_dataset (Dataset): Test dataset. test_dataset (Dataset): Test dataset.
""" """
net_with_loss = TransformerNetworkWithLoss(config, is_training=True) net_with_loss = TransformerNetworkWithLoss(config, is_training=True)
net_with_loss.init_parameters_data()
if config.existed_ckpt: if config.existed_ckpt:
if config.existed_ckpt.endswith(".npz"): if config.existed_ckpt.endswith(".npz"):

Loading…
Cancel
Save