|
|
|
@ -113,22 +113,27 @@ Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *ou
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) {
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(!((fill_value->type() == DataType::DE_STRING) && (input->type() != DataType::DE_STRING)),
|
|
|
|
|
const DataType &fill_type = fill_value->type();
|
|
|
|
|
const DataType &input_type = input->type();
|
|
|
|
|
const TensorShape &input_shape = input->shape();
|
|
|
|
|
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(!((fill_type == DataType::DE_STRING) && (input_type != DataType::DE_STRING)),
|
|
|
|
|
"Types do not match");
|
|
|
|
|
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar");
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<Tensor> out;
|
|
|
|
|
|
|
|
|
|
const DataType &to = input->type();
|
|
|
|
|
std::unique_ptr<TypeCastOp> op(new TypeCastOp(to));
|
|
|
|
|
std::shared_ptr<Tensor> out, fill_output;
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<Tensor> fill_output;
|
|
|
|
|
RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output));
|
|
|
|
|
if (input_type != DataType::DE_STRING && fill_type != DataType::DE_STRING && input_type != fill_type) {
|
|
|
|
|
std::unique_ptr<TypeCastOp> op(new TypeCastOp(input_type));
|
|
|
|
|
RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output));
|
|
|
|
|
} else {
|
|
|
|
|
fill_output = fill_value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type()));
|
|
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input_shape, input_type));
|
|
|
|
|
|
|
|
|
|
switch (input->type().value()) {
|
|
|
|
|
switch (input_type.value()) {
|
|
|
|
|
case DataType::DE_BOOL: {
|
|
|
|
|
bool value = 0;
|
|
|
|
|
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
|
|
|
@ -206,10 +211,10 @@ Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output
|
|
|
|
|
std::string_view fill_string_view;
|
|
|
|
|
RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {}));
|
|
|
|
|
std::string fill_string = std::string(fill_string_view);
|
|
|
|
|
for (int i = 0; i < input->shape().NumOfElements(); i++) {
|
|
|
|
|
for (int i = 0; i < input_shape.NumOfElements(); i++) {
|
|
|
|
|
strings.emplace_back(fill_string);
|
|
|
|
|
}
|
|
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input->shape()));
|
|
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input_shape));
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case DataType::DE_UNKNOWN: {
|
|
|
|
|