|
|
@ -43,13 +43,16 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
|
|
|
|
void* buf = buffer.get();
|
|
|
|
void* buf = buffer.get();
|
|
|
|
|
|
|
|
|
|
|
|
void* payload = nullptr;
|
|
|
|
void* payload = nullptr;
|
|
|
|
size_t payload_size;
|
|
|
|
size_t payload_size = 0;
|
|
|
|
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
|
|
|
|
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
|
|
|
|
e.WriteString(VarMsg::kVarnameFieldNumber, name);
|
|
|
|
e.WriteString(VarMsg::kVarnameFieldNumber, name);
|
|
|
|
if (var->IsType<framework::LoDTensor>()) {
|
|
|
|
if (var->IsType<framework::LoDTensor>()) {
|
|
|
|
e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
|
|
|
|
e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
|
|
|
|
} else if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
} else if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
|
|
|
|
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
|
|
|
|
|
|
|
|
} else if (var->IsType<ncclUniqueId>()) {
|
|
|
|
|
|
|
|
// NOTE: sendrecv only support RAW type for NCCL_ID
|
|
|
|
|
|
|
|
e.WriteUint64(VarMsg::kTypeFieldNumber, 2);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (!out_name.empty()) {
|
|
|
|
if (!out_name.empty()) {
|
|
|
@ -139,11 +142,27 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
|
|
|
|
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
|
|
|
|
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
|
|
|
|
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
|
|
|
|
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
|
|
|
|
} break;
|
|
|
|
} break;
|
|
|
|
|
|
|
|
case framework::proto::VarType_Type_RAW: {
|
|
|
|
|
|
|
|
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
|
|
|
|
|
|
|
|
NCCL_UNIQUE_ID_BYTES);
|
|
|
|
|
|
|
|
ncclUniqueId* uid = var->GetMutable<ncclUniqueId>();
|
|
|
|
|
|
|
|
e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES));
|
|
|
|
|
|
|
|
} break;
|
|
|
|
default:
|
|
|
|
default:
|
|
|
|
PADDLE_THROW("Serialize does not support type: %s",
|
|
|
|
PADDLE_THROW("Serialize does not support type: %s",
|
|
|
|
typeid(var->Type()).name());
|
|
|
|
typeid(var->Type()).name());
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (framework::ToVarType(var->Type()) == framework::proto::VarType_Type_RAW) {
|
|
|
|
|
|
|
|
// for serialize NCCL_ID
|
|
|
|
|
|
|
|
::grpc::Slice slices(e.size());
|
|
|
|
|
|
|
|
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
|
|
|
|
|
|
|
|
::grpc::ByteBuffer tmp(&slices, 1);
|
|
|
|
|
|
|
|
msg->Swap(&tmp);
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// steal reference of tensor data
|
|
|
|
// steal reference of tensor data
|
|
|
|
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
|
|
|
|
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
|
|
|
|
int num_slices = 2; // only SelectedRows have rows buffer
|
|
|
|
int num_slices = 2; // only SelectedRows have rows buffer
|
|
|
|