#!/usr/bin/env python

import sys, os, signal, re, errno, socket, time, string, random, copy
import tempfile
import shutil
import argparse
import subprocess, shlex
from contextlib import contextmanager

# Import oprofile from the parent direcotry of the script.
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir + '/oprofile')))
import oprofile
import profiles

try: from termcolor import colored
except ImportError:
    colored = lambda str, color: str


def random_str(length = 5, alphabet = string.letters + string.digits):
    alphabet = string.letters + string.digits
    return ''.join(random.choice(alphabet) for x in xrange(length))

def sh_escape_within_single_quotes(str):
    return str.replace("'", "'\\''")

class SafeContextManagerHelper(object):
    def __init__(self, coroutine):
        self.coroutine = coroutine
    def __enter__(self):
        return self.coroutine.next()
    def __exit__(self, t, v, tb):
        try:
            self.coroutine.next()
        except StopIteration:
            pass
        else:
            raise RuntimeError("safe_contextmanager used with an incorrectly implemented function")

def safe_contextmanager(fun):
    """contextlib.contextmanager has a weird behavior where if an exception is raised within the
    body of the with-block, it will reraise that exception within the context manager. This is
    bad because it defeats the purpose of using a with-block to make it easy to do cleanup."""
    def fun2(*args, **kwargs):
        return SafeContextManagerHelper(fun(*args, **kwargs))
    return fun2

@safe_contextmanager
def bracket(start, end):
    start()
    yield
    end()

@safe_contextmanager
def directory(new_dir):
    try: os.makedirs(new_dir)
    except OSError, e:
        if e.errno != errno.EEXIST: raise
        # Otherwise the directory already existed, so it's fine.
    old_dir = os.getcwd()
    os.chdir(new_dir)
    yield
    os.chdir(old_dir)


