#
# superimp.py
#
# Copyright (C) 2005  Dr. Stephane Gagne
# the full copyright notice is found in the LICENSE file in this directory
#
# This program is free software; 
# you can redistribute it and/or modify it under the terms of the 
# GNU General Public License as published by the Free Software Foundation; 
# either version 2 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, 
# but WITHOUT ANY WARRANTY; without even the implied warranty of 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License 
# along with this program; if not, write to the Free Software Foundation, Inc., 
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
# Contact:  nmr@rsvs.ulaval.ca
#
# --------------------------------------------------------------------
# author:  Leigh Willard
# lab:     Stephane Gagne, Laval Universite
# date:    jan 2004
# --------------------------------------------------------------------
#



from pymol import cmd
from Tkinter import *
from tkFileDialog import *
from rmsd import *
from Crmsd import *
import string, common



class Impose:
    """ class to hold all superimpose data, and to draw the superimpose gui """

    def __init__(self, parent, nmrframe):
        self.nmrframe = nmrframe
        self.atomlist = ['C O N CA', 'C N CA', '-H -O', '-O']
        self.atom1 = StringVar(parent)
        self.obj1 = StringVar(parent)
        self.resrange = StringVar(parent)
        self.resrange.set('all')
        self.atom1.set('ca')
        self.bbonly = IntVar(parent)
        self.bbonly.set(0)
        self.resopt = IntVar(parent)
        self.resopt.set(1)
        self.xrmsd = DoubleVar(parent)
        self.xrmsd.set(1.0)
        self.bestrange = StringVar(parent)
        self.bbdef_list = ['C N CA', 'C O N CA']
        self.bbdef = StringVar(parent)
        self.bbdef.set(self.bbdef_list[0])


    def toggle_bb(self):
        """ toggle backbone only drawing """
        if self.bbonly.get() == 0:
            cmd.show("lines")
        if self.bbonly.get() == 1:
            cmd.hide("lines")
            cmd.show("lines", "n. n+c+ca+o")


    def draw(self, nmrframe):

        supF = Frame(self.nmrframe)
        supF.pack()

        Label(supF, text="").pack()
        Label(supF, text="SUPERIMPOSE STRUCTURES:", fg="midnightblue").pack()
        Label(supF, text="").pack()
        Checkbutton(supF, text="Display backbone only", variable=self.bbonly, command=self.toggle_bb).pack()
        Label(supF, text="").pack()
        middleF = Frame(supF)
        middleF.pack()
        fr1 = Frame(middleF, bd=2, relief="sunken")
        fr2 = Frame(middleF, bd=2, relief="sunken")
        fr1.pack(side=LEFT,fill=BOTH, expand=YES)
        fr2.pack(fill=BOTH, expand=YES, padx=10)

        Label(fr1, text="").pack()
        Label(fr1, text="FIT/RMSD", fg="darkgreen").pack()
        Label(fr1, text="").pack()
  
        Label(fr1, text="").pack()
        # entry box of all current objects.  this is the object
        # that the others will be superimposed to.
        self.objChoice(fr1, self.obj1, "superimpose on which structure:" )
        Label(fr1, text="").pack()

        # entry box of all current objects.  this is the object
        # that the others will be superimposed to.
        self.anyChoice(fr1, self.atom1, self.atomlist, 10, "superimpose on which atoms:" )
        Label(fr1, text="").pack()

        fr3 = Frame(fr1)
        fr3.pack()
        Label(fr3, text="Residue range:").pack(side=LEFT)
        resE = Entry(fr3, width=15, textvariable=self.resrange)
        resE.pack()
        Label(fr1, text="").pack()

        self.anyChoice(fr1, self.bbdef, self.bbdef_list, 10, "backbone definition: " )

        Label(fr1, text="").pack()


        fr4 = Frame(fr1, bd=2, relief="sunken")
        fr4.pack()
        Label(fr4, text="Options for calculating per residue RMSD's: ").pack()

        Checkbutton(fr4, text="pairwise fit to mean", variable=self.resopt).pack()

        Label(supF, text="").pack()
        Label(supF, text="").pack()
        bt_fit = Button(fr1, text="Do Fit", bg="green", command=self.start_superimpose)
        bt_fit.pack(side=BOTTOM)


        # find best superimpose range
        self.gui_best_range(fr2)


    def gui_best_range(self, fr):
        """ draw the gui for calculating the best range """

        Label(fr, text="").pack()
        Label(fr, text="Calculate Best Range ", fg="darkgreen" ).pack()
        Label(fr, text="for Superimposition", fg="darkgreen" ).pack()
        Label(fr, text="").pack()

        fr_a = Frame(fr)
        fr_a.pack()
        Label(fr_a, text="Stop at which RMSD").pack(padx=10)
        Label(fr_a, text="(to mean)").pack(padx=10)
        resE = Entry(fr_a, width=5, textvariable=self.xrmsd)
        resE.pack()

        Label(fr, text="").pack()

        fr_b = Frame(fr)
        fr_b.pack()
        Label(fr_b, text="Best Range: ").pack(padx=10)
        rangeE = Entry(fr_b, width=20, textvariable=self.bestrange, 
            state=DISABLED)
        rangeE.pack()

        Label(fr, text="").pack()


        bt = Button(fr, text="Go", bg="green", command=self.start_best_range)
        bt.pack(side=BOTTOM)


    def anyChoice(self, parent, myvar, objlist, wdth, label):
        """ draw a pull-down menu of screen objects """

        def do_popup(event):
            mymenu.tk_popup(*parent.winfo_pointerxy())
            parent.focus_set()

        Label(parent, text=label).pack()
        atomE = Entry(parent, textvariable=myvar, width=wdth)
        atomE.pack()
        myvar.set(objlist[0])
        mymenu = Menu(tearoff=0)
        for i in objlist:
            mymenu.add_command(label=i, \
                command=lambda arg1=myvar, arg2=i: arg1.set(arg2))
        atomE.bind("<Button>", do_popup)


    def objChoice(self, parent, myvar, label):
        """ draw a pull-down menu of screen objects """

        def do_popup(event):
            mymenu = Menu(tearoff=0)
            objs = cmd.get_names('objects',1)
            try: myvar.set(objs[0])
            except:
                print "*** ERROR:  no structures loaded."
                return
            for i in objs:
                mymenu.add_command(label=i, \
                command=lambda arg1=myvar, arg2=i: arg1.set(arg2))
            mymenu.tk_popup(*parent.winfo_pointerxy())
            parent.focus_set()

        objs = cmd.get_names('objects',1)
        try: myvar.set(objs[0])
        except: myvar.set("")
        Label(parent, text=label).pack()
        objE = Entry(parent, width=10, textvariable=myvar)
        objE.pack()
        objE.bind("<Button>", do_popup)


    def new_rmsd(self):
        objs = cmd.get_names('objects',1)

        # residue range
        resi = " and name " + self.atom1.get()
        range = self.resrange.get()
        if range == "" or range == "all":
            resi = ""
        else:
            resi = " and resi " + self.resrange.get()

        sel1 = self.obj1.get() + " and name " + self.atom1.get() + resi
        sel2a = " and name " + self.atom1.get() + resi

        for i in objs:
            if i == self.obj1.get(): continue
            sel2 = i + sel2a
            rmsd=cmd.rms(sel1, sel2)
            print "rmsd %s + %s = %f" % (self.obj1.get(), i, rmsd)


    def current_rmsd(self):
        rmsdL = []
        rmsd1T = 0
        rmsd2T = 0
        objs = cmd.get_names('objects',1)

        # residue range
        resi = " and name " + self.atom1.get()
        range = self.resrange.get()
        if range == "" or range == "all":
            resi = ""
        else:
            resi = " and resi " + self.resrange.get()

        sel1 = self.obj1.get() + " and name " + self.atom1.get() + resi
        sel2a = " and name " + self.atom1.get() + resi

        print "-----------------|-----------------|---------|-----------"
        print " BASE OBJECT     | OBJECT          | CURRENT | PROJECTED "
        print "-----------------|-----------------|---------|-----------"
        for i in objs:
            if i == self.obj1.get(): continue
            sel2 = i + sel2a
