/*-
 * Copyright (c) 2006, 2011 Allan Saddi <allan@saddi.com>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 * $Id$
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <sys/wait.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <libgen.h>
#include <signal.h>
#include <fcntl.h>
#include <errno.h>
#include <limits.h>
#include <netdb.h>

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <pthread.h>

#include "ajp.h"
#include "wsgi.h"
#include "version.h"

/* Some arbitrary constants */
#define MAX_SOCKETS 10
#define DEFAULT_MAX_PROCESSES 16
#define REAP_INTERVAL 2

/* Tweakables from glue */
extern int unquoteURL;
extern int multiprocess;

/* Misc UI stuff */
static char *progName;
static int verbosity = 0;
static int forking = 0;

/* Socket handling */
static int socks[MAX_SOCKETS];
static int sockCount, nfds;

/* Threading support */
static pthread_t _masterThread;
static int maxThreads = INT_MAX, threadCount;
static pthread_mutex_t threadCountMutex = PTHREAD_MUTEX_INITIALIZER;

/* Forking support */
static pid_t _masterProcess;
static int maxProcesses = DEFAULT_MAX_PROCESSES, processCount;

/* Used to hold data about each connection */
struct clientData {
  int sock;
  union {
    struct sockaddr_in in;
    struct sockaddr_in6 in6;
  } addr;
};

/* Pipe shared between parent and children to detect if parent dies */
static int pipefds[2]; /* pipefds[0] is child, pipefds[1] is parent */

/* Signal handling */
#define MAIN_LOOP_OK 0
#define MAIN_LOOP_DIE 1
#define MAIN_LOOP_RELOAD 2

static volatile sig_atomic_t mainLoopState = MAIN_LOOP_OK;

/* Simple wrapper around strerror */
static void
myperror(const char *subModule)
{
  if (subModule == NULL)
    fprintf(stderr, "%s: %s\n", progName, strerror(errno));
  else
    fprintf(stderr, "%s: %s: %s\n", progName, subModule, strerror(errno));
}

static const char *
sockaddr_to_string(struct sockaddr *sa, char *dst, size_t size)
{
  char addr[INET6_ADDRSTRLEN];
  in_port_t port;

  switch (sa->sa_family) {
  case AF_INET:
    inet_ntop(AF_INET, &((struct sockaddr_in *)sa)->sin_addr, addr,
	      sizeof(addr));
    port = ntohs(((struct sockaddr_in *)sa)->sin_port);
    break;
  case AF_INET6:
    inet_ntop(AF_INET6, &((struct sockaddr_in6 *)sa)->sin6_addr, addr,
	      sizeof(addr));
    port = ntohs(((struct sockaddr_in6 *)sa)->sin6_port);
    break;
  default:
    strncpy(dst, "[unknown address family]", size - 1);
    dst[size - 1] = '\0';
    return dst;
  }

  snprintf(dst, size, "%s:%d", addr, port);
  return dst;
}

/* Connection handler for threads */
static void *
threadClientHandler(void *clientData)
{
  struct clientData *cd = clientData;
  AJPContext *ctxt;
  char addr[100];

  if (verbosity)
    fprintf(stderr, "%s: Connection opened (%s)\n",
	    progName,
	    sockaddr_to_string((struct sockaddr *)&cd->addr, addr,
			       sizeof(addr)));
  if ((ctxt = ajpCreateContext(cd->sock,
			       (ajp_handler_t)wsgiHandler,
			       (ajp_body_handler_t)wsgiBodyHandlerNC))) {
    int nfds = cd->sock + 1;

    /* Main loop */
    while (mainLoopState == MAIN_LOOP_OK) {
      fd_set readfds;
      int ready;

      /* Wait on client socket */
      FD_ZERO(&readfds);
      FD_SET(cd->sock, &readfds);

      if ((ready = select(nfds, &readfds, NULL, NULL, NULL)) == -1 &&
	  errno != EINTR) {
	myperror("select");
	exit(2);
      }

      if (ready > 0) {
	if (FD_ISSET(cd->sock, &readfds)) {
	  /* Handle request from AJP client */
	  if (ajpProcessInput(ctxt) != AJP_OK)
	    break;
	}
      }
    }

    ajpDestroyContext(ctxt);
  }
  if (verbosity)
    fprintf(stderr, "%s: Connection closed (%s)\n",
	    progName,
	    addr);

  free(cd);

  pthread_mutex_lock(&threadCountMutex);
  threadCount--;
  pthread_mutex_unlock(&threadCountMutex);

  /* Relay signal to main thread, just in case */
  if (mainLoopState != MAIN_LOOP_OK) {
    pthread_kill(_masterThread, mainLoopState == MAIN_LOOP_RELOAD ? SIGHUP :
		 SIGINT);
  }

  return NULL;
}

