#! /usr/bin/env python

# pipeline to align reads to contigs implemented with ruffus (http://ruffus.org.uk)
from ruffus import *
import sys
import os.path
import subprocess
import getopt

def getBasename(inFile):
    (dirName, fileName) = os.path.split(inFile)
    (fileBaseName, fileExtension)=os.path.splitext(fileName)
    return fileBaseName

# Run a shell command
def runCmd(cmd):
    print cmd
    return subprocess.Popen(cmd, shell=True).wait()

def _align(ref, reads, out):
    global threads
    alnCmd = "bwa aln -o 0 -t %d %s %s > %s" % (threads, ref, reads, out)
    runCmd(alnCmd)

# Process commands
def runBWAIndex(inFile):
    runCmd("bwa index " + inFile)

def runBWAAln(inFiles, contigsFile, outFile):
    assert len(inFiles) == 2
    if inFiles[0] == inFiles[1]:
        print 'Error: the two reads files are the same'
        sys.exit(1)

    contigsBasename = getBasename(contigsFile)
    tempSAI = map(lambda x : getBasename(x) + "." + contigsBasename + ".bwsai", inFiles)

    _align(contigsFile, inFiles[0], tempSAI[0])
    _align(contigsFile, inFiles[1], tempSAI[1])
    
    samCmd = "bwa sampe -s %s %s %s %s %s | samtools view -Sb -o %s -" % (contigsFile, tempSAI[0], tempSAI[1], inFiles[0], inFiles[1], outFile)
    runCmd(samCmd)

    os.unlink(tempSAI[0])
    os.unlink(tempSAI[1])

def runMerge(inFiles, outFile):
    cmd = "samtools merge %s %s" % (outFile, " ".join(inFiles))
    runCmd(cmd)

def runSort(inFile, outFile):

    (outPrefix, fileExtension) = os.path.splitext(outFile)
    cmd = "samtools sort %s %s" % (inFile, outPrefix)
    runCmd(cmd)
    
    #cmd = "samtools index %s" % (outFile)
    #runCmd(cmd)

def runSplitReads(inFiles, outFiles):
    assert len(inFiles) == 1 and len(outFiles) == 2
    scriptPath = os.path.dirname(__file__)
    cmd = "%s/sga-deinterleave.pl %s %s %s" % (scriptPath, inFiles[0], outFiles[0], outFiles[1])
    runCmd(cmd)

#
def usage():
    print 'usage: sga-align [options] <contigs file> <input files>'
    print 'align reads to contigs'
    print 'Options:'
    print '       --name=STR          Use STR as the basename for the output files.'
    print '       -t,--threads=N      Use N threads when running bwa.'


# Variables
threads = 1

try:
    opts, args = getopt.gnu_getopt(sys.argv[1:], 't:', ["name=",
                                                        "threads=",
                                                        "help"])
except getopt.GetoptError, err:
        print str(err)
        usage()
        sys.exit(2)
    
for (oflag, oarg) in opts:
        if oflag == "--name":
            projectName = oarg
        elif oflag == "--threads" or oflag == '-t':
            threads = int(oarg)
        elif oflag == '--help':
            usage()
            sys.exit(1)
        else:
            print 'Unrecognized argument', oflag
            usage()
            sys.exit(0)

if len(args) < 2:
    print 'Error, a contigs file and reads file must be provided'
    usage()
    sys.exit(0)

bPEMode = True
contigFile = args[0]
readFiles = args[1:]

if len(readFiles) == 0:
    print 'Error, at least one input file must be specified'
    sys.exit(0)

print contigFile
print readFiles

indexInFiles = contigFile
indexOutFiles = [contigFile + ".bwt"]
@files(indexInFiles, indexOutFiles)
def indexContigs(input, output):
    runBWAIndex(input)

# Prepare the reads for alignments
prepareInFiles = list()
prepareOutFiles = list()
prepareAction = "nothing"

# If there is a single reads file, split it into two pe files
# otherwise do nothing
if len(readFiles) == 1 and bPEMode:
    prepareInFiles = readFiles
    prepareOutFiles = [projectName + ".r1.fastq", projectName + ".r2.fastq"]
    prepareAction = "split"
else:
    prepareInFiles = readFiles
    prepareOutFiles = readFiles
    prepareAction = "nothing"

@follows(indexContigs)
@files(prepareInFiles, prepareOutFiles, prepareAction)
def prepareReads(input, output, action):
    if action == "split":
        runSplitReads(input, output)

alignInFiles = list(prepareOutFiles)
alignOutFiles = projectName + ".bam"
alignInFiles.append(contigFile)
@follows(prepareReads)
@files(alignInFiles, alignOutFiles)
def alignReads(input, output):
    contigFile = input[-1]
    input.pop()
    assert len(input) == 2
    runBWAAln(input, contigFile, output)

sortInFile = alignOutFiles
sortOutFile = projectName + ".refsort.bam"
@follows(alignReads)
@files(sortInFile, sortOutFile)
def sortBAM(input, output):
    runSort(input, output)

pipeline_run([sortBAM])
