/ src / request / webRequestExt.ts
webRequestExt.ts
  1  import type { KeysOfUnion, SetReturnType, UnknownRecord } from 'type-fest';
  2  import type { WebRequest } from 'wxt/browser';
  3  
  4  import { concatUint8Arrays } from '@/helpers/concatUint8Arrays';
  5  import { isObject } from '@/helpers/isObject';
  6  import type { GraphqlRequest, GraphqlResponse } from '@/types/graphql';
  7  import type { MaybePromise } from '@/types/util';
  8  import { Logger } from '@/utils/logger';
  9  
 10  import type { Context } from './ctx';
 11  
 12  export type RequestDetails = WebRequest.OnBeforeRequestDetailsType;
 13  
 14  type ExtraInfoSpec = Exclude<WebRequest.OnBeforeRequestOptions, 'blocking'>[];
 15  
 16  interface OnBeforeRequestCtx {
 17  	ctx: Context;
 18  	details: RequestDetails;
 19  	logger: Logger;
 20  }
 21  
 22  interface RewriteCompleteResponseOptions<CtxAddition extends object> {
 23  	name: string;
 24  	filter: WebRequest.RequestFilter;
 25  	extraInfoSpec?: ExtraInfoSpec;
 26  	onBeforeRequest?: (
 27  		ctx: OnBeforeRequestCtx & Partial<CtxAddition>,
 28  	) => MaybePromise<WebRequest.BlockingResponse | undefined>;
 29  	isHandleRequest?: (
 30  		ctx: OnBeforeRequestCtx & Partial<CtxAddition>,
 31  	) => MaybePromise<[isHandle: false, reason: string, ...logArgs: unknown[]] | [isHandle: true]>;
 32  	rewrite: (
 33  		ctx: CompleteResponseRewriterCtx & Partial<CtxAddition>,
 34  	) => MaybePromise<Uint8Array | undefined>;
 35  }
 36  
 37  interface CompleteResponseRewriterCtx extends OnBeforeRequestCtx {
 38  	data: Uint8Array;
 39  }
 40  
 41  interface RewriteCompleteJsonResponseOptions<T extends object, CtxAddition extends object>
 42  	extends Omit<RewriteCompleteResponseOptions<CtxAddition>, 'rewrite'> {
 43  	rewrite: (
 44  		ctx: CompleteJsonResponseRewriterCtx<T> & Partial<CtxAddition>,
 45  	) => MaybePromise<T | void>;
 46  }
 47  
 48  interface CompleteJsonResponseRewriterCtx<T extends object> extends OnBeforeRequestCtx {
 49  	data: T;
 50  }
 51  
 52  interface RewriteCompleteGraphqlOptions<T extends object, CtxAddition extends object>
 53  	extends RewriteCompleteJsonResponseOptions<T, CtxAddition> {
 54  	ifQueriesFields?: KeysOfUnion<T>[];
 55  }
 56  
 57  export class WebRequestExt implements Disposable {
 58  	readonly #ctx: Context;
 59  	readonly #unregisterHandlers: (() => void)[] = [];
 60  	readonly #utf8Decoder = new TextDecoder('utf-8');
 61  
 62  	constructor(ctx: Context) {
 63  		this.#ctx = ctx;
 64  	}
 65  
 66  	rewriteCompleteResponse = <CtxAddition extends object = Record<never, never>>({
 67  		name,
 68  		filter,
 69  		extraInfoSpec,
 70  		onBeforeRequest,
 71  		isHandleRequest,
 72  		rewrite,
 73  	}: RewriteCompleteResponseOptions<CtxAddition>) => {
 74  		this.onBeforeRequest(
 75  			async details => {
 76  				const { requestId } = details;
 77  
 78  				let loggerPrefix = requestId;
 79  
 80  				if ('originUrl' in details && details.originUrl) {
 81  					let url = details.originUrl;
 82  					try {
 83  						const originUrl = new URL(details.originUrl);
 84  
 85  						if (originUrl.origin === this.#ctx.origin) {
 86  							url = details.originUrl.slice(originUrl.origin.length);
 87  						}
 88  					} catch {
 89  						// noop
 90  					}
 91  					loggerPrefix = `${url} | ${name} | ${loggerPrefix}`;
 92  				}
 93  
 94  				const logger = Logger.create(loggerPrefix);
 95  				logger.debug('gonna rewrite request', details);
 96  
 97  				const requestCtx: OnBeforeRequestCtx = { ctx: this.#ctx, details, logger };
 98  
 99  				try {
100  					let res: ReturnType<NonNullable<typeof onBeforeRequest>>;
101  
102  					if (
103  						typeof onBeforeRequest === 'function' &&
104  						(res = await onBeforeRequest(requestCtx as typeof requestCtx & CtxAddition))
105  					) {
106  						logger.debug('`onBeforeRequest` returned', res);
107  						return res;
108  					}
109  				} catch (err) {
110  					logger.error('`onBeforeRequest` failed', err);
111  				}
112  
113  				try {
114  					let isHandle = false,
115  						ignoreReason: unknown[] | undefined;
116  
117  					if (
118  						typeof isHandleRequest === 'function' &&
119  						([isHandle, ...ignoreReason] = await isHandleRequest(
120  							requestCtx as typeof requestCtx & CtxAddition,
121  						)) &&
122  						!isHandle
123  					) {
124  						logger.debug('ignoring request due to `isHandleRequest`:', ...ignoreReason);
125  						return;
126  					}
127  				} catch (err) {
128  					logger.error('`isHandleRequest` failed', err);
129  				}
130  
131  				const filter = browser.webRequest.filterResponseData(requestId);
132  				const buffers: Uint8Array[] = [];
133  
134  				// cSpell:ignore ondata
135  				filter.ondata = event => {
136  					buffers.push(new Uint8Array(event.data));
137  				};
138  
139  				filter.onstop = async () => {
140  					const buf = concatUint8Arrays(buffers);
141  					let res: Uint8Array = buf;
142  
143  					if (buffers.length) {
144  						const processedRes = await rewrite(
145  							Object.assign<
146  								typeof requestCtx,
147  								Omit<CompleteResponseRewriterCtx, keyof typeof requestCtx>
148  							>(requestCtx, {
149  								data: buf,
150  							}) as CompleteResponseRewriterCtx & CtxAddition,
151  						);
152  						processedRes && (res = processedRes);
153  					}
154  
155  					logger.debug('writing', res === buf ? 'unmodified' : 'modified', 'response');
156  					filter.write(res);
157  					filter.close();
158  				};
159  
160  				return {};
161  			},
162  			filter,
163  			['blocking', ...(extraInfoSpec || [])],
164  		);
165  	};
166  
167  	rewriteCompleteJsonResponse = <
168  		T extends object,
169  		CtxAddition extends object = Record<never, never>,
170  	>({
171  		rewrite,
172  		...options
173  	}: RewriteCompleteJsonResponseOptions<T, CtxAddition>) => {
174  		this.rewriteCompleteResponse({
175  			...options,
176  			rewrite: async ctx => {
177  				const responseStr = this.#utf8Decoder.decode(ctx.data);
178  
179  				try {
180  					const res = JSON.parse(responseStr) as T;
181  
182  					ctx.logger.debug('rewriting', res);
183  
184  					const processedRes = await rewrite({ ...ctx, data: res });
185  
186  					if (typeof processedRes !== 'undefined') {
187  						const encoder = new TextEncoder();
188  						return encoder.encode(JSON.stringify(processedRes));
189  					}
190  				} catch (err) {
191  					ctx.logger.error('rewriting failed, leaving response untouched...', err);
192  				}
193  			},
194  		});
195  	};
196  
197  	rewriteCompleteGraphql = <T extends object, CtxAddition extends object = Record<never, never>>({
198  		ifQueriesFields,
199  		extraInfoSpec,
200  		isHandleRequest,
201  		rewrite,
202  		...options
203  	}: RewriteCompleteGraphqlOptions<T, CtxAddition>) => {
204  		const fieldsRegex = ifQueriesFields?.length
205  			? new RegExp(`[:{}\\s](${ifQueriesFields.join('|')})[@({\\s]`)
206  			: undefined;
207  
208  		this.rewriteCompleteJsonResponse<GraphqlResponse<T>>({
209  			...options,
210  			extraInfoSpec: ['requestBody', ...(extraInfoSpec || [])],
211  			isHandleRequest: async ctx => {
212  				if (typeof isHandleRequest === 'function') {
213  					const res = await isHandleRequest(ctx);
214  
215  					if (!res[0]) {
216  						return res;
217  					}
218  				}
219  
220  				if (ctx.details.method !== 'POST') {
221  					return [false, 'expected POST method, got', ctx.details.method];
222  				}
223  
224  				const requestBody = ctx.details.requestBody;
225  
226  				if (!requestBody?.raw?.[0]?.bytes) {
227  					return [false, 'expected raw body', requestBody];
228  				}
229  
230  				if (!fieldsRegex) {
231  					return [true];
232  				}
233  
234  				try {
235  					const { query } = JSON.parse(
236  						this.#utf8Decoder.decode(requestBody.raw[0].bytes as AllowSharedBufferSource),
237  					) as GraphqlRequest<UnknownRecord>;
238  
239  					const isIncludesFields = fieldsRegex.test(query);
240  
241  					return isIncludesFields
242  						? [true]
243  						: [false, 'selection set misses', ifQueriesFields, query];
244  				} catch (err) {
245  					ctx.logger.error('failed to check requestBody', err);
246  				}
247  
248  				return [true];
249  			},
250  			rewrite: async ctx => {
251  				const res = ctx.data;
252  
253  				if (!isObject(res) || 'errors' in res || !isObject(res.data)) {
254  					return;
255  				}
256  
257  				const newData = await rewrite({ ...ctx, data: res.data } as typeof ctx & {
258  					data: (typeof res)['data'];
259  				} & CtxAddition);
260  
261  				if (typeof newData !== 'undefined') {
262  					res.data = newData;
263  				}
264  
265  				return res;
266  			},
267  		});
268  	};
269  
270  	private onBeforeRequest = (
271  		handler: (details: RequestDetails) => MaybePromise<void | WebRequest.BlockingResponse>,
272  		filter: WebRequest.RequestFilter,
273  		extraInfoSpec?: WebRequest.OnBeforeRequestOptions[],
274  	): void => {
275  		// `onBeforeRequest` is ok with `Promise<void>`
276  		const callback = handler as SetReturnType<typeof handler, WebRequest.BlockingResponseOrPromise>;
277  		browser.webRequest.onBeforeRequest.addListener(
278  			callback,
279  			{
280  				...filter,
281  				urls: filter.urls.map(url => new URL(url, this.#ctx.origin).toString()),
282  			},
283  			extraInfoSpec,
284  		);
285  		this.#unregisterHandlers.push(() =>
286  			browser.webRequest.onBeforeRequest.removeListener(callback),
287  		);
288  	};
289  
290  	[Symbol.dispose]() {
291  		for (const unregister of this.#unregisterHandlers) {
292  			unregister();
293  		}
294  	}
295  }