17
17
package org .mongodb .scala .internal
18
18
19
19
import java .util .concurrent .ConcurrentLinkedQueue
20
+ import java .util .concurrent .atomic .AtomicLong
20
21
21
22
import org .mongodb .scala .{ Observable , Observer , Subscription }
22
23
23
- private [scala] case class ZipObservable [T , U ](
24
- observable1 : Observable [T ],
25
- observable2 : Observable [U ]
26
- ) extends Observable [(T , U )] {
24
+ private [scala] case class ZipObservable [L , R ](
25
+ leftObservable : Observable [L ],
26
+ rightObservable : Observable [R ]
27
+ ) extends Observable [(L , R )] {
27
28
28
- def subscribe (observer : Observer [_ >: (T , U )]): Unit = {
29
+ def subscribe (observer : Observer [_ >: (L , R )]): Unit = {
29
30
val helper = SubscriptionHelper (observer)
30
- observable1 .subscribe(SubscriptionCheckingObserver (helper.createFirstObserver ))
31
- observable2 .subscribe(SubscriptionCheckingObserver (helper.createSecondObserver ))
31
+ leftObservable .subscribe(SubscriptionCheckingObserver (helper.createLeftObserver ))
32
+ rightObservable .subscribe(SubscriptionCheckingObserver (helper.createRightObserver ))
32
33
}
33
34
34
- case class SubscriptionHelper (observer : Observer [_ >: (T , U )]) {
35
- private val thisQueue : ConcurrentLinkedQueue [(Long , T )] = new ConcurrentLinkedQueue [(Long , T )]()
36
- private val thatQueue : ConcurrentLinkedQueue [(Long , U )] = new ConcurrentLinkedQueue [(Long , U )]()
35
+ case class SubscriptionHelper (observer : Observer [_ >: (L , R )]) {
36
+ private val leftQueue : ConcurrentLinkedQueue [(Long , L )] = new ConcurrentLinkedQueue [(Long , L )]()
37
+ private val rightQueue : ConcurrentLinkedQueue [(Long , R )] = new ConcurrentLinkedQueue [(Long , R )]()
37
38
39
+ private val leftCounter : AtomicLong = new AtomicLong ()
40
+ private val rightCounter : AtomicLong = new AtomicLong ()
41
+ @ volatile private var completedLeft : Boolean = false
42
+ @ volatile private var completedRight : Boolean = false
38
43
@ volatile private var terminated : Boolean = false
39
- @ volatile private var observable1Subscription : Option [Subscription ] = None
40
- @ volatile private var observable2Subscription : Option [Subscription ] = None
44
+ @ volatile private var leftSubscription : Option [Subscription ] = None
45
+ @ volatile private var rightSubscription : Option [Subscription ] = None
41
46
42
- def createFirstObserver : Observer [T ] = createSubObserver[T ](thisQueue, observer, firstSub = true )
43
-
44
- def createSecondObserver : Observer [U ] = createSubObserver[U ](thatQueue, observer, firstSub = false )
47
+ def createLeftObserver : Observer [L ] = createSubObserver[L ](leftQueue, observer, isLeftSub = true )
48
+ def createRightObserver : Observer [R ] = createSubObserver[R ](rightQueue, observer, isLeftSub = false )
45
49
46
50
private def createSubObserver [A ](
47
51
queue : ConcurrentLinkedQueue [(Long , A )],
48
- observer : Observer [_ >: (T , U )],
49
- firstSub : Boolean
52
+ observer : Observer [_ >: (L , R )],
53
+ isLeftSub : Boolean
50
54
): Observer [A ] = {
51
55
new Observer [A ] {
52
56
@ volatile private var counter : Long = 0
@@ -56,38 +60,61 @@ private[scala] case class ZipObservable[T, U](
56
60
}
57
61
58
62
override def onSubscribe (subscription : Subscription ): Unit = {
59
- if (firstSub ) {
60
- observable1Subscription = Some (subscription)
63
+ if (isLeftSub ) {
64
+ leftSubscription = Some (subscription)
61
65
} else {
62
- observable2Subscription = Some (subscription)
66
+ rightSubscription = Some (subscription)
63
67
}
64
68
65
- if (observable1Subscription .nonEmpty && observable2Subscription .nonEmpty) {
69
+ if (leftSubscription .nonEmpty && rightSubscription .nonEmpty) {
66
70
observer.onSubscribe(jointSubscription)
67
71
}
68
72
}
69
73
70
74
override def onComplete (): Unit = {
71
- if (! firstSub) {
72
- terminated = true
73
- observer.onComplete()
74
- }
75
+ markCompleted(isLeftSub)
76
+ processNext(observer)
75
77
}
76
78
77
79
override def onNext (tResult : A ): Unit = {
80
+ if (isLeftSub) leftCounter.incrementAndGet() else rightCounter.incrementAndGet()
78
81
counter += 1
79
82
queue.add((counter, tResult))
80
- if ( ! firstSub) processNext(observer)
83
+ processNext(observer)
81
84
}
82
85
}
83
86
}
84
87
85
- private def processNext (observer : Observer [_ >: (T , U )]): Unit = {
86
- (thisQueue.peek, thatQueue.peek) match {
87
- case ((k1 : Long , _), (k2 : Long , _)) if k1 == k2 => observer.onNext((thisQueue.poll()._2, thatQueue.poll()._2))
88
+ private def markCompleted (isLeftSub : Boolean ): Unit = synchronized {
89
+ if (isLeftSub) {
90
+ completedLeft = true
91
+ } else {
92
+ completedRight = true
93
+ }
94
+ }
95
+
96
+ private def completed (): Unit = synchronized {
97
+ if (! terminated) {
98
+ terminated = true
99
+ leftSubscription.foreach(_.unsubscribe())
100
+ rightSubscription.foreach(_.unsubscribe())
101
+ observer.onComplete()
102
+ }
103
+ }
104
+
105
+ private def processNext (observer : Observer [_ >: (L , R )]): Unit = synchronized {
106
+ (leftQueue.peek, rightQueue.peek) match {
107
+ case ((k1 : Long , _), (k2 : Long , _)) if k1 == k2 =>
108
+ observer.onNext((leftQueue.poll()._2, rightQueue.poll()._2))
109
+ processNext(observer)
88
110
case _ =>
89
- if (! terminated && ! jointSubscription.isUnsubscribed) jointSubscription.request(1 ) // Uneven queues request more data
90
- // from downstream so to honor the original request for data.
111
+ if (! terminated && ! jointSubscription.isUnsubscribed) {
112
+ if (completedLeft && rightCounter.get() >= leftCounter.get()) {
113
+ completed()
114
+ } else if (completedRight && leftCounter.get() >= rightCounter.get()) {
115
+ completed()
116
+ }
117
+ }
91
118
}
92
119
}
93
120
@@ -96,14 +123,14 @@ private[scala] case class ZipObservable[T, U](
96
123
override def isUnsubscribed : Boolean = ! subscribed
97
124
98
125
override def request (n : Long ): Unit = {
99
- observable1Subscription .foreach(_.request(n))
100
- observable2Subscription .foreach(_.request(n))
126
+ leftSubscription .foreach(_.request(n))
127
+ rightSubscription .foreach(_.request(n))
101
128
}
102
129
103
130
override def unsubscribe (): Unit = {
104
131
subscribed = false
105
- observable1Subscription .foreach(_.unsubscribe())
106
- observable2Subscription .foreach(_.unsubscribe())
132
+ leftSubscription .foreach(_.unsubscribe())
133
+ rightSubscription .foreach(_.unsubscribe())
107
134
}
108
135
}
109
136
}
0 commit comments