/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.tokenize;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
import org.apache.sysds.runtime.transform.tokenize.TokenizerPost;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

public class TokenizerPostCount
implements TokenizerPost {
    private static final long serialVersionUID = 6382000606237705019L;
    private final Params params;
    private final int numIdCols;
    private final int maxTokens;
    private final boolean wideFormat;

    public TokenizerPostCount(JSONObject params, int numIdCols, int maxTokens, boolean wideFormat) throws JSONException {
        this.params = new Params(params);
        this.numIdCols = numIdCols;
        this.maxTokens = maxTokens;
        this.wideFormat = wideFormat;
    }

    @Override
    public FrameBlock tokenizePost(List<Tokenizer.DocumentToTokens> tl, FrameBlock out) {
        block0: for (Tokenizer.DocumentToTokens docToToken : tl) {
            List<Object> keys = docToToken.keys;
            List<Tokenizer.Token> tokenList = docToToken.tokens;
            Map<String, Long> tokenCounts = tokenList.stream().collect(Collectors.groupingBy(token -> token.textToken, Collectors.counting()));
            Stream<String> distinctTokenStream = tokenList.stream().map(token -> token.textToken).distinct();
            if (this.params.sort_alpha) {
                distinctTokenStream = distinctTokenStream.sorted();
            }
            List outputTokens = distinctTokenStream.collect(Collectors.toList());
            int numTokens = 0;
            for (String token2 : outputTokens) {
                if (numTokens >= this.maxTokens) continue block0;
                long count = tokenCounts.get(token2);
                ArrayList<Object> rowList = new ArrayList<Object>(keys);
                rowList.add(token2);
                rowList.add(count);
                Object[] row = new Object[rowList.size()];
                rowList.toArray(row);
                out.appendRow(row);
                ++numTokens;
            }
        }
        return out;
    }

    @Override
    public Types.ValueType[] getOutSchema() {
        if (this.wideFormat) {
            throw new IllegalArgumentException("Wide Format is not supported for Count Representation.");
        }
        Types.ValueType[] schema = UtilFunctions.nCopies(this.numIdCols + 2, Types.ValueType.STRING);
        schema[this.numIdCols + 1] = Types.ValueType.INT64;
        return schema;
    }

    @Override
    public long getNumRows(long inRows) {
        if (this.wideFormat) {
            return inRows;
        }
        return inRows * (long)this.maxTokens;
    }

    @Override
    public long getNumCols() {
        return this.getOutSchema().length;
    }

    static class Params
    implements Serializable {
        private static final long serialVersionUID = 5121697674346781880L;
        public boolean sort_alpha = false;

        public Params(JSONObject json) throws JSONException {
            if (json != null && json.has("sort_alpha")) {
                this.sort_alpha = json.getBoolean("sort_alpha");
            }
        }
    }
}