/* Connection handler for processes */
static void
procClientHandler(struct clientData *cd, int parent)
{
  AJPContext *ctxt;
  char addr[100];

  if (verbosity)
    fprintf(stderr, "%s: Connection opened (%s)\n",
	    progName,
	    sockaddr_to_string((struct sockaddr *)&cd->addr, addr,
			       sizeof(addr)));
  if ((ctxt = ajpCreateContext(cd->sock,
			       (ajp_handler_t)wsgiHandler,
			       (ajp_body_handler_t)wsgiBodyHandlerNC))) {
    int nfds = cd->sock;

    /* Figure out max number of FDs we have to check */
    if (parent > nfds)
      nfds = parent;
    nfds++;

    /* Main loop */
    while (mainLoopState == MAIN_LOOP_OK) {
      fd_set readfds;
      int ready;

      /* Wait on both client and parent sockets */
      FD_ZERO(&readfds);
      FD_SET(cd->sock, &readfds);
      FD_SET(parent, &readfds);

      if ((ready = select(nfds, &readfds, NULL, NULL, NULL)) == -1 &&
	  errno != EINTR) {
	myperror("select");
	exit(2);
      }

      if (ready > 0) {
	if (FD_ISSET(cd->sock, &readfds)) {
	  /* Handle request from AJP client */
	  if (ajpProcessInput(ctxt) != AJP_OK)
	    break;
	}
	if (FD_ISSET(parent, &readfds)) {
	  /* Only thing this could mean is that the parent died */
	  break;
	}
      }
    }

    ajpDestroyContext(ctxt);
  }
  if (verbosity)
    fprintf(stderr, "%s: Connection closed (%s)\n",
	    progName,
	    addr);

  free(cd);
}

/* Connection handoff for threads */
static void
threadHandoff(struct clientData *cd)
{
  pthread_mutex_lock(&threadCountMutex);
  if (threadCount >= maxThreads) {
    close(cd->sock);
    free(cd);
  }
  else {
    pthread_t t;

    if (pthread_create(&t, NULL, threadClientHandler, cd)) {
      myperror("pthread_create");
      close(cd->sock);
      free(cd);
    }
    else {
      pthread_detach(t);
      threadCount++;
    }
  }
  pthread_mutex_unlock(&threadCountMutex);
}

static void
setCloseOnExec(int fd)
{
  int flags;

  if ((flags = fcntl(fd, F_GETFD)) == -1) {
    perror("fcntl: F_GETFD");
    exit(2);
  }
  flags |= FD_CLOEXEC;
  if (fcntl(fd, F_SETFD, flags) == -1) {
    perror("fnctl: F_SETFD");
    exit(2);
  }
}

/* Connection handoff for processes */
static void
procHandoff(struct clientData *cd)
{
  if (processCount >= maxProcesses) {
    close(cd->sock);
    free(cd);
  }
  else {
    int fd, fd_max;

    switch (fork()) {
    case -1:
      /* error */
      myperror("fork");
      close(cd->sock);
      free(cd);
      break;
    default:
      /* parent */
      processCount++;
      close(cd->sock);
      free(cd);
      break;
    case 0:
      /* child */
      /* Close all extraneous FDs */
      fd_max = getdtablesize();
      for (fd = STDERR_FILENO + 1; fd < fd_max; fd++) {
	if (fd != cd->sock && fd != pipefds[0])
	  close(fd);
      }

      /* Set close-on-exec on socket/pipe FDs */
      setCloseOnExec(cd->sock);
      setCloseOnExec(pipefds[0]);

      procClientHandler(cd, pipefds[0]);
      exit(0);
    }
  }
}