class Worker(object):
    running = False
    finished = False
    process = None
    out_file = None
    status_file = None
    ssh = None
    argtype = 'Worker'

    def __init__(self, dbench, name, id, ssh, path, raw, param):
        self.dbench = dbench
        self.name = name
        self.id = id
        self.ssh = ssh
        self.path = path
        self.raw = raw
        self.param = param
        self.internal_init()

    def __repr__(self):
        if self.id and self.id != self.name:    id_str = '{%s}' % self.id # XXX
        else:                                   id_str = ''
        if self.ssh:   ssh_str = '%s:' % self.ssh
        else:          ssh_str = ''
        if self.path or self.ssh: path_str = '[%s%s]' % (ssh_str, (self.path or ''))
        else:                     path_str = ''
        if self.raw:   raw_str = ':'
        else:          raw_str = ''
        if self.param: param_str = ':%s%s' % (raw_str, self.param)
        else:          param_str = ''
        return '%s(%s%s%s%s)' % (self.argtype, id_str, self.name, path_str, param_str)

    # Note: start() and stop() are called on the same object once per run
    # (instead of a new object being created each time).

    def start(self):
        self.finished = False
        with directory(self.id):
            self.worker_current_dir = os.getcwd()
            self.status_file = open('status.txt', 'w')
            print >> self.status_file, 'worker: %r' % self
            print >> self.status_file, 'start: %d' % int(time.time())
            self.internal_start()
        self.running = True

    def stop(self):
        assert self.running
        self.internal_stop()

        if self.ssh: self.copy_ssh() # XXX: Will this ever need to be overridden?

        print >> self.status_file, 'stop: %d' % int(time.time())
        if self.process:
            status = self.process.poll()
            print >> self.status_file, 'exitstatus: %s' % self.process.poll()
            if status == None:
                self.dbench.warn('%r not terminated despite calling internal_stop()' % self)
            elif status > 0:
                # TODO: If the status is negative, that means the process was
                # killed. Usually that's OK, because we kill it in
                # internal_stop() (e.g. monitors); we want to report this if it
                # was killed by another process, though.
                self.dbench.warn('%r exited with positive status %d; output:' % (self, status))
                with open(os.path.join(self.worker_current_dir, 'output.txt'), 'r') as f:
                    # TODO: Need a better multi-line log.
                    self.dbench.log(f.read())
        self.status_file.close()
        self.status_file = None

        self.running = False

    def internal_init(self): pass # Override this if needed.
    def internal_start(self): # Override this.
        pass
    def internal_stop(self): # Override this, usually.
        self.stop_process()
        self.finished = True

    def run_process(self, executable, args):
        # It might be better to redirect stdout and stderr to two different
        # files, but then we'd lose the order of the messages.
        self.out_file = open('output.txt', 'w')
        if self.path: executable = self.path

        cmdline = [executable] + args

        if self.ssh:
            self.ssh_unique_path = '/tmp/dbench-%s-%d-%s' % (self.dbench.hosts[0], os.getpid(), random_str())
            cmdline = ['ssh', self.ssh, '--', 'dbench-internal-run', self.ssh_unique_path] + cmdline

        print >> self.status_file, 'cmdline: %r' % cmdline

        try:
            # TODO: Possibly set cmdline[0] to basename and pass in executable=.
            self.process = subprocess.Popen(cmdline, preexec_fn=os.setsid,
                                            stdout=self.out_file, stderr=self.out_file)
        except OSError, e:
            self.dbench.die("Can't execute `%s': %s" % (executable, e.strerror))

    def copy_ssh(self):
        if self.ssh:
            assert self.ssh_unique_path
            # XXX: Doing this synchronously will probably cause the rest of the monitors to terminate too late, which might cause some trouble.
            cmdline = ['scp', '-pr', '%s:%s/*' % (self.ssh, self.ssh_unique_path), self.worker_current_dir]
            ret = subprocess.call(cmdline, stdout=subprocess.PIPE, stderr=subprocess.PIPE) # XXX: Will =subprocess.PIPE block if too much output is printed? Maybe this output should be logged too?
            if ret != 0:
                self.dbench.warn('Failed to copy SSH files from %s:%s.' % (self.ssh, self.ssh_unique_path))
            ret = subprocess.call(['ssh', self.ssh, 'rm', '-rf', "'%s'" % sh_escape_within_single_quotes(self.ssh_unique_path)],
                                  stdout=subprocess.PIPE, stderr=subprocess.PIPE) # Note: Still needs escaping.
            if ret != 0:
                self.dbench.warn('Failed to delete temporary directory %s:%s.' % (self.ssh, self.ssh_unique_path))

    MAX_WAIT_TIME = 5
    WAIT_INTERVAL = 0.05
    def bounded_wait(self):
        bound = time.time() + self.MAX_WAIT_TIME
        while self.process.poll() == None and time.time() < bound:
            time.sleep(self.WAIT_INTERVAL)
        return self.process.poll()

    def stop_process(self):
        if self.process and self.process.poll() == None:
            # TODO: Kill more intelligently, as in test_common.
            self.process.terminate()
            if self.bounded_wait() == None:
                self.dbench.warn('%s subprocess not responding to SIGTERM after %d seconds; sending SIGKILL.' % (self, self.MAX_WAIT_TIME))
                self.process.kill()
        if self.out_file and not self.out_file.closed:
            self.out_file.close()

    def wait(self):
        assert self.process
        self.process.wait()
        self.stop()

    binary_paths = []
    def find_binary(self, default):
        # TODO: This can be much more intelligent.
        if self.path != None or self.ssh != None:
            return default # If the path is given, run_process will just use the binary from the path; if we're using SSH we can't easily check if a file exists.
        for bin in self.binary_paths:
            if os.path.exists(bin): # Relative to current directory.
                return bin
            rel_to_script = os.path.abspath(os.path.join(self.dbench.script_path, bin))
            if os.path.exists(rel_to_script): # Relative to dbench executable.
                return rel_to_script
        return default # Will look in $PATH

    def compute_args(self, prefix=[], default = None):
        args = shlex.split(self.param or '')

        if default == None: default = prefix

        if self.raw: return args
        elif self.param: return prefix + args
        else: return default

