no getmutable nccl_com

yu239-patch-1
Yang Yang 7 years ago
parent 0e2deaa5fd
commit 0c45eab7ff

@ -23,7 +23,6 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h" // platform::Communicator
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
@ -54,15 +53,15 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarDesc::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarDesc::NCCL_COM) {
var->GetMutable<platform::Communicator>();
} else if (var_type == proto::VarDesc::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarDesc::NCCL_COM) {
// GetMutable will be called in ncclInit
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER]",
"LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]",
var_type);
}
}

@ -212,5 +212,5 @@ class ParallelOpTestMultipleInput(BaseParallelForTest):
fetch=['fc1.w@GRAD', 'fc2.w@GRAD', 'fc3.w@GRAD'])
#if __name__ == '__main__':
# unittest.main()
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save