|
|
|
@ -31,35 +31,14 @@ namespace detail {
|
|
|
|
|
|
|
|
|
|
using VarMsg = sendrecv::VariableMessage;
|
|
|
|
|
|
|
|
|
|
VarMsg::Type DataTypeToEnum(std::type_index type) {
|
|
|
|
|
if (typeid(platform::float16).hash_code() == type.hash_code()) {
|
|
|
|
|
return VarMsg::FP16;
|
|
|
|
|
} else if (typeid(const float).hash_code() == type.hash_code()) {
|
|
|
|
|
// CPPLint complains Using C-style cast. Use static_cast<float>() instead
|
|
|
|
|
// One fix to this is to replace float with const float because
|
|
|
|
|
// typeid(T) == typeid(const T)
|
|
|
|
|
// http://en.cppreference.com/w/cpp/language/typeid
|
|
|
|
|
return VarMsg::FP32;
|
|
|
|
|
} else if (typeid(const double).hash_code() == type.hash_code()) {
|
|
|
|
|
return VarMsg::FP64;
|
|
|
|
|
} else if (typeid(const int).hash_code() == type.hash_code()) {
|
|
|
|
|
return VarMsg::INT32;
|
|
|
|
|
} else if (typeid(const int64_t).hash_code() == type.hash_code()) {
|
|
|
|
|
return VarMsg::INT64;
|
|
|
|
|
} else if (typeid(const bool).hash_code() == type.hash_code()) {
|
|
|
|
|
return VarMsg::BOOL;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Not supported");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GetTensorPayload(framework::Variable* var,
|
|
|
|
|
const platform::DeviceContext& ctx, VarMsg* request,
|
|
|
|
|
void** payload, size_t* payload_size) {
|
|
|
|
|
auto tensor = var->Get<framework::LoDTensor>();
|
|
|
|
|
// FIXME(wuyi): data types in send_recv.proto is not synced with
|
|
|
|
|
// FIXME(wuyi): data types in send_recv.proto is copied from
|
|
|
|
|
// framework.proto
|
|
|
|
|
request->set_data_type(DataTypeToEnum(tensor.type()));
|
|
|
|
|
request->set_data_type(
|
|
|
|
|
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
|
|
|
|
|
for (auto& dim : framework::vectorize(tensor.dims())) {
|
|
|
|
|
request->add_dims(dim);
|
|
|
|
|
}
|
|
|
|
@ -96,7 +75,8 @@ void GetSelectedRowsPayload(framework::Variable* var,
|
|
|
|
|
const platform::DeviceContext& ctx, VarMsg* request,
|
|
|
|
|
void** payload, size_t* payload_size) {
|
|
|
|
|
auto* slr = var->GetMutable<framework::SelectedRows>();
|
|
|
|
|
request->set_data_type(DataTypeToEnum(slr->value().type()));
|
|
|
|
|
request->set_data_type(
|
|
|
|
|
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
|
|
|
|
|
request->set_lod_level(0);
|
|
|
|
|
request->set_slr_height(slr->height());
|
|
|
|
|
|
|
|
|
@ -170,7 +150,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
|
|
|
|
|
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
|
|
|
|
|
e.WriteRawBytes(std::string(header.data(), header.size()));
|
|
|
|
|
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
|
|
|
|
|
|
|
|
|
|
// steal reference of tensor data
|
|
|
|
|
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
|
|
|
|
|
int num_slices = 2; // only SelectedRows have rows buffer
|
|
|
|
|