class Cmd(Worker):
    def internal_start(self):
        self.run_process('bash', ['-c', self.param])

class Sleep(Worker):
    def internal_start(self):
        try:
            length = float(self.param)
            if length < 0: raise ValueError
        except ValueError, e:
            self.dbench.warn("Invalid sleep value `%s'. Defaulting to 1." % self.param)
            length = 1.0
        time.sleep(length)


class Server(Worker):
    argtype = 'Server'

    MAX_WAIT_TIME = 3600
    WAIT_INTERVAL = 0.5
    def wait_for_server_ready(self):
        print "Waiting: %d\n" % time.time()
        bound = time.time() + self.MAX_WAIT_TIME
        while time.time() < bound:
            s = socket.socket()
            try:
                s.connect(('localhost', self.dbench.port)) # XXX: Change this when/if we support running the server on a different host.
            except socket.error, e:
                # Couldn't connect; retry.
                time.sleep(self.WAIT_INTERVAL)
                continue
            # Connected to the server successfully.
            s.shutdown(socket.SHUT_RDWR)
            s.close()
            # XXX: For some reason the RethinkDB server seems to introduce odd
            # delays if enough clients connect to it with some particular
            # workload immediately after startup. This should be properly
            # fixed, but until then, this works well enough.
            print "Done waiting: %d\n" % time.time()
            return
        self.dbench.die('Unable to connect to server (localhost:%d)' % (self.dbench.port)) # XXX: This too.

class Client(Worker): argtype = 'Client'
class Insert(Worker): argtype = 'Insert'
class Monitor(Worker): argtype = 'Monitor'


class MemcachedLikeServer(Server):
    pass


class RethinkDB(MemcachedLikeServer):
    binary_paths = ['../../../../../build/release/rethinkdb', '../../build/release/rethinkdb']

    def internal_start(self):
        args = self.compute_args(['-p', str(self.dbench.port)])
        self.run_process(self.find_binary('rethinkdb'), args)
        self.wait_for_server_ready()

    MAX_WAIT_TIME = 3600
    WAIT_INTERVAL = 0.5

    # Replication setups require us to wait until the nodes are in sync before sending queries
    def wait_for_server_ready(self):
        bound = time.time() + self.MAX_WAIT_TIME
        print "Waiting: %d\n" % time.time()
        while time.time() < bound:
            try:
                s = socket.socket()
                s.connect(('localhost', self.dbench.port))
                s.send('get test\r\n\r\n')
                res = s.makefile().readline()
                if res == 'END\r\n':
                    print "Waiting done: %d\n" % time.time()
                    return
            except socket.error, e:
                pass
            finally:
                s.close()
            time.sleep(self.WAIT_INTERVAL)
        self.dbench.die('Unable to connect to server (localhost:%d)' % (self.dbench.port))

class Memcached(MemcachedLikeServer):
    binary_paths = []

    def internal_start(self):
        args = self.compute_args(['-p', str(self.dbench.port)])
        self.run_process(self.find_binary('memcached'), args)
        self.wait_for_server_ready()

class MemcacheDB(MemcachedLikeServer):
    binary_paths = []

    def internal_start(self):
        args = self.compute_args(['-p', str(self.dbench.port),
                                  '-H', 'memcachedb_data',
                                  '-N'              # enable DB_TXN_NOSYNC to gain big performance improved, default is off
                                  #'-t', '<num>'    # number of threads to use, default 4
                                  #'-v'             # verbose (print errors/warnings while in event loop)
                                  #'-m', '<num>'    # in-memmory cache size of BerkeleyDB in megabytes, default is 64MB
                                  #'-B' '<db_type>' # type of database, 'btree' or 'hash'. default is 'btree'
                                 ])
        self.run_process(self.find_binary('memcachedb'), args)
        self.wait_for_server_ready()


