#!/usr/bin/python

#host fantasia {
#  dhcp-client-identifier
#  hardware ethernet 08:00:07:26:c0:a5;
#  fixed-address fantasia.fugue.com;
#}

#subnet 1.2.3.0 netmask 255.255.255.0 {
#  option routers 1.2.3.4;
#  range 1.2.3.100 1.2.3.200;
#  option domain-name "foo.bar.example.com";
#}

#shared-network "foo" {
#}

from ldaptor.protocols.ldap import ldapclient, distinguishedname, ldapconnector, ldapsyntax
from ldaptor.protocols import pureber, pureldap
from ldaptor import usage, ldapfilter, config
from twisted.internet import protocol, reactor, defer
from socket import inet_aton, inet_ntoa


def my_aton_octets(ip):
    s=inet_aton(ip)
    octets=map(None, s)
    n=0L
    for o in octets:
	n=n<<8
	n+=ord(o)
    return n

def my_aton_numbits(num):
    n=0L
    while num>0:
	n>>=1
	n |= 2**31
	num-=1
    return n

def my_aton(ip):
    try:
	i=int(ip)
    except ValueError:
	return my_aton_octets(ip)
    else:
	return my_aton_numbits(i)

def my_ntoa(n):
    s=(
	chr((n>>24)&0xFF)
	+ chr((n>>16)&0xFF)
	+ chr((n>>8)&0xFF)
	+ chr(n&0xFF)
       )
    ip=inet_ntoa(s)
    return ip

class HostIPAddress:
    def __init__(self, host, ipAddress):
	self.host=host
	self.ipAddress=ipAddress

    def printDHCP(self, domain, prefix=''):
	r=([
	    '# %s' % self.host.dn,
	    'host %s.%s {' % (self.host.name, domain),
	    ]
	   + [ '\thardware ethernet %s;' % mac
	       for mac in self.host.macAddresses
	       ]
	   + [
	    '\tfixed-address %s;' % self.ipAddress,
	    '}'
	    ])
	print '\n'.join([prefix+line for line in r])

    def __repr__(self):
	return (self.__class__.__name__
		+'('
		+'host=%s, ' % id(self.host)
		+'ipAddress=%s' % repr(self.ipAddress)
		+')')

class Host:
    def __init__(self, dn, name, ipAddresses, macAddresses=()):
	self.dn=dn
	self.name=name
	self.ipAddresses=[HostIPAddress(self, ip) for ip in ipAddresses]
	self.macAddresses=macAddresses

    def __repr__(self):
	return (self.__class__.__name__
		+'('
		+'dn=%s, ' % repr(self.dn)
		+'name=%s, ' % repr(self.name)
		+'ipAddresses=%s' % repr(self.ipAddresses)
		+'macAddresses=%s' % repr(self.macAddresses)
		+')')

class Net:
    def __init__(self, dn, name, address, mask,
		 routers=(),
		 dhcpRanges=(),
		 winsServers=(),
		 domainNameServers=(),
		 ):
	self.dn=dn
	self.name=name
	self.address=address
	self.mask=mask
	self.routers=routers
	self.dhcpRanges=dhcpRanges
	self.winsServers=winsServers
	self.domainNameServers=domainNameServers
	self.hosts=[]

    def isInNet(self, ipAddress):
	net = my_aton(self.address)
	mask = my_aton(self.mask)
	ip = my_aton(ipAddress)
	if ip&mask == net:
	    return 1
	return 0

    def addHost(self, host):
	assert self.isInNet(host.ipAddress)
	self.hosts.append(host)

    def printDHCP(self, domain, prefix=''):
	nm = self.mask
	nm = my_aton(nm)
	nm = my_ntoa(nm)
	r = ['# %s' % self.dn,
	     'subnet %s netmask %s {' % (self.address, nm),
	     '\toption domain-name "%s.%s";' % (self.name, domain)]
	if self.routers:
	    r.append('\toption routers %s;' % (', '.join(self.routers)))
	for dhcpRange in self.dhcpRanges:
	    r.append('\trange %s;' % dhcpRange)
	if self.winsServers:
	    r.append('\toption netbios-name-servers %s;' % (', '.join(self.winsServers)))
	if self.domainNameServers:
	    r.append('\toption domain-name-servers %s;' % (', '.join(self.domainNameServers)))
	r.append('}')

	print '\n'.join([prefix+line for line in r])

	seen = {}
	for host in self.hosts:
	    if seen.has_key(host.host):
		continue
	    seen[host.host]=1
	    host.printDHCP(self.name+'.'+domain, prefix=prefix)

    def __repr__(self):
	return (self.__class__.__name__
		+'('
		+'dn=%s, ' % repr(self.dn)
		+'name=%s, ' % repr(self.name)
		+'address=%s, ' % repr(self.address)
		+'mask=%s' % repr(self.mask)
		+')')

class SharedNet:
    def __init__(self, name):
	self.name=name
	self.nets=[]

    def addNet(self, net):
	self.nets.append(net)

    def printDHCP(self, domain):
	print 'shared-network "%s" {' % self.name
	for net in self.nets:
	    net.printDHCP(domain, prefix='\t')
	print '}'
	print

