/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ctakes.relationextractor.ae.features;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.ctakes.relationextractor.ae.features.RelationFeaturesExtractor;
import org.apache.ctakes.relationextractor.data.analysis.Utils;
import org.apache.ctakes.typesystem.type.syntax.WordToken;
import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.text.AnnotationFS;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.cleartk.ml.Feature;

public class EmbeddingFeatureExtractor
implements RelationFeaturesExtractor<IdentifiedAnnotation, IdentifiedAnnotation> {
    private int numberOfDimensions;
    private Map<String, List<Double>> wordVectors;

    public EmbeddingFeatureExtractor(Map<String, List<Double>> wordVectors) {
        this.wordVectors = wordVectors;
        this.numberOfDimensions = this.wordVectors.get("oov").size();
    }

    @Override
    public List<Feature> extract(JCas jCas, IdentifiedAnnotation arg1, IdentifiedAnnotation arg2) throws AnalysisEngineProcessException {
        String featureName;
        int dim;
        ArrayList<Feature> features = new ArrayList<Feature>();
        String arg1LastWord = Utils.getLastWord(jCas, (Annotation)arg1).toLowerCase();
        String arg2LastWord = Utils.getLastWord(jCas, (Annotation)arg2).toLowerCase();
        List<Double> arg1Vector = this.wordVectors.containsKey(arg1LastWord) ? this.wordVectors.get(arg1LastWord) : this.wordVectors.get("oov");
        List<Double> arg2Vector = this.wordVectors.containsKey(arg2LastWord) ? this.wordVectors.get(arg2LastWord) : this.wordVectors.get("oov");
        for (dim = 0; dim < this.numberOfDimensions; ++dim) {
            featureName = String.format("arg1_dim_%d", dim);
            features.add(new Feature(featureName, (Object)arg1Vector.get(dim)));
        }
        for (dim = 0; dim < this.numberOfDimensions; ++dim) {
            featureName = String.format("arg2_dim_%d", dim);
            features.add(new Feature(featureName, (Object)arg2Vector.get(dim)));
        }
        double similarity = this.computeCosineSimilarity(arg1Vector, arg2Vector);
        features.add(new Feature("arg_cos_sim", (Object)similarity));
        List wordsBetweenArgs = JCasUtil.selectBetween((JCas)jCas, WordToken.class, (AnnotationFS)arg1, (AnnotationFS)arg2);
        if (wordsBetweenArgs.size() < 1) {
            return features;
        }
        List<Double> sum = new ArrayList<Double>(Collections.nCopies(this.numberOfDimensions, 0.0));
        for (WordToken wordToken : wordsBetweenArgs) {
            List<Double> wordVector = this.wordVectors.containsKey(wordToken.getCoveredText().toLowerCase()) ? this.wordVectors.get(wordToken.getCoveredText().toLowerCase()) : this.wordVectors.get("oov");
            sum = this.addVectors(sum, wordVector);
        }
        for (int dim2 = 0; dim2 < this.numberOfDimensions; ++dim2) {
            String featureName2 = String.format("average_dim_%d", dim2);
            features.add(new Feature(featureName2, (Object)((Double)sum.get(dim2) / (double)wordsBetweenArgs.size())));
        }
        return features;
    }

    public double computeCosineSimilarity(List<Double> vector1, List<Double> vector2) {
        double dotProduct = 0.0;
        double norm1 = 0.0;
        double norm2 = 0.0;
        for (int dim = 0; dim < this.numberOfDimensions; ++dim) {
            dotProduct += vector1.get(dim) * vector2.get(dim);
            norm1 += Math.pow(vector1.get(dim), 2.0);
            norm2 += Math.pow(vector2.get(dim), 2.0);
        }
        return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
    }

    public List<Double> addVectors(List<Double> vector1, List<Double> vector2) {
        ArrayList<Double> sum = new ArrayList<Double>();
        for (int dim = 0; dim < this.numberOfDimensions; ++dim) {
            sum.add(vector1.get(dim) + vector2.get(dim));
        }
        return sum;
    }
}

