|
|
|
@ -212,9 +212,9 @@ class InferShapeContext {
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<const Variable*> MultiOutputVar(const std::string& name) const {
|
|
|
|
|
std::vector<Variable*> MultiOutputVar(const std::string& name) const {
|
|
|
|
|
auto names = op_.Outputs(name);
|
|
|
|
|
std::vector<const Variable*> res;
|
|
|
|
|
std::vector<Variable*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[this](const std::string& name) {
|
|
|
|
@ -271,6 +271,20 @@ class InferShapeContext {
|
|
|
|
|
return &var->Get<Tensor>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
|
|
|
|
|
size_t j = 0) const {
|
|
|
|
|
PADDLE_ENFORCE_LT(i, InputSize(in));
|
|
|
|
|
PADDLE_ENFORCE_LT(j, OutputSize(out));
|
|
|
|
|
auto* in_var = MultiInputVar(in)[i];
|
|
|
|
|
auto* out_var = MultiOutputVar(out)[j];
|
|
|
|
|
if (!in_var->IsType<LoDTensor>()) return;
|
|
|
|
|
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
|
|
|
|
|
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
|
|
|
|
|
auto in_tensor = in_var->Get<LoDTensor>();
|
|
|
|
|
auto* out_tensor = out_var->GetMutable<LoDTensor>();
|
|
|
|
|
out_tensor->set_lod(in_tensor.lod());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const OperatorBase& op_;
|
|
|
|
|
const Scope& scope_;
|
|
|
|
@ -283,6 +297,13 @@ template <>
|
|
|
|
|
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
Tensor* InferShapeContext::Output<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct EigenDeviceConverter;
|
|
|
|
|
|
|
|
|
@ -315,38 +336,10 @@ class ExecutionContext : public InferShapeContext {
|
|
|
|
|
return device_context_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// redefine Output function,
|
|
|
|
|
// use Variable::Get instead of Variable::GetMutable
|
|
|
|
|
template <typename T>
|
|
|
|
|
T* Output(const std::string& name) const {
|
|
|
|
|
auto var = OutputVar(name);
|
|
|
|
|
return var == nullptr ? nullptr : const_cast<T*>(&var->Get<T>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// redefine MultiOutput function.
|
|
|
|
|
// use Variable::Get instead of Variable::GetMutable
|
|
|
|
|
template <typename T>
|
|
|
|
|
std::vector<T*> MultiOutput(const std::string& name) const {
|
|
|
|
|
auto names = op().Outputs(name);
|
|
|
|
|
std::vector<T*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(
|
|
|
|
|
names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[&](const std::string& sub_name) { return Output<T>(sub_name); });
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const platform::DeviceContext& device_context_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
class OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
/**
|
|
|
|
|