|
|
|
@ -23,12 +23,12 @@ from mindspore import log as logger
|
|
|
|
|
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def graphdata_startserver():
|
|
|
|
|
def graphdata_startserver(server_port):
|
|
|
|
|
"""
|
|
|
|
|
start graphdata server
|
|
|
|
|
"""
|
|
|
|
|
logger.info('test start server.\n')
|
|
|
|
|
ds.GraphData(DATASET_FILE, 1, 'server')
|
|
|
|
|
ds.GraphData(DATASET_FILE, 1, 'server', port=server_port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RandomBatchedSampler(ds.Sampler):
|
|
|
|
@ -83,11 +83,13 @@ def test_graphdata_distributed():
|
|
|
|
|
"""
|
|
|
|
|
logger.info('test distributed.\n')
|
|
|
|
|
|
|
|
|
|
p1 = Process(target=graphdata_startserver)
|
|
|
|
|
server_port = random.randint(10000, 60000)
|
|
|
|
|
|
|
|
|
|
p1 = Process(target=graphdata_startserver, args=(server_port,))
|
|
|
|
|
p1.start()
|
|
|
|
|
time.sleep(2)
|
|
|
|
|
|
|
|
|
|
g = ds.GraphData(DATASET_FILE, 1, 'client')
|
|
|
|
|
g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
|
|
|
|
|
nodes = g.get_all_nodes(1)
|
|
|
|
|
assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
|
|
|
|
|
row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
|
|
|
|
|