|
|
|
@ -21,6 +21,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "op_info.h"
|
|
|
|
|
#include "paddle/framework/attribute.h"
|
|
|
|
|
#include "paddle/framework/data_type.h"
|
|
|
|
|
#include "paddle/framework/framework.pb.h"
|
|
|
|
|
#include "paddle/framework/lod_tensor.h"
|
|
|
|
|
#include "paddle/framework/scope.h"
|
|
|
|
@ -407,7 +408,7 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
|
|
|
|
|
const Scope& scope_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpKernel {
|
|
|
|
|
class OpKernelBase {
|
|
|
|
|
public:
|
|
|
|
|
/**
|
|
|
|
|
* ExecutionContext is the only parameter of Kernel Run function.
|
|
|
|
@ -418,33 +419,47 @@ class OpKernel {
|
|
|
|
|
|
|
|
|
|
virtual void Compute(const ExecutionContext& context) const = 0;
|
|
|
|
|
|
|
|
|
|
virtual ~OpKernel() {}
|
|
|
|
|
virtual ~OpKernelBase() = default;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class OpKernel : public OpKernelBase {
|
|
|
|
|
public:
|
|
|
|
|
using ELEMENT_TYPE = T;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
struct OpKernelKey {
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
DataType data_type_;
|
|
|
|
|
|
|
|
|
|
OpKernelKey() = default;
|
|
|
|
|
explicit OpKernelKey(const platform::DeviceContext& dev_ctx) {
|
|
|
|
|
place_ = dev_ctx.GetPlace();
|
|
|
|
|
}
|
|
|
|
|
OpKernelKey(DataType data_type, platform::Place place)
|
|
|
|
|
: place_(place), data_type_(data_type) {}
|
|
|
|
|
|
|
|
|
|
OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
|
|
|
|
|
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
|
|
|
|
|
|
|
|
|
|
bool operator==(const OpKernelKey& o) const {
|
|
|
|
|
return platform::places_are_same_class(place_, o.place_);
|
|
|
|
|
return platform::places_are_same_class(place_, o.place_) &&
|
|
|
|
|
data_type_ == o.data_type_;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct OpKernelHash {
|
|
|
|
|
std::hash<bool> hash_;
|
|
|
|
|
std::hash<int> hash_;
|
|
|
|
|
size_t operator()(const OpKernelKey& key) const {
|
|
|
|
|
return hash_(platform::is_gpu_place(key.place_));
|
|
|
|
|
int place = key.place_.which();
|
|
|
|
|
int data_type = static_cast<int>(key.data_type_);
|
|
|
|
|
// NOTE: Number of places limit to 16.
|
|
|
|
|
int pre_hash = data_type << 4 | (place & 0x0F);
|
|
|
|
|
return hash_(pre_hash);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
using OpKernelMap =
|
|
|
|
|
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
|
|
|
|
|
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
|
|
|
|
|
OpKernelHash>;
|
|
|
|
|
|
|
|
|
|
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
|
|
|
|
|
const VariableNameMap& outputs, const AttributeMap& attrs)
|
|
|
|
@ -458,8 +473,10 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
|
|
|
|
|
void Run(const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const final {
|
|
|
|
|
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
|
|
|
|
|
opKernel->Compute(ExecutionContext(*this, scope, dev_ctx));
|
|
|
|
|
ExecutionContext ctx(*this, scope, dev_ctx);
|
|
|
|
|
auto& opKernel = AllOpKernels().at(type_).at(
|
|
|
|
|
OpKernelKey(IndicateDataType(ctx), dev_ctx));
|
|
|
|
|
opKernel->Compute(ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
|
|
|
|
@ -469,13 +486,43 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SupportGPU() const override {
|
|
|
|
|
OperatorWithKernel::OpKernelKey key;
|
|
|
|
|
key.place_ = platform::GPUPlace();
|
|
|
|
|
return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0;
|
|
|
|
|
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
|
|
|
|
|
return std::any_of(op_kernels.begin(), op_kernels.end(),
|
|
|
|
|
[](OpKernelMap::const_reference kern_pair) {
|
|
|
|
|
return platform::is_gpu_place(kern_pair.first.place_);
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
virtual void InferShape(InferShapeContextBase* ctx) const = 0;
|
|
|
|
|
|
|
|
|
|
// indicate kernel DataType by input data. Defaultly all input data must be
|
|
|
|
|
// same.
|
|
|
|
|
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
|
|
|
|
|
auto& scope = ctx.scope();
|
|
|
|
|
int data_type = -1;
|
|
|
|
|
for (auto& input : this->inputs_) {
|
|
|
|
|
for (auto& ipt_name : input.second) {
|
|
|
|
|
auto* var = scope.FindVar(ipt_name);
|
|
|
|
|
if (var != nullptr) {
|
|
|
|
|
const Tensor* t = nullptr;
|
|
|
|
|
if (var->IsType<Tensor>()) {
|
|
|
|
|
t = &var->Get<Tensor>();
|
|
|
|
|
} else if (var->IsType<LoDTensor>()) {
|
|
|
|
|
t = &var->Get<LoDTensor>();
|
|
|
|
|
}
|
|
|
|
|
if (t != nullptr) {
|
|
|
|
|
int tmp = static_cast<int>(ToDataType(t->type()));
|
|
|
|
|
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
|
|
|
|
|
"DataType of Paddle Op must be same.");
|
|
|
|
|
data_type = tmp;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
|
|
|
|
|
return static_cast<DataType>(data_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|