/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.service.impl;

import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.GraphInitData;
import com.alibaba.cloud.ai.graph.GraphRepresentation;
import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.PersistentConfig;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.serializer.StateSerializer;
import com.alibaba.cloud.ai.graph.serializer.plain_text.PlainTextStateSerializer;
import com.alibaba.cloud.ai.graph.state.StateSnapshot;
import com.alibaba.cloud.ai.param.GraphStreamParam;
import com.alibaba.cloud.ai.service.GraphService;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;

@Service
public class GraphServiceImpl
implements GraphService,
ApplicationContextAware {
    private static final Logger log = LoggerFactory.getLogger(GraphServiceImpl.class);
    private final ObjectMapper objectMapper;
    private final Map<String, Map<PersistentConfig, CompiledGraph>> graphCache = new ConcurrentHashMap<String, Map<PersistentConfig, CompiledGraph>>();
    private Map<String, StateGraph> stateGraphMap;
    private ApplicationContext applicationContext;

    public GraphServiceImpl(ObjectMapper objectMapper) {
        this.objectMapper = objectMapper;
    }

    @Override
    public Map<String, StateGraph> getStateGraphs() {
        if (this.stateGraphMap == null) {
            this.stateGraphMap = new ConcurrentHashMap<String, StateGraph>();
            this.applicationContext.getBeansOfType(StateGraph.class).forEach((_k, v) -> {
                if (v.getName() != null) {
                    this.stateGraphMap.put(v.getName(), (StateGraph)v);
                } else {
                    this.stateGraphMap.put(String.valueOf(v.hashCode()), (StateGraph)v);
                }
            });
            this.applicationContext.getBeansOfType(CompiledGraph.class).forEach((_k, v) -> {
                if (v.stateGraph.getName() != null) {
                    this.stateGraphMap.put(v.stateGraph.getName(), v.stateGraph);
                } else {
                    this.stateGraphMap.put(String.valueOf(v.hashCode()), v.stateGraph);
                }
            });
        }
        return this.stateGraphMap;
    }

    @Override
    public GraphInitData getPrintableGraphData(String name) throws GraphStateException {
        ArrayList<GraphInitData.ArgumentMetadata> inputArgs = new ArrayList<GraphInitData.ArgumentMetadata>();
        inputArgs.add(new GraphInitData.ArgumentMetadata(name, GraphInitData.ArgumentMetadata.ArgumentType.STRING, true));
        CompiledGraph compiledGraph = this.stateGraphMap.get(name).compile();
        GraphRepresentation graph = compiledGraph.getGraph(GraphRepresentation.Type.MERMAID, name, false);
        return new GraphInitData(name, graph.content(), inputArgs);
    }

    private String serializeOutput(String threadId, NodeOutput output) {
        try {
            StringBuilder sb = new StringBuilder();
            sb.append("[ \"").append(threadId).append("\",\n");
            String outputAsString = this.objectMapper.writeValueAsString((Object)output);
            sb.append(outputAsString).append("\n]");
            return sb.toString();
        }
        catch (IOException e) {
            log.error("error serializing state", (Throwable)e);
            return "";
        }
    }

    @Override
    public Flux<ServerSentEvent<String>> stream(String name, GraphStreamParam param, InputStream inputStream) throws Exception {
        Map dataMap;
        StateSerializer stateSerializer;
        String threadId = param.getThread();
        boolean resume = param.isResume();
        PersistentConfig persistentConfig = new PersistentConfig(param.getSessionId(), threadId);
        StateGraph stateGraph = this.stateGraphMap.get(name);
        CompiledGraph compiledGraph = this.graphCache.get(name).get(persistentConfig);
        if (resume && (stateSerializer = stateGraph.getStateSerializer()) instanceof PlainTextStateSerializer) {
            PlainTextStateSerializer textSerializer = (PlainTextStateSerializer)stateSerializer;
            dataMap = textSerializer.read((Reader)new InputStreamReader(inputStream)).data();
        } else {
            dataMap = (Map)this.objectMapper.readValue(inputStream, (TypeReference)new TypeReference<Map<String, Object>>(){});
        }
        AsyncGenerator generator = null;
        if (resume) {
            log.trace("RESUME REQUEST PREPARE");
            if (compiledGraph == null) {
                throw new IllegalStateException("Missing CompiledGraph in session!");
            }
            String checkpointId = param.getCheckpoint();
            String node = param.getNode();
            RunnableConfig config = RunnableConfig.builder().threadId(threadId).checkPointId(checkpointId).build();
            StateSnapshot stateSnapshot = compiledGraph.getState(config);
            config = stateSnapshot.config();
            log.trace("RESUME UPDATE STATE FORM {} USING CONFIG {}\n{}", new Object[]{node, config, dataMap});
            config = compiledGraph.updateState(config, dataMap, node);
            log.trace("RESUME REQUEST STREAM {}", (Object)config);
            generator = compiledGraph.streamSnapshots(null, config);
        } else {
            log.trace("dataMap: {}", (Object)dataMap);
            if (compiledGraph == null) {
                compiledGraph = stateGraph.compile(this.compileConfig(persistentConfig));
                Map compiledGraphMap = this.graphCache.computeIfAbsent(name, k -> new ConcurrentHashMap());
                compiledGraphMap.put(persistentConfig, compiledGraph);
            }
            generator = compiledGraph.streamSnapshots(dataMap, this.runnableConfig(persistentConfig));
        }
        Sinks.Many sink = Sinks.many().unicast().onBackpressureBuffer();
        Flux flux = sink.asFlux();
        ((CompletableFuture)generator.forEachAsync(s -> {
            try {
                String output = this.serializeOutput(threadId, (NodeOutput)s);
                sink.tryEmitNext((Object)ServerSentEvent.builder((Object)output).build());
                TimeUnit.SECONDS.sleep(1L);
            }
            catch (InterruptedException e) {
                throw new CompletionException(e);
            }
        }).thenAccept(v -> sink.tryEmitComplete())).exceptionally(e -> {
            log.error("Error streaming", e);
            sink.tryEmitError(e);
            return null;
        });
        return flux;
    }

    private CompileConfig compileConfig(PersistentConfig config) {
        return CompileConfig.builder().saverConfig(SaverConfig.builder().register("memory", (BaseCheckpointSaver)new MemorySaver()).build()).build();
    }

    RunnableConfig runnableConfig(PersistentConfig config) {
        return RunnableConfig.builder().threadId(config.threadId()).build();
    }

    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
    }
}

