From 155e434a008d86fd45093d51734006368895a1f1 Mon Sep 17 00:00:00 2001 From: Lucas Coratger <73360179+coratgerl@users.noreply.github.com> Date: Sat, 6 Dec 2025 17:21:48 +0100 Subject: [PATCH] feat(wobe-graphql-apollo): add security options --- .../wobe-graphql-apollo/src/index.test.ts | 322 ++++++++++++++++++ packages/wobe-graphql-apollo/src/index.ts | 259 +++++++++++++- 2 files changed, 577 insertions(+), 4 deletions(-) diff --git a/packages/wobe-graphql-apollo/src/index.test.ts b/packages/wobe-graphql-apollo/src/index.test.ts index 508ce2c..93a2ff8 100644 --- a/packages/wobe-graphql-apollo/src/index.test.ts +++ b/packages/wobe-graphql-apollo/src/index.test.ts @@ -3,6 +3,8 @@ import { Wobe } from 'wobe' import getPort from 'get-port' import { WobeGraphqlApolloPlugin } from '.' +const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)) + describe('Wobe GraphQL Apollo plugin', () => { it('should reject GET requests by default', async () => { const port = await getPort() @@ -179,6 +181,326 @@ describe('Wobe GraphQL Apollo plugin', () => { wobe.stop() }) + it('should block queries that exceed max depth', async () => { + const port = await getPort() + const wobe = new Wobe() + + await wobe.usePlugin( + await WobeGraphqlApolloPlugin({ + maxDepth: 2, + options: { + typeDefs: `#graphql + type Query { hello: Hello } + type Hello { nested: Nested } + type Nested { value: String } + `, + resolvers: { + Query: { + hello: () => ({ nested: { value: 'ok' } }), + }, + }, + }, + }), + ) + + wobe.listen(port) + + const res = await fetch(`http://127.0.0.1:${port}/graphql`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query: ` + query TooDeep { hello { nested { value } } } + `, + }), + }) + + const body = await res.json() + expect(body.data).toBeUndefined() + expect(body.errors?.[0]?.message).toContain('max depth') + + wobe.stop() + }) + + it('should block queries that exceed max cost', async () => { + const port = await getPort() + const wobe = new Wobe() + + await wobe.usePlugin( + await WobeGraphqlApolloPlugin({ + maxCost: 2, + options: { + typeDefs: `#graphql + type Query { a: String b: String c: String } + `, + resolvers: { + Query: { + a: () => 'a', + b: () => 'b', + c: () => 'c', + }, + }, + }, + }), + ) + + wobe.listen(port) + + const res = await fetch(`http://127.0.0.1:${port}/graphql`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query: ` + query TooExpensive { a b c } + `, + }), + }) + + const body = await res.json() + expect(body.data).toBeUndefined() + expect(body.errors?.[0]?.message).toContain('too expensive') + + wobe.stop() + }) + + it('should reject multiple operations by default', async () => { + const port = await getPort() + const wobe = new Wobe() + + await wobe.usePlugin( + await WobeGraphqlApolloPlugin({ + options: { + typeDefs: `#graphql + type Query { a: String b: String } + `, + resolvers: { + Query: { + a: () => 'a', + b: () => 'b', + }, + }, + }, + }), + ) + + wobe.listen(port) + + const res = await fetch(`http://127.0.0.1:${port}/graphql`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query: ` + query One { a } + query Two { b } + `, + }), + }) + + const body = await res.json() + expect(body.data).toBeUndefined() + expect(body.errors?.[0]?.message).toMatch( + /Multiple operations|Could not determine/i, + ) + + wobe.stop() + }) + + it('should allow only whitelisted operation names when provided', async () => { + const port = await getPort() + const wobe = new Wobe() + + await wobe.usePlugin( + await WobeGraphqlApolloPlugin({ + allowedOperationNames: ['AllowedOp'], + allowMultipleOperations: false, + options: { + typeDefs: `#graphql + type Query { a: String } + `, + resolvers: { + Query: { + a: () => 'a', + }, + }, + }, + }), + ) + + wobe.listen(port) + + const res = await fetch(`http://127.0.0.1:${port}/graphql`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query: ` + query NotAllowed { a } + `, + }), + }) + + const body = await res.json() + expect(body.data).toBeUndefined() + expect(body.errors?.[0]?.message).toContain('not allowed') + + wobe.stop() + }) + + it('should reject requests that exceed maxRequestSizeBytes', async () => { + const port = await getPort() + const wobe = new Wobe() + + await wobe.usePlugin( + await WobeGraphqlApolloPlugin({ + maxRequestSizeBytes: 10, + options: { + typeDefs: `#graphql + type Query { a: String } + `, + resolvers: { + Query: { + a: () => 'a', + }, + }, + }, + }), + ) + + wobe.listen(port) + + const res = await fetch(`http://127.0.0.1:${port}/graphql`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query: ` + query LargePayload { a } + `, + }), + }) + + expect(res.status).toBe(413) + wobe.stop() + }) + + it('should timeout when resolver exceeds timeoutMs', async () => { + const port = await getPort() + const wobe = new Wobe() + + await wobe.usePlugin( + await WobeGraphqlApolloPlugin({ + timeoutMs: 10, + options: { + typeDefs: `#graphql + type Query { slow: String } + `, + resolvers: { + Query: { + slow: async () => { + await sleep(50) + return 'slow' + }, + }, + }, + }, + }), + ) + + wobe.listen(port) + + const res = await fetch(`http://127.0.0.1:${port}/graphql`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query: ` + query Slow { slow } + `, + }), + }) + + expect(res.status).toBe(504) + wobe.stop() + }) + + it('should allow rateLimiter to block requests', async () => { + const port = await getPort() + const wobe = new Wobe() + + await wobe.usePlugin( + await WobeGraphqlApolloPlugin({ + rateLimiter: async () => + new Response('Too Many Requests', { status: 429 }), + options: { + typeDefs: `#graphql + type Query { hello: String } + `, + resolvers: { + Query: { + hello: () => 'Hello', + }, + }, + }, + }), + ) + + wobe.listen(port) + + const res = await fetch(`http://127.0.0.1:${port}/graphql`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query: ` + query Test { hello } + `, + }), + }) + + expect(res.status).toBe(429) + wobe.stop() + }) + + it('should call onRequestResolved hook with timing info', async () => { + const port = await getPort() + const wobe = new Wobe() + let called = false + let status: number | undefined + + await wobe.usePlugin( + await WobeGraphqlApolloPlugin({ + onRequestResolved: (input) => { + called = true + status = input.status + }, + options: { + typeDefs: `#graphql + type Query { hello: String } + `, + resolvers: { + Query: { + hello: () => 'Hello', + }, + }, + }, + }), + ) + + wobe.listen(port) + + const res = await fetch(`http://127.0.0.1:${port}/graphql`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query: ` + query Hook { hello } + `, + }), + }) + + expect(res.status).toBe(200) + expect(called).toBe(true) + expect(status).toBe(200) + + wobe.stop() + }) + it('should have custom wobe context in graphql context with record', async () => { const port = await getPort() diff --git a/packages/wobe-graphql-apollo/src/index.ts b/packages/wobe-graphql-apollo/src/index.ts index 21affc0..16dda04 100644 --- a/packages/wobe-graphql-apollo/src/index.ts +++ b/packages/wobe-graphql-apollo/src/index.ts @@ -1,5 +1,6 @@ import { ApolloServer, type ApolloServerOptions } from '@apollo/server' import { ApolloServerPluginLandingPageLocalDefault } from '@apollo/server/plugin/landingPage/default' +import { GraphQLError, type ValidationRule } from 'graphql' import type { Wobe, MaybePromise, @@ -14,6 +15,184 @@ export type GraphQLApolloContext = const getQueryString = (url: string) => url.slice(url.indexOf('?', 11) + 1) +const createDepthLimitRule = (maxDepth: number): ValidationRule => { + return (context) => { + const checkDepth = (depth: number) => { + if (depth > maxDepth) { + context.reportError( + new GraphQLError( + `Query is too deep: ${depth} > max depth ${maxDepth}`, + ), + ) + } + } + + const traverse = ( + selectionSet: any, + depth: number, + visitedFragments: Set, + ) => { + checkDepth(depth) + + for (const selection of selectionSet.selections || []) { + if (selection.selectionSet) { + traverse( + selection.selectionSet, + depth + 1, + visitedFragments, + ) + continue + } + + if (selection.kind === 'FragmentSpread') { + const name = selection.name.value + if (visitedFragments.has(name)) continue + visitedFragments.add(name) + const fragment = context.getFragment(name) + if (fragment) { + traverse( + fragment.selectionSet, + depth + 1, + visitedFragments, + ) + } + } + } + } + + return { + OperationDefinition(node) { + traverse(node.selectionSet, 1, new Set()) + }, + } + } +} + +const createCostLimitRule = (maxCost: number): ValidationRule => { + return (context) => { + let totalCost = 0 + + const countSelections = ( + selectionSet: any, + visitedFragments: Set, + ): number => { + let cost = 0 + for (const selection of selectionSet.selections || []) { + cost += 1 + if (selection.selectionSet) { + cost += countSelections( + selection.selectionSet, + visitedFragments, + ) + continue + } + if (selection.kind === 'FragmentSpread') { + const name = selection.name.value + if (visitedFragments.has(name)) continue + visitedFragments.add(name) + const fragment = context.getFragment(name) + if (fragment) { + cost += countSelections( + fragment.selectionSet, + visitedFragments, + ) + } + } + } + return cost + } + + return { + OperationDefinition(node) { + const visitedFragments = new Set() + totalCost += countSelections( + node.selectionSet, + visitedFragments, + ) + }, + Document: { + leave() { + if (totalCost > maxCost) { + context.reportError( + new GraphQLError( + `Query is too expensive: ${totalCost} > max cost ${maxCost}`, + ), + ) + } + }, + }, + } + } +} + +const createOperationConstraintsRule = ({ + allowedOperationNames, + allowMultipleOperations, +}: { + allowedOperationNames?: string[] + allowMultipleOperations: boolean +}): ValidationRule => { + return (context) => { + const seenOperations: string[] = [] + + return { + OperationDefinition(node) { + const name = node.name?.value + if (name) { + seenOperations.push(name) + if ( + allowedOperationNames && + allowedOperationNames.length > 0 && + !allowedOperationNames.includes(name) + ) { + context.reportError( + new GraphQLError( + `Operation "${name}" is not allowed in this endpoint.`, + ), + ) + } + } + }, + Document: { + leave() { + if (!allowMultipleOperations && seenOperations.length > 1) { + context.reportError( + new GraphQLError( + 'Multiple operations are not allowed in this endpoint.', + ), + ) + } + }, + }, + } + } +} + +const resolveWithTimeout = async ( + resolve: () => Promise, + timeoutMs: number | undefined, +) => { + if (!timeoutMs || timeoutMs <= 0) return resolve() + + let timeoutId: ReturnType | undefined + + const timeoutPromise = new Promise((resolveTimeout) => { + timeoutId = setTimeout( + () => + resolveTimeout( + new Response('Request Timeout', { status: 504 }), + ), + timeoutMs, + ) + }) + + const response = await Promise.race([resolve(), timeoutPromise]) + + if (timeoutId) clearTimeout(timeoutId) + + return response +} + export interface GraphQLApolloPluginOptions { graphqlMiddleware?: ( resolve: () => Promise, @@ -22,6 +201,19 @@ export interface GraphQLApolloPluginOptions { allowGetRequests?: boolean isProduction?: boolean allowIntrospection?: boolean + maxDepth?: number + maxCost?: number + maxRequestSizeBytes?: number + timeoutMs?: number + allowedOperationNames?: string[] + allowMultipleOperations?: boolean + onRequestResolved?: (input: { + operationName?: string | null + success: boolean + status: number + durationMs: number + }) => void + rateLimiter?: (context: Context) => MaybePromise } export const WobeGraphqlApolloPlugin = async ({ @@ -32,6 +224,14 @@ export const WobeGraphqlApolloPlugin = async ({ isProduction, allowGetRequests = false, allowIntrospection, + maxDepth, + maxCost, + maxRequestSizeBytes, + timeoutMs, + allowedOperationNames, + allowMultipleOperations = false, + onRequestResolved, + rateLimiter, }: { options: ApolloServerOptions graphqlEndpoint?: string @@ -40,11 +240,26 @@ export const WobeGraphqlApolloPlugin = async ({ } & GraphQLApolloPluginOptions): Promise => { const introspection = options.introspection ?? - (allowIntrospection === true ? true : isProduction ? false : true) + (allowIntrospection === true ? true : !isProduction) + + const validationRules: ValidationRule[] = [ + ...(options.validationRules || []), + ...(maxDepth ? [createDepthLimitRule(maxDepth)] : []), + ...(maxCost ? [createCostLimitRule(maxCost)] : []), + ...(!allowMultipleOperations || (allowedOperationNames?.length || 0) > 0 + ? [ + createOperationConstraintsRule({ + allowedOperationNames, + allowMultipleOperations, + }), + ] + : []), + ] const server = new ApolloServer({ ...options, introspection, + validationRules, plugins: [ ...(options?.plugins || []), ...(isProduction @@ -62,13 +277,39 @@ export const WobeGraphqlApolloPlugin = async ({ return (wobe: Wobe) => { const getResponse = async (context: Context) => { const fetchEndpoint = async (request: Request) => { + let requestBody: any + + if (maxRequestSizeBytes) { + const contentLength = request.headers.get('content-length') + if ( + contentLength && + Number(contentLength) > maxRequestSizeBytes + ) { + return new Response('Request Entity Too Large', { + status: 413, + }) + } + } + + if (rateLimiter) { + const rateLimitResult = await rateLimiter(context) + if (rateLimitResult instanceof Response) + return rateLimitResult + } + + const start = performance.now() + + if (request.method !== 'GET') { + requestBody = await request.json() + } + const res = await server.executeHTTPGraphQLRequest({ httpGraphQLRequest: { method: request.method, body: request.method === 'GET' ? request.body - : await request.json(), + : requestBody, // @ts-expect-error headers: request.headers, search: getQueryString(request.url), @@ -100,16 +341,26 @@ export const WobeGraphqlApolloPlugin = async ({ headers: res.headers, }) + onRequestResolved?.({ + operationName: requestBody?.operationName, + success: response.ok, + status: response.status, + durationMs: performance.now() - start, + }) + return response } return new Response() } - if (!graphqlMiddleware) return fetchEndpoint(context.request) + const resolve = async () => fetchEndpoint(context.request) + + if (!graphqlMiddleware) + return resolveWithTimeout(resolve, timeoutMs) return graphqlMiddleware(async () => { - const response = await fetchEndpoint(context.request) + const response = await resolveWithTimeout(resolve, timeoutMs) return response }, context.res)