|
|
|
@ -544,11 +544,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
|
|
|
|
|
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
|
|
|
|
|
size_t j = 0) const override {
|
|
|
|
|
PADDLE_ENFORCE_LT(i, Inputs(in).size());
|
|
|
|
|
PADDLE_ENFORCE_LT(j, Outputs(out).size());
|
|
|
|
|
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
|
|
|
|
|
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
|
|
|
|
|
const std::vector<std::string>& inputs = Inputs(in);
|
|
|
|
|
const std::vector<std::string>& outputs = Outputs(out);
|
|
|
|
|
PADDLE_ENFORCE_LT(i, inputs.size());
|
|
|
|
|
PADDLE_ENFORCE_LT(j, outputs.size());
|
|
|
|
|
Variable* in_var = scope_.FindVar(inputs.at(i));
|
|
|
|
|
if (!in_var->IsType<LoDTensor>()) return;
|
|
|
|
|
Variable* out_var = scope_.FindVar(outputs.at(j));
|
|
|
|
|
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
|
|
|
|
|
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
|
|
|
|
|
auto in_tensor = in_var->Get<LoDTensor>();
|
|
|
|
@ -576,20 +578,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
out_tensor->set_layout(in_tensor.layout());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ShareLayout(const std::string& in, const std::string& out, size_t i = 0,
|
|
|
|
|
size_t j = 0) const {
|
|
|
|
|
PADDLE_ENFORCE_LT(i, Inputs(in).size());
|
|
|
|
|
PADDLE_ENFORCE_LT(j, Outputs(out).size());
|
|
|
|
|
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
|
|
|
|
|
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
|
|
|
|
|
if (!in_var->IsType<LoDTensor>()) return;
|
|
|
|
|
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
|
|
|
|
|
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
|
|
|
|
|
auto in_tensor = in_var->Get<LoDTensor>();
|
|
|
|
|
auto* out_tensor = out_var->GetMutable<LoDTensor>();
|
|
|
|
|
out_tensor->set_layout(in_tensor.layout());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsRuntime() const override { return true; }
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|