!13952 fix NoRepeatNGram example

From: @yanzhenxiang2020
Reviewed-by: @wuxuejian,@liangchenghui
Signed-off-by: @wuxuejian
pull/13952/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 33c9fda50a

@ -155,19 +155,28 @@ class NoRepeatNGram(PrimitiveWithInfer):
Examples:
>>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3)
>>> state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2],
[9, 3, 9, 5, 4, 1, 5]],
[[4, 8, 6, 4, 5, 6, 4],
[4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32)
>>> log_probs = Tensor([[[0.75858542, 0.8437121 , 0.69025469, 0.79379992, 0.27400691,
0.84709179, 0.78771346, 0.68587179, 0.22943851, 0.17682976]],
[[0.99401879, 0.77239773, 0.81973878, 0.32085208, 0.59944118,
0.3125177, 0.52604189, 0.77111461, 0.98443699, 0.71532898]]], dtype=mindspore.float32)
... [9, 3, 9, 5, 4, 1, 5]],
... [[4, 8, 6, 4, 5, 6, 4],
... [4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32)
>>> log_probs = Tensor([[[0.7, 0.8, 0.6, 0.9, 0.2, 0.8, 0.4, 0.6, 0.2, 0.7],
... [0.4, 0.5, 0.6, 0.7, 0.8, 0.1, 0.9, 0.8, 0.7, 0.1]],
... [[0.9, 0.7, 0.6, 0.3, 0.5, 0.3, 0.5, 0.4, 0.8, 0.6],
... [0.5, 0.8, 0.8, 0.7, 0.7, 0.8, 0.2, 0.7, 0.9, 0.7]]], dtype=mindspore.float32)
>>> output = no_repeat_ngram(state_seq, log_probs)
>>> print(output)
[[[0.75858542 -3.4028235e+38 0.69025469 0.79379992 0.27400691
-3.4028235e+38 0.78771346 0.68587179 0.22943851 0.17682976]]
[[0.99401879 0.77239773 0.81973878 0.32085208 0.59944118
-3.4028235e+38 0.52604189 0.77111461 0.98443699 0.71532898]]]
[[[ 6.9999999e-01 -3.4028235e+38 6.0000002e-01 8.9999998e-01
2.0000000e-01 -3.4028235e+38 4.0000001e-01 6.0000002e-01
2.0000000e-01 6.9999999e-01]
[ 4.0000001e-01 5.0000000e-01 6.0000002e-01 6.9999999e-01
8.0000001e-01 1.0000000e-01 8.9999998e-01 8.0000001e-01
6.9999999e-01 1.0000000e-01]]
[[ 8.9999998e-01 6.9999999e-01 6.0000002e-01 3.0000001e-01
5.0000000e-01 -3.4028235e+38 5.0000000e-01 4.0000001e-01
8.0000001e-01 6.0000002e-01]
[ 5.0000000e-01 8.0000001e-01 8.0000001e-01 6.9999999e-01
6.9999999e-01 8.0000001e-01 2.0000000e-01 6.9999999e-01
-3.4028235e+38 6.9999999e-01]]]
"""
@prim_attr_register
@ -179,11 +188,11 @@ class NoRepeatNGram(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['state_seq', 'log_probs'], outputs=['log_probs'])
def infer_shape(self, seq_shape, log_shape):
validator.check_int(len(seq_shape), 3, Rel.EQ, "rank_of_seq", self.name)
validator.check_int(len(log_shape), 3, Rel.EQ, "rank_of_log", self.name)
validator.check_int(seq_shape[0], log_shape[0], Rel.EQ, "seq_shape shape[0]", self.name)
validator.check_int(seq_shape[1], log_shape[1], Rel.EQ, "seq_shape shape[1]", self.name)
validator.check_int(self.ngram_size, seq_shape[2] + 1, Rel.LE, "ngram_size", self.name)
validator.check_int(len(seq_shape), 3, Rel.EQ, "rank of state_seq", self.name)
validator.check_int(len(log_shape), 3, Rel.EQ, "rank of log_probs", self.name)
validator.check("state_seq shape[0]", seq_shape[0], "log_probs shape[0]", log_shape[0], Rel.EQ, self.name)
validator.check("state_seq shape[1]", seq_shape[1], "log_probs shape[1]", log_shape[1], Rel.EQ, self.name)
validator.check("ngram_size", self.ngram_size, "state_seq shape[2] + 1", seq_shape[2] + 1, Rel.LE, self.name)
return log_shape
def infer_dtype(self, seq_type, log_type):

Loading…
Cancel
Save