|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
|
|
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
@ -44,16 +44,13 @@ class NetOneHot(nn.Cell):
|
|
|
|
|
self.one_hot_3(indices3), self.one_hot_4(indices4))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_gpu_training
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_one_hot():
|
|
|
|
|
one_hot = NetOneHot()
|
|
|
|
|
indices1 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(np.int32))
|
|
|
|
|
indices2 = Tensor(np.array([1, 2, 3]).astype(np.int32))
|
|
|
|
|
indices3 = Tensor(np.array([[0, 1], [1, 0]]).astype(np.int32))
|
|
|
|
|
indices4 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(np.int32))
|
|
|
|
|
output = one_hot(indices1, indices2, indices3, indices4)
|
|
|
|
|
def one_hot(nptype):
|
|
|
|
|
one_hot_net = NetOneHot()
|
|
|
|
|
indices1 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(nptype))
|
|
|
|
|
indices2 = Tensor(np.array([1, 2, 3]).astype(nptype))
|
|
|
|
|
indices3 = Tensor(np.array([[0, 1], [1, 0]]).astype(nptype))
|
|
|
|
|
indices4 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(nptype))
|
|
|
|
|
output = one_hot_net(indices1, indices2, indices3, indices4)
|
|
|
|
|
expect_0 = np.array([
|
|
|
|
|
[[2., 3., 3., 3., 3., 3.], [3., 2., 3., 3., 3., 3.]],
|
|
|
|
|
[[3., 3., 3., 3., 2., 3.], [3., 3., 3., 3., 3., 2.]],
|
|
|
|
@ -80,3 +77,15 @@ def test_one_hot():
|
|
|
|
|
assert (output[1].asnumpy() == expect_1).all()
|
|
|
|
|
assert (output[2].asnumpy() == expect_2).all()
|
|
|
|
|
assert (output[3].asnumpy() == expect_3).all()
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_gpu_training
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_one_hot_int32():
|
|
|
|
|
one_hot(np.int32)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_gpu_training
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_one_hot_int64():
|
|
|
|
|
one_hot(np.int64)
|
|
|
|
|