|
|
|
@ -42,23 +42,17 @@ static void PrepareData(const platform::Place& place,
|
|
|
|
|
for (const auto& var_base : name_pair.second) {
|
|
|
|
|
const auto* tensor = GetTensorFromVar(var_base->Var());
|
|
|
|
|
if (tensor && tensor->IsInitialized()) {
|
|
|
|
|
auto tmp_place = tensor->place();
|
|
|
|
|
|
|
|
|
|
// TODO(jiabin): Support transform data layout when we Verify it on more
|
|
|
|
|
// tests
|
|
|
|
|
if (!(tmp_place == place)) {
|
|
|
|
|
auto kernel_type_for_var = op.GetKernelTypeForVar(
|
|
|
|
|
name_pair.first, *tensor, expected_kernel_key);
|
|
|
|
|
if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << "Transform Variable " << var_base->Name() << " from "
|
|
|
|
|
<< kernel_type_for_var << " to " << expected_kernel_key;
|
|
|
|
|
framework::Tensor out;
|
|
|
|
|
TransformData(expected_kernel_key, kernel_type_for_var, *tensor,
|
|
|
|
|
&out);
|
|
|
|
|
SetTensorToVariable(var_base->Var(), out, var_base->MutableVar());
|
|
|
|
|
}
|
|
|
|
|
auto kernel_type_for_var = op.GetKernelTypeForVar(
|
|
|
|
|
name_pair.first, *tensor, expected_kernel_key);
|
|
|
|
|
if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << "Transform Variable " << var_base->Name() << " from "
|
|
|
|
|
<< kernel_type_for_var << " to " << expected_kernel_key;
|
|
|
|
|
framework::Tensor out;
|
|
|
|
|
TransformData(expected_kernel_key, kernel_type_for_var, *tensor,
|
|
|
|
|
&out);
|
|
|
|
|
SetTensorToVariable(var_base->Var(), out, var_base->MutableVar());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -93,6 +87,13 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
|
|
|
|
|
auto& kernels = kernels_iter->second;
|
|
|
|
|
|
|
|
|
|
framework::RuntimeContext ctx({}, {});
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
// MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
|
|
|
|
|
// GetKernelType functions, so we need to copy the attributes there.
|
|
|
|
|
// Const qualifier of Attrs had to be discarded to overwrite it.
|
|
|
|
|
auto& mutable_op_attrs = const_cast<framework::AttributeMap&>(op.Attrs());
|
|
|
|
|
mutable_op_attrs = attrs;
|
|
|
|
|
#endif
|
|
|
|
|
auto expected_kernel_key =
|
|
|
|
|
op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
|
|
|
|
|
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs));
|
|
|
|
|