/ upstreamproxy / relay.ts
relay.ts
  1  /* eslint-disable eslint-plugin-n/no-unsupported-features/node-builtins */
  2  /**
  3   * CONNECT-over-WebSocket relay for CCR upstreamproxy.
  4   *
  5   * Listens on localhost TCP, accepts HTTP CONNECT from curl/gh/kubectl/etc,
  6   * and tunnels bytes over WebSocket to the CCR upstreamproxy endpoint.
  7   * The CCR server-side terminates the tunnel, MITMs TLS, injects org-configured
  8   * credentials (e.g. DD-API-KEY), and forwards to the real upstream.
  9   *
 10   * WHY WebSocket and not raw CONNECT: CCR ingress is GKE L7 with path-prefix
 11   * routing; there's no connect_matcher in cdk-constructs. The session-ingress
 12   * tunnel (sessions/tunnel/v1alpha/tunnel.proto) already uses this pattern.
 13   *
 14   * Protocol: bytes are wrapped in UpstreamProxyChunk protobuf messages
 15   * (`message UpstreamProxyChunk { bytes data = 1; }`) for compatibility with
 16   * gateway.NewWebSocketStreamAdapter on the server side.
 17   */
 18  
 19  import { createServer, type Socket as NodeSocket } from 'node:net'
 20  import { logForDebugging } from '../utils/debug.js'
 21  import { getWebSocketTLSOptions } from '../utils/mtls.js'
 22  import { getWebSocketProxyAgent, getWebSocketProxyUrl } from '../utils/proxy.js'
 23  
 24  // The CCR container runs behind an egress gateway — direct outbound is
 25  // blocked, so the WS upgrade must go through the same HTTP CONNECT proxy
 26  // everything else uses. undici's globalThis.WebSocket does not consult
 27  // the global dispatcher for the upgrade, so under Node we use the ws package
 28  // with an explicit agent (same pattern as SessionsWebSocket). Bun's native
 29  // WebSocket takes a proxy URL directly. Preloaded in startNodeRelay so
 30  // openTunnel stays synchronous and the CONNECT state machine doesn't race.
 31  type WSCtor = typeof import('ws').default
 32  let nodeWSCtor: WSCtor | undefined
 33  
 34  // Intersection of the surface openTunnel touches. Both undici's
 35  // globalThis.WebSocket and the ws package satisfy this via property-style
 36  // onX handlers.
 37  type WebSocketLike = Pick<
 38    WebSocket,
 39    | 'onopen'
 40    | 'onmessage'
 41    | 'onerror'
 42    | 'onclose'
 43    | 'send'
 44    | 'close'
 45    | 'readyState'
 46    | 'binaryType'
 47  >
 48  
 49  // Envoy per-request buffer cap. Week-1 Datadog payloads won't hit this, but
 50  // design for it so git-push doesn't need a relay rewrite.
 51  const MAX_CHUNK_BYTES = 512 * 1024
 52  
 53  // Sidecar idle timeout is 50s; ping well inside that.
 54  const PING_INTERVAL_MS = 30_000
 55  
 56  /**
 57   * Encode an UpstreamProxyChunk protobuf message by hand.
 58   *
 59   * For `message UpstreamProxyChunk { bytes data = 1; }` the wire format is:
 60   *   tag = (field_number << 3) | wire_type = (1 << 3) | 2 = 0x0a
 61   *   followed by varint length, followed by the bytes.
 62   *
 63   * protobufjs would be the general answer; for a single-field bytes message
 64   * the hand encoding is 10 lines and avoids a runtime dep in the hot path.
 65   */
 66  export function encodeChunk(data: Uint8Array): Uint8Array {
 67    const len = data.length
 68    // varint encoding of length — most chunks fit in 1–3 length bytes
 69    const varint: number[] = []
 70    let n = len
 71    while (n > 0x7f) {
 72      varint.push((n & 0x7f) | 0x80)
 73      n >>>= 7
 74    }
 75    varint.push(n)
 76    const out = new Uint8Array(1 + varint.length + len)
 77    out[0] = 0x0a
 78    out.set(varint, 1)
 79    out.set(data, 1 + varint.length)
 80    return out
 81  }
 82  
 83  /**
 84   * Decode an UpstreamProxyChunk. Returns the data field, or null if malformed.
 85   * Tolerates the server sending a zero-length chunk (keepalive semantics).
 86   */
 87  export function decodeChunk(buf: Uint8Array): Uint8Array | null {
 88    if (buf.length === 0) return new Uint8Array(0)
 89    if (buf[0] !== 0x0a) return null
 90    let len = 0
 91    let shift = 0
 92    let i = 1
 93    while (i < buf.length) {
 94      const b = buf[i]!
 95      len |= (b & 0x7f) << shift
 96      i++
 97      if ((b & 0x80) === 0) break
 98      shift += 7
 99      if (shift > 28) return null
100    }
101    if (i + len > buf.length) return null
102    return buf.subarray(i, i + len)
103  }
104  
105  export type UpstreamProxyRelay = {
106    port: number
107    stop: () => void
108  }
109  
110  type ConnState = {
111    ws?: WebSocketLike
112    connectBuf: Buffer
113    pinger?: ReturnType<typeof setInterval>
114    // Bytes that arrived after the CONNECT header but before ws.onopen fired.
115    // TCP can coalesce CONNECT + ClientHello into one packet, and the socket's
116    // data callback can fire again while the WS handshake is still in flight.
117    // Both cases would silently drop bytes without this buffer.
118    pending: Buffer[]
119    wsOpen: boolean
120    // Set once the server's 200 Connection Established has been forwarded and
121    // the tunnel is carrying TLS. After that, writing a plaintext 502 would
122    // corrupt the client's TLS stream — just close instead.
123    established: boolean
124    // WS onerror is always followed by onclose; without a guard the second
125    // handler would sock.end() an already-ended socket. First caller wins.
126    closed: boolean
127  }
128  
129  /**
130   * Minimal socket abstraction so the CONNECT parser and WS tunnel plumbing
131   * are runtime-agnostic. Implementations handle write backpressure internally:
132   * Bun's sock.write() does partial writes and needs explicit tail-queueing;
133   * Node's net.Socket buffers unconditionally and never drops bytes.
134   */
135  type ClientSocket = {
136    write: (data: Uint8Array | string) => void
137    end: () => void
138  }
139  
140  function newConnState(): ConnState {
141    return {
142      connectBuf: Buffer.alloc(0),
143      pending: [],
144      wsOpen: false,
145      established: false,
146      closed: false,
147    }
148  }
149  
150  /**
151   * Start the relay. Returns the ephemeral port it bound and a stop function.
152   * Uses Bun.listen when available, otherwise Node's net.createServer — the CCR
153   * container runs the CLI under Node, not Bun.
154   */
155  export async function startUpstreamProxyRelay(opts: {
156    wsUrl: string
157    sessionId: string
158    token: string
159  }): Promise<UpstreamProxyRelay> {
160    const authHeader =
161      'Basic ' + Buffer.from(`${opts.sessionId}:${opts.token}`).toString('base64')
162    // WS upgrade itself is auth-gated (proto authn: PRIVATE_API) — the gateway
163    // wants the session-ingress JWT on the upgrade request, separate from the
164    // Proxy-Authorization that rides inside the tunneled CONNECT.
165    const wsAuthHeader = `Bearer ${opts.token}`
166  
167    const relay =
168      typeof Bun !== 'undefined'
169        ? startBunRelay(opts.wsUrl, authHeader, wsAuthHeader)
170        : await startNodeRelay(opts.wsUrl, authHeader, wsAuthHeader)
171  
172    logForDebugging(`[upstreamproxy] relay listening on 127.0.0.1:${relay.port}`)
173    return relay
174  }
175  
176  function startBunRelay(
177    wsUrl: string,
178    authHeader: string,
179    wsAuthHeader: string,
180  ): UpstreamProxyRelay {
181    // Bun TCP sockets don't auto-buffer partial writes: sock.write() returns
182    // the byte count actually handed to the kernel, and the remainder is
183    // silently dropped. When the kernel buffer fills, we queue the tail and
184    // let the drain handler flush it. Per-socket because the adapter closure
185    // outlives individual handler calls.
186    type BunState = ConnState & { writeBuf: Uint8Array[] }
187  
188    // eslint-disable-next-line custom-rules/require-bun-typeof-guard -- caller dispatches on typeof Bun
189    const server = Bun.listen<BunState>({
190      hostname: '127.0.0.1',
191      port: 0,
192      socket: {
193        open(sock) {
194          sock.data = { ...newConnState(), writeBuf: [] }
195        },
196        data(sock, data) {
197          const st = sock.data
198          const adapter: ClientSocket = {
199            write: payload => {
200              const bytes =
201                typeof payload === 'string'
202                  ? Buffer.from(payload, 'utf8')
203                  : payload
204              if (st.writeBuf.length > 0) {
205                st.writeBuf.push(bytes)
206                return
207              }
208              const n = sock.write(bytes)
209              if (n < bytes.length) st.writeBuf.push(bytes.subarray(n))
210            },
211            end: () => sock.end(),
212          }
213          handleData(adapter, st, data, wsUrl, authHeader, wsAuthHeader)
214        },
215        drain(sock) {
216          const st = sock.data
217          while (st.writeBuf.length > 0) {
218            const chunk = st.writeBuf[0]!
219            const n = sock.write(chunk)
220            if (n < chunk.length) {
221              st.writeBuf[0] = chunk.subarray(n)
222              return
223            }
224            st.writeBuf.shift()
225          }
226        },
227        close(sock) {
228          cleanupConn(sock.data)
229        },
230        error(sock, err) {
231          logForDebugging(`[upstreamproxy] client socket error: ${err.message}`)
232          cleanupConn(sock.data)
233        },
234      },
235    })
236  
237    return {
238      port: server.port,
239      stop: () => server.stop(true),
240    }
241  }
242  
243  // Exported so tests can exercise the Node path directly — the test runner is
244  // Bun, so the runtime dispatch in startUpstreamProxyRelay always picks Bun.
245  export async function startNodeRelay(
246    wsUrl: string,
247    authHeader: string,
248    wsAuthHeader: string,
249  ): Promise<UpstreamProxyRelay> {
250    nodeWSCtor = (await import('ws')).default
251    const states = new WeakMap<NodeSocket, ConnState>()
252  
253    const server = createServer(sock => {
254      const st = newConnState()
255      states.set(sock, st)
256      // Node's sock.write() buffers internally — a false return signals
257      // backpressure but the bytes are already queued, so no tail-tracking
258      // needed for correctness. Week-1 payloads won't stress the buffer.
259      const adapter: ClientSocket = {
260        write: payload => {
261          sock.write(typeof payload === 'string' ? payload : Buffer.from(payload))
262        },
263        end: () => sock.end(),
264      }
265      sock.on('data', data =>
266        handleData(adapter, st, data, wsUrl, authHeader, wsAuthHeader),
267      )
268      sock.on('close', () => cleanupConn(states.get(sock)))
269      sock.on('error', err => {
270        logForDebugging(`[upstreamproxy] client socket error: ${err.message}`)
271        cleanupConn(states.get(sock))
272      })
273    })
274  
275    return new Promise((resolve, reject) => {
276      server.once('error', reject)
277      server.listen(0, '127.0.0.1', () => {
278        const addr = server.address()
279        if (addr === null || typeof addr === 'string') {
280          reject(new Error('upstreamproxy: server has no TCP address'))
281          return
282        }
283        resolve({
284          port: addr.port,
285          stop: () => server.close(),
286        })
287      })
288    })
289  }
290  
291  /**
292   * Shared per-connection data handler. Phase 1 accumulates the CONNECT request;
293   * phase 2 forwards client bytes over the WS tunnel.
294   */
295  function handleData(
296    sock: ClientSocket,
297    st: ConnState,
298    data: Buffer,
299    wsUrl: string,
300    authHeader: string,
301    wsAuthHeader: string,
302  ): void {
303    // Phase 1: accumulate until we've seen the full CONNECT request
304    // (terminated by CRLF CRLF). curl/gh send this in one packet, but
305    // don't assume that.
306    if (!st.ws) {
307      st.connectBuf = Buffer.concat([st.connectBuf, data])
308      const headerEnd = st.connectBuf.indexOf('\r\n\r\n')
309      if (headerEnd === -1) {
310        // Guard against a client that never sends CRLFCRLF.
311        if (st.connectBuf.length > 8192) {
312          sock.write('HTTP/1.1 400 Bad Request\r\n\r\n')
313          sock.end()
314        }
315        return
316      }
317      const reqHead = st.connectBuf.subarray(0, headerEnd).toString('utf8')
318      const firstLine = reqHead.split('\r\n')[0] ?? ''
319      const m = firstLine.match(/^CONNECT\s+(\S+)\s+HTTP\/1\.[01]$/i)
320      if (!m) {
321        sock.write('HTTP/1.1 405 Method Not Allowed\r\n\r\n')
322        sock.end()
323        return
324      }
325      // Stash any bytes that arrived after the CONNECT header so
326      // openTunnel can flush them once the WS is open.
327      const trailing = st.connectBuf.subarray(headerEnd + 4)
328      if (trailing.length > 0) {
329        st.pending.push(Buffer.from(trailing))
330      }
331      st.connectBuf = Buffer.alloc(0)
332      openTunnel(sock, st, firstLine, wsUrl, authHeader, wsAuthHeader)
333      return
334    }
335    // Phase 2: WS exists. If it isn't OPEN yet, buffer; ws.onopen will
336    // flush. Once open, pump client bytes to WS in chunks.
337    if (!st.wsOpen) {
338      st.pending.push(Buffer.from(data))
339      return
340    }
341    forwardToWs(st.ws, data)
342  }
343  
344  function openTunnel(
345    sock: ClientSocket,
346    st: ConnState,
347    connectLine: string,
348    wsUrl: string,
349    authHeader: string,
350    wsAuthHeader: string,
351  ): void {
352    // core/websocket/stream.go picks JSON vs binary-proto from the upgrade
353    // request's Content-Type header (defaults to JSON). Without application/proto
354    // the server protojson.Unmarshals our hand-encoded binary chunks and fails
355    // silently with EOF.
356    const headers = {
357      'Content-Type': 'application/proto',
358      Authorization: wsAuthHeader,
359    }
360    let ws: WebSocketLike
361    if (nodeWSCtor) {
362      ws = new nodeWSCtor(wsUrl, {
363        headers,
364        agent: getWebSocketProxyAgent(wsUrl),
365        ...getWebSocketTLSOptions(),
366      }) as unknown as WebSocketLike
367    } else {
368      ws = new globalThis.WebSocket(wsUrl, {
369        // @ts-expect-error — Bun extension; not in lib.dom WebSocket types
370        headers,
371        proxy: getWebSocketProxyUrl(wsUrl),
372        tls: getWebSocketTLSOptions() || undefined,
373      })
374    }
375    ws.binaryType = 'arraybuffer'
376    st.ws = ws
377  
378    ws.onopen = () => {
379      // First chunk carries the CONNECT line plus Proxy-Authorization so the
380      // server can auth the tunnel and know the target host:port. Server
381      // responds with its own "HTTP/1.1 200" over the tunnel; we just pipe it.
382      const head =
383        `${connectLine}\r\n` + `Proxy-Authorization: ${authHeader}\r\n` + `\r\n`
384      ws.send(encodeChunk(Buffer.from(head, 'utf8')))
385      // Flush anything that arrived while the WS handshake was in flight —
386      // trailing bytes from the CONNECT packet and any data() callbacks that
387      // fired before onopen.
388      st.wsOpen = true
389      for (const buf of st.pending) {
390        forwardToWs(ws, buf)
391      }
392      st.pending = []
393      // Not all WS implementations expose ping(); empty chunk works as an
394      // application-level keepalive the server can ignore.
395      st.pinger = setInterval(sendKeepalive, PING_INTERVAL_MS, ws)
396    }
397  
398    ws.onmessage = ev => {
399      const raw =
400        ev.data instanceof ArrayBuffer
401          ? new Uint8Array(ev.data)
402          : new Uint8Array(Buffer.from(ev.data))
403      const payload = decodeChunk(raw)
404      if (payload && payload.length > 0) {
405        st.established = true
406        sock.write(payload)
407      }
408    }
409  
410    ws.onerror = ev => {
411      const msg = 'message' in ev ? String(ev.message) : 'websocket error'
412      logForDebugging(`[upstreamproxy] ws error: ${msg}`)
413      if (st.closed) return
414      st.closed = true
415      if (!st.established) {
416        sock.write('HTTP/1.1 502 Bad Gateway\r\n\r\n')
417      }
418      sock.end()
419      cleanupConn(st)
420    }
421  
422    ws.onclose = () => {
423      if (st.closed) return
424      st.closed = true
425      sock.end()
426      cleanupConn(st)
427    }
428  }
429  
430  function sendKeepalive(ws: WebSocketLike): void {
431    if (ws.readyState === WebSocket.OPEN) {
432      ws.send(encodeChunk(new Uint8Array(0)))
433    }
434  }
435  
436  function forwardToWs(ws: WebSocketLike, data: Buffer): void {
437    if (ws.readyState !== WebSocket.OPEN) return
438    for (let off = 0; off < data.length; off += MAX_CHUNK_BYTES) {
439      const slice = data.subarray(off, off + MAX_CHUNK_BYTES)
440      ws.send(encodeChunk(slice))
441    }
442  }
443  
444  function cleanupConn(st: ConnState | undefined): void {
445    if (!st) return
446    if (st.pinger) clearInterval(st.pinger)
447    if (st.ws && st.ws.readyState <= WebSocket.OPEN) {
448      try {
449        st.ws.close()
450      } catch {
451        // already closing
452      }
453    }
454    st.ws = undefined
455  }