/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.service.base;

import com.alibaba.cloud.ai.connector.config.DbConfig;
import com.alibaba.cloud.ai.dto.schema.ColumnDTO;
import com.alibaba.cloud.ai.dto.schema.SchemaDTO;
import com.alibaba.cloud.ai.dto.schema.TableDTO;
import com.alibaba.cloud.ai.enums.BizDataSourceTypeEnum;
import com.alibaba.cloud.ai.request.SearchRequest;
import com.alibaba.cloud.ai.service.base.BaseVectorStoreService;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.ai.document.Document;

public abstract class BaseSchemaService {
    protected final DbConfig dbConfig;
    protected final Gson gson;
    protected final BaseVectorStoreService vectorStoreService;

    public BaseSchemaService(DbConfig dbConfig, Gson gson, BaseVectorStoreService vectorStoreService) {
        this.dbConfig = dbConfig;
        this.gson = gson;
        this.vectorStoreService = vectorStoreService;
    }

    public SchemaDTO mixRag(String query, List<String> keywords) {
        return this.mixRagForAgent(null, query, keywords);
    }

    public SchemaDTO mixRagForAgent(String agentId, String query, List<String> keywords) {
        SchemaDTO schemaDTO = new SchemaDTO();
        this.extractDatabaseName(schemaDTO);
        List<Document> tableDocuments = this.getTableDocuments(query, agentId);
        List<List<Document>> columnDocumentList = this.getColumnDocumentsByKeywords(keywords, agentId);
        this.buildSchemaFromDocuments(columnDocumentList, tableDocuments, schemaDTO);
        return schemaDTO;
    }

    public void buildSchemaFromDocuments(List<List<Document>> columnDocumentList, List<Document> tableDocuments, SchemaDTO schemaDTO) {
        this.processColumnWeights(columnDocumentList, tableDocuments);
        Map<String, Document> weightedColumns = this.selectWeightedColumns(columnDocumentList, 100);
        Set<String> foreignKeySet = this.extractForeignKeyRelations(tableDocuments);
        List<TableDTO> tableList = this.buildTableListFromDocuments(tableDocuments);
        this.expandTableDocumentsWithForeignKeys(tableDocuments, foreignKeySet, "table");
        this.expandColumnDocumentsWithForeignKeys(weightedColumns, foreignKeySet, "column");
        this.attachColumnsToTables(weightedColumns, tableList);
        schemaDTO.setTable(tableList);
        Set foreignKeys = tableDocuments.stream().map(doc -> doc.getMetadata().getOrDefault("foreignKey", "")).flatMap(fk -> Arrays.stream(fk.split("\u3001"))).filter(StringUtils::isNotBlank).collect(Collectors.toSet());
        schemaDTO.setForeignKeys(List.of(new ArrayList(foreignKeys)));
    }

    public List<Document> getTableDocuments(String query) {
        return this.getTableDocuments(query, null);
    }

    public List<Document> getTableDocuments(String query, String agentId) {
        if (agentId != null) {
            return this.vectorStoreService.getDocumentsForAgent(agentId, query, "table");
        }
        return this.vectorStoreService.getDocuments(query, "table");
    }

    public List<Document> getTableDocumentsForAgent(String agentId, String query) {
        return this.vectorStoreService.getDocumentsForAgent(agentId, query, "table");
    }

    public List<List<Document>> getColumnDocumentsByKeywords(List<String> keywords) {
        return this.getColumnDocumentsByKeywords(keywords, null);
    }

    public List<List<Document>> getColumnDocumentsByKeywords(List<String> keywords, String agentId) {
        if (agentId != null) {
            return this.getColumnDocumentsByKeywordsForAgent(agentId, keywords);
        }
        return keywords.stream().map(kw -> this.vectorStoreService.getDocuments((String)kw, "column")).collect(Collectors.toList());
    }

    public List<List<Document>> getColumnDocumentsByKeywordsForAgent(String agentId, List<String> keywords) {
        return keywords.stream().map(kw -> this.vectorStoreService.getDocumentsForAgent(agentId, (String)kw, "column")).collect(Collectors.toList());
    }

    private void expandColumnDocumentsWithForeignKeys(Map<String, Document> weightedColumns, Set<String> foreignKeySet, String vectorType) {
        Set<String> existingColumnNames = weightedColumns.keySet();
        HashSet<String> missingColumns = new HashSet<String>();
        for (String key : foreignKeySet) {
            if (existingColumnNames.contains(key)) continue;
            missingColumns.add(key);
        }
        for (String columnName : missingColumns) {
            this.addColumnsDocument(weightedColumns, columnName, vectorType);
        }
    }

    private void expandTableDocumentsWithForeignKeys(List<Document> tableDocuments, Set<String> foreignKeySet, String vectorType) {
        Set uniqueTableNames = tableDocuments.stream().map(doc -> (String)doc.getMetadata().get("name")).collect(Collectors.toSet());
        HashSet<String> missingTables = new HashSet<String>();
        for (String key : foreignKeySet) {
            String tableName;
            String[] parts = key.split("\\.");
            if (parts.length != 2 || uniqueTableNames.contains(tableName = parts[0])) continue;
            missingTables.add(tableName);
        }
        for (String tableName : missingTables) {
            this.addTableDocument(tableDocuments, tableName, vectorType);
        }
    }

    protected abstract void addTableDocument(List<Document> var1, String var2, String var3);

    protected abstract void addColumnsDocument(Map<String, Document> var1, String var2, String var3);

