webRequestExt.ts
1 import type { KeysOfUnion, SetReturnType, UnknownRecord } from 'type-fest'; 2 import type { WebRequest } from 'wxt/browser'; 3 4 import { concatUint8Arrays } from '@/helpers/concatUint8Arrays'; 5 import { isObject } from '@/helpers/isObject'; 6 import type { GraphqlRequest, GraphqlResponse } from '@/types/graphql'; 7 import type { MaybePromise } from '@/types/util'; 8 import { Logger } from '@/utils/logger'; 9 10 import type { Context } from './ctx'; 11 12 export type RequestDetails = WebRequest.OnBeforeRequestDetailsType; 13 14 type ExtraInfoSpec = Exclude<WebRequest.OnBeforeRequestOptions, 'blocking'>[]; 15 16 interface OnBeforeRequestCtx { 17 ctx: Context; 18 details: RequestDetails; 19 logger: Logger; 20 } 21 22 interface RewriteCompleteResponseOptions<CtxAddition extends object> { 23 name: string; 24 filter: WebRequest.RequestFilter; 25 extraInfoSpec?: ExtraInfoSpec; 26 onBeforeRequest?: ( 27 ctx: OnBeforeRequestCtx & Partial<CtxAddition>, 28 ) => MaybePromise<WebRequest.BlockingResponse | undefined>; 29 isHandleRequest?: ( 30 ctx: OnBeforeRequestCtx & Partial<CtxAddition>, 31 ) => MaybePromise<[isHandle: false, reason: string, ...logArgs: unknown[]] | [isHandle: true]>; 32 rewrite: ( 33 ctx: CompleteResponseRewriterCtx & Partial<CtxAddition>, 34 ) => MaybePromise<Uint8Array | undefined>; 35 } 36 37 interface CompleteResponseRewriterCtx extends OnBeforeRequestCtx { 38 data: Uint8Array; 39 } 40 41 interface RewriteCompleteJsonResponseOptions<T extends object, CtxAddition extends object> 42 extends Omit<RewriteCompleteResponseOptions<CtxAddition>, 'rewrite'> { 43 rewrite: ( 44 ctx: CompleteJsonResponseRewriterCtx<T> & Partial<CtxAddition>, 45 ) => MaybePromise<T | void>; 46 } 47 48 interface CompleteJsonResponseRewriterCtx<T extends object> extends OnBeforeRequestCtx { 49 data: T; 50 } 51 52 interface RewriteCompleteGraphqlOptions<T extends object, CtxAddition extends object> 53 extends RewriteCompleteJsonResponseOptions<T, CtxAddition> { 54 ifQueriesFields?: KeysOfUnion<T>[]; 55 } 56 57 export class WebRequestExt implements Disposable { 58 readonly #ctx: Context; 59 readonly #unregisterHandlers: (() => void)[] = []; 60 readonly #utf8Decoder = new TextDecoder('utf-8'); 61 62 constructor(ctx: Context) { 63 this.#ctx = ctx; 64 } 65 66 rewriteCompleteResponse = <CtxAddition extends object = Record<never, never>>({ 67 name, 68 filter, 69 extraInfoSpec, 70 onBeforeRequest, 71 isHandleRequest, 72 rewrite, 73 }: RewriteCompleteResponseOptions<CtxAddition>) => { 74 this.onBeforeRequest( 75 async details => { 76 const { requestId } = details; 77 78 let loggerPrefix = requestId; 79 80 if ('originUrl' in details && details.originUrl) { 81 let url = details.originUrl; 82 try { 83 const originUrl = new URL(details.originUrl); 84 85 if (originUrl.origin === this.#ctx.origin) { 86 url = details.originUrl.slice(originUrl.origin.length); 87 } 88 } catch { 89 // noop 90 } 91 loggerPrefix = `${url} | ${name} | ${loggerPrefix}`; 92 } 93 94 const logger = Logger.create(loggerPrefix); 95 logger.debug('gonna rewrite request', details); 96 97 const requestCtx: OnBeforeRequestCtx = { ctx: this.#ctx, details, logger }; 98 99 try { 100 let res: ReturnType<NonNullable<typeof onBeforeRequest>>; 101 102 if ( 103 typeof onBeforeRequest === 'function' && 104 (res = await onBeforeRequest(requestCtx as typeof requestCtx & CtxAddition)) 105 ) { 106 logger.debug('`onBeforeRequest` returned', res); 107 return res; 108 } 109 } catch (err) { 110 logger.error('`onBeforeRequest` failed', err); 111 } 112 113 try { 114 let isHandle = false, 115 ignoreReason: unknown[] | undefined; 116 117 if ( 118 typeof isHandleRequest === 'function' && 119 ([isHandle, ...ignoreReason] = await isHandleRequest( 120 requestCtx as typeof requestCtx & CtxAddition, 121 )) && 122 !isHandle 123 ) { 124 logger.debug('ignoring request due to `isHandleRequest`:', ...ignoreReason); 125 return; 126 } 127 } catch (err) { 128 logger.error('`isHandleRequest` failed', err); 129 } 130 131 const filter = browser.webRequest.filterResponseData(requestId); 132 const buffers: Uint8Array[] = []; 133 134 // cSpell:ignore ondata 135 filter.ondata = event => { 136 buffers.push(new Uint8Array(event.data)); 137 }; 138 139 filter.onstop = async () => { 140 const buf = concatUint8Arrays(buffers); 141 let res: Uint8Array = buf; 142 143 if (buffers.length) { 144 const processedRes = await rewrite( 145 Object.assign< 146 typeof requestCtx, 147 Omit<CompleteResponseRewriterCtx, keyof typeof requestCtx> 148 >(requestCtx, { 149 data: buf, 150 }) as CompleteResponseRewriterCtx & CtxAddition, 151 ); 152 processedRes && (res = processedRes); 153 } 154 155 logger.debug('writing', res === buf ? 'unmodified' : 'modified', 'response'); 156 filter.write(res); 157 filter.close(); 158 }; 159 160 return {}; 161 }, 162 filter, 163 ['blocking', ...(extraInfoSpec || [])], 164 ); 165 }; 166 167 rewriteCompleteJsonResponse = < 168 T extends object, 169 CtxAddition extends object = Record<never, never>, 170 >({ 171 rewrite, 172 ...options 173 }: RewriteCompleteJsonResponseOptions<T, CtxAddition>) => { 174 this.rewriteCompleteResponse({ 175 ...options, 176 rewrite: async ctx => { 177 const responseStr = this.#utf8Decoder.decode(ctx.data); 178 179 try { 180 const res = JSON.parse(responseStr) as T; 181 182 ctx.logger.debug('rewriting', res); 183 184 const processedRes = await rewrite({ ...ctx, data: res }); 185 186 if (typeof processedRes !== 'undefined') { 187 const encoder = new TextEncoder(); 188 return encoder.encode(JSON.stringify(processedRes)); 189 } 190 } catch (err) { 191 ctx.logger.error('rewriting failed, leaving response untouched...', err); 192 } 193 }, 194 }); 195 }; 196 197 rewriteCompleteGraphql = <T extends object, CtxAddition extends object = Record<never, never>>({ 198 ifQueriesFields, 199 extraInfoSpec, 200 isHandleRequest, 201 rewrite, 202 ...options 203 }: RewriteCompleteGraphqlOptions<T, CtxAddition>) => { 204 const fieldsRegex = ifQueriesFields?.length 205 ? new RegExp(`[:{}\\s](${ifQueriesFields.join('|')})[@({\\s]`) 206 : undefined; 207 208 this.rewriteCompleteJsonResponse<GraphqlResponse<T>>({ 209 ...options, 210 extraInfoSpec: ['requestBody', ...(extraInfoSpec || [])], 211 isHandleRequest: async ctx => { 212 if (typeof isHandleRequest === 'function') { 213 const res = await isHandleRequest(ctx); 214 215 if (!res[0]) { 216 return res; 217 } 218 } 219 220 if (ctx.details.method !== 'POST') { 221 return [false, 'expected POST method, got', ctx.details.method]; 222 } 223 224 const requestBody = ctx.details.requestBody; 225 226 if (!requestBody?.raw?.[0]?.bytes) { 227 return [false, 'expected raw body', requestBody]; 228 } 229 230 if (!fieldsRegex) { 231 return [true]; 232 } 233 234 try { 235 const { query } = JSON.parse( 236 this.#utf8Decoder.decode(requestBody.raw[0].bytes as AllowSharedBufferSource), 237 ) as GraphqlRequest<UnknownRecord>; 238 239 const isIncludesFields = fieldsRegex.test(query); 240 241 return isIncludesFields 242 ? [true] 243 : [false, 'selection set misses', ifQueriesFields, query]; 244 } catch (err) { 245 ctx.logger.error('failed to check requestBody', err); 246 } 247 248 return [true]; 249 }, 250 rewrite: async ctx => { 251 const res = ctx.data; 252 253 if (!isObject(res) || 'errors' in res || !isObject(res.data)) { 254 return; 255 } 256 257 const newData = await rewrite({ ...ctx, data: res.data } as typeof ctx & { 258 data: (typeof res)['data']; 259 } & CtxAddition); 260 261 if (typeof newData !== 'undefined') { 262 res.data = newData; 263 } 264 265 return res; 266 }, 267 }); 268 }; 269 270 private onBeforeRequest = ( 271 handler: (details: RequestDetails) => MaybePromise<void | WebRequest.BlockingResponse>, 272 filter: WebRequest.RequestFilter, 273 extraInfoSpec?: WebRequest.OnBeforeRequestOptions[], 274 ): void => { 275 // `onBeforeRequest` is ok with `Promise<void>` 276 const callback = handler as SetReturnType<typeof handler, WebRequest.BlockingResponseOrPromise>; 277 browser.webRequest.onBeforeRequest.addListener( 278 callback, 279 { 280 ...filter, 281 urls: filter.urls.map(url => new URL(url, this.#ctx.origin).toString()), 282 }, 283 extraInfoSpec, 284 ); 285 this.#unregisterHandlers.push(() => 286 browser.webRequest.onBeforeRequest.removeListener(callback), 287 ); 288 }; 289 290 [Symbol.dispose]() { 291 for (const unregister of this.#unregisterHandlers) { 292 unregister(); 293 } 294 } 295 }