Mass text summarization fix bug.

pull/2676/head
linqingke 5 years ago
parent d6d93f16b1
commit eb5c4cf2ea

@ -18,7 +18,7 @@ export DEVICE_ID=0
export RANK_ID=0 export RANK_ID=0
export RANK_SIZE=1 export RANK_SIZE=1
options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab -- "$@"` options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab: -- "$@"`
eval set -- "$options" eval set -- "$options"
echo $options echo $options
@ -129,6 +129,7 @@ do
esac esac
done done
file_path=$(cd "$(dirname $0)" || exit; pwd)
for((i=0; i < $RANK_SIZE; i++)) for((i=0; i < $RANK_SIZE; i++))
do do
if [ $RANK_SIZE -gt 1 ] if [ $RANK_SIZE -gt 1 ]
@ -139,7 +140,6 @@ do
fi fi
echo "Working on device $i" echo "Working on device $i"
file_path=$(cd "$(dirname $0)" || exit; pwd)
cd $file_path || exit cd $file_path || exit
cd ../ || exit cd ../ || exit

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Dataset loader to feed into model.""" """Dataset loader to feed into model."""
import os
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as deC import mindspore.dataset.transforms.c_transforms as deC
@ -40,12 +39,6 @@ def _load_dataset(input_files, batch_size, epoch_count=1,
if not input_files: if not input_files:
raise FileNotFoundError("Require at least one dataset.") raise FileNotFoundError("Require at least one dataset.")
if not (schema_file and
os.path.exists(schema_file)
and os.path.isfile(schema_file)
and os.path.basename(schema_file).endswith(".json")):
raise FileNotFoundError("`dataset_schema` must be a existed json file.")
if not isinstance(sink_mode, bool): if not isinstance(sink_mode, bool):
raise ValueError("`sink` must be type of bool.") raise ValueError("`sink` must be type of bool.")

@ -47,7 +47,7 @@ def rouge(hypothesis: List[str], target: List[str]):
edited_ref.append(r + "\n") edited_ref.append(r + "\n")
_rouge = Rouge() _rouge = Rouge()
scores = _rouge.get_scores(edited_hyp, target, avg=True) scores = _rouge.get_scores(edited_hyp, edited_ref, avg=True)
print(" | ROUGE Score:") print(" | ROUGE Score:")
print(f" | RG-1(F): {scores['rouge-1']['f'] * 100:8.2f}") print(f" | RG-1(F): {scores['rouge-1']['f'] * 100:8.2f}")
print(f" | RG-2(F): {scores['rouge-2']['f'] * 100:8.2f}") print(f" | RG-2(F): {scores['rouge-2']['f'] * 100:8.2f}")

@ -120,6 +120,7 @@ def _build_training_pipeline(config: TransformerConfig,
test_dataset (Dataset): Test dataset. test_dataset (Dataset): Test dataset.
""" """
net_with_loss = TransformerNetworkWithLoss(config, is_training=True) net_with_loss = TransformerNetworkWithLoss(config, is_training=True)
net_with_loss.init_parameters_data()
if config.existed_ckpt: if config.existed_ckpt:
if config.existed_ckpt.endswith(".npz"): if config.existed_ckpt.endswith(".npz"):

Loading…
Cancel
Save