/*
 * Decompiled with CFR 0.152.
 */
package org.apache.nifi.security.util.crypto;

import at.favre.lib.crypto.bcrypt.Radix64Encoder;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.security.util.EncryptionMethod;
import org.apache.nifi.security.util.crypto.AESKeyedCipherProvider;
import org.apache.nifi.security.util.crypto.CipherUtility;
import org.apache.nifi.security.util.crypto.KeyDerivationBcryptSecureHasher;
import org.apache.nifi.security.util.crypto.KeyedCipherProvider;
import org.apache.nifi.security.util.crypto.RandomIVPBECipherProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BcryptCipherProvider
extends RandomIVPBECipherProvider {
    private static final Logger logger = LoggerFactory.getLogger(BcryptCipherProvider.class);
    private final int workFactor;
    private static final int DEFAULT_WORK_FACTOR = 12;
    private static final int DEFAULT_SALT_LENGTH = 16;
    private static final Pattern BCRYPT_SALT_FORMAT = Pattern.compile("^\\$\\d\\w\\$\\d{2}\\$[\\w\\/\\.]{22}");
    private static final String BCRYPT_SALT_FORMAT_MSG = "The salt must be of the format $2a$10$gUVbkVzp79H8YaCOsCVZNu. To generate a salt, use BcryptCipherProvider#generateSalt()";

    public BcryptCipherProvider() {
        this(12);
    }

    public BcryptCipherProvider(int workFactor) {
        this.workFactor = workFactor;
        if (workFactor < 12) {
            logger.warn("The provided work factor {} is below the recommended minimum {}", (Object)workFactor, (Object)12);
        }
    }

    @Override
    Logger getLogger() {
        return logger;
    }

    @Override
    public Cipher getCipher(EncryptionMethod encryptionMethod, String password, byte[] salt, byte[] iv, int keyLength, boolean encryptMode) throws Exception {
        return this.createCipherAndHandleExceptions(encryptionMethod, password, salt, iv, keyLength, encryptMode, false);
    }

    private Cipher createCipherAndHandleExceptions(EncryptionMethod encryptionMethod, String password, byte[] salt, byte[] iv, int keyLength, boolean encryptMode, boolean useLegacyKeyDerivation) {
        try {
            return this.getInitializedCipher(encryptionMethod, password, salt, iv, keyLength, encryptMode, useLegacyKeyDerivation);
        }
        catch (IllegalArgumentException e) {
            throw e;
        }
        catch (Exception e) {
            throw new ProcessException("Error initializing the cipher", (Throwable)e);
        }
    }

    @Override
    public Cipher getCipher(EncryptionMethod encryptionMethod, String password, byte[] salt, int keyLength, boolean encryptMode) throws Exception {
        return this.getCipher(encryptionMethod, password, salt, new byte[0], keyLength, encryptMode);
    }

    public Cipher getLegacyDecryptCipher(EncryptionMethod encryptionMethod, String password, byte[] salt, byte[] iv, int keyLength) {
        return this.createCipherAndHandleExceptions(encryptionMethod, password, salt, iv, keyLength, false, true);
    }

    protected Cipher getInitializedCipher(EncryptionMethod encryptionMethod, String password, byte[] salt, byte[] iv, int keyLength, boolean encryptMode, boolean useLegacyKeyDerivation) throws Exception {
        if (encryptionMethod == null) {
            throw new IllegalArgumentException("The encryption method must be specified");
        }
        if (!encryptionMethod.isCompatibleWithStrongKDFs()) {
            throw new IllegalArgumentException(encryptionMethod.name() + " is not compatible with Bcrypt");
        }
        if (StringUtils.isEmpty((CharSequence)password)) {
            throw new IllegalArgumentException("Encryption with an empty password is not supported");
        }
        String algorithm = encryptionMethod.getAlgorithm();
        String provider = encryptionMethod.getProvider();
        String cipherName = CipherUtility.parseCipherFromAlgorithm(algorithm);
        if (!CipherUtility.isValidKeyLength(keyLength, cipherName)) {
            throw new IllegalArgumentException(keyLength + " is not a valid key length for " + cipherName);
        }
        if (salt == null || salt.length == 0) {
            throw new IllegalArgumentException(BCRYPT_SALT_FORMAT_MSG);
        }
        String saltString = new String(salt, StandardCharsets.UTF_8);
        byte[] rawSalt = new byte[16];
        int workFactor = 0;
        if (!BcryptCipherProvider.isBcryptFormattedSalt(saltString)) {
            throw new IllegalArgumentException(BCRYPT_SALT_FORMAT_MSG);
        }
        workFactor = this.parseSalt(saltString, rawSalt);
        try {
            SecretKey tempKey = this.deriveKey(password, keyLength, algorithm, rawSalt, workFactor, useLegacyKeyDerivation);
            AESKeyedCipherProvider keyedCipherProvider = new AESKeyedCipherProvider();
            return ((KeyedCipherProvider)keyedCipherProvider).getCipher(encryptionMethod, tempKey, iv, encryptMode);
        }
        catch (IllegalArgumentException e) {
            if (e.getMessage().contains("salt must be exactly")) {
                throw new IllegalArgumentException(BCRYPT_SALT_FORMAT_MSG, e);
            }
            if (e.getMessage().contains("The salt length")) {
                throw new IllegalArgumentException("The raw salt must be greater than or equal to 16 bytes", e);
            }
            logger.error("Encountered an error generating the Bcrypt hash", (Throwable)e);
            throw e;
        }
    }

    private SecretKey deriveKey(String password, int keyLength, String algorithm, byte[] rawSalt, int workFactor, boolean useLegacyKeyDerivation) {
        int derivedKeyLength = keyLength / 8;
        KeyDerivationBcryptSecureHasher secureHasher = new KeyDerivationBcryptSecureHasher(derivedKeyLength, workFactor, useLegacyKeyDerivation);
        byte[] derivedKey = secureHasher.hashRaw(password.getBytes(StandardCharsets.UTF_8), rawSalt);
        return new SecretKeySpec(derivedKey, algorithm);
    }

    public static boolean isBcryptFormattedSalt(String salt) {
        if (salt == null || salt.length() == 0) {
            throw new IllegalArgumentException("The salt cannot be empty. To generate a salt, use BcryptCipherProvider#generateSalt()");
        }
        Matcher matcher = BCRYPT_SALT_FORMAT.matcher(salt);
        return matcher.find();
    }

    private int parseSalt(String bcryptSalt, byte[] rawSalt) {
        if (StringUtils.isEmpty((CharSequence)bcryptSalt)) {
            throw new IllegalArgumentException("Cannot parse empty salt");
        }
        byte[] salt = BcryptCipherProvider.extractRawSalt(bcryptSalt);
        if (rawSalt.length < salt.length) {
            byte[] tempBytes = new byte[salt.length];
            System.arraycopy(rawSalt, 0, tempBytes, 0, rawSalt.length);
            rawSalt = tempBytes;
        }
        System.arraycopy(salt, 0, rawSalt, 0, salt.length);
        String[] saltComponents = bcryptSalt.split("\\$");
        return Integer.parseInt(saltComponents[2]);
    }

    public static String formatSaltForBcrypt(byte[] salt, int workFactor) {
        String rawSaltString = new String(salt, StandardCharsets.UTF_8);
        if (BcryptCipherProvider.isBcryptFormattedSalt(rawSaltString)) {
            return rawSaltString;
        }
        String saltString = "$2a$" + StringUtils.leftPad((String)String.valueOf(workFactor), (int)2, (String)"0") + "$" + new String(new Radix64Encoder.Default().encode(salt), StandardCharsets.UTF_8);
        return saltString;
    }

    @Override
    public byte[] generateSalt() {
        byte[] salt = new byte[16];
        SecureRandom sr = new SecureRandom();
        sr.nextBytes(salt);
        String saltString = BcryptCipherProvider.formatSaltForBcrypt(salt, this.getWorkFactor());
        return saltString.getBytes(StandardCharsets.UTF_8);
    }

    public static byte[] extractRawSalt(String fullSalt) {
        String[] saltComponents = fullSalt.split("\\$");
        if (saltComponents.length < 4) {
            throw new IllegalArgumentException("Could not parse salt");
        }
        return new Radix64Encoder.Default().decode(saltComponents[3].getBytes(StandardCharsets.UTF_8));
    }

    @Override
    public int getDefaultSaltLength() {
        return 16;
    }

    protected int getWorkFactor() {
        return this.workFactor;
    }
}

