/ src / utils / mcpWebSocketTransport.ts
mcpWebSocketTransport.ts
  1  import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
  2  import {
  3    type JSONRPCMessage,
  4    JSONRPCMessageSchema,
  5  } from '@modelcontextprotocol/sdk/types.js'
  6  import type WsWebSocket from 'ws'
  7  import { logForDiagnosticsNoPII } from './diagLogs.js'
  8  import { toError } from './errors.js'
  9  import { jsonParse, jsonStringify } from './slowOperations.js'
 10  
 11  // WebSocket readyState constants (same for both native and ws)
 12  const WS_CONNECTING = 0
 13  const WS_OPEN = 1
 14  
 15  // Minimal interface shared by globalThis.WebSocket and ws.WebSocket
 16  type WebSocketLike = {
 17    readonly readyState: number
 18    close(): void
 19    send(data: string): void
 20  }
 21  
 22  export class WebSocketTransport implements Transport {
 23    private started = false
 24    private opened: Promise<void>
 25    private isBun = typeof Bun !== 'undefined'
 26  
 27    constructor(private ws: WebSocketLike) {
 28      this.opened = new Promise((resolve, reject) => {
 29        if (this.ws.readyState === WS_OPEN) {
 30          resolve()
 31        } else if (this.isBun) {
 32          const nws = this.ws as unknown as globalThis.WebSocket
 33          const onOpen = () => {
 34            nws.removeEventListener('open', onOpen)
 35            nws.removeEventListener('error', onError)
 36            resolve()
 37          }
 38          const onError = (event: Event) => {
 39            nws.removeEventListener('open', onOpen)
 40            nws.removeEventListener('error', onError)
 41            logForDiagnosticsNoPII('error', 'mcp_websocket_connect_fail')
 42            reject(event)
 43          }
 44          nws.addEventListener('open', onOpen)
 45          nws.addEventListener('error', onError)
 46        } else {
 47          const nws = this.ws as unknown as WsWebSocket
 48          nws.on('open', () => {
 49            resolve()
 50          })
 51          nws.on('error', error => {
 52            logForDiagnosticsNoPII('error', 'mcp_websocket_connect_fail')
 53            reject(error)
 54          })
 55        }
 56      })
 57  
 58      // Attach persistent event handlers
 59      if (this.isBun) {
 60        const nws = this.ws as unknown as globalThis.WebSocket
 61        nws.addEventListener('message', this.onBunMessage)
 62        nws.addEventListener('error', this.onBunError)
 63        nws.addEventListener('close', this.onBunClose)
 64      } else {
 65        const nws = this.ws as unknown as WsWebSocket
 66        nws.on('message', this.onNodeMessage)
 67        nws.on('error', this.onNodeError)
 68        nws.on('close', this.onNodeClose)
 69      }
 70    }
 71  
 72    onclose?: () => void
 73    onerror?: (error: Error) => void
 74    onmessage?: (message: JSONRPCMessage) => void
 75  
 76    // Bun (native WebSocket) event handlers
 77    private onBunMessage = (event: MessageEvent) => {
 78      try {
 79        const data =
 80          typeof event.data === 'string' ? event.data : String(event.data)
 81        const messageObj = jsonParse(data)
 82        const message = JSONRPCMessageSchema.parse(messageObj)
 83        this.onmessage?.(message)
 84      } catch (error) {
 85        this.handleError(error)
 86      }
 87    }
 88  
 89    private onBunError = () => {
 90      this.handleError(new Error('WebSocket error'))
 91    }
 92  
 93    private onBunClose = () => {
 94      this.handleCloseCleanup()
 95    }
 96  
 97    // Node (ws package) event handlers
 98    private onNodeMessage = (data: Buffer) => {
 99      try {
100        const messageObj = jsonParse(data.toString('utf-8'))
101        const message = JSONRPCMessageSchema.parse(messageObj)
102        this.onmessage?.(message)
103      } catch (error) {
104        this.handleError(error)
105      }
106    }
107  
108    private onNodeError = (error: unknown) => {
109      this.handleError(error)
110    }
111  
112    private onNodeClose = () => {
113      this.handleCloseCleanup()
114    }
115  
116    // Shared error handler
117    private handleError(error: unknown): void {
118      logForDiagnosticsNoPII('error', 'mcp_websocket_message_fail')
119      this.onerror?.(toError(error))
120    }
121  
122    // Shared close handler with listener cleanup
123    private handleCloseCleanup(): void {
124      this.onclose?.()
125      // Clean up listeners after close
126      if (this.isBun) {
127        const nws = this.ws as unknown as globalThis.WebSocket
128        nws.removeEventListener('message', this.onBunMessage)
129        nws.removeEventListener('error', this.onBunError)
130        nws.removeEventListener('close', this.onBunClose)
131      } else {
132        const nws = this.ws as unknown as WsWebSocket
133        nws.off('message', this.onNodeMessage)
134        nws.off('error', this.onNodeError)
135        nws.off('close', this.onNodeClose)
136      }
137    }
138  
139    /**
140     * Starts listening for messages on the WebSocket.
141     */
142    async start(): Promise<void> {
143      if (this.started) {
144        throw new Error('Start can only be called once per transport.')
145      }
146      await this.opened
147      if (this.ws.readyState !== WS_OPEN) {
148        logForDiagnosticsNoPII('error', 'mcp_websocket_start_not_opened')
149        throw new Error('WebSocket is not open. Cannot start transport.')
150      }
151      this.started = true
152      // Unlike stdio, WebSocket connections are typically already established when the transport is created.
153      // No explicit connection action needed here, just attaching listeners.
154    }
155  
156    /**
157     * Closes the WebSocket connection.
158     */
159    async close(): Promise<void> {
160      if (
161        this.ws.readyState === WS_OPEN ||
162        this.ws.readyState === WS_CONNECTING
163      ) {
164        this.ws.close()
165      }
166      // Ensure listeners are removed even if close was called externally or connection was already closed
167      this.handleCloseCleanup()
168    }
169  
170    /**
171     * Sends a JSON-RPC message over the WebSocket connection.
172     */
173    async send(message: JSONRPCMessage): Promise<void> {
174      if (this.ws.readyState !== WS_OPEN) {
175        logForDiagnosticsNoPII('error', 'mcp_websocket_send_not_opened')
176        throw new Error('WebSocket is not open. Cannot send message.')
177      }
178      const json = jsonStringify(message)
179  
180      try {
181        if (this.isBun) {
182          // Native WebSocket.send() is synchronous (no callback)
183          this.ws.send(json)
184        } else {
185          await new Promise<void>((resolve, reject) => {
186            ;(this.ws as unknown as WsWebSocket).send(json, error => {
187              if (error) {
188                reject(error)
189              } else {
190                resolve()
191              }
192            })
193          })
194        }
195      } catch (error) {
196        this.handleError(error)
197        throw error
198      }
199    }
200  }