Move GetTensor to tensor_util (#15011)
* refine tensor test=develop * refine tensor test=develop * fix device_context log test=developrevert-15207-remove_op_handle_lock_and_fix_var
parent
bc16bcda49
commit
b9fb03cf54
@ -1,42 +0,0 @@
|
|||||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
#include "paddle/fluid/framework/tensor.h"
|
|
||||||
#include "paddle/fluid/platform/temporary_allocator.h"
|
|
||||||
namespace paddle {
|
|
||||||
namespace platform {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
paddle::framework::Tensor GetTensor(
|
|
||||||
memory::allocation::AllocationPtr temp_allocation_ptr,
|
|
||||||
const framework::DDim &dim) {
|
|
||||||
auto &deleter = temp_allocation_ptr.get_deleter();
|
|
||||||
auto *allocation_ptr = temp_allocation_ptr.release();
|
|
||||||
auto shared_allocation =
|
|
||||||
std::shared_ptr<memory::allocation::Allocation>(allocation_ptr, deleter);
|
|
||||||
|
|
||||||
PADDLE_ENFORCE(dynamic_cast<TemporaryAllocation *>(allocation_ptr) != nullptr,
|
|
||||||
"The AllocationPtr must be TemporaryAllocation.");
|
|
||||||
PADDLE_ENFORCE_EQ(allocation_ptr->size(),
|
|
||||||
framework::product(dim) * sizeof(T));
|
|
||||||
|
|
||||||
paddle::framework::Tensor temp_tensor(std::type_index(typeid(T)));
|
|
||||||
temp_tensor.Resize(dim);
|
|
||||||
temp_tensor.ResetHolder(std::move(shared_allocation));
|
|
||||||
return temp_tensor;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace platform
|
|
||||||
} // namespace paddle
|
|
Loading…
Reference in new issue