"Templates for generating DOLFIN wrappers"

__author__ = "Martin Alnes (martinal@simula.no) and Anders Logg (logg@simula.no)"
__date__ = "2008-11-06"
__copyright__ = "Copyright (C) 2008-2009 Martin Alnes"
__license__  = "GNU GPL version 3 or any later version"

# Last changed: 2010-01-18

import sys, re
import ufl

from functionspace import generate_functionspace_class
from form import generate_form_class

comment = "// DOLFIN wrappers"

stl_includes = """\
// Standard library includes
#include <string>"""

dolfin_includes = """\
// DOLFIN includes
#include <dolfin/common/NoDeleter.h>
#include <dolfin/fem/FiniteElement.h>
#include <dolfin/fem/DofMap.h>
#include <dolfin/fem/Form.h>
#include <dolfin/function/FunctionSpace.h>
#include <dolfin/function/GenericFunction.h>
#include <dolfin/function/CoefficientAssigner.h>"""

class UFCElementName:
    "Encapsulation of the names related to a generated UFC form."
    def __init__(self, name, ufc_finite_element_classnames, ufc_dof_map_classnames):
        """Arguments:

        """
        assert len(ufc_finite_element_classnames) == len(ufc_dof_map_classnames)

        self.name                          = name
        self.ufc_finite_element_classnames = ufc_finite_element_classnames
        self.ufc_dof_map_classnames        = ufc_dof_map_classnames

    def __str__(self):
        s = "UFCFiniteElementNames instance:\n"
        s += "name:                      %s\n" % self.name
        s += "finite_element_classnames: %s\n" % str(self.ufc_finite_element_classnames)
        s += "ufc_dof_map_classnames:    %s\n" % str(self.ufc_dof_map_classnames)
        return s

class UFCFormNames:
    "Encapsulation of the names related to a generated UFC form."
    def __init__(self, name, coefficient_names, ufc_form_classname, ufc_finite_element_classnames, ufc_dof_map_classnames):
        """Arguments:

        @param name:
            Name of form (e.g. 'a', 'L', 'M').
        @param coefficient_names:
            List of names of form coefficients (e.g. 'f', 'g').
        @param ufc_form_classname:
            Name of ufc::form subclass.
        @param ufc_finite_element_classnames:
            List of names of ufc::finite_element subclasses (length rank + num_coefficients).
        @param ufc_dof_map_classnames:
            List of names of ufc::dof_map subclasses (length rank + num_coefficients).
        """
        assert len(coefficient_names) <= len(ufc_dof_map_classnames)
        assert len(ufc_finite_element_classnames) == len(ufc_dof_map_classnames)

        self.num_coefficients              = len(coefficient_names)
        self.rank                          = len(ufc_finite_element_classnames) - self.num_coefficients
        self.name                          = name
        self.coefficient_names             = coefficient_names
        self.ufc_form_classname            = ufc_form_classname
        self.ufc_finite_element_classnames = ufc_finite_element_classnames
        self.ufc_dof_map_classnames        = ufc_dof_map_classnames

    def __str__(self):
        s = "UFCFormNames instance:\n"
        s += "rank:                      %d\n" % self.rank
        s += "num_coefficients:          %d\n" % self.num_coefficients
        s += "name:                      %s\n" % self.name
        s += "coefficient_names:         %s\n" % str(self.coefficient_names)
        s += "ufc_form_classname:        %s\n" % str(self.ufc_form_classname)
        s += "finite_element_classnames: %s\n" % str(self.ufc_finite_element_classnames)
        s += "ufc_dof_map_classnames:    %s\n" % str(self.ufc_dof_map_classnames)
        return s

def generate_dolfin_classes(prefix, names, common_space):
    if not isinstance(names, UFCElementName):
        return _generate_dolfin_classes(prefix, names, common_space)
    else:
        return _generate_dolfin_element_classes(prefix, names, common_space)

