|
|
|
@ -16,6 +16,7 @@ limitations under the License. */
|
|
|
|
|
#include <nccl.h>
|
|
|
|
|
#endif
|
|
|
|
|
#include <sys/time.h>
|
|
|
|
|
#include <limits>
|
|
|
|
|
#include <thread> // NOLINT
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/data_type.h"
|
|
|
|
@ -31,7 +32,12 @@ namespace distributed {
|
|
|
|
|
|
|
|
|
|
class IOBufWriter {
|
|
|
|
|
public:
|
|
|
|
|
static void Append(butil::IOBuf* iobuf, int k, const char* v, int64_t vlen) {
|
|
|
|
|
static void Append(const std::string& varname, butil::IOBuf* iobuf, int k,
|
|
|
|
|
const char* v, int64_t vlen) {
|
|
|
|
|
if (vlen >= std::numeric_limits<int>::max() || vlen < 0) {
|
|
|
|
|
LOG(FATAL) << "AppendZeroCopy varname:" << varname << ", vlen:" << vlen;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
iobuf->append(reinterpret_cast<char*>(&k), 4);
|
|
|
|
|
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
|
|
|
|
|
iobuf->append(v, vlen);
|
|
|
|
@ -87,6 +93,10 @@ class IOBufWriter {
|
|
|
|
|
int k, const char* v, int64_t vlen,
|
|
|
|
|
bool in_cuda_pinned, void (*destroy)(void*),
|
|
|
|
|
void* user_data) {
|
|
|
|
|
if (vlen >= std::numeric_limits<int>::max() || vlen < 0) {
|
|
|
|
|
LOG(FATAL) << "AppendZeroCopy varname:" << varname << ", vlen:" << vlen;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_BRPC_RDMA
|
|
|
|
|
IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned,
|
|
|
|
|
destroy, user_data);
|
|
|
|
@ -134,7 +144,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
|
|
|
|
|
request->set_type(::sendrecv::NCCL_ID);
|
|
|
|
|
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
|
|
|
|
|
// TODO(gongwb): use append_zero to avoid data copy.
|
|
|
|
|
IOBufWriter::Append(iobuf,
|
|
|
|
|
IOBufWriter::Append(name, iobuf,
|
|
|
|
|
sendrecv::VariableMessage::kSerializedFieldNumber,
|
|
|
|
|
uid.internal, NCCL_UNIQUE_ID_BYTES);
|
|
|
|
|
return;
|
|
|
|
@ -149,7 +159,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
|
|
|
|
|
// FIXME(gongwb): it seems that can use zero copy.
|
|
|
|
|
if (var_is_not_stable) {
|
|
|
|
|
IOBufWriter::Append(
|
|
|
|
|
iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
|
|
|
|
|
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
|
|
|
|
|
static_cast<const char*>(payload->ptr()), payload->memory_size());
|
|
|
|
|
} else {
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
@ -171,10 +181,11 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
|
|
|
|
|
|
|
|
|
|
if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto* slr = var->GetMutable<framework::SelectedRows>();
|
|
|
|
|
size_t rows_memory_size =
|
|
|
|
|
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
|
|
|
|
|
PADDLE_ENFORCE(VectorElemName(slr->rows()) == typeid(int64_t).name());
|
|
|
|
|
size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
|
|
|
|
|
|
|
|
|
|
IOBufWriter::Append(iobuf, ::sendrecv::VariableMessage::kRowsFieldNumber,
|
|
|
|
|
IOBufWriter::Append(name, iobuf,
|
|
|
|
|
::sendrecv::VariableMessage::kRowsFieldNumber,
|
|
|
|
|
reinterpret_cast<const char*>(slr->rows().data()),
|
|
|
|
|
static_cast<int64_t>(rows_memory_size));
|
|
|
|
|
}
|
|
|
|
|