|
|
|
@ -182,8 +182,6 @@ static const Tensor* GetTensorFromVar(const Variable* var) {
|
|
|
|
|
const Tensor* t = nullptr;
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
t = &(var->Get<LoDTensor>());
|
|
|
|
|
} else if (var->IsType<Tensor>()) {
|
|
|
|
|
t = &(var->Get<Tensor>());
|
|
|
|
|
} else if (var->IsType<SelectedRows>()) {
|
|
|
|
|
t = &(var->Get<SelectedRows>().value());
|
|
|
|
|
} else {
|
|
|
|
@ -197,8 +195,6 @@ static Tensor* GetMutableTensorFromVar(Variable* var) {
|
|
|
|
|
Tensor* t = nullptr;
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
t = var->GetMutable<LoDTensor>();
|
|
|
|
|
} else if (var->IsType<Tensor>()) {
|
|
|
|
|
t = var->GetMutable<Tensor>();
|
|
|
|
|
} else if (var->IsType<SelectedRows>()) {
|
|
|
|
|
t = var->GetMutable<SelectedRows>()->mutable_value();
|
|
|
|
|
} else {
|
|
|
|
@ -362,8 +358,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
Variable* var = scope_.FindVar(name);
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
return var->Get<LoDTensor>().dims();
|
|
|
|
|
} else if (var->IsType<Tensor>()) {
|
|
|
|
|
return var->Get<Tensor>().dims();
|
|
|
|
|
} else if (var->IsType<SelectedRows>()) {
|
|
|
|
|
return var->Get<SelectedRows>().GetCompleteDims();
|
|
|
|
|
} else {
|
|
|
|
@ -376,8 +370,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
Variable* var = scope_.FindVar(name);
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
var->GetMutable<LoDTensor>()->Resize(dim);
|
|
|
|
|
} else if (var->IsType<Tensor>()) {
|
|
|
|
|
var->GetMutable<Tensor>()->Resize(dim);
|
|
|
|
|
} else if (var->IsType<SelectedRows>()) {
|
|
|
|
|
var->GetMutable<SelectedRows>()->set_height(dim[0]);
|
|
|
|
|
} else {
|
|
|
|
|