#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module contains a class UFCCell to represent the properties of a cell in a easily accessible way.
"""

# Copyright (C) 2008 Martin Sandve Alnes and Simula Resarch Laboratory
#
# This file is part of SyFi.
#
# SyFi is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# SyFi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SyFi. If not, see <http://www.gnu.org/licenses/>.
#
# First added:  2008-08-12
# Last changed: 2008-08-12

import swiginac
from swiginac import matrix, exp, sqrt

import SyFi

from sfc.symbolic_utils import cross, inner, symbols
from sfc.common.output import sfc_debug

# TODO: this code could do with some tests!

class UFCCell:
    def __init__(self, polygon):
        sfc_debug("Entering UFCCell.__init__")
        
        assert isinstance(polygon, SyFi.Polygon)
        self.polygon = polygon
        name = polygon.str()
        x, y, z = symbols("xyz")
        
        if isinstance(polygon, SyFi.ReferenceLine) or name == "ReferenceLine":
            self.shape        = "interval"
            self.facet_shape  = "point"

            # ... dimensions:
            self.nsd          = 1
            self.num_vertices = 2
            self.num_edges    = 1
            self.num_faces    = 0
            self.num_facets   = self.num_vertices
            self.num_entities = (self.num_vertices, self.num_edges)

            # ... connectivity:
            self.facet_vertices = [
                                   (0,),
                                   (1,)
                                  ]

            # ... geometry:
            self.vertices        = [matrix(self.nsd, 1, polygon.vertex(i)) for i in range(self.num_vertices)]
            self.facet_polygons  = self.vertices            

            # ... normal vector:
            n = matrix(1, 1, [1])
            self.facet_n = [-n, +n] # TODO: is this what we want to do for the interval "normal"?

            # ... implicit equations:
            self.facet_equations = [x, x-1]

        elif isinstance(polygon, SyFi.ReferenceTriangle) or name == "ReferenceTriangle":
            self.shape = "triangle"
            self.facet_shape  = "interval"

            # ... dimensions:
            self.nsd          = 2
            self.num_vertices = 3
            self.num_edges    = 3
            self.num_faces    = 1
            self.num_facets   = self.num_edges
            self.num_entities = (self.num_vertices, self.num_edges, self.num_faces)

            # ... connectivity:
            # facet vertices in counterclockwise direction around reference cell:
            self.facet_vertices = [
                                   (1, 2),
                                   (2, 0),
                                   (0, 1)
                                  ]
            # ... geometry:
            self.vertices        = [matrix(self.nsd, 1, polygon.vertex(i)) for i in range(self.num_vertices)]
            self.facet_polygons  = [polygon.line(i)   for i in range(self.num_facets)]

            # ... normal vector:
            self.facet_n = []
            for facet in range(self.num_facets):
                fromvert, tovert = self.facet_vertices[facet]
                
                # FIXME: finish facet stuff
                t = self.vertices[fromvert] - self.vertices[tovert]
                n = matrix(2, 1, [t[1], -t[0]])
                n = n / sqrt(inner(n,n)) # TODO: Is this approach efficient? FIXME: Is this correct? Scaling to unit normal on reference domain, but this isn't necessarily unit length on global domain...

                self.facet_n.append(n)

            # ... implicit equations:
            self.edge_equations  = [x+y-1, x, y]
            self.facet_equations = self.edge_equations

        elif isinstance(polygon, SyFi.ReferenceTetrahedron) or name == "ReferenceTetrahedron":
            self.shape        = "tetrahedron"
            self.facet_shape  = "triangle"
            
            # ... dimensions:
            self.nsd          = 3
            self.num_vertices = 4
            self.num_edges    = 6
            self.num_faces    = 4
            self.num_facets   = self.num_faces
            self.num_entities = (self.num_vertices, self.num_edges, self.num_faces, 1)

            # ... connectivity:
            # facet vertices in counterclockwise direction around reference cell:
            self.facet_vertices = [
                                   (1, 2, 3),
                                   (0, 3, 2), # (0, 2, 3)
                                   (0, 1, 3),
                                   (0, 2, 1)  # (0, 1, 2)
                                  ]

            # ... geometry:
            self.vertices        = [matrix(self.nsd, 1, polygon.vertex(i)) for i in range(self.num_vertices)]
            self.facet_polygons  = [polygon.triangle(i) for i in range(self.num_facets)]

            # ... normal vector:
            self.facet_n = []
            for facet in range(self.num_facets):
                vert = self.facet_vertices[facet]

                t01 = self.vertices[vert[1]] - self.vertices[vert[0]]
                t02 = self.vertices[vert[2]] - self.vertices[vert[0]]
                n = cross(t01, t02)
                n = n / sqrt(inner(n,n)) # TODO: is this approach efficient? FIXME: Is this correct? Scaling to unit normal on reference domain...

                self.facet_n.append(n)

            # ... implicit equations:
            self.face_equations  = [ (x+y+z-1), x, y, z]
            self.facet_equations = self.face_equations

            # in a tetrahedron, no two edges are parallel WARNiNG FIXME: this edge code is not verified!
            p = matrix(3, 1, [x, y, z])
            v0, v1, v2, v3 = self.vertices
            self.edge_equations  = [ cross( (p - v3), (v2 - v3) ),
                                     cross( (p - v3), (v1 - v3) ),
                                     cross( (p - v2), (v1 - v2) ),
                                     cross( (p - v3), (v0 - v3) ),
                                     cross( (p - v2), (v0 - v2) ),
                                     cross( (p - v1), (v0 - v1) ) ]

        elif isinstance(polygon, SyFi.ReferenceRectangle) or name == "ReferenceRectangle":
            self.shape        = "quadrilateral"
            self.facet_shape  = "interval"
            
            # ... dimensions:
            self.nsd          = 2
            self.num_vertices = 4
            self.num_edges    = 4
            self.num_faces    = 1
            self.num_facets   = self.num_edges
            self.num_entities = (self.num_vertices, self.num_edges, self.num_faces)

            # ... connectivity:
            # facet vertices in counterclockwise direction around reference cell:
            self.facet_vertices = [
                                   (2, 3),
                                   (1, 2),
                                   (3, 0),
                                   (0, 1)
                                  ]

            # ... geometry:
            self.vertices        = [matrix(self.nsd, 1, polygon.vertex(i)) for i in range(self.num_vertices)]
            self.facet_polygons  = [polygon.line(i) for i in range(self.num_facets)]

            # ... normal vector:
            self.facet_n = []
            for facet in range(self.num_facets):
                fromvert, tovert = self.facet_vertices[facet]
                
                t = self.vertices[fromvert] - self.vertices[tovert]
                n = matrix(2, 1, [t[1], -t[0]])
                
                self.facet_n.append(n)

            # ... implicit equations:
            self.edge_equations  = [x-1, y-1, x, y]
            self.facet_equations = self.edge_equations

        elif isinstance(polygon, SyFi.ReferenceBox) or name == "ReferenceBox":
            self.shape        = "hexahedron"
            self.facet_shape  = "quadrilateral"

            # ... dimensions:
            self.nsd          = 3
            self.num_vertices = 8
            self.num_edges    = 12
            self.num_faces    = 6
            self.num_facets   = self.num_faces
            self.num_entities = (self.num_vertices, self.num_edges, self.num_faces, 1)

            # ... connectivity:
            # facet vertices in counterclockwise direction around reference cell:
            self.facet_vertices = [
                                   (4, 5, 6, 7),
                                   (2, 3, 7, 6), # (2, 3, 6, 7)
                                   (1, 2, 6, 5), # (1, 2, 5, 6)
                                   (0, 4, 7, 3), # (0, 3, 4, 7)
                                   (0, 1, 5, 4), # (0, 1, 4, 5)
                                   (0, 3, 2, 1)  # (0, 1, 2, 3)
                                  ]

            # Current Lagrange basis function order in SyFi: 
            #0 [[0, 0, 0], 0]
            #1 [[0, 0, 1], 0]
            #2 [[1, 0, 0], 0]
            #3 [[1, 0, 1], 0]
            #4 [[0, 1, 0], 0]
            #5 [[0, 1, 1], 0]
            #6 [[1, 1, 0], 0]
            #7 [[1, 1, 1], 0]
            # UFC vertex order for hexes requires reordering: TODO: don't need to do this?
            #0->0 [[0, 0, 0], 0]
            #2->1 [[1, 0, 0], 0]
            #6->2 [[1, 1, 0], 0]
            #4->3 [[0, 1, 0], 0]
            #1->4 [[0, 0, 1], 0]
            #3->5 [[1, 0, 1], 0]
            #7->6 [[1, 1, 1], 0]
            #5->7 [[0, 1, 1], 0]

            # ... geometry:
            self.vertices        = [matrix(self.nsd, 1, polygon.vertex(i)) for i in range(self.num_vertices)]
            self.facet_polygons  = [polygon.rectangle(i) for i in range(self.num_facets)]

            self.facet_n = []
            for facet in range(self.num_facets):
                # counterclockwise ordering seen from outside of reference cell:
                vert = self.facet_vertices[facet]

                # computing normal as the cross product of the facet diagonals
                t02 = self.vertices[vert[0]] - self.vertices[vert[2]]
                t13 = self.vertices[vert[1]] - self.vertices[vert[3]]
                n = cross(t02, t13)
                #n = n / swiginac.sqrt(inner(n,n)) # TODO: is this approach efficient? FIXME: Is this correct? Scaling to unit normal on reference domain...

                self.facet_n.append(n)

            # ... implicit equations:
            # all faces coincide with a cartesian plane
            self.face_equations  = [z-1, y-1, x-1, x, y, z]
            self.facet_equations = self.face_equations

            # points on an edge satisfy the equations of two faces
            f = self.face_equations
            e = [0]*12
            e[0 ] = (f[0], f[1])
            e[1 ] = (f[0], f[2])
            e[2 ] = (f[0], f[3])
            e[3 ] = (f[0], f[4])
            e[4 ] = (f[1], f[3])
            e[5 ] = (f[1], f[2])
            e[6 ] = (f[1], f[5])
            e[7 ] = (f[2], f[4])
            e[8 ] = (f[2], f[5])
            e[9 ] = (f[3], f[4])
            e[10] = (f[3], f[5])
            e[11] = (f[4], f[5])
            self.edge_equations = []
            for f1, f2 in e:
                self.edge_equations.append( exp(f1)*exp(f2) - 1 ) # TODO: a better equation?

        else:
            raise RuntimeError("Unknown polygon type %s." % name)
        
        sfc_debug("Leaving UFCCell.__init__")
    
    def find_entity(self, xi):
        "Find which cell entity the coordinate xi lies on."
        for i in range(self.nsd):
            for j in range(self.num_entities[i]):
                if self.entity_check(i, j, xi):
                    return (i, j)
        return (self.nsd, 0)

    def entity_check(self, d, i, p):
        # this is a bit ugly, could benefit from a cleanup
  
        # make p into a list
        if isinstance(p, swiginac.matrix):
            p = [p[k] for k in range(len(p))]
        elif isinstance(p, swiginac.basic):
            p = [p]
        
        # check if we match a vertex exactly
        if d == 0:
            #eq = self.vertex_equations[i]
            return bool(self.vertices[i] == matrix(self.nsd, 1, p))

        # get implicit equation for this entity
        if d == -1:
            eq = self.facet_equations[i]
        if d == 1:
            eq = self.edge_equations[i]
        if d == 2:
            eq = self.face_equations[i]
        
        # check if implicit equation is zero in this point, which means p is on the entity
        x = symbols(["x", "y", "z"])
        for j in range(len(p)):
            eq = eq.subs(x[j] == p[j])
        return inner(eq, eq).expand().is_zero()

    def facet_check(self, i, p):
        return self.entity_check(-1, i, p)
    
    def vertex_check(self, i, p):
        return self.entity_check(0, i, p)

    def edge_check(self, i, p):
        return self.entity_check(1, i, p)

    def face_check(self, i, p):
        return self.entity_check(2, i, p)

    def __eq__(self, other):
        if self.shape == other.shape:
            return True
        return False

    def __ne__(self, other):
        if self.shape == other.shape:
            return False
        return True

    def __str__(self):
        s = "Cell\n"
        s += "  shape:            %s\n" % self.shape
        s += "  nsd:              %d\n" % self.nsd
        s += "  num_vertices:     %d\n" % self.num_vertices
        s += "  num_edges:        %d\n" % self.num_edges
        s += "  num_faces:        %d\n" % self.num_faces
        s += "  num_facets:       %d\n" % self.num_facets
        s += "  vertices:         %s\n" % str(self.vertices)
        s += "  facet_equations:  %s\n" % str(self.facet_equations)
        return s


if __name__ == "__main__":
    print ""
    polygon = SyFi.ReferenceLine()
    cell = UFCCell(polygon)
    print cell

    print ""
    polygon = SyFi.ReferenceTriangle()
    cell = UFCCell(polygon)
    print cell

    print ""
    polygon = SyFi.ReferenceTetrahedron()
    cell = UFCCell(polygon)
    print cell