#            rmsdL.append(cmd.rms(sel1, sel2))
            rmsd1=cmd.rms_cur(sel1, sel2)
            rmsd2=cmd.rms(sel1, sel2)
            rmsd1T += rmsd1
            rmsd2T += rmsd2
            print " %-15s | %-15s |  %6.3f |  %6.3f" % (self.obj1.get(), i, rmsd1, rmsd2)
        print "-----------------|-----------------|---------|-----------"
        print " TOTALS          |                 | %7.2f | %7.2f" % (rmsd1T, rmsd2T)
        print "-----------------|-----------------|---------|-----------"


    def start_superimpose(self):

        objs = cmd.get_names('objects',1)
        if len(objs) == 0:
            print "*** ERROR: no structures loaded."
            return

        self.ref = self.obj1.get()
        self.atoms = self.atom1.get()
        self.range = self.resrange.get()
        self.bbatoms = self.bbdef.get()
        self.bbatoms = self.bbatoms.split()

        rmsdL = superimpose(self.ref, self.atoms, self.range, self.bbatoms, self.resopt.get(), 0)



    def build_array(self, l):
        """ a function to turn a python list into a C array """
        nitems = len(l)
        a = new_charArray(nitems)
        i = 0
        for item in l:
            charArray_setitem(a, i, item)
            i = i + 1
        return a

    def build_str(self, l):
        """ a function to turn a python list into a C string """
        nitems = len(l)
        a = ""
        i = 0
        for item in l:
            a = a + " " + item
            i = i + 1
        return a


     
    def start_best_range(self):

        objs = cmd.get_names('objects',1)
        if len(objs) == 0:
            print "*** ERROR: no structures loaded."
            return

        self.ref = self.obj1.get()
        self.atoms = self.atom1.get()
        self.bbatoms = self.bbdef.get()
        self.bbatoms = self.bbatoms.split()
        self.range = self.resrange.get()
        self.maxrmsd = self.xrmsd.get()
        objs = cmd.get_names('objects', 1)
        objsC = self.build_str(objs)

        bb = string.join(self.bbatoms)
        rmsdL = C_do_best_range(self.ref, self.atoms, objsC, self.range, common.workingdir, self.xrmsd.get())
        print "RANGE = %s " % (rmsdL)

