今天分享一个LeetCode题,题号是677,标题是键值映射,题目标签是字典树。
实现一个 MapSum 类里的两个方法,insert 和 sum。
对于方法 insert,你将得到一对(字符串,整数)的键值对。字符串表示键,整数表示值。如果键已经存在,那么原来的键值对将被替代成新的键值对。
对于方法 sum,你将得到一个表示前缀的字符串,你需要返回所有以该前缀开头的键的值的总和。
示例 1:
输入: insert("apple", 3), 输出: Null
输入: sum("ap"), 输出: 3
输入: insert("app", 2), 输出: Null
输入: sum("ap"), 输出: 5
此题没有说明字符串包含哪些字符,不然可以直接用直接寻址表,为保守起见,暂时先用哈希表存储字符。
我们首先设定字典树的数据结构应该如何实现的,如下图所示:
字典树
从根节点可以看到,每一个与根节点相邻的节点都包含一个字符,可以将字符设定成索引,按照字符索引去查找,时间复杂度降为O(1)。因为每一个节点都需要一个整数值,用来求和一棵子树的所有的值,可以设计成下面的代码:
class Node {
int value; // 整数值
TreeMap<Character,Node> next; // 查找时间复杂度为O(1)
}
TreeMap是关于红黑树的哈希表,用来存储字符的键和下一个节点的对象。
接着插入字符串,获取字符串里的字符可以通过charAt方法获取,或者直接转化为字符数组,遍历整个字符串的字符。
然后求和前缀子树的时候,先判断前缀字符串是否存在,如果不存在,则直接返回为0;如果存在,则进行深度优先遍历将所有的值累加起来,具体代码的执行动画如下:
视频大小:1.90M
import java.util.TreeMap;
class MapSum {
private class Node {
int value;
TreeMap<Character, Node> next;
Node(int value) {
this.value = value;
next = new TreeMap<>();
}
Node() {
this(0);
}
}
private Node root;
/**
* Initialize your data structure here.
*/
public MapSum() {
root = new Node();
}
public void insert(String key, int val) {
Node cur = root;
for (int i = 0; i < key.length(); i++) {
char c = key.charAt(i);
if (cur.next.get(c) == null)
cur.next.put(c, new Node());
cur = cur.next.get(c);
}
cur.value = val;
}
public int sum(String prefix) {
// 先判断prefix是否存在
Node cur = root;
for (int i = 0; i < prefix.length(); i++) {
char c = prefix.charAt(i);
if (cur.next.get(c) == null)
return 0;
cur = cur.next.get(c);
}
// 递归 将遍历的值相加起来
return sum(cur);
}
private int sum(Node node) {
int val = node.value;
for (char c : node.next.keySet()) {
val += sum(node.next.get(c));
}
return val;
}
}
/**
* Your MapSum object will be instantiated and called as such:
* MapSum obj = new MapSum();
* obj.insert(key,val);
* int param_2 = obj.sum(prefix);
*/
如果字符串里的字符都是小写字母,可以直接用直接寻址表保存26个字母,代码如下:
class MapSum {
private class Node {
int value;
Node[] next;
Node() {
value = 0;
next = new Node[26]; // 26个字符
}
}
private Node root;
/**
* Initialize your data structure here.
*/
public MapSum() {
root = new Node();
}
public void insert(String key, int val) {
Node cur = root;
for (char c : key.toCharArray()) {
if (cur.next[c - 'a'] == null)
cur.next[c - 'a'] = new Node();
cur = cur.next[c - 'a'];
}
cur.value = val;
}
public int sum(String prefix) {
// 先判断prefix是否存在
Node cur = root;
for (char c : prefix.toCharArray()) {
if (cur.next[c - 'a'] == null)
return 0;
cur = cur.next[c - 'a'];
}
// 递归 将遍历的值相加起来
return sum(cur);
}
private int sum(Node node) {
int val = node.value;
for (int i = 0; i < 26; i++) {
if (node.next[i] != null)
val += sum(node.next[i]);
}
return val;
}
}