import db_base

class sql_db(db_base.db_base):
    """Abstract base class for SQL database interfaces

    This base class contains methods that are common to all SQL databases
    """

    ###########################################################################
    # Utility methods

    def clear_cache(self):
        for val in self.cache.values():
            val.clear()

    def empty_table(self, table_name):
        self.execute("DELETE FROM %s" % table_name.lower())

    def drop_table(self, table_name):
        self.execute("DROP TABLE %s" % table_name.lower())
    
    ###########################################################################
    # Class registration methods

    create_fmt = "CREATE TABLE %s (%s)"

    def create_table(self, cls):
        """Create a table for class cls"""

        cls.table_name = cls.__name__.lower()

        # add the column names to the property objects
        self.create_column_names(cls)

        # Ask the properties to ask the database to create column with the
        # right signatures
        columns = []
        for attr in cls.persistent_values:
            columns.append(attr.create_column(self))
        for attr in cls.persistent_refs:
            columns.append(attr.create_column(self))

        cmnd = self.create_fmt % (cls.table_name, ', '.join(columns))
        self.execute(cmnd)

        cls.tid = self.get_tid(cls.table_name)

        # Create a cache for instances of this class
        self.cache[cls.tid] = {}
        self.tid_map[cls.tid] = cls
        
    def register_existing_class(self, cls):
        """If the table already exists, just do some bookkeeping"""
        cls.table_name = cls.__name__.lower()
        self.create_column_names(cls)
        cls.tid = self.get_tid(cls.table_name)
        self.cache[cls.tid] = {}
        self.tid_map[cls.tid] = cls

    def create_column_names(self, cls):
        """Add the column names to the property objects"""
        for prop in cls.persistent_values:
            prop.sql_name = prop.name.lower()
        for prop in cls.persistent_refs:
            prop.sql_tid_name = prop.name.lower()+'_tid'
            prop.sql_oid_name = prop.name.lower()+'_oid'

    ###########################################################################
    # Store and retrieve instances

    def commit(self, instance):
        """Commit an instance to the database.

        A new instance should be inserted into the database
        An existing instance should update the database
        """
        if instance.oid is None:
            self.insert(instance)
        else:
            self.update(instance)

    insert_fmt = "INSERT INTO %s (%s) VALUES (%s)"

    def insert(self, instance):
        """Insert a new instance into the database"""
        columns, values = self.get_columns_and_values(instance)
        cmnd = self.insert_fmt % (instance.table_name,
                                  ','.join(columns),
                                  ','.join(values))
        oid = self.execute(cmnd)
        instance.oid = oid
        self.cache[instance.__class__.tid][oid] = instance
        instance._dirty = False

    def update(self, instance):
        """Update the database for an existing instance (if necessary)"""
        # do nothing is the instance has not been modified
        if not instance._dirty:
            return
        columns, values = self.get_columns_and_values(instance)
        terms = ['%s=%s' % (col, val) for col, val in zip(columns, values)]
        cmnd = "UPDATE %s SET %s WHERE oid=%s" % (instance.table_name,
                                                  ','.join(terms),
                                                  instance.oid)
        self.execute(cmnd)
        instance._dirty = False

    def get_columns_and_values(self, instance):
        """Return a list of column names and row-values in SQL representation"""
        columns = []
        values = []
        for prop in instance.persistent_values:
            columns.append(prop.sql_name)
            values.append(str(prop._fastget(instance, self)))
        for prop in instance.persistent_refs:
            columns += [prop.sql_tid_name, prop.sql_oid_name] 
            values.extend(prop._fastget(instance, self))
        return columns, values

    discard_oid = 0 # discard oid on retrieve

    def retrieve_instance(self, cls, oid):
        """Retrieve an instance from the database"""
        instance = cls.make_old_instance(oid)
        self.cache[cls.tid][instance.oid] = instance

        cmnd = "SELECT * FROM %s WHERE oid=%s" % (cls.table_name, oid)
        data = iter(self.fetchall(cmnd))
        if self.discard_oid:
            oid = data.next()
        for prop in cls.persistent_values:
            prop._fastset(instance, data.next(), self)
        for prop in cls.persistent_refs:
            ref_tid, ref_oid = data.next(), data.next()
            if ref_tid is not None:
                prop._fastset(instance, (self.tid_map[ref_tid], ref_oid), self)

        return instance

    ###########################################################################
    # Evaluate queries

    def query_generator(self, cls, relation):
        """This method is called by relation.__iter__ and generates instances
        of cls satisfying relation.
        """

        # The bound_variables dict stores information needed to produce joins.
        #
        # The simple query A.a.b == 1 implies a join between A and the type
        # that A.a refers to. This information is stored in the
        # bound_variables dict as {('_x', 'a'): Y} where '_x' is an alias for
        # the free variable (A). Y is the ptype of A.a.
        #
        # This relation will result in one join clause _x.a_oid = _x_a.oid,
        # where _x_a is an alias for Y. This is the same alias used in the
        # WHERE clause: _x_a.b=1
        #
        # The query A.a.b.c == 1 implies two joins. One between A and X, the
        # type that A.a refers to, and one between X and Y, the type, that X.b
        # refers to. This produces a bound_variables dict {('_x', 'a'):X,
        # ('_x', 'a', 'b'):Y}, and a join clause (_x.a_oid = _x_a.oid AND
        # _x_a.b_oid = _x_a_b.oid)
        
        # we use a dict to avoid generaring duplicate joins
        bound_variables = {}
        relation.update_bound_variables(bound_variables)

        # we have filtered the join clause, use a list from now
        bound_variables = bound_variables.items()

        select_clause = relation.sql_repr(self)
        if not bound_variables:
            # If there are no bound variables, we don't have to generate
            # join clauses.
            cmnd = "SELECT _x.oid FROM %s _x WHERE %s" % (cls.table_name,
                                                          select_clause)
            oids = self.fetchall(cmnd)
            for oid in oids:
                yield cls(oid=oid)
        else:
            cmnd_format = "SELECT _x.oid FROM %s _x,%s WHERE %s AND %s"

            # for each bound variable we have to look at all of its subclasses too.
            # for two bound variables with n and m subclasses respectively, this will
            # result in n*m queries
            for current_vars in self.generate_bound_variables(bound_variables):

                # A string of comma-separated 'cls.tbl_name alias' pairs
                aliases = self.create_aliases(current_vars)

                # For each bound variable generate the join term, join with 'AND' 
                join_clause = self.create_join_clause(current_vars)

                # We have all necessary information, execute the query
                cmnd = cmnd_format % (cls.table_name, aliases, join_clause, select_clause)
                oids = self.fetchall(cmnd)
                for oid in oids:
                    yield cls(oid=oid)

    def generate_bound_variables(self, bound_variables):
        """Generate all permutations for the subclasses in the bound_variables

        bound_variables is a list of (alias, cls) pairs. Each cls has a list
        of suclasses, generated by cls.all_subclasses(). The generator yields
        lists of (alias, subcls) pairs, that are obtained by producing all
        permutations of the subclass lists.

        I.e.:

        bound_variables = [(alias1, cls1), ... , (aliasN, clsN)]

        for subcls1 in cls1.all_subclasses():
            ... # N nested loops
                for subclsN in clsN.all_subclasses():
                    yield [(alias1, subcls1), ... , (aliasN, subclsN)]
        """
        
        alias, cls = bound_variables[0]
        tail = bound_variables[1:]

        for subcls in cls.all_subclasses():
            if tail:
                for j in self.generate_bound_variables(tail):
                    k = [(alias, subcls)]
                    k.extend(j)
                    yield k
            else:
                yield [(alias, subcls)]

    def create_aliases(self, bound_variables):
        """Create terms for the FROM part"""
        aliases = []
        for alias, cls in bound_variables:
            aliases.append("%s %s" % (cls.table_name, '_'.join(alias)))
        return ','.join(aliases)

    def create_join_clause(self, bound_variables):
        """Create the join-part of the WHERE clause"""
        clauses = []
        fmt_map = {str:repr, int:str, long:str}
        for alias, cls in bound_variables:
            head, tail = '_'.join(alias[:-1]), alias[-1]
            clauses.append("%s.%s_oid = %s_%s.oid" % (head, tail, head, tail))
            clauses.append("%s.%s_tid = %s" % (head, tail, fmt_map[type(cls.tid)](cls.tid)))
        return " AND ".join(clauses)