#        rmsdL = do_best_range(self.ref, self.atoms, self.range, bb, self.xrmsd.get(), self.resopt.get())

        self.bestrange.set(rmsdL)


#---------------------------------------------------------------------------
def format_range(myrange):
    """ take a range in the form of a list, and make it
        more human readable
    """

    newrange = ""
    start = -1
    last = -1
    for value in myrange:
        if value == last + 1:
            last = value
            continue

        # start of new chain
        if start != -1:
            if newrange != "": newrange = "%s,%d-%d" % (newrange, start, last)
            else: newrange = "%d-%d" % (start, last)
        start = value
        last = value

    if newrange != "": newrange = "%s,%d-%d" % (newrange, start, last)
    else: newrange = "%d-%d" % (start, last)
    print "newrange = ", newrange
    return newrange


#---------------------------------------------------------------------------
def global_mean_rmsd(myrmsd, sup, range, atoms):
    """ calculate the global rmsd's from each object to the mean,
    and to each other """

    tmp_cdx = dict()
    for i in myrmsd.objs:
        cdx = myrmsd.get_coords(i, range, atoms)
        tmp_cdx[i] = cdx

    copy_objs = list()
    copy_objs.extend(myrmsd.objs)
    allrmsd = list()
    go_i = 1
    print ""
    while go_i:

        try: i = copy_objs.pop(0)
        except: go_i = 0

        for j in copy_objs:
            sup.set(tmp_cdx[i], tmp_cdx[j])
            result = sup.get_rms2()
            print "rmsd ", i, " <-> ", j, " =  %.2f" % (result)
            myrmsd.rmsdD[i][j] = result
            myrmsd.rmsdD[j][i] = result
            allrmsd.append(result)

    return allrmsd


