|
|
|
@ -317,6 +317,104 @@ class ExecutionContext : public InferShapeContext {
|
|
|
|
|
const platform::DeviceContext& device_context_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CompileTimeInferShapeContext : public InferShapeContextBase {
|
|
|
|
|
public:
|
|
|
|
|
CompileTimeInferShapeContext(const OperatorBase& op, const Scope& scope)
|
|
|
|
|
: op_(op), scope_(scope) {}
|
|
|
|
|
|
|
|
|
|
bool HasInput(const std::string& name) const {
|
|
|
|
|
auto ipt = op_.Input(name);
|
|
|
|
|
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
|
|
|
|
|
return var != nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasOutput(const std::string& name) const {
|
|
|
|
|
auto ipt = op_.Output(name);
|
|
|
|
|
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
|
|
|
|
|
return var != nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasInputs(const std::string& name) const {
|
|
|
|
|
auto inputs = op_.Inputs(name);
|
|
|
|
|
if (inputs.size() == 0UL) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (auto& input : inputs) {
|
|
|
|
|
if (scope_.FindVar(input) == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasOutputs(const std::string& name) const {
|
|
|
|
|
auto outputs = op_.Outputs(name);
|
|
|
|
|
if (outputs.size() == 0UL) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (auto& output : outputs) {
|
|
|
|
|
if (scope_.FindVar(output) == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DDim GetInputDim(const std::string& name) const {
|
|
|
|
|
return GetDim(op_.Input(name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetInputDim(const std::string& name, const DDim& dim) {
|
|
|
|
|
SetDim(op_.Input(name), dim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DDim GetOutputDim(const std::string& name) const {
|
|
|
|
|
return GetDim(op_.Output(name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetOutputDim(const std::string& name, const DDim& dim) {
|
|
|
|
|
SetDim(op_.Output(name), dim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AttrReader Attrs() const { return AttrReader(op_.Attrs()); }
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string>& Inputs(const std::string& name) const {
|
|
|
|
|
return op_.Inputs(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string>& Outputs(const std::string& name) const {
|
|
|
|
|
return op_.Outputs(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
template <bool Allocate>
|
|
|
|
|
Tensor* GetTensor(const std::string& name) const {
|
|
|
|
|
Tensor* t = nullptr;
|
|
|
|
|
auto* var = scope_.FindVar(name);
|
|
|
|
|
if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
|
|
|
|
|
if (Allocate) {
|
|
|
|
|
t = var->GetMutable<LoDTensor>();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Variable(%s) should be tensor", name);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
t = GetTensorFromVar(scope_.FindVar(name));
|
|
|
|
|
}
|
|
|
|
|
return t;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DDim GetDim(const std::string& name) const {
|
|
|
|
|
return GetTensor<false>(name)->dims();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDim(const std::string& name, const DDim& dim) {
|
|
|
|
|
GetTensor<true>(name)->Resize(dim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const OperatorBase& op_;
|
|
|
|
|
const Scope& scope_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class RuntimeInferShapeContext : public InferShapeContextBase {
|
|
|
|
|
public:
|
|
|
|
|
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
|
|
|
|
|