|
|
|
@ -79,31 +79,28 @@ class OperatorBase {
|
|
|
|
|
|
|
|
|
|
virtual ~OperatorBase() {}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline const T& Attr(const std::string& name) const {
|
|
|
|
|
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
|
|
|
|
|
name);
|
|
|
|
|
return boost::get<T>(attrs_.at(name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// if scope is not null, also show dimensions of arguments
|
|
|
|
|
virtual std::string DebugStringEx(const Scope* scope) const;
|
|
|
|
|
|
|
|
|
|
std::string DebugString() const { return DebugStringEx(nullptr); }
|
|
|
|
|
|
|
|
|
|
/// Net will call this interface function to Run an op.
|
|
|
|
|
/// Executor will call this interface function to Run an op.
|
|
|
|
|
// The implementation should be written at RunImpl
|
|
|
|
|
void Run(const Scope& scope, const platform::Place& place);
|
|
|
|
|
|
|
|
|
|
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
|
|
|
|
|
virtual void Stop() {}
|
|
|
|
|
|
|
|
|
|
virtual bool IsNetOp() const { return false; }
|
|
|
|
|
/// if scope is not null, also show dimensions of arguments
|
|
|
|
|
virtual std::string DebugStringEx(const Scope* scope) const;
|
|
|
|
|
std::string DebugString() const { return DebugStringEx(nullptr); }
|
|
|
|
|
|
|
|
|
|
virtual bool SupportGPU() const { return false; }
|
|
|
|
|
|
|
|
|
|
/// rename inputs outputs name
|
|
|
|
|
void Rename(const std::string& old_name, const std::string& new_name);
|
|
|
|
|
const std::string& Type() const { return type_; }
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline const T& Attr(const std::string& name) const {
|
|
|
|
|
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
|
|
|
|
|
name);
|
|
|
|
|
return boost::get<T>(attrs_.at(name));
|
|
|
|
|
}
|
|
|
|
|
const AttributeMap& Attrs() const { return attrs_; }
|
|
|
|
|
|
|
|
|
|
const VariableNameMap& Inputs() const { return inputs_; }
|
|
|
|
|
const VariableNameMap& Outputs() const { return outputs_; }
|
|
|
|
@ -112,7 +109,7 @@ class OperatorBase {
|
|
|
|
|
std::string Input(const std::string& name) const;
|
|
|
|
|
//! Get a input which has multiple variables.
|
|
|
|
|
const std::vector<std::string>& Inputs(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
//! Get all inputs variable names
|
|
|
|
|
std::vector<std::string> InputVars() const;
|
|
|
|
|
|
|
|
|
|
//! Get a output with argument's name described in `op_proto`
|
|
|
|
@ -120,13 +117,9 @@ class OperatorBase {
|
|
|
|
|
//! Get an output which has multiple variables.
|
|
|
|
|
//! TODO add a vector_view to prevent memory copy.
|
|
|
|
|
const std::vector<std::string>& Outputs(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
//! Get all outputs variable names
|
|
|
|
|
virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
|
|
|
|
|
|
|
|
|
|
const std::string& Type() const { return type_; }
|
|
|
|
|
void SetType(const std::string& type) { type_ = type; }
|
|
|
|
|
const AttributeMap& Attrs() const { return attrs_; }
|
|
|
|
|
|
|
|
|
|
// Return a new operator instance, which is as same as this.
|
|
|
|
|
// Use unique_ptr to prevent caller forget to delete this pointer.
|
|
|
|
|
virtual std::unique_ptr<OperatorBase> Clone() const = 0;
|
|
|
|
@ -278,20 +271,6 @@ class ExecutionContext {
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
|
|
|
|
|
size_t j = 0) const {
|
|
|
|
|
PADDLE_ENFORCE_LT(i, InputSize(in));
|
|
|
|
|
PADDLE_ENFORCE_LT(j, OutputSize(out));
|
|
|
|
|
auto* in_var = MultiInputVar(in)[i];
|
|
|
|
|
auto* out_var = MultiOutputVar(out)[j];
|
|
|
|
|
if (!in_var->IsType<LoDTensor>()) return;
|
|
|
|
|
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
|
|
|
|
|
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
|
|
|
|
|
auto in_tensor = in_var->Get<LoDTensor>();
|
|
|
|
|
auto* out_tensor = out_var->GetMutable<LoDTensor>();
|
|
|
|
|
out_tensor->set_lod(in_tensor.lod());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::Place GetPlace() const { return device_context_.GetPlace(); }
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContextType>
|
|
|
|
|