!10197 Add input check for the embedding

From: @huangxinjing
Reviewed-by: @stsuteng,@zhunaipan
Signed-off-by: @stsuteng
pull/10197/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 441e283e21

@ -33,6 +33,16 @@ from ..cell import Cell
__all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup'] __all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup']
@constexpr
def _check_input_2d(input_shape, param_name, func_name):
if len(input_shape) != 2:
raise ValueError(f"{func_name} {param_name} should be 2d, but got shape {input_shape}")
return True
@constexpr
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
class Embedding(Cell): class Embedding(Cell):
r""" r"""
@ -449,6 +459,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
def construct(self, input_indices, input_values, field_ids): def construct(self, input_indices, input_values, field_ids):
_check_input_2d(F.shape(input_indices), "input_indices", self.cls_name)
_check_input_2d(F.shape(input_values), "input_values", self.cls_name)
_check_input_2d(F.shape(field_ids), "field_ids", self.cls_name)
_check_input_dtype(F.dtype(input_indices), "input_indices", [mstype.int32, mstype.int64], self.cls_name)
_check_input_dtype(F.dtype(input_values), "input_values", [mstype.float32], self.cls_name)
_check_input_dtype(F.dtype(field_ids), "field_ids", [mstype.int32], self.cls_name)
batch_size = self.shape(input_indices)[0] batch_size = self.shape(input_indices)[0]
num_segments = batch_size * self.field_size num_segments = batch_size * self.field_size
bias = Range(0, num_segments, self.field_size)() bias = Range(0, num_segments, self.field_size)()

@ -14,11 +14,12 @@
# ============================================================================ # ============================================================================
""" test nn embedding """ """ test nn embedding """
import numpy as np import numpy as np
import pytest
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype from mindspore.common import dtype
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.nn import Embedding from mindspore.nn import Embedding, MultiFieldEmbeddingLookup
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
@ -43,6 +44,55 @@ def test_check_embedding_3():
_executor.compile(net, input_data) _executor.compile(net, input_data)
def compile_multi_field_embedding(shape_id, shape_value, shape_field,
type_id, type_value, type_field):
net = MultiFieldEmbeddingLookup(20000, 768, 3)
input_data = Tensor(np.ones(shape_id), type_id)
input_value = Tensor(np.ones(shape_value), type_value)
input_field = Tensor(np.ones(shape_field), type_field)
_executor.compile(net, input_data, input_value, input_field)
@non_graph_engine
def test_check_multifield_embedding_right_type():
compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
dtype.int64, dtype.float32, dtype.int32)
@non_graph_engine
def test_check_multifield_embedding_false_type_input():
with pytest.raises(TypeError):
compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
dtype.int16, dtype.float32, dtype.int32)
@non_graph_engine
def test_check_multifield_embedding_false_type_value():
with pytest.raises(TypeError):
compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
dtype.int16, dtype.float16, dtype.int32)
@non_graph_engine
def test_check_multifield_embedding_false_type_field_id():
with pytest.raises(TypeError):
compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
dtype.int16, dtype.float32, dtype.int16)
@non_graph_engine
def test_check_multifield_embedding_false_input_shape():
with pytest.raises(TypeError):
compile_multi_field_embedding((8,), (8, 200), (8, 200),
dtype.int16, dtype.float32, dtype.int16)
@non_graph_engine
def test_check_multifield_embedding_false_value_shape():
with pytest.raises(TypeError):
compile_multi_field_embedding((8, 200), (8,), (8, 200),
dtype.int16, dtype.float32, dtype.int16)
@non_graph_engine @non_graph_engine
def test_print_embedding(): def test_print_embedding():
net = Embedding(20000, 768, False) net = Embedding(20000, 768, False)

Loading…
Cancel
Save