You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
117 lines
4.1 KiB
117 lines
4.1 KiB
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#include "paddle/fluid/framework/data_transform.h"
|
|
|
|
#include "paddle/fluid/framework/data_device_transform.h"
|
|
#include "paddle/fluid/framework/data_layout_transform.h"
|
|
#include "paddle/fluid/framework/data_type_transform.h"
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
#endif
|
|
|
|
namespace paddle {
|
|
namespace framework {
|
|
|
|
static void PassTensorData(Tensor *from, Tensor *to) {
|
|
to->ShareDataWith(*from);
|
|
*from = Tensor();
|
|
}
|
|
|
|
void TransformData(const OpKernelType &expected_kernel_type,
|
|
const OpKernelType &kernel_type_for_var,
|
|
const Tensor &input_tensor, Tensor *output_tensor) {
|
|
bool transformed = false;
|
|
Tensor in;
|
|
in.ShareDataWith(input_tensor);
|
|
Tensor out;
|
|
DataLayout lin = kernel_type_for_var.data_layout_;
|
|
DataLayout lout = expected_kernel_type.data_layout_;
|
|
|
|
// do layout transform
|
|
if (NeedTransformLayout(lout, lin)) {
|
|
if (lin == DataLayout::kMKLDNN || lout == DataLayout::kMKLDNN) {
|
|
PADDLE_ENFORCE(
|
|
!(lin == DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN),
|
|
"No layout transform needed between two MKLDNN OPKernels");
|
|
|
|
if (lin != DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN) {
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
|
|
// Just set layout/format. No real transform occur
|
|
|
|
auto out_format = platform::MKLDNNFormatForSize(in.dims().size(),
|
|
ToMKLDNNFormat(lin));
|
|
|
|
out.ShareDataWith(input_tensor);
|
|
out.set_layout(DataLayout::kMKLDNN);
|
|
out.set_format(out_format);
|
|
#endif
|
|
} else {
|
|
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
|
|
// Do transform via MKLDNN lib
|
|
TransDataLayoutFromMKLDNN(kernel_type_for_var, expected_kernel_type, in,
|
|
&out);
|
|
}
|
|
} else {
|
|
// Case3 - transfrom between Non-MKLDNN OPKernels
|
|
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
|
|
}
|
|
transformed = true;
|
|
PassTensorData(&out, &in);
|
|
}
|
|
|
|
// do data type transform
|
|
if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) {
|
|
TransDataType(kernel_type_for_var, expected_kernel_type, in, &out);
|
|
transformed = true;
|
|
PassTensorData(&out, &in);
|
|
}
|
|
|
|
// do device transform
|
|
if (!platform::is_same_place(kernel_type_for_var.place_,
|
|
expected_kernel_type.place_)) {
|
|
TransDataDevice(in, expected_kernel_type.place_, &out);
|
|
transformed = true;
|
|
PassTensorData(&out, &in);
|
|
}
|
|
|
|
PADDLE_ENFORCE(transformed, "No transform is applied, please check!");
|
|
// get output data
|
|
output_tensor->ShareDataWith(in);
|
|
}
|
|
|
|
void SetTensorToVariable(const Variable &in_var, const Tensor &tensor,
|
|
Variable *out_var) {
|
|
if (in_var.IsType<LoDTensor>()) {
|
|
auto &in_lod_tensor = in_var.Get<LoDTensor>();
|
|
auto *tran_lod_tensor = out_var->GetMutable<LoDTensor>();
|
|
tran_lod_tensor->set_lod(in_lod_tensor.lod());
|
|
tran_lod_tensor->set_layout(in_lod_tensor.layout());
|
|
tran_lod_tensor->ShareDataWith(tensor);
|
|
} else if (in_var.IsType<SelectedRows>()) {
|
|
auto &in_selected_rows = in_var.Get<SelectedRows>();
|
|
auto *trans_selected_rows = out_var->GetMutable<SelectedRows>();
|
|
trans_selected_rows->set_height(in_selected_rows.height());
|
|
trans_selected_rows->set_rows(in_selected_rows.rows());
|
|
trans_selected_rows->mutable_value()->ShareDataWith(tensor);
|
|
} else {
|
|
PADDLE_THROW("unknown var type");
|
|
}
|
|
}
|
|
|
|
} // namespace framework
|
|
} // namespace paddle
|