You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/framework/operator.cc

272 lines
8.2 KiB

/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/operator.h"
#include <algorithm>
#include <atomic>
namespace paddle {
namespace framework {
template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_.GetEigenDevice<platform::CPUPlace>();
}
#ifdef PADDLE_WITH_CUDA
template <>
Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_.GetEigenDevice<platform::GPUPlace>();
}
#endif
const Tensor* GetTensorFromVar(const Variable* var) {
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}
Tensor* GetTensorFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return var->GetMutable<Tensor>();
}
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
"Op %s input %s should contain only one variable", type_,
name);
return ins.empty() ? kEmptyVarName : ins[0];
}
8 years ago
const std::vector<std::string>& OperatorBase::Inputs(
const std::string& name) const {
auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Op %s do not have input %s", type_,
name);
return it->second;
}
std::string OperatorBase::Output(const std::string& name) const {
auto& outs = Outputs(name);
PADDLE_ENFORCE_LE(outs.size(), 1UL,
"Op %s output %s should contain only one variable", type_,
name);
return outs.empty() ? kEmptyVarName : outs[0];
}
8 years ago
const std::vector<std::string>& OperatorBase::Outputs(
const std::string& name) const {
auto it = outputs_.find(name);
PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output called %s",
type_, name);
return it->second;
}
std::string OperatorBase::DebugString() const {
std::stringstream ss;
8 years ago
ss << "Op(" << type_ << "), inputs:{";
for (auto it = inputs_.begin(); it != inputs_.end();) {
auto& input = *it;
8 years ago
ss << input.first << "[";
for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i];
if (i != input.second.size() - 1) {
ss << ", ";
}
}
8 years ago
ss << "]";
++it;
if (it != inputs_.end()) {
ss << ", ";
}
}
8 years ago
ss << "}, outputs:{";
for (auto it = outputs_.begin(); it != outputs_.end();) {
auto& output = *it;
8 years ago
ss << output.first << "[";
for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i];
if (i != output.second.size() - 1) {
ss << ", ";
}
}
8 years ago
ss << "]";
++it;
if (it != outputs_.end()) {
ss << ", ";
}
}
8 years ago
ss << "}.";
return ss.str();
}
void OperatorBase::Rename(const std::string& old_name,
const std::string& new_name) {
8 years ago
for (auto& input : inputs_) {
std::replace(input.second.begin(), input.second.end(), old_name, new_name);
}
for (auto& output : outputs_) {
std::replace(output.second.begin(), output.second.end(), old_name,
new_name);
}
}
OperatorBase::OperatorBase(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
GenerateTemporaryNames();
CheckAllInputOutputSet();
}
8 years ago
std::vector<std::string> OperatorBase::InputVars() const {
std::vector<std::string> ret_val;
for (auto& o : outputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val;
if (has_intermediate) {
// push all outputs into ret_val
for (auto& o : outputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
auto& info = OpInfoMap::Instance().Get(Type());
// get all OpProto::Var for outputs
for (auto& o : info.Proto().outputs()) {
// ignore all intermediate output
if (o.intermediate()) continue;
auto out = outputs_.find(o.name());
if (out != outputs_.end()) {
ret_val.reserve(ret_val.size() + out->second.size());
ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
}
}
return ret_val;
}
void OperatorBase::CheckAllInputOutputSet() const {
auto& info_map = OpInfoMap::Instance();
auto* op_info = info_map.GetNullable(Type());
8 years ago
if (op_info == nullptr || op_info->proto_ == nullptr) return;
for (auto& in : op_info->Proto().inputs()) {
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
8 years ago
"Type %s's input %s is not set", Type(), in.name());
}
for (auto& out : op_info->Proto().outputs()) {
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
8 years ago
"Type %s's output %s is not set", Type(), out.name());
}
}
void OperatorBase::GenerateTemporaryNames() {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
}
template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const {
auto names = op().Inputs(name);
std::vector<const Tensor*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
});
return res;
}
template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto var = OutputVar(name);
return var == nullptr ? nullptr : var->GetMutable<LoDTensor>();
}
template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const {
auto names = op().Outputs(name);
std::vector<Tensor*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr
: var->GetMutable<LoDTensor>();
});
return res;
}
std::ostream& operator<<(std::ostream& os,
const OperatorWithKernel::OpKernelKey& kernel_key) {
os << "place[" << kernel_key.place_ << "]:data_type[" << kernel_key.data_type_
<< "]";
return os;
}
bool OpSupportGPU(const std::string& op_type) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) {
// All control operator must support GPU
return true;
}
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_)) {
return true;
}
}
return false;
}
} // namespace framework
} // namespace paddle