Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ public static void main(String[] args) throws Exception {
.builder()
.mcpEndpoint(MCP_ENDPOINT)
.keepAliveInterval(Duration.ofSeconds(30))
.securityValidator(
DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
.securityValidator(DefaultServerTransportSecurityValidator.builder()
.allowedOrigin("http://localhost:*")
.allowedHost("localhost:*")
.build())
.build();

// Build server with all conformance test features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

/**
* Default implementation of {@link ServerTransportSecurityValidator} that validates the
* Origin header against a list of allowed origins.
* Origin and Host headers against lists of allowed values.
*
* <p>
* Supports exact matches and wildcard port patterns (e.g., "http://example.com:*").
* Supports exact matches and wildcard port patterns (e.g., "http://example.com:*" for
* origins, "example.com:*" for hosts).
*
* @author Daniel Garnier-Moiroux
* @see ServerTransportSecurityValidator
Expand All @@ -25,32 +26,55 @@ public class DefaultServerTransportSecurityValidator implements ServerTransportS

private static final String ORIGIN_HEADER = "Origin";

private static final String HOST_HEADER = "Host";

private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403,
"Invalid Origin header");

private static final ServerTransportSecurityException INVALID_HOST = new ServerTransportSecurityException(421,
"Invalid Host header");

private final List<String> allowedOrigins;

private final List<String> allowedHosts;

/**
* Creates a new validator with the specified allowed origins.
* Creates a new validator with the specified allowed origins and hosts.
* @param allowedOrigins List of allowed origin patterns. Supports exact matches
* (e.g., "http://example.com:8080") and wildcard ports (e.g., "http://example.com:*")
* @param allowedHosts List of allowed host patterns. Supports exact matches (e.g.,
* "example.com:8080") and wildcard ports (e.g., "example.com:*")
*/
public DefaultServerTransportSecurityValidator(List<String> allowedOrigins) {
public DefaultServerTransportSecurityValidator(List<String> allowedOrigins, List<String> allowedHosts) {
Assert.notNull(allowedOrigins, "allowedOrigins must not be null");
Assert.notNull(allowedHosts, "allowedHosts must not be null");
this.allowedOrigins = allowedOrigins;
this.allowedHosts = allowedHosts;
}

@Override
public void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
boolean missingHost = true;
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) {
List<String> values = entry.getValue();
if (values != null && !values.isEmpty()) {
validateOrigin(values.get(0));
if (values == null || values.isEmpty()) {
throw INVALID_ORIGIN;
}
validateOrigin(values.get(0));
}
else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) {
missingHost = false;
List<String> values = entry.getValue();
if (values == null || values.isEmpty()) {
throw INVALID_HOST;
}
break;
validateHost(values.get(0));
}
}
if (!allowedHosts.isEmpty() && missingHost) {
throw INVALID_HOST;
}
}

/**
Expand Down Expand Up @@ -82,6 +106,37 @@ else if (allowed.endsWith(":*")) {
throw INVALID_ORIGIN;
}

/**
* Validates a single host value against the allowed hosts.
* @param host The host header value, or null if not present
* @throws ServerTransportSecurityException if the host is not allowed
*/
private void validateHost(String host) throws ServerTransportSecurityException {
if (allowedHosts.isEmpty()) {
return;
}

// Host is required
if (host == null || host.isBlank()) {
throw INVALID_HOST;
}

for (String allowed : allowedHosts) {
if (allowed.equals(host)) {
return;
}
else if (allowed.endsWith(":*")) {
// Wildcard port pattern: "example.com:*"
String baseHost = allowed.substring(0, allowed.length() - 2);
if (host.equals(baseHost) || host.startsWith(baseHost + ":")) {
return;
}
}
}

throw INVALID_HOST;
}

/**
* Creates a new builder for constructing a DefaultServerTransportSecurityValidator.
* @return A new builder instance
Expand All @@ -97,6 +152,8 @@ public static class Builder {

private final List<String> allowedOrigins = new ArrayList<>();

private final List<String> allowedHosts = new ArrayList<>();

/**
* Adds an allowed origin pattern.
* @param origin The origin to allow (e.g., "http://localhost:8080" or
Expand All @@ -119,12 +176,33 @@ public Builder allowedOrigins(List<String> origins) {
return this;
}

/**
* Adds an allowed host pattern.
* @param host The host to allow (e.g., "localhost:8080" or "example.com:*")
* @return this builder instance
*/
public Builder allowedHost(String host) {
this.allowedHosts.add(host);
return this;
}

/**
* Adds multiple allowed host patterns.
* @param hosts The hosts to allow
* @return this builder instance
*/
public Builder allowedHosts(List<String> hosts) {
Assert.notNull(hosts, "hosts must not be null");
this.allowedHosts.addAll(hosts);
return this;
}

/**
* Builds the validator instance.
* @return A new DefaultServerTransportSecurityValidator
*/
public DefaultServerTransportSecurityValidator build() {
return new DefaultServerTransportSecurityValidator(allowedOrigins);
return new DefaultServerTransportSecurityValidator(allowedOrigins, allowedHosts);
}

}
Expand Down
Loading