# -*- coding: utf-8 -*-
"""
This module contains a representation class for forms, storing information about all form aspects.
"""

# Copyright (C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory
#
# This file is part of SyFi.
#
# SyFi is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# SyFi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SyFi. If not, see <http://www.gnu.org/licenses/>.
#
# First added:  2008-08-13
# Last changed: 2009-04-19

import copy
import math

import swiginac
import SyFi

import ufl
from ufl.algorithms import extract_max_quadrature_element_degree,\
    estimate_quadrature_degree, estimate_total_polynomial_degree

from sfc.common import sfc_assert
from sfc.common.names import finite_element_classname, dof_map_classname, \
                           integral_classname, form_classname
from sfc.common.utilities import matrix_to_list, list_to_vector, index_string, dot_product
from sfc.geometry import UFCCell, geometry_mapping
from sfc.quadrature import find_quadrature_rule
from sfc.symbolic_utils import symbolic_matrix, symbolic_vector, symbol, symbols, cross, sub, inner
from sfc.representation.elementrepresentation import create_syfi_polygon

class FormRepresentation(object):
    def __init__(self, formdata, element_reps, options):
        
        # UFL data about the form
        self.formdata = formdata
        
        # Mapping : ufl.FiniteElement -> sfc.ElementRepresentation
        self.element_reps = element_reps
        
        # SFC options
        self.options   = options
        
        # Fetch some common variables from formdata for shorter names in rest of code
        self.form             = self.formdata.preprocessed_form
        self.rank             = self.formdata.rank
        self.num_coefficients = self.formdata.num_coefficients
        self.domain           = self.formdata.cell.domain()
        
        # Some strings
        self.signature = repr(self.form)
        self._define_classnames()

        # Build some symbols and expressions
        self._build_geometry_tokens()
        self._build_argument_tokens()
        
        self._define_quadrature_rule()
        
        # Build running count of all unique ufl elements
        # mapping : ufl_element -> running count of unique elements
        self.element_number = {}
        for e in self.formdata.sub_elements:
            sfc_assert(e in self.element_reps, "Missing ElementRepresentation!")
            if not e in self.element_number:
                self.element_number[e] = len(self.element_number)

    def _define_quadrature_rule(self):
        # Get integration order from options or elements
        d = self.options.code.integral.integration_order
        if d is None:
            d = extract_max_quadrature_element_degree(self.form)
        if d is None:
            d = estimate_quadrature_degree(self.form) # TODO: Option for which estimation method to choose!
        if d is None:
            d = estimate_total_polynomial_degree(self.form)
        sfc_assert(d is not None, "All quadrature degree estimation rules failed, this shouldn't happen!")
        self.quad_order = d 
        
        # Determine quadrature rule from options or elements
        if self.options.code.integral.integration_method == "quadrature":
            self.quad_rule = find_quadrature_rule(self.domain, self.quad_order)
            self._build_facet_quadrature_rules()
        else:
            self.quad_rule = None
            self.facet_quad_rules = None
    
    def _build_facet_quadrature_rules(self):
        # pick a rule
        cell = self.cell
        reference_quad_rule = find_quadrature_rule(cell.facet_shape, self.quad_order, family="gauss")
        
        self.facet_quad_rules = []
        for facet in range(cell.num_facets):
            # facet polygon on the reference cell
            facet_polygon = cell.facet_polygons[facet]

            # TODO: That dimensions get right here is rather magical,
            #       I suspect that initSyFi plays a major role here...
            # affine mapping between reference 
            facet_G, facet_x0 = geometry_mapping(facet_polygon)
            
            # scaling factor for diagonal facets,
            # for scaling quadrature rule to a
            # different reference domain, this is NOT the
            # same as the global/local scaling factor facet_D!  
            D = 1.0
            if facet == 0:
                if cell.shape == "triangle":
                    D = math.sqrt(2)
                elif cell.shape == "tetrahedron":
                    D = math.sqrt(3)
            
            # scale weights by scaling factor
            weights = [ float(w)*D for w in reference_quad_rule.weights ]
            
            # apply affine mapping to quadrature points
            points = []
            for p in reference_quad_rule.points:
                # extend (n-1)D point p into nD point x:
                x = swiginac.matrix(len(facet_x0), 1, list(p)+[0.0])
                # apply mapping in nD space:
                newx = facet_G * x + facet_x0
                # store as a list:
                xlist = [float(newx[i].evalf()) for i in range(len(newx))]
                points.append(xlist)
            
            # make a copy of reference quadrature rule and modify it:
            facet_quad_rule = copy.copy(reference_quad_rule)

            if cell.shape != "interval":
                facet_quad_rule.nsd += 1
                assert facet_quad_rule.nsd == cell.nsd
            facet_quad_rule.comment += "\nMapped to polygon in higher dimension.\n"
            facet_quad_rule.weights = weights
            facet_quad_rule.points = points
            self.facet_quad_rules.append(facet_quad_rule)
    
    def _define_classnames(self):
        # Generate names for all classes
        fd = self.formdata
        self.classname = form_classname(self.form, self.options)
        self.fe_names   = [finite_element_classname(e) for e in fd.elements]
        self.dm_names   = [dof_map_classname(e) for e in fd.elements]
        
        self.itg_names = dict((itg, integral_classname(itg, self.classname)) \
                           for itg in self.form.integrals())
        f = self.form
        self.citg_names = dict((itg.measure().domain_id(), self.itg_names[itg]) for itg in f.cell_integrals())
        self.eitg_names = dict((itg.measure().domain_id(), self.itg_names[itg]) for itg in f.exterior_facet_integrals())
        self.iitg_names = dict((itg.measure().domain_id(), self.itg_names[itg]) for itg in f.interior_facet_integrals())
    
    def _build_geometry_tokens(self):
        "Build tokens for variables derived from cell."
        
        self.polygon = create_syfi_polygon(self.domain)
        self.cell = UFCCell(self.polygon)
        
        cell = self.cell
        nsd = cell.nsd
        
        # vx[i][j] = component j of coordinate of vertex i in cell
        self.vx_sym = []
        self.vx_expr = []
        for i in range(cell.num_vertices):
            s = swiginac.matrix(nsd, 1, symbols("vx%d_%d" % (i,j) for j in range(nsd)))
            e = swiginac.matrix(nsd, 1, symbols("c.coordinates[%d][%d]" % (i,j) for j in range(nsd))) 
            self.vx_sym.append(s)
            self.vx_expr.append(e)
        
        # Build a global polygon from the vertex coordinates:
        #   (using Triangle and Tetrahedron for the hypercube affine mappings
        #   is correct for rectangular and trapezoidal shapes, while using the polygons
        #   Box and Rectangle from SyFi would only support straight rectangular shapes)
        p = [matrix_to_list(vx) for vx in self.vx_sym]
        if   cell.shape == "interval":      polygon = SyFi.Line(*p)
        elif cell.shape == "triangle":      polygon = SyFi.Triangle(*p)
        elif cell.shape == "tetrahedron":   polygon = SyFi.Tetrahedron(*p)
        elif cell.shape == "quadrilateral": polygon = SyFi.Rectangle(*p) # SyFi.Triangle(p[0], p[1], p[3]) #SyFi.Rectangle(p[0], p[2])     # TODO: Better affine mapping?
        elif cell.shape == "hexahedron":    polygon = SyFi.Box(*p)       # SyFi.Tetrahedron(p[0], p[1], p[3], p[4]) # SyFi.Box(p[0], p[6]) # TODO: Better affine mapping?
        self.global_polygon = polygon
        #self.global_cell = UFCCell(polygon)
        
        # Geometry mapping 
        self.G_sym  = symbolic_matrix(nsd, nsd, "G")              # FIXME: Make sure we don't mess up transpose here again...
        self.x0_sym = self.vx_sym[0] #symbolic_vector(nsd, "x0")  # FIXME: Make sure this is consistent.
        self.G_expr, self.x0_expr = geometry_mapping(self.global_polygon)     # TODO: Non-affine mappings?
        
        self.detGtmp_sym  = symbol("detGtmp")
        self.detGtmp_expr = swiginac.determinant(self.G_sym)
        
        self.detG_sym  = symbol("detG")
        self.detG_expr = swiginac.abs(self.detGtmp_sym)
        
        # Sign of det G, negative if cell is inverted
        self.detG_sign_sym = symbol("detG_sign")
        self.detG_sign_expr = self.detG_sym / swiginac.abs(self.detG_sym)
        
        # Inverse of geometry mapping (expression simplification using cofactor)
        self.Ginv_sym  = symbolic_matrix(nsd, nsd, "Ginv")
        self.Ginv_expr = (self.detGtmp_expr*swiginac.inverse(self.G_sym)) / self.detGtmp_sym 
        #self.Ginv_expr = swiginac.inverse(self.G_sym)
        
        # Local and global coordinates
        self.xi_sym = symbols(("x","y","z")[:nsd])
        # TODO: Change to xiN, but x,y,z is used as local coordinate symbols in rest of SyFi.
        #self.xi_sym  = symbols(("xi0", "xi1", "xi2")[:nsd])
        self.x_sym  = symbols(("x0","x1","x2")[:nsd])
        x_expr = [sum(self.G_sym[i,j]*self.xi_sym[j] for j in range(nsd)) for i in range(nsd)] # FIXME: Transpose?
        self.x_expr = swiginac.matrix(nsd, 1, x_expr)
        
        # Quadrature variables (coordinates are self.xi_sym?)
        self.quad_weight_sym = symbol("quad_weight")
        self.D_sym = symbol("D")
        
        # Facet normal expressions with sign scaling for inverted cells
        self.n_sym = symbolic_vector(nsd, "n")
        self.n_expr = [self.detG_sign_sym*n for n in self.cell.facet_n]

        # Facet mapping
        #self.facet_G_sym = FIXME
        #self.facet_G_expr[facet] = FIXME
        
        self.facet_D_sym = symbol("facet_D")
        self.facet_D_expr = []
        
        # --- Compute facet_D
        sqrt = swiginac.sqrt
        if cell.shape == "interval":
            self.facet_D_expr = [swiginac.numeric(1) for facet in range(cell.num_facets)]

        elif cell.shape == "triangle":
            for facet in range(cell.num_facets):
                facet_polygon = cell.facet_polygons[facet]
                # facet_polygon is a line
                v0 = self.G_sym * list_to_vector(facet_polygon.vertex(0))
                v1 = self.G_sym * list_to_vector(facet_polygon.vertex(1))
                v  = sub(v1, v0)
                D = sqrt( inner(v, v) ) # |v|
                if facet == 0: # diagonal facet
                    D = D / sqrt(2) # FIXME: Is this right?
                # TODO: Should in general use determinant of mapping instead of this linear approximation
                self.facet_D_expr.append(D)

        elif cell.shape == "tetrahedron":
            for facet in range(cell.num_facets):
                facet_polygon = cell.facet_polygons[facet]
                # facet_polygon is a triangle
                v0 = facet_polygon.vertex(0)
                v1 = facet_polygon.vertex(1)
                v2 = facet_polygon.vertex(2)
                # map local reference coordinates to get global cell coordinates
                v0 = self.G_sym * list_to_vector(v0)
                v1 = self.G_sym * list_to_vector(v1)
                v2 = self.G_sym * list_to_vector(v2)
                
                # vx[i][j] = component j of coordinate of vertex i in cell
                #self.vx_sym[i][j]
                
                # A = |a x b|/2, D = A/(1/2) = |a x b|,
                # with A = global triangle area, and
                # D being A over the reference area 1/2 
                c = cross(sub(v1, v0), sub(v2, v0)) 
                D = sqrt( inner(c, c) ) # |c|
                if facet == 0: # diagonal facet 
                    D = D / sqrt(3) # FIXME: Is this right?
                # TODO: Should in general use determinant of mapping instead of this linear approximation
                self.facet_D_expr.append(D)

        elif cell.shape == "quadrilateral":
            for facet in range(cell.num_facets):
                facet_polygon = cell.facet_polygons[facet]
                # facet_polygon is a line
                v0 = self.G_sym * list_to_vector(facet_polygon.vertex(0))
                v1 = self.G_sym * list_to_vector(facet_polygon.vertex(1))
                v  = sub(v1, v0)
                D  = sqrt(inner(v,v))
                # TODO: Should in general use determinant of mapping instead of this linear approximation
                self.facet_D_expr.append(D)
    
        elif cell.shape == "hexahedron":
            for facet in range(cell.num_facets):
                facet_polygon = cell.facet_polygons[facet]
                # facet_polygon is a quadrilateral
                v0 = self.G_sym * list_to_vector(facet_polygon.vertex(0))
                v1 = self.G_sym * list_to_vector(facet_polygon.vertex(1))
                v2 = self.G_sym * list_to_vector(facet_polygon.vertex(2))
                v3 = self.G_sym * list_to_vector(facet_polygon.vertex(3))
                # compute midpoints
                m0 = (v0 + v1)/2
                m1 = (v1 + v2)/2
                m2 = (v2 + v3)/2
                m3 = (v3 + v0)/2
                # area is length of cross product of vectors between opposing midpoints
                c = cross(sub(m2, m0), sub(m1, m3))
                D = sqrt( inner(c, c) )
                # TODO: Should in general use determinant of mapping instead of this linear approximation
                self.facet_D_expr.append(D)
    
    def _build_argument_tokens(self):
        # --- Coefficient dofs
        self.w_dofs = []
        for i in range(self.num_coefficients):
            rep = self.element_reps[self.formdata.elements[self.rank + i]]
            wdofs = symbols("w[%d][%d]" % (i, j) for j in range(rep.local_dimension))
            self.w_dofs.append(wdofs)
    
    def __str__(self):
        s = ""
        s += "FormRepresentation:\n"
        s += "\n" # TODO: Add more here
        s += "    Geometry tokens:\n"
        s += "        self.cell            = %s\n" % self.cell
        #s += "        self.global_cell     = %s\n" % self.global_cell
        s += "        self.vx_sym          = %s\n" % self.vx_sym
        s += "        self.vx_expr         = %s\n" % self.vx_expr
        s += "        self.xi_sym          = %s\n" % self.xi_sym
        s += "        self.x_sym           = %s\n" % self.x_sym
        s += "        self.x_expr          = %s\n" % self.x_expr
        s += "        self.G_sym           = %s\n" % self.G_sym
        s += "        self.G_expr,         = %s\n" % self.G_expr,
        s += "        self.detGtmp_sym     = %s\n" % self.detGtmp_sym
        s += "        self.detGtmp_expr    = %s\n" % self.detGtmp_expr
        s += "        self.detG_sym        = %s\n" % self.detG_sym
        s += "        self.detG_expr       = %s\n" % self.detG_expr
        s += "        self.Ginv_sym        = %s\n" % self.Ginv_sym
        s += "        self.Ginv_expr       = %s\n" % self.Ginv_expr
        s += "        self.quad_weight_sym = %s\n" % self.quad_weight_sym
        s += "    Argument tokens:\n"
        s += "        self.w_dofs          = %s\n" % self.w_dofs
        return s
    
    #def element_representation(self, iarg):
    #    return self.element_reps[self.formdata.elements[iarg]]
    
    #def unique_element_number(self, iarg):
    #    return self.element_number[self.formdata.elements[iarg]]
    
    #===============================================================================
    #    def _build_argument_cache(self):
    #        use_symbols = self.integration_method == "quadrature"
    #        
    #        self.v_cache = []
    #        self.Dv_cache = []
    #        for iarg in range(self.rank + self.num_coefficients):
    #            elm = self.formdata.elements[iarg]
    #            rep = self.element_reps[elm]
    #            self.v_cache.append([])
    #            self.Dv_cache.append([])
    #            for i in range(rep.local_dimension):
    #                for component in rep.value_components:
    #                    ve = self.v_expr(iarg, i, component)
    #                    vs = self.v_sym(iarg, i, component)
    #                    self.v_cache[iarg][i][component] = (vs, ve)
    #                    
    #                    for derivatives in [(d,) for d in range(self.cell.nsd)]:
    #                        Dve = self.Dv_expr(iarg, i, component, derivatives, use_symbols)
    #                        Dvs = self.Dv_sym(iarg, i, component, derivatives)
    #                        self.Dv_cache[iarg][i][component][derivatives] = (Dvs, Dve)
    #===============================================================================
    
    # --- Basis function and coefficient data for arguments
    
    def v_expr(self, iarg, i, component):
        elm = self.formdata.elements[iarg]
        #component, elm = elm.extract_component(component)
        
        rep = self.element_reps[elm]
        v = rep.basis_function(i, component)
        return v
    
    def v_sym(self, iarg, i, component, on_facet):
        elm = self.formdata.elements[iarg]
        
        e = self.v_expr(iarg, i, component)
        e2 = e.evalf()
        if e2.nops() == 0:
            if e2 == 0: # FIXME: Remove this line after fixing code generation, this leads to "double 0;" statements
                return e2
        
        scomponent, selm = elm.extract_component(component)
        num = self.element_number[selm]
        
        prefix = "qt[facet][iq]." if on_facet else "qt[iq]."
        suffix = index_string((num, i) + scomponent)
        name = "%sv_%s" % (prefix, suffix)
        s = symbol(name)
        return s
    
    def w_expr(self, iarg, component, use_symbols, on_facet):
        elm = self.formdata.elements[self.rank+iarg]
        rep = self.element_reps[elm]
        dim = rep.local_dimension
        if use_symbols:
            vs = [self.v_sym(iarg+self.rank, i, component, on_facet=on_facet) for i in range(dim)]
        else:
            vs = [self.v_expr(iarg+self.rank, i, component) for i in range(dim)]
        w = dot_product(self.w_dofs[iarg], vs)
        return w
    
    def w_sym(self, iarg, component):
        name = "w_%s" % index_string((iarg,) + component)
        s = symbol(name)
        return s
    
    # --- Basis function and coefficient data for derivatives of arguments
    
    def ddxi(self, f, i):
        return swiginac.diff(f, self.xi_sym[i])
    
    def ddx(self, f, i):
        return sum(self.Ginv_sym[j,i]*self.ddxi(f, j) for j in range(self.cell.nsd)) # TODO: Verify this line (in particular transposed or not!)
    
    def dv_expr(self, iarg, i, component, d):
        """Return expression for dv/dxi, with v being a particular
        component of basis function i in argument space iarg,
        and d is a tuple (i.e. multiindex) of xi directions (local coordinates)."""
        sfc_assert(len(d) == 1, "Higher order derivatives not implemented.")
        d = tuple(sorted(d))
        v = self.v_expr(iarg, i, component)
        dv = v
        for k in d:
            dv = self.ddxi(dv, k)
        return dv
    
    def dv_sym(self, iarg, i, component, d, on_facet):
        sfc_assert(len(d) == 1, "Higher order derivatives not implemented.")
        d = tuple(sorted(d))
        
        e = self.dv_expr(iarg, i, component, d)
        e2 = e.evalf()
        if e2.nops() == 0:
            if e2 == 0: # FIXME: Remove this line after fixing code generation, this leads to "double 0;" statements
                return e2
        
        prefix = "qt[facet][iq]." if on_facet else "qt[iq]."
        suffix = index_string((iarg, i) + component + d)
        name = "%sdv_dxi_%s" % (prefix, suffix)
        return symbol(name)
    
    def Dv_expr(self, iarg, i, component, d, use_symbols, on_facet):
        """Return expression for dv/dx, with v being a particular
        component of basis function i in argument space iarg,
        and d is a tuple (i.e. multiindex) of x directions (global coordinates)."""
        sfc_assert(len(d) == 1, "Higher order derivatives not implemented.")
        d = tuple(sorted(d))
        # Get local derivatives in all directions
        if use_symbols:
            dv = [self.dv_sym(iarg, i, component, (j,), on_facet=on_facet) for j in range(self.cell.nsd)]
        else:
            dv = [self.dv_expr(iarg, i, component, (j,)) for j in range(self.cell.nsd)]
        # Apply mapping to dv
        j, = d
        Dv = sum(self.Ginv_sym[k, j]*dv[k] for k in range(self.cell.nsd)) # TODO: Verify this line (in particular transposed or not!)
        return Dv
    
    def Dv_sym(self, iarg, i, component, d, on_facet):
        sfc_assert(len(d) == 1, "Higher order derivatives not implemented.")
        d = tuple(sorted(d))
        
        e = self.Dv_expr(iarg, i, component, d, False, on_facet=on_facet)
        e2 = e.evalf()
        if e2.nops() == 0:
            if e2 == 0: # FIXME: Remove this line after fixing code generation, this leads to "double 0;" statements
                return e2
        
        name = "Dv_%s" % index_string((iarg, i) + component + d)
        return symbol(name)
    
    def Dw_expr(self, iarg, component, d, use_symbols, on_facet):
        sfc_assert(len(d) == 1, "Higher order derivatives not implemented.")
        d = tuple(sorted(d))
        elm = self.formdata.elements[self.rank+iarg]
        rep = self.element_reps[elm]
        dim = rep.local_dimension
        if use_symbols:
            vs = [self.Dv_sym(self.rank+iarg, i, component, d, on_facet=on_facet) for i in range(dim)]
        else:
            vs = [self.Dv_expr(self.rank+iarg, i, component, d, use_symbols=False, on_facet=on_facet) for i in range(dim)]
        w = dot_product(self.w_dofs[iarg], vs)
        return w
    
    def Dw_sym(self, iarg, component, d):
        sfc_assert(len(d) == 1, "Higher order derivatives not implemented.")
        d = tuple(sorted(d))
        name = "Dw_%s" % index_string((iarg,) + component + d)
        return symbol(name)