    protected Map<String, Document> selectWeightedColumns(List<List<Document>> columnDocumentList, int maxCount) {
        HashMap<String, Document> result = new HashMap<String, Document>();
        int index = 0;
        while (result.size() < maxCount) {
            boolean completed = true;
            for (List<Document> docs : columnDocumentList) {
                if (index >= docs.size()) continue;
                Document doc = docs.get(index);
                String id = doc.getId();
                if (!result.containsKey(id)) {
                    result.put(id, doc);
                }
                completed = false;
            }
            ++index;
            if (!completed) continue;
            break;
        }
        return result;
    }

    protected List<TableDTO> buildTableListFromDocuments(List<Document> documents) {
        ArrayList<TableDTO> tableList = new ArrayList<TableDTO>();
        for (Document doc : documents) {
            String primaryKey;
            Map meta = doc.getMetadata();
            TableDTO dto = new TableDTO();
            dto.setName((String)meta.get("name"));
            dto.setDescription((String)meta.get("description"));
            if (meta.containsKey("primaryKey") && StringUtils.isNotBlank((CharSequence)(primaryKey = (String)meta.get("primaryKey")))) {
                dto.setPrimaryKeys(List.of(primaryKey));
            }
            tableList.add(dto);
        }
        return tableList;
    }

    public void processColumnWeights(List<List<Document>> columnDocuments, List<Document> tableDocuments) {
        columnDocuments.replaceAll(docs -> docs.stream().filter(column -> tableDocuments.stream().anyMatch(table -> table.getMetadata().get("name").equals(column.getMetadata().get("tableName")))).peek(column -> {
            Optional<Document> matchingTable = tableDocuments.stream().filter(table -> table.getMetadata().get("name").equals(column.getMetadata().get("tableName"))).findFirst();
            matchingTable.ifPresent(tableDoc -> {
                Double tableScore = Optional.ofNullable((Double)tableDoc.getMetadata().get("score")).orElse(tableDoc.getScore());
                Double columnScore = Optional.ofNullable((Double)column.getMetadata().get("score")).orElse(column.getScore());
                if (tableScore != null && columnScore != null) {
                    column.getMetadata().put("score", columnScore * tableScore);
                }
            });
        }).sorted(Comparator.comparing(d -> (Double)d.getMetadata().get("score")).reversed()).collect(Collectors.toList()));
    }

    protected Set<String> extractForeignKeyRelations(List<Document> tableDocuments) {
        HashSet<String> result = new HashSet<String>();
        for (Document doc : tableDocuments) {
            String foreignKeyStr = doc.getMetadata().getOrDefault("foreignKey", "");
            if (!StringUtils.isNotBlank((CharSequence)foreignKeyStr)) continue;
            Arrays.stream(foreignKeyStr.split("\u3001")).forEach(pair -> {
                String[] parts = pair.split("=");
                if (parts.length == 2) {
                    result.add(parts[0].trim());
                    result.add(parts[1].trim());
                }
            });
        }
        return result;
    }

    protected void attachColumnsToTables(Map<String, Document> weightedColumns, List<TableDTO> tableList) {
        if (CollectionUtils.isEmpty(weightedColumns.values())) {
            return;
        }
        for (Document columnDoc : weightedColumns.values()) {
            Map meta = columnDoc.getMetadata();
            ColumnDTO columnDTO = new ColumnDTO();
            columnDTO.setName((String)meta.get("name"));
            columnDTO.setDescription((String)meta.get("description"));
            columnDTO.setType((String)meta.get("type"));
            String samplesStr = (String)meta.get("samples");
            if (StringUtils.isNotBlank((CharSequence)samplesStr)) {
                List samples = (List)this.gson.fromJson(samplesStr, new TypeToken<List<String>>(){}.getType());
                columnDTO.setData(samples);
            }
            String tableName = (String)meta.get("tableName");
            tableList.stream().filter(t -> t.getName().equals(tableName)).findFirst().ifPresent(dto -> dto.getColumn().add(columnDTO));
        }
    }

    protected Map<String, Object> getTableMetadata(String tableName) {
        List<Document> tableDocuments = this.getTableDocuments(tableName);
        for (Document doc : tableDocuments) {
            Map metadata = doc.getMetadata();
            if (!tableName.equals(metadata.get("name"))) continue;
            return metadata;
        }
        return null;
    }

    public void extractDatabaseName(SchemaDTO schemaDTO) {
        String pattern = ":\\d+/([^/?&]+)";
        if (BizDataSourceTypeEnum.isMysqlDialect((String)this.dbConfig.getDialectType())) {
            Pattern regex = Pattern.compile(pattern);
            Matcher matcher = regex.matcher(this.dbConfig.getUrl());
            if (matcher.find()) {
                schemaDTO.setName(matcher.group(1));
            }
        } else if (BizDataSourceTypeEnum.isPgDialect((String)this.dbConfig.getDialectType())) {
            schemaDTO.setName(this.dbConfig.getSchema());
        }
    }

    protected void handleDocumentQuery(List<Document> targetList, String key, String vectorType, Function<String, SearchRequest> requestBuilder, Function<SearchRequest, List<Document>> searchFunc) {
        SearchRequest request = requestBuilder.apply(key);
        request.setVectorType(vectorType);
        request.setTopK(10);
        List<Document> docs = searchFunc.apply(request);
        if (CollectionUtils.isNotEmpty(docs)) {
            targetList.addAll(docs);
        }
    }

    protected void handleDocumentQuery(Map<String, Document> targetMap, String key, String vectorType, Function<String, SearchRequest> requestBuilder, Function<SearchRequest, List<Document>> searchFunc) {
        SearchRequest request = requestBuilder.apply(key);
        request.setVectorType(vectorType);
        request.setTopK(10);
        List<Document> docs = searchFunc.apply(request);
        if (CollectionUtils.isNotEmpty(docs)) {
            for (Document doc : docs) {
                targetMap.putIfAbsent(doc.getId(), doc);
            }
        }
    }
}

