|
|
@ -12,6 +12,7 @@
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
# ==============================================================================
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
from mindspore import log as logger
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
|
|
|
|
|
|
@ -382,6 +383,35 @@ def test_weighted_random_sampler():
|
|
|
|
logger.info("Number of data in data1: {}".format(num_iter))
|
|
|
|
logger.info("Number of data in data1: {}".format(num_iter))
|
|
|
|
assert num_iter == 11
|
|
|
|
assert num_iter == 11
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_weighted_random_sampler_exception():
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Test error cases for WeightedRandomSampler
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
logger.info("Test error cases for WeightedRandomSampler")
|
|
|
|
|
|
|
|
error_msg_1 = "type of weights element should be number"
|
|
|
|
|
|
|
|
with pytest.raises(TypeError, match=error_msg_1):
|
|
|
|
|
|
|
|
weights = ""
|
|
|
|
|
|
|
|
ds.WeightedRandomSampler(weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
error_msg_2 = "type of weights element should be number"
|
|
|
|
|
|
|
|
with pytest.raises(TypeError, match=error_msg_2):
|
|
|
|
|
|
|
|
weights = (0.9, 0.8, 1.1)
|
|
|
|
|
|
|
|
ds.WeightedRandomSampler(weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
error_msg_3 = "weights size should not be 0"
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match=error_msg_3):
|
|
|
|
|
|
|
|
weights = []
|
|
|
|
|
|
|
|
ds.WeightedRandomSampler(weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
error_msg_4 = "weights should not contain negative numbers"
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match=error_msg_4):
|
|
|
|
|
|
|
|
weights = [1.0, 0.1, 0.02, 0.3, -0.4]
|
|
|
|
|
|
|
|
ds.WeightedRandomSampler(weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
error_msg_5 = "elements of weights should not be all zero"
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match=error_msg_5):
|
|
|
|
|
|
|
|
weights = [0, 0, 0, 0, 0]
|
|
|
|
|
|
|
|
ds.WeightedRandomSampler(weights)
|
|
|
|
|
|
|
|
|
|
|
|
def test_imagefolder_rename():
|
|
|
|
def test_imagefolder_rename():
|
|
|
|
logger.info("Test Case rename")
|
|
|
|
logger.info("Test Case rename")
|
|
|
@ -465,6 +495,9 @@ if __name__ == '__main__':
|
|
|
|
test_weighted_random_sampler()
|
|
|
|
test_weighted_random_sampler()
|
|
|
|
logger.info('test_weighted_random_sampler Ended.\n')
|
|
|
|
logger.info('test_weighted_random_sampler Ended.\n')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_weighted_random_sampler_exception()
|
|
|
|
|
|
|
|
logger.info('test_weighted_random_sampler_exception Ended.\n')
|
|
|
|
|
|
|
|
|
|
|
|
test_imagefolder_numshards()
|
|
|
|
test_imagefolder_numshards()
|
|
|
|
logger.info('test_imagefolder_numshards Ended.\n')
|
|
|
|
logger.info('test_imagefolder_numshards Ended.\n')
|
|
|
|
|
|
|
|
|
|
|
|