diff --git a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java index 411c8ecc5..3d162a5de 100644 --- a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java +++ b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java @@ -67,8 +67,10 @@ public static void main(String[] args) throws Exception { .builder() .mcpEndpoint(MCP_ENDPOINT) .keepAliveInterval(Duration.ofSeconds(30)) - .securityValidator( - DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .securityValidator(DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build()) .build(); // Build server with all conformance test features diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java index 5321aada7..db1b4f75e 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java @@ -12,10 +12,11 @@ /** * Default implementation of {@link ServerTransportSecurityValidator} that validates the - * Origin header against a list of allowed origins. + * Origin and Host headers against lists of allowed values. * *

- * Supports exact matches and wildcard port patterns (e.g., "http://example.com:*"). + * Supports exact matches and wildcard port patterns (e.g., "http://example.com:*" for + * origins, "example.com:*" for hosts). * * @author Daniel Garnier-Moiroux * @see ServerTransportSecurityValidator @@ -25,32 +26,55 @@ public class DefaultServerTransportSecurityValidator implements ServerTransportS private static final String ORIGIN_HEADER = "Origin"; + private static final String HOST_HEADER = "Host"; + private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403, "Invalid Origin header"); + private static final ServerTransportSecurityException INVALID_HOST = new ServerTransportSecurityException(421, + "Invalid Host header"); + private final List allowedOrigins; + private final List allowedHosts; + /** - * Creates a new validator with the specified allowed origins. + * Creates a new validator with the specified allowed origins and hosts. * @param allowedOrigins List of allowed origin patterns. Supports exact matches * (e.g., "http://example.com:8080") and wildcard ports (e.g., "http://example.com:*") + * @param allowedHosts List of allowed host patterns. Supports exact matches (e.g., + * "example.com:8080") and wildcard ports (e.g., "example.com:*") */ - public DefaultServerTransportSecurityValidator(List allowedOrigins) { + public DefaultServerTransportSecurityValidator(List allowedOrigins, List allowedHosts) { Assert.notNull(allowedOrigins, "allowedOrigins must not be null"); + Assert.notNull(allowedHosts, "allowedHosts must not be null"); this.allowedOrigins = allowedOrigins; + this.allowedHosts = allowedHosts; } @Override public void validateHeaders(Map> headers) throws ServerTransportSecurityException { + boolean missingHost = true; for (Map.Entry> entry : headers.entrySet()) { if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) { List values = entry.getValue(); - if (values != null && !values.isEmpty()) { - validateOrigin(values.get(0)); + if (values == null || values.isEmpty()) { + throw INVALID_ORIGIN; + } + validateOrigin(values.get(0)); + } + else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) { + missingHost = false; + List values = entry.getValue(); + if (values == null || values.isEmpty()) { + throw INVALID_HOST; } - break; + validateHost(values.get(0)); } } + if (!allowedHosts.isEmpty() && missingHost) { + throw INVALID_HOST; + } } /** @@ -82,6 +106,37 @@ else if (allowed.endsWith(":*")) { throw INVALID_ORIGIN; } + /** + * Validates a single host value against the allowed hosts. + * @param host The host header value, or null if not present + * @throws ServerTransportSecurityException if the host is not allowed + */ + private void validateHost(String host) throws ServerTransportSecurityException { + if (allowedHosts.isEmpty()) { + return; + } + + // Host is required + if (host == null || host.isBlank()) { + throw INVALID_HOST; + } + + for (String allowed : allowedHosts) { + if (allowed.equals(host)) { + return; + } + else if (allowed.endsWith(":*")) { + // Wildcard port pattern: "example.com:*" + String baseHost = allowed.substring(0, allowed.length() - 2); + if (host.equals(baseHost) || host.startsWith(baseHost + ":")) { + return; + } + } + } + + throw INVALID_HOST; + } + /** * Creates a new builder for constructing a DefaultServerTransportSecurityValidator. * @return A new builder instance @@ -97,6 +152,8 @@ public static class Builder { private final List allowedOrigins = new ArrayList<>(); + private final List allowedHosts = new ArrayList<>(); + /** * Adds an allowed origin pattern. * @param origin The origin to allow (e.g., "http://localhost:8080" or @@ -119,12 +176,33 @@ public Builder allowedOrigins(List origins) { return this; } + /** + * Adds an allowed host pattern. + * @param host The host to allow (e.g., "localhost:8080" or "example.com:*") + * @return this builder instance + */ + public Builder allowedHost(String host) { + this.allowedHosts.add(host); + return this; + } + + /** + * Adds multiple allowed host patterns. + * @param hosts The hosts to allow + * @return this builder instance + */ + public Builder allowedHosts(List hosts) { + Assert.notNull(hosts, "hosts must not be null"); + this.allowedHosts.addAll(hosts); + return this; + } + /** * Builds the validator instance. * @return A new DefaultServerTransportSecurityValidator */ public DefaultServerTransportSecurityValidator build() { - return new DefaultServerTransportSecurityValidator(allowedOrigins); + return new DefaultServerTransportSecurityValidator(allowedOrigins, allowedHosts); } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java index 7e1593e1b..d4cf8582d 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java @@ -22,6 +22,9 @@ class DefaultServerTransportSecurityValidatorTests { private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403, "Invalid Origin header"); + private static final ServerTransportSecurityException INVALID_HOST = new ServerTransportSecurityException(421, + "Invalid Host header"); + private final DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() .allowedOrigin("http://localhost:8080") .build(); @@ -31,161 +34,370 @@ void builder() { assertThatCode(() -> DefaultServerTransportSecurityValidator.builder().build()).doesNotThrowAnyException(); assertThatThrownBy(() -> DefaultServerTransportSecurityValidator.builder().allowedOrigins(null).build()) .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> DefaultServerTransportSecurityValidator.builder().allowedHosts(null).build()) + .isInstanceOf(IllegalArgumentException.class); } - @Test - void originHeaderMissing() { - assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); - } + @Nested + class OriginHeader { - @Test - void originHeaderListEmpty() { - assertThatCode(() -> validator.validateHeaders(Map.of("Origin", List.of()))).doesNotThrowAnyException(); - } + @Test + void originHeaderMissing() { + assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + } - @Test - void caseInsensitive() { - var headers = Map.of("origin", List.of("http://localhost:8080")); + @Test + void originHeaderListEmpty() { + assertThatThrownBy(() -> validator.validateHeaders(Map.of("Origin", List.of()))).isEqualTo(INVALID_ORIGIN); + } - assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); - } + @Test + void caseInsensitive() { + var headers = Map.of("origin", List.of("http://localhost:8080")); - @Test - void exactMatch() { - var headers = originHeader("http://localhost:8080"); + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } - assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); - } + @Test + void exactMatch() { + var headers = originHeader("http://localhost:8080"); - @Test - void differentPort() { + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } - var headers = originHeader("http://localhost:3000"); + @Test + void differentPort() { - assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); - } + var headers = originHeader("http://localhost:3000"); - @Test - void differentHost() { + assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } - var headers = originHeader("http://example.com:8080"); + @Test + void differentHost() { - assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); - } + var headers = originHeader("http://example.com:8080"); - @Test - void differentScheme() { + assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Test + void differentScheme() { + + var headers = originHeader("https://localhost:8080"); + + assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Nested + class WildcardPort { + + private final DefaultServerTransportSecurityValidator wildcardValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedOrigin("http://localhost:*") + .build(); + + @Test + void anyPortWithWildcard() { + var headers = originHeader("http://localhost:3000"); + + assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void noPortWithWildcard() { + var headers = originHeader("http://localhost"); + + assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentPortWithWildcard() { + var headers = originHeader("http://localhost:8080"); + + assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentHostWithWildcard() { + var headers = originHeader("http://example.com:3000"); + + assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Test + void differentSchemeWithWildcard() { + var headers = originHeader("https://localhost:3000"); + + assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + } + + @Nested + class MultipleOrigins { + + DefaultServerTransportSecurityValidator multipleOriginsValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedOrigin("http://localhost:8080") + .allowedOrigin("http://example.com:3000") + .allowedOrigin("http://myapp.example.com:*") + .build(); + + @Test + void matchingOneOfMultiple() { + var headers = originHeader("http://example.com:3000"); + + assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void matchingWildcardInMultiple() { + var headers = originHeader("http://myapp.example.com:9999"); + + assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void notMatchingAny() { + var headers = originHeader("http://malicious.example.com:1234"); - var headers = originHeader("https://localhost:8080"); + assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + } + + @Nested + class BuilderTests { + + @Test + void shouldAddMultipleOriginsWithAllowedOriginsMethod() { + DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() + .allowedOrigins(List.of("http://localhost:8080", "http://example.com:*")) + .build(); + + var headers = originHeader("http://example.com:3000"); + + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void shouldCombineAllowedOriginMethods() { + DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:8080") + .allowedOrigins(List.of("http://example.com:*", "http://test.com:3000")) + .build(); + + assertThatCode(() -> validator.validateHeaders(originHeader("http://localhost:8080"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(originHeader("http://example.com:9999"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(originHeader("http://test.com:3000"))) + .doesNotThrowAnyException(); + } + + } - assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); } @Nested - class WildcardPort { + class HostHeader { - private final DefaultServerTransportSecurityValidator wildcardValidator = DefaultServerTransportSecurityValidator + private final DefaultServerTransportSecurityValidator hostValidator = DefaultServerTransportSecurityValidator .builder() - .allowedOrigin("http://localhost:*") + .allowedHost("localhost:8080") .build(); @Test - void anyPortWithWildcard() { - var headers = originHeader("http://localhost:3000"); + void notConfigured() { + assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + } - assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + @Test + void missing() { + assertThatThrownBy(() -> hostValidator.validateHeaders(new HashMap<>())).isEqualTo(INVALID_HOST); } @Test - void noPortWithWildcard() { - var headers = originHeader("http://localhost"); + void listEmpty() { + assertThatThrownBy(() -> hostValidator.validateHeaders(Map.of("Host", List.of()))).isEqualTo(INVALID_HOST); + } - assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + @Test + void caseInsensitive() { + var headers = Map.of("host", List.of("localhost:8080")); + + assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException(); } @Test - void differentPortWithWildcard() { - var headers = originHeader("http://localhost:8080"); + void exactMatch() { + var headers = hostHeader("localhost:8080"); - assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException(); } @Test - void differentHostWithWildcard() { - var headers = originHeader("http://example.com:3000"); + void differentPort() { + var headers = hostHeader("localhost:3000"); - assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); } @Test - void differentSchemeWithWildcard() { - var headers = originHeader("https://localhost:3000"); + void differentHost() { + var headers = hostHeader("example.com:8080"); + + assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + } + + @Nested + class HostWildcardPort { + + private final DefaultServerTransportSecurityValidator wildcardHostValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedHost("localhost:*") + .build(); + + @Test + void anyPort() { + var headers = hostHeader("localhost:3000"); + + assertThatCode(() -> wildcardHostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void noPort() { + var headers = hostHeader("localhost"); + + assertThatCode(() -> wildcardHostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentHost() { + var headers = hostHeader("example.com:3000"); + + assertThatThrownBy(() -> wildcardHostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + } + + } + + @Nested + class MultipleHosts { + + DefaultServerTransportSecurityValidator multipleHostsValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedHost("example.com:3000") + .allowedHost("myapp.example.com:*") + .build(); + + @Test + void exactMatch() { + var headers = hostHeader("example.com:3000"); + + assertThatCode(() -> multipleHostsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void wildcard() { + var headers = hostHeader("myapp.example.com:9999"); + + assertThatCode(() -> multipleHostsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentHost() { + var headers = hostHeader("malicious.example.com:3000"); + + assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + } + + @Test + void differentPort() { + var headers = hostHeader("localhost:8080"); + + assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + } + + } + + @Nested + class HostBuilderTests { + + @Test + void multipleHosts() { + DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() + .allowedHosts(List.of("localhost:8080", "example.com:*")) + .build(); + + assertThatCode(() -> validator.validateHeaders(hostHeader("example.com:3000"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(hostHeader("localhost:8080"))) + .doesNotThrowAnyException(); + } + + @Test + void combined() { + DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() + .allowedHost("localhost:8080") + .allowedHosts(List.of("example.com:*", "test.com:3000")) + .build(); + + assertThatCode(() -> validator.validateHeaders(hostHeader("localhost:8080"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(hostHeader("example.com:9999"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(hostHeader("test.com:3000"))).doesNotThrowAnyException(); + } - assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); } } @Nested - class MultipleOrigins { + class CombinedOriginAndHostValidation { - DefaultServerTransportSecurityValidator multipleOriginsValidator = DefaultServerTransportSecurityValidator + private final DefaultServerTransportSecurityValidator combinedValidator = DefaultServerTransportSecurityValidator .builder() - .allowedOrigin("http://localhost:8080") - .allowedOrigin("http://example.com:3000") - .allowedOrigin("http://myapp.com:*") + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") .build(); @Test - void matchingOneOfMultiple() { - var headers = originHeader("http://example.com:3000"); + void bothValid() { + var header = headers("http://localhost:8080", "localhost:8080"); - assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> combinedValidator.validateHeaders(header)).doesNotThrowAnyException(); } @Test - void matchingWildcardInMultiple() { - var headers = originHeader("http://myapp.com:9999"); + void originValidHostInvalid() { + var header = headers("http://localhost:8080", "malicious.example.com:8080"); - assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_HOST); } @Test - void notMatchingAny() { - var headers = originHeader("http://malicious.example.com:1234"); + void originInvalidHostValid() { + var header = headers("http://malicious.example.com:8080", "localhost:8080"); - assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_ORIGIN); } - } - - @Nested - class BuilderTests { - @Test - void shouldAddMultipleOriginsWithAllowedOriginsMethod() { - DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() - .allowedOrigins(List.of("http://localhost:8080", "http://example.com:*")) - .build(); - - var headers = originHeader("http://example.com:3000"); + void originMissingHostValid() { + // Origin missing is OK (same-origin request) + var header = headers(null, "localhost:8080"); - assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> combinedValidator.validateHeaders(header)).doesNotThrowAnyException(); } @Test - void shouldCombineAllowedOriginMethods() { - DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() - .allowedOrigin("http://localhost:8080") - .allowedOrigins(List.of("http://example.com:*", "http://test.com:3000")) - .build(); + void originValidHostMissing() { + // Host missing is NOT OK when allowedHosts is configured + var header = headers("http://localhost:8080", null); - assertThatCode(() -> validator.validateHeaders(originHeader("http://localhost:8080"))) - .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(originHeader("http://example.com:9999"))) - .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(originHeader("http://test.com:3000"))) - .doesNotThrowAnyException(); + assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_HOST); } } @@ -194,4 +406,19 @@ private static Map> originHeader(String origin) { return Map.of("Origin", List.of(origin)); } + private static Map> hostHeader(String host) { + return Map.of("Host", List.of(host)); + } + + private static Map> headers(String origin, String host) { + var map = new HashMap>(); + if (origin != null) { + map.put("Origin", List.of(origin)); + } + if (host != null) { + map.put("Host", List.of(host)); + } + return map; + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java index e9e64c0d0..9f5b9b30d 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java @@ -48,6 +48,8 @@ class ServerTransportSecurityIntegrationTests { private static final String DISALLOWED_ORIGIN = "https://malicious.example.com"; + private static final String DISALLOWED_HOST = "malicious.example.com:8080"; + @Parameter private static Transport transport; @@ -73,6 +75,7 @@ static void afterAll() { @BeforeEach void setUp() { + requestCustomizer.reset(); mcpClient = transport.createMcpClient(baseUrl, requestCustomizer); } @@ -115,6 +118,29 @@ void messageOriginNotAllowed() { assertThatThrownBy(() -> mcpClient.listTools()); } + @Test + void hostAllowed() { + // Host header is set by default by HttpClient to the request URI host + var result = mcpClient.initialize(); + var tools = mcpClient.listTools(); + + assertThat(result.protocolVersion()).isNotEmpty(); + assertThat(tools.tools()).isEmpty(); + } + + @Test + void connectHostNotAllowed() { + requestCustomizer.setHostHeader(DISALLOWED_HOST); + assertThatThrownBy(() -> mcpClient.initialize()); + } + + @Test + void messageHostNotAllowed() { + mcpClient.initialize(); + requestCustomizer.setHostHeader(DISALLOWED_HOST); + assertThatThrownBy(() -> mcpClient.listTools()); + } + // ---------------------------------------------------- // Tomcat management // ---------------------------------------------------- @@ -182,8 +208,10 @@ static class Sse implements Transport { public Sse() { transport = HttpServletSseServerTransportProvider.builder() .messageEndpoint("/mcp/message") - .securityValidator( - DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .securityValidator(DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build()) .build(); McpServer.sync(transport) .serverInfo("test-server", "1.0.0") @@ -213,8 +241,10 @@ static class StreamableHttp implements Transport { public StreamableHttp() { transport = HttpServletStreamableServerTransportProvider.builder() - .securityValidator( - DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .securityValidator(DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build()) .build(); McpServer.sync(transport) .serverInfo("test-server", "1.0.0") @@ -245,8 +275,10 @@ static class Stateless implements Transport { public Stateless() { transport = HttpServletStatelessServerTransport.builder() - .securityValidator( - DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .securityValidator(DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build()) .build(); McpServer.sync(transport) .serverInfo("test-server", "1.0.0") @@ -275,18 +307,33 @@ static class TestRequestCustomizer implements McpSyncHttpClientRequestCustomizer private String originHeader = null; + private String hostHeader = null; + @Override public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body, McpTransportContext context) { if (originHeader != null) { builder.header("Origin", originHeader); } + if (hostHeader != null) { + // HttpClient normally sets Host automatically, but we can override it + builder.header("Host", hostHeader); + } } public void setOriginHeader(String originHeader) { this.originHeader = originHeader; } + public void setHostHeader(String hostHeader) { + this.hostHeader = hostHeader; + } + + public void reset() { + this.originHeader = null; + this.hostHeader = null; + } + } } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java index 06e1286d2..1a331bed5 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java @@ -57,6 +57,8 @@ public class WebFluxServerTransportSecurityIntegrationTests { private static final String DISALLOWED_ORIGIN = "https://malicious.example.com"; + private static final String DISALLOWED_HOST = "malicious.example.com:8080"; + @Parameter private static Transport transport; @@ -78,7 +80,7 @@ static void afterAll() { private McpSyncClient mcpClient; - private final TestOriginHeaderExchangeFilterFunction exchangeFilterFunction = new TestOriginHeaderExchangeFilterFunction(); + private final TestHeaderExchangeFilterFunction exchangeFilterFunction = new TestHeaderExchangeFilterFunction(); @BeforeEach void setUp() { @@ -124,6 +126,29 @@ void messageOriginNotAllowed() { assertThatThrownBy(() -> mcpClient.listTools()); } + @Test + void hostAllowed() { + // Host header is set by default by WebClient to the request URI host + var result = mcpClient.initialize(); + var tools = mcpClient.listTools(); + + assertThat(result.protocolVersion()).isNotEmpty(); + assertThat(tools.tools()).isEmpty(); + } + + @Test + void connectHostNotAllowed() { + exchangeFilterFunction.setHostHeader(DISALLOWED_HOST); + assertThatThrownBy(() -> mcpClient.initialize()); + } + + @Test + void messageHostNotAllowed() { + mcpClient.initialize(); + exchangeFilterFunction.setHostHeader(DISALLOWED_HOST); + assertThatThrownBy(() -> mcpClient.listTools()); + } + // ---------------------------------------------------- // Server management // ---------------------------------------------------- @@ -164,7 +189,7 @@ static Stream transports() { */ interface Transport { - McpSyncClient createMcpClient(String baseUrl, TestOriginHeaderExchangeFilterFunction customizer); + McpSyncClient createMcpClient(String baseUrl, TestHeaderExchangeFilterFunction customizer); RouterFunction routerFunction(); @@ -180,8 +205,10 @@ static class Sse implements Transport { public Sse() { transportProvider = WebFluxSseServerTransportProvider.builder() .messageEndpoint("/mcp/message") - .securityValidator( - DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .securityValidator(DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build()) .build(); McpServer.sync(transportProvider) .serverInfo("test-server", "1.0.0") @@ -190,8 +217,7 @@ public Sse() { } @Override - public McpSyncClient createMcpClient(String baseUrl, - TestOriginHeaderExchangeFilterFunction exchangeFilterFunction) { + public McpSyncClient createMcpClient(String baseUrl, TestHeaderExchangeFilterFunction exchangeFilterFunction) { var transport = WebFluxSseClientTransport .builder(WebClient.builder().baseUrl(baseUrl).filter(exchangeFilterFunction)) .jsonMapper(McpJsonMapper.getDefault()) @@ -212,8 +238,10 @@ static class StreamableHttp implements Transport { public StreamableHttp() { transportProvider = WebFluxStreamableServerTransportProvider.builder() - .securityValidator( - DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .securityValidator(DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build()) .build(); McpServer.sync(transportProvider) .serverInfo("test-server", "1.0.0") @@ -222,8 +250,7 @@ public StreamableHttp() { } @Override - public McpSyncClient createMcpClient(String baseUrl, - TestOriginHeaderExchangeFilterFunction exchangeFilterFunction) { + public McpSyncClient createMcpClient(String baseUrl, TestHeaderExchangeFilterFunction exchangeFilterFunction) { var transport = WebClientStreamableHttpTransport .builder(WebClient.builder().baseUrl(baseUrl).filter(exchangeFilterFunction)) .jsonMapper(McpJsonMapper.getDefault()) @@ -245,8 +272,10 @@ static class Stateless implements Transport { public Stateless() { transportProvider = WebFluxStatelessServerTransport.builder() - .securityValidator( - DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .securityValidator(DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build()) .build(); McpServer.sync(transportProvider) .serverInfo("test-server", "1.0.0") @@ -255,8 +284,7 @@ public Stateless() { } @Override - public McpSyncClient createMcpClient(String baseUrl, - TestOriginHeaderExchangeFilterFunction exchangeFilterFunction) { + public McpSyncClient createMcpClient(String baseUrl, TestHeaderExchangeFilterFunction exchangeFilterFunction) { var transport = WebClientStreamableHttpTransport .builder(WebClient.builder().baseUrl(baseUrl).filter(exchangeFilterFunction)) .jsonMapper(McpJsonMapper.getDefault()) @@ -272,18 +300,30 @@ public RouterFunction routerFunction() { } - static class TestOriginHeaderExchangeFilterFunction implements ExchangeFilterFunction { + static class TestHeaderExchangeFilterFunction implements ExchangeFilterFunction { private String origin = null; + private String host = null; + public void setOriginHeader(String origin) { this.origin = origin; } + public void setHostHeader(String host) { + this.host = host; + } + @Override public Mono filter(ClientRequest request, ExchangeFunction next) { - var updatedRequest = ClientRequest.from(request).header("origin", this.origin).build(); - return next.exchange(updatedRequest); + var builder = ClientRequest.from(request); + if (this.origin != null) { + builder.header("Origin", this.origin); + } + if (this.host != null) { + builder.header("Host", this.host); + } + return next.exchange(builder.build()); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java index 9615547d3..23b0e7e37 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java @@ -62,6 +62,8 @@ public class ServerTransportSecurityIntegrationTests { private static final String DISALLOWED_ORIGIN = "https://malicious.example.com"; + private static final String DISALLOWED_HOST = "malicious.example.com:8080"; + @Parameter private static Class configClass; @@ -89,6 +91,7 @@ static void afterAll() { void setUp() { mcpClient = tomcatServer.appContext().getBean(McpSyncClient.class); requestCustomizer = tomcatServer.appContext().getBean(TestRequestCustomizer.class); + requestCustomizer.reset(); } @AfterEach @@ -130,6 +133,29 @@ void messageOriginNotAllowed() { assertThatThrownBy(() -> mcpClient.listTools()); } + @Test + void hostAllowed() { + // Host header is set by default by HttpClient to the request URI host + var result = mcpClient.initialize(); + var tools = mcpClient.listTools(); + + assertThat(result.protocolVersion()).isNotEmpty(); + assertThat(tools.tools()).isEmpty(); + } + + @Test + void connectHostNotAllowed() { + requestCustomizer.setHostHeader(DISALLOWED_HOST); + assertThatThrownBy(() -> mcpClient.initialize()); + } + + @Test + void messageHostNotAllowed() { + mcpClient.initialize(); + requestCustomizer.setHostHeader(DISALLOWED_HOST); + assertThatThrownBy(() -> mcpClient.listTools()); + } + // ---------------------------------------------------- // Tomcat management // ---------------------------------------------------- @@ -194,7 +220,10 @@ TestRequestCustomizer requestCustomizer() { @Bean DefaultServerTransportSecurityValidator validator() { - return DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build(); + return DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build(); } } @@ -317,18 +346,32 @@ static class TestRequestCustomizer implements McpSyncHttpClientRequestCustomizer private String originHeader = null; + private String hostHeader = null; + @Override public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body, McpTransportContext context) { if (originHeader != null) { builder.header("Origin", originHeader); } + if (hostHeader != null) { + builder.header("Host", hostHeader); + } } public void setOriginHeader(String originHeader) { this.originHeader = originHeader; } + public void setHostHeader(String hostHeader) { + this.hostHeader = hostHeader; + } + + public void reset() { + this.originHeader = null; + this.hostHeader = null; + } + } }