/*
 * Decompiled with CFR 0.152.
 */
package io.github.lnyocly.ai4j.platform.openai.chat;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.lnyocly.ai4j.config.OpenAiConfig;
import io.github.lnyocly.ai4j.listener.SseListener;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletion;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletionResponse;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatMessage;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.Choice;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.StreamOptions;
import io.github.lnyocly.ai4j.platform.openai.tool.Tool;
import io.github.lnyocly.ai4j.platform.openai.tool.ToolCall;
import io.github.lnyocly.ai4j.platform.openai.usage.Usage;
import io.github.lnyocly.ai4j.service.Configuration;
import io.github.lnyocly.ai4j.service.IChatService;
import io.github.lnyocly.ai4j.utils.ToolUtil;
import io.github.lnyocly.ai4j.utils.ValidateUtil;
import java.util.ArrayList;
import java.util.List;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OpenAiChatService
implements IChatService {
    private static final Logger log = LoggerFactory.getLogger(OpenAiChatService.class);
    private final OpenAiConfig openAiConfig;
    private final OkHttpClient okHttpClient;
    private final EventSource.Factory factory;

    public OpenAiChatService(Configuration configuration) {
        this.openAiConfig = configuration.getOpenAiConfig();
        this.okHttpClient = configuration.getOkHttpClient();
        this.factory = configuration.createRequestFactory();
    }

    @Override
    public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, ChatCompletion chatCompletion) throws Exception {
        if (baseUrl == null || "".equals(baseUrl)) {
            baseUrl = this.openAiConfig.getApiHost();
        }
        if (apiKey == null || "".equals(apiKey)) {
            apiKey = this.openAiConfig.getApiKey();
        }
        chatCompletion.setStream(false);
        chatCompletion.setStreamOptions(null);
        if (chatCompletion.getFunctions() != null && !chatCompletion.getFunctions().isEmpty()) {
            List<Tool> tools = ToolUtil.getAllFunctionTools(chatCompletion.getFunctions());
            chatCompletion.setTools(tools);
            if (tools == null) {
                chatCompletion.setParallelToolCalls(null);
            }
        } else {
            chatCompletion.setParallelToolCalls(null);
        }
        Usage allUsage = new Usage();
        String finishReason = "first";
        while ("first".equals(finishReason) || "tool_calls".equals(finishReason)) {
            finishReason = null;
            ObjectMapper mapper = new ObjectMapper();
            String requestString = mapper.writeValueAsString((Object)chatCompletion);
            Request request = new Request.Builder().header("Authorization", "Bearer " + apiKey).url(ValidateUtil.concatUrl(baseUrl, this.openAiConfig.getChatCompletionUrl())).post(RequestBody.create((MediaType)MediaType.parse((String)"application/json; charset=utf-8"), (String)requestString)).build();
            Response execute = this.okHttpClient.newCall(request).execute();
            if (execute.isSuccessful() && execute.body() != null) {
                ChatCompletionResponse chatCompletionResponse = (ChatCompletionResponse)mapper.readValue(execute.body().string(), ChatCompletionResponse.class);
                Choice choice = chatCompletionResponse.getChoices().get(0);
                finishReason = choice.getFinishReason();
                Usage usage = chatCompletionResponse.getUsage();
                allUsage.setCompletionTokens(allUsage.getCompletionTokens() + usage.getCompletionTokens());
                allUsage.setTotalTokens(allUsage.getTotalTokens() + usage.getTotalTokens());
                allUsage.setPromptTokens(allUsage.getPromptTokens() + usage.getPromptTokens());
                if ("tool_calls".equals(finishReason)) {
                    ChatMessage message = choice.getMessage();
                    List<ToolCall> toolCalls = message.getToolCalls();
                    ArrayList<ChatMessage> messages = new ArrayList<ChatMessage>(chatCompletion.getMessages());
                    messages.add(message);
                    for (ToolCall toolCall : toolCalls) {
                        String functionName = toolCall.getFunction().getName();
                        String arguments = toolCall.getFunction().getArguments();
                        String functionResponse = ToolUtil.invoke(functionName, arguments);
                        messages.add(ChatMessage.withTool(functionResponse, toolCall.getId()));
                    }
                    chatCompletion.setMessages(messages);
                    continue;
                }
                chatCompletionResponse.setUsage(allUsage);
                return chatCompletionResponse;
            }
            return null;
        }
        return null;
    }

    @Override
    public ChatCompletionResponse chatCompletion(ChatCompletion chatCompletion) throws Exception {
        return this.chatCompletion(null, null, chatCompletion);
    }

    @Override
    public void chatCompletionStream(String baseUrl, String apiKey, ChatCompletion chatCompletion, SseListener eventSourceListener) throws Exception {
        if (baseUrl == null || "".equals(baseUrl)) {
            baseUrl = this.openAiConfig.getApiHost();
        }
        if (apiKey == null || "".equals(apiKey)) {
            apiKey = this.openAiConfig.getApiKey();
        }
        chatCompletion.setStream(true);
        StreamOptions streamOptions = chatCompletion.getStreamOptions();
        if (streamOptions == null) {
            chatCompletion.setStreamOptions(new StreamOptions(true));
        }
        if (chatCompletion.getFunctions() != null && !chatCompletion.getFunctions().isEmpty()) {
            List<Tool> tools = ToolUtil.getAllFunctionTools(chatCompletion.getFunctions());
            chatCompletion.setTools(tools);
            if (tools == null) {
                chatCompletion.setParallelToolCalls(null);
            }
        } else {
            chatCompletion.setParallelToolCalls(null);
        }
        String finishReason = "first";
        while ("first".equals(finishReason) || "tool_calls".equals(finishReason)) {
            finishReason = null;
            ObjectMapper mapper = new ObjectMapper();
            String jsonString = mapper.writeValueAsString((Object)chatCompletion);
            Request request = new Request.Builder().header("Authorization", "Bearer " + apiKey).url(ValidateUtil.concatUrl(baseUrl, this.openAiConfig.getChatCompletionUrl())).post(RequestBody.create((MediaType)MediaType.parse((String)"application/json"), (String)jsonString)).build();
            this.factory.newEventSource(request, (EventSourceListener)eventSourceListener);
            eventSourceListener.getCountDownLatch().await();
            finishReason = eventSourceListener.getFinishReason();
            List<ToolCall> toolCalls = eventSourceListener.getToolCalls();
            if (!"tool_calls".equals(finishReason) || toolCalls.isEmpty()) continue;
            ChatMessage responseMessage = ChatMessage.withAssistant(eventSourceListener.getToolCalls());
            ArrayList<ChatMessage> messages = new ArrayList<ChatMessage>(chatCompletion.getMessages());
            messages.add(responseMessage);
            for (ToolCall toolCall : toolCalls) {
                String functionName = toolCall.getFunction().getName();
                String arguments = toolCall.getFunction().getArguments();
                String functionResponse = ToolUtil.invoke(functionName, arguments);
                messages.add(ChatMessage.withTool(functionResponse, toolCall.getId()));
            }
            eventSourceListener.setToolCalls(new ArrayList<ToolCall>());
            eventSourceListener.setToolCall(null);
            chatCompletion.setMessages(messages);
        }
    }

    @Override
    public void chatCompletionStream(ChatCompletion chatCompletion, SseListener eventSourceListener) throws Exception {
        this.chatCompletionStream(null, null, chatCompletion, eventSourceListener);
    }
}

