#!/usr/bin/python

import getopt, sys, os

training_basedir = ""
sphinxbinpath = ""
sphinxpath = ""

def find_paths():
    global training_basedir
    global sphinxbinpath
    global sphinxpath
    # Find the location of the files, it can be libexec or lib or lib64
    currentpath = os.path.dirname(os.path.realpath(__file__))

    sphinxbinpath = os.path.realpath(currentpath + "/../libexec/sphinxtrain")

    if os.path.exists(currentpath + "/../lib/sphinxtrain/bw"):
        sphinxbinpath = os.path.realpath(currentpath + "/../lib/sphinxtrain/bw")

    if os.path.exists(currentpath + "/../bin/Release"):
        sphinxbinpath = os.path.realpath(currentpath + "/../bin/Release")

    # Find the location for the libraries
    sphinxpath = os.path.realpath(currentpath + "/../lib/sphinxtrain")

    if os.path.exists(currentpath + "/../lib64/sphinxtrain/scripts/00.verify"):
        sphinxpath = os.path.realpath(currentpath + "/../lib64/sphinxtrain")

    if os.path.exists(currentpath + "/../scripts/00.verify"):
        sphinxpath = os.path.realpath(currentpath + "/..")

    if not (os.path.exists(sphinxbinpath + "/bw") or os.path.exists(sphinxbinpath + "/bw.exe")):
        print "Failed to find sphinxtrain binaries. Check your installation"
        exit(1)

    # Perl script want forward slashes
    training_basedir = os.getcwd().replace('\\', '/');
    sphinxpath = sphinxpath.replace('\\','/')
    sphinxbinpath = sphinxbinpath.replace('\\','/')

    print "Sphinxtrain path:", sphinxpath
    print "Sphinxtrain binaries path:", sphinxbinpath
    
def setup(task):
    if not os.path.exists("etc"):
        os.mkdir("etc")

    print "Setting up the database " + task

    out_cfg = open("etc/sphinx_train.cfg", "w")
    for line in open(sphinxpath + "/etc/sphinx_train.cfg", "r"):
        line = line.replace("___DB_NAME___", task)
        line = line.replace("___BASE_DIR___", training_basedir)
        line = line.replace("___SPHINXTRAIN_DIR___", sphinxpath)
        line = line.replace("___SPHINXTRAIN_BIN_DIR___", sphinxbinpath)
        out_cfg.write(line)
    out_cfg.close()

    out_cfg = open("etc/feat.params", "w")
    for line in open(sphinxpath + "/etc/feat.params", "r"):
        out_cfg.write(line)
    out_cfg.close()

steps = [
"000.comp_feat/slave_feat.pl",
"00.verify/verify_all.pl",
"0000.g2p_train/g2p_train.pl",
"01.lda_train/slave_lda.pl",
"02.mllt_train/slave_mllt.pl",
"05.vector_quantize/slave.VQ.pl",
"10.falign_ci_hmm/slave_convg.pl",
"11.force_align/slave_align.pl",
"12.vtln_align/slave_align.pl",
"20.ci_hmm/slave_convg.pl",
"30.cd_hmm_untied/slave_convg.pl",
"40.buildtrees/slave.treebuilder.pl",
"45.prunetree/slave.state-tying.pl",
"50.cd_hmm_tied/slave_convg.pl",
"60.lattice_generation/slave_genlat.pl",
"61.lattice_pruning/slave_prune.pl",
"62.lattice_conversion/slave_conv.pl",
"65.mmie_train/slave_convg.pl",
"90.deleted_interpolation/deleted_interpolation.pl",
"decode/slave.pl",
]

def run_stages(stages):
    for stage in stages.split(","):
        for step in steps:
                name = step.split("/")[0].split(".")[-1]
                if name == stage:
                    ret = os.system(sphinxpath + "/scripts/" + step)
                    if ret != 0:
                	exit(ret)

def run_from(stage):
    found = False
    for step in steps:
        name = step.split("/")[0].split(".")[-1]
        if name == stage or found:
            found = True
            ret = os.system(sphinxpath + "/scripts/" + step)
            if ret != 0:
                exit(ret)

def run():
    print "Running the training"
    for step in steps:
	ret = os.system(sphinxpath + "/scripts/" + step)
	if ret != 0:
	    exit(ret)

def usage():
    print ""
    print "Sphinxtrain processes the audio files and creates and acoustic model "
    print "for CMUSphinx toolkit. The data needs to have a certain layout "
    print "See the tutorial http://cmusphinx.sourceforge.net/wiki/tutorialam "
    print "for details"
    print ""
    print "Usage: sphinxtrain [options] <command>"
    print ""
    print "Commands:"
    print "     -t <task> setup - copy configuration into database"
    print "     [-s <stage1,stage2,stage3>] [-f <stage>] run - run the training or just selected stages"

def main():

    try:
        opts, args = getopt.getopt(sys.argv[1:], "ht:s:f:", ["help", "task", "stages", "from"])
    except getopt.GetoptError, err:
        print str(err)
        usage()
        sys.exit(-1)

    task = None
    stages = None
    from_stage = None

    for o, a in opts:
        if o in ("-t", "--task"):
            task = a
        if o in ("-f", "--from"):
            from_stage = a
        if o in ("-s", "--stages"):
            stages = a
        if o in ("-h", "--help"):
                usage()

    if len(args) == 0:
        usage()
        sys.exit(-1)

    command = args[0]

    find_paths()

    if command == "setup":
        if task == None:
            print "No task name defined"
            sys.exit(-1)        
        setup(task)
    elif command == "run":
        if stages != None:
            run_stages(stages)
        elif from_stage != None:
            run_from(from_stage)
        else:
            run()
    else:
        run()

if __name__ == "__main__":
    main()
