You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
249 lines
7.6 KiB
249 lines
7.6 KiB
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#include "paddle/framework/operator.h"
|
|
#include <algorithm>
|
|
|
|
namespace paddle {
|
|
namespace framework {
|
|
|
|
template <>
|
|
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
|
|
platform::CPUPlace, Eigen::DefaultDevice>() const {
|
|
return *device_context_.get_eigen_device<Eigen::DefaultDevice>();
|
|
}
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
template <>
|
|
Eigen::GpuDevice&
|
|
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
|
|
return *device_context_.get_eigen_device<Eigen::GpuDevice>();
|
|
}
|
|
#endif
|
|
|
|
const Tensor* GetTensorFromVar(const Variable* var) {
|
|
if (var->IsType<LoDTensor>()) {
|
|
return &var->Get<LoDTensor>();
|
|
}
|
|
PADDLE_ENFORCE(var->IsType<Tensor>(),
|
|
"The Input must be LoDTensor or Tensor.");
|
|
return &var->Get<Tensor>();
|
|
}
|
|
|
|
Tensor* GetTensorFromVar(Variable* var) {
|
|
if (var->IsType<LoDTensor>()) {
|
|
return var->GetMutable<LoDTensor>();
|
|
}
|
|
PADDLE_ENFORCE(var->IsType<Tensor>(),
|
|
"The Input must be LoDTensor or Tensor.");
|
|
return var->GetMutable<Tensor>();
|
|
}
|
|
|
|
std::string OperatorBase::Input(const std::string& name) const {
|
|
auto& ins = Inputs(name);
|
|
PADDLE_ENFORCE_LE(ins.size(), 1UL,
|
|
"Op %s input %s should contain only one variable", type_,
|
|
name);
|
|
return ins.empty() ? kEmptyVarName : ins[0];
|
|
}
|
|
|
|
const std::vector<std::string>& OperatorBase::Inputs(
|
|
const std::string& name) const {
|
|
auto it = inputs_.find(name);
|
|
PADDLE_ENFORCE(it != inputs_.end(), "Op %s do not have input %s", type_,
|
|
name);
|
|
return it->second;
|
|
}
|
|
|
|
std::string OperatorBase::Output(const std::string& name) const {
|
|
auto& outs = Outputs(name);
|
|
PADDLE_ENFORCE_LE(outs.size(), 1UL,
|
|
"Op %s output %s should contain only one variable", type_,
|
|
name);
|
|
return outs.empty() ? kEmptyVarName : outs[0];
|
|
}
|
|
|
|
const std::vector<std::string>& OperatorBase::Outputs(
|
|
const std::string& name) const {
|
|
auto it = outputs_.find(name);
|
|
PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output called %s",
|
|
type_, name);
|
|
return it->second;
|
|
}
|
|
|
|
std::string OperatorBase::DebugString() const {
|
|
std::stringstream ss;
|
|
ss << "Op(" << type_ << "), inputs:{";
|
|
for (auto it = inputs_.begin(); it != inputs_.end();) {
|
|
auto& input = *it;
|
|
ss << input.first << "[";
|
|
for (size_t i = 0; i < input.second.size(); ++i) {
|
|
ss << input.second[i];
|
|
if (i != input.second.size() - 1) {
|
|
ss << ", ";
|
|
}
|
|
}
|
|
ss << "]";
|
|
++it;
|
|
if (it != inputs_.end()) {
|
|
ss << ", ";
|
|
}
|
|
}
|
|
ss << "}, outputs:{";
|
|
for (auto it = outputs_.begin(); it != outputs_.end();) {
|
|
auto& output = *it;
|
|
ss << output.first << "[";
|
|
for (size_t i = 0; i < output.second.size(); ++i) {
|
|
ss << output.second[i];
|
|
if (i != output.second.size() - 1) {
|
|
ss << ", ";
|
|
}
|
|
}
|
|
ss << "]";
|
|
++it;
|
|
if (it != outputs_.end()) {
|
|
ss << ", ";
|
|
}
|
|
}
|
|
ss << "}.";
|
|
return ss.str();
|
|
}
|
|
|
|
void OperatorBase::Rename(const std::string& old_name,
|
|
const std::string& new_name) {
|
|
for (auto& input : inputs_) {
|
|
std::replace(input.second.begin(), input.second.end(), old_name, new_name);
|
|
}
|
|
for (auto& output : outputs_) {
|
|
std::replace(output.second.begin(), output.second.end(), old_name,
|
|
new_name);
|
|
}
|
|
}
|
|
|
|
OperatorBase::OperatorBase(const std::string& type,
|
|
const VariableNameMap& inputs,
|
|
const VariableNameMap& outputs,
|
|
const AttributeMap& attrs)
|
|
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
|
|
GenerateTemporaryNames();
|
|
CheckAllInputOutputSet();
|
|
}
|
|
|
|
std::vector<std::string> OperatorBase::InputVars() const {
|
|
std::vector<std::string> ret_val;
|
|
for (auto& o : outputs_) {
|
|
ret_val.reserve(ret_val.size() + o.second.size());
|
|
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
|
|
}
|
|
return ret_val;
|
|
}
|
|
|
|
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
|
|
std::vector<std::string> ret_val;
|
|
if (has_intermediate) {
|
|
// push all outputs into ret_val
|
|
for (auto& o : outputs_) {
|
|
ret_val.reserve(ret_val.size() + o.second.size());
|
|
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
|
|
}
|
|
return ret_val;
|
|
}
|
|
auto& info = OpInfoMap::Instance().Get(Type());
|
|
|
|
// get all OpProto::Var for outputs
|
|
for (auto& o : info.Proto().outputs()) {
|
|
// ignore all intermediate output
|
|
if (o.intermediate()) continue;
|
|
auto out = outputs_.find(o.name());
|
|
if (out != outputs_.end()) {
|
|
ret_val.reserve(ret_val.size() + out->second.size());
|
|
ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
|
|
}
|
|
}
|
|
return ret_val;
|
|
}
|
|
|
|
void OperatorBase::CheckAllInputOutputSet() const {
|
|
auto& info_map = OpInfoMap::Instance();
|
|
auto* op_info = info_map.GetNullable(Type());
|
|
if (op_info == nullptr || op_info->proto_ == nullptr) return;
|
|
|
|
for (auto& in : op_info->Proto().inputs()) {
|
|
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
|
|
"Type %s's input %s is not set", Type(), in.name());
|
|
}
|
|
|
|
for (auto& out : op_info->Proto().outputs()) {
|
|
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
|
|
"Type %s's output %s is not set", Type(), out.name());
|
|
}
|
|
}
|
|
|
|
void OperatorBase::GenerateTemporaryNames() {
|
|
static std::atomic<size_t> gUniqId(0UL);
|
|
for (auto& output : outputs_) {
|
|
for (auto& output_name : output.second) {
|
|
if (output_name == kTempVarName) {
|
|
output_name += type_;
|
|
output_name += "@";
|
|
output_name += std::to_string(gUniqId.fetch_add(1));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const {
|
|
auto* var = InputVar(name);
|
|
return var == nullptr ? nullptr : GetTensorFromVar(var);
|
|
}
|
|
|
|
template <>
|
|
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
|
|
const std::string& name) const {
|
|
auto names = op().Inputs(name);
|
|
std::vector<const Tensor*> res;
|
|
res.reserve(names.size());
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
[&](const std::string& sub_name) {
|
|
auto var = scope_.FindVar(sub_name);
|
|
return var == nullptr ? nullptr : GetTensorFromVar(var);
|
|
});
|
|
return res;
|
|
}
|
|
|
|
template <>
|
|
Tensor* InferShapeContext::Output<Tensor>(const std::string& name) const {
|
|
auto var = OutputVar(name);
|
|
return var == nullptr ? nullptr : var->GetMutable<LoDTensor>();
|
|
}
|
|
|
|
template <>
|
|
std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>(
|
|
const std::string& name) const {
|
|
auto names = op().Outputs(name);
|
|
std::vector<Tensor*> res;
|
|
res.reserve(names.size());
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
[&](const std::string& sub_name) {
|
|
auto var = scope_.FindVar(sub_name);
|
|
return var == nullptr ? nullptr
|
|
: var->GetMutable<LoDTensor>();
|
|
});
|
|
return res;
|
|
}
|
|
|
|
} // namespace framework
|
|
} // namespace paddle
|