假设我们有关键词abcd、ce,待匹配的文本为abcefg。
正常trie匹配流程是从第一个字符a开始匹配,找到了abcd这个关键词分支,在匹配第四个字符e时,在这个关键词分支c的子节点中没有匹配到字符e,此时会把trie树节点重置到root,从第二个字符b重新开始匹配,此时b和c字符都已经在trie树中匹配过一次。
Fail指针的作用就是可以不用走回头路继续匹配e或者f。我们已经匹配了文本中的abc支线,那能否同时知道Trie中是否存在bc支线或者c支线的关键词了,如果知道我们就可以不用再次从root节点开始去重新匹配b和c了,直接定位到bc或者c支线继续匹配e。abc支线中c节点的Fail指针就指向bc或者c支线的c节点,如果不存在bc或者c节点,则指向root节点。在c的Fail节点去继续匹配e。
public class AcTrieNode
{
/// <summary>
/// 子
/// </summary>
public Dictionary<char, AcTrieNode> Children { get; set; } = new();
/// <summary>
/// 失败指针
/// </summary>
public AcTrieNode? Fail { get; set; }
/// <summary>
/// 可匹配的所有关键词(包含当前关键词以及有效失败指针所有的output列表)
/// </summary>
public List<string> Output { get; set; } = new();
}
public class AcAutomaton
{
private readonly AcTrieNode _root;
public AcAutomaton()
{
_root = new AcTrieNode();
}
/// <summary>
/// 关键词加入到Trie树中
/// </summary>
/// <param name="word"></param>
public void Insert(string word)
{
var current = _root;
foreach (var c in word)
{
if (!current.Children.ContainsKey(c))
{
current.Children[c] = new AcTrieNode();
}
current = current.Children[c];
}
//关键词加入到当前结束节点的Output集合中
current.Output.Add(word);
}
/// <summary>
/// 给每个节点设置失败指针,root的失败指针为null
/// 一层一层的遍历给节点设置失败指针
/// </summary>
public void BuildFailurePointers()
{
var queue = new Queue<AcTrieNode>();
//root的子节点的失败指针Fail都是指向root
foreach (var node in _root.Children.Values)
{
node.Fail = _root;
//当前子节点加入队列中
queue.Enqueue(node);
}
while (queue.Count > 0)
{
//取出队列中的一个节点
var current = queue.Dequeue();
foreach (var entry in current.Children)
{
var c = entry.Key; //child的关键字
var child = entry.Value;
var fail = current.Fail;
//从当前节点的fail节点开始寻找是否存在与c匹配的子节点
while (fail != null && !fail.Children.ContainsKey(c))
{
// 不存在,则继续沿着失败指针链向上回溯
fail = fail.Fail;
}
if (fail == null)
{
//fail不存在则将child的失败指针指向root
child.Fail = _root;
}
else
{
//存在,则将child的失败指针指向fail中找到的子节点
child.Fail = fail.Children[c];
//并继承fail节点的Output列表(child所组成的关键词是包含fail所组成的关键词)
child.Output.AddRange(child.Fail.Output);
}
//将child 加入队列中
queue.Enqueue(child);
}
}
}
/// <summary>
/// 搜索文本,返回高亮的文本和匹配到的关键词
/// </summary>
/// <param name="text"></param>
/// <returns></returns>
public (string highlightedText, Dictionary<string, int> matchedKeywords) Search(string text)
{
var node = _root;
var length = text.Length;
var highlight = new bool[length];
var keywordMatches = new Dictionary<string, int>();
for (var i = 0; i < length; i++)
{
var c = text[i];
//子节点中匹配c
while (node != null && !node.Children.ContainsKey(c))
{
//匹配不到时,沿着fail指针回溯
node = node.Fail;
}
if (node == null)
{
//找不到匹配的节点,则从root节点开始重新匹配
node = _root;
continue;
}
node = node.Children[c];
//标记匹配到的关键词
MarkKeyword(highlight, keywordMatches, i, node.Output);
}
var highlightedText = Highlight(text, highlight);
return (highlightedText, keywordMatches);
}
/// <summary>
/// 标记匹配上的关键词
/// </summary>
/// <param name="highlight"></param>
/// <param name="keywordMatches"></param>
/// <param name="index"></param>
/// <param name="output"></param>
private void MarkKeyword(bool[] highlight, Dictionary<string, int> keywordMatches, int index, List<string> output)
{
// 取当前节点output列表所有匹配
foreach (var keyword in output)
{
MarkKeyword(highlight, keywordMatches, index, keyword);
}
}
/// <summary>
/// 标记匹配上的关键词
/// </summary>
/// <param name="highlight"></param>
/// <param name="keywordMatches"></param>
/// <param name="index"></param>
/// <param name="keyword"></param>
private void MarkKeyword(bool[] highlight, Dictionary<string, int> keywordMatches, int index, string keyword)
{
if (!keywordMatches.ContainsKey(keyword))
{
keywordMatches[keyword] = 0;
}
keywordMatches[keyword]++;
var startPos = index - keyword.Length + 1;
for (var k = startPos; k <= index; k++)
{
highlight[k] = true;
}
}
/// <summary>
/// 文字高亮前缀
/// </summary>
private const string HighlightPrefix = "<b>";
/// <summary>
/// 文字高亮后缀
/// </summary>
private const string HighlightSuffix = "</b>";
/// <summary>
/// 高亮文本
/// </summary>
/// <param name="text"></param>
/// <param name="highlight"></param>
/// <returns></returns>
private string Highlight(string text, bool[] highlight)
{
var length = text.Length;
var sb = new StringBuilder();
var inHighlight = false;
for (var i = 0; i < length; i++)
{
if (highlight[i] && !inHighlight)
{
sb.Append(HighlightPrefix);
inHighlight = true;
}
else if (!highlight[i] && inHighlight)
{
sb.Append(HighlightSuffix);
inHighlight = false;
}
sb.Append(text[i]);
}
if (inHighlight)
{
sb.Append(HighlightSuffix);
}
return sb.ToString();
}
}
/**
* AC TireNode
*
*/
public class AcTrieNode {
private Map<Character, AcTrieNode> children = new HashMap<>();
private AcTrieNode fail;
private List<String> outPut = new ArrayList<>();
public Map<Character, AcTrieNode> getChildren() {
return children;
}
public void putIfAbsent(char c, AcTrieNode node) {
this.children.putIfAbsent(c, node);
}
public AcTrieNode getChildren(char c) {
return children.get(c);
}
public boolean hasChild(char c) {
return children.containsKey(c);
}
public AcTrieNode getFail() {
return fail;
}
public void setFail(AcTrieNode fail) {
this.fail = fail;
}
public List<String> getOutPut() {
return outPut;
}
public void addOutput(String keyword) {
this.outPut.add(keyword);
}
public void addAll(List<String> output) {
this.outPut.addAll(output);
}
}
/**
* AC自动机类
*
*/
public class AcAutomaton {
private final AcTrieNode root;
public AcAutomaton() {
this.root = new AcTrieNode();
}
public AcAutomaton(List<String> keywords) {
this();
init(keywords);
}
/**
* 初始化Trie树,构建Fail指针
*
*/
public void init(List<String> keywords) {
for (String keyword : keywords) {
insert(keyword);
}
buildFailPointer();
}
/**
* 新增关键词
*
*/
public void insert(String word) {
AcTrieNode node = root;
for (char c : word.toCharArray()) {
node.putIfAbsent(c, new AcTrieNode());
node = node.getChildren(c);
}
node.addOutput(word);
}
/**
* 构建Fail指针
*
*/
public void buildFailPointer() {
Queue<AcTrieNode> queue = new LinkedList<>();
for (Map.Entry<Character, AcTrieNode> entry : root.getChildren().entrySet()) {
entry.getValue().setFail(root);
queue.offer(entry.getValue());
}
while (!queue.isEmpty()) {
AcTrieNode node = queue.poll();
for (Map.Entry<Character, AcTrieNode> entry : node.getChildren().entrySet()) {
Character c = entry.getKey();
AcTrieNode childNode = entry.getValue();
AcTrieNode failNode = node.getFail();
while (failNode != null && !failNode.hasChild(c)) {
failNode = failNode.getFail();
}
if (failNode == null) {
childNode.setFail(root);
} else {
AcTrieNode childFailNode = failNode.getChildren(c);
childNode.setFail(childFailNode);
childNode.addAll(childFailNode.getOutPut());
}
queue.offer(childNode);
}
}
}
/**
* 高亮
*
*/
public HighlighterResult highlight(String text) {
AcTrieNode node = root;
int textLength = text.length();
boolean[] highlight = new boolean[textLength];
Set<String> matchedKeywords = new HashSet<>();
for (int i = 0; i < textLength; i++) {
Character c = text.charAt(i);
while (node != null && !node.hasChild(c)) {
//匹配不到时,使用Fail指针换支线开始匹配,一直到根节点
node = node.getFail();
}
if (node == null) {
node = root;
continue;
}
node = node.getChildren(c);
markHighlight(node, i, highlight, matchedKeywords);
}
return new HighlighterResult(text, highlight, matchedKeywords);
}
/**
* 添加匹配到的关键词,并标记高亮位置
*
*/
private void markHighlight(AcTrieNode node, int index, boolean[] highlight,Set<String> matchedKeywords) {
for (String s : node.getOutPut()) {
matchedKeywords.add(s);
int startPos = index - s.length() + 1;
for (int k = startPos; k <= index; k++) {
highlight[k] = true;
}
}
}
}
/**
* 高亮返回对象
*
*/
public class HighlighterResult {
private String highlighterText;
private Set<String> matchedKeywords;
public HighlighterResult(String text, boolean[] highlight, Set<String> matchedKeywords) {
this.highlighterText = doHighlight(text, highlight);
this.matchedKeywords = matchedKeywords;
}
public String getHighlighterText() {
return highlighterText;
}
public Set<String> getMatchedKeywords() {
return matchedKeywords;
}
/**
* 高亮前缀
*/
private static final String HIGHLIGHTER_PREFIX = "<span style=\"color:red\">";
/**
* 高亮后缀
*/
private static final String HIGHLIGHTER_SUFFIX = "</span>";
/**
* 根据标识高亮文本
*
*/
private String doHighlight(String text,boolean[] highlight) {
StringBuilder sb = new StringBuilder();
boolean inHighlight = false;
for (int i = 0; i < text.length(); i++) {
if (highlight[i] && !inHighlight) {
sb.append(HIGHLIGHTER_PREFIX);
inHighlight = true;
} else if (!highlight[i] && inHighlight) {
sb.append(HIGHLIGHTER_SUFFIX);
}
sb.append(text.charAt(i));
}
if (inHighlight) {
sb.append(HIGHLIGHTER_SUFFIX);
}
return sb.toString();
}
}