/* Creates the listening socket(s) */
static int
createSocket(const char *ifname, int port)
{
  struct addrinfo hints, *res, *res0;
  int error;
  int s;
  int on = 1;
  char portBuf[10];

  snprintf(portBuf, sizeof(portBuf), "%d", port);

  memset(&hints, 0, sizeof(hints));
  hints.ai_family = PF_UNSPEC;
  hints.ai_socktype = SOCK_STREAM;
  hints.ai_flags = AI_PASSIVE;
  error = getaddrinfo(ifname, portBuf, &hints, &res0);
  if (error) {
    fprintf(stderr, "%s: getaddrinfo: %s\n", progName,
	    gai_strerror(error));
    return -1;
  }

  sockCount = 0;
  nfds = -1;
  for (res = res0; res; res = res->ai_next) {
    if ((s = socket(res->ai_family, res->ai_socktype,
		    res->ai_protocol)) == -1) {
      myperror("socket");
      continue;
    }

    if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on))) {
      myperror("setsockopt");
      close(s);
      continue;
    }

    if (bind(s, res->ai_addr, res->ai_addrlen)) {
      myperror("bind");
      close(s);
      continue;
    }

    if (listen(s, SOMAXCONN)) {
      myperror("listen");
      close(s);
      continue;
    }

    socks[sockCount++] = s;

    if (s > nfds)
      nfds = s;

    if (sockCount == MAX_SOCKETS) {
      fprintf (stderr, "%s: Maximum number of sockets bound\n",
	       progName);
      break;
    }
  }
  freeaddrinfo(res0);

  nfds++;

  return sockCount ? 0 : -1;
}

/* Signal handler for SIGHUP/SIGINT/SIGTERM */
static void
die(int sig)
{
  mainLoopState = sig == SIGHUP ? MAIN_LOOP_RELOAD : MAIN_LOOP_DIE;

  /* Relay signal to parent process, in case we're backgrounded */
  if (forking) {
    if (getpid() != _masterProcess)
      kill(_masterProcess, sig);
  }
}

/* Displays command-line usage */
static void
usage(void)
{
  int len, i;
  char *filler, *c;

  len = strlen(progName);
  if ((filler = malloc(len + 1)) == NULL) {
    myperror("malloc");
    exit(2);
  }
  for (i = 0, c = filler; i < len; i++)
    *(c++) = ' ';
  *c = '\0';

  fprintf(stderr,
	  "Usage: %s [-BFQVv] [-l <logFile>]\n"
	  "       %s [-h <ifname>] [-p <port>]\n"
	  "       %s [-n <maxConnections>]\n"
	  "       %s <moduleName> <appName> [<scriptName>]\n",
	  progName, filler, filler, filler);
  exit(1);
}

