/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.config;

import com.alibaba.cloud.ai.config.CodeExecutorProperties;
import com.alibaba.cloud.ai.connector.accessor.Accessor;
import com.alibaba.cloud.ai.connector.config.DbConfig;
import com.alibaba.cloud.ai.dispatcher.PlanExecutorDispatcher;
import com.alibaba.cloud.ai.dispatcher.PythonExecutorDispatcher;
import com.alibaba.cloud.ai.dispatcher.QueryRewriteDispatcher;
import com.alibaba.cloud.ai.dispatcher.SQLExecutorDispatcher;
import com.alibaba.cloud.ai.dispatcher.SemanticConsistenceDispatcher;
import com.alibaba.cloud.ai.dispatcher.SqlGenerateDispatcher;
import com.alibaba.cloud.ai.graph.GraphRepresentation;
import com.alibaba.cloud.ai.graph.KeyStrategyFactory;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.alibaba.cloud.ai.node.KeywordExtractNode;
import com.alibaba.cloud.ai.node.PlanExecutorNode;
import com.alibaba.cloud.ai.node.PlannerNode;
import com.alibaba.cloud.ai.node.PythonAnalyzeNode;
import com.alibaba.cloud.ai.node.PythonExecuteNode;
import com.alibaba.cloud.ai.node.PythonGenerateNode;
import com.alibaba.cloud.ai.node.QueryRewriteNode;
import com.alibaba.cloud.ai.node.ReportGeneratorNode;
import com.alibaba.cloud.ai.node.SchemaRecallNode;
import com.alibaba.cloud.ai.node.SemanticConsistencyNode;
import com.alibaba.cloud.ai.node.SqlExecuteNode;
import com.alibaba.cloud.ai.node.SqlGenerateNode;
import com.alibaba.cloud.ai.node.TableRelationNode;
import com.alibaba.cloud.ai.service.DatasourceService;
import com.alibaba.cloud.ai.service.UserPromptConfigService;
import com.alibaba.cloud.ai.service.base.BaseNl2SqlService;
import com.alibaba.cloud.ai.service.base.BaseSchemaService;
import com.alibaba.cloud.ai.service.business.BusinessKnowledgeRecallService;
import com.alibaba.cloud.ai.service.code.CodePoolExecutorService;
import com.alibaba.cloud.ai.service.semantic.SemanticModelRecallService;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class Nl2sqlConfiguration {
    private static final Logger logger = LoggerFactory.getLogger(Nl2sqlConfiguration.class);
    private BaseNl2SqlService nl2SqlService;
    private BaseSchemaService schemaService;
    private Accessor dbAccessor;
    private DbConfig dbConfig;
    private CodeExecutorProperties codeExecutorProperties;
    private CodePoolExecutorService codePoolExecutor;
    private SemanticModelRecallService semanticModelRecallService;
    private BusinessKnowledgeRecallService businessKnowledgeRecallService;
    private UserPromptConfigService promptConfigService;
    private DatasourceService datasourceService;

    public Nl2sqlConfiguration(@Qualifier(value="nl2SqlServiceImpl") BaseNl2SqlService nl2SqlService, @Qualifier(value="schemaServiceImpl") BaseSchemaService schemaService, @Qualifier(value="mysqlAccessor") Accessor dbAccessor, DbConfig dbConfig, CodeExecutorProperties codeExecutorProperties, CodePoolExecutorService codePoolExecutor, SemanticModelRecallService semanticModelRecallService, BusinessKnowledgeRecallService businessKnowledgeRecallService, UserPromptConfigService promptConfigService, DatasourceService datasourceService) {
        this.nl2SqlService = nl2SqlService;
        this.schemaService = schemaService;
        this.dbAccessor = dbAccessor;
        this.dbConfig = dbConfig;
        this.codeExecutorProperties = codeExecutorProperties;
        this.codePoolExecutor = codePoolExecutor;
        this.semanticModelRecallService = semanticModelRecallService;
        this.businessKnowledgeRecallService = businessKnowledgeRecallService;
        this.promptConfigService = promptConfigService;
        this.datasourceService = datasourceService;
    }

    @Bean
    public StateGraph nl2sqlGraph(ChatClient.Builder chatClientBuilder) throws GraphStateException {
        KeyStrategyFactory keyStrategyFactory = () -> {
            HashMap<String, ReplaceStrategy> keyStrategyHashMap = new HashMap<String, ReplaceStrategy>();
            keyStrategyHashMap.put("input", new ReplaceStrategy());
            keyStrategyHashMap.put("agentId", new ReplaceStrategy());
            keyStrategyHashMap.put("agentId", new ReplaceStrategy());
            keyStrategyHashMap.put("BUSINESS_KNOWLEDGE", new ReplaceStrategy());
            keyStrategyHashMap.put("SEMANTIC_MODEL", new ReplaceStrategy());
            keyStrategyHashMap.put("QUERY_REWRITE_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("KEYWORD_EXTRACT_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("EVIDENCES", new ReplaceStrategy());
            keyStrategyHashMap.put("TABLE_DOCUMENTS_FOR_SCHEMA", new ReplaceStrategy());
            keyStrategyHashMap.put("COLUMN_DOCUMENTS_BY_KEYWORDS_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("SQL_VALIDATE_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("SQL_VALIDATE_EXCEPTION_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("TABLE_RELATION_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("SQL_GENERATE_SCHEMA_MISSING_ADVICE", new ReplaceStrategy());
            keyStrategyHashMap.put("SQL_GENERATE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("SQL_GENERATE_COUNT", new ReplaceStrategy());
            keyStrategyHashMap.put("SEMANTIC_CONSISTENCY_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("SEMANTIC_CONSISTENCY_NODE_RECOMMEND_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("PLANNER_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("PLAN_CURRENT_STEP", new ReplaceStrategy());
            keyStrategyHashMap.put("PLAN_NEXT_NODE", new ReplaceStrategy());
            keyStrategyHashMap.put("PLAN_VALIDATION_STATUS", new ReplaceStrategy());
            keyStrategyHashMap.put("PLAN_VALIDATION_ERROR", new ReplaceStrategy());
            keyStrategyHashMap.put("PLAN_REPAIR_COUNT", new ReplaceStrategy());
            keyStrategyHashMap.put("SQL_EXECUTE_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("SQL_EXECUTE_NODE_EXCEPTION_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("SQL_RESULT_LIST_MEMORY", new ReplaceStrategy());
            keyStrategyHashMap.put("PYTHON_IS_SUCCESS", new ReplaceStrategy());
            keyStrategyHashMap.put("PYTHON_TRIES_COUNT", new ReplaceStrategy());
            keyStrategyHashMap.put("PYTHON_EXECUTE_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("PYTHON_GENERATE_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("PYTHON_ANALYSIS_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("result", new ReplaceStrategy());
            return keyStrategyHashMap;
        };
        StateGraph stateGraph = new StateGraph("nl2sqlGraph", keyStrategyFactory).addNode("QUERY_REWRITE_NODE", AsyncNodeAction.node_async((NodeAction)new QueryRewriteNode(this.nl2SqlService))).addNode("KEYWORD_EXTRACT_NODE", AsyncNodeAction.node_async((NodeAction)new KeywordExtractNode(this.nl2SqlService))).addNode("SCHEMA_RECALL_NODE", AsyncNodeAction.node_async((NodeAction)new SchemaRecallNode(this.schemaService))).addNode("TABLE_RELATION_NODE", AsyncNodeAction.node_async((NodeAction)new TableRelationNode(this.schemaService, this.nl2SqlService, this.businessKnowledgeRecallService, this.semanticModelRecallService))).addNode("SQL_GENERATE_NODE", AsyncNodeAction.node_async((NodeAction)new SqlGenerateNode(chatClientBuilder, this.nl2SqlService))).addNode("PLANNER_NODE", AsyncNodeAction.node_async((NodeAction)new PlannerNode(chatClientBuilder))).addNode("PLAN_EXECUTOR_NODE", AsyncNodeAction.node_async((NodeAction)new PlanExecutorNode())).addNode("SQL_EXECUTE_NODE", AsyncNodeAction.node_async((NodeAction)new SqlExecuteNode(this.dbAccessor, this.datasourceService))).addNode("PYTHON_GENERATE_NODE", AsyncNodeAction.node_async((NodeAction)new PythonGenerateNode(this.codeExecutorProperties, chatClientBuilder))).addNode("PYTHON_EXECUTE_NODE", AsyncNodeAction.node_async((NodeAction)new PythonExecuteNode(this.codePoolExecutor))).addNode("PYTHON_ANALYZE_NODE", AsyncNodeAction.node_async((NodeAction)new PythonAnalyzeNode(chatClientBuilder))).addNode("REPORT_GENERATOR_NODE", AsyncNodeAction.node_async((NodeAction)new ReportGeneratorNode(chatClientBuilder, this.promptConfigService))).addNode("SEMANTIC_CONSISTENCY_NODE", AsyncNodeAction.node_async((NodeAction)new SemanticConsistencyNode(this.nl2SqlService)));
        stateGraph.addEdge("__START__", "QUERY_REWRITE_NODE").addConditionalEdges("QUERY_REWRITE_NODE", AsyncEdgeAction.edge_async((EdgeAction)new QueryRewriteDispatcher()), Map.of("KEYWORD_EXTRACT_NODE", "KEYWORD_EXTRACT_NODE", "__END__", "__END__")).addEdge("KEYWORD_EXTRACT_NODE", "SCHEMA_RECALL_NODE").addEdge("SCHEMA_RECALL_NODE", "TABLE_RELATION_NODE").addEdge("TABLE_RELATION_NODE", "PLANNER_NODE").addEdge("PLANNER_NODE", "PLAN_EXECUTOR_NODE").addEdge("PYTHON_GENERATE_NODE", "PYTHON_EXECUTE_NODE").addConditionalEdges("PYTHON_EXECUTE_NODE", AsyncEdgeAction.edge_async((EdgeAction)new PythonExecutorDispatcher()), Map.of("PYTHON_ANALYZE_NODE", "PYTHON_ANALYZE_NODE", "__END__", "__END__", "PYTHON_GENERATE_NODE", "PYTHON_GENERATE_NODE")).addEdge("PYTHON_ANALYZE_NODE", "PLAN_EXECUTOR_NODE").addConditionalEdges("PLAN_EXECUTOR_NODE", AsyncEdgeAction.edge_async((EdgeAction)new PlanExecutorDispatcher()), Map.of("PLANNER_NODE", "PLANNER_NODE", "SQL_EXECUTE_NODE", "SQL_EXECUTE_NODE", "PYTHON_GENERATE_NODE", "PYTHON_GENERATE_NODE", "REPORT_GENERATOR_NODE", "REPORT_GENERATOR_NODE", "__END__", "__END__")).addEdge("REPORT_GENERATOR_NODE", "__END__").addConditionalEdges("SQL_EXECUTE_NODE", AsyncEdgeAction.edge_async((EdgeAction)new SQLExecutorDispatcher()), Map.of("SQL_GENERATE_NODE", "SQL_GENERATE_NODE", "SEMANTIC_CONSISTENCY_NODE", "SEMANTIC_CONSISTENCY_NODE")).addConditionalEdges("SQL_GENERATE_NODE", AsyncEdgeAction.edge_async((EdgeAction)new SqlGenerateDispatcher()), Map.of("KEYWORD_EXTRACT_NODE", "KEYWORD_EXTRACT_NODE", "__END__", "__END__", "SQL_EXECUTE_NODE", "SQL_EXECUTE_NODE")).addConditionalEdges("SEMANTIC_CONSISTENCY_NODE", AsyncEdgeAction.edge_async((EdgeAction)new SemanticConsistenceDispatcher()), Map.of("SQL_GENERATE_NODE", "SQL_GENERATE_NODE", "PLAN_EXECUTOR_NODE", "PLAN_EXECUTOR_NODE"));
        GraphRepresentation graphRepresentation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML, "workflow graph");
        logger.info("\n\n");
        logger.info(graphRepresentation.content());
        logger.info("\n\n");
        return stateGraph;
    }
}

