Merge pull request from jacquesqiao/fix-ctr-reader-svm

fix ctr reader read svm data
revert-15661-fix-cpu-broadcast
乔龙飞 Qiao Longfei 6 years ago committed by GitHub
commit 7ddf4e2c55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -213,7 +213,7 @@ void ReadSvmData(const DataDesc& data_desc, std::shared_ptr<Reader> reader,
framework::LoD lod{lod_data};
lod_tensor.set_lod(lod);
int64_t* tensor_data = lod_tensor.mutable_data<int64_t>(
framework::make_ddim({1, static_cast<int64_t>(batch_feasign.size())}),
framework::make_ddim({static_cast<int64_t>(batch_feasign.size()), 1}),
platform::CPUPlace());
memcpy(tensor_data, batch_feasign.data(),
batch_feasign.size() * sizeof(int64_t));
@ -223,7 +223,7 @@ void ReadSvmData(const DataDesc& data_desc, std::shared_ptr<Reader> reader,
// insert label tensor
framework::LoDTensor label_tensor;
auto* label_tensor_data = label_tensor.mutable_data<int64_t>(
framework::make_ddim({1, static_cast<int64_t>(batch_label.size())}),
framework::make_ddim({static_cast<int64_t>(batch_label.size()), 1}),
platform::CPUPlace());
memcpy(label_tensor_data, batch_label.data(),
batch_label.size() * sizeof(int64_t));

@ -123,7 +123,7 @@ TEST(CTR_READER, read_data) {
std::vector<std::tuple<LoD, std::vector<int64_t>>> data_slot_6003{b1, b2, b3,
b4};
std::vector<DDim> label_dims = {{1, 3}, {1, 3}, {1, 3}, {1, 1}};
std::vector<DDim> label_dims = {{3, 1}, {3, 1}, {3, 1}, {1, 1}};
LoDTensorBlockingQueueHolder queue_holder;
int capacity = 64;

Loading…
Cancel
Save