package main import ( "bufio" "crypto/sha1" "encoding/base64" "encoding/binary" "io" "net" "net/http" "strings" "time" ) // Hub manages websocket clients and broadcasts messages to them. type Hub struct { register chan *Client unregister chan *Client broadcast chan []byte clients map[*Client]bool } // Client represents a single websocket client connection. type Client struct { hub *Hub conn net.Conn send chan []byte } // NewHub constructs a new Hub instance. func NewHub() *Hub { return &Hub{ register: make(chan *Client), unregister: make(chan *Client), broadcast: make(chan []byte, 1024), clients: make(map[*Client]bool), } } // Run starts the hub event loop. func (h *Hub) Run() { for { select { case c := <-h.register: h.clients[c] = true case c := <-h.unregister: if _, ok := h.clients[c]; ok { delete(h.clients, c) close(c.send) _ = c.conn.Close() } case msg := <-h.broadcast: for c := range h.clients { select { case c.send <- msg: default: // Client is slow or dead; drop it. delete(h.clients, c) close(c.send) _ = c.conn.Close() } } } } } // computeAccept computes the Sec-WebSocket-Accept header value. func computeAccept(key string) string { const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" h := sha1.New() h.Write([]byte(key + magic)) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } // ServeWS upgrades the HTTP request to a WebSocket connection and registers a client. func (h *Hub) ServeWS(w http.ResponseWriter, r *http.Request) { if !strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") || strings.ToLower(r.Header.Get("Upgrade")) != "websocket" { http.Error(w, "upgrade required", http.StatusBadRequest) return } key := r.Header.Get("Sec-WebSocket-Key") if key == "" { http.Error(w, "missing Sec-WebSocket-Key", http.StatusBadRequest) return } accept := computeAccept(key) hj, ok := w.(http.Hijacker) if !ok { http.Error(w, "websocket not supported", http.StatusInternalServerError) return } conn, buf, err := hj.Hijack() if err != nil { return } // Write the upgrade response response := "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " + accept + "\r\n" + "\r\n" if _, err := buf.WriteString(response); err != nil { _ = conn.Close() return } if err := buf.Flush(); err != nil { _ = conn.Close() return } client := &Client{ hub: h, conn: conn, send: make(chan []byte, 256), } h.register <- client go client.writePump() go client.readPump() } // writeWSFrame writes a single WebSocket frame to the writer. func writeWSFrame(w io.Writer, opcode byte, payload []byte) error { // FIN set, opcode as provided header := []byte{0x80 | (opcode & 0x0F)} l := len(payload) switch { case l < 126: header = append(header, byte(l)) case l <= 65535: ext := make([]byte, 2) binary.BigEndian.PutUint16(ext, uint16(l)) header = append(header, 126) header = append(header, ext...) default: ext := make([]byte, 8) binary.BigEndian.PutUint64(ext, uint64(l)) header = append(header, 127) header = append(header, ext...) } if _, err := w.Write(header); err != nil { return err } if l > 0 { if _, err := w.Write(payload); err != nil { return err } } return nil } // readPump handles control frames from the client and discards other incoming frames. // This server is broadcast-only to clients. func (c *Client) readPump() { defer func() { c.hub.unregister <- c }() reader := bufio.NewReader(c.conn) for { // Read first two bytes b1, err := reader.ReadByte() if err != nil { return } b2, err := reader.ReadByte() if err != nil { return } opcode := b1 & 0x0F masked := (b2 & 0x80) != 0 length := int64(b2 & 0x7F) if length == 126 { ext := make([]byte, 2) if _, err := io.ReadFull(reader, ext); err != nil { return } length = int64(binary.BigEndian.Uint16(ext)) } else if length == 127 { ext := make([]byte, 8) if _, err := io.ReadFull(reader, ext); err != nil { return } length = int64(binary.BigEndian.Uint64(ext)) } var maskKey [4]byte if masked { if _, err := io.ReadFull(reader, maskKey[:]); err != nil { return } } // Handle Ping -> Pong if opcode == 0x9 && length <= 125 { payload := make([]byte, length) if _, err := io.ReadFull(reader, payload); err != nil { return } // Unmask if masked if masked { for i := int64(0); i < length; i++ { payload[i] ^= maskKey[i%4] } } _ = writeWSFrame(c.conn, 0xA, payload) // best-effort pong continue } // Close frame if opcode == 0x8 { // Drain payload if any, then exit if _, err := io.CopyN(io.Discard, reader, length); err != nil { return } return } // For other frames, just discard payload if _, err := io.CopyN(io.Discard, reader, length); err != nil { return } } } // writePump sends queued messages to the client and pings periodically to keep the connection alive. func (c *Client) writePump() { ticker := time.NewTicker(30 * time.Second) defer func() { ticker.Stop() _ = c.conn.Close() }() for { select { case msg, ok := <-c.send: if !ok { // try to send close frame _ = writeWSFrame(c.conn, 0x8, nil) return } if err := writeWSFrame(c.conn, 0x1, msg); err != nil { return } case <-ticker.C: // Send a ping to keep connections alive behind proxies _ = writeWSFrame(c.conn, 0x9, []byte("ping")) } } }