|
|
|
@ -63,6 +63,26 @@ struct TensorCopyVisitor {
|
|
|
|
|
int64_t size_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct TensorFillVisitor {
|
|
|
|
|
TensorFillVisitor(framework::Tensor* dst, int64_t dst_offset, int64_t size,
|
|
|
|
|
float value)
|
|
|
|
|
: dst_(dst), dst_offset_(dst_offset), size_(size) {}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void apply() const {
|
|
|
|
|
// TODO(Yancey1989): support other place
|
|
|
|
|
platform::CPUPlace cpu;
|
|
|
|
|
auto* tensor_data = dst_->mutable_data<T>(cpu);
|
|
|
|
|
auto* start = tensor_data + dst_offset_;
|
|
|
|
|
auto* end = start + size_;
|
|
|
|
|
std::fill(start, end, static_cast<T>(0.0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::Tensor* dst_;
|
|
|
|
|
int64_t dst_offset_;
|
|
|
|
|
int64_t size_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) {
|
|
|
|
|
{ // the 1st field, uint32_t version
|
|
|
|
@ -120,7 +140,17 @@ bool SelectedRows::HasKey(int64_t key) const {
|
|
|
|
|
: true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown) {
|
|
|
|
|
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown,
|
|
|
|
|
bool is_test) {
|
|
|
|
|
if (is_test) {
|
|
|
|
|
auto iter = id_to_index_.find(key);
|
|
|
|
|
if (iter == id_to_index_.end()) {
|
|
|
|
|
return -1;
|
|
|
|
|
} else {
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rwlock_->RDLock();
|
|
|
|
|
auto iter = id_to_index_.find(key);
|
|
|
|
|
if (iter == id_to_index_.end()) {
|
|
|
|
@ -172,7 +202,7 @@ void SelectedRows::SyncIndex() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
|
|
|
|
|
bool auto_grown) {
|
|
|
|
|
bool auto_grown, bool is_test) {
|
|
|
|
|
PADDLE_ENFORCE(value->IsInitialized(),
|
|
|
|
|
"The value tensor should be initialized.");
|
|
|
|
|
if (ids.numel() == 0) {
|
|
|
|
@ -183,11 +213,18 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
|
|
|
|
|
"output tensor should have the same shape with table "
|
|
|
|
|
"except the dims[0].");
|
|
|
|
|
for (int i = 0; i < ids.numel(); ++i) {
|
|
|
|
|
int64_t index = AutoGrownIndex(ids.data<int64_t>()[i], auto_grown);
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
framework::ToDataType(value_->type()),
|
|
|
|
|
TensorCopyVisitor(value, i * value_width, *value_.get(),
|
|
|
|
|
index * value_width, value_width));
|
|
|
|
|
int64_t index =
|
|
|
|
|
AutoGrownIndex(ids.data<int64_t>()[i], auto_grown, is_test);
|
|
|
|
|
if (index < 0) {
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
framework::ToDataType(value_->type()),
|
|
|
|
|
TensorFillVisitor(value, i * value_width, value_width, 0.0));
|
|
|
|
|
} else {
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
framework::ToDataType(value_->type()),
|
|
|
|
|
TensorCopyVisitor(value, i * value_width, *value_.get(),
|
|
|
|
|
index * value_width, value_width));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|