/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file rpc_pipe_impl.cc
 * \brief Pipe-based RPC channel.
 */
// Linux only for now, as linux is the most common usecase.
#if defined(__linux__) || defined(__ANDROID__)

#include <errno.h>
#include <signal.h>
#include <sys/types.h>
#include <tvm/runtime/registry.h>
#include <unistd.h>

#include <cstdlib>
#include <memory>

#include "../../support/pipe.h"
#include "rpc_endpoint.h"
#include "rpc_local_session.h"

namespace tvm {
namespace runtime {

class PipeChannel final : public RPCChannel {
 public:
  explicit PipeChannel(int readfd, int writefd, pid_t child_pid)
      : readfd_(readfd), writefd_(writefd), child_pid_(child_pid) {}

  ~PipeChannel() { Close(); }

  size_t Send(const void* data, size_t size) final {
    ssize_t n = write(writefd_, data, size);
    if (n == -1) {
      LOG(FATAL) << "Pipe write error";
    }
    return static_cast<size_t>(n);
  }

  size_t Recv(void* data, size_t size) final {
    ssize_t n = read(readfd_, data, size);
    if (n == -1) {
      LOG(FATAL) << "Pipe read error";
    }
    return static_cast<size_t>(n);
  }

  void Close() {
    close(readfd_);
    close(writefd_);
    kill(child_pid_, SIGKILL);
  }

 private:
  int readfd_;
  int writefd_;
  pid_t child_pid_;
};

Module CreatePipeClient(std::vector<std::string> cmd) {
  int parent2child[2];
  int child2parent[2];
  ICHECK_EQ(pipe(parent2child), 0);
  ICHECK_EQ(pipe(child2parent), 0);

  int parent_read = child2parent[0];
  int parent_write = parent2child[1];
  int child_read = parent2child[0];
  int child_write = child2parent[1];

  pid_t pid = fork();
  if (pid == 0) {
    // child process
    close(parent_read);
    close(parent_write);
    std::string sread_pipe = std::to_string(child_read);
    std::string swrite_pipe = std::to_string(child_write);
    std::vector<char*> argv;
    for (auto& str : cmd) {
      argv.push_back(dmlc::BeginPtr(str));
    }
    argv.push_back(dmlc::BeginPtr(sread_pipe));
    argv.push_back(dmlc::BeginPtr(swrite_pipe));
    argv.push_back(nullptr);
    execvp(argv[0], &argv[0]);
  }
  // parent process
  close(child_read);
  close(child_write);

  auto endpt = RPCEndpoint::Create(
      std::unique_ptr<PipeChannel>(new PipeChannel(parent_read, parent_write, pid)), "pipe",
      "pipe");
  endpt->InitRemoteSession(TVMArgs(nullptr, nullptr, 0));
  return CreateRPCSessionModule(CreateClientSession(endpt));
}

TVM_REGISTER_GLOBAL("rpc.CreatePipeClient").set_body([](TVMArgs args, TVMRetValue* rv) {
  std::vector<std::string> cmd;
  for (int i = 0; i < args.size(); ++i) {
    cmd.push_back(args[i].operator std::string());
  }
  *rv = CreatePipeClient(cmd);
});

}  // namespace runtime
}  // namespace tvm
#endif