# GNU Enterprise Common - Interbase/Firebird DB Driver - Schema Introspection
#
# Copyright 2001-2005 Free Software Foundation
#
# This file is part of GNU Enterprise
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# $Id: Introspection.py 7007 2005-02-11 16:18:59Z reinhard $

__all__ = ['Introspection']

import string
import re
import kinterbasdb

from gnue.common.datasources import GIntrospection

# =============================================================================
# This class implements schema introspection for Interbase / Firebird
# =============================================================================

class Introspection (GIntrospection.Introspection):

  # list of the types of Schema objects this driver provides
  types = [('table', _('Tables'), 1),
           ('view' , _('Views') , 1)]

  _NOW = re.compile ("'(NOW\s*\(\)\s*)'", re.IGNORECASE)


  # ---------------------------------------------------------------------------
  # Find a schema element by name and/or type
  # ---------------------------------------------------------------------------

  def find (self, name = None, type = None):
    """
    This function searches the schema for an element by name and/or type. If no
    name and no type is given, all elements will be retrieved.

    @param name: look for an element with this name
    @param type: look for an element with this type
    @return: A sequence of schema instances, one per element found, or None if
        no element could be found.
    """

    gDebug (9, "Looking for '%s' of type '%s'" % (name, type))

    result = []
    cond   = ["rdb$system_flag = 0"]

    if name is not None:
      cond.append (u"rdb$relation_name = '%s'" % self.__identifier (name))

    if type == 'table':
      cond.append (u"rdb$view_source IS NULL")
    elif type == 'view':
      cond.append (u"rdb$view_source IS NOT NULL")

    cmd = u"SELECT rdb$relation_name, rdb$view_source FROM RDB$RELATIONS " \
           "WHERE %s ORDER BY rdb$relation_name" % string.join (cond, " AND ")

    cursor = self._connection.makecursor (cmd)

    try:
      for rs in cursor.fetchall ():
        relname = string.strip (rs [0])

        attrs = {'id'        : relname,
                 'name'      : relname,
                 'type'      : rs [1] is None and 'table' or 'view',
                 'primarykey': None,
                 'indices'   : self.__getIndices (relname)}

        if attrs ['indices'] is not None:
          for index in attrs ['indices'].values ():
            if index ['primary']:
              attrs ['primarykey'] = index ['fields']

        result.append ( \
          GIntrospection.Schema (attrs, getChildSchema = self._getChildSchema))

    finally:
      cursor.close ()

    return len (result) and result or None


  # ---------------------------------------------------------------------------
  # Get all fields of a relation / view
  # ---------------------------------------------------------------------------

  def _getChildSchema (self, parent):
    """
    This function returns a list of all child elements of a given parent
    relation.

    @param parent: schema object instance whose child elements should be
        fetched.
    @return: sequence of schema instances, one per element found
    """

    result = []

    cmd = u"SELECT rf.rdb$field_name, tp.rdb$type_name, rf.rdb$null_flag, " \
             "rf.rdb$default_source, fs.rdb$field_length, " \
             "fs.rdb$field_scale, fs.rdb$field_precision " \
           "FROM rdb$relation_fields rf, rdb$fields fs, rdb$types tp " \
           "WHERE rf.rdb$relation_name = '%s' AND " \
             "fs.rdb$field_name = rf.rdb$field_source AND " \
             "tp.rdb$type = fs.rdb$field_type AND " \
             "tp.rdb$field_name = 'RDB$FIELD_TYPE'" \
           "ORDER BY rf.rdb$field_name" % self.__identifier (parent.name)

    cursor = self._connection.makecursor (cmd)

    try:
      for rs in cursor.fetchall ():
        nativetype = rs [1].strip ()
        attrs = {'id'        : "%s.%s" % (parent.name, rs [0].strip ()),
                 'name'      : rs [0].strip (),
                 'type'      : 'field',
                 'nativetype': nativetype,
                 'required'  : rs [2] is not None}

        if nativetype in ['DATE', 'TIME', 'TIMESTAMP']:
          attrs ['datatype'] = 'date'

        elif nativetype in ['DOUBLE', 'FLOAT', 'INT64', 'LONG', 'QUAD', \
            'SHORT']:
          attrs ['datatype']  = 'number'
          attrs ['length']    = rs [6]
          attrs ['precision'] = abs (rs [5])

        else:
          attrs ['datatype'] = 'text'
          attrs ['length'] = rs [4]

        if rs [3] is not None:
          default = rs [3]
          if self._NOW.search (default) is not None:
            attrs ['defaulttype'] = 'timestamp'
          else:
            attrs ['defaulttype'] = 'constant'
            attrs ['defaultval']  = default [8:]

        result.append (GIntrospection.Schema (attrs))

    finally:
      cursor.close ()

    return result


  # ---------------------------------------------------------------------------
  # Get all indices of a given relation
  # ---------------------------------------------------------------------------

  def __getIndices (self, relname):
    """
    This function creates a dictionary with all indices of a given relation
    where the keys are the indexnames and the values are dictionaries
    describing the indices. Such a dictionary has the keys 'unique', 'primary'
    and 'fields', where 'unique' specifies whether the index is unique or not
    and 'primary' specifies wether the index is the primary key or not.
    'fields' holds a sequence with all field names building the index.

    @param relname: name of the relation to fetch indices for
    @return: dictionary with indices or None if no indices were found
    """

    result = {}

    cmd = u"SELECT i.rdb$index_name, i.rdb$unique_flag, s.rdb$field_name " \
           "FROM rdb$indices i, rdb$index_segments s " \
           "WHERE i.rdb$index_name = s.rdb$index_name " \
           "  AND i.rdb$relation_name = '%s' " \
           "  AND i.rdb$foreign_key IS NULL " \
           "ORDER BY i.rdb$index_name, s.rdb$field_position" \
          % self.__identifier (relname)

    cursor = self._connection.makecursor (cmd)
    
    try:
      for rs in cursor.fetchall ():
        indexName = rs [0].strip ()

        if not result.has_key (indexName):
          result [indexName] = {'unique' : rs [1] or False,
                                'primary': False,
                                'fields' : []}
        result [indexName] ['fields'].append (rs [2].strip ())

    finally:
      cursor.close ()

    # if there's a primary key, update the proper index-entry and replace the
    # indexname by it's constraint name
    cmd = u"SELECT rdb$index_name, rdb$constraint_name " \
           "FROM rdb$relation_constraints " \
           "WHERE rdb$constraint_type = 'PRIMARY KEY' " \
              "AND rdb$relation_name = '%s'" % self.__identifier (relname)

    cursor = self._connection.makecursor (cmd)
    
    try:
      rs = cursor.fetchone ()
      if rs:
        (ixName, coName) = [rs [0].strip (), rs [1].strip ()]
        if result.has_key (ixName):
          result [ixName]['primary'] = True
          result [coName] = result [ixName]
          del result [ixName]

    finally:
      cursor.close ()

    return len (result.keys ()) and result or None


  # ---------------------------------------------------------------------------
  # Prepare an identifier for matching against rdb$-values
  # ---------------------------------------------------------------------------

  def __identifier (self, name):
    #if kinterbasdb.__version__ [:3] == (3, 0, 1):
      #return name.upper ()
    #else:
      #return name
    return name.upper ()
