/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file rpc_server.cc * \brief RPC Server implementation. */ #include #if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) #include #include #include #endif #include #include #include #include #include #include #include "../../src/runtime/rpc/rpc_endpoint.h" #include "../../src/runtime/rpc/rpc_socket_impl.h" #include "../../src/support/socket.h" #include "rpc_env.h" #include "rpc_server.h" #include "rpc_tracker_client.h" #if defined(_WIN32) #include "win32_process.h" #endif using namespace std::chrono; namespace tvm { namespace runtime { /*! * \brief wait the child process end. * \param status status value */ #if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) static pid_t waitPidEintr(int* status) { pid_t pid = 0; while ((pid = waitpid(-1, status, 0)) == -1) { if (errno == EINTR) { continue; } else { perror("waitpid"); abort(); } } return pid; } #endif #ifdef __ANDROID__ static std::string getNextString(std::stringstream* iss) { std::string str = iss->str(); size_t start = iss->tellg(); size_t len = str.size(); // Skip leading spaces. while (start < len && isspace(str[start])) start++; size_t end = start; while (end < len && !isspace(str[end])) end++; iss->seekg(end); return str.substr(start, end - start); } #endif /*! * \brief RPCServer RPC Server class. * * \param host The hostname of the server, Default=0.0.0.0 * * \param port_search_start The low end of the search range for an * available port for the RPC, Default=9090 * * \param port_search_end The high search the search range for an * available port for the RPC, Default=9099 * * \param tracker The address of RPC tracker in host:port format * (e.g. "10.77.1.234:9190") * * \param key The key used to identify the device type in tracker. * * \param custom_addr Custom IP Address to Report to RPC Tracker. */ class RPCServer { public: /*! * \brief Constructor. */ RPCServer(std::string host, int port_search_start, int port_search_end, std::string tracker_addr, std::string key, std::string custom_addr, std::string work_dir) : host_(std::move(host)), port_search_start_(port_search_start), my_port_(0), port_search_end_(port_search_end), tracker_addr_(std::move(tracker_addr)), key_(std::move(key)), custom_addr_(std::move(custom_addr)), work_dir_(std::move(work_dir)) {} /*! * \brief Destructor. */ ~RPCServer() { try { // Free the resources tracker_sock_.Close(); listen_sock_.Close(); } catch (...) { } } /*! * \brief Start Creates the RPC listen process and execution. */ void Start() { listen_sock_.Create(); my_port_ = listen_sock_.TryBindHost(host_, port_search_start_, port_search_end_); LOG(INFO) << "bind to " << host_ << ":" << my_port_; listen_sock_.Listen(1); std::future proc(std::async(std::launch::async, &RPCServer::ListenLoopProc, this)); proc.get(); // Close the listen socket listen_sock_.Close(); } private: /*! * \brief ListenLoopProc The listen process. */ void ListenLoopProc() { TrackerClient tracker(tracker_addr_, key_, custom_addr_, my_port_); while (true) { support::TCPSocket conn; support::SockAddr addr("0.0.0.0", 0); std::string opts; try { // step 1: setup tracker and report to tracker tracker.TryConnect(); // step 2: wait for in-coming connections AcceptConnection(&tracker, &conn, &addr, &opts); } catch (const char* msg) { LOG(WARNING) << "Socket exception: " << msg; // close tracker resource tracker.Close(); continue; } catch (const std::exception& e) { // close tracker resource tracker.Close(); LOG(WARNING) << "Exception standard: " << e.what(); continue; } int timeout = GetTimeOutFromOpts(opts); #if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) // step 3: serving if (timeout != 0) { const pid_t timer_pid = fork(); if (timer_pid == 0) { // Timer process sleep(timeout); _exit(0); } const pid_t worker_pid = fork(); if (worker_pid == 0) { // Worker process ServerLoopProc(conn, addr, work_dir_); _exit(0); } int status = 0; const pid_t finished_first = waitPidEintr(&status); if (finished_first == timer_pid) { kill(worker_pid, SIGTERM); } else if (finished_first == worker_pid) { kill(timer_pid, SIGTERM); } else { LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; } int status_second = 0; waitPidEintr(&status_second); // Logging. if (finished_first == timer_pid) { LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout << "), Process status = " << status_second; } else if (finished_first == worker_pid) { LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second; } } else { auto pid = fork(); if (pid == 0) { ServerLoopProc(conn, addr, work_dir_); _exit(0); } // Wait for the result int status = 0; wait(&status); LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status; } #elif defined(WIN32) auto start_time = high_resolution_clock::now(); try { SpawnRPCChild(conn.sockfd, seconds(timeout)); } catch (const std::exception&) { } auto dur = high_resolution_clock::now() - start_time; LOG(INFO) << "Serve Time " << duration_cast(dur).count() << "ms"; #else LOG(WARNING) << "Unknown platform. It is not known how to bring up the subprocess." << " RPC will be launched in the main thread."; ServerLoopProc(conn, addr, work_dir_); #endif // close from our side. LOG(INFO) << "Socket Connection Closed"; conn.Close(); } } /*! * \brief AcceptConnection Accepts the RPC Server connection. * \param tracker Tracker details. * \param conn_sock New connection information. * \param addr New connection address information. * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ void AcceptConnection(TrackerClient* tracker, support::TCPSocket* conn_sock, support::SockAddr* addr, std::string* opts, int ping_period = 2) { std::set old_keyset; std::string matchkey; // Report resource to tracker and get key tracker->ReportResourceAndGetKey(my_port_, &matchkey); while (true) { tracker->WaitConnectionAndUpdateKey(listen_sock_, my_port_, ping_period, &matchkey); support::TCPSocket conn = listen_sock_.Accept(addr); int code = kRPCMagic; ICHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); if (code != kRPCMagic) { conn.Close(); LOG(FATAL) << "Client connected is not TVM RPC server"; continue; } int keylen = 0; ICHECK_EQ(conn.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen)); const char* CLIENT_HEADER = "client:"; const char* SERVER_HEADER = "server:"; std::string expect_header = CLIENT_HEADER + matchkey; std::string server_key = SERVER_HEADER + key_; if (size_t(keylen) < expect_header.length()) { conn.Close(); LOG(INFO) << "Wrong client header length"; continue; } ICHECK_NE(keylen, 0); std::string remote_key; remote_key.resize(keylen); ICHECK_EQ(conn.RecvAll(&remote_key[0], keylen), keylen); std::stringstream ssin(remote_key); std::string arg0; #ifndef __ANDROID__ ssin >> arg0; #else arg0 = getNextString(&ssin); #endif if (arg0 != expect_header) { code = kRPCMismatch; ICHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); conn.Close(); LOG(WARNING) << "Mismatch key from" << addr->AsString(); continue; } else { code = kRPCSuccess; ICHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); keylen = int(server_key.length()); ICHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); ICHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); LOG(INFO) << "Connection success " << addr->AsString(); #ifndef __ANDROID__ ssin >> *opts; #else *opts = getNextString(&ssin); #endif *conn_sock = conn; return; } } } /*! * \brief ServerLoopProc The Server loop process. * \param sock The socket information * \param addr The socket address information */ static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr, std::string work_dir) { // Server loop const auto env = RPCEnv(work_dir); RPCServerLoop(int(sock.sockfd)); LOG(INFO) << "Finish serving " << addr.AsString(); env.CleanUp(); } /*! * \brief GetTimeOutFromOpts Parse and get the timeout option. * \param opts The option string */ int GetTimeOutFromOpts(const std::string& opts) const { const std::string option = "-timeout="; size_t pos = opts.rfind(option); if (pos != std::string::npos) { const std::string cmd = opts.substr(pos + option.size()); ICHECK(support::IsNumber(cmd)) << "Timeout is not valid"; return std::stoi(cmd); } return 0; } std::string host_; int port_search_start_; int my_port_; int port_search_end_; std::string tracker_addr_; std::string key_; std::string custom_addr_; std::string work_dir_; support::TCPSocket listen_sock_; support::TCPSocket tracker_sock_; }; #if defined(WIN32) /*! * \brief ServerLoopFromChild The Server loop process. * \param socket The socket information */ void ServerLoopFromChild(SOCKET socket) { // Server loop tvm::support::TCPSocket sock(socket); const auto env = RPCEnv(); RPCServerLoop(int(sock.sockfd)); sock.Close(); env.CleanUp(); } #endif /*! * \brief RPCServerCreate Creates the RPC Server. * \param host The hostname of the server, Default=0.0.0.0 * \param port The port of the RPC, Default=9090 * \param port_end The end search port of the RPC, Default=9099 * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 * Default="" \param key The key used to identify the device type in tracker. Default="" \param * custom_addr Custom IP Address to Report to RPC Tracker. Default="" \param silent Whether run in * silent mode. Default=True */ void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr, std::string key, std::string custom_addr, std::string work_dir, bool silent) { if (silent) { // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } // Start the rpc server RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), std::move(custom_addr), std::move(work_dir)); rpc.Start(); } TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) { RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); }); } // namespace runtime } // namespace tvm