Fix bug for multi-GPU inference.

cblas_new
dangqingqing 8 years ago
parent 55115ac682
commit 7c13292cff

@ -35,6 +35,13 @@ class Inference(object):
name = param.getName()
assert isinstance(val, api.Vector)
val.copyFromNumpyArray(parameters.get(name).flatten())
# the setValueUpdated function is called in randomize, zeroMem,
# load function in paddle/parameter/Parameter.cpp. But in the
# inference mode, the setValueUpdated is never called, it will
# cause the parameter will not be dispatched
# in MultiGradientMachine for multi-GPU. So setValueUpdated is
# called here, but it's better to call this function in one place.
param.setValueUpdated()
self.__gradient_machine__ = gm
self.__data_types__ = topo.data_type()

Loading…
Cancel
Save