def haveHosts(hosts, nets, sharedNets, dnsDomain):
    for host in hosts:
        for hostIP in host.ipAddresses:
            parent=None
            for net in nets + reduce(lambda x,y: x+y,
                                          [x.nets for x in sharedNets.values()],
                                          []):
                if net.isInNet(hostIP.ipAddress):
                    parent=net
                    break

            if parent:
                parent.addHost(hostIP)
            else:
                sys.stderr.write("IP address %s is in no net, discarding.\n" % hostIP)

    for net in sharedNets.values():
        net.printDHCP(dnsDomain)
    for net in nets:
        net.printDHCP(dnsDomain)

def only(e, attr):
    if len(e[attr])!=1:
        raise RuntimeError, \
              "object %s attribute %r has multiple values: %s" \
              % (e.dn, attr, e[attr])
    for val in e[attr]:
        return val

def _cbGetHosts(entries):
    entries = []
    for e in entries:
        cn = only(e, 'cn')
	self.entries.append(Host(str(e.dn),
				 str(cn),
				 map(str, e['ipHostNumber']),
				 map(str, e.get('macAddress', ()))))
    return entries

def getHosts(e, filter):
    filt=pureldap.LDAPFilter_and(value=(
        pureldap.LDAPFilter_present('cn'),
        pureldap.LDAPFilter_present('ipHostNumber'),
        ))
    if filter:
        filt = pureldap.LDAPFilter_and(value=(filter, filt))

    d = e.search(filterObject=filt,
                 attributes=['cn',
                             'ipHostNumber',
                             'macAddress',
                             ])
    d.addCallback(_cbGetHosts)
    return d

def haveNets(data, e, baseDN, filt, dnsDomain):
    nets, sharedNets = data
    d = getHosts(e, filt)
    d.addCallback(haveHosts, nets, sharedNets, dnsDomain)
    return d

def _cbGetNets(entries):
    sharedNetworks = {}
    entries = []

    for e in entries:
        cn=only(e, 'cn')
        ipNetworkNumber=only(e, 'ipNetworkNumber')
        ipNetmaskNumber=only(args, 'ipNetmaskNumber')
        net = Net(objectName, cn,
                  ipNetworkNumber, ipNetmaskNumber,
                  routers=args.get('router', ()),
                  dhcpRanges=args.get('dhcpRange', ()),
                  winsServers=args.get('winsServer', ()),
                  domainNameServers=args.get('domainNameServer', ()),
                  )
        if args.has_key('sharedNetworkName'):
            name = only(e, 'sharedNetworkName')
            if not sharedNetworks.has_key(name):
                sharedNetworks[name]=SharedNet(name)
            sharedNetworks[name].addNet(net)
        else:
            entries.append(net)

    return (entries, sharedNetworks)

def getNets(e, filter):
    filt=pureldap.LDAPFilter_and(value=(
        pureldap.LDAPFilter_present('cn'),
        pureldap.LDAPFilter_present('ipNetworkNumber'),
        pureldap.LDAPFilter_present('ipNetmaskNumber'),
        ))
    if filter:
        filt = pureldap.LDAPFilter_and(value=(filter, filt))

    d = e.search(filterObject=filt,
                 attributes=['cn',
                             'ipNetworkNumber',
                             'ipNetmaskNumber',
                             'router',
                             'dhcpRange',
                             'winsServer',
                             'domainNameServer',
                             'sharedNetworkName'])
    d.addCallback(_cbGetNets)
    return d

def search(client, baseDN, filter, dnsDomain):
    e=ldapsyntax.LDAPEntry(client=client, dn=baseDN)
    d = getNets(e, filter)
    d.addCallback(haveNets, e, baseDN, filter, dnsDomain)
    return d


exitStatus=0

def error(fail):
    print >>sys.stderr, 'fail:', fail.getErrorMessage()
    global exitStatus
    exitStatus=1

def main(cfg, filter_text, dnsDomain):
    try:
        baseDN = cfg.getBaseDN()
    except config.MissingBaseDNError, e:
        print >>sys.stderr, "%s: %s." % (sys.argv[0], e)
        sys.exit(1)

    from twisted.python import log
    log.startLogging(sys.stderr, setStdout=0)

    if filter_text is not None:
	filt = ldapfilter.parseFilter(filter_text)
    else:
	filt = None

    c = ldapconnector.LDAPClientCreator(reactor,
                                        ldapclient.LDAPClient)
    d = c.connectAnonymously(dn=baseDN,
                             overrides=cfg.getServiceLocationOverrides())
    d.addCallback(search, baseDN, filt, dnsDomain)
    d.addErrback(error)
    d.addBoth(lambda x: reactor.stop())

    reactor.run()
    sys.exit(exitStatus)

class MyOptions(usage.Options,
                usage.Options_service_location,
                usage.Options_base_optional):
    """LDAPtor dhcpd config file exporter"""

    optParameters = (
	('dns-domain', None, 'example.com',
	 "DNS domain name"),
	)

    def parseArgs(self, filter=None):
	self.opts['filter'] = filter

if __name__ == "__main__":
    import sys
    try:
	opts = MyOptions()
	opts.parseOptions()
    except usage.UsageError, ue:
	sys.stderr.write('%s: %s\n' % (sys.argv[0], ue))
	sys.exit(1)

    cfg = config.LDAPConfig(baseDN=opts['base'],
                            serviceLocationOverrides=opts['service-location'])
    main(cfg,
	 opts['filter'],
	 opts['dns-domain'],
	 )
