package org.matrix.socket.pool;

import lombok.extern.slf4j.Slf4j;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 订阅的池子
 *
 * @author huangxiahao
 */
@Slf4j
public class MonitorSocketPool {

    /**
     * uniqueKey -> socket队列
     * key为uniqueKey value 为 socket队列
     */
    private ConcurrentHashMap<String, List<WebSocketSession>> socketMap = new ConcurrentHashMap<>();

    /**
     * socketId -> uniqueKey队列
     * key为socketId value 为 uniqueKey队列
     */
    private  ConcurrentHashMap<String, List<String>> socketKeyMap = new ConcurrentHashMap<>();

    public  void add(String key, WebSocketSession clientSocket) throws IOException {
        if (clientSocket != null & key != null) {
            List<WebSocketSession> webSocketSessions = socketMap.get(key);
            if (webSocketSessions != null) {
                List<String> list = socketKeyMap.get(clientSocket.getId());
                if (list!=null&&list.contains(key)){
                    log.info("用户重复订阅 批次：{}",key);
                }else {
                    webSocketSessions.add(clientSocket);
                }
            } else {
                webSocketSessions = new ArrayList<>();
                webSocketSessions.add(clientSocket);
                socketMap.put(key, webSocketSessions);
            }
            addKeyMap(clientSocket.getId(), key);
        }
    }

    public void addKeyMap(String socketId, String key) {
        if (socketId != null & key != null) {
            List<String> list = socketKeyMap.get(socketId);
            if (list != null) {
                list.add(key);
            } else {
                list = new ArrayList<>();
                list.add(key);
                socketKeyMap.put(socketId, list);
            }
        }
    }

    public void remove(WebSocketSession clientSocket) {
        String socketId = clientSocket.getId();
        List<String> list = socketKeyMap.get(socketId);
        if (list != null) {
            for (String s : list) {
                socketMap.get(s);
                socketMap.get(s).remove(clientSocket);
                if (socketMap.get(s).size()==0){
                    socketMap.remove(s);
                }
            }
        }
        socketKeyMap.remove(socketId);

    }

    public List<WebSocketSession> get(String key) {
        return socketMap.get(key);
    }
}