class Membase(MemcachedLikeServer):
    binary_paths = []

    def internal_start(self):
        if self.dbench.port != 11211:
            self.dbench.die('Membase requires port 11211; please run explicitly with that port.')
        # XXX: membase-wrapper requires that you specify -m and -d. Say this here.
        args = self.compute_args([])
        self.run_process(self.find_binary('membase-wrapper'), args)
        self.wait_for_server_ready()

    MAX_WAIT_TIME = 3600
    WAIT_INTERVAL = 0.5
    # When membase is first run, before the configuration step is complete,
    # clients that connect are automatically directed to a "null bucket", and
    # remain connected to it for the lifetime of the connection. We want to
    # wait until we can connect to a real bucket; we can't do that here.
    def wait_for_server_ready(self):
        reasonable_responses = ['SERVER_ERROR unauthorized, null bucket\r\n',
                                'SERVER_ERROR a2b not_my_vbucket\r\n']

        bound = time.time() + self.MAX_WAIT_TIME
        print "Waiting: %d\n" % time.time()
        while time.time() < bound:
            try:
                s = socket.socket()
                s.connect(('localhost', self.dbench.port)) # XXX: Change this when/if we support running the server on a different host.
                s.send('set test 0 0 0\r\n\r\n')
                res = s.makefile().readline()
                if res == 'STORED\r\n':
                    print "Waiting done: %d\n" % time.time()
                    return
                elif res not in reasonable_responses:
                    self.dbench.warn('Unknown response from membase server: %s' % res)
            except socket.error, e:
                pass
            finally:
                s.close()
            time.sleep(self.WAIT_INTERVAL)
        self.dbench.die('Unable to connect to server (localhost:%d)' % (self.dbench.port)) # XXX: This too.

    # XXX: This shouldn't be necessary when we're smarter about killing children.
    def stop_process(self):
        if self.process and self.process.poll() == None:
            time.sleep(200)
        if self.process and self.process.poll() == None:
            time.sleep(600)
        if self.process and self.process.poll() == None:
            os.killpg(self.process.pid, signal.SIGTERM)
            if self.bounded_wait() == None:
                self.dbench.warn('%s sub-process-group not responding to SIGTERM after %d seconds; sending SIGKILL.' % (self, self.MAX_WAIT_TIME))
                os.killpg(self.process.pid, signal.SIGKILL)
        if self.out_file and not self.out_file.closed:
            self.out_file.close()


class MySQL(Server):
    def internal_start(self):
        args = self.compute_args(['--port=%d' % self.dbench.port])
        self.run_process('/usr/local/mysql/libexec/mysqld', args)
        self.wait_for_server_ready()
        time.sleep(10)

    def internal_stop(self):
        os.system("/usr/local/mysql/bin/mysqladmin shutdown -u root --port=%d" % self.dbench.port)
        time.sleep(15)
        self.stop_process()
        self.finished = True

class Stress(Client):
    binary_paths = ['../../../../stress-client/stress', '../stress-client/stress']

    LATENCY_FILE = 'latency.txt'
    QPS_FILE = 'qps.txt'

    def internal_start(self):
        host_args = []
        for host in self.dbench.hosts:
            host_args.append('-s')
            host_args.append('sockmemcached,%s:%d' % (host, self.dbench.port))

        default_args = host_args + \
                       ['-l', self.LATENCY_FILE,
                        '-q', self.QPS_FILE,
                        '--client-suffix']
        args = self.compute_args(default_args)
        self.run_process(self.find_binary('stress'), args)

