|
|
|
@ -701,125 +701,85 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
|
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct RecordTime {
|
|
|
|
|
RecordTime(const std::string& name, const std::string& type)
|
|
|
|
|
: name_(name), type_(type), start_(std::chrono::system_clock::now()) {}
|
|
|
|
|
|
|
|
|
|
void inline stop() {
|
|
|
|
|
end_ = std::chrono::system_clock::now();
|
|
|
|
|
std::chrono::duration<double> diff = end_ - start_;
|
|
|
|
|
VLOG(1) << name_ << " " << type_ << " time record: " << diff.count();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~RecordTime() {
|
|
|
|
|
if (type_ == "elementwise_add") {
|
|
|
|
|
stop();
|
|
|
|
|
}
|
|
|
|
|
// stop();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string name_;
|
|
|
|
|
std::string type_;
|
|
|
|
|
std::chrono::system_clock::time_point start_;
|
|
|
|
|
std::chrono::system_clock::time_point end_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const {
|
|
|
|
|
RecordTime rt("OperatorWithKernel::All", type_);
|
|
|
|
|
{
|
|
|
|
|
RecordTime rt("OperatorWithKernel::InferShape", type_);
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
|
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
RecordTime* rt_1 = new RecordTime("OperatorWithKernel::Compute1", type_);
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
|
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
|
// check if op[type] has kernel registered.
|
|
|
|
|
auto& all_op_kernels = AllOpKernels();
|
|
|
|
|
auto kernels_iter = all_op_kernels.find(type_);
|
|
|
|
|
if (kernels_iter == all_op_kernels.end()) {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"There are no kernels which are registered in the %s operator.",
|
|
|
|
|
type_);
|
|
|
|
|
}
|
|
|
|
|
// check if op[type] has kernel registered.
|
|
|
|
|
auto& all_op_kernels = AllOpKernels();
|
|
|
|
|
auto kernels_iter = all_op_kernels.find(type_);
|
|
|
|
|
if (kernels_iter == all_op_kernels.end()) {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"There are no kernels which are registered in the %s operator.", type_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
|
OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
|
|
|
|
|
|
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the
|
|
|
|
|
// transform functions are ready.
|
|
|
|
|
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the
|
|
|
|
|
// transform functions are ready.
|
|
|
|
|
|
|
|
|
|
// for (auto& candidate : kKernelPriority) {
|
|
|
|
|
// Do selection
|
|
|
|
|
// }
|
|
|
|
|
// for (auto& candidate : kKernelPriority) {
|
|
|
|
|
// Do selection
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
auto expected_kernel_key =
|
|
|
|
|
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
|
|
|
|
|
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
|
|
|
|
|
auto expected_kernel_key =
|
|
|
|
|
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
|
|
|
|
|
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
|
|
|
|
|
|
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
|
|
|
|
|
if (kernel_iter == kernels.end() &&
|
|
|
|
|
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
|
|
|
|
|
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
|
|
|
|
|
expected_kernel_key.library_type_ = LibraryType::kPlain;
|
|
|
|
|
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
|
|
|
|
|
kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
}
|
|
|
|
|
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
|
|
|
|
|
if (kernel_iter == kernels.end() &&
|
|
|
|
|
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
|
|
|
|
|
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
|
|
|
|
|
expected_kernel_key.library_type_ = LibraryType::kPlain;
|
|
|
|
|
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
|
|
|
|
|
kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
if (kernel_iter == kernels.end()) {
|
|
|
|
|
PADDLE_THROW("op %s does not have kernel for %s", type_,
|
|
|
|
|
KernelTypeToString(expected_kernel_key));
|
|
|
|
|
}
|
|
|
|
|
if (kernel_iter == kernels.end()) {
|
|
|
|
|
PADDLE_THROW("op %s does not have kernel for %s", type_,
|
|
|
|
|
KernelTypeToString(expected_kernel_key));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// do data transformScope &transfer_scope;
|
|
|
|
|
std::vector<std::string> transfered_inplace_vars;
|
|
|
|
|
Scope* transfer_scope = nullptr;
|
|
|
|
|
// auto* transfer_scope =
|
|
|
|
|
// TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars);
|
|
|
|
|
// do data transformScope &transfer_scope;
|
|
|
|
|
std::vector<std::string> transfered_inplace_vars;
|
|
|
|
|
auto* transfer_scope =
|
|
|
|
|
TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars);
|
|
|
|
|
|
|
|
|
|
// exec scope is the scope that kernel actually executed on.
|
|
|
|
|
const Scope& exec_scope = scope;
|
|
|
|
|
// const Scope& exec_scope =
|
|
|
|
|
// (transfer_scope == nullptr ? scope : *transfer_scope);
|
|
|
|
|
// exec scope is the scope that kernel actually executed on.
|
|
|
|
|
const Scope& exec_scope =
|
|
|
|
|
(transfer_scope == nullptr ? scope : *transfer_scope);
|
|
|
|
|
|
|
|
|
|
if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
|
|
|
|
|
dev_ctx = pool.Get(expected_kernel_key.place_);
|
|
|
|
|
}
|
|
|
|
|
delete rt_1;
|
|
|
|
|
if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
|
|
|
|
|
dev_ctx = pool.Get(expected_kernel_key.place_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RecordTime* rt_2 = new RecordTime("OperatorWithKernel::Compute2", type_);
|
|
|
|
|
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
|
|
|
|
|
delete rt_2;
|
|
|
|
|
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
|
|
|
|
|
|
|
|
|
|
RecordTime* rt_3 = new RecordTime("OperatorWithKernel::Compute3", type_);
|
|
|
|
|
if (!transfered_inplace_vars.empty()) {
|
|
|
|
|
// there is inplace variable has been transfered.
|
|
|
|
|
TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope);
|
|
|
|
|
}
|
|
|
|
|
if (!transfered_inplace_vars.empty()) {
|
|
|
|
|
// there is inplace variable has been transfered.
|
|
|
|
|
TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/*For profiling/benchmark only*/
|
|
|
|
|
if (FLAGS_benchmark) {
|
|
|
|
|
dev_ctx->Wait();
|
|
|
|
|
}
|
|
|
|
|
/*For profiling/benchmark only*/
|
|
|
|
|
if (FLAGS_benchmark) {
|
|
|
|
|
dev_ctx->Wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_check_nan_inf) {
|
|
|
|
|
for (auto& vname : OutputVars(true)) {
|
|
|
|
|
auto* var = exec_scope.FindVar(vname);
|
|
|
|
|
if (var == nullptr) continue;
|
|
|
|
|
if (var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
|
|
|
|
|
} else if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
CheckTensorNANOrInf(vname,
|
|
|
|
|
var->Get<framework::SelectedRows>().value());
|
|
|
|
|
}
|
|
|
|
|
if (FLAGS_check_nan_inf) {
|
|
|
|
|
for (auto& vname : OutputVars(true)) {
|
|
|
|
|
auto* var = exec_scope.FindVar(vname);
|
|
|
|
|
if (var == nullptr) continue;
|
|
|
|
|
if (var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
|
|
|
|
|
} else if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
CheckTensorNANOrInf(vname, var->Get<framework::SelectedRows>().value());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
delete rt_3;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void OperatorWithKernel::TransferInplaceVarsBack(
|
|
|
|
|