|
|
|
@ -22,6 +22,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "op_info.h"
|
|
|
|
|
#include "paddle/framework/attribute.h"
|
|
|
|
|
#include "paddle/framework/block_desc.h"
|
|
|
|
|
#include "paddle/framework/data_type.h"
|
|
|
|
|
#include "paddle/framework/framework.pb.h"
|
|
|
|
|
#include "paddle/framework/lod_tensor.h"
|
|
|
|
@ -317,26 +318,122 @@ class ExecutionContext : public InferShapeContext {
|
|
|
|
|
const platform::DeviceContext& device_context_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CompileTimeInferShapeContext : public InferShapeContextBase {
|
|
|
|
|
public:
|
|
|
|
|
CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
|
|
|
|
|
: op_(op), block_(block) {}
|
|
|
|
|
|
|
|
|
|
bool HasInput(const std::string& name) const override {
|
|
|
|
|
const std::vector<std::string>& input_names = op_.Input(name);
|
|
|
|
|
auto length = input_names.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(length, 1UL,
|
|
|
|
|
"Input(%s) should have only one value, "
|
|
|
|
|
"but it have %d now",
|
|
|
|
|
name, length);
|
|
|
|
|
return block_.HasVar(input_names[0]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasOutput(const std::string& name) const override {
|
|
|
|
|
const std::vector<std::string>& output_names = op_.Output(name);
|
|
|
|
|
auto length = output_names.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(length, 1UL,
|
|
|
|
|
"Output(%s) should have only one value, "
|
|
|
|
|
"but it have %d now",
|
|
|
|
|
name, length);
|
|
|
|
|
return block_.HasVar(output_names[0]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasInputs(const std::string& name) const override {
|
|
|
|
|
const std::vector<std::string>& input_names = op_.Input(name);
|
|
|
|
|
PADDLE_ENFORCE(!input_names.empty(), "Inputs(%s) length is 0", name);
|
|
|
|
|
for (auto& input : input_names) {
|
|
|
|
|
if (!block_.HasVar(input)) return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasOutputs(const std::string& name) const override {
|
|
|
|
|
const std::vector<std::string>& output_names = op_.Output(name);
|
|
|
|
|
PADDLE_ENFORCE(!output_names.empty(), "Inputs(%s) length is 0", name);
|
|
|
|
|
for (auto& output : output_names) {
|
|
|
|
|
if (!block_.HasVar(output)) return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DDim GetInputDim(const std::string& name) const override {
|
|
|
|
|
std::vector<DDim> ddims = GetInputsDim(name);
|
|
|
|
|
auto length = ddims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(length, 1UL,
|
|
|
|
|
"Input(%s) should have 1 value, "
|
|
|
|
|
"but it has %d now",
|
|
|
|
|
name, length);
|
|
|
|
|
return ddims[0];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetInputDim(const std::string& name, const DDim& dim) override {
|
|
|
|
|
SetInputsDim(name, {dim});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DDim GetOutputDim(const std::string& name) const override {
|
|
|
|
|
std::vector<DDim> ddims = GetOutputsDim(name);
|
|
|
|
|
auto length = ddims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(length, 1UL,
|
|
|
|
|
"Output(%s) should have 1 value, "
|
|
|
|
|
"but it has %d now",
|
|
|
|
|
name, length);
|
|
|
|
|
return ddims[0];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetOutputDim(const std::string& name, const DDim& dim) override {
|
|
|
|
|
SetOutputsDim(name, {dim});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); }
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string>& Inputs(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return op_.Input(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string>& Outputs(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return op_.Output(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
DDim GetDim(const std::string& name) const override {
|
|
|
|
|
return framework::make_ddim(block_.Var(name)->Shape());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDim(const std::string& name, const DDim& dim) override {
|
|
|
|
|
block_.Var(name)->SetShape(framework::vectorize(dim));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const OpDescBind& op_;
|
|
|
|
|
const BlockDescBind& block_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class RuntimeInferShapeContext : public InferShapeContextBase {
|
|
|
|
|
public:
|
|
|
|
|
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
|
|
|
|
|
: op_(op), scope_(scope) {}
|
|
|
|
|
|
|
|
|
|
bool HasInput(const std::string& name) const {
|
|
|
|
|
bool HasInput(const std::string& name) const override {
|
|
|
|
|
auto ipt = op_.Input(name);
|
|
|
|
|
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
|
|
|
|
|
return var != nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasOutput(const std::string& name) const {
|
|
|
|
|
bool HasOutput(const std::string& name) const override {
|
|
|
|
|
auto ipt = op_.Output(name);
|
|
|
|
|
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
|
|
|
|
|
return var != nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasInputs(const std::string& name) const {
|
|
|
|
|
bool HasInputs(const std::string& name) const override {
|
|
|
|
|
auto inputs = op_.Inputs(name);
|
|
|
|
|
if (inputs.size() == 0UL) {
|
|
|
|
|
if (inputs.empty()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (auto& input : inputs) {
|
|
|
|
@ -347,9 +444,9 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasOutputs(const std::string& name) const {
|
|
|
|
|
bool HasOutputs(const std::string& name) const override {
|
|
|
|
|
auto outputs = op_.Outputs(name);
|
|
|
|
|
if (outputs.size() == 0UL) {
|
|
|
|
|
if (outputs.empty()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (auto& output : outputs) {
|
|
|
|
@ -360,29 +457,31 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DDim GetInputDim(const std::string& name) const {
|
|
|
|
|
DDim GetInputDim(const std::string& name) const override {
|
|
|
|
|
return GetDim(op_.Input(name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetInputDim(const std::string& name, const DDim& dim) {
|
|
|
|
|
void SetInputDim(const std::string& name, const DDim& dim) override {
|
|
|
|
|
SetDim(op_.Input(name), dim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DDim GetOutputDim(const std::string& name) const {
|
|
|
|
|
DDim GetOutputDim(const std::string& name) const override {
|
|
|
|
|
return GetDim(op_.Output(name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetOutputDim(const std::string& name, const DDim& dim) {
|
|
|
|
|
void SetOutputDim(const std::string& name, const DDim& dim) override {
|
|
|
|
|
SetDim(op_.Output(name), dim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AttrReader Attrs() const { return AttrReader(op_.Attrs()); }
|
|
|
|
|
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string>& Inputs(const std::string& name) const {
|
|
|
|
|
const std::vector<std::string>& Inputs(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return op_.Inputs(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string>& Outputs(const std::string& name) const {
|
|
|
|
|
const std::vector<std::string>& Outputs(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return op_.Outputs(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -403,11 +502,11 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
|
|
|
|
|
return t;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DDim GetDim(const std::string& name) const {
|
|
|
|
|
DDim GetDim(const std::string& name) const override {
|
|
|
|
|
return GetTensor<false>(name)->dims();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDim(const std::string& name, const DDim& dim) {
|
|
|
|
|
void SetDim(const std::string& name, const DDim& dim) override {
|
|
|
|
|
GetTensor<true>(name)->Resize(dim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -513,9 +612,9 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
virtual void InferShape(InferShapeContextBase* ctx) const = 0;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
// indicate kernel DataType by input data. Defaultly all input data must be
|
|
|
|
|
// same.
|
|
|
|
|
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
|
|
|
|
|