From 4f7755eeff2ad51127d4ee1ae938d973102f9860 Mon Sep 17 00:00:00 2001 From: Ori Gold Date: Mon, 3 Nov 2025 13:55:14 -0800 Subject: [PATCH 1/2] fix: request wrapper header methods adjusted to include custom headers --- .../com/perimeterx/http/RequestWrapper.java | 53 ++++++++++- .../perimeterx/api/RequestWrapperTest.java | 93 +++++++++++++++++++ 2 files changed, 141 insertions(+), 5 deletions(-) diff --git a/src/main/java/com/perimeterx/http/RequestWrapper.java b/src/main/java/com/perimeterx/http/RequestWrapper.java index 84e29844..d9730c27 100644 --- a/src/main/java/com/perimeterx/http/RequestWrapper.java +++ b/src/main/java/com/perimeterx/http/RequestWrapper.java @@ -7,8 +7,7 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStreamReader; -import java.util.HashMap; -import java.util.Map; +import java.util.*; /** * Reading HttpServletRequest is limited to one time only @@ -25,6 +24,12 @@ public RequestWrapper(HttpServletRequest request) { this.customHeaders = new HashMap<>(); } + // Add a custom header to the request + public void addHeader(String name, String value) { + this.customHeaders.put(name, value); + } + + // Modify body methods to read from the cached body @Override public ServletInputStream getInputStream() throws IOException { final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(getBody().getBytes()); @@ -36,6 +41,7 @@ public BufferedReader getReader() throws IOException { return new BufferedReader(new InputStreamReader(this.getInputStream())); } + // Modify header methods to include custom headers @Override public String getHeader(String name) { String headerValue = customHeaders.get(name); @@ -43,11 +49,48 @@ public String getHeader(String name) { if (headerValue != null) { return headerValue; } - return ((HttpServletRequest) getRequest()).getHeader(name); + return super.getHeader(name); } - public void addHeader(String name, String value) { - this.customHeaders.put(name, value); + @Override + public Enumeration getHeaderNames() { + Enumeration headerNames = super.getHeaderNames(); + List list = Collections.list(headerNames); + for (String customHeaderName : customHeaders.keySet()) { + if (!list.contains(customHeaderName)) { + list.add(customHeaderName); + } + } + return Collections.enumeration(list); + } + + @Override + public Enumeration getHeaders(String name) { + String headerValue = customHeaders.get(name); + if (headerValue != null) { + List list = new ArrayList<>(); + list.add(headerValue); + return Collections.enumeration(list); + } + return super.getHeaders(name); + } + + @Override + public int getIntHeader(String name) { + final String headerValue = getHeader(name); + if (headerValue != null) { + return Integer.parseInt(headerValue); + } + return -1; + } + + @Override + public long getDateHeader(String name) { + final String headerValue = getHeader(name); + if (headerValue != null) { + return Long.parseLong(headerValue); + } + return -1L; } public synchronized String getBody() throws IOException { diff --git a/src/test/java/com/perimeterx/api/RequestWrapperTest.java b/src/test/java/com/perimeterx/api/RequestWrapperTest.java index 4f2a122b..1fea8485 100644 --- a/src/test/java/com/perimeterx/api/RequestWrapperTest.java +++ b/src/test/java/com/perimeterx/api/RequestWrapperTest.java @@ -6,6 +6,8 @@ import java.io.BufferedReader; import java.io.IOException; +import java.util.Collections; +import java.util.List; import static org.testng.Assert.*; @@ -64,4 +66,95 @@ public void testSpecialCharacters() throws IOException { RequestWrapper requestWrapper = new RequestWrapper(req); assertEquals(requestWrapper.getBody(), new String(bytes)); } + + @Test + public void testGetHeader() { + MockHttpServletRequest req = new MockHttpServletRequest(); + req.addHeader("header1", "value1"); + RequestWrapper requestWrapper = new RequestWrapper(req); + requestWrapper.addHeader("header2", "value2"); + + assertEquals(requestWrapper.getHeader("header1"), "value1"); + assertEquals(requestWrapper.getHeader("header2"), "value2"); + } + + @Test + public void testGetHeaderNames() { + MockHttpServletRequest req = new MockHttpServletRequest(); + req.addHeader("header1", "value1"); + RequestWrapper requestWrapper = new RequestWrapper(req); + requestWrapper.addHeader("header2", "value2"); + + boolean foundHeader1 = false; + boolean foundHeader2 = false; + for (String headerName : Collections.list(requestWrapper.getHeaderNames())) { + if (headerName.equals("header1")) { + foundHeader1 = true; + } + if (headerName.equals("header2")) { + foundHeader2 = true; + } + } + assertTrue(foundHeader1); + assertTrue(foundHeader2); + } + + @Test + public void testGetHeaders() { + MockHttpServletRequest req = new MockHttpServletRequest(); + req.addHeader("header1", "value1"); + RequestWrapper requestWrapper = new RequestWrapper(req); + requestWrapper.addHeader("header2", "value2"); + + List header1Values = Collections.list(requestWrapper.getHeaders("header1")); + assertEquals(header1Values.size(), 1); + for (String headerValue : header1Values) { + assertEquals(headerValue, "value1"); + } + + List header2Values = Collections.list(requestWrapper.getHeaders("header2")); + assertEquals(header2Values.size(), 1); + for (String headerValue : header2Values) { + assertEquals(headerValue, "value2"); + } + } + + @Test + public void testGetIntHeader() { + MockHttpServletRequest req = new MockHttpServletRequest(); + req.addHeader("intHeader", "123"); + RequestWrapper requestWrapper = new RequestWrapper(req); + requestWrapper.addHeader("customIntHeader", "456"); + requestWrapper.addHeader("stringHeader", "stringValue"); + + assertEquals(requestWrapper.getIntHeader("intHeader"), 123); + assertEquals(requestWrapper.getIntHeader("customIntHeader"), 456); + assertEquals(requestWrapper.getIntHeader("nonExistentHeader"), -1); + try { + requestWrapper.getIntHeader("stringHeader"); + fail("Expected NumberFormatException"); + } catch (NumberFormatException e) { + // Expected exception + } + } + + @Test + public void testGetDateHeader() { + MockHttpServletRequest req = new MockHttpServletRequest(); + long now = System.currentTimeMillis(); + req.addHeader("dateHeader", Long.toString(now)); + RequestWrapper requestWrapper = new RequestWrapper(req); + requestWrapper.addHeader("customDateHeader", Long.toString(now + 1000)); + requestWrapper.addHeader("stringHeader", "stringValue"); + + assertEquals(requestWrapper.getDateHeader("dateHeader"), now); + assertEquals(requestWrapper.getDateHeader("customDateHeader"), now + 1000); + assertEquals(requestWrapper.getDateHeader("nonExistentHeader"), -1); + try { + requestWrapper.getDateHeader("stringHeader"); + fail("Expected NumberFormatException"); + } catch (NumberFormatException e) { + // Expected exception + } + } } From 5217486f8035949245eeef5d45525ec92b4b2e6f Mon Sep 17 00:00:00 2001 From: Ori Gold Date: Mon, 3 Nov 2025 14:39:48 -0800 Subject: [PATCH 2/2] test: add context unit tests to confirm request wrapper headers are included --- .../com/perimeterx/models/PXContextTest.java | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/test/java/com/perimeterx/models/PXContextTest.java b/src/test/java/com/perimeterx/models/PXContextTest.java index 36eef547..fe2705c1 100644 --- a/src/test/java/com/perimeterx/models/PXContextTest.java +++ b/src/test/java/com/perimeterx/models/PXContextTest.java @@ -5,6 +5,7 @@ import com.perimeterx.api.providers.HostnameProvider; import com.perimeterx.api.providers.IPProvider; import com.perimeterx.api.providers.RemoteAddressIPProvider; +import com.perimeterx.http.RequestWrapper; import com.perimeterx.models.configuration.PXConfiguration; import com.perimeterx.models.risk.CustomParameters; import org.mockito.Mockito; @@ -14,6 +15,7 @@ import org.testng.annotations.Test; import javax.servlet.http.HttpServletRequest; +import java.util.Collections; /** * Test {@link PXContext} @@ -53,4 +55,41 @@ public void customParamsTest() { Mockito.verify(spyTestCustomParamProvider).buildCustomParameters(pxConfig, context); } + + @Test + public void allRequestHeadersShouldBeInPXContext() { + CustomParameters customParameters = new CustomParameters(); + customParameters.setCustomParam1("number1"); + TestCustomParamProvider spyTestCustomParamProvider = Mockito.spy(new TestCustomParamProvider(customParameters)); + PXConfiguration pxConfig = PXConfiguration.builder() + .appId("APP_ID") + .authToken("AUTH_123") + .cookieKey("COOKIE_123") + .customParametersProvider(spyTestCustomParamProvider) + .build(); + ((MockHttpServletRequest) request).addHeader("TEST-BYPASS", "0"); + PXContext context = new PXContext(request, this.ipProvider, this.hostnameProvider, pxConfig); + Assert.assertEquals(context.getHeaders().size(), Collections.list(request.getHeaderNames()).size()); + } + + @Test + public void allRequestWrapperHeadersShouldBeInPXContext() { + CustomParameters customParameters = new CustomParameters(); + customParameters.setCustomParam1("number1"); + TestCustomParamProvider spyTestCustomParamProvider = Mockito.spy(new TestCustomParamProvider(customParameters)); + PXConfiguration pxConfig = PXConfiguration.builder() + .appId("APP_ID") + .authToken("AUTH_123") + .cookieKey("COOKIE_123") + .customParametersProvider(spyTestCustomParamProvider) + .build(); + ((MockHttpServletRequest) request).addHeader("TEST-BYPASS", "0"); + RequestWrapper requestWrapper = new RequestWrapper(request); + requestWrapper.addHeader("client-ip", "127.0.0.1"); + requestWrapper.addHeader("accept", "application/json"); + requestWrapper.addHeader("content-type", "application/json"); + + PXContext context = new PXContext(requestWrapper, this.ipProvider, this.hostnameProvider, pxConfig); + Assert.assertEquals(context.getHeaders().size(), Collections.list(request.getHeaderNames()).size() + 3); + } }