# chgsupport - functionality to support cHg
#
# Copyright 2011 Yuya Nishihara <yuya@tcha.org>
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2 or any later version.

"""functionality to support cHg

hg serve --cmdserver=unix
    communicate with cmdserver via unix domain socket

This extends channels and commands of cmdserver:

's'-channel
    handles util.system() request

'P'-channel
    handles getpass() request

'chdir' command
    change current directory

'getpager' command
    checks if pager is enabled and which pager should be executed

'getpid' command
    get process id of the server

'setchguiattr' command
    set initial value of ui attributes

'setenv' command
    replace os.environ completely
"""

import errno, os, re, struct, sys, tempfile, traceback, SocketServer
from mercurial import cmdutil, commands, commandserver, dispatch, encoding, \
                      error, extensions, i18n, scmutil, util
from mercurial.node import hex
from mercurial.i18n import _

testedwith = '3.0 3.1'

class channeledinput(commandserver.channeledinput):
    def readpass(self):
        return self._read(self.maxchunksize, 'P').rstrip('\n')

class channeledoutput(commandserver.channeledoutput):
    def flush(self):
        # tell flush() request by 0-length data
        self.out.write(struct.pack('>cI', self.channel, 0))
        self.out.flush()

class channeledsystem(object):
    """Request to execute system() in the following format:

    command length (unsigned int),
    cmd

    and waits:

    exitcode length (unsigned int),
    exitcode (int)
    """
    def __init__(self, in_, out, channel):
        self.in_ = in_
        self.out = out
        self.channel = channel

    # copied from mercurial/util.py
    def __call__(self, cmd, environ={}, cwd=None, onerr=None, errprefix=None,
                 out=None):
        # XXX no support for environ, cwd and out

        origcmd = cmd
        cmd = util.quotecommand(cmd)

        self.out.write(struct.pack('>cI', self.channel, len(cmd)))
        self.out.write(cmd)
        self.out.flush()

        length = self.in_.read(4)
        length = struct.unpack('>I', length)[0]
        if length != 4:
            rc = 255
        else:
            rc = struct.unpack('>i', self.in_.read(4))[0]

        if rc and onerr:
            errmsg = '%s %s' % (os.path.basename(origcmd.split(None, 1)[0]),
                                util.explainexit(rc)[0])
            if errprefix:
                errmsg = '%s: %s' % (errprefix, errmsg)
            try:
                onerr.warn(errmsg + '\n')
            except AttributeError:
                raise onerr(errmsg)
        return rc

_envmodstoreload = [
    # recreate gettext 't' by reloading i18n
    ('LANG LANGUAGE LC_MESSAGES'.split(), [encoding, i18n]),
    ('HGENCODING HGENCODINGMODE HGENCODINGAMBIGUOUS'.split(), [encoding]),
    ]

def _listmodstoreload(newenv):
    """List modules must be reloaded after os.environ change"""
    toreload = set()
    for envkeys, mods in _envmodstoreload:
        if util.any(newenv.get(k) != os.environ.get(k) for k in envkeys):
            toreload.update(mods)
    return toreload

def _fixdefaultencoding():
    """Apply new default encoding to commands table"""
    newdefaults = {'encoding': encoding.encoding,
                   'encodingmode': encoding.encodingmode}
    for i, opt in enumerate(commands.globalopts):
        name = opt[1]
        newdef = newdefaults.get(name)
        if newdef is not None:
            commands.globalopts[i] = opt[:2] + (newdef,) + opt[3:]

_envvarre = re.compile(r'\$[a-zA-Z_]+')

def _clearenvaliases(cmdtable):
    """Remove stale command aliases referencing env vars; variable expansion
    is done at dispatch.addaliases()"""
    for name, tab in cmdtable.items():
        cmddef = tab[0]
        if (isinstance(cmddef, dispatch.cmdalias)
            and not cmddef.definition.startswith('!')  # shell alias
            and _envvarre.search(cmddef.definition)):
            del cmdtable[name]

