diff --git a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java
index 411c8ecc5..3d162a5de 100644
--- a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java
+++ b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java
@@ -67,8 +67,10 @@ public static void main(String[] args) throws Exception {
.builder()
.mcpEndpoint(MCP_ENDPOINT)
.keepAliveInterval(Duration.ofSeconds(30))
- .securityValidator(
- DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .securityValidator(DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigin("http://localhost:*")
+ .allowedHost("localhost:*")
+ .build())
.build();
// Build server with all conformance test features
diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java
index 5321aada7..db1b4f75e 100644
--- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java
+++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java
@@ -12,10 +12,11 @@
/**
* Default implementation of {@link ServerTransportSecurityValidator} that validates the
- * Origin header against a list of allowed origins.
+ * Origin and Host headers against lists of allowed values.
*
*
- * Supports exact matches and wildcard port patterns (e.g., "http://example.com:*").
+ * Supports exact matches and wildcard port patterns (e.g., "http://example.com:*" for
+ * origins, "example.com:*" for hosts).
*
* @author Daniel Garnier-Moiroux
* @see ServerTransportSecurityValidator
@@ -25,32 +26,55 @@ public class DefaultServerTransportSecurityValidator implements ServerTransportS
private static final String ORIGIN_HEADER = "Origin";
+ private static final String HOST_HEADER = "Host";
+
private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403,
"Invalid Origin header");
+ private static final ServerTransportSecurityException INVALID_HOST = new ServerTransportSecurityException(421,
+ "Invalid Host header");
+
private final List allowedOrigins;
+ private final List allowedHosts;
+
/**
- * Creates a new validator with the specified allowed origins.
+ * Creates a new validator with the specified allowed origins and hosts.
* @param allowedOrigins List of allowed origin patterns. Supports exact matches
* (e.g., "http://example.com:8080") and wildcard ports (e.g., "http://example.com:*")
+ * @param allowedHosts List of allowed host patterns. Supports exact matches (e.g.,
+ * "example.com:8080") and wildcard ports (e.g., "example.com:*")
*/
- public DefaultServerTransportSecurityValidator(List allowedOrigins) {
+ public DefaultServerTransportSecurityValidator(List allowedOrigins, List allowedHosts) {
Assert.notNull(allowedOrigins, "allowedOrigins must not be null");
+ Assert.notNull(allowedHosts, "allowedHosts must not be null");
this.allowedOrigins = allowedOrigins;
+ this.allowedHosts = allowedHosts;
}
@Override
public void validateHeaders(Map> headers) throws ServerTransportSecurityException {
+ boolean missingHost = true;
for (Map.Entry> entry : headers.entrySet()) {
if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) {
List values = entry.getValue();
- if (values != null && !values.isEmpty()) {
- validateOrigin(values.get(0));
+ if (values == null || values.isEmpty()) {
+ throw INVALID_ORIGIN;
+ }
+ validateOrigin(values.get(0));
+ }
+ else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) {
+ missingHost = false;
+ List values = entry.getValue();
+ if (values == null || values.isEmpty()) {
+ throw INVALID_HOST;
}
- break;
+ validateHost(values.get(0));
}
}
+ if (!allowedHosts.isEmpty() && missingHost) {
+ throw INVALID_HOST;
+ }
}
/**
@@ -82,6 +106,37 @@ else if (allowed.endsWith(":*")) {
throw INVALID_ORIGIN;
}
+ /**
+ * Validates a single host value against the allowed hosts.
+ * @param host The host header value, or null if not present
+ * @throws ServerTransportSecurityException if the host is not allowed
+ */
+ private void validateHost(String host) throws ServerTransportSecurityException {
+ if (allowedHosts.isEmpty()) {
+ return;
+ }
+
+ // Host is required
+ if (host == null || host.isBlank()) {
+ throw INVALID_HOST;
+ }
+
+ for (String allowed : allowedHosts) {
+ if (allowed.equals(host)) {
+ return;
+ }
+ else if (allowed.endsWith(":*")) {
+ // Wildcard port pattern: "example.com:*"
+ String baseHost = allowed.substring(0, allowed.length() - 2);
+ if (host.equals(baseHost) || host.startsWith(baseHost + ":")) {
+ return;
+ }
+ }
+ }
+
+ throw INVALID_HOST;
+ }
+
/**
* Creates a new builder for constructing a DefaultServerTransportSecurityValidator.
* @return A new builder instance
@@ -97,6 +152,8 @@ public static class Builder {
private final List allowedOrigins = new ArrayList<>();
+ private final List allowedHosts = new ArrayList<>();
+
/**
* Adds an allowed origin pattern.
* @param origin The origin to allow (e.g., "http://localhost:8080" or
@@ -119,12 +176,33 @@ public Builder allowedOrigins(List origins) {
return this;
}
+ /**
+ * Adds an allowed host pattern.
+ * @param host The host to allow (e.g., "localhost:8080" or "example.com:*")
+ * @return this builder instance
+ */
+ public Builder allowedHost(String host) {
+ this.allowedHosts.add(host);
+ return this;
+ }
+
+ /**
+ * Adds multiple allowed host patterns.
+ * @param hosts The hosts to allow
+ * @return this builder instance
+ */
+ public Builder allowedHosts(List hosts) {
+ Assert.notNull(hosts, "hosts must not be null");
+ this.allowedHosts.addAll(hosts);
+ return this;
+ }
+
/**
* Builds the validator instance.
* @return A new DefaultServerTransportSecurityValidator
*/
public DefaultServerTransportSecurityValidator build() {
- return new DefaultServerTransportSecurityValidator(allowedOrigins);
+ return new DefaultServerTransportSecurityValidator(allowedOrigins, allowedHosts);
}
}
diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java
index 7e1593e1b..d4cf8582d 100644
--- a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java
+++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java
@@ -22,6 +22,9 @@ class DefaultServerTransportSecurityValidatorTests {
private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403,
"Invalid Origin header");
+ private static final ServerTransportSecurityException INVALID_HOST = new ServerTransportSecurityException(421,
+ "Invalid Host header");
+
private final DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder()
.allowedOrigin("http://localhost:8080")
.build();
@@ -31,161 +34,370 @@ void builder() {
assertThatCode(() -> DefaultServerTransportSecurityValidator.builder().build()).doesNotThrowAnyException();
assertThatThrownBy(() -> DefaultServerTransportSecurityValidator.builder().allowedOrigins(null).build())
.isInstanceOf(IllegalArgumentException.class);
+ assertThatThrownBy(() -> DefaultServerTransportSecurityValidator.builder().allowedHosts(null).build())
+ .isInstanceOf(IllegalArgumentException.class);
}
- @Test
- void originHeaderMissing() {
- assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException();
- }
+ @Nested
+ class OriginHeader {
- @Test
- void originHeaderListEmpty() {
- assertThatCode(() -> validator.validateHeaders(Map.of("Origin", List.of()))).doesNotThrowAnyException();
- }
+ @Test
+ void originHeaderMissing() {
+ assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException();
+ }
- @Test
- void caseInsensitive() {
- var headers = Map.of("origin", List.of("http://localhost:8080"));
+ @Test
+ void originHeaderListEmpty() {
+ assertThatThrownBy(() -> validator.validateHeaders(Map.of("Origin", List.of()))).isEqualTo(INVALID_ORIGIN);
+ }
- assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException();
- }
+ @Test
+ void caseInsensitive() {
+ var headers = Map.of("origin", List.of("http://localhost:8080"));
- @Test
- void exactMatch() {
- var headers = originHeader("http://localhost:8080");
+ assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
- assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException();
- }
+ @Test
+ void exactMatch() {
+ var headers = originHeader("http://localhost:8080");
- @Test
- void differentPort() {
+ assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
- var headers = originHeader("http://localhost:3000");
+ @Test
+ void differentPort() {
- assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
- }
+ var headers = originHeader("http://localhost:3000");
- @Test
- void differentHost() {
+ assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ }
- var headers = originHeader("http://example.com:8080");
+ @Test
+ void differentHost() {
- assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
- }
+ var headers = originHeader("http://example.com:8080");
- @Test
- void differentScheme() {
+ assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ }
+
+ @Test
+ void differentScheme() {
+
+ var headers = originHeader("https://localhost:8080");
+
+ assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ }
+
+ @Nested
+ class WildcardPort {
+
+ private final DefaultServerTransportSecurityValidator wildcardValidator = DefaultServerTransportSecurityValidator
+ .builder()
+ .allowedOrigin("http://localhost:*")
+ .build();
+
+ @Test
+ void anyPortWithWildcard() {
+ var headers = originHeader("http://localhost:3000");
+
+ assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void noPortWithWildcard() {
+ var headers = originHeader("http://localhost");
+
+ assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void differentPortWithWildcard() {
+ var headers = originHeader("http://localhost:8080");
+
+ assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void differentHostWithWildcard() {
+ var headers = originHeader("http://example.com:3000");
+
+ assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ }
+
+ @Test
+ void differentSchemeWithWildcard() {
+ var headers = originHeader("https://localhost:3000");
+
+ assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ }
+
+ }
+
+ @Nested
+ class MultipleOrigins {
+
+ DefaultServerTransportSecurityValidator multipleOriginsValidator = DefaultServerTransportSecurityValidator
+ .builder()
+ .allowedOrigin("http://localhost:8080")
+ .allowedOrigin("http://example.com:3000")
+ .allowedOrigin("http://myapp.example.com:*")
+ .build();
+
+ @Test
+ void matchingOneOfMultiple() {
+ var headers = originHeader("http://example.com:3000");
+
+ assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void matchingWildcardInMultiple() {
+ var headers = originHeader("http://myapp.example.com:9999");
+
+ assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void notMatchingAny() {
+ var headers = originHeader("http://malicious.example.com:1234");
- var headers = originHeader("https://localhost:8080");
+ assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ }
+
+ }
+
+ @Nested
+ class BuilderTests {
+
+ @Test
+ void shouldAddMultipleOriginsWithAllowedOriginsMethod() {
+ DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigins(List.of("http://localhost:8080", "http://example.com:*"))
+ .build();
+
+ var headers = originHeader("http://example.com:3000");
+
+ assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void shouldCombineAllowedOriginMethods() {
+ DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigin("http://localhost:8080")
+ .allowedOrigins(List.of("http://example.com:*", "http://test.com:3000"))
+ .build();
+
+ assertThatCode(() -> validator.validateHeaders(originHeader("http://localhost:8080")))
+ .doesNotThrowAnyException();
+ assertThatCode(() -> validator.validateHeaders(originHeader("http://example.com:9999")))
+ .doesNotThrowAnyException();
+ assertThatCode(() -> validator.validateHeaders(originHeader("http://test.com:3000")))
+ .doesNotThrowAnyException();
+ }
+
+ }
- assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
}
@Nested
- class WildcardPort {
+ class HostHeader {
- private final DefaultServerTransportSecurityValidator wildcardValidator = DefaultServerTransportSecurityValidator
+ private final DefaultServerTransportSecurityValidator hostValidator = DefaultServerTransportSecurityValidator
.builder()
- .allowedOrigin("http://localhost:*")
+ .allowedHost("localhost:8080")
.build();
@Test
- void anyPortWithWildcard() {
- var headers = originHeader("http://localhost:3000");
+ void notConfigured() {
+ assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException();
+ }
- assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ @Test
+ void missing() {
+ assertThatThrownBy(() -> hostValidator.validateHeaders(new HashMap<>())).isEqualTo(INVALID_HOST);
}
@Test
- void noPortWithWildcard() {
- var headers = originHeader("http://localhost");
+ void listEmpty() {
+ assertThatThrownBy(() -> hostValidator.validateHeaders(Map.of("Host", List.of()))).isEqualTo(INVALID_HOST);
+ }
- assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ @Test
+ void caseInsensitive() {
+ var headers = Map.of("host", List.of("localhost:8080"));
+
+ assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException();
}
@Test
- void differentPortWithWildcard() {
- var headers = originHeader("http://localhost:8080");
+ void exactMatch() {
+ var headers = hostHeader("localhost:8080");
- assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException();
}
@Test
- void differentHostWithWildcard() {
- var headers = originHeader("http://example.com:3000");
+ void differentPort() {
+ var headers = hostHeader("localhost:3000");
- assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST);
}
@Test
- void differentSchemeWithWildcard() {
- var headers = originHeader("https://localhost:3000");
+ void differentHost() {
+ var headers = hostHeader("example.com:8080");
+
+ assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST);
+ }
+
+ @Nested
+ class HostWildcardPort {
+
+ private final DefaultServerTransportSecurityValidator wildcardHostValidator = DefaultServerTransportSecurityValidator
+ .builder()
+ .allowedHost("localhost:*")
+ .build();
+
+ @Test
+ void anyPort() {
+ var headers = hostHeader("localhost:3000");
+
+ assertThatCode(() -> wildcardHostValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void noPort() {
+ var headers = hostHeader("localhost");
+
+ assertThatCode(() -> wildcardHostValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void differentHost() {
+ var headers = hostHeader("example.com:3000");
+
+ assertThatThrownBy(() -> wildcardHostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST);
+ }
+
+ }
+
+ @Nested
+ class MultipleHosts {
+
+ DefaultServerTransportSecurityValidator multipleHostsValidator = DefaultServerTransportSecurityValidator
+ .builder()
+ .allowedHost("example.com:3000")
+ .allowedHost("myapp.example.com:*")
+ .build();
+
+ @Test
+ void exactMatch() {
+ var headers = hostHeader("example.com:3000");
+
+ assertThatCode(() -> multipleHostsValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void wildcard() {
+ var headers = hostHeader("myapp.example.com:9999");
+
+ assertThatCode(() -> multipleHostsValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void differentHost() {
+ var headers = hostHeader("malicious.example.com:3000");
+
+ assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST);
+ }
+
+ @Test
+ void differentPort() {
+ var headers = hostHeader("localhost:8080");
+
+ assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST);
+ }
+
+ }
+
+ @Nested
+ class HostBuilderTests {
+
+ @Test
+ void multipleHosts() {
+ DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder()
+ .allowedHosts(List.of("localhost:8080", "example.com:*"))
+ .build();
+
+ assertThatCode(() -> validator.validateHeaders(hostHeader("example.com:3000")))
+ .doesNotThrowAnyException();
+ assertThatCode(() -> validator.validateHeaders(hostHeader("localhost:8080")))
+ .doesNotThrowAnyException();
+ }
+
+ @Test
+ void combined() {
+ DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder()
+ .allowedHost("localhost:8080")
+ .allowedHosts(List.of("example.com:*", "test.com:3000"))
+ .build();
+
+ assertThatCode(() -> validator.validateHeaders(hostHeader("localhost:8080")))
+ .doesNotThrowAnyException();
+ assertThatCode(() -> validator.validateHeaders(hostHeader("example.com:9999")))
+ .doesNotThrowAnyException();
+ assertThatCode(() -> validator.validateHeaders(hostHeader("test.com:3000"))).doesNotThrowAnyException();
+ }
- assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
}
}
@Nested
- class MultipleOrigins {
+ class CombinedOriginAndHostValidation {
- DefaultServerTransportSecurityValidator multipleOriginsValidator = DefaultServerTransportSecurityValidator
+ private final DefaultServerTransportSecurityValidator combinedValidator = DefaultServerTransportSecurityValidator
.builder()
- .allowedOrigin("http://localhost:8080")
- .allowedOrigin("http://example.com:3000")
- .allowedOrigin("http://myapp.com:*")
+ .allowedOrigin("http://localhost:*")
+ .allowedHost("localhost:*")
.build();
@Test
- void matchingOneOfMultiple() {
- var headers = originHeader("http://example.com:3000");
+ void bothValid() {
+ var header = headers("http://localhost:8080", "localhost:8080");
- assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ assertThatCode(() -> combinedValidator.validateHeaders(header)).doesNotThrowAnyException();
}
@Test
- void matchingWildcardInMultiple() {
- var headers = originHeader("http://myapp.com:9999");
+ void originValidHostInvalid() {
+ var header = headers("http://localhost:8080", "malicious.example.com:8080");
- assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_HOST);
}
@Test
- void notMatchingAny() {
- var headers = originHeader("http://malicious.example.com:1234");
+ void originInvalidHostValid() {
+ var header = headers("http://malicious.example.com:8080", "localhost:8080");
- assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_ORIGIN);
}
- }
-
- @Nested
- class BuilderTests {
-
@Test
- void shouldAddMultipleOriginsWithAllowedOriginsMethod() {
- DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder()
- .allowedOrigins(List.of("http://localhost:8080", "http://example.com:*"))
- .build();
-
- var headers = originHeader("http://example.com:3000");
+ void originMissingHostValid() {
+ // Origin missing is OK (same-origin request)
+ var header = headers(null, "localhost:8080");
- assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException();
+ assertThatCode(() -> combinedValidator.validateHeaders(header)).doesNotThrowAnyException();
}
@Test
- void shouldCombineAllowedOriginMethods() {
- DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder()
- .allowedOrigin("http://localhost:8080")
- .allowedOrigins(List.of("http://example.com:*", "http://test.com:3000"))
- .build();
+ void originValidHostMissing() {
+ // Host missing is NOT OK when allowedHosts is configured
+ var header = headers("http://localhost:8080", null);
- assertThatCode(() -> validator.validateHeaders(originHeader("http://localhost:8080")))
- .doesNotThrowAnyException();
- assertThatCode(() -> validator.validateHeaders(originHeader("http://example.com:9999")))
- .doesNotThrowAnyException();
- assertThatCode(() -> validator.validateHeaders(originHeader("http://test.com:3000")))
- .doesNotThrowAnyException();
+ assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_HOST);
}
}
@@ -194,4 +406,19 @@ private static Map> originHeader(String origin) {
return Map.of("Origin", List.of(origin));
}
+ private static Map> hostHeader(String host) {
+ return Map.of("Host", List.of(host));
+ }
+
+ private static Map> headers(String origin, String host) {
+ var map = new HashMap>();
+ if (origin != null) {
+ map.put("Origin", List.of(origin));
+ }
+ if (host != null) {
+ map.put("Host", List.of(host));
+ }
+ return map;
+ }
+
}
diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java
index e9e64c0d0..9f5b9b30d 100644
--- a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java
+++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java
@@ -48,6 +48,8 @@ class ServerTransportSecurityIntegrationTests {
private static final String DISALLOWED_ORIGIN = "https://malicious.example.com";
+ private static final String DISALLOWED_HOST = "malicious.example.com:8080";
+
@Parameter
private static Transport transport;
@@ -73,6 +75,7 @@ static void afterAll() {
@BeforeEach
void setUp() {
+ requestCustomizer.reset();
mcpClient = transport.createMcpClient(baseUrl, requestCustomizer);
}
@@ -115,6 +118,29 @@ void messageOriginNotAllowed() {
assertThatThrownBy(() -> mcpClient.listTools());
}
+ @Test
+ void hostAllowed() {
+ // Host header is set by default by HttpClient to the request URI host
+ var result = mcpClient.initialize();
+ var tools = mcpClient.listTools();
+
+ assertThat(result.protocolVersion()).isNotEmpty();
+ assertThat(tools.tools()).isEmpty();
+ }
+
+ @Test
+ void connectHostNotAllowed() {
+ requestCustomizer.setHostHeader(DISALLOWED_HOST);
+ assertThatThrownBy(() -> mcpClient.initialize());
+ }
+
+ @Test
+ void messageHostNotAllowed() {
+ mcpClient.initialize();
+ requestCustomizer.setHostHeader(DISALLOWED_HOST);
+ assertThatThrownBy(() -> mcpClient.listTools());
+ }
+
// ----------------------------------------------------
// Tomcat management
// ----------------------------------------------------
@@ -182,8 +208,10 @@ static class Sse implements Transport {
public Sse() {
transport = HttpServletSseServerTransportProvider.builder()
.messageEndpoint("/mcp/message")
- .securityValidator(
- DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .securityValidator(DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigin("http://localhost:*")
+ .allowedHost("localhost:*")
+ .build())
.build();
McpServer.sync(transport)
.serverInfo("test-server", "1.0.0")
@@ -213,8 +241,10 @@ static class StreamableHttp implements Transport {
public StreamableHttp() {
transport = HttpServletStreamableServerTransportProvider.builder()
- .securityValidator(
- DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .securityValidator(DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigin("http://localhost:*")
+ .allowedHost("localhost:*")
+ .build())
.build();
McpServer.sync(transport)
.serverInfo("test-server", "1.0.0")
@@ -245,8 +275,10 @@ static class Stateless implements Transport {
public Stateless() {
transport = HttpServletStatelessServerTransport.builder()
- .securityValidator(
- DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .securityValidator(DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigin("http://localhost:*")
+ .allowedHost("localhost:*")
+ .build())
.build();
McpServer.sync(transport)
.serverInfo("test-server", "1.0.0")
@@ -275,18 +307,33 @@ static class TestRequestCustomizer implements McpSyncHttpClientRequestCustomizer
private String originHeader = null;
+ private String hostHeader = null;
+
@Override
public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body,
McpTransportContext context) {
if (originHeader != null) {
builder.header("Origin", originHeader);
}
+ if (hostHeader != null) {
+ // HttpClient normally sets Host automatically, but we can override it
+ builder.header("Host", hostHeader);
+ }
}
public void setOriginHeader(String originHeader) {
this.originHeader = originHeader;
}
+ public void setHostHeader(String hostHeader) {
+ this.hostHeader = hostHeader;
+ }
+
+ public void reset() {
+ this.originHeader = null;
+ this.hostHeader = null;
+ }
+
}
}
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java
index 06e1286d2..1a331bed5 100644
--- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java
@@ -57,6 +57,8 @@ public class WebFluxServerTransportSecurityIntegrationTests {
private static final String DISALLOWED_ORIGIN = "https://malicious.example.com";
+ private static final String DISALLOWED_HOST = "malicious.example.com:8080";
+
@Parameter
private static Transport transport;
@@ -78,7 +80,7 @@ static void afterAll() {
private McpSyncClient mcpClient;
- private final TestOriginHeaderExchangeFilterFunction exchangeFilterFunction = new TestOriginHeaderExchangeFilterFunction();
+ private final TestHeaderExchangeFilterFunction exchangeFilterFunction = new TestHeaderExchangeFilterFunction();
@BeforeEach
void setUp() {
@@ -124,6 +126,29 @@ void messageOriginNotAllowed() {
assertThatThrownBy(() -> mcpClient.listTools());
}
+ @Test
+ void hostAllowed() {
+ // Host header is set by default by WebClient to the request URI host
+ var result = mcpClient.initialize();
+ var tools = mcpClient.listTools();
+
+ assertThat(result.protocolVersion()).isNotEmpty();
+ assertThat(tools.tools()).isEmpty();
+ }
+
+ @Test
+ void connectHostNotAllowed() {
+ exchangeFilterFunction.setHostHeader(DISALLOWED_HOST);
+ assertThatThrownBy(() -> mcpClient.initialize());
+ }
+
+ @Test
+ void messageHostNotAllowed() {
+ mcpClient.initialize();
+ exchangeFilterFunction.setHostHeader(DISALLOWED_HOST);
+ assertThatThrownBy(() -> mcpClient.listTools());
+ }
+
// ----------------------------------------------------
// Server management
// ----------------------------------------------------
@@ -164,7 +189,7 @@ static Stream transports() {
*/
interface Transport {
- McpSyncClient createMcpClient(String baseUrl, TestOriginHeaderExchangeFilterFunction customizer);
+ McpSyncClient createMcpClient(String baseUrl, TestHeaderExchangeFilterFunction customizer);
RouterFunction> routerFunction();
@@ -180,8 +205,10 @@ static class Sse implements Transport {
public Sse() {
transportProvider = WebFluxSseServerTransportProvider.builder()
.messageEndpoint("/mcp/message")
- .securityValidator(
- DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .securityValidator(DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigin("http://localhost:*")
+ .allowedHost("localhost:*")
+ .build())
.build();
McpServer.sync(transportProvider)
.serverInfo("test-server", "1.0.0")
@@ -190,8 +217,7 @@ public Sse() {
}
@Override
- public McpSyncClient createMcpClient(String baseUrl,
- TestOriginHeaderExchangeFilterFunction exchangeFilterFunction) {
+ public McpSyncClient createMcpClient(String baseUrl, TestHeaderExchangeFilterFunction exchangeFilterFunction) {
var transport = WebFluxSseClientTransport
.builder(WebClient.builder().baseUrl(baseUrl).filter(exchangeFilterFunction))
.jsonMapper(McpJsonMapper.getDefault())
@@ -212,8 +238,10 @@ static class StreamableHttp implements Transport {
public StreamableHttp() {
transportProvider = WebFluxStreamableServerTransportProvider.builder()
- .securityValidator(
- DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .securityValidator(DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigin("http://localhost:*")
+ .allowedHost("localhost:*")
+ .build())
.build();
McpServer.sync(transportProvider)
.serverInfo("test-server", "1.0.0")
@@ -222,8 +250,7 @@ public StreamableHttp() {
}
@Override
- public McpSyncClient createMcpClient(String baseUrl,
- TestOriginHeaderExchangeFilterFunction exchangeFilterFunction) {
+ public McpSyncClient createMcpClient(String baseUrl, TestHeaderExchangeFilterFunction exchangeFilterFunction) {
var transport = WebClientStreamableHttpTransport
.builder(WebClient.builder().baseUrl(baseUrl).filter(exchangeFilterFunction))
.jsonMapper(McpJsonMapper.getDefault())
@@ -245,8 +272,10 @@ static class Stateless implements Transport {
public Stateless() {
transportProvider = WebFluxStatelessServerTransport.builder()
- .securityValidator(
- DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .securityValidator(DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigin("http://localhost:*")
+ .allowedHost("localhost:*")
+ .build())
.build();
McpServer.sync(transportProvider)
.serverInfo("test-server", "1.0.0")
@@ -255,8 +284,7 @@ public Stateless() {
}
@Override
- public McpSyncClient createMcpClient(String baseUrl,
- TestOriginHeaderExchangeFilterFunction exchangeFilterFunction) {
+ public McpSyncClient createMcpClient(String baseUrl, TestHeaderExchangeFilterFunction exchangeFilterFunction) {
var transport = WebClientStreamableHttpTransport
.builder(WebClient.builder().baseUrl(baseUrl).filter(exchangeFilterFunction))
.jsonMapper(McpJsonMapper.getDefault())
@@ -272,18 +300,30 @@ public RouterFunction> routerFunction() {
}
- static class TestOriginHeaderExchangeFilterFunction implements ExchangeFilterFunction {
+ static class TestHeaderExchangeFilterFunction implements ExchangeFilterFunction {
private String origin = null;
+ private String host = null;
+
public void setOriginHeader(String origin) {
this.origin = origin;
}
+ public void setHostHeader(String host) {
+ this.host = host;
+ }
+
@Override
public Mono filter(ClientRequest request, ExchangeFunction next) {
- var updatedRequest = ClientRequest.from(request).header("origin", this.origin).build();
- return next.exchange(updatedRequest);
+ var builder = ClientRequest.from(request);
+ if (this.origin != null) {
+ builder.header("Origin", this.origin);
+ }
+ if (this.host != null) {
+ builder.header("Host", this.host);
+ }
+ return next.exchange(builder.build());
}
}
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java
index 9615547d3..23b0e7e37 100644
--- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java
@@ -62,6 +62,8 @@ public class ServerTransportSecurityIntegrationTests {
private static final String DISALLOWED_ORIGIN = "https://malicious.example.com";
+ private static final String DISALLOWED_HOST = "malicious.example.com:8080";
+
@Parameter
private static Class> configClass;
@@ -89,6 +91,7 @@ static void afterAll() {
void setUp() {
mcpClient = tomcatServer.appContext().getBean(McpSyncClient.class);
requestCustomizer = tomcatServer.appContext().getBean(TestRequestCustomizer.class);
+ requestCustomizer.reset();
}
@AfterEach
@@ -130,6 +133,29 @@ void messageOriginNotAllowed() {
assertThatThrownBy(() -> mcpClient.listTools());
}
+ @Test
+ void hostAllowed() {
+ // Host header is set by default by HttpClient to the request URI host
+ var result = mcpClient.initialize();
+ var tools = mcpClient.listTools();
+
+ assertThat(result.protocolVersion()).isNotEmpty();
+ assertThat(tools.tools()).isEmpty();
+ }
+
+ @Test
+ void connectHostNotAllowed() {
+ requestCustomizer.setHostHeader(DISALLOWED_HOST);
+ assertThatThrownBy(() -> mcpClient.initialize());
+ }
+
+ @Test
+ void messageHostNotAllowed() {
+ mcpClient.initialize();
+ requestCustomizer.setHostHeader(DISALLOWED_HOST);
+ assertThatThrownBy(() -> mcpClient.listTools());
+ }
+
// ----------------------------------------------------
// Tomcat management
// ----------------------------------------------------
@@ -194,7 +220,10 @@ TestRequestCustomizer requestCustomizer() {
@Bean
DefaultServerTransportSecurityValidator validator() {
- return DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build();
+ return DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigin("http://localhost:*")
+ .allowedHost("localhost:*")
+ .build();
}
}
@@ -317,18 +346,32 @@ static class TestRequestCustomizer implements McpSyncHttpClientRequestCustomizer
private String originHeader = null;
+ private String hostHeader = null;
+
@Override
public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body,
McpTransportContext context) {
if (originHeader != null) {
builder.header("Origin", originHeader);
}
+ if (hostHeader != null) {
+ builder.header("Host", hostHeader);
+ }
}
public void setOriginHeader(String originHeader) {
this.originHeader = originHeader;
}
+ public void setHostHeader(String hostHeader) {
+ this.hostHeader = hostHeader;
+ }
+
+ public void reset() {
+ this.originHeader = null;
+ this.hostHeader = null;
+ }
+
}
}