|
|
|
@ -29,6 +29,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/inference/analysis/device.h"
|
|
|
|
|
#include "paddle/fluid/inference/analysis/dot.h"
|
|
|
|
|
#include "paddle/fluid/inference/analysis/helper.h"
|
|
|
|
|
#include "paddle/fluid/platform/variant.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace inference {
|
|
|
|
@ -38,39 +39,35 @@ class NodeMap;
|
|
|
|
|
|
|
|
|
|
// A helper class to maintain the status from Pass.
|
|
|
|
|
struct NodeAttr {
|
|
|
|
|
using any_t =
|
|
|
|
|
boost::variant<bool, float, int32_t, int64_t, void *, std::string>;
|
|
|
|
|
// NOTE T should be a primary type or a struct combined by several primary
|
|
|
|
|
// types.
|
|
|
|
|
// NOTE the STL containers should not use here.
|
|
|
|
|
// Some usages
|
|
|
|
|
// Attr attr;
|
|
|
|
|
// attr.Bool() = true;
|
|
|
|
|
|
|
|
|
|
bool &Bool() { return As<bool>(); }
|
|
|
|
|
float &Float() { return As<float>(); }
|
|
|
|
|
int32_t &Int32() { return As<int32_t>(); }
|
|
|
|
|
int64_t &Int64() { return As<int64_t>(); }
|
|
|
|
|
void *&Pointer() { return As<void *>(); }
|
|
|
|
|
std::string &String();
|
|
|
|
|
std::string &String() { return As<std::string>(); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
template <typename T>
|
|
|
|
|
T &As() {
|
|
|
|
|
// init storage in the first usage.
|
|
|
|
|
if (data_.empty()) {
|
|
|
|
|
VLOG(4) << "resize data to " << sizeof(T);
|
|
|
|
|
type_index_ = std::type_index(typeid(T));
|
|
|
|
|
data_.resize(sizeof(T));
|
|
|
|
|
if (type_index_ == typeid(NodeAttr)) {
|
|
|
|
|
type_index_ = typeid(T);
|
|
|
|
|
any_data_ = T();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE(type_index_ == typeid(T), "fetch error type");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(framework::IsType<T>(type_index_),
|
|
|
|
|
"type not matched, origin is %s, want %s",
|
|
|
|
|
DataTypeNamer::Global().repr(type_index_),
|
|
|
|
|
DataTypeNamer::Global().repr<T>());
|
|
|
|
|
PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
|
|
|
|
|
return *reinterpret_cast<T *>(&data_[0]);
|
|
|
|
|
return boost::get<T>(any_data_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::string data_;
|
|
|
|
|
any_t any_data_;
|
|
|
|
|
std::type_index type_index_{typeid(NodeAttr)};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|