/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include "./ms_service.grpc.pb.h" using grpc::Channel; using grpc::ClientContext; using grpc::Status; using ms_serving::MSService; using ms_serving::PredictReply; using ms_serving::PredictRequest; using ms_serving::Tensor; using ms_serving::TensorShape; class MSClient { public: explicit MSClient(std::shared_ptr channel) : stub_(MSService::NewStub(channel)) {} ~MSClient() = default; std::string Predict() { // Data we are sending to the server. PredictRequest request; Tensor data; TensorShape shape; shape.add_dims(2); shape.add_dims(2); *data.mutable_tensor_shape() = shape; data.set_tensor_type(ms_serving::MS_FLOAT32); std::vector input_data{1, 2, 3, 4}; data.set_data(input_data.data(), input_data.size() * sizeof(float)); *request.add_data() = data; *request.add_data() = data; std::cout << "intput tensor size is " << request.data_size() << std::endl; // Container for the data we expect from the server. PredictReply reply; // Context for the client. It could be used to convey extra information to // the server and/or tweak certain RPC behaviors. ClientContext context; // The actual RPC. Status status = stub_->Predict(&context, request, &reply); std::cout << "Compute [[1, 2], [3, 4]] + [[1, 2], [3, 4]]" << std::endl; // Act upon its status. if (status.ok()) { std::cout << "Add result is"; for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) { std::cout << " " << (reinterpret_cast(reply.mutable_result(0)->mutable_data()->data()))[i]; } std::cout << std::endl; return "RPC OK"; } else { std::cout << status.error_code() << ": " << status.error_message() << std::endl; return "RPC failed"; } } private: std::unique_ptr stub_; }; int main(int argc, char **argv) { // Instantiate the client. It requires a channel, out of which the actual RPCs // are created. This channel models a connection to an endpoint specified by // the argument "--target=" which is the only expected argument. // We indicate that the channel isn't authenticated (use of // InsecureChannelCredentials()). std::string target_str; std::string arg_target_str("--target"); if (argc > 1) { // parse target std::string arg_val = argv[1]; size_t start_pos = arg_val.find(arg_target_str); if (start_pos != std::string::npos) { start_pos += arg_target_str.size(); if (start_pos < arg_val.size() && arg_val[start_pos] == '=') { target_str = arg_val.substr(start_pos + 1); } else { std::cout << "The only correct argument syntax is --target=" << std::endl; return 0; } } else { target_str = "localhost:5500"; } } else { target_str = "localhost:5500"; } MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); std::string reply = client.Predict(); std::cout << "client received: " << reply << std::endl; return 0; }