Fix embedding layer

Fix input check
pull/9834/head
huangxinjing 4 years ago
parent 10d326ee62
commit 996ee72c50

@ -189,6 +189,7 @@ class EmbeddingLookup(Cell):
target='CPU', slice_mode='batch_slice', manual_shapes=None,
max_norm=None, sparse=True, vocab_cache_size=0):
super(EmbeddingLookup, self).__init__()
validator.check_value_type('sparse', sparse, [bool], self.cls_name)
self.target = target
if target not in ('CPU', 'DEVICE'):
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
@ -200,9 +201,9 @@ class EmbeddingLookup(Cell):
else:
self.gatherv2 = P.GatherV2()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
self.vocab_cache_size = validator.check_value_type('vocab_cache_size', vocab_cache_size, [int], self.cls_name)
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.cache_enable = self.vocab_cache_size > 0
@ -355,7 +356,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
slice_mode, feature_num_list, max_norm, sparse)
self.field_size = validator.check_value_type('field_size', field_size, [int], self.cls_name)
self.field_size = validator.check_positive_int(field_size, 'field_size')
self.operator = operator
self.mul = P.Mul()
@ -429,7 +430,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
batch_size = self.shape(input_indices)[0]
num_segments = batch_size * self.field_size
bias = Range(0, num_segments, self.field_size)()
bias = self.reshape(bias, (self.field_size, -1))
bias = self.reshape(bias, (batch_size, -1))
field_ids = self.bias_add(field_ids, bias)
if self.target == "CPU":

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -23,7 +23,6 @@ from mindspore import Tensor, context
from mindspore.nn import TrainOneStepCell, Adam
from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation(get_all=True)
@ -48,10 +47,11 @@ class NetWithLoss(nn.Cell):
class Net(nn.Cell):
def __init__(self, shape, slice_mode=nn.EmbeddingLookup.BATCH_SLICE, target="Device", operator='SUM'):
def __init__(self, shape, field_size=10, slice_mode=nn.EmbeddingLookup.BATCH_SLICE, target="Device",
operator='SUM'):
super().__init__()
self.embedding = nn.MultiFieldEmbeddingLookup(vocab_size=32, embedding_size=64, target=target,
field_size=shape[1], slice_mode=slice_mode, operator=operator)
field_size=field_size, slice_mode=slice_mode, operator=operator)
self.reshape = P.Reshape().shard(((8, 1, 1),))
self.batch_size = shape[0]
@ -77,28 +77,28 @@ def compile_net(net, shape):
def test_embeddinglookup_batch_parallel_sum():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, target='DEVICE'))
net = NetWithLoss(Net(shape, field_size=10, target='DEVICE'))
compile_net(net, shape)
def test_embeddinglookup_row_parallel_sum():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, target='DEVICE'))
net = NetWithLoss(Net(shape, field_size=9, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, target='DEVICE'))
compile_net(net, shape)
def test_embeddinglookup_column_parallel_sum():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, target='DEVICE'))
net = NetWithLoss(Net(shape, field_size=10, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, target='DEVICE'))
compile_net(net, shape)
def test_embeddinglookup_batch_parallel_mean():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, target='DEVICE', operator='MEAN'))
net = NetWithLoss(Net(shape, field_size=1, target='DEVICE', operator='MEAN'))
compile_net(net, shape)

Loading…
Cancel
Save