|
|
|
@ -377,12 +377,10 @@ class PyLayer {
|
|
|
|
|
class PYBIND11_HIDDEN RuntimeInferVarTypeContext
|
|
|
|
|
: public framework::InferVarTypeContext {
|
|
|
|
|
public:
|
|
|
|
|
RuntimeInferVarTypeContext(imperative::OpBase* op,
|
|
|
|
|
const imperative::VarBasePtrMap* inputs,
|
|
|
|
|
RuntimeInferVarTypeContext(const imperative::VarBasePtrMap* inputs,
|
|
|
|
|
imperative::VarBasePtrMap* outputs,
|
|
|
|
|
const framework::AttributeMap* attrs_map)
|
|
|
|
|
: InferVarTypeContext(nullptr, nullptr),
|
|
|
|
|
op_(op),
|
|
|
|
|
inputs_(inputs),
|
|
|
|
|
outputs_(outputs),
|
|
|
|
|
attrs_(attrs_map),
|
|
|
|
@ -406,83 +404,86 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::Attribute GetAttr(const std::string& name) const {
|
|
|
|
|
virtual ~RuntimeInferVarTypeContext() {}
|
|
|
|
|
|
|
|
|
|
framework::Attribute GetAttr(const std::string& name) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(attrs_);
|
|
|
|
|
return attrs_->at(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline bool HasVar(const std::string& name) const {
|
|
|
|
|
bool HasVar(const std::string& name) const override {
|
|
|
|
|
return var_set_.count(name) > 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline bool HasInput(const std::string& name) const {
|
|
|
|
|
bool HasInput(const std::string& name) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(inputs_);
|
|
|
|
|
return inputs_->count(name) > 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline bool HasOutput(const std::string& name) const {
|
|
|
|
|
bool HasOutput(const std::string& name) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(outputs_);
|
|
|
|
|
return outputs_->count(name) > 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline const std::vector<std::string>& Input(const std::string& name) const {
|
|
|
|
|
const std::vector<std::string>& Input(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return input_names_.at(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline const std::vector<std::string>& Output(const std::string& name) const {
|
|
|
|
|
const std::vector<std::string>& Output(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return output_names_.at(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline framework::proto::VarType::Type GetType(
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
framework::proto::VarType::Type GetType(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return var_set_.at(name)->DType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void SetType(const std::string& name,
|
|
|
|
|
framework::proto::VarType::Type type) {
|
|
|
|
|
void SetType(const std::string& name,
|
|
|
|
|
framework::proto::VarType::Type type) override {
|
|
|
|
|
var_set_[name]->SetDType(type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline framework::proto::VarType::Type GetDataType(
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
framework::proto::VarType::Type GetDataType(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return var_set_.at(name)->DType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void SetDataType(const std::string& name,
|
|
|
|
|
framework::proto::VarType::Type type) {
|
|
|
|
|
void SetDataType(const std::string& name,
|
|
|
|
|
framework::proto::VarType::Type type) override {
|
|
|
|
|
var_set_[name]->SetDType(type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline std::vector<framework::proto::VarType::Type> GetDataTypes(
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
std::vector<framework::proto::VarType::Type> GetDataTypes(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
PADDLE_THROW("GetDataTypes is not supported in runtime InferVarType");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void SetDataTypes(
|
|
|
|
|
const std::string& name,
|
|
|
|
|
const std::vector<framework::proto::VarType::Type>& multiple_data_type) {
|
|
|
|
|
void SetDataTypes(const std::string& name,
|
|
|
|
|
const std::vector<framework::proto::VarType::Type>&
|
|
|
|
|
multiple_data_type) override {
|
|
|
|
|
PADDLE_THROW("SetDataTypes is not supported in runtime InferVarType");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline std::vector<int64_t> GetShape(const std::string& name) const {
|
|
|
|
|
std::vector<int64_t> GetShape(const std::string& name) const override {
|
|
|
|
|
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void SetShape(const std::string& name,
|
|
|
|
|
const std::vector<int64_t>& dims) {
|
|
|
|
|
void SetShape(const std::string& name,
|
|
|
|
|
const std::vector<int64_t>& dims) override {
|
|
|
|
|
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline int32_t GetLoDLevel(const std::string& name) const {
|
|
|
|
|
int32_t GetLoDLevel(const std::string& name) const override {
|
|
|
|
|
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void SetLoDLevel(const std::string& name, int32_t lod_level) {
|
|
|
|
|
void SetLoDLevel(const std::string& name, int32_t lod_level) override {
|
|
|
|
|
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
imperative::OpBase* op_;
|
|
|
|
|
const imperative::VarBasePtrMap* inputs_;
|
|
|
|
|
imperative::VarBasePtrMap* outputs_;
|
|
|
|
|
const framework::AttributeMap* attrs_;
|
|
|
|
|