|
|
|
@ -122,7 +122,7 @@ bool SelectedRows::HasKey(int64_t key) const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SelectedRows::Get(int64_t key, framework::Tensor* value,
|
|
|
|
|
int64_t row) const {
|
|
|
|
|
int64_t offset) const {
|
|
|
|
|
int64_t index = Index(key);
|
|
|
|
|
PADDLE_ENFORCE_GE(index, 0, "The key should be exists in the Table.");
|
|
|
|
|
PADDLE_ENFORCE(value->IsInitialized(),
|
|
|
|
@ -138,7 +138,7 @@ bool SelectedRows::Get(int64_t key, framework::Tensor* value,
|
|
|
|
|
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
framework::ToDataType(value_->type()),
|
|
|
|
|
TensorCopyVisitor(cpu, value, row * value_width, *value_.get(),
|
|
|
|
|
TensorCopyVisitor(cpu, value, offset * value_width, *value_.get(),
|
|
|
|
|
index * value_width, value_width));
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|