|
|
|
@ -345,27 +345,10 @@ class OpKernel : public OpKernelBase {
|
|
|
|
|
using ELEMENT_TYPE = T;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
struct OpKernelKey {
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
DataType data_type_;
|
|
|
|
|
|
|
|
|
|
OpKernelKey(DataType data_type, platform::Place place)
|
|
|
|
|
: place_(place), data_type_(data_type) {}
|
|
|
|
|
|
|
|
|
|
OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
|
|
|
|
|
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
|
|
|
|
|
|
|
|
|
|
bool operator==(const OpKernelKey& o) const {
|
|
|
|
|
return platform::places_are_same_class(place_, o.place_) &&
|
|
|
|
|
data_type_ == o.data_type_;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct OpKernelHash {
|
|
|
|
|
struct OpKernelType {
|
|
|
|
|
struct Hash {
|
|
|
|
|
std::hash<int> hash_;
|
|
|
|
|
size_t operator()(const OpKernelKey& key) const {
|
|
|
|
|
size_t operator()(const OpKernelType& key) const {
|
|
|
|
|
int place = key.place_.which();
|
|
|
|
|
int data_type = static_cast<int>(key.data_type_);
|
|
|
|
|
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
|
|
|
|
@ -374,9 +357,26 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
DataType data_type_;
|
|
|
|
|
|
|
|
|
|
OpKernelType(DataType data_type, platform::Place place)
|
|
|
|
|
: place_(place), data_type_(data_type) {}
|
|
|
|
|
|
|
|
|
|
OpKernelType(DataType data_type, const platform::DeviceContext& dev_ctx)
|
|
|
|
|
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
|
|
|
|
|
|
|
|
|
|
bool operator==(const OpKernelType& o) const {
|
|
|
|
|
return platform::places_are_same_class(place_, o.place_) &&
|
|
|
|
|
data_type_ == o.data_type_;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
using OpKernelMap =
|
|
|
|
|
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
|
|
|
|
|
OpKernelHash>;
|
|
|
|
|
std::unordered_map<OpKernelType, std::unique_ptr<OpKernelBase>,
|
|
|
|
|
OpKernelType::Hash>;
|
|
|
|
|
|
|
|
|
|
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
|
|
|
|
|
const VariableNameMap& outputs, const AttributeMap& attrs)
|
|
|
|
@ -404,40 +404,15 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
virtual OpKernelType GetKernelType(const ExecutionContext& ctx) const;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// indicate kernel DataType by input data. Defaultly all input data must be
|
|
|
|
|
// same.
|
|
|
|
|
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
|
|
|
|
|
auto& scope = ctx.scope();
|
|
|
|
|
int data_type = -1;
|
|
|
|
|
for (auto& input : this->inputs_) {
|
|
|
|
|
for (auto& ipt_name : input.second) {
|
|
|
|
|
auto* var = scope.FindVar(ipt_name);
|
|
|
|
|
if (var != nullptr) {
|
|
|
|
|
const Tensor* t = nullptr;
|
|
|
|
|
if (var->IsType<Tensor>()) {
|
|
|
|
|
t = &var->Get<Tensor>();
|
|
|
|
|
} else if (var->IsType<LoDTensor>()) {
|
|
|
|
|
t = &var->Get<LoDTensor>();
|
|
|
|
|
} else if (var->IsType<SelectedRows>()) {
|
|
|
|
|
t = &(var->Get<SelectedRows>().value());
|
|
|
|
|
}
|
|
|
|
|
if (t != nullptr) {
|
|
|
|
|
int tmp = static_cast<int>(ToDataType(t->type()));
|
|
|
|
|
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
|
|
|
|
|
"DataType of Paddle Op %s must be the same.",
|
|
|
|
|
Type());
|
|
|
|
|
data_type = tmp;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
|
|
|
|
|
return static_cast<DataType>(data_type);
|
|
|
|
|
}
|
|
|
|
|
DataType IndicateDataType(const ExecutionContext& ctx) const;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::ostream& operator<<(std::ostream& os,
|
|
|
|
|
const OperatorWithKernel::OpKernelKey& kernel_key);
|
|
|
|
|
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key);
|
|
|
|
|
|
|
|
|
|
extern bool OpSupportGPU(const std::string& op_type);
|
|
|
|
|
|
|
|
|
|