|
|
|
@ -23,7 +23,6 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor_array.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/reader.h"
|
|
|
|
|
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h" // platform::Communicator
|
|
|
|
|
#include "paddle/fluid/platform/place.h"
|
|
|
|
|
#include "paddle/fluid/platform/profiler.h"
|
|
|
|
|
|
|
|
|
@ -54,15 +53,15 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
|
|
|
|
|
var->GetMutable<LoDTensorArray>();
|
|
|
|
|
} else if (var_type == proto::VarDesc::PLACE_LIST) {
|
|
|
|
|
var->GetMutable<platform::PlaceList>();
|
|
|
|
|
} else if (var_type == proto::VarDesc::NCCL_COM) {
|
|
|
|
|
var->GetMutable<platform::Communicator>();
|
|
|
|
|
} else if (var_type == proto::VarDesc::READER) {
|
|
|
|
|
var->GetMutable<ReaderHolder>();
|
|
|
|
|
} else if (var_type == proto::VarDesc::NCCL_COM) {
|
|
|
|
|
// GetMutable will be called in ncclInit
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"Variable type %d is not in "
|
|
|
|
|
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
|
|
|
|
|
"LOD_RANK_TABLE, PLACE_LIST, READER]",
|
|
|
|
|
"LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]",
|
|
|
|
|
var_type);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|