class chgcmdserver(commandserver.server):
    def __init__(self, ui, repo, fin, fout):
        cui = _cmdui(ui)
        try:
            super(chgcmdserver, self).__init__(cui, repo, fin, fout)
        except TypeError:
            # hg<3.2 (a0e81aa94125)
            super(chgcmdserver, self).__init__(cui, repo, mode='pipe')
            try:
                self.cerr = channeledoutput(fout, 'e')
                self.cout = channeledoutput(fout, 'o')
                self.cresult = channeledoutput(fout, 'r')
            except TypeError:
                # hg<3.2 (8cc5e673cac0)
                self.cerr = channeledoutput(fout, fout, 'e')
                self.cout = channeledoutput(fout, fout, 'o')
                self.cresult = channeledoutput(fout, fout, 'r')
            self.cin = channeledinput(fin, fout, 'I')
            self.client = fin

        self._sui = ui  # ui for server output
        self.csystem = channeledsystem(fin, fout, 's')
        self.ui.csystem = self.csystem

    def chdir(self):
        """Change current directory

        Note that the behavior of --cwd option is bit different from this.
        It does not affect --config parameter.
        """
        length = struct.unpack('>I', self._read(4))[0]
        if not length:
            return
        path = self._read(length)
        self._sui.debug('chdir to %r\n' % path)
        os.chdir(path)

    def getpager(self):
        """Read cmdargs and write pager command to r-channel if enabled

        If pager isn't enabled, this writes '\0' because channeledoutput
        does not allow to write empty data.
        """
        length = struct.unpack('>I', self._read(4))[0]
        if not length:
            args = []
        else:
            args = self._read(length).split('\0')

        try:
            pagermod = extensions.find('pager')
        except KeyError:
            self.cresult.write('\0')
            return

        pagercmd = self.ui.config('pager', 'pager', os.environ.get('PAGER'))
        if not pagercmd:
            self.cresult.write('\0')
            return

        try:
            cmd, _func, args, options, _cmdoptions = dispatch._parse(self.ui,
                                                                     args)
        except (error.AmbiguousCommand, error.CommandError,
                error.UnknownCommand):
            self.cresult.write('\0')
            return

        # duplicated from hgext/pager.py
        attend = self.ui.configlist('pager', 'attend', pagermod.attended)
        auto = options['pager'] == 'auto'
        always = util.parsebool(options['pager'])
        if (always or auto and
            (cmd in attend or
             (cmd not in self.ui.configlist('pager', 'ignore')
              and not attend))):
            self.cresult.write(pagercmd)
        else:
            self.cresult.write('\0')

    def getpid(self):
        """Write pid of the server"""
        self.cresult.write(struct.pack('>i', os.getpid()))

    def runcommand(self):
        # reset time-stamp so that "progress.delay" can take effect
        progbar = getattr(self.ui, '_progbar', None)
        if progbar:
            self._sui.debug('reset progbar\n')
            progbar.resetstate()
        # swap stderr so that progress output can be redirected to client
        olderr = sys.stderr
        sys.stderr = self.cerr
        try:
            super(chgcmdserver, self).runcommand()
        finally:
            sys.stderr = olderr

    def setchguiattr(self):
        """Set ui attributes if not configured"""
        length = struct.unpack('>I', self._read(4))[0]
        if not length:
            return
        s = self._read(length)
        for l in s.splitlines():
            k, v = l.split(': ', 1)
            s, n = k.split('.', 1)
            if self.ui.config(s, n) is None:
                self._sui.debug('set ui %r: %r\n' % (k, v))
                self.ui.setconfig(s, n, v)

        # reload config so that ui.plain() takes effect
        for f in scmutil.rcpath():
            self.ui.readconfig(f, trust=True)

        # skip initial uisetup() by 'mode=chgauto'
        if self.ui.config('color', 'mode') == 'chgauto':
            self.ui.setconfig('color', 'mode', 'auto')

    def setenv(self):
        """Clear and update os.environ

        Note that not all variables can make an effect on the running process.
        """
        length = struct.unpack('>I', self._read(4))[0]
        if not length:
            return
        s = self._read(length)
        try:
            newenv = dict(l.split('=', 1) for l in s.split('\0'))
        except ValueError:
            raise ValueError('unexpected value in setenv request')

        modstoreload = _listmodstoreload(newenv)

        if self._sui.debugflag:
            for k in sorted(set(os.environ.keys() + newenv.keys())):
                ov, nv = os.environ.get(k), newenv.get(k)
                if ov != nv:
                    self._sui.debug('change env %r: %r -> %r\n' % (k, ov, nv))
        os.environ.clear()
        os.environ.update(newenv)

        for mod in modstoreload:
            self._sui.debug('reload %s module\n' % mod.__name__)
            reload(mod)
        if encoding in modstoreload:
            _fixdefaultencoding()

        _clearenvaliases(commands.table)

    capabilities = commandserver.server.capabilities.copy()
    capabilities.update({'chdir': chdir,
                         'getpager': getpager,
                         'getpid': getpid,
                         'runcommand': runcommand,
                         'setchguiattr': setchguiattr,
                         'setenv': setenv})