def _generate_dolfin_classes(prefix, form_names, common_space):
    """Generate code for all dolfin wrapper classes.

    @param prefix:
        String, prefix for all form names.
    @param form_names:
        List of UFCFormNames instances for each form to wrap.
    @param common_space:
        Tuple (form_index, space_index) of common function space if any, otherwise None.
    """

    # Collection of code blocks
    blocks = []

    # Check if naming forms by rank (BilinearForm etc) makes sense
    name_forms_by_rank = True
    ranks = [fn.rank for fn in form_names]
    for r in range(min(ranks), max(ranks) + 1):
        if ranks.count(r) > 1:
            name_forms_by_rank = False

    class_typedefs = []

    # Build list of unique coefficient names in forms
    # NB! This will break down if multiple forms define
    # different coefficients using the same name, e.g.
    # if coefficient_names are auto-generated like "w%d".
    coefficient_names = set()
    for fn in form_names:
        coefficient_names.update(fn.coefficient_names)
    coefficient_names = sorted(list(coefficient_names))

    # Match ufc classnames with argument names (independently
    # of SFC or FFC naming convention outside this function)
    # NB! This will break down if multiple forms define
    # different coefficients using the same name, e.g.
    # if coefficient_names are auto-generated like "w%d".
    ufc_finite_element_classnames = {}
    ufc_dof_map_classnames = {}
    for fn in form_names:
        for i, name in enumerate(fn.coefficient_names):
            ufc_finite_element_classnames[name] = fn.ufc_finite_element_classnames[fn.rank + i]
            ufc_dof_map_classnames[name] = fn.ufc_dof_map_classnames[fn.rank + i]

    # Generate FunctionSpace subclasses for coefficients independently of forms
    global_coefficientspace_classnames = {}
    for name in coefficient_names:
        global_coefficientspace_classnames[name] = "%sCoefficientSpace_%s" % (prefix, name)
        fe = ufc_finite_element_classnames[name]
        dm = ufc_dof_map_classnames[name]
        code = generate_functionspace_class(global_coefficientspace_classnames[name], fe, dm)
        blocks.append(code)

    # Handle each form
    for fn in form_names:

        # Form class name
        classname = "%sForm_%s" % (prefix, fn.name)
        if name_forms_by_rank and fn.rank <= 3:
            suffix = {0: "Functional", 1: "LinearForm", 2: "BilinearForm", 3: "TrilinearForm"}[fn.rank]
            class_typedefs.append((classname, "%s%s" % (prefix, suffix)))

        # Class names of generated classes, named by form and numbering
        functionspace_classnames = ["%s_FunctionSpace_%d"  % (classname, i) for i in range(fn.rank + fn.num_coefficients)]
        coefficient_classnames   = ["%s_Coefficient_%s"   % (classname, n) for n in fn.coefficient_names]

        # Class names for typedefs in class namespace
        coefficientspace_classnames = ["CoefficientSpace_%s" % n for n in fn.coefficient_names]
        basespace_classnames = ["FunctionSpace_%d" % i for i in range(fn.rank)]

        # Special treatment for bilinear and linear forms
        if fn.rank == 1:
            basespace_classnames[0] = "TestSpace"
        elif fn.rank == 2:
            basespace_classnames[0] = "TestSpace"
            basespace_classnames[1] = "TrialSpace"

        # Generate FunctionSpace subclasses
        for i in range(fn.rank):
            code = generate_functionspace_class(functionspace_classnames[i],
                                                fn.ufc_finite_element_classnames[i], fn.ufc_dof_map_classnames[i])
            blocks.append(code)

        # Generate typedefs for FunctionSpace subclasses for coefficients
        for i in range(fn.num_coefficients):
            global_name = global_coefficientspace_classnames[fn.coefficient_names[i]]
            local_name = functionspace_classnames[fn.rank + i]
            code = "typedef %s %s;" % (global_name, local_name)
            blocks.append(code)

        # Generate Form subclass
        formclass = generate_form_class(classname, fn.ufc_form_classname,
                                        functionspace_classnames, basespace_classnames, coefficientspace_classnames,
                                        coefficient_classnames, fn.coefficient_names)
        blocks.append(formclass)

    # Handle common function space
    if common_space is not None:
        form_index, space_index = common_space
        fn = form_names[form_index]
        formclassname = "%sForm_%s" % (prefix, fn.name) # NB! Should match form classname constructed in loop above
        spaceclassname = basespace_classnames[space_index]
        spaceclassname = "%s::%s" % (formclassname, spaceclassname)
        class_typedefs.append((spaceclassname, "FunctionSpace"))

    # Add class typedefs for optional FooBilinearForm naming
    if class_typedefs:
        code = "// Class typedefs\n"
        code += "\n".join("typedef %s %s;" % (a, b) for (a, b) in class_typedefs)
        blocks.append(code)

    # Join blocks together
    code = "\n\n".join(blocks)
    return code

def _generate_dolfin_element_classes(prefix, name, common_space):
    """Generate wrapper code for single finite elements
    """
    # Collection of code blocks
    blocks = []

    # Class names of generated classes, named by form and numbering
    functionspace_classnames = ["FunctionSpace"]

    # Class names for typedefs in class namespace
    basespace_classnames = ["FunctionSpace"]

    # Generate FunctionSpace subclasses
    code = generate_functionspace_class(functionspace_classnames[0],
                                        name.ufc_finite_element_classnames[0],
                                        name.ufc_dof_map_classnames[0])
    blocks.append(code)

    code = "\n\n".join(blocks)
    return code

def generate_dolfin_code(prefix, header, form_names, common_space=None, add_guards=False):
    """Generate complete dolfin wrapper code with given generated names.

    @param prefix:
        String, prefix for all form names.
    @param header:
        Code that will be inserted at the top of the file.
    @param form_names:
        List of UFCFormNames instances for each form to wrap.
    @param common_space:
        Tuple (form_index, space_index) of common function space if any, otherwise None.
    @param add_guards:
        True iff guards (ifdefs) should be added
    """

    code = ""

    # Generate guardnames
    guardname = ("%s_h" % prefix).upper()
    preguard = "#ifndef %s\n#define %s\n" % (guardname, guardname)
    postguard = "#endif\n"

    # Using prefix as namespace
    namespace = prefix

    # Generate classes
    dolfin_classes = generate_dolfin_classes("", form_names, common_space)

    # Add pre guard
    if add_guards:
        code += preguard + "\n"

    # Add code
    code += "\n".join((comment,
                       header,
                       stl_includes,
                       "",
                       dolfin_includes,
                       "\nnamespace %s\n{\n" % namespace,
                       dolfin_classes,
                       "\n}"))

    # Add post guard
    if add_guards:
        code += "\n" + postguard + "\n"

    return code
