|
|
@ -722,6 +722,16 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
return GetDims(vars);
|
|
|
|
return GetDims(vars);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<proto::VarType::Type> GetInputsVarType(
|
|
|
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
|
|
|
return GetVarTypes(InputVars(name));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<proto::VarType::Type> GetOutputsVarType(
|
|
|
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
|
|
|
return GetVarTypes(OutputVars(name));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
DDim GetDim(Variable* var) const {
|
|
|
|
DDim GetDim(Variable* var) const {
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
@ -766,8 +776,17 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
PADDLE_THROW("Only compile time support this method");
|
|
|
|
PADDLE_THROW("Only compile time support this method");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
proto::VarType::Type GetVarType(const std::string& name) const override {
|
|
|
|
std::vector<proto::VarType::Type> GetVarTypes(
|
|
|
|
auto* var = scope_.FindVar(name);
|
|
|
|
const std::vector<Variable*>& vars) const {
|
|
|
|
|
|
|
|
std::vector<proto::VarType::Type> retv;
|
|
|
|
|
|
|
|
retv.resize(vars.size());
|
|
|
|
|
|
|
|
std::transform(vars.begin(), vars.end(), retv.begin(),
|
|
|
|
|
|
|
|
std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType),
|
|
|
|
|
|
|
|
this, std::placeholders::_1));
|
|
|
|
|
|
|
|
return retv;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
proto::VarType::Type GetVarType(Variable* var) const {
|
|
|
|
return ToVarType(var->Type());
|
|
|
|
return ToVarType(var->Type());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|