|
|
|
@ -36,11 +36,13 @@ namespace detail {
|
|
|
|
|
|
|
|
|
|
class VariableResponse {
|
|
|
|
|
public:
|
|
|
|
|
VariableResponse(const framework::Scope* scope,
|
|
|
|
|
VariableResponse(bool use_local_scope, const framework::Scope* scope,
|
|
|
|
|
const platform::DeviceContext* dev_ctx)
|
|
|
|
|
: scope_(scope), dev_ctx_(dev_ctx) {}
|
|
|
|
|
: use_local_scope_(use_local_scope), scope_(scope), dev_ctx_(dev_ctx) {
|
|
|
|
|
local_scope_ = &scope->NewScope();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual ~VariableResponse() {}
|
|
|
|
|
virtual ~VariableResponse() { scope_->DeleteScope(local_scope_); }
|
|
|
|
|
|
|
|
|
|
// return:
|
|
|
|
|
// 0:ok.
|
|
|
|
@ -54,11 +56,25 @@ class VariableResponse {
|
|
|
|
|
// other: number of error field.
|
|
|
|
|
int Parse(const ::grpc::ByteBuffer& byte_buffer);
|
|
|
|
|
|
|
|
|
|
const framework::Scope& GetLocalScope() const { return *local_scope_; }
|
|
|
|
|
|
|
|
|
|
inline std::string Varname() { return meta_.varname(); }
|
|
|
|
|
inline std::string OutVarname() { return meta_.out_varname(); }
|
|
|
|
|
|
|
|
|
|
// should call parse first.
|
|
|
|
|
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
|
|
|
|
|
framework::Variable* GetVar() {
|
|
|
|
|
return local_scope_->FindVar(meta_.varname());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::Variable* InitVar() {
|
|
|
|
|
if (use_local_scope_) {
|
|
|
|
|
bool has_var = (scope_->FindVar(meta_.varname()) != nullptr);
|
|
|
|
|
PADDLE_ENFORCE(has_var);
|
|
|
|
|
return local_scope_->Var(meta_.varname());
|
|
|
|
|
} else {
|
|
|
|
|
return scope_->FindVar(meta_.varname());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
|
|
|
|
@ -73,7 +89,9 @@ class VariableResponse {
|
|
|
|
|
const framework::DDim& dims, int length);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
bool use_local_scope_ = false;
|
|
|
|
|
const framework::Scope* scope_;
|
|
|
|
|
framework::Scope* local_scope_ = nullptr;
|
|
|
|
|
const platform::DeviceContext* dev_ctx_;
|
|
|
|
|
// only Skeleton
|
|
|
|
|
sendrecv::VariableMessage meta_;
|
|
|
|
|