|
|
|
@ -614,16 +614,19 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
|
|
|
|
|
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
|
|
|
|
|
size_t j = 0) override {
|
|
|
|
|
PADDLE_ENFORCE_LT(i, Inputs(in).size());
|
|
|
|
|
PADDLE_ENFORCE_LT(j, Outputs(out).size());
|
|
|
|
|
const std::string& input_n = Inputs(in)[i];
|
|
|
|
|
const std::string& output_n = Outputs(out)[j];
|
|
|
|
|
auto in_it = ctx_.inputs.find(in);
|
|
|
|
|
auto out_it = ctx_.outputs.find(out);
|
|
|
|
|
PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i,
|
|
|
|
|
"Inputs %s should have %llu argument", in, i);
|
|
|
|
|
PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j,
|
|
|
|
|
"Outputs %s should have %llu argument", out, j);
|
|
|
|
|
|
|
|
|
|
Variable* in_var = in_it->second[i];
|
|
|
|
|
Variable* out_var = out_it->second[j];
|
|
|
|
|
|
|
|
|
|
Variable* in_var = scope_.FindVar(input_n);
|
|
|
|
|
Variable* out_var = scope_.FindVar(output_n);
|
|
|
|
|
PADDLE_ENFORCE(in_var->Type() == out_var->Type(),
|
|
|
|
|
"The type of %s and %s is not the same.", output_n,
|
|
|
|
|
GetDim(input_n));
|
|
|
|
|
"The type of %s and %s is not the same.", in_var->Type(),
|
|
|
|
|
out_var->Type());
|
|
|
|
|
|
|
|
|
|
if (in_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
|
|
|
|
@ -644,13 +647,16 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
|
|
|
|
|
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
|
|
|
|
|
size_t j = 0) const override {
|
|
|
|
|
const std::vector<std::string>& inputs = Inputs(in);
|
|
|
|
|
const std::vector<std::string>& outputs = Outputs(out);
|
|
|
|
|
PADDLE_ENFORCE_LT(i, inputs.size());
|
|
|
|
|
PADDLE_ENFORCE_LT(j, outputs.size());
|
|
|
|
|
Variable* in_var = scope_.FindVar(inputs.at(i));
|
|
|
|
|
auto in_it = ctx_.inputs.find(in);
|
|
|
|
|
auto out_it = ctx_.outputs.find(out);
|
|
|
|
|
PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i,
|
|
|
|
|
"Inputs %s should have %llu argument", in, i);
|
|
|
|
|
PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j,
|
|
|
|
|
"Outputs %s should have %llu argument", out, j);
|
|
|
|
|
|
|
|
|
|
Variable* in_var = in_it->second.at(i);
|
|
|
|
|
if (!in_var->IsType<LoDTensor>()) return;
|
|
|
|
|
Variable* out_var = scope_.FindVar(outputs.at(j));
|
|
|
|
|
Variable* out_var = out_it->second.at(j);
|
|
|
|
|
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
|
|
|
|
|
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
|
|
|
|
|
auto in_tensor = in_var->Get<LoDTensor>();
|
|
|
|
|