/*
 * Decompiled with CFR 0.152.
 */
package io.github.lnyocly.ai4j.utils;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;
import com.knuddels.jtokkit.api.ModelType;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatMessage;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TikTokensUtil {
    private static final Logger log = LoggerFactory.getLogger(TikTokensUtil.class);
    private static final Map<String, Encoding> modelMap = new HashMap<String, Encoding>();
    private static final EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();

    public static int tokens(EncodingType encodingType, String content) {
        Encoding encoding = registry.getEncoding(encodingType);
        return encoding.countTokens(content);
    }

    public static int tokens(String modelName, String content) {
        if (StringUtils.isEmpty((CharSequence)content)) {
            return 0;
        }
        Encoding encoding = modelMap.get(modelName);
        return encoding.countTokens(content);
    }

    public static int tokens(String modelName, List<ChatMessage> messages) {
        Encoding encoding = modelMap.get(modelName);
        if (ObjectUtils.isEmpty((Object)encoding)) {
            throw new IllegalArgumentException("\u4e0d\u652f\u6301\u8ba1\u7b97Token\u7684\u6a21\u578b: " + modelName);
        }
        int tokensPerMessage = 0;
        int tokensPerName = 0;
        if (modelName.startsWith("gpt-4")) {
            tokensPerMessage = 3;
            tokensPerName = 1;
        } else if (modelName.startsWith("gpt-3.5")) {
            tokensPerMessage = 4;
            tokensPerName = -1;
        }
        int sum = 0;
        for (ChatMessage message : messages) {
            sum += tokensPerMessage;
            sum += encoding.countTokens(message.getContent().getText());
            sum += encoding.countTokens(message.getRole());
            if (!StringUtils.isNotEmpty((CharSequence)message.getName())) continue;
            sum += encoding.countTokens(message.getName());
            sum += tokensPerName;
        }
        return sum += 3;
    }

    static {
        for (ModelType model : ModelType.values()) {
            Optional encodingForModel = registry.getEncodingForModel(model.getName());
            encodingForModel.ifPresent(encoding -> modelMap.put(model.getName(), (Encoding)encoding));
        }
    }
}

