|
|
|
@ -327,13 +327,13 @@ class InferShapeContext {
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor* GetTensorFromVar(const Variable* var) const {
|
|
|
|
|
const Tensor* GetTensorFromVar(const Variable* var) const {
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
return const_cast<LoDTensor*>(&var->Get<LoDTensor>());
|
|
|
|
|
return &var->Get<LoDTensor>();
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(var->IsType<Tensor>(),
|
|
|
|
|
"The Input(%s) must be LoDTensor or Tensor.");
|
|
|
|
|
return const_cast<Tensor*>(&var->Get<Tensor>());
|
|
|
|
|
return &var->Get<Tensor>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -341,6 +341,13 @@ class InferShapeContext {
|
|
|
|
|
const Scope& scope_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct EigenDeviceConverter;
|
|
|
|
|
|
|
|
|
@ -397,6 +404,13 @@ class ExecutionContext : public InferShapeContext {
|
|
|
|
|
const platform::DeviceContext* device_context_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
class OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
/**
|
|
|
|
|