|
|
|
@ -14,7 +14,6 @@
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
@ -40,10 +39,11 @@ class Tracer {
|
|
|
|
|
|
|
|
|
|
virtual ~Tracer() {}
|
|
|
|
|
|
|
|
|
|
void Trace(OpBase* op,
|
|
|
|
|
const std::map<std::string, std::vector<VarBase*>>& inputs,
|
|
|
|
|
const std::map<std::string, std::vector<VarBase*>>& outputs,
|
|
|
|
|
framework::BlockDesc* block, const bool stop_gradient = false);
|
|
|
|
|
void Trace(
|
|
|
|
|
OpBase* op,
|
|
|
|
|
const std::map<std::string, std::vector<VarBase*>>& inputs, // NOLINT
|
|
|
|
|
const std::map<std::string, std::vector<VarBase*>>& outputs, // NOLINT
|
|
|
|
|
framework::BlockDesc* block, const bool stop_gradient = false);
|
|
|
|
|
|
|
|
|
|
std::vector<VarBase*> PyTrace(OpBase* op, const std::vector<VarBase*>& inputs,
|
|
|
|
|
bool stop_gradient = false);
|
|
|
|
|