kafka SslPrincipalMapper 源码

  • 2022-10-20
  • 浏览 (425)

kafka SslPrincipalMapper 代码

文件路径:/clients/src/main/java/org/apache/kafka/common/security/ssl/SslPrincipalMapper.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.kafka.common.security.ssl;

import java.io.IOException;
import java.util.List;
import java.util.ArrayList;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.DEFAULT_SSL_PRINCIPAL_MAPPING_RULES;

public class SslPrincipalMapper {

    private static final String RULE_PATTERN = "(DEFAULT)|RULE:((\\\\.|[^\\\\/])*)/((\\\\.|[^\\\\/])*)/([LU]?).*?|(.*?)";
    private static final Pattern RULE_SPLITTER = Pattern.compile("\\s*(" + RULE_PATTERN + ")\\s*(,\\s*|$)");
    private static final Pattern RULE_PARSER = Pattern.compile(RULE_PATTERN);

    private final List<Rule> rules;

    public SslPrincipalMapper(String sslPrincipalMappingRules) {
        this.rules = parseRules(splitRules(sslPrincipalMappingRules));
    }

    public static SslPrincipalMapper fromRules(String sslPrincipalMappingRules) {
        return new SslPrincipalMapper(sslPrincipalMappingRules);
    }

    private static List<String> splitRules(String sslPrincipalMappingRules) {
        if (sslPrincipalMappingRules == null) {
            sslPrincipalMappingRules = DEFAULT_SSL_PRINCIPAL_MAPPING_RULES;
        }

        List<String> result = new ArrayList<>();
        Matcher matcher = RULE_SPLITTER.matcher(sslPrincipalMappingRules.trim());
        while (matcher.find()) {
            result.add(matcher.group(1));
        }

        return result;
    }

    private static List<Rule> parseRules(List<String> rules) {
        List<Rule> result = new ArrayList<>();
        for (String rule : rules) {
            Matcher matcher = RULE_PARSER.matcher(rule);
            if (!matcher.lookingAt()) {
                throw new IllegalArgumentException("Invalid rule: " + rule);
            }
            if (rule.length() != matcher.end()) {
                throw new IllegalArgumentException("Invalid rule: `" + rule + "`, unmatched substring: `" + rule.substring(matcher.end()) + "`");
            }

            // empty rules are ignored
            if (matcher.group(1) != null) {
                result.add(new Rule());
            } else if (matcher.group(2) != null) {
                result.add(new Rule(matcher.group(2),
                                    matcher.group(4),
                                    "L".equals(matcher.group(6)),
                                    "U".equals(matcher.group(6))));
            }
        }

        return result;
    }

    public String getName(String distinguishedName) throws IOException {
        for (Rule r : rules) {
            String principalName = r.apply(distinguishedName);
            if (principalName != null) {
                return principalName;
            }
        }
        throw new NoMatchingRule("No rules apply to " + distinguishedName + ", rules " + rules);
    }

    @Override
    public String toString() {
        return "SslPrincipalMapper(rules = " + rules + ")";
    }

    public static class NoMatchingRule extends IOException {
        NoMatchingRule(String msg) {
            super(msg);
        }
    }

    private static class Rule {
        private static final Pattern BACK_REFERENCE_PATTERN = Pattern.compile("\\$(\\d+)");

        private final boolean isDefault;
        private final Pattern pattern;
        private final String replacement;
        private final boolean toLowerCase;
        private final boolean toUpperCase;

        Rule() {
            isDefault = true;
            pattern = null;
            replacement = null;
            toLowerCase = false;
            toUpperCase = false;
        }

        Rule(String pattern, String replacement, boolean toLowerCase, boolean toUpperCase) {
            isDefault = false;
            this.pattern = pattern == null ? null : Pattern.compile(pattern);
            this.replacement = replacement;
            this.toLowerCase = toLowerCase;
            this.toUpperCase = toUpperCase;
        }

        String apply(String distinguishedName) {
            if (isDefault) {
                return distinguishedName;
            }

            String result = null;
            final Matcher m = pattern.matcher(distinguishedName);

            if (m.matches()) {
                result = distinguishedName.replaceAll(pattern.pattern(), escapeLiteralBackReferences(replacement, m.groupCount()));
            }

            if (toLowerCase && result != null) {
                result = result.toLowerCase(Locale.ENGLISH);
            } else if (toUpperCase & result != null) {
                result = result.toUpperCase(Locale.ENGLISH);
            }

            return result;
        }

        //If we find a back reference that is not valid, then we will treat it as a literal string. For example, if we have 3 capturing
        //groups and the Replacement Value has the value is "$1@$4", then we want to treat the $4 as a literal "$4", rather
        //than attempting to use it as a back reference.
        //This method was taken from Apache Nifi project : org.apache.nifi.authorization.util.IdentityMappingUtil
        private String escapeLiteralBackReferences(final String unescaped, final int numCapturingGroups) {
            if (numCapturingGroups == 0) {
                return unescaped;
            }

            String value = unescaped;
            final Matcher backRefMatcher = BACK_REFERENCE_PATTERN.matcher(value);
            while (backRefMatcher.find()) {
                final String backRefNum = backRefMatcher.group(1);
                if (backRefNum.startsWith("0")) {
                    continue;
                }
                int backRefIndex = Integer.parseInt(backRefNum);


                // if we have a replacement value like $123, and we have less than 123 capturing groups, then
                // we want to truncate the 3 and use capturing group 12; if we have less than 12 capturing groups,
                // then we want to truncate the 2 and use capturing group 1; if we don't have a capturing group then
                // we want to truncate the 1 and get 0.
                while (backRefIndex > numCapturingGroups && backRefIndex >= 10) {
                    backRefIndex /= 10;
                }

                if (backRefIndex > numCapturingGroups) {
                    final StringBuilder sb = new StringBuilder(value.length() + 1);
                    final int groupStart = backRefMatcher.start(1);

                    sb.append(value.substring(0, groupStart - 1));
                    sb.append("\\");
                    sb.append(value.substring(groupStart - 1));
                    value = sb.toString();
                }
            }

            return value;
        }

        @Override
        public String toString() {
            StringBuilder buf = new StringBuilder();
            if (isDefault) {
                buf.append("DEFAULT");
            } else {
                buf.append("RULE:");
                if (pattern != null) {
                    buf.append(pattern);
                }
                if (replacement != null) {
                    buf.append("/");
                    buf.append(replacement);
                }
                if (toLowerCase) {
                    buf.append("/L");
                } else if (toUpperCase) {
                    buf.append("/U");
                }
            }
            return buf.toString();
        }

    }
}

相关信息

kafka 源码目录

相关文章

kafka DefaultSslEngineFactory 源码

kafka SslFactory 源码

0  赞