|
|
|
@ -20,24 +20,6 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace platform {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void print_lod_tensor(const std::string& var_name,
|
|
|
|
|
const framework::LoDTensor& lod_tensor,
|
|
|
|
|
const std::string& print_info) {
|
|
|
|
|
auto inspect = lod_tensor.data<T>();
|
|
|
|
|
auto element_num = lod_tensor.numel();
|
|
|
|
|
|
|
|
|
|
std::ostringstream sstream;
|
|
|
|
|
sstream << print_info << "\t";
|
|
|
|
|
sstream << var_name << "\t";
|
|
|
|
|
sstream << inspect[0];
|
|
|
|
|
for (int j = 1; j < element_num; ++j) {
|
|
|
|
|
sstream << " " << inspect[j];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::cout << sstream.str() << std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PrintVar(framework::Scope* scope, const std::string& var_name,
|
|
|
|
|
const std::string& print_info) {
|
|
|
|
|
framework::Variable* var = scope->FindVar(var_name);
|
|
|
|
@ -52,26 +34,11 @@ void PrintVar(framework::Scope* scope, const std::string& var_name,
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::LoDTensor printed_tensor;
|
|
|
|
|
printed_tensor.set_lod(tensor->lod());
|
|
|
|
|
printed_tensor.Resize(tensor->dims());
|
|
|
|
|
if (platform::is_cpu_place(tensor->place())) {
|
|
|
|
|
printed_tensor.ShareDataWith(*tensor);
|
|
|
|
|
} else {
|
|
|
|
|
platform::CPUPlace place;
|
|
|
|
|
framework::TensorCopy(*tensor, place, &printed_tensor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define PrintLoDTensorCallback(cpp_type, proto_type) \
|
|
|
|
|
do { \
|
|
|
|
|
if (tensor->type() == proto_type) { \
|
|
|
|
|
print_lod_tensor<cpp_type>(var_name, printed_tensor, print_info); \
|
|
|
|
|
return; \
|
|
|
|
|
} \
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
_ForEachDataType_(PrintLoDTensorCallback);
|
|
|
|
|
VLOG(1) << "PrintVar: unrecognized data type:" << printed_tensor.type();
|
|
|
|
|
std::ostringstream sstream;
|
|
|
|
|
sstream << print_info << "\t";
|
|
|
|
|
sstream << var_name << "\t";
|
|
|
|
|
sstream << *tensor << "\t";
|
|
|
|
|
std::cout << sstream.str() << std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // end namespace platform
|
|
|
|
|