Make code more efficient

fea/anakin-support-x86
minqiyang 7 years ago
parent 9812bb8b48
commit c4d000a990

@ -57,23 +57,18 @@ static DDim GetDims(const Scope& scope, const std::string& name,
return DDim({-1});
}
if (var->IsInitialized()) {
if (var->IsType<LoDTensor>()) {
const LoDTensor& tensor = var->Get<LoDTensor>();
if (tensor.IsInitialized()) {
return tensor.dims();
} else {
return DDim({-1});
}
} else if (var->IsType<SelectedRows>()) {
if (get_actual_dim) {
return var->Get<SelectedRows>().value().dims();
} else {
return var->Get<SelectedRows>().GetCompleteDims();
}
} else {
if (var->IsType<LoDTensor>()) {
const LoDTensor& tensor = var->Get<LoDTensor>();
if (UNLIKELY(tensor.IsInitialized())) {
return DDim({-1});
}
return tensor.dims();
} else if (var->IsType<SelectedRows>()) {
if (get_actual_dim) {
return var->Get<SelectedRows>().value().dims();
} else {
return var->Get<SelectedRows>().GetCompleteDims();
}
} else {
return DDim({-1});
}
@ -85,20 +80,15 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
return "";
}
if (var->IsInitialized()) {
if (var->IsType<LoDTensor>()) {
const LoDTensor& tensor = var->Get<LoDTensor>();
if (tensor.IsInitialized()) {
return DataTypeToString(ToDataType(tensor.type()));
} else {
return "";
}
} else if (var->IsType<SelectedRows>()) {
return DataTypeToString(
ToDataType(var->Get<SelectedRows>().value().type()));
} else {
if (var->IsType<LoDTensor>()) {
const LoDTensor& tensor = var->Get<LoDTensor>();
if (UNLIKELY(!tensor.IsInitialized())) {
return "";
}
return DataTypeToString(ToDataType(tensor.type()));
} else if (var->IsType<SelectedRows>()) {
return DataTypeToString(
ToDataType(var->Get<SelectedRows>().value().type()));
} else {
return "";
}
@ -110,10 +100,8 @@ static int GetRowSize(const Scope& scope, const std::string& name) {
return -1;
}
if (var->IsInitialized()) {
if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().rows().size();
}
if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().rows().size();
}
return -1;
@ -127,17 +115,12 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
return default_lod;
}
if (var->IsInitialized()) {
if (var->IsType<LoDTensor>()) {
const LoDTensor& tensor = var->Get<LoDTensor>();
if (tensor.IsInitialized()) {
return tensor.lod();
} else {
return default_lod;
}
} else {
if (var->IsType<LoDTensor>()) {
const LoDTensor& tensor = var->Get<LoDTensor>();
if (UNLIKELY(!tensor.IsInitialized())) {
return default_lod;
}
return tensor.lod();
} else {
return default_lod;
}

@ -82,7 +82,7 @@ class Tensor {
template <typename T>
const T* data() const;
bool IsInitialized() const;
inline bool IsInitialized() const;
/**
* @brief Return a pointer to mutable memory block.

Loading…
Cancel
Save