/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.run;

import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.dependency.srl.SRLabeler;
import com.googlecode.clearnlp.engine.EngineSetter;
import com.googlecode.clearnlp.feature.xml.SRLFtrXml;
import com.googlecode.clearnlp.reader.SRLReader;
import com.googlecode.clearnlp.run.AbstractRun;
import com.googlecode.clearnlp.util.UTFile;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTXml;
import com.googlecode.clearnlp.util.pair.Pair;
import java.io.FileInputStream;
import java.util.Set;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;

public class SRLTrain
extends AbstractRun {
    @Option(name="-i", usage="the directory containg training files (input; required)", required=true, metaVar="<directory>")
    protected String s_trainDir;
    @Option(name="-c", usage="the configuration file (input; required)", required=true, metaVar="<filename>")
    protected String s_configXml;
    @Option(name="-f", usage="the feature file (input; required)", required=true, metaVar="<filename>")
    protected String s_featureXml;
    @Option(name="-m", usage="the model file (output; required)", required=true, metaVar="<filename>")
    protected String s_modelFile;
    @Option(name="-n", usage="the bootstrapping level (default: 2)", required=false, metaVar="<integer>")
    protected int n_boot = 2;
    @Option(name="-sb", usage="if set, save all intermediate bootstrapping models", required=false, metaVar="<boolean>")
    protected boolean b_saveAllModels = false;

    public SRLTrain() {
    }

    public SRLTrain(String[] args) {
        this.initArgs(args);
        try {
            this.run(this.s_configXml, this.s_featureXml, this.s_trainDir, this.s_modelFile, this.n_boot);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void run(String configXml, String featureXml, String trainDir, String modelFile, int nBoot) throws Exception {
        Element eConfig = UTXml.getDocumentElement(new FileInputStream(configXml));
        SRLReader reader = (SRLReader)this.getReader((Element)eConfig).o1;
        SRLFtrXml xml = new SRLFtrXml(new FileInputStream(featureXml));
        String[] trainFiles = UTFile.getSortedFileList(trainDir);
        Pair<Set<String>, Set<String>> p = this.getDownUpSets(reader, xml, trainFiles, -1);
        int boot = 0;
        SRLabeler labeler = this.getTrainedLabeler(eConfig, reader, xml, trainFiles, null, (Set)p.o1, (Set)p.o2, -1);
        if (this.b_saveAllModels) {
            EngineSetter.setSRLabeler(modelFile + "." + boot, featureXml, labeler);
        }
        for (boot = 1; boot <= nBoot; ++boot) {
            labeler = this.getTrainedLabeler(eConfig, reader, xml, trainFiles, labeler.getModels(), (Set)p.o1, (Set)p.o2, -1);
            if (!this.b_saveAllModels) continue;
            EngineSetter.setSRLabeler(modelFile + "." + boot, featureXml, labeler);
        }
        if (!this.b_saveAllModels) {
            EngineSetter.setSRLabeler(modelFile, featureXml, labeler);
        }
    }

    public Pair<Set<String>, Set<String>> getDownUpSets(SRLReader reader, SRLFtrXml xml, String[] trainFiles, int devId) {
        SRLabeler labeler = new SRLabeler();
        int size = trainFiles.length;
        System.out.println("Collecting lexica:");
        for (int i = 0; i < size; ++i) {
            DEPTree tree;
            if (devId == i) continue;
            reader.open(UTInput.createBufferedFileReader(trainFiles[i]));
            while ((tree = reader.next()) != null) {
                labeler.label(tree);
            }
            System.out.print(".");
            reader.close();
        }
        System.out.println();
        Set<String> sDown = labeler.getDownSet(xml.getDownCutoff());
        Set<String> sUp = labeler.getUpSet(xml.getUpCutoff());
        System.out.printf("- down-paths: size = %d, cutoff = %d\n", sDown.size(), xml.getDownCutoff());
        System.out.printf("- up-paths  : size = %d, cutoff = %d\n", sUp.size(), xml.getUpCutoff());
        return new Pair<Set<String>, Set<String>>(sDown, sUp);
    }

    public SRLabeler getTrainedLabeler(Element eConfig, SRLReader reader, SRLFtrXml xml, String[] trainFiles, StringModel[] models, Set<String> sDown, Set<String> sUp, int devId) throws Exception {
        int i;
        StringTrainSpace[] spaces = new StringTrainSpace[2];
        int size = trainFiles.length;
        for (i = 0; i < spaces.length; ++i) {
            spaces[i] = new StringTrainSpace(false, xml.getLabelCutoff(0), xml.getFeatureCutoff(0));
        }
        SRLabeler labeler = models == null ? new SRLabeler(xml, spaces, sDown, sUp) : new SRLabeler(xml, models, spaces, sDown, sUp);
        System.out.println("Collecting training instances:");
        for (i = 0; i < size; ++i) {
            DEPTree tree;
            if (devId == i) continue;
            reader.open(UTInput.createBufferedFileReader(trainFiles[i]));
            while ((tree = reader.next()) != null) {
                labeler.label(tree);
            }
            System.out.print(".");
            reader.close();
        }
        System.out.println();
        models = new StringModel[spaces.length];
        for (i = 0; i < models.length; ++i) {
            models[i] = (StringModel)this.getModel(UTXml.getFirstElementByTagName(eConfig, "train"), spaces[i], i);
        }
        return new SRLabeler(xml, models, sDown, sUp);
    }

    public static void main(String[] args) {
        new SRLTrain(args);
    }
}

