|
|
|
@ -34,8 +34,7 @@ namespace distributed {
|
|
|
|
|
|
|
|
|
|
static void SerializeDestroyCallback(void* payload) {
|
|
|
|
|
if (payload != nullptr) {
|
|
|
|
|
auto* shared_payload =
|
|
|
|
|
reinterpret_cast<std::shared_ptr<memory::Allocation>*>(payload);
|
|
|
|
|
auto* shared_payload = reinterpret_cast<TensorPayload*>(payload);
|
|
|
|
|
delete shared_payload;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -46,7 +45,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
|
|
|
|
|
const std::string& out_name) {
|
|
|
|
|
platform::RecordRPCEvent record_event("serial", &ctx);
|
|
|
|
|
VarMsg request;
|
|
|
|
|
std::shared_ptr<memory::Allocation>* payload = nullptr;
|
|
|
|
|
TensorPayload* payload = nullptr;
|
|
|
|
|
|
|
|
|
|
request.set_varname(name);
|
|
|
|
|
// Note: normally the profiler is enabled in 1 trainer, hence only
|
|
|
|
@ -65,12 +64,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
|
|
|
|
|
}
|
|
|
|
|
if (var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
request.set_type(::sendrecv::LOD_TENSOR);
|
|
|
|
|
payload = new std::shared_ptr<memory::Allocation>(
|
|
|
|
|
GetTensorPayload(var, ctx, &request));
|
|
|
|
|
payload = new TensorPayload(GetTensorPayload(var, ctx, &request));
|
|
|
|
|
} else if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
request.set_type(::sendrecv::SELECTED_ROWS);
|
|
|
|
|
payload = new std::shared_ptr<memory::Allocation>(
|
|
|
|
|
GetSelectedRowsPayload(var, ctx, &request));
|
|
|
|
|
payload = new TensorPayload(GetSelectedRowsPayload(var, ctx, &request));
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
} else if (var->IsType<ncclUniqueId>()) {
|
|
|
|
|
request.set_type(::sendrecv::NCCL_ID);
|
|
|
|
@ -106,16 +103,16 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(payload);
|
|
|
|
|
|
|
|
|
|
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
|
|
|
|
|
payload->get()->size());
|
|
|
|
|
payload->memory_size());
|
|
|
|
|
// steal reference of tensor data
|
|
|
|
|
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
|
|
|
|
|
int num_slices = 2; // only SelectedRows have rows buffer
|
|
|
|
|
slices[0] = ::grpc::Slice(e.size());
|
|
|
|
|
memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
|
|
|
|
|
slices[1] = ::grpc::Slice(grpc_slice_new_with_user_data(
|
|
|
|
|
payload->get()->ptr(), payload->get()->size(),
|
|
|
|
|
SerializeDestroyCallback, payload),
|
|
|
|
|
::grpc::Slice::STEAL_REF);
|
|
|
|
|
slices[1] = ::grpc::Slice(
|
|
|
|
|
grpc_slice_new_with_user_data(payload->ptr(), payload->memory_size(),
|
|
|
|
|
SerializeDestroyCallback, payload),
|
|
|
|
|
::grpc::Slice::STEAL_REF);
|
|
|
|
|
|
|
|
|
|
if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto* slr = var->GetMutable<framework::SelectedRows>();
|
|
|
|
|