|
|
|
@ -277,9 +277,9 @@ class InferShapeContext {
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<const Variable*> MultiOutputVar(const std::string& name) const {
|
|
|
|
|
std::vector<Variable*> MultiOutputVar(const std::string& name) const {
|
|
|
|
|
auto names = op_.Outputs(name);
|
|
|
|
|
std::vector<const Variable*> res;
|
|
|
|
|
std::vector<Variable*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[this](const std::string& name) {
|
|
|
|
@ -336,12 +336,19 @@ class InferShapeContext {
|
|
|
|
|
return &var->Get<Tensor>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ShareLoD(const std::string& in, const std::string& out) const {
|
|
|
|
|
PADDLE_ENFORCE(InputVar(in)->IsType<LoDTensor>(),
|
|
|
|
|
"The Input(%s) must be LoDTensor.", in);
|
|
|
|
|
PADDLE_ENFORCE(OutputVar(out)->IsType<LoDTensor>(),
|
|
|
|
|
"The Output(%s) must be LoDTensor.", out);
|
|
|
|
|
Output<LoDTensor>(out)->set_lod(Input<LoDTensor>(in)->lod());
|
|
|
|
|
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
|
|
|
|
|
size_t j = 0) const {
|
|
|
|
|
PADDLE_ENFORCE_LT(i, InputSize(in));
|
|
|
|
|
PADDLE_ENFORCE_LT(j, OutputSize(out));
|
|
|
|
|
auto* in_var = MultiInputVar(in)[i];
|
|
|
|
|
auto* out_var = MultiOutputVar(out)[j];
|
|
|
|
|
PADDLE_ENFORCE(in_var->IsType<LoDTensor>(),
|
|
|
|
|
"The %d-th input of Input(%s) must be LoDTensor.", in);
|
|
|
|
|
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
|
|
|
|
|
"The %d-th output of Output(%s) must be LoDTensor.", out);
|
|
|
|
|
auto in_tensor = in_var->Get<LoDTensor>();
|
|
|
|
|
auto* out_tensor = out_var->GetMutable<LoDTensor>();
|
|
|
|
|
out_tensor->set_lod(in_tensor.lod());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -388,38 +395,10 @@ class ExecutionContext : public InferShapeContext {
|
|
|
|
|
return device_context_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// redefine Output function,
|
|
|
|
|
// use Variable::Get instead of Variable::GetMutable
|
|
|
|
|
template <typename T>
|
|
|
|
|
T* Output(const std::string& name) const {
|
|
|
|
|
auto var = OutputVar(name);
|
|
|
|
|
return var == nullptr ? nullptr : const_cast<T*>(&var->Get<T>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// redefine MultiOutput function.
|
|
|
|
|
// use Variable::Get instead of Variable::GetMutable
|
|
|
|
|
template <typename T>
|
|
|
|
|
std::vector<T*> MultiOutput(const std::string& name) const {
|
|
|
|
|
auto names = op().Outputs(name);
|
|
|
|
|
std::vector<T*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(
|
|
|
|
|
names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[&](const std::string& sub_name) { return Output<T>(sub_name); });
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const platform::DeviceContext& device_context_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
class OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
/**
|
|
|
|
|