fix unneeded call to typecast op for string

pull/2855/head
nhussain 5 years ago
parent 8b78bbc3ac
commit 1aca3f6404

@ -113,22 +113,27 @@ Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *ou
}
Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) {
CHECK_FAIL_RETURN_UNEXPECTED(!((fill_value->type() == DataType::DE_STRING) && (input->type() != DataType::DE_STRING)),
const DataType &fill_type = fill_value->type();
const DataType &input_type = input->type();
const TensorShape &input_shape = input->shape();
CHECK_FAIL_RETURN_UNEXPECTED(!((fill_type == DataType::DE_STRING) && (input_type != DataType::DE_STRING)),
"Types do not match");
CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar");
std::shared_ptr<Tensor> out;
const DataType &to = input->type();
std::unique_ptr<TypeCastOp> op(new TypeCastOp(to));
std::shared_ptr<Tensor> out, fill_output;
std::shared_ptr<Tensor> fill_output;
RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output));
if (input_type != DataType::DE_STRING && fill_type != DataType::DE_STRING && input_type != fill_type) {
std::unique_ptr<TypeCastOp> op(new TypeCastOp(input_type));
RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output));
} else {
fill_output = fill_value;
}
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type()));
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input_shape, input_type));
switch (input->type().value()) {
switch (input_type.value()) {
case DataType::DE_BOOL: {
bool value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
@ -206,10 +211,10 @@ Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output
std::string_view fill_string_view;
RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {}));
std::string fill_string = std::string(fill_string_view);
for (int i = 0; i < input->shape().NumOfElements(); i++) {
for (int i = 0; i < input_shape.NumOfElements(); i++) {
strings.emplace_back(fill_string);
}
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input->shape()));
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input_shape));
break;
}
case DataType::DE_UNKNOWN: {

Loading…
Cancel
Save