# Copyright (C) 2004 Scott W. Dunlop <swdunlop at users.sourceforge.net>
# 
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

import sqlite
import time
import os
import sys

from exceptions import Exception, StopIteration
from types import StringType
from node import Node
from threading import Lock, Thread

class DatabaseError( sqlite.DatabaseError ):
    pass

class ExecutionError( DatabaseError ):
    def __init__ (self, operation, *args ):
        DatabaseError.__init__( self, *args )
        self.operation = operation

    def getOperation( self ):
        return self.operation

    def __str__( self ):
        return "%s, error occurred while executing %r" %(
            DatabaseError.__str__( self ),
            self.getOperation() 
        )

class TransactionError( ExecutionError ):
    pass

class QueryError( ExecutionError ):
    pass

class Database( object ):
    __slots__ = ( 'wiki', 'databaseConnection' )

    def __init__( self, wiki ):
        self.wiki = wiki
        self.databaseConnection = None
    
    def getWiki( self ):
        return self.wiki
        
    def connect( self, dbPath, initIfMissing = True ):
        if os.access( dbPath, os.F_OK ):
            if os.access( dbPath, os.R_OK | os.W_OK ):
                self.databaseConnection = sqlite.connect( 
                    dbPath, 
                    autocommit=1, 
                    command_logfile=DatabaseLogger( self.getWiki() )
                )
                
                DatabaseSweeper( self ).start()
                return True
            else:
                raise DatabaseError, "Database file cannot be accessed."
                
        elif initIfMissing:
            self.databaseConnection = sqlite.connect( 
                dbPath, 
                autocommit=1, 
                command_logFile=DatabaseLogger(self.getWiki() )
            )
            self.initTables( )
            
            return True
        else:
            return False
            
    def executeOnCursor1( self, databaseCursor, operation ):
        if isinstance( operation, StringType ):
            databaseCursor.execute( operation )
        else:
            databaseCursor.execute( operation[0], operation[1:] )
        return databaseCursor
            
    def executeOnCursor( self, databaseCursor, *operations ):
        for operation in operations:
            self.executeOnCursor1( databaseCursor, operation )
        return databaseCursor

    def query( self, *operations ):
        try:
            databaseCursor = self.executeOnCursor( 
                self.databaseConnection.cursor(), *operations 
            )
        except ExecutionError, e:
            raise QueryError( 
                e.getOperation(), "Query did not complete successfully"
            )
        else:    
            while True:
                nextrec = databaseCursor.fetchone( )
                if nextrec:
                    yield nextrec
                else:
                    return

    def query1( self, *operations ):
        results = self.query( *operations )
        try:
            return results.next()
        except StopIteration:
            return None
        
    def transact( self, *operations ):
        try:
            databaseCursor = self.databaseConnection.cursor()
            databaseCursor.execute( "BEGIN;" )
        
            self.executeOnCursor( databaseCursor, *operations )
        except Exception, e:
            databaseCursor.execute( "ROLLBACK;" )
            raise e
        else:
            databaseCursor.execute( "COMMIT;" )

    def initTables( self ):
        self.transact(
            "CREATE TABLE CloudNodes ( key, content, creation, creator, modification, modifier );",
            "CREATE TABLE CloudReferences ( src, dest );",
            "CREATE TABLE CloudChanges ( key, time, idx, op, line, user );",
            "CREATE TABLE CloudConfig ( key, value );",
            "CREATE TABLE CloudAuth( username, password );",
            "CREATE TABLE CloudHtml( key, html );",
            "CREATE UNIQUE INDEX CloudNodes_key ON CloudNodes( key );",
            "CREATE INDEX CloudNodes_creation ON CloudNodes( creation );",
            "CREATE INDEX CloudNodes_modification ON CloudNodes( modification );",
            "CREATE INDEX CloudReferences_src ON CloudReferences( src );",
            "CREATE INDEX CloudReferences_dest ON CloudReferences( dest );",
            "CREATE INDEX CloudChanges_key ON CloudChanges( key );",
            "CREATE INDEX CloudChanges_time ON CloudChanges( time );",
            "CREATE INDEX CloudChanges_idx ON CloudChanges( idx );",
            "CREATE UNIQUE INDEX CloudConfig_key ON CloudConfig( key );",
            "CREATE UNIQUE INDEX CloudAuth_username ON CloudAuth( username );",
            "CREATE UNIQUE INDEX CloudHtml_key ON CloudHtml( key );"
        )
        
        self.setVersion( 1,1 )

    def getNodeRecord( self, key ):
        return self.query1((
            "SELECT key, content, creation, creator, modification, modifier FROM CloudNodes WHERE key == %s;", key 
        ))

    def fetchNode( self, key ):
        record = self.getNodeRecord( key )
        if record is None:
            return Node( self.getWiki(), key )
        else:
            return Node( self.getWiki(), *record )

    def fetchAllNodes( self ):
        for record in self.query((
            "SELECT key, content, creation, creator, modification, modifier FROM CloudNodes;" 
        )):
            yield Node( self.getWiki(), *record )
             
    def fetchHtml( self, key ):
        res = self.query1((
            "SELECT html FROM CloudHtml WHERE key == %s;", key 
        ))
        
        if res:
            return res[0] or None
        else:
            return None
            
    def storeHtml( self, key, html ):
        self.transact((
            "INSERT INTO CloudHtml( key, html )VALUES( %s, %s );",
            key, html
        ))
        
    def fetchReferencesTo( self, key ):
        references = []
        
        for record in self.query((
            "SELECT src FROM CloudReferences WHERE dest == %s;", key
        )):
            yield record[0]
        
    def nodeContentChanged( self, node ):
        key = node.getKey()
        content = node.getContent()
        modification = node.getModification()
        modifier = node.getModifier()
        
        self.transact(( "DELETE FROM CloudHtml WHERE key == %s;", key ))
        self.transact(( "DELETE FROM CloudReferences WHERE src==%s;", key ))
        
        self.transact(( 
            "UPDATE CloudNodes "
            "SET content=%s, modification=%s, modifier=%s "
            "WHERE key=%s;",
            
            content, modification, modifier, 
            key
        ))
        
        for link in node.getContentLinks():
            self.transact((
                "INSERT INTO CloudReferences( src, dest ) VALUES ( %s, %s );", 
                key, link
            ))
        
    def nodeContentCreated( self, node ):
        key = node.getKey()
        content = node.getContent()
        creation = node.getCreation()
        creator = node.getCreator()
                        
        self.transact((
            "INSERT INTO CloudNodes "
            "( key, content, creator, creation, modifier, modification ) "
            "VALUES ( %s, %s, %s, %s, %s, %s );", 
            
            key, content, creator, creation, creator, creation
        ))
        
        for ref in self.fetchReferencesTo( key ):
            self.transact((
                "DELETE FROM CloudHtml WHERE key == %s;", ref
            ))

        for link in node.getContentLinks():
            self.transact((
                "INSERT INTO CloudReferences( src, dest ) VALUES ( %s, %s );", 
                key, link
            ))

    def removeNode( self, key ):
        for ref in self.fetchReferencesTo( key ):
            self.transact((
                "DELETE FROM CloudHtml WHERE key == %s;", key
            ))
        
        self.transact(
            ("DELETE FROM CloudNodes WHERE key == %s;", key),
            ("DELETE FROM CloudChanges WHERE key == %s;", key),
            ("DELETE FROM CloudReferences WHERE src == %s;", key)
        )

    def nodeExists( self, key ):
        return self.getNodeRecord( key ) is not None

    def addChange( self, key, idx, op, line, time, user = '' ):
        self.transact((
            "INSERT INTO CloudChanges( key, time, idx, op, line, user )"
            " VALUES ( %s, %s, %s, %s, %s, %s );",
            
             key, time, idx, op, line, user
        ))
        
    def getChanges( self, key, windowSecs ):
        cutoff = self.getServerTime() - windowSecs
        
        return self.query((
            "SELECT time, idx, op, line, user FROM CloudChanges "
            "WHERE key == %s AND time >= %s "
            "ORDER BY time ASC, idx ASC",
            
            key, cutoff
        ))
            
    def getNodeKeys( self ):
        for record in self.query((
            "SELECT key FROM CloudNodes;"
        )):
            yield record[0]
    
    def getServerTime( self ):
        return int(time.time())
        
    def getConfig( self, key, default = None ):
        record = self.query1(( 
            "SELECT value FROM CloudConfig WHERE key == %s;",
            key
        ))
        
        if record:
            return record[0]
        else:
            return default
    
    def delConfig( self, key ):
        self.transact((
            "DELETE FROM CloudConfig WHERE key == %s;",
            key
        ))
        
    def setConfig( self, key, value ):
        if self.getConfig( key ) is not None:
            self.transact((
                "UPDATE CloudConfig SET value = %s WHERE key == %s;",
                value, key
            ))
        else:
            self.transact((
                "INSERT INTO CloudConfig( key, value ) VALUES ( %s, %s );",
                key, value
            ))

    def getConfigKeys( self ):
        for record in self.query( "SELECT key FROM CloudConfig;" ):
            yield record[0]
   
    def getAddedNodesSince( self, interval ):
        cutoff = self.getServerTime() - interval
        
        for record in self.query((
            "SELECT key, creation, creator FROM CloudNodes WHERE creation >= %s ORDER BY creation;",
            
            cutoff
        )):
            yield record

    def getChangedNodesSince( self, interval ):
        cutoff = self.getServerTime() - interval
        
        for record in self.query((
            "SELECT key, modification, modifier FROM CloudNodes WHERE modification >= %s ORDER BY modification;",
            
            cutoff
        )):
            yield record

    def getPassword( self, username, default = None ):
        record = self.query1(( 
            "SELECT password FROM CloudAuth WHERE username == %s;",
            username
        ))
        
        if record:
            return record[0]
        else:
            return default
    
    def delPassword( self, username ):
        self.transact((
            "DELETE FROM CloudAuth WHERE username == %s;",
            username
        ))
        
    def setPassword( self, username, value ):
        if self.getPassword( username ) is not None:
            self.transact((
                "UPDATE CloudAuth SET password = %s WHERE username == %s;",
                value, username
            ))
        else:
            self.transact((
                "INSERT INTO CloudAuth( username, password ) VALUES ( %s, %s );",
                username, value
            ))

    def getUsernames( self ):
        for record in self.query( "SELECT username FROM CloudAuth;" ):
            yield record[0]
 
    def sweepChanges( self ):
        #NOTE: We don't worry about thread safety here, because transact
        #      uses a fresh cursor for the transaction.  It's dirty, but
        #      acceptable.

        self.transact((
            "DELETE FROM CloudChanges WHERE time < %s;",
            self.getServerTime() - 604800
        ))

    def getVersion( self ):
        major = int( self.getConfig( "major-version", "" ) or "1" )
        minor = int( self.getConfig( "minor-version", "" ) or "0" )
        
        return major, minor
    
    def setVersion( self, major, minor ):
        self.setConfig( "major-version", str( major ) )
        self.setConfig( "minor-version", str( minor ) )
        
        return major, minor
        
class DatabaseLogger( object ):
    __slots__ = ['wiki']
    
    def __init__( self, wiki ):
        self.wiki = wiki
        
    def write( self, message ):
        if self.wiki.logFile is not None:
            self.wiki.logData( message )

class DatabaseSweeper(Thread):
    __slots__ = 'database',

    def __init__(self, database):
        self.database = database
        Thread.__init__(self)
        self.setDaemon(1)

    def run(self):
        while 1:
            time.sleep( 86400 )
            self.database.sweepChanges()

