package org.springframework.ai.ollama;

import com.fasterxml.jackson.core.type.TypeReference;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import java.time.Duration;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.LegacyToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;

/* loaded from: input_file:org/springframework/ai/ollama/OllamaChatModel.class */
public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel {
    private static final String DONE = "done";
    private static final String METADATA_PROMPT_EVAL_COUNT = "prompt-eval-count";
    private static final String METADATA_EVAL_COUNT = "eval-count";
    private static final String METADATA_CREATED_AT = "created-at";
    private static final String METADATA_TOTAL_DURATION = "total-duration";
    private static final String METADATA_LOAD_DURATION = "load-duration";
    private static final String METADATA_PROMPT_EVAL_DURATION = "prompt-eval-duration";
    private static final String METADATA_EVAL_DURATION = "eval-duration";
    private final OllamaApi chatApi;
    private final OllamaOptions defaultOptions;
    private final ObservationRegistry observationRegistry;
    private final OllamaModelManager modelManager;
    private final ToolCallingManager toolCallingManager;
    private ChatModelObservationConvention observationConvention;
    private static final Logger logger = LoggerFactory.getLogger(OllamaChatModel.class);
    private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
    private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build();

    /* loaded from: input_file:org/springframework/ai/ollama/OllamaChatModel$Builder.class */
    public static final class Builder {
        private OllamaApi ollamaApi;
        private ToolCallingManager toolCallingManager;
        private FunctionCallbackResolver functionCallbackResolver;
        private List<FunctionCallback> toolFunctionCallbacks;
        private OllamaOptions defaultOptions = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build();
        private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
        private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();

        private Builder() {
        }

        public Builder ollamaApi(OllamaApi ollamaApi) {
            this.ollamaApi = ollamaApi;
            return this;
        }

        public Builder defaultOptions(OllamaOptions ollamaOptions) {
            this.defaultOptions = ollamaOptions;
            return this;
        }

        public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
            this.toolCallingManager = toolCallingManager;
            return this;
        }

        @Deprecated
        public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
            this.functionCallbackResolver = functionCallbackResolver;
            return this;
        }

        @Deprecated
        public Builder toolFunctionCallbacks(List<FunctionCallback> list) {
            this.toolFunctionCallbacks = list;
            return this;
        }

        public Builder observationRegistry(ObservationRegistry observationRegistry) {
            this.observationRegistry = observationRegistry;
            return this;
        }

        public Builder modelManagementOptions(ModelManagementOptions modelManagementOptions) {
            this.modelManagementOptions = modelManagementOptions;
            return this;
        }

        public OllamaChatModel build() {
            if (this.toolCallingManager != null) {
                Assert.isNull(this.functionCallbackResolver, "functionCallbackResolver must not be set when toolCallingManager is set");
                Assert.isNull(this.toolFunctionCallbacks, "toolFunctionCallbacks must not be set when toolCallingManager is set");
                return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager, this.observationRegistry, this.modelManagementOptions);
            }
            if (this.functionCallbackResolver == null) {
                return new OllamaChatModel(this.ollamaApi, this.defaultOptions, OllamaChatModel.DEFAULT_TOOL_CALLING_MANAGER, this.observationRegistry, this.modelManagementOptions);
            }
            Assert.isNull(this.toolCallingManager, "toolCallingManager must not be set when functionCallbackResolver is set");
            return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackResolver, this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks : List.of(), this.observationRegistry, this.modelManagementOptions);
        }
    }

    @Deprecated
    public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions ollamaOptions, @Nullable FunctionCallbackResolver functionCallbackResolver, @Nullable List<FunctionCallback> list, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
        this(ollamaApi, ollamaOptions, new LegacyToolCallingManager(functionCallbackResolver, list), observationRegistry, modelManagementOptions);
        logger.warn("This constructor is deprecated and will be removed in the next milestone. Please use the OllamaChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
    }

    public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions ollamaOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
        super((FunctionCallbackResolver) null, OllamaOptions.builder().build(), List.of());
        this.observationConvention = DEFAULT_OBSERVATION_CONVENTION;
        Assert.notNull(ollamaApi, "ollamaApi must not be null");
        Assert.notNull(ollamaOptions, "defaultOptions must not be null");
        Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
        Assert.notNull(observationRegistry, "observationRegistry must not be null");
        Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
        this.chatApi = ollamaApi;
        this.defaultOptions = ollamaOptions;
        this.toolCallingManager = toolCallingManager;
        this.observationRegistry = observationRegistry;
        this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
        initializeModel(ollamaOptions.getModel(), modelManagementOptions.pullModelStrategy());
    }

    public static Builder builder() {
        return new Builder();
    }

    static ChatResponseMetadata from(OllamaApi.ChatResponse chatResponse, ChatResponse chatResponse2) {
        Assert.notNull(chatResponse, "OllamaApi.ChatResponse must not be null");
        DefaultUsage defaultUsage = getDefaultUsage(chatResponse);
        Integer promptTokens = defaultUsage.getPromptTokens();
        Integer completionTokens = defaultUsage.getCompletionTokens();
        int intValue = defaultUsage.getTotalTokens().intValue();
        Duration evalDuration = chatResponse.getEvalDuration();
        Duration promptEvalDuration = chatResponse.getPromptEvalDuration();
        Duration loadDuration = chatResponse.getLoadDuration();
        Duration totalDuration = chatResponse.getTotalDuration();
        if (chatResponse2 != null && chatResponse2.getMetadata() != null) {
            if (chatResponse2.getMetadata().get(METADATA_EVAL_DURATION) != null) {
                evalDuration = evalDuration.plus((Duration) chatResponse2.getMetadata().get(METADATA_EVAL_DURATION));
            }
            if (chatResponse2.getMetadata().get(METADATA_PROMPT_EVAL_DURATION) != null) {
                promptEvalDuration = promptEvalDuration.plus((Duration) chatResponse2.getMetadata().get(METADATA_PROMPT_EVAL_DURATION));
            }
            if (chatResponse2.getMetadata().get(METADATA_LOAD_DURATION) != null) {
                loadDuration = loadDuration.plus((Duration) chatResponse2.getMetadata().get(METADATA_LOAD_DURATION));
            }
            if (chatResponse2.getMetadata().get(METADATA_TOTAL_DURATION) != null) {
                totalDuration = totalDuration.plus((Duration) chatResponse2.getMetadata().get(METADATA_TOTAL_DURATION));
            }
            if (chatResponse2.getMetadata().getUsage() != null) {
                promptTokens = Integer.valueOf(promptTokens.intValue() + chatResponse2.getMetadata().getUsage().getPromptTokens().intValue());
                completionTokens = Integer.valueOf(completionTokens.intValue() + chatResponse2.getMetadata().getUsage().getCompletionTokens().intValue());
                intValue += chatResponse2.getMetadata().getUsage().getTotalTokens().intValue();
            }
        }
        DefaultUsage defaultUsage2 = new DefaultUsage(promptTokens, completionTokens, Integer.valueOf(intValue));
        return ChatResponseMetadata.builder().usage(defaultUsage2).model(chatResponse.model()).keyValue(METADATA_CREATED_AT, chatResponse.createdAt()).keyValue(METADATA_EVAL_DURATION, evalDuration).keyValue(METADATA_EVAL_COUNT, Integer.valueOf(defaultUsage2.getCompletionTokens().intValue())).keyValue(METADATA_LOAD_DURATION, loadDuration).keyValue(METADATA_PROMPT_EVAL_DURATION, promptEvalDuration).keyValue(METADATA_PROMPT_EVAL_COUNT, Integer.valueOf(defaultUsage2.getPromptTokens().intValue())).keyValue(METADATA_TOTAL_DURATION, totalDuration).keyValue(DONE, chatResponse.done()).build();
    }

    private static DefaultUsage getDefaultUsage(OllamaApi.ChatResponse chatResponse) {
        return new DefaultUsage((Integer) Optional.ofNullable(chatResponse.promptEvalCount()).orElse(0), (Integer) Optional.ofNullable(chatResponse.evalCount()).orElse(0));
    }

    public ChatResponse call(Prompt prompt) {
        return internalCall(buildRequestPrompt(prompt), null);
    }

    private ChatResponse internalCall(Prompt prompt, ChatResponse chatResponse) {
        OllamaApi.ChatRequest ollamaChatRequest = ollamaChatRequest(prompt, false);
        ChatModelObservationContext build = ChatModelObservationContext.builder().prompt(prompt).provider(OllamaApi.PROVIDER_NAME).requestOptions(prompt.getOptions()).build();
        ChatResponse chatResponse2 = (ChatResponse) ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {
            return build;
        }, this.observationRegistry).observe(() -> {
            OllamaApi.ChatResponse chat = this.chatApi.chat(ollamaChatRequest);
            AssistantMessage assistantMessage = new AssistantMessage(chat.message().content(), Map.of(), chat.message().toolCalls() == null ? List.of() : chat.message().toolCalls().stream().map(toolCall -> {
                return new AssistantMessage.ToolCall("", "function", toolCall.function().name(), ModelOptionsUtils.toJsonString(toolCall.function().arguments()));
            }).toList());
            ChatGenerationMetadata chatGenerationMetadata = ChatGenerationMetadata.NULL;
            if (chat.promptEvalCount() != null && chat.evalCount() != null) {
                chatGenerationMetadata = ChatGenerationMetadata.builder().finishReason(chat.doneReason()).build();
            }
            ChatResponse chatResponse3 = new ChatResponse(List.of(new Generation(assistantMessage, chatGenerationMetadata)), from(chat, chatResponse));
            build.setResponse(chatResponse3);
            return chatResponse3;
        });
        if (!ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) || chatResponse2 == null || !chatResponse2.hasToolCalls()) {
            return chatResponse2;
        }
        ToolExecutionResult executeToolCalls = this.toolCallingManager.executeToolCalls(prompt, chatResponse2);
        return executeToolCalls.returnDirect() ? ChatResponse.builder().from(chatResponse2).generations(ToolExecutionResult.buildGenerations(executeToolCalls)).build() : internalCall(new Prompt(executeToolCalls.conversationHistory(), prompt.getOptions()), chatResponse2);
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        return internalStream(buildRequestPrompt(prompt), null);
    }

    private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse chatResponse) {
        return Flux.deferContextual(contextView -> {
            OllamaApi.ChatRequest ollamaChatRequest = ollamaChatRequest(prompt, true);
            ChatModelObservationContext build = ChatModelObservationContext.builder().prompt(prompt).provider(OllamaApi.PROVIDER_NAME).requestOptions(prompt.getOptions()).build();
            Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {
                return build;
            }, this.observationRegistry);
            observation.parentObservation((Observation) contextView.getOrDefault("micrometer.observation", (Object) null)).start();
            Flux flatMap = this.chatApi.streamingChat(ollamaChatRequest).map(chatResponse2 -> {
                String content = chatResponse2.message() != null ? chatResponse2.message().content() : "";
                List of = List.of();
                if (chatResponse2.message() != null && chatResponse2.message().toolCalls() != null) {
                    of = chatResponse2.message().toolCalls().stream().map(toolCall -> {
                        return new AssistantMessage.ToolCall("", "function", toolCall.function().name(), ModelOptionsUtils.toJsonString(toolCall.function().arguments()));
                    }).toList();
                }
                AssistantMessage assistantMessage = new AssistantMessage(content, Map.of(), of);
                ChatGenerationMetadata chatGenerationMetadata = ChatGenerationMetadata.NULL;
                if (chatResponse2.promptEvalCount() != null && chatResponse2.evalCount() != null) {
                    chatGenerationMetadata = ChatGenerationMetadata.builder().finishReason(chatResponse2.doneReason()).build();
                }
                return new ChatResponse(List.of(new Generation(assistantMessage, chatGenerationMetadata)), from(chatResponse2, chatResponse));
            }).flatMap(chatResponse3 -> {
                if (!ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) || !chatResponse3.hasToolCalls()) {
                    return Flux.just(chatResponse3);
                }
                ToolExecutionResult executeToolCalls = this.toolCallingManager.executeToolCalls(prompt, chatResponse3);
                return executeToolCalls.returnDirect() ? Flux.just(ChatResponse.builder().from(chatResponse3).generations(ToolExecutionResult.buildGenerations(executeToolCalls)).build()) : internalStream(new Prompt(executeToolCalls.conversationHistory(), prompt.getOptions()), chatResponse3);
            });
            Objects.requireNonNull(observation);
            Flux contextWrite = flatMap.doOnError(observation::error).doFinally(signalType -> {
                observation.stop();
            }).contextWrite(context -> {
                return context.put("micrometer.observation", observation);
            });
            MessageAggregator messageAggregator = new MessageAggregator();
            Objects.requireNonNull(build);
            return messageAggregator.aggregate(contextWrite, (v1) -> {
                r2.setResponse(v1);
            });
        });
    }

    Prompt buildRequestPrompt(Prompt prompt) {
        OllamaOptions ollamaOptions = null;
        if (prompt.getOptions() != null) {
            ToolCallingChatOptions options = prompt.getOptions();
            if (options instanceof ToolCallingChatOptions) {
                ollamaOptions = (OllamaOptions) ModelOptionsUtils.copyToTarget(options, ToolCallingChatOptions.class, OllamaOptions.class);
            } else {
                FunctionCallingOptions options2 = prompt.getOptions();
                ollamaOptions = options2 instanceof FunctionCallingOptions ? (OllamaOptions) ModelOptionsUtils.copyToTarget(options2, FunctionCallingOptions.class, OllamaOptions.class) : (OllamaOptions) ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, OllamaOptions.class);
            }
        }
        OllamaOptions ollamaOptions2 = (OllamaOptions) ModelOptionsUtils.merge(ollamaOptions, this.defaultOptions, OllamaOptions.class);
        if (ollamaOptions != null) {
            ollamaOptions2.setInternalToolExecutionEnabled((Boolean) ModelOptionsUtils.mergeOption(ollamaOptions.isInternalToolExecutionEnabled(), this.defaultOptions.isInternalToolExecutionEnabled()));
            ollamaOptions2.setToolNames(ToolCallingChatOptions.mergeToolNames(ollamaOptions.getToolNames(), this.defaultOptions.getToolNames()));
            ollamaOptions2.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(ollamaOptions.getToolCallbacks(), this.defaultOptions.getToolCallbacks()));
            ollamaOptions2.setToolContext(ToolCallingChatOptions.mergeToolContext(ollamaOptions.getToolContext(), this.defaultOptions.getToolContext()));
        } else {
            ollamaOptions2.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled());
            ollamaOptions2.setToolNames(this.defaultOptions.getToolNames());
            ollamaOptions2.setToolCallbacks(this.defaultOptions.getToolCallbacks());
            ollamaOptions2.setToolContext(this.defaultOptions.getToolContext());
        }
        if (!StringUtils.hasText(ollamaOptions2.getModel())) {
            throw new IllegalArgumentException("model cannot be null or empty");
        }
        ToolCallingChatOptions.validateToolCallbacks(ollamaOptions2.getToolCallbacks());
        return new Prompt(prompt.getInstructions(), ollamaOptions2);
    }

    OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean z) {
        List<OllamaApi.Message> list = prompt.getInstructions().stream().map(message -> {
            if (message instanceof UserMessage) {
                UserMessage userMessage = (UserMessage) message;
                OllamaApi.Message.Builder content = OllamaApi.Message.builder(OllamaApi.Message.Role.USER).content(message.getText());
                if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
                    content.images(userMessage.getMedia().stream().map(media -> {
                        return fromMediaData(media.getData());
                    }).toList());
                }
                return List.of(content.build());
            }
            if (message instanceof SystemMessage) {
                return List.of(OllamaApi.Message.builder(OllamaApi.Message.Role.SYSTEM).content(((SystemMessage) message).getText()).build());
            }
            if (!(message instanceof AssistantMessage)) {
                if (message instanceof ToolResponseMessage) {
                    return ((ToolResponseMessage) message).getResponses().stream().map(toolResponse -> {
                        return OllamaApi.Message.builder(OllamaApi.Message.Role.TOOL).content(toolResponse.responseData()).build();
                    }).toList();
                }
                throw new IllegalArgumentException("Unsupported message type: " + String.valueOf(message.getMessageType()));
            }
            AssistantMessage assistantMessage = (AssistantMessage) message;
            List<OllamaApi.Message.ToolCall> list2 = null;
            if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
                list2 = assistantMessage.getToolCalls().stream().map(toolCall -> {
                    return new OllamaApi.Message.ToolCall(new OllamaApi.Message.ToolCallFunction(toolCall.name(), (Map) JsonParser.fromJson(toolCall.arguments(), new TypeReference<Map<String, Object>>() { // from class: org.springframework.ai.ollama.OllamaChatModel.1
                    })));
                }).toList();
            }
            return List.of(OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content(assistantMessage.getText()).toolCalls(list2).build());
        }).flatMap((v0) -> {
            return v0.stream();
        }).toList();
        OllamaOptions options = prompt.getOptions();
        OllamaApi.ChatRequest.Builder options2 = OllamaApi.ChatRequest.builder(options.getModel()).stream(z).messages(list).options(options);
        if (options.getFormat() != null) {
            options2.format(options.getFormat());
        }
        if (options.getKeepAlive() != null) {
            options2.keepAlive(options.getKeepAlive());
        }
        List<ToolDefinition> resolveToolDefinitions = this.toolCallingManager.resolveToolDefinitions(options);
        if (!CollectionUtils.isEmpty(resolveToolDefinitions)) {
            options2.tools(getTools(resolveToolDefinitions));
        }
        return options2.build();
    }

    private String fromMediaData(Object obj) {
        if (obj instanceof byte[]) {
            return Base64.getEncoder().encodeToString((byte[]) obj);
        }
        if (obj instanceof String) {
            return (String) obj;
        }
        throw new IllegalArgumentException("Unsupported media data type: " + obj.getClass().getSimpleName());
    }

    private List<OllamaApi.ChatRequest.Tool> getTools(List<ToolDefinition> list) {
        return list.stream().map(toolDefinition -> {
            return new OllamaApi.ChatRequest.Tool(new OllamaApi.ChatRequest.Tool.Function(toolDefinition.name(), toolDefinition.description(), toolDefinition.inputSchema()));
        }).toList();
    }

    public ChatOptions getDefaultOptions() {
        return OllamaOptions.fromOptions(this.defaultOptions);
    }

    private void initializeModel(String str, PullModelStrategy pullModelStrategy) {
        if (pullModelStrategy == null || PullModelStrategy.NEVER.equals(pullModelStrategy)) {
            return;
        }
        this.modelManager.pullModel(str, pullModelStrategy);
    }

    public void setObservationConvention(ChatModelObservationConvention chatModelObservationConvention) {
        Assert.notNull(chatModelObservationConvention, "observationConvention cannot be null");
        this.observationConvention = chatModelObservationConvention;
    }
}
