|
|
|
@ -32,10 +32,10 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename VarType>
|
|
|
|
|
static void PrepareDataImpl(
|
|
|
|
|
const platform::Place& place, const NameVarMap<VarType>& ins,
|
|
|
|
|
const framework::OperatorWithKernel& op,
|
|
|
|
|
const framework::OpKernelType& expected_kernel_key) {
|
|
|
|
|
static void PrepareData(const platform::Place& place,
|
|
|
|
|
const NameVarMap<VarType>& ins,
|
|
|
|
|
const framework::OperatorWithKernel& op,
|
|
|
|
|
const framework::OpKernelType& expected_kernel_key) {
|
|
|
|
|
for (const auto& name_pair : ins) {
|
|
|
|
|
for (const auto& var_base : name_pair.second) {
|
|
|
|
|
const auto* tensor = GetTensorFromVar(var_base->Var());
|
|
|
|
@ -63,20 +63,6 @@ static void PrepareDataImpl(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PreparedOp::PrepareData(
|
|
|
|
|
const platform::Place& place, const NameVarMap<VarBase>& ins,
|
|
|
|
|
const framework::OperatorWithKernel& op,
|
|
|
|
|
const framework::OpKernelType& expected_kernel_key) {
|
|
|
|
|
PrepareDataImpl<VarBase>(place, ins, op, expected_kernel_key);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PreparedOp::PrepareData(
|
|
|
|
|
const platform::Place& place, const NameVarMap<VariableWrapper>& ins,
|
|
|
|
|
const framework::OperatorWithKernel& op,
|
|
|
|
|
const framework::OpKernelType& expected_kernel_key) {
|
|
|
|
|
PrepareDataImpl<VariableWrapper>(place, ins, op, expected_kernel_key);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PreparedOp::PreparedOp(const framework::OperatorBase& op,
|
|
|
|
|
const framework::RuntimeContext& ctx,
|
|
|
|
|
const framework::OperatorWithKernel::OpKernelFunc& func,
|
|
|
|
@ -122,7 +108,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
|
|
|
|
|
place = dev_ctx->GetPlace();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PrepareDataImpl<VarType>(place, ins, op, expected_kernel_key);
|
|
|
|
|
PrepareData<VarType>(place, ins, op, expected_kernel_key);
|
|
|
|
|
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|