int
main(int argc, char *argv[])
{
  int ch;
  char *end;
  char *scriptName = "", *ifname = "localhost";
  int port = 8009;
  int daemonize = 0;
  char *logName = NULL;
  void (*handoff)(struct clientData *);

  if ((progName = strdup(basename(argv[0]))) == NULL) {
    perror(NULL);
    exit(2);
  }

  /* Command-line parsing */
  while ((ch = getopt(argc, argv, "BFQVl:h:n:p:t:v")) != -1) {
    switch (ch) {
    case 'B':
      daemonize++;
      break;
    case 'F':
      forking++;
      break;
    case 'Q':
      unquoteURL = 0;
      break;
    case 'V':
      fprintf (stderr, "%s %s (Python %s)\n", progName, VERSION_STRING,
	       wsgiPyVersion);
      exit(1);
      break;
    case 'l':
      logName = optarg;
      break;
    case 'h':
      ifname = optarg;
      break;
    case 't':
      fprintf(stderr, "%s: -t option is deprecated, use -n instead\n", progName);
      /* FALLTHRU */
    case 'n':
      maxThreads = strtol(optarg, &end, 10);
      if (*end || maxThreads < 1) {
	fprintf(stderr, "%s: Bad maxConnections argument\n", progName);
	exit(1);
      }
      maxProcesses = maxThreads;
      break;
    case 'p':
      port = strtol(optarg, &end, 10);
      if (*end || port < 1 || port > 65535) {
	fprintf (stderr, "%s: Bad port argument\n", progName);
	exit(1);
      }
      break;
    case 'v':
      verbosity++;
      break;
    default:
      usage();
    }
  }
  argc -= optind;
  argv += optind;

  if (argc < 2 || argc > 3)
    usage();

  if (argc > 2) {
    /* Validate scriptName */
    scriptName = argv[2];
    if (scriptName[0]) {
      if (scriptName[0] != '/') {
	fprintf(stderr, "%s: scriptName must start with '/'\n", progName);
	exit(1);
      }
      if (scriptName[strlen(scriptName) - 1] == '/') {
	fprintf(stderr, "%s: scriptName must not end with '/'\n", progName);
	exit(1);
      }
    }
  }

  if (logName != NULL) {
    /* Set up log file */
    int log;

    if ((log = open(logName, O_CREAT|O_APPEND|O_WRONLY, 0666)) == -1) {
      myperror(logName);
      exit(2);
    }

    dup2(log, STDERR_FILENO);

    if (log != STDERR_FILENO)
      close(log);
  }

  if (daemonize) {
    /* Do this ourselves so we have more control */
    int fd = open("/dev/null", O_RDWR);
    dup2(fd, STDIN_FILENO);
    dup2(fd, STDOUT_FILENO);
    if (logName == NULL)
      dup2(fd, STDERR_FILENO);
    if (fd > STDERR_FILENO)
      close(fd);

    if (daemon (1, 1)) {
      myperror("daemon");
      exit(2);
    }
  }

  /* Create listening sockets */
  if (createSocket(ifname, port))
    exit(2);

  /* Initialize Python/WSGI */
  if (wsgiInit(argv[0], argv[1], scriptName, progName))
    exit(2);

  /* Ensure Python interpreter is shut down on exit */
  atexit(wsgiCleanup);

  /* Do things a little differently depending on whether we use threads or
     processes */
  if (!forking) {
    _masterThread = pthread_self();
    handoff = threadHandoff;
  }
  else {
    /* Set up pipe for parent death detection */
    if (pipe(pipefds) == -1) {
      myperror("pipe");
      exit(2);
    }

    _masterProcess = getpid();
    handoff = procHandoff;
  }

  /* for wsgi.multithread & wsgi.multiprocess */
  multiprocess = forking;

  /* Install signal handlers */
  signal(SIGHUP, die);
  signal(SIGINT, die);
  signal(SIGTERM, die);

  /* Main loop */
  while (mainLoopState == MAIN_LOOP_OK) {
    fd_set readfds;
    int i;
    struct timeval timeout;
    int ready;
    int status;

    /* Wait on all listening sockets */
    FD_ZERO(&readfds);
    for (i = 0; i < sockCount; i++)
      FD_SET(socks[i], &readfds);

    /* Wait for activity */
    timeout.tv_sec = REAP_INTERVAL;
    timeout.tv_usec = 0;
    if ((ready = select(nfds, &readfds, NULL, NULL, forking ? &timeout : NULL))
	== -1 && errno != EINTR) {
      myperror("select");
      exit(2);
    }

    /* Reap children, if any */
    if (forking) {
      while (waitpid(-1, &status, WNOHANG) > 0) {
	if (processCount > 0) {
	  processCount--;
	}
      }
    }

    /* Scan for active sockets */
    if (ready > 0) {
      for (i = 0; i < sockCount; i++) {
	if (FD_ISSET(socks[i], &readfds)) {
	  struct clientData *cd;
	  socklen_t addrLen = sizeof(cd->addr);

	  if ((cd = malloc(sizeof(*cd))) == NULL) {
	    myperror("malloc");
	    exit(2);
	  }

	  /* Accept client connection... */
	  if ((cd->sock = accept(socks[i], (struct sockaddr *)&cd->addr,
				 &addrLen)) != -1) {
	    /* ...and hand it off */
	    handoff(cd);
	  }
	  else if (errno != EINTR) {
	    myperror("accept");
	    free(cd);
	  }
	}
      }
    }
  }

  if (mainLoopState == MAIN_LOOP_RELOAD) {
    if (verbosity)
      fprintf(stderr, "%s: Exiting for reload...\n", progName);
    exit(3);
  }
  else {
    if (verbosity)
      fprintf(stderr, "%s: Exiting...\n", progName);
    exit(0);
  }
}