def _cmdui(ui):
    class chgui(ui.__class__):
        def __init__(self, src=None):
            super(chgui, self).__init__(src)
            if src:
                self.csystem = getattr(src, 'csystem', None)
            else:
                self.csystem = None  # should be assigned later
            # progbar is a singleton; we want to make it see the last config
            progbar = getattr(self, '_progbar', None)
            if progbar:
                progbar.ui = self

        def termwidth(self):
            n = self.configint('ui', '_termwidth')
            if n and n > 0:
                return n
            else:
                return super(chgui, self).termwidth()

        # copied from mercurial/ui.py
        def edit(self, text, user, extra={}, editform=None):
            (fd, name) = tempfile.mkstemp(prefix="hg-editor-", suffix=".txt",
                                          text=True)
            try:
                f = os.fdopen(fd, "w")
                f.write(text)
                f.close()

                environ = {'HGUSER': user}
                if 'transplant_source' in extra:
                    environ.update(
                        {'HGREVISION': hex(extra['transplant_source'])})
                for label in ('source', 'rebase_source'):
                    if label in extra:
                        environ.update({'HGREVISION': extra[label]})
                        break
                if editform:
                    environ.update({'HGEDITFORM': editform})

                editor = self.geteditor()

                self.csystem("%s \"%s\"" % (editor, name),
                             environ=environ,
                             onerr=util.Abort, errprefix=_("edit failed"),
                             out=self.fout)

                f = open(name)
                t = f.read()
                f.close()
            finally:
                os.unlink(name)

            return t

        # copied from mercurial/ui.py
        def getpass(self, prompt=None, default=None):
            if not self.interactive():
                return default
            try:
                self.fout.write(prompt or _('password: '))
                self.fout.flush()
                return self.fin.readpass()
            except EOFError:
                raise util.Abort(_('response expected'))

        # copied from mercurial/ui.py
        def plain(self, feature=None):
            plain = self.configbool('ui', '_plain')
            plainexcept = self.configlist('ui', '_plainexcept')
            if not plain and not plainexcept:
                return False
            if feature and plainexcept:
                return feature not in plainexcept
            return True

    ui = chgui(ui)
    # don't retain command-line opts of the service
    for opt in ('verbose', 'debug', 'quiet', 'traceback'):
        ui.setconfig('ui', opt, None)
    ui.setconfig('progress', 'assume-tty', None)
    return ui

def _disablepager():
    try:
        pagermod = extensions.find('pager')
    except KeyError:
        return
    # _runpager(p) or _runpager(ui, p)
    extensions.wrapfunction(pagermod, '_runpager', lambda orig, *args: None)

def _serve(orig, ui, repo, **opts):
    mode = opts['cmdserver']
    if not mode or mode != 'unix':
        return orig(ui, repo, **opts)

    _disablepager()

    if util.safehasattr(commandserver, 'unixservice'):
        return orig(ui, repo, **opts)
    # hg<3.2 (840be5ca03e1)
    service = unixservice(ui, repo, opts)
    cmdutil.service(opts, initfn=service.init, runfn=service.run)

if not util.safehasattr(commandserver, 'unixservice'):
    # hg<3.2 (840be5ca03e1)

    # copied from mercurial/commandserver.py
    class _requesthandler(SocketServer.StreamRequestHandler):
        def handle(self):
            ui = self.server.ui
            repo = self.server.repo
            sv = chgcmdserver(ui, repo, self.rfile, self.wfile)
            try:
                try:
                    sv.serve()
                # handle exceptions that may be raised by command server. most
                # of known exceptions are caught by dispatch.
                except util.Abort, inst:
                    ui.warn(_('abort: %s\n') % inst)
                except IOError, inst:
                    if inst.errno != errno.EPIPE:
                        raise
                except KeyboardInterrupt:
                    pass
            except: # re-raises
                # also write traceback to error channel. otherwise client cannot
                # see it because it is written to server's stderr by default.
                traceback.print_exc(file=sv.cerr)
                raise

    class unixservice(object):
        """
        Listens on unix domain socket and forks server per connection
        """
        def __init__(self, ui, repo, opts):
            self.ui = ui
            self.repo = repo
            self.address = opts['address']
            if not util.safehasattr(SocketServer, 'UnixStreamServer'):
                raise util.Abort(_('unsupported platform'))
            if not self.address:
                raise util.Abort(_('no socket path specified with --address'))

        def init(self):
            class cls(SocketServer.ForkingMixIn, SocketServer.UnixStreamServer):
                ui = self.ui
                repo = self.repo
            self.server = cls(self.address, _requesthandler)
            self.ui.status(_('listening at %s\n') % self.address)
            self.ui.flush()  # avoid buffering of status message

        def run(self):
            try:
                self.server.serve_forever()
            finally:
                os.unlink(self.address)

def uisetup(ui):
    extensions.wrapcommand(commands.table, 'serve', _serve)
    commandserver.channeledoutput = channeledoutput
    commandserver.channeledinput = channeledinput
    commandserver.server = chgcmdserver
