|
|
|
@ -351,6 +351,20 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
return op_.Outputs(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
|
|
|
|
|
size_t j = 0) const override {
|
|
|
|
|
PADDLE_ENFORCE_LT(i, Inputs(in).size());
|
|
|
|
|
PADDLE_ENFORCE_LT(j, Outputs(out).size());
|
|
|
|
|
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
|
|
|
|
|
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
|
|
|
|
|
if (!in_var->IsType<LoDTensor>()) return;
|
|
|
|
|
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
|
|
|
|
|
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
|
|
|
|
|
auto in_tensor = in_var->Get<LoDTensor>();
|
|
|
|
|
auto* out_tensor = out_var->GetMutable<LoDTensor>();
|
|
|
|
|
out_tensor->set_lod(in_tensor.lod());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
DDim GetDim(const std::string& name) const override {
|
|
|
|
|
Variable* var = scope_.FindVar(name);
|
|
|
|
|