class Stressinsert(Insert):
    binary_paths = ['../../../../stress-client/stress', '../stress-client/stress']

    LATENCY_FILE = 'latency.txt'
    QPS_FILE = 'qps.txt'

    def internal_start(self):
        host_args = []
        for host in self.dbench.hosts:
            host_args.append('-s')
            host_args.append('sockmemcached,%s:%d' % (host, self.dbench.port))

        default_args = host_args + \
                       ['-l', self.LATENCY_FILE,
                        '-q', self.QPS_FILE,
                        '--client-suffix']
        args = self.compute_args(default_args)
        self.run_process(self.find_binary('stress'), args)

class MySQLStress(Stress):
    def internal_start(self):
        user, passwd, db = 'teapot', 'teapot', 'bench' # XXX: This shouldn't be hard-wired.
        host_args = []
        for host in self.dbench.hosts:
            host_args.append('-s')
            host_args.append('mysql,%s/%s@%s:%d+%s' % (user, passwd, host, self.dbench.port, db))

        default_args = host_args + \
                       ['-l', self.LATENCY_FILE,
                        '-q', self.QPS_FILE]
        args = self.compute_args(default_args)
        self.run_process(self.find_binary('stress'), args)

class StressFree(Stress): # Temporary.
    # A hacky version of wait() that restarts if the client has an error.
    def wait(self):
        assert self.process
        status = self.process.wait()
        if status != 0:
            self.dbench.warn('%r exited with status %d' % (self, status))
        if status == 255:
            self.dbench.warn('Assuming stress client failure is temporary; restarting.')
            with directory(self.worker_current_dir):
                print >> self.status_file, 'restart: %d' % int(time.time())
                self.internal_start()
            return self.wait()
        self.stop()


class Stat(Monitor):
    # Runs a 
    def internal_start(self):
        args = self.compute_args(prefix=[], default=['1'])
        self.run_process(self.statcmd, args)

class VMStat(Stat): statcmd = 'vmstat'
class IOStat(Stat): statcmd = 'iostat'
class IFStat(Stat):
    statcmd = 'ifstat'

    def internal_start(self):
        prefix = ['-n', # Display header only once.
                  '-t'] # Timestamps.
        args = self.compute_args(prefix=prefix, default=prefix + ['1'])
        self.run_process(self.statcmd, args)

class RDBStat(Stat):
    statcmd = 'rdbstat'

    def internal_start(self):
        if self.ssh: self.dbench.die("Can't run rdbstat over SSH yet.") # XXX: Figure out what to do regarding hosts.
        prefix = ['-p', str(self.dbench.port)]
        args = self.compute_args(prefix=prefix, default=prefix + ['1'])
        self.run_process(self.statcmd, args)

class OProfile(Monitor):
    # XXX: sudo must allow the user to run opcontrol without a password.

    def internal_init(self):
        if self.ssh:
            self.dbench.warn("Can't run OProfile over SSH.")
            self.ssh = None
        if self.param:
            self.dbench.warn("OProfile doesn't support parameters yet.")

        self.op = oprofile.OProfile()
        self.profiles = copy.deepcopy(profiles.small_packet_profiles)

    def internal_start(self):
        self.cur_profile = self.profiles.pop(0)
        self.op.start(self.cur_profile.events)

    def internal_stop(self):
        # stop_and_report() needs to be called from inside the correct directory. Maybe .stop() should do this instead?
        with directory(self.worker_current_dir):
            self.op.stop()
        self.finished = self.profiles == []

workers = {
    'cmd': Cmd,
    'sleep': Sleep,

    'rethinkdb': RethinkDB,
    'memcachedb': MemcacheDB,
    'membase': Membase,
    'memcached': Memcached,

    'mysql': MySQL,

    'stress': Stress,
    'stressinsert': Stressinsert,
    'mysqlstress': MySQLStress,
    'stressfree': StressFree,

    'oprofile': OProfile,
    'vmstat': VMStat,
    'iostat': IOStat,
    'ifstat': IFStat,
    'rdbstat': RDBStat,
}

