avoid data transform for linspace OP (#27444)

revert-27520-disable_pr
wangchaochaohu 5 years ago committed by GitHub
parent a04524759e
commit 76fb95fe76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/linspace_op.h"
#include <string>
namespace paddle {
namespace operators {
@ -55,6 +56,12 @@ class LinspaceOp : public framework::OperatorWithKernel {
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return expected_kernel_type;
}
};
class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker {

@ -1453,10 +1453,13 @@ def linspace(start, stop, num, dtype=None, name=None):
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if not isinstance(start, Variable):
with device_guard("cpu"):
tensor_start = fill_constant([1], dtype, start)
if not isinstance(stop, Variable):
with device_guard("cpu"):
tensor_stop = fill_constant([1], dtype, stop)
if not isinstance(num, Variable):
with device_guard("cpu"):
tensor_num = fill_constant([1], 'int32', num)
if in_dygraph_mode():
return core.ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype',

Loading…
Cancel
Save