ThreadLocal 源码分析

TOP 带着问题看源码

  1. ThreadLocal 是怎么保证不同线程内部的变量隔离的
  2. 你说了ThreadLocalMap,那它是如何解决Hash冲突的
  3. ThreadLocal 什么情况下会内存泄漏

1. 基本介绍

我们知道解决共享变量不安全的一种方式,就是利用每个线程的私有变量来操作,避免共享变量导致的线程不安全问题。

ThreadLocal 就是提供一个局部变量,不会遇到并发问题。

2. 成员变量 & 核心类分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// 计算hash值
private final int threadLocalHashCode = nextHashCode();
// 使用原子类记录hash值
private static AtomicInteger nextHashCode =
new AtomicInteger();
// 魔数,更好的分散数据
private static final int HASH_INCREMENT = 0x61c88647;

// Thread.class
// 每个线程类都会有一个 ThreadLocalMap
ThreadLocal.ThreadLocalMap threadLocals = null;

// java.lang.ThreadLocal.ThreadLocalMap
// 初始化容量
private static final int INITIAL_CAPACITY = 16;
// 存储数据数组
private Entry[] table;
// 元素个数
private int size = 0;
// 扩容的阈值
private int threshold;
// 构造方法,初始化值
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}

回到问题 TOP 1 ,可以知道是不同线程都有自己的 ThreadLocalMap ,也就天然做到了隔离

2.1 ThreadLocal 内存泄漏的原因

通过对变量和核心类的分析,相信对 ThreadLocal 的一个结构有了大致的了解,接下来我们先来看下 ThreadLocal 的 map 是怎么定义的。

1
2
3
4
5
6
7
8
9
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;
// 注意,这里 k 作为弱引用
// 原因是如果是强引用,我们如果把 threadlocal 置为 null 不再使用,但是其在线程中的 threadlocalmap还在,导致无法gc
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

回到问题 TOP 3 ,由此可以分析,在 gc 的情况下,k 会出现为 null,也就会出现 value 还在但是无法拿到的情况(内存泄漏)。

实际上在 ThreadLocal 中这个问题并不是想象的那么可怕,其核心方法基本都会对这种无效的数据进行清理。

3. 核心方法分析

3.1 set 数据

核心逻辑就是:

  1. 把数据放到当前线程的 ThreadLocalMap 的 value
  2. 如果当前 key 的位置已经有了就覆盖
  3. 如果当前位置的元素与当前 key 不相等,就插入下一个可以放置元素的地方
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
public void set(T value) {
// 获取当前线程
Thread t = Thread.currentThread();
// 获取当前线程的 ThreadLocalMap
ThreadLocalMap map = getMap(t);
if (map != null)
// 这里this是指的调用这个方法的threadlocal对象
// 调用 set 方法
map.set(this, value);
else
// 还没有就创建一个map
createMap(t, value);
}

private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
// 根据 hash 计算位置
int i = key.threadLocalHashCode & (len-1);
// 循环,如果第一次没有获取到 key 相同的,就循环下一个位置
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
// 获取 key
ThreadLocal<?> k = e.get();
// key 相同,直接覆盖
if (k == key) {
e.value = value;
return;
}
// key 为空就调用 replaceStaleEntry,见下文
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 走到这里说明目标位置是空的,构建元素放到存储数组中
tab[i] = new Entry(key, value);
// 增加元素个数
int sz = ++size;
// 如果不再有无用元素,并且容量超过了阈值,就扩容
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

回到问题 TOP 2 ,可以知道其处理 hash 冲突采用的是开放寻址法,位置已被占就会找下一个。在数据量较少的场景,这个是很合适的。

3.2 get 数据

核心逻辑就是:

  1. 把当前线程的 ThreadLocalMap 的 value 取出来
  2. 如果按照 hash 查找到的 key 不一样,说明出现 hash 冲突了,调用 getEntryAfterMiss()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
public T get() {
// 获取当前线程
Thread t = Thread.currentThread();
// 获取当前线程的 ThreadLocalMap
ThreadLocalMap map = getMap(t);
if (map != null) {
// 调用 getEntry 根据 key 查找到 value
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}

private Entry getEntry(ThreadLocal<?> key) {
// 根据hash值确定位置,获取元素
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
// 如果 key 不相同或者值为null,调用 getEntryAfterMiss
return getEntryAfterMiss(key, i, e);
}

3.3 辅助方法

3.3.1 getEntryAfterMiss

该方法主要是用来处理 hash 冲突的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
// 死循环
while (e != null) {
ThreadLocal<?> k = e.get();
// 当 key 也相同时候就返回
if (k == key)
return e;
// 如果 key 为 null 调用 expungeStaleEntry
if (k == null)
expungeStaleEntry(i);
else
// 获取下一个位置
i = nextIndex(i, len);
// 获取新的位置元素
e = tab[i];
}
return null;
}

3.3.2 expungeStaleEntry

核心逻辑就是:

  1. 清理无效的 entity
  2. 往后继续搜索和清理,直到 tab[i] == null 退出
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// 删除无效元素
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

// Rehash until we encounter null
Entry e;
int i;
// 循环,直到 tab[i] == null 退出
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// 如果再次发现 key 为 null 的都删掉
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
// 处理 rehash 情况
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;

// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}

3.3.3 cleanSomeSlots

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
// 获取下一个索引值
i = nextIndex(i, len);
Entry e = tab[i];
// 如果这个索引对应的 value 不为空 并且 key 是空的
if (e != null && e.get() == null) {
// 重置n为哈希表大小
n = len;
removed = true;
// 清理无效的 entity
// expungeStaleEntry我们前面也分析了,会往后找到所有无效的 entity
i = expungeStaleEntry(i);
}
// 每次搜索范围减少一半
} while ( (n >>>= 1) != 0);
return removed;
}

4. 总结

可以看到,ThreadLocal 在完成基本功能之外,做了很多辅助操作来避免内存泄漏,这种严谨的做法也值得我们在工作中来做。