/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.zhipuai;

import io.micrometer.observation.ObservationConvention;
import io.micrometer.observation.ObservationRegistry;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.zhipuai.ZhiPuAiEmbeddingOptions;
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
import org.springframework.ai.zhipuai.api.ZhiPuApiConstants;
import org.springframework.lang.Nullable;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

public class ZhiPuAiEmbeddingModel
extends AbstractEmbeddingModel {
    private static final Logger logger = LoggerFactory.getLogger(ZhiPuAiEmbeddingModel.class);
    private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
    private final ZhiPuAiEmbeddingOptions defaultOptions;
    private final RetryTemplate retryTemplate;
    private final ZhiPuAiApi zhiPuAiApi;
    private final MetadataMode metadataMode;
    private final ObservationRegistry observationRegistry;
    private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

    public ZhiPuAiEmbeddingModel(ZhiPuAiApi zhiPuAiApi) {
        this(zhiPuAiApi, MetadataMode.EMBED);
    }

    public ZhiPuAiEmbeddingModel(ZhiPuAiApi zhiPuAiApi, MetadataMode metadataMode) {
        this(zhiPuAiApi, metadataMode, ZhiPuAiEmbeddingOptions.builder().model(ZhiPuAiApi.DEFAULT_EMBEDDING_MODEL).build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
    }

    public ZhiPuAiEmbeddingModel(ZhiPuAiApi zhiPuAiApi, MetadataMode metadataMode, ZhiPuAiEmbeddingOptions zhiPuAiEmbeddingOptions) {
        this(zhiPuAiApi, metadataMode, zhiPuAiEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
    }

    public ZhiPuAiEmbeddingModel(ZhiPuAiApi zhiPuAiApi, MetadataMode metadataMode, ZhiPuAiEmbeddingOptions zhiPuAiEmbeddingOptions, RetryTemplate retryTemplate) {
        this(zhiPuAiApi, metadataMode, zhiPuAiEmbeddingOptions, retryTemplate, ObservationRegistry.NOOP);
    }

    public ZhiPuAiEmbeddingModel(ZhiPuAiApi zhiPuAiApi, MetadataMode metadataMode, ZhiPuAiEmbeddingOptions options, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
        Assert.notNull((Object)zhiPuAiApi, (String)"ZhiPuAiApi must not be null");
        Assert.notNull((Object)metadataMode, (String)"metadataMode must not be null");
        Assert.notNull((Object)options, (String)"options must not be null");
        Assert.notNull((Object)retryTemplate, (String)"retryTemplate must not be null");
        Assert.notNull((Object)observationRegistry, (String)"observationRegistry must not be null");
        this.zhiPuAiApi = zhiPuAiApi;
        this.metadataMode = metadataMode;
        this.defaultOptions = options;
        this.retryTemplate = retryTemplate;
        this.observationRegistry = observationRegistry;
    }

    public float[] embed(Document document) {
        Assert.notNull((Object)document, (String)"Document must not be null");
        return this.embed(document.getFormattedContent(this.metadataMode));
    }

    public EmbeddingResponse call(EmbeddingRequest request) {
        Assert.notEmpty((Collection)request.getInstructions(), (String)"At least one text is required!");
        if (request.getInstructions().size() != 1) {
            logger.warn("ZhiPu Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
        }
        ZhiPuAiEmbeddingOptions requestOptions = this.mergeOptions(request.getOptions(), this.defaultOptions);
        EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder().embeddingRequest(request).provider(ZhiPuApiConstants.PROVIDER_NAME).requestOptions((EmbeddingOptions)requestOptions).build();
        return (EmbeddingResponse)EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry).observe(() -> {
            ArrayList<float[]> embeddingList = new ArrayList<float[]>();
            ZhiPuAiApi.Usage totalUsage = new ZhiPuAiApi.Usage(0, 0, 0);
            for (String inputContent : request.getInstructions()) {
                ZhiPuAiApi.EmbeddingRequest<String> apiRequest = this.createEmbeddingRequest(inputContent, requestOptions);
                ZhiPuAiApi.EmbeddingList response = (ZhiPuAiApi.EmbeddingList)this.retryTemplate.execute(ctx -> (ZhiPuAiApi.EmbeddingList)this.zhiPuAiApi.embeddings(apiRequest).getBody());
                if (response == null || response.data() == null || response.data().isEmpty()) {
                    logger.warn("No embeddings returned for input: {}", (Object)inputContent);
                    embeddingList.add(new float[0]);
                    continue;
                }
                int completionTokens = totalUsage.completionTokens() + response.usage().completionTokens();
                int promptTokens = totalUsage.promptTokens() + response.usage().promptTokens();
                int totalTokens = totalUsage.totalTokens() + response.usage().totalTokens();
                totalUsage = new ZhiPuAiApi.Usage(completionTokens, promptTokens, totalTokens);
                embeddingList.add(((ZhiPuAiApi.Embedding)response.data().get(0)).embedding());
            }
            String model = request.getOptions() != null && request.getOptions().getModel() != null ? request.getOptions().getModel() : "unknown";
            EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(model, (Usage)this.getDefaultUsage(totalUsage));
            AtomicInteger indexCounter = new AtomicInteger(0);
            List<Embedding> embeddings = embeddingList.stream().map(e -> new Embedding(e, Integer.valueOf(indexCounter.getAndIncrement()))).toList();
            EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, metadata);
            observationContext.setResponse((Object)embeddingResponse);
            return embeddingResponse;
        });
    }

    private DefaultUsage getDefaultUsage(ZhiPuAiApi.Usage usage) {
        return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), (Object)usage);
    }

    private ZhiPuAiEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeOptions, ZhiPuAiEmbeddingOptions defaultOptions) {
        ZhiPuAiEmbeddingOptions runtimeOptionsForProvider = (ZhiPuAiEmbeddingOptions)ModelOptionsUtils.copyToTarget((Object)runtimeOptions, EmbeddingOptions.class, ZhiPuAiEmbeddingOptions.class);
        if (runtimeOptionsForProvider == null) {
            return defaultOptions;
        }
        return ZhiPuAiEmbeddingOptions.builder().model((String)ModelOptionsUtils.mergeOption((Object)runtimeOptionsForProvider.getModel(), (Object)defaultOptions.getModel())).dimensions((Integer)ModelOptionsUtils.mergeOption((Object)runtimeOptionsForProvider.getDimensions(), (Object)defaultOptions.getDimensions())).build();
    }

    private ZhiPuAiApi.EmbeddingRequest<String> createEmbeddingRequest(String text, EmbeddingOptions requestOptions) {
        return new ZhiPuAiApi.EmbeddingRequest<String>(text, requestOptions.getModel(), requestOptions.getDimensions());
    }

    public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
        this.observationConvention = observationConvention;
    }
}