class DBench(object):

    worker_re = re.compile(r'''^
                               (\{
                                  (?P<id>\w+)
                               \})?
                               (?P<name>[^:\[]+)
                               (\[
                                  ((?P<ssh>(\w+@)?[a-zA-Z0-9\-.]+):)?
                                  (?P<path>[^\]:]*)
                               \])?
                               (:
                                 (?P<raw>:)?
                                 (?P<param>.*)
                               )?
                               $''', re.X)

    # XXX: Ugh.
    existing_ids = []
    def unique_id(self, initial_id, iter=None):
        id = initial_id + str(iter or '')
        if id in self.existing_ids:
            return self.unique_id(initial_id, (iter or 0) + 1)
        else:
            self.existing_ids.append(id)
            return id

    def parse_worker(self, arg):
        match = self.worker_re.match(arg)
        if not match:
            raise argparse.ArgumentTypeError("Invalid argument: %s" % arg)
        name = match.groupdict()['name']
        id = self.unique_id(match.groupdict()['id'] or name)
        ssh = match.groupdict()['ssh']
        path = match.groupdict()['path']
        raw = match.groupdict()['raw'] == ':'
        param = match.groupdict()['param']
        try:
            klass = workers[name]
        except KeyError:
            raise argparse.ArgumentTypeError("Unknown worker: %s" % name)

        return klass(dbench=self, name=name, id=id, ssh=ssh, path=path, raw=raw, param=param)

    def get_port(self, default=None):
        if default: return default
        # Otherwise, find unused port.
        so = socket.socket()
        so.bind(('localhost', 0))
        port = so.getsockname()[1]
        so.close()
        time.sleep(0.5) # XXX: To make sure the socket is really closed. Surely there's a better way to do this.
        return port

    def parse_args(self):
        parser = argparse.ArgumentParser(description='Run benchmarks and/or profile.')
        parser.add_argument('server', help='Server XXX', type=self.parse_worker)
        parser.add_argument('client', help='Client XXX', type=self.parse_worker)
        parser.add_argument('monitors', help='Monitors XXX', metavar='monitor',
                            nargs='*', type=self.parse_worker)
        parser.add_argument('-i', '--insert', help='XXX', type=self.parse_worker)
        parser.add_argument('-f', '--force', help='Delete output directory before starting',
                            action='store_true')
        parser.add_argument('-P', '--prime-caches', help='Run client twice to prime caches',
                            action='store_true')
        parser.add_argument('-d', '--output-directory', help='Directory to output benchmarks to.',
                            type=str, default='./bench_output')
        parser.add_argument('-p', '--port', help='Server port (if not specified, will find unused port automatically).', type=int)
        default_host = socket.gethostname()
        parser.add_argument('-H', '--hosts', help='Comma-separated list of hostnames that will be given as an argument to workers run over SSH (default: %s). Note: All hostnames must point to the same machine for now; multiple hostnames should only be used for benchmarking over multiple network interfaces.' % default_host,
                            type=lambda hosts: hosts.split(','), default=[default_host])
        # TODO: use (/usr/bin/)time?
        # TODO: time-format
        # TODO: Minimum #runs.
        self.args = parser.parse_args()
        self.port = self.get_port(self.args.port)
        self.hosts = self.args.hosts
        self.monitors = self.args.monitors
        self.server = self.args.server
        self.client = self.args.client
        self.insert = self.args.insert

    def warn(self, str):
        self.log('Warning: %s' % str, sys.stderr, 'red')

    dying = False
    def die(self, str):
        if self.dying:
            self.log('Yo dawg error: Died while dying. Exiting immediately.', sys.stderr, 'red')
            sys.exit(1)
        self.dying = True
        self.log('Fatal error: %s' % str, sys.stderr, 'red')
        self.emergency_cleanup()
        sys.exit(1)

    def info(self, str):
        self.log(str, color='green')

    def emergency_cleanup(self):
        for w in [self.client, self.server] + self.monitors:
            if w.running: w.stop()

    def sig_handler(self, sig, frame):
        self.die("Got signal %d." % sig)

    # We store the log messages in an array instead of just writing them to a
    # file because sometimes the file wasn't created yet, or there was an issue
    # making it. XXX: This seems needlessly complicated.
    log_messages = []
    log_depth = 0
    def log(self, str, term=sys.stdout, color=None):
        str = ' ' * self.log_depth + str
        if color: print >> term, colored(str, color)
        else:     print >> term, str
        self.log_messages.append(str)
    
    # XXX: This isn't a good way of doing it.
    @safe_contextmanager
    def deeper_log(self, delta = 1):
        self.log_depth += 1
        yield
        self.log_depth -= 1

    @safe_contextmanager
    def timed_subsection(self, name):
        self.info("Starting %s (at %d)" % (name, int(time.time())))
        with self.deeper_log(1):
            yield
        self.info("Finished %s (at %d)" % (name, int(time.time())))

    def setup(self):
        self.info('DBench starting with arguments: %s' % sys.argv[1:])
        self.info('Start time: %d' % int(time.time()))
        self.parse_args()
        self.info('(Port: %d)' % self.port)
        self.script_path = os.path.abspath(os.path.dirname(__file__)) # XXX: This is a hack.
        signal.signal(signal.SIGINT, self.sig_handler)
        signal.signal(signal.SIGTERM, self.sig_handler)
        if self.args.force and os.path.exists(self.args.output_directory):
            self.info("(Deleting old output `%s')" % self.args.output_directory)
            shutil.rmtree(self.args.output_directory)

        try: os.makedirs(self.args.output_directory)
        except OSError, e:
            if e.errno != errno.EEXIST: raise
            self.die('Directory %s exists (use --force to delete)' % self.args.output_directory)
        os.chdir(self.args.output_directory)

    def finish(self):
        self.info('End time: %d' % int(time.time()))
        with open('bench_log.txt', 'w') as f:
            for msg in self.log_messages:
                print >> f, msg

    def start_workers(self):
        # Make sure the insert client finishes first
        if self.insert:
            self.info('%s started for insert' % self.insert)
            self.info('Waiting for client...')
            self.insert.start()
            self.insert.wait()
            self.info('Client completed')
            time.sleep(180) # Wait a bit for the database to be able to flush stuff...
    
        if self.args.prime_caches:
            # Run client twice to prime caches of the server
            self.client.start()
            self.info('%s started for cache priming' % self.client)
            self.info('Waiting for client...')
            self.client.wait()
            self.info('Client completed')
            
        for w in self.monitors + [self.client]:
            w.start()
            self.info('%s started' % w)
    def stop_workers(self):
        if self.client.running: self.die('Client not terminated')
        # XXX: We stop the monitors before the server because they sometimes misbehave otherwise.
        for w in self.monitors + [self.server]:
            w.stop()
            self.info('%s stopped' % w)


    def wait(self):
        self.info('Waiting for client...')
        self.client.wait()
        self.info('Client completed')

    @safe_contextmanager
    def setup_run(self):
        with directory(str(self.run_id)):
            with self.timed_subsection('run %d' % self.run_id):
                yield
        self.run_id += 1

    def one_run(self):
        with self.setup_run():
            print "Starting server"
            self.server.start()
            self.info('%s started' % self.server)
            with bracket(self.start_workers, self.stop_workers):
                self.wait()


    def need_to_run(self):
        return any(not m.finished for m in [self.server, self.client] + self.monitors)

    run_id = 1

    def run(self):
        with bracket(self.setup, self.finish):
            while self.need_to_run():
                self.one_run()

if __name__ == '__main__':
    if not os.getenv("MOCKBENCH", False):
        DBench().run()
