/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.run;

import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.dependency.DEPNode;
import com.googlecode.clearnlp.dependency.DEPParser;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.engine.EngineSetter;
import com.googlecode.clearnlp.feature.xml.DEPFtrXml;
import com.googlecode.clearnlp.reader.DEPReader;
import com.googlecode.clearnlp.run.AbstractRun;
import com.googlecode.clearnlp.util.UTFile;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTXml;
import com.googlecode.clearnlp.util.map.Prob1DMap;
import java.io.FileInputStream;
import java.util.ArrayList;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;

public class DEPTrain
extends AbstractRun {
    protected final String LEXICON_PUNCTUATION = "punctuation";
    @Option(name="-i", usage="input directory containing training files (required)", required=true, metaVar="<directory>")
    protected String s_trainDir;
    @Option(name="-c", usage="configuration file (required)", required=true, metaVar="<filename>")
    protected String s_configXml;
    @Option(name="-f", usage="feature template file (required)", required=true, metaVar="<filename>")
    protected String s_featureXml;
    @Option(name="-m", usage="model file (output; required)", required=true, metaVar="<filename>")
    protected String s_modelFile;
    @Option(name="-n", usage="bootstrapping level (default: 2)", required=false, metaVar="<integer>")
    protected int n_boot = 2;
    @Option(name="-sb", usage="if set, save all intermediate bootstrapping models", required=false, metaVar="<boolean>")
    protected boolean b_saveAllModels = false;

    public DEPTrain() {
    }

    public DEPTrain(String[] args) {
        this.initArgs(args);
        try {
            this.run(this.s_configXml, this.s_featureXml, this.s_trainDir, this.s_modelFile, this.n_boot);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void run(String configXml, String featureXml, String trainDir, String modelFile, int nBoot) throws Exception {
        Element eConfig = UTXml.getDocumentElement(new FileInputStream(configXml));
        DEPFtrXml xml = new DEPFtrXml(new FileInputStream(featureXml));
        String[] trainFiles = UTFile.getSortedFileListBySize(trainDir, ".*", true);
        Set<String> sPunc = this.getLexica(eConfig, xml, trainFiles, -1);
        int boot = 0;
        DEPParser parser = this.getTrainedParser(eConfig, xml, sPunc, trainFiles, null, -1, boot);
        if (this.b_saveAllModels) {
            EngineSetter.setDEPParser(modelFile + "." + boot, featureXml, parser);
        }
        for (boot = 1; boot <= nBoot; ++boot) {
            parser = this.getTrainedParser(eConfig, xml, sPunc, trainFiles, parser.getModel(), -1, boot);
            if (!this.b_saveAllModels) continue;
            EngineSetter.setDEPParser(modelFile + "." + boot, featureXml, parser);
        }
        if (!this.b_saveAllModels) {
            EngineSetter.setDEPParser(modelFile, featureXml, parser);
        }
    }

    protected Set<String> getLexica(Element eConfig, DEPFtrXml xml, String[] trainFiles, int devId) throws Exception {
        DEPReader reader = (DEPReader)this.getReader((Element)eConfig).o1;
        Prob1DMap mPunct = new Prob1DMap();
        int size = trainFiles.length;
        System.out.println("Collecting lexica:");
        for (int i = 0; i < size; ++i) {
            DEPTree tree;
            if (i == devId) continue;
            reader.open(UTInput.createBufferedFileReader(trainFiles[i]));
            while ((tree = reader.next()) != null) {
                this.collectLexica(tree, mPunct, xml.getPunctuationLabel());
            }
            System.out.print(".");
            reader.close();
        }
        System.out.println();
        return mPunct.toSet(xml.getPunctuationCutoff());
    }

    private void collectLexica(DEPTree tree, Prob1DMap mPunct, String punctLabel) {
        int size = tree.size();
        for (int i = 1; i < size; ++i) {
            DEPNode node = tree.get(i);
            if (!node.isLabel(punctLabel)) continue;
            mPunct.add(node.form);
        }
    }

    public DEPParser getTrainedParser(Element eConfig, DEPFtrXml xml, Set<String> sPunc, String[] trainFiles, StringModel model, int devId, int boot) throws Exception {
        StringTrainSpace space;
        int i;
        int size = trainFiles.length;
        int labelCutoff = xml.getLabelCutoff(0);
        int featureCutoff = xml.getFeatureCutoff(0);
        Element eTrain = UTXml.getFirstElementByTagName(eConfig, "train");
        int numThreads = this.getNumOfThreads(eTrain);
        ExecutorService executor = Executors.newFixedThreadPool(numThreads);
        ArrayList<StringTrainSpace> spaces = new ArrayList<StringTrainSpace>();
        System.out.println("Collecting training instances:");
        for (i = 0; i < size; ++i) {
            if (devId == i) continue;
            space = new StringTrainSpace(false, labelCutoff, featureCutoff);
            spaces.add(space);
            executor.execute(new TrainTask(eConfig, xml, sPunc, trainFiles[i], model, space));
        }
        executor.shutdown();
        try {
            executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
        }
        catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println();
        space = (StringTrainSpace)spaces.get(0);
        size = spaces.size();
        if (size > 1) {
            System.out.println("Merging training instances:");
            for (i = 1; i < size; ++i) {
                space.appendSpace((StringTrainSpace)spaces.get(i));
                ((StringTrainSpace)spaces.get(i)).clear();
                System.out.print(".");
            }
            System.out.println();
        }
        model = null;
        model = (StringModel)this.getModel(eTrain, space, 0);
        return new DEPParser(xml, sPunc, model);
    }

    protected void printScores(int[] counts) {
        System.out.printf("- LAS: %5.2f (%d/%d)\n", 100.0 * (double)counts[1] / (double)counts[0], counts[1], counts[0]);
        System.out.printf("- UAS: %5.2f (%d/%d)\n", 100.0 * (double)counts[2] / (double)counts[0], counts[2], counts[0]);
        System.out.printf("- LS : %5.2f (%d/%d)\n", 100.0 * (double)counts[3] / (double)counts[0], counts[3], counts[0]);
    }

    public static void main(String[] args) {
        new DEPTrain(args);
    }

    private class TrainTask
    implements Runnable {
        DEPParser d_parser;
        DEPReader d_reader;

        public TrainTask(Element eConfig, DEPFtrXml xml, Set<String> sPunc, String trainFile, StringModel model, StringTrainSpace space) {
            this.d_parser = model == null ? new DEPParser(xml, sPunc, space) : new DEPParser(xml, sPunc, model, space);
            this.d_reader = (DEPReader)DEPTrain.this.getReader((Element)eConfig).o1;
            this.d_reader.open(UTInput.createBufferedFileReader(trainFile));
        }

        @Override
        public void run() {
            DEPTree tree;
            while ((tree = this.d_reader.next()) != null) {
                this.d_parser.parse(tree);
            }
            this.d_reader.close();
            System.out.print(".");
        }
    }
}

