|
|
|
@ -33,22 +33,24 @@ class InferShapeContext {
|
|
|
|
|
virtual bool HasInput(const std::string &name) const = 0;
|
|
|
|
|
virtual bool HasOutput(const std::string &name) const = 0;
|
|
|
|
|
|
|
|
|
|
std::vector<proto::VarType::Type> GetInputsVarType(
|
|
|
|
|
virtual std::vector<proto::VarType::Type> GetInputsVarType(
|
|
|
|
|
const std::string &name) const;
|
|
|
|
|
std::vector<proto::VarType::Type> GetOutputsVarType(
|
|
|
|
|
virtual std::vector<proto::VarType::Type> GetOutputsVarType(
|
|
|
|
|
const std::string &name) const;
|
|
|
|
|
|
|
|
|
|
virtual bool HasInputs(const std::string &name) const = 0;
|
|
|
|
|
virtual bool HasOutputs(const std::string &name) const = 0;
|
|
|
|
|
|
|
|
|
|
DDim GetInputDim(const std::string &name) const;
|
|
|
|
|
std::vector<DDim> GetInputsDim(const std::string &name) const;
|
|
|
|
|
std::vector<DDim> GetReaderDims(const std::string &name) const;
|
|
|
|
|
DDim GetInputsElementDim(const std::string &name, int idx) const;
|
|
|
|
|
virtual DDim GetInputDim(const std::string &name) const;
|
|
|
|
|
virtual std::vector<DDim> GetInputsDim(const std::string &name) const;
|
|
|
|
|
virtual std::vector<DDim> GetReaderDims(const std::string &name) const;
|
|
|
|
|
virtual DDim GetInputsElementDim(const std::string &name, int idx) const;
|
|
|
|
|
|
|
|
|
|
void SetOutputDim(const std::string &name, const DDim &dim);
|
|
|
|
|
void SetOutputsDim(const std::string &name, const std::vector<DDim> &dims);
|
|
|
|
|
void SetReaderDims(const std::string &name, const std::vector<DDim> &dims);
|
|
|
|
|
virtual void SetOutputDim(const std::string &name, const DDim &dim);
|
|
|
|
|
virtual void SetOutputsDim(const std::string &name,
|
|
|
|
|
const std::vector<DDim> &dims);
|
|
|
|
|
virtual void SetReaderDims(const std::string &name,
|
|
|
|
|
const std::vector<DDim> &dims);
|
|
|
|
|
|
|
|
|
|
virtual AttrReader Attrs() const = 0;
|
|
|
|
|
virtual const std::vector<std::string> &Inputs(
|
|
|
|
@ -67,13 +69,14 @@ class InferShapeContext {
|
|
|
|
|
|
|
|
|
|
virtual bool IsRuntime() const = 0;
|
|
|
|
|
|
|
|
|
|
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);
|
|
|
|
|
std::vector<InferShapeVarPtr> GetOutputVarPtrs(const std::string &name);
|
|
|
|
|
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
|
|
|
|
|
virtual std::vector<InferShapeVarPtr> GetInputVarPtrs(
|
|
|
|
|
const std::string &name) = 0;
|
|
|
|
|
virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs(
|
|
|
|
|
const std::string &name) = 0;
|
|
|
|
|
|
|
|
|
|
// Note: In while op, we need this to be public
|
|
|
|
|
void SetDims(const std::vector<std::string> &names,
|
|
|
|
|
const std::vector<DDim> &dims);
|
|
|
|
|
virtual void SetDims(const std::vector<std::string> &names,
|
|
|
|
|
const std::vector<DDim> &dims);
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
virtual DDim GetDim(const std::string &name) const = 0;
|
|
|
|
|