|
|
|
@ -354,18 +354,18 @@ void OperatorBase::GenerateTemporaryNames() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool VarIsTensor(const Variable* var) {
|
|
|
|
|
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
|
|
|
|
|
static bool VarIsTensor(const Variable& var) {
|
|
|
|
|
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const Tensor* GetTensorFromVar(Variable* var) {
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
return var->GetMutable<LoDTensor>();
|
|
|
|
|
} else if (var->IsType<SelectedRows>()) {
|
|
|
|
|
return var->GetMutable<SelectedRows>()->mutable_value();
|
|
|
|
|
const Tensor* GetTensorFromVar(const Variable& var) {
|
|
|
|
|
if (var.IsType<LoDTensor>()) {
|
|
|
|
|
return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));
|
|
|
|
|
} else if (var.IsType<SelectedRows>()) {
|
|
|
|
|
return &(var.Get<SelectedRows>().value());
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
|
|
|
|
|
var->Type().name());
|
|
|
|
|
var.Type().name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
|
|
|
|
|
template <>
|
|
|
|
|
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
|
|
|
|
|
auto* var = InputVar(name);
|
|
|
|
|
return var == nullptr ? nullptr
|
|
|
|
|
: GetTensorFromVar(const_cast<Variable*>(var));
|
|
|
|
|
return var == nullptr ? nullptr : GetTensorFromVar(*var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -428,7 +427,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[&](const std::string& sub_name) {
|
|
|
|
|
auto var = scope_.FindVar(sub_name);
|
|
|
|
|
return var == nullptr ? nullptr : GetTensorFromVar(var);
|
|
|
|
|
return var == nullptr ? nullptr : GetTensorFromVar(*var);
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
@ -770,8 +769,10 @@ void OperatorWithKernel::TransferInplaceVarsBack(
|
|
|
|
|
for (auto& var_name : inplace_vars) {
|
|
|
|
|
VLOG(3) << "share inplace var " + var_name + " back to it's original scope";
|
|
|
|
|
auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name));
|
|
|
|
|
auto* transformed_tensor =
|
|
|
|
|
GetTensorFromVar(transfer_scope.FindVar(var_name));
|
|
|
|
|
auto* var = transfer_scope.FindVar(var_name);
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr",
|
|
|
|
|
var_name);
|
|
|
|
|
auto* transformed_tensor = GetTensorFromVar(*var);
|
|
|
|
|
original_tensor->ShareDataWith(*transformed_tensor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -784,11 +785,11 @@ Scope* OperatorWithKernel::TryTransferData(
|
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
|
|
auto* var = scope.FindVar(var_name);
|
|
|
|
|
// Only tensor can be tranfer to another device.
|
|
|
|
|
if (var == nullptr || !VarIsTensor(var)) {
|
|
|
|
|
if (var == nullptr || !VarIsTensor(*var)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto* tensor_in = GetTensorFromVar(var);
|
|
|
|
|
auto* tensor_in = GetTensorFromVar(*var);
|
|
|
|
|
if (!tensor_in->IsInitialized()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|