#----------------------------------------------------------------------
def calc_avg_to_mean(sup, myrmsd, range, atoms):
    """ calculate the global rmsd's to the mean.  return the average """

    meanrmsd = list()
    tmp_cdx = dict()

    for i in myrmsd.objs:
        cdx = myrmsd.get_coords(i, range, atoms)
        tmp_cdx[i] = cdx

    for i in myrmsd.objs:
        if i == "mean": continue
        sup.set(tmp_cdx[i], tmp_cdx["mean"])
        result = sup.get_rms2()
        meanrmsd.append(result)

    return MLab.mean(meanrmsd)


#---------------------------------------------------------------------------
def leigh_sup(sup, myrmsd, ref, range, atoms):

    # get coordinate list for reference structure
    cdx1 = myrmsd.get_coords(ref, myrmsd.range, myrmsd.atoms)


    # for each onscreen object, superimpose it to 'struc'
    for i in myrmsd.objs:
        if i == ref: continue

        # get coordinate list for object i
        cdx2 = myrmsd.get_coords(i, myrmsd.range, myrmsd.atoms)

        # do the superimposition
        sup.set(cdx1, cdx2)
        sup.run()
        rot, tran=sup.get_rotran()


        # create the TTT matrix (list) that pymol wants
        TTT = [rot[0, 0], rot[0,1], rot[0,2], 0, \
               rot[1, 0], rot[1,1], rot[1,2], 0, \
               rot[2, 0], rot[2,1], rot[2,2], 0, \
               rot[3, 0], rot[3,1], rot[3,2], 0]

        # update our pdb values (used globally throughout pynmr)
        myrmsd.update_pdb(i, TTT)


#---------------------------------------------------------------------------
def do_best_range(ref, atoms, range, bbdef, limit, pairfit=0):
    """ ref = the structure to superimpose the others to
        atoms = the atoms to superimpose on ex. "c o n ca"
        ex "-H"
        bbdef = the backbone definition ie "N CA C"
    """
    from SVDSuperimposer import SVDSuperimposer


    sup = SVDSuperimposer()
    myrmsd = RMSD(ref, atoms, range, bbdef)
    myrange = range

    #   superimpose, based on the total range
    #   get the mean structure
    #   get the rmsd from each superimposed struc to mean

    #   is the averagae rmsd > given limit?
    #       no: calc rmsd per residue.  eliminate biggest rmsd.
    #       yes:  break
    #       superimpose, based on new range
    #       mean structure
    #       rmsd per object to mean


    leigh_sup(sup, myrmsd, ref, myrange, myrmsd.atoms)
    # calcualte the mean structure
    myrmsd.calc_show_mean(showflag = 1)

    # calculate the per-residue rmsds
    rmsd_bb_res = dict()
    result = myrmsd.get_rmsd_per_res(sup, "mean", bbdef)
    for res in myrmsd.resnums:
        rmsd_bb_res[res] = MLab.mean(result[res])


    # what is the average rmsd to the mean?
    # this is what we are trying to minimize!
    # loop until rmsd is under the maximum value
    result = calc_avg_to_mean(sup, myrmsd, myrange, myrmsd.atoms)
    print "starting RMSD = ", result
    while result > limit:

        idx = 0
        largest = 0

        # get the highest value from rmsd_bb_res

        for res in myrange:
            if rmsd_bb_res[res] > largest:
                idx = res
                largest = rmsd_bb_res[res]
        myrange.remove(idx)
        # this takes the longest
        leigh_sup(sup, myrmsd, ref, myrange, myrmsd.atoms)
        # calcualte the mean structure
        myrmsd.calc_show_mean(showflag = 1)

        result = calc_avg_to_mean(sup, myrmsd, myrange, myrmsd.atoms)
        print "eliminating res %3d with rmsd %5.2f NEW RMSD = %5.2f" % \
            (idx, largest, result)

        # calculate the per-residue rmsds
        rmsd_bb_res = dict()
        morerms = myrmsd.get_rmsd_per_res(sup, "mean", bbdef)
        for res in myrmsd.resnums:
            rmsd_bb_res[res] = MLab.mean(morerms[res])


    # get the range into a readable form
    return format_range(myrange)



