/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file rpc_server.cc * \brief RPC Server for TVM. */ #include #include #include #if defined(__linux__) || defined(__ANDROID__) #include #endif #include #include #include #include #include #include "../../src/support/socket.h" #include "../../src/support/utils.h" #include "rpc_server.h" #if defined(_WIN32) #include "win32_process.h" #endif using namespace std; using namespace tvm::runtime; using namespace tvm::support; static const string kUsage = "Command line usage\n" " server - Start the server\n" "--host - The hostname of the server, Default=0.0.0.0\n" "--port - The port of the RPC, Default=9090\n" "--port-end - The end search port of the RPC, Default=9099\n" "--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" "--key - The key used to identify the device type in tracker. Default=\"\"\n" "--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" "--work-dir - Custom work directory. Default=\"\"\n" "--silent - Whether to run in silent mode. Default=False\n" "\n" " Example\n" " ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 " " --tracker=127.0.0.1:9190 --key=rasp" "\n"; /*! * \brief RpcServerArgs. * \arg host The hostname of the server, Default=0.0.0.0 * \arg port The port of the RPC, Default=9090 * \arg port_end The end search port of the RPC, Default=9099 * \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" * \arg key The key used to identify the device type in tracker. Default="" * \arg custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \arg work_dir Custom work directory. Default="" * \arg silent Whether run in silent mode. Default=False */ struct RpcServerArgs { string host = "0.0.0.0"; int port = 9090; int port_end = 9099; string tracker; string key; string custom_addr; string work_dir; bool silent = false; #if defined(WIN32) std::string mmap_path; #endif }; /*! * \brief PrintArgs print the contents of RpcServerArgs * \param args RpcServerArgs structure */ void PrintArgs(const RpcServerArgs& args) { LOG(INFO) << "host = " << args.host; LOG(INFO) << "port = " << args.port; LOG(INFO) << "port_end = " << args.port_end; LOG(INFO) << "tracker = " << args.tracker; LOG(INFO) << "key = " << args.key; LOG(INFO) << "custom_addr = " << args.custom_addr; LOG(INFO) << "work_dir = " << args.work_dir; LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False")); } #if defined(__linux__) || defined(__ANDROID__) /*! * \brief CtrlCHandler, exits if Ctrl+C is pressed * \param s signal */ void CtrlCHandler(int s) { LOG(INFO) << "\nUser pressed Ctrl+C, Exiting"; exit(1); } /*! * \brief HandleCtrlC Register for handling Ctrl+C event. */ void HandleCtrlC() { // Ctrl+C handler struct sigaction sigIntHandler; sigIntHandler.sa_handler = CtrlCHandler; sigemptyset(&sigIntHandler.sa_mask); sigIntHandler.sa_flags = 0; sigaction(SIGINT, &sigIntHandler, nullptr); } #endif /*! * \brief GetCmdOption Parse and find the command option. * \param argc arg counter * \param argv arg values * \param option command line option to search for. * \param key whether the option itself is key * \return value corresponding to option. */ string GetCmdOption(int argc, char* argv[], string option, bool key = false) { string cmd; for (int i = 1; i < argc; ++i) { string arg = argv[i]; if (arg.find(option) == 0) { if (key) { cmd = argv[i]; return cmd; } // We assume "=" is the end of option. ICHECK_EQ(*option.rbegin(), '='); cmd = arg.substr(arg.find('=') + 1); return cmd; } } return cmd; } /*! * \brief ValidateTracker Check the tracker address format is correct and changes the format. * \param tracker The tracker input. * \return result of operation. */ bool ValidateTracker(string& tracker) { vector list = Split(tracker, ':'); if ((list.size() != 2) || (!ValidateIP(list[0])) || (!IsNumber(list[1]))) { return false; } ostringstream ss; ss << "('" << list[0] << "', " << list[1] << ")"; tracker = ss.str(); return true; } /*! * \brief ParseCmdArgs parses the command line arguments. * \param argc arg counter * \param argv arg values * \param args the output structure which holds the parsed values */ void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) { const string silent = GetCmdOption(argc, argv, "--silent", true); if (!silent.empty()) { args.silent = true; // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } const string host = GetCmdOption(argc, argv, "--host="); if (!host.empty()) { if (!ValidateIP(host)) { LOG(WARNING) << "Wrong host address format."; LOG(INFO) << kUsage; exit(1); } args.host = host; } const string port = GetCmdOption(argc, argv, "--port="); if (!port.empty()) { if (!IsNumber(port) || stoi(port) > 65535) { LOG(WARNING) << "Wrong port number."; LOG(INFO) << kUsage; exit(1); } args.port = stoi(port); } const string port_end = GetCmdOption(argc, argv, "--port_end="); if (!port_end.empty()) { if (!IsNumber(port_end) || stoi(port_end) > 65535) { LOG(WARNING) << "Wrong port_end number."; LOG(INFO) << kUsage; exit(1); } args.port_end = stoi(port_end); } string tracker = GetCmdOption(argc, argv, "--tracker="); if (!tracker.empty()) { if (!ValidateTracker(tracker)) { LOG(WARNING) << "Wrong tracker address format."; LOG(INFO) << kUsage; exit(1); } args.tracker = tracker; } const string key = GetCmdOption(argc, argv, "--key="); if (!key.empty()) { args.key = key; } const string custom_addr = GetCmdOption(argc, argv, "--custom_addr="); if (!custom_addr.empty()) { if (!ValidateIP(custom_addr)) { LOG(WARNING) << "Wrong custom address format."; LOG(INFO) << kUsage; exit(1); } args.custom_addr = custom_addr; } #if defined(WIN32) const string mmap_path = GetCmdOption(argc, argv, "--child_proc="); if (!mmap_path.empty()) { args.mmap_path = mmap_path; dmlc::InitLogging("--minloglevel=0"); } #endif const string work_dir = GetCmdOption(argc, argv, "--work-dir="); if (!work_dir.empty()) { args.work_dir = work_dir; } } /*! * \brief RpcServer Starts the RPC server. * \param argc arg counter * \param argv arg values * \return result of operation. */ int RpcServer(int argc, char* argv[]) { RpcServerArgs args; /* parse the command line args */ ParseCmdArgs(argc, argv, args); PrintArgs(args); LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop."; #if defined(__linux__) || defined(__ANDROID__) // Ctrl+C handler HandleCtrlC(); #endif #if defined(WIN32) if (!args.mmap_path.empty()) { int ret = 0; try { ChildProcSocketHandler(args.mmap_path); } catch (const std::exception&) { ret = -1; } return ret; } #endif RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr, args.work_dir, args.silent); return 0; } /*! * \brief main The main function. * \param argc arg counter * \param argv arg values * \return result of operation. */ int main(int argc, char* argv[]) { if (argc <= 1) { LOG(INFO) << kUsage; return 0; } // Runs WSAStartup on Win32, no-op on POSIX Socket::Startup(); #if defined(_WIN32) SetEnvironmentVariableA("CUDA_CACHE_DISABLE", "1"); #endif if (0 == strcmp(argv[1], "server")) { return RpcServer(argc, argv); } LOG(INFO) << kUsage; return 0; }