参考 ReentrantLock,自实现非公平锁
1. 源码
public class QueuedSynchronizer implements Lock{
private transient Thread exclusiveOwnerThread;
private volatile int state;
private transient volatile Node head;
private transient volatile Node tail;
// 是否打印日志
private transient boolean isPrintLog =false;
public QueuedSynchronizer(){
}
public QueuedSynchronizer(boolean isPrintLog){
this.isPrintLog=isPrintLog;
}
// 加锁
@Override
public void lock() {
// 谁先cas修改state成功,则获取锁资源
if(compareAndSetState(0,1)){
setExclusiveOwnerThread(Thread.currentThread());
if(isPrintLog){
System.out.println(Thread.currentThread().getName()+" lock,cas success");
}
}else {
// 未修改成功,再次获取,若没有成功,则添加队列中,并堵塞
acquire(1);
}
}
// 解锁
@Override
public void unlock() {
release(1);
}
private void acquire(int state){
// 再次获取锁资源
if(!tryAcquire(state)){
if(isPrintLog){
System.out.println(Thread.currentThread().getName()+" lock,try acquire fail");
}
// 没有获取成功的线程,进入堵塞队列,并修改线程状态
addWaiter();
if(isPrintLog){
System.out.println(Thread.currentThread().getName()+" lock,addWaiter success");
}
LockSupport.park(this);
// 线程堵塞,解锁后,需要获取锁资源
for(;!tryAcquire(state););
}
if(isPrintLog){
System.out.println(Thread.currentThread().getName()+" lock,try acquire success");
}
}
private void release(int expectState) {
Node head = getHead();
if(tryRelease(expectState)){
if(isPrintLog){
System.out.println(Thread.currentThread().getName()+" unlock,try release success,state:"+state);
}
if(head==null){
if(isPrintLog){
System.out.println(Thread.currentThread().getName()+" unlock,head is null");
}
return;
}
Node next = head.getNext();
if(next!=null){
if(isPrintLog){
System.out.println(Thread.currentThread().getName()+" unlock,next != null");
}
setHead(next);
next.setPrev(null);
LockSupport.unpark(next.getThread());
}else {
if(isPrintLog){
System.out.println(Thread.currentThread().getName()+" unlock,next == null");
}
}
}
}
private boolean tryRelease(int expectState){
int state=getState();
boolean released=false;
int count=0;
for(;;){
if(isPrintLog){
System.out.println(Thread.currentThread().getName()+" unlock,try release loop,count:"+count);
}
if(state<=0){
released=true;
}
else if(state==1){
int newState=state-expectState;
if(compareAndSetState(state,newState)){
setExclusiveOwnerThread(null);
released=true;
}
}else {
int newState=state-expectState;
if(compareAndSetState(state,newState)){
released=true;
}
}
if(released){
break;
}
count++;
}
return true;
}
private Node addWaiter(){
Thread currentThread = Thread.currentThread();
Node newNode = new Node(currentThread);
for(;;){
// 若尾节点为空,则新建头节点,赋值给尾节点
if(getTail()==null){
if(compareAndSetHead(head,new Node())){
setTail(head);
}
}else {
// 尾节点不为空,则尾插法 插入链表
Node prev=getTail();
if(compareAndSetTail(tail,newNode)){
newNode.prev= prev;
prev.next=newNode;
return newNode;
}
}
}
}
private boolean tryAcquire(int expectState){
int state = getState();
// 如果是当前线程获取资源
if(state==0){
if(compareAndSetState(0,expectState)){
setExclusiveOwnerThread(getExclusiveOwnerThread());
return true;
}
}
else if(isCurrentThread()){
int nextState=state+expectState;
if(nextState<0){
throw new Error("nextState is illegal");
}
return compareAndSetState(state,nextState);
}
return false;
}
private boolean isCurrentThread(){
return getExclusiveOwnerThread()!=null&&getExclusiveOwnerThread()==Thread.currentThread();
}
public Thread getExclusiveOwnerThread() {
return exclusiveOwnerThread;
}
public void setExclusiveOwnerThread(Thread exclusiveOwnerThread) {
this.exclusiveOwnerThread = exclusiveOwnerThread;
}
public final int getState() {
return state;
}
public void setState(int state) {
this.state = state;
}
public Node getHead() {
return head;
}
public void setHead(Node head) {
this.head = head;
}
public Node getTail() {
return tail;
}
public void setTail(Node tail) {
this.tail = tail;
}
public void setPrintLog(boolean printLog) {
isPrintLog = printLog;
}
class Node{
Node next;
Node prev;
Thread thread;
Node(){}
Node(Thread thread){
this.thread=thread;
}
public Node getNext() {
return next;
}
public void setNext(Node next) {
this.next = next;
}
public Node getPrev() {
return prev;
}
public void setPrev(Node prev) {
this.prev = prev;
}
public Thread getThread() {
return thread;
}
public void setThread(Thread thread) {
this.thread = thread;
}
}
// unsafe and cas
private static final Unsafe unsafe = getUnsafeInstance();
private static final long stateOffset;
private static final long headOffset;
private static final long tailOffset;
private static Unsafe getUnsafeInstance() {
try {
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
theUnsafe.setAccessible(true);
return (Unsafe) theUnsafe.get(null);
} catch (Exception e) {
throw new Error("get unsafe error", e);
}
}
static {
try {
stateOffset = unsafe.objectFieldOffset(QueuedSynchronizer.class.getDeclaredField("state"));
headOffset = unsafe.objectFieldOffset(QueuedSynchronizer.class.getDeclaredField("head"));
tailOffset = unsafe.objectFieldOffset(QueuedSynchronizer.class.getDeclaredField("tail"));
} catch (NoSuchFieldException e) {
throw new Error(e);
}
}
private final boolean compareAndSetState(int expect, int update){
return unsafe.compareAndSwapInt(this,stateOffset,expect,update);
}
private final boolean compareAndSetHead(Node expect,Node update){
return unsafe.compareAndSwapObject(this,headOffset,expect,update);
}
private final boolean compareAndSetTail(Node expect,Node update){
return unsafe.compareAndSwapObject(this,tailOffset,expect,update);
}
}
2. 测试
public class QueuedSynchronizerTest {
private static final QueuedSynchronizer lock = new QueuedSynchronizer(true);
private volatile int i;
public void testLock(){
lock.lock();
try {
try {
Thread.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
i++;
System.out.println(Thread.currentThread().getName()+" ,i="+i);
}finally {
lock.unlock();
}
}
public static void main(String[] args) {
QueuedSynchronizerTest instance = new QueuedSynchronizerTest();
// 启动2000个线程自增
IntStream.range(0,2000).forEach(e->{
new Thread(()->{
instance.testLock();
}).start();
});
// 打印堵塞队列节点
// printThread(instance);
}
private static void printThread(QueuedSynchronizerTest instance){
new Thread(()->{
while (true){
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println("----print start----");
instance.printQueue();
System.out.println("----print end----");
}
}).start();
}
private void printQueue(){
QueuedSynchronizer.Node point = lock.getHead();
if(point==null){
return;
}
System.out.println("head:"+point.getThread().getName());
while (point.getNext()!=null){
System.out.println(point.getThread().getName());
point=point.next;
}
}
}