Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions src/main/java/com/perimeterx/http/RequestWrapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;
import java.util.*;

/**
* Reading HttpServletRequest is limited to one time only
Expand All @@ -25,6 +24,12 @@ public RequestWrapper(HttpServletRequest request) {
this.customHeaders = new HashMap<>();
}

// Add a custom header to the request
public void addHeader(String name, String value) {
this.customHeaders.put(name, value);
}

// Modify body methods to read from the cached body
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(getBody().getBytes());
Expand All @@ -36,18 +41,56 @@ public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(this.getInputStream()));
}

// Modify header methods to include custom headers
@Override
public String getHeader(String name) {
String headerValue = customHeaders.get(name);

if (headerValue != null) {
return headerValue;
}
return ((HttpServletRequest) getRequest()).getHeader(name);
return super.getHeader(name);
}

public void addHeader(String name, String value) {
this.customHeaders.put(name, value);
@Override
public Enumeration<String> getHeaderNames() {
Enumeration<String> headerNames = super.getHeaderNames();
List<String> list = Collections.list(headerNames);
for (String customHeaderName : customHeaders.keySet()) {
if (!list.contains(customHeaderName)) {
list.add(customHeaderName);
}
}
return Collections.enumeration(list);
}

@Override
public Enumeration<String> getHeaders(String name) {
String headerValue = customHeaders.get(name);
if (headerValue != null) {
List<String> list = new ArrayList<>();
list.add(headerValue);
return Collections.enumeration(list);
}
return super.getHeaders(name);
}

@Override
public int getIntHeader(String name) {
final String headerValue = getHeader(name);
if (headerValue != null) {
return Integer.parseInt(headerValue);
}
return -1;
}

@Override
public long getDateHeader(String name) {
final String headerValue = getHeader(name);
if (headerValue != null) {
return Long.parseLong(headerValue);
}
return -1L;
}

public synchronized String getBody() throws IOException {
Expand Down
93 changes: 93 additions & 0 deletions src/test/java/com/perimeterx/api/RequestWrapperTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import java.io.BufferedReader;
import java.io.IOException;
import java.util.Collections;
import java.util.List;

import static org.testng.Assert.*;

Expand Down Expand Up @@ -64,4 +66,95 @@ public void testSpecialCharacters() throws IOException {
RequestWrapper requestWrapper = new RequestWrapper(req);
assertEquals(requestWrapper.getBody(), new String(bytes));
}

@Test
public void testGetHeader() {
MockHttpServletRequest req = new MockHttpServletRequest();
req.addHeader("header1", "value1");
RequestWrapper requestWrapper = new RequestWrapper(req);
requestWrapper.addHeader("header2", "value2");

assertEquals(requestWrapper.getHeader("header1"), "value1");
assertEquals(requestWrapper.getHeader("header2"), "value2");
}

@Test
public void testGetHeaderNames() {
MockHttpServletRequest req = new MockHttpServletRequest();
req.addHeader("header1", "value1");
RequestWrapper requestWrapper = new RequestWrapper(req);
requestWrapper.addHeader("header2", "value2");

boolean foundHeader1 = false;
boolean foundHeader2 = false;
for (String headerName : Collections.list(requestWrapper.getHeaderNames())) {
if (headerName.equals("header1")) {
foundHeader1 = true;
}
if (headerName.equals("header2")) {
foundHeader2 = true;
}
}
assertTrue(foundHeader1);
assertTrue(foundHeader2);
}

@Test
public void testGetHeaders() {
MockHttpServletRequest req = new MockHttpServletRequest();
req.addHeader("header1", "value1");
RequestWrapper requestWrapper = new RequestWrapper(req);
requestWrapper.addHeader("header2", "value2");

List<String> header1Values = Collections.list(requestWrapper.getHeaders("header1"));
assertEquals(header1Values.size(), 1);
for (String headerValue : header1Values) {
assertEquals(headerValue, "value1");
}

List<String> header2Values = Collections.list(requestWrapper.getHeaders("header2"));
assertEquals(header2Values.size(), 1);
for (String headerValue : header2Values) {
assertEquals(headerValue, "value2");
}
}

@Test
public void testGetIntHeader() {
MockHttpServletRequest req = new MockHttpServletRequest();
req.addHeader("intHeader", "123");
RequestWrapper requestWrapper = new RequestWrapper(req);
requestWrapper.addHeader("customIntHeader", "456");
requestWrapper.addHeader("stringHeader", "stringValue");

assertEquals(requestWrapper.getIntHeader("intHeader"), 123);
assertEquals(requestWrapper.getIntHeader("customIntHeader"), 456);
assertEquals(requestWrapper.getIntHeader("nonExistentHeader"), -1);
try {
requestWrapper.getIntHeader("stringHeader");
fail("Expected NumberFormatException");
} catch (NumberFormatException e) {
// Expected exception
}
}

@Test
public void testGetDateHeader() {
MockHttpServletRequest req = new MockHttpServletRequest();
long now = System.currentTimeMillis();
req.addHeader("dateHeader", Long.toString(now));
RequestWrapper requestWrapper = new RequestWrapper(req);
requestWrapper.addHeader("customDateHeader", Long.toString(now + 1000));
requestWrapper.addHeader("stringHeader", "stringValue");

assertEquals(requestWrapper.getDateHeader("dateHeader"), now);
assertEquals(requestWrapper.getDateHeader("customDateHeader"), now + 1000);
assertEquals(requestWrapper.getDateHeader("nonExistentHeader"), -1);
try {
requestWrapper.getDateHeader("stringHeader");
fail("Expected NumberFormatException");
} catch (NumberFormatException e) {
// Expected exception
}
}
}
39 changes: 39 additions & 0 deletions src/test/java/com/perimeterx/models/PXContextTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.perimeterx.api.providers.HostnameProvider;
import com.perimeterx.api.providers.IPProvider;
import com.perimeterx.api.providers.RemoteAddressIPProvider;
import com.perimeterx.http.RequestWrapper;
import com.perimeterx.models.configuration.PXConfiguration;
import com.perimeterx.models.risk.CustomParameters;
import org.mockito.Mockito;
Expand All @@ -14,6 +15,7 @@
import org.testng.annotations.Test;

import javax.servlet.http.HttpServletRequest;
import java.util.Collections;

/**
* Test {@link PXContext}
Expand Down Expand Up @@ -53,4 +55,41 @@ public void customParamsTest() {

Mockito.verify(spyTestCustomParamProvider).buildCustomParameters(pxConfig, context);
}

@Test
public void allRequestHeadersShouldBeInPXContext() {
CustomParameters customParameters = new CustomParameters();
customParameters.setCustomParam1("number1");
TestCustomParamProvider spyTestCustomParamProvider = Mockito.spy(new TestCustomParamProvider(customParameters));
PXConfiguration pxConfig = PXConfiguration.builder()
.appId("APP_ID")
.authToken("AUTH_123")
.cookieKey("COOKIE_123")
.customParametersProvider(spyTestCustomParamProvider)
.build();
((MockHttpServletRequest) request).addHeader("TEST-BYPASS", "0");
PXContext context = new PXContext(request, this.ipProvider, this.hostnameProvider, pxConfig);
Assert.assertEquals(context.getHeaders().size(), Collections.list(request.getHeaderNames()).size());
}

@Test
public void allRequestWrapperHeadersShouldBeInPXContext() {
CustomParameters customParameters = new CustomParameters();
customParameters.setCustomParam1("number1");
TestCustomParamProvider spyTestCustomParamProvider = Mockito.spy(new TestCustomParamProvider(customParameters));
PXConfiguration pxConfig = PXConfiguration.builder()
.appId("APP_ID")
.authToken("AUTH_123")
.cookieKey("COOKIE_123")
.customParametersProvider(spyTestCustomParamProvider)
.build();
((MockHttpServletRequest) request).addHeader("TEST-BYPASS", "0");
RequestWrapper requestWrapper = new RequestWrapper(request);
requestWrapper.addHeader("client-ip", "127.0.0.1");
requestWrapper.addHeader("accept", "application/json");
requestWrapper.addHeader("content-type", "application/json");

PXContext context = new PXContext(requestWrapper, this.ipProvider, this.hostnameProvider, pxConfig);
Assert.assertEquals(context.getHeaders().size(), Collections.list(request.getHeaderNames()).size() + 3);
}
}