|
|
|
|
@ -81,13 +81,35 @@ class CompileTimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
"The %s[%d] is @EMPTY@", out, j);
|
|
|
|
|
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
|
|
|
|
|
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
|
|
|
|
|
if (in_var->GetType() != proto::VarType::LOD_TENSOR) {
|
|
|
|
|
VLOG(3) << "input " << in << " is not LodTensor";
|
|
|
|
|
if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
|
|
|
|
|
in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
|
|
|
|
|
VLOG(3) << "input " << in << " is not LodTensor or LodTensorArray.";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
out_var->SetLoDLevel(in_var->GetLoDLevel());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DecreaseLoDLevel(const std::string &in, const std::string &out,
|
|
|
|
|
size_t i = 0, size_t j = 0) const override {
|
|
|
|
|
PADDLE_ENFORCE_LT(i, Inputs(in).size());
|
|
|
|
|
PADDLE_ENFORCE_LT(j, Outputs(out).size());
|
|
|
|
|
PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName,
|
|
|
|
|
"The %s[%d] is @EMPTY@", in, i);
|
|
|
|
|
PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName,
|
|
|
|
|
"The %s[%d] is @EMPTY@", out, j);
|
|
|
|
|
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
|
|
|
|
|
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
|
|
|
|
|
PADDLE_ENFORCE(out_var->GetType() == proto::VarType::LOD_TENSOR_ARRAY ||
|
|
|
|
|
out_var->GetType() == proto::VarType::LOD_TENSOR,
|
|
|
|
|
"The input %s should be LodTensorArray or LodTensor.",
|
|
|
|
|
out_var->Name());
|
|
|
|
|
PADDLE_ENFORCE(in_var->GetType() == proto::VarType::LOD_TENSOR,
|
|
|
|
|
"The input %s should be LodTensor.", in_var->Name());
|
|
|
|
|
if (in_var->GetLoDLevel() > 0) {
|
|
|
|
|
out_var->SetLoDLevel(in_var->GetLoDLevel() - 1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsRuntime() const override;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
|