/*-
 * Copyright (c) 2006 Allan Saddi <allan@saddi.com>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 * $Id$
 */

#include <Python.h>

#include <assert.h>
#include <ctype.h>

#include "wsgi-int.h"

const char *wsgiPyVersion = PY_VERSION;
PyObject *wsgiStderr;
const char *wsgiScriptName = "";
int wsgiScriptNameLen = 0;

static PyObject *pApp;
static PyThreadState *_main;

/* Assumes c is a valid hex digit */
static inline int
toxdigit(int c)
{
  if (c >= '0' && c <= '9')
    return c - '0';
  else if (c >= 'A' && c <= 'F')
    return c - 'A' + 10;
  else if (c >= 'a' && c <= 'f')
    return c - 'a' + 10;
  return -1;
}

/* Unquote an escaped path */
const char *
wsgiUnquote(const char *s)
{
  int len = strlen(s);
  char *result, *t;

  if ((result = PyMem_Malloc(len + 1)) == NULL)
    return NULL;

  t = result;
  while (*s) {
    if (*s == '%') {
      if (s[1] && s[2] && isxdigit(s[1]) && isxdigit(s[2])) {
	*(t++) = (toxdigit(s[1]) << 4) | toxdigit(s[2]);
	s += 3;
      }
      else
	*(t++) = *(s++);
    }
    else
      *(t++) = *(s++);
  }
  *t = '\0';

  return result;
}

int
wsgiPutEnv(Request *self, const char *key, const char *value)
{
  PyObject *val;
  int ret;

  if ((val = PyString_FromString(value)) == NULL)
    return -1;
  ret = PyDict_SetItemString(self->environ, key, val);
  Py_DECREF(val);
  if (ret)
    return -1;
  return 0;
}

static void
sendResponse500(void *ctxt)
{
  static const char *headers[] = {
    "Content-Type", "text/html; charset=iso-8859-1",
  };
  static const char *body =
    "<!DOCTYPE HTML PUBLIC \"-//IETF//DTD HTML 2.0//EN\">\n"
    "<html><head>\n"
    "<title>500 Internal Error</title>\n"
    "</head><body>\n"
    "<h1>Internal Error</h1>\n"
    "<p>The server encountered an unexpected condition which\n"
    "prevented it from fulfilling the request.</p>\n"
    "</body></html>\n";

  if (!wsgiSendHeaders(ctxt, 500, "Internal Error", 1, headers))
    wsgiSendBody(ctxt, (uint8_t *)body, strlen(body));
}

static void
Request_clear(Request *self)
{
  PyObject *tmp;

  tmp = self->result;
  self->result = NULL;
  Py_XDECREF(tmp);

  tmp = self->headers;
  self->headers = NULL;
  Py_XDECREF(tmp);

  tmp = self->status;
  self->status = NULL;
  Py_XDECREF(tmp);

  tmp = self->input;
  self->input = NULL;
  Py_XDECREF(tmp);

  tmp = self->environ;
  self->environ = NULL;
  Py_XDECREF(tmp);
}

static void
Request_dealloc(Request *self)
{
  Request_clear(self);
  self->ob_type->tp_free((PyObject *)self);
}

static PyObject *
Request_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
  Request *self;

  self = (Request *)type->tp_alloc(type, 0);
  if (self != NULL) {
    self->environ = PyDict_New();
    if (self->environ == NULL) {
      Py_DECREF(self);
      return NULL;
    }

    self->input = NULL;
    self->status = NULL;
    self->headers = NULL;
    self->result = NULL;
    self->headers_sent = 0;
  }

  return (PyObject *)self;
}

/* Constructor. Accepts the context CObject as its sole argument. */
static int
Request_init(Request *self, PyObject *args, PyObject *kwds)
{
  PyObject *context_obj, *args2;

  if (!PyArg_ParseTuple(args, "O!", &PyCObject_Type, &context_obj))
    return -1;

  self->context = PyCObject_AsVoidPtr(context_obj);

  if ((args2 = Py_BuildValue("(Oi)", self,
			     wsgiGetContentLength(self->context))) == NULL)
    return -1;

  self->input = PyObject_CallObject((PyObject *)&InputStream_Type, args2);
  Py_DECREF(args2);
  if (self->input == NULL)
    return -1;

  return wsgiPopulateEnviron(self);
}

