enable PE return numpy (#11704)

port
chengduo 7 years ago committed by GitHub
parent 991cedb4c3
commit a64844ad00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -78,6 +78,8 @@ def as_numpy(tensor):
Returns: Returns:
numpy.ndarray numpy.ndarray
""" """
if isinstance(tensor, core.LoDTensorArray):
return [as_numpy(t) for t in tensor]
if isinstance(tensor, list): if isinstance(tensor, list):
return [as_numpy(t) for t in tensor] return [as_numpy(t) for t in tensor]
assert isinstance(tensor, core.LoDTensor) assert isinstance(tensor, core.LoDTensor)

@ -160,7 +160,7 @@ class ParallelExecutor(object):
build_strategy, num_trainers, trainer_id) build_strategy, num_trainers, trainer_id)
self.scope = scope self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None): def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=False):
""" """
Run a parallel executor with fetch_list. Run a parallel executor with fetch_list.
@ -196,6 +196,8 @@ class ParallelExecutor(object):
to each device. Default None. to each device. Default None.
feed_dict: Alias for feed parameter, for backward compatibility. feed_dict: Alias for feed parameter, for backward compatibility.
This parameter has been deprecated. Default None. This parameter has been deprecated. Default None.
return_numpy(bool): Whether converts the fetched tensor to numpy.
Default: False.
Returns: Returns:
List: The fetched result list. List: The fetched result list.
@ -270,6 +272,9 @@ class ParallelExecutor(object):
if self.is_dist: if self.is_dist:
self.bcast_params() self.bcast_params()
if return_numpy:
return executor.as_numpy(arr)
return [arr[i] for i in range(len(arr))] return [arr[i] for i in range(len(arr))]
def bcast_params(self): def bcast_params(self):

@ -75,7 +75,9 @@ class TestFetchOp(unittest.TestCase):
fetch_list.append(k) fetch_list.append(k)
for data in train_inputs: for data in train_inputs:
ret = pe.run(fetch_list, feed=feeder.feed(data)) ret = pe.run(fetch_list,
feed=feeder.feed(data),
return_numpy=True)
for i in range(len(fetch_list)): for i in range(len(fetch_list)):
assert not math.isnan(np.sum(ret[i])) and \ assert not math.isnan(np.sum(ret[i])) and \
not math.isinf(np.sum(ret[i])) not math.isinf(np.sum(ret[i]))

Loading…
Cancel
Save