#---------------------------------------------------------------------------
def superimpose(ref, atoms, range, bbdef, pairfit=0, progressflag=0):
    """ ref = the structure to superimpose the others to
        atoms = the atoms to superimpose on ex. "name c+o+n+ca"
          ex "not h."
        range = residue range ie. "20-88"
        bbdef = the backbone definition ie "N CA C"
        progressflag = if a progress bar should appear
     """
    from SVDSuperimposer import SVDSuperimposer


    sup = SVDSuperimposer()
    myrmsd = RMSD(ref, atoms, range, bbdef)

    # error check
    if len(myrmsd.range) == 0:
        print "ERROR in Range.  Please fix it and continue."
        return

    # get coords for reference structure
    cdx1 = myrmsd.get_coords(ref, myrmsd.range, myrmsd.atoms)
    if cdx1.shape[0] == 0:
        print "WARNING:  nothing to superimpose"
        return

    # for each onscreen object, superimpose it to 'struc'
    for i in myrmsd.objs:
        if i == ref: continue

        # get subset of coords, depending on the atom and residue range
        cdx2 = myrmsd.get_coords(i, myrmsd.range, myrmsd.atoms)

        # do the superimposition
        sup.set(cdx1, cdx2)
        sup.run()
        rot, tran=sup.get_rotran()

        # DEBUG
        #print "DEBUG.py obj = ", i
        #print "DEBUG.py  rotation = ", rot
        #print "DEBUG.py  translation = ", tran

        # create the TTT matrix (list) that pymol wants
        TTT = [rot[0, 0], rot[0,1], rot[0,2], 0, \
               rot[1, 0], rot[1,1], rot[1,2], 0, \
               rot[2, 0], rot[2,1], rot[2,2], 0, \
               tran[0], tran[1], tran[2], 0]

        # get pymol to do the rotation/translation
        cmd.transform_selection(i, TTT, transpose=1)
        # update our pdb values (used throughout pynmr)
        myrmsd.update_pdb(i, TTT)
        # update our coordinate list (used locally)
        #myrmsd.car[i] = myrmsd.update_coords(myrmsd.car[i],TTT)


    # calculate the mean structure
    myrmsd.calc_show_mean()

    # calculate the new rmsd values for all structures to each other,
    # with the updated coordinates
    allrmsd = global_mean_rmsd(myrmsd, sup, myrmsd.range, myrmsd.atoms)

    meanrmsd = myrmsd.rmsdD["mean"].values()

    print ""
    print "average rmsd to mean = %.3f stdev = %.3f" % (MLab.mean(meanrmsd),
        MLab.std(meanrmsd))

    if (len(myrmsd.rmsdD) > 1): print "average total rmsd = %.3f stdev = %.3f\n" % (MLab.mean(allrmsd), MLab.std(allrmsd))

    myrmsd.calc_print_res_rmsd(ref, sup, pairfit)


#---------------------------------------------------------------------------
def start_superImpose(parent, nmrframe):
    """ this is called from the main pynmr module. """

    # create superimpose object
    sup = Impose(parent, nmrframe)

    # from here on in we wait for user input and respond
    sup.draw(nmrframe)

