|
|
|
@ -1073,7 +1073,9 @@ Scope* OperatorWithKernel::PrepareData(
|
|
|
|
|
|
|
|
|
|
proto::VarType::Type OperatorWithKernel::IndicateDataType(
|
|
|
|
|
const ExecutionContext& ctx) const {
|
|
|
|
|
int data_type = -1;
|
|
|
|
|
proto::VarType::Type dafault_data_type =
|
|
|
|
|
static_cast<proto::VarType::Type>(-1);
|
|
|
|
|
proto::VarType::Type data_type = dafault_data_type;
|
|
|
|
|
for (auto& input : this->inputs_) {
|
|
|
|
|
const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first);
|
|
|
|
|
for (size_t i = 0; i < vars.size(); ++i) {
|
|
|
|
@ -1090,18 +1092,19 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
|
|
|
|
|
if (t != nullptr) {
|
|
|
|
|
PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu)is not initialized",
|
|
|
|
|
input.first, i);
|
|
|
|
|
int tmp = static_cast<int>(t->type());
|
|
|
|
|
proto::VarType::Type tmp = t->type();
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
tmp == data_type || data_type == -1,
|
|
|
|
|
tmp == data_type || data_type == dafault_data_type,
|
|
|
|
|
"DataType of Paddle Op %s must be the same. Get (%d) != (%d)",
|
|
|
|
|
Type(), data_type, tmp);
|
|
|
|
|
Type(), DataTypeToString(data_type), DataTypeToString(tmp));
|
|
|
|
|
data_type = tmp;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
|
|
|
|
|
return static_cast<proto::VarType::Type>(data_type);
|
|
|
|
|
PADDLE_ENFORCE(data_type != dafault_data_type,
|
|
|
|
|
"DataType should be indicated by input");
|
|
|
|
|
return data_type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpKernelType OperatorWithKernel::GetExpectedKernelType(
|
|
|
|
|