/* start_response() callable implementation */
static PyObject *
Request_start_response(Request *self, PyObject *args)
{
  PyObject *status, *headers, *exc_info = NULL;
  PyObject *tmp;

  if (!PyArg_ParseTuple(args, "SO!|O:start_response", &status,
			&PyList_Type, &headers,
			&exc_info))
    return NULL;

  if (exc_info != NULL && exc_info != Py_None) {
    /* If the headers have already been sent, just propagate the
       exception. */
    if (self->headers_sent) {
      PyObject *type, *value, *tb;
      if (!PyArg_ParseTuple(exc_info, "OOO", &type, &value, &tb))
	return NULL;
      Py_INCREF(type);
      Py_INCREF(value);
      Py_INCREF(tb);
      PyErr_Restore(type, value, tb);
      return NULL;
    }
  }
  else if (self->status != NULL || self->headers != NULL) {
    PyErr_SetString(PyExc_AssertionError, "headers already set");
    return NULL;
  }

  /* TODO validation of status and headers */

  tmp = self->status;
  Py_INCREF(status);
  self->status = status;
  Py_XDECREF(tmp);

  tmp = self->headers;
  Py_INCREF(headers);
  self->headers = headers;
  Py_XDECREF(tmp);

  return PyObject_GetAttrString((PyObject *)self, "write");
}

/* Sends headers. Assumes self->status and self->headers are valid. */
static int
_wsgiSendHeaders(Request *self)
{
  const char *status, *statusMsg;
  int statusCode;
  int headerCount;
  const char **headers;
  int i, j;
  int found;
  char lenBuf[20];

  if ((status = PyString_AsString(self->status)) == NULL)
    return -1;

  statusCode = strtol(status, NULL, 10);
  statusMsg = &status[4];

  headerCount = PyList_Size(self->headers);
  if (PyErr_Occurred())
    return -1;

  /* NB: 1 extra header for Content-Length */
  if ((headers = PyMem_Malloc(sizeof(*headers) *
			      (headerCount * 2 + 2))) == NULL)
    return -1;

  for (i = 0, j = 0, found = 0; i < headerCount; i++) {
    PyObject *item;
    const char *header, *value;

    if ((item = PyList_GetItem(self->headers, i)) == NULL)
      goto bad;

    if (!PyArg_ParseTuple(item, "ss", &header, &value))
      goto bad;

    headers[j++] = header;
    headers[j++] = value;

    if (!found && !strcasecmp(header, "Content-Length"))
      found = 1;
  }

  /* See if we can deduce Content-Length if it wasn't given */
  if (!found && self->result != NULL) {
    /* Does result have a size? */
    int resultLen = PySequence_Size(self->result);
    if (!PyErr_Occurred()) {
      /* Does it have a length of 1? */
      if (resultLen == 1) {
	/* Content-Length is the length of the first & only string */
	PyObject *item = PySequence_GetItem(self->result, 0);
	if (item != NULL) {
	  int len = PyString_Size(item);
	  Py_DECREF(item);

	  if (!PyErr_Occurred()) {
	    snprintf(lenBuf, sizeof(lenBuf), "%d", len);
	    headerCount++;
	    headers[j++] = "Content-Length";
	    headers[j++] = lenBuf;
	  }
	}
      }
    }
    /* Don't care about any errors due to this */
    PyErr_Clear();
  }

  if (wsgiSendHeaders(self->context, statusCode, statusMsg,
		      headerCount, headers))
    goto bad;

  PyMem_Free(headers);
  return 0;

 bad:
  PyMem_Free(headers);
  return -1;
}

/* Send a chunk of data */
static inline int
wsgiWrite(Request *self, const char *data, int len)
{
  if (len) {
    if (wsgiSendBody(self->context, (uint8_t *)data, len))
      return -1;
  }
  return 0;
}

/* write() callable implementation */
static PyObject *
Request_write(Request *self, PyObject *args)
{
  const char *data;
  int dataLen;

  if (self->status == NULL && self->headers == NULL) {
    PyErr_SetString(PyExc_AssertionError, "write() before start_response()");
    return NULL;
  }

  if (!PyArg_ParseTuple(args, "s#:write", &data, &dataLen))
    return NULL;

  /* Send headers if necessary */
  if (!self->headers_sent) {
    if (_wsgiSendHeaders(self))
      return NULL;
    self->headers_sent = 1;
  }

  if (wsgiWrite(self, data, dataLen))
    return NULL;

  Py_INCREF(Py_None);
  return Py_None;
}

/* Send a wrapped file using wsgiSendFile */
static int
wsgiSendFileWrapper(Request *self, FileWrapper *wrapper)
{
  PyObject *pFileno, *args, *pFD;
  int fd;

  /* file-like must have fileno */
  if (!PyObject_HasAttrString((PyObject *)wrapper->filelike, "fileno"))
    return 1;

  if ((pFileno = PyObject_GetAttrString((PyObject *)wrapper->filelike,
					"fileno")) == NULL)
    return -1;

  if ((args = PyTuple_New(0)) == NULL) {
    Py_DECREF(pFileno);
    return -1;
  }

  pFD = PyObject_CallObject(pFileno, args);
  Py_DECREF(args);
  Py_DECREF(pFileno);
  if (pFD == NULL)
    return -1;

  fd = PyInt_AsLong(pFD);
  Py_DECREF(pFD);
  if (PyErr_Occurred())
    return -1;

  /* Send headers if necessary */
  if (!self->headers_sent) {
    if (_wsgiSendHeaders(self))
      return -1;
    self->headers_sent = 1;
  }

  if (wsgiSendFile(self->context, fd))
    return -1;

  return 0;
}

/* Send the application's response */
static int
wsgiSendResponse(Request *self, PyObject *result)
{
  PyObject *iter;
  PyObject *item;
  int ret;

  /* Check if it's a FileWrapper */
  if (result->ob_type == &FileWrapper_Type) {
    ret = wsgiSendFileWrapper(self, (FileWrapper *)result);
    if (ret < 0)
      return -1;
    if (!ret)
      return 0;
    /* Fallthrough */
  }

  iter = PyObject_GetIter(result);
  if (iter == NULL)
    return -1;

  while ((item = PyIter_Next(iter))) {
    int dataLen;
    const char *data;

    dataLen = PyString_Size(item);
    if (PyErr_Occurred()) {
      Py_DECREF(item);
      break;
    }

    if (dataLen) {
      if ((data = PyString_AsString(item)) == NULL) {
	Py_DECREF(item);
	break;
      }

      /* Send headers if necessary */
      if (!self->headers_sent) {
	if (_wsgiSendHeaders(self)) {
	  Py_DECREF(item);
	  break;
	}
	self->headers_sent = 1;
      }

      if (wsgiWrite(self, data, dataLen)) {
	Py_DECREF(item);
	break;
      }
    }

    Py_DECREF(item);
  }

  Py_DECREF(iter);

  if (PyErr_Occurred())
    return -1;

  /* Send headers if they haven't been sent at this point */
  if (!self->headers_sent) {
    if (_wsgiSendHeaders(self))
      return -1;
    self->headers_sent = 1;
  }

  return 0;
}

/* Ensure application's iterator's close() method is called */
static void
wsgiCallClose(PyObject *result)
{
  PyObject *type, *value, *traceback;
  PyObject *pClose, *args, *ret;

  /* Save exception state */
  PyErr_Fetch(&type, &value, &traceback);

  if (PyObject_HasAttrString(result, "close")) {
    pClose = PyObject_GetAttrString(result, "close");
    if (pClose != NULL) {
      args = PyTuple_New(0);
      if (args != NULL) {
	ret = PyObject_CallObject(pClose, args);
	Py_DECREF(args);
	Py_XDECREF(ret);
      }
      Py_DECREF(pClose);
    }
  }

  /* Restore exception state */
  PyErr_Restore(type, value, traceback);
}

static PyMethodDef Request_methods[] = {
  { "start_response", (PyCFunction)Request_start_response, METH_VARARGS,
    "WSGI start_response callable" },
  { "write", (PyCFunction)Request_write, METH_VARARGS,
    "WSGI write callable" },
  { NULL }
};

PyTypeObject Request_Type = {
  PyObject_HEAD_INIT(NULL)
  0,                         /*ob_size*/
  "_wsgisup.Request",        /*tp_name*/
  sizeof(Request),           /*tp_basicsize*/
  0,                         /*tp_itemsize*/
  (destructor)Request_dealloc, /*tp_dealloc*/
  0,                         /*tp_print*/
  0,                         /*tp_getattr*/
  0,                         /*tp_setattr*/
  0,                         /*tp_compare*/
  0,                         /*tp_repr*/
  0,                         /*tp_as_number*/
  0,                         /*tp_as_sequence*/
  0,                         /*tp_as_mapping*/
  0,                         /*tp_hash */
  0,                         /*tp_call*/
  0,                         /*tp_str*/
  0,                         /*tp_getattro*/
  0,                         /*tp_setattro*/
  0,                         /*tp_as_buffer*/
  Py_TPFLAGS_DEFAULT,        /*tp_flags*/
  "WSGI Request class",      /* tp_doc */
  0,		             /* tp_traverse */
  0,		             /* tp_clear */
  0,		             /* tp_richcompare */
  0,		             /* tp_weaklistoffset */
  0,		             /* tp_iter */
  0,		             /* tp_iternext */
  Request_methods,           /* tp_methods */
  0,                         /* tp_members */
  0,                         /* tp_getset */
  0,                         /* tp_base */
  0,                         /* tp_dict */
  0,                         /* tp_descr_get */
  0,                         /* tp_descr_set */
  0,                         /* tp_dictoffset */
  (initproc)Request_init,    /* tp_init */
  0,                         /* tp_alloc */
  Request_new,               /* tp_new */
};

static PyMethodDef wsgisup_methods[] = {
  { NULL }
};

#ifndef PyMODINIT_FUNC
#define PyMODINIT_FUNC void
#endif
PyMODINIT_FUNC
init_wsgisup(void)
{
  PyObject *m;

  if (PyType_Ready(&Request_Type) < 0)
    return;
  if (PyType_Ready(&InputStream_Type) < 0)
    return;
  if (PyType_Ready(&FileWrapper_Type) < 0)
    return;

  m = Py_InitModule3("_wsgisup", wsgisup_methods,
		     "WSGI C support module");

  if (m == NULL)
    return;

  Py_INCREF(&Request_Type);
  PyModule_AddObject(m, "Request", (PyObject *)&Request_Type);

  Py_INCREF(&InputStream_Type);
  PyModule_AddObject(m, "InputStream", (PyObject *)&InputStream_Type);

  Py_INCREF(&FileWrapper_Type);
  PyModule_AddObject(m, "FileWrapper", (PyObject *)&FileWrapper_Type);
}

void
wsgiCleanup(void)
{
  PyEval_RestoreThread(_main);
  Py_DECREF(pApp);
  Py_DECREF(wsgiStderr);
  Py_Finalize();
}

int
wsgiInit(const char *moduleName, const char *appName, const char *scriptName,
	 const char *progName)
{
  PyObject *pName, *pModule;
  char *argv[1];

  wsgiScriptName = scriptName;
  wsgiScriptNameLen = strlen(scriptName);

  PyEval_InitThreads();
  Py_Initialize();
  argv[0] = (char *)progName;
  PySys_SetArgv(1, argv); /* make sure sys.argv gets set */

  init_wsgisup();

  pName = PyString_FromString("threading");
  if (pName == NULL)
    goto err;

  pModule = PyImport_Import(pName);
  Py_DECREF(pName);
  Py_XDECREF(pModule);
  if (pModule == NULL)
    goto err;

  pName = PyString_FromString("sys");
  if (pName == NULL)
    goto err;

  pModule = PyImport_Import(pName);
  Py_DECREF(pName);
  if (pModule == NULL)
    goto err;

  wsgiStderr = PyObject_GetAttrString(pModule, "stderr");
  Py_DECREF(pModule);
  if (wsgiStderr == NULL)
    goto err;

  pName = PyString_FromString(moduleName);
  if (pName == NULL)
    goto err;

  pModule = PyImport_Import(pName);
  Py_DECREF(pName);
  if (pModule == NULL)
    goto err;

  pApp = PyObject_GetAttrString(pModule, (char *)appName);
  Py_DECREF(pModule);
  if (pApp == NULL || !PyCallable_Check(pApp)) {
    Py_XDECREF(pApp);
    goto err;
  }

  _main = PyEval_SaveThread();
  return 0;

 err:
  PyErr_Print();
  return -1;
}

int
wsgiHandler(void *ctxt)
{
  PyGILState_STATE gstate;
  PyObject *ctxt_c_obj, *args, *start_resp;
  Request *req_obj = NULL;
  PyObject *result;

  gstate = PyGILState_Ensure();

  /* Create Request object, passing it the context as a CObject */
  ctxt_c_obj = PyCObject_FromVoidPtr(ctxt, NULL);
  if (ctxt_c_obj == NULL)
    goto out;

  args = Py_BuildValue("(O)", ctxt_c_obj);
  Py_DECREF(ctxt_c_obj);
  if (args == NULL)
    goto out;

  req_obj = (Request *)PyObject_CallObject((PyObject *)&Request_Type, args);
  Py_DECREF(args);
  if (req_obj == NULL)
    goto out;

  wsgiSetRequestData(ctxt, req_obj);

  /* Prime input stream */
  if (wsgiPrimeInput(ctxt))
    goto out;

  /* Get start_response callable */
  start_resp = PyObject_GetAttrString((PyObject *)req_obj, "start_response");
  if (start_resp == NULL)
    goto out;

  /* Build arguments and call application object */
  args = Py_BuildValue("(OO)", req_obj->environ, start_resp);
  Py_DECREF(start_resp);
  if (args == NULL)
    goto out;

  result = PyObject_CallObject(pApp, args);
  Py_DECREF(args);
  if (result != NULL) {
    /* Handle the application response */
    req_obj->result = result;
    /* result now owned by req_obj */
    wsgiSendResponse(req_obj, result); /* ignore return */
    wsgiCallClose(result);
  }

 out:
  if (PyErr_Occurred()) {
    PyErr_Print();

    /* Display HTTP 500 error, if possible */
    if (req_obj == NULL || !req_obj->headers_sent)
      sendResponse500(ctxt);
  }

  if (req_obj != NULL) {
    wsgiSetRequestData(ctxt, NULL);

    /* Don't rely on cyclic GC. Clear circular references NOW. */
    Request_clear(req_obj);

    Py_DECREF(req_obj);
  }

  PyGILState_Release(gstate);

  /* Always return success. */
  return 0;
}
