Skip to content

Commit 7d35c2c

Browse files
committed
ZipObservable fix
JAVA-3687
1 parent b74bd0f commit 7d35c2c

File tree

2 files changed

+81
-35
lines changed

2 files changed

+81
-35
lines changed

driver-scala/src/main/scala/org/mongodb/scala/internal/ZipObservable.scala

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,40 @@
1717
package org.mongodb.scala.internal
1818

1919
import java.util.concurrent.ConcurrentLinkedQueue
20+
import java.util.concurrent.atomic.AtomicLong
2021

2122
import org.mongodb.scala.{ Observable, Observer, Subscription }
2223

23-
private[scala] case class ZipObservable[T, U](
24-
observable1: Observable[T],
25-
observable2: Observable[U]
26-
) extends Observable[(T, U)] {
24+
private[scala] case class ZipObservable[L, R](
25+
leftObservable: Observable[L],
26+
rightObservable: Observable[R]
27+
) extends Observable[(L, R)] {
2728

28-
def subscribe(observer: Observer[_ >: (T, U)]): Unit = {
29+
def subscribe(observer: Observer[_ >: (L, R)]): Unit = {
2930
val helper = SubscriptionHelper(observer)
30-
observable1.subscribe(SubscriptionCheckingObserver(helper.createFirstObserver))
31-
observable2.subscribe(SubscriptionCheckingObserver(helper.createSecondObserver))
31+
leftObservable.subscribe(SubscriptionCheckingObserver(helper.createLeftObserver))
32+
rightObservable.subscribe(SubscriptionCheckingObserver(helper.createRightObserver))
3233
}
3334

34-
case class SubscriptionHelper(observer: Observer[_ >: (T, U)]) {
35-
private val thisQueue: ConcurrentLinkedQueue[(Long, T)] = new ConcurrentLinkedQueue[(Long, T)]()
36-
private val thatQueue: ConcurrentLinkedQueue[(Long, U)] = new ConcurrentLinkedQueue[(Long, U)]()
35+
case class SubscriptionHelper(observer: Observer[_ >: (L, R)]) {
36+
private val leftQueue: ConcurrentLinkedQueue[(Long, L)] = new ConcurrentLinkedQueue[(Long, L)]()
37+
private val rightQueue: ConcurrentLinkedQueue[(Long, R)] = new ConcurrentLinkedQueue[(Long, R)]()
3738

39+
private val leftCounter: AtomicLong = new AtomicLong()
40+
private val rightCounter: AtomicLong = new AtomicLong()
41+
@volatile private var completedLeft: Boolean = false
42+
@volatile private var completedRight: Boolean = false
3843
@volatile private var terminated: Boolean = false
39-
@volatile private var observable1Subscription: Option[Subscription] = None
40-
@volatile private var observable2Subscription: Option[Subscription] = None
44+
@volatile private var leftSubscription: Option[Subscription] = None
45+
@volatile private var rightSubscription: Option[Subscription] = None
4146

42-
def createFirstObserver: Observer[T] = createSubObserver[T](thisQueue, observer, firstSub = true)
43-
44-
def createSecondObserver: Observer[U] = createSubObserver[U](thatQueue, observer, firstSub = false)
47+
def createLeftObserver: Observer[L] = createSubObserver[L](leftQueue, observer, isLeftSub = true)
48+
def createRightObserver: Observer[R] = createSubObserver[R](rightQueue, observer, isLeftSub = false)
4549

4650
private def createSubObserver[A](
4751
queue: ConcurrentLinkedQueue[(Long, A)],
48-
observer: Observer[_ >: (T, U)],
49-
firstSub: Boolean
52+
observer: Observer[_ >: (L, R)],
53+
isLeftSub: Boolean
5054
): Observer[A] = {
5155
new Observer[A] {
5256
@volatile private var counter: Long = 0
@@ -56,38 +60,61 @@ private[scala] case class ZipObservable[T, U](
5660
}
5761

5862
override def onSubscribe(subscription: Subscription): Unit = {
59-
if (firstSub) {
60-
observable1Subscription = Some(subscription)
63+
if (isLeftSub) {
64+
leftSubscription = Some(subscription)
6165
} else {
62-
observable2Subscription = Some(subscription)
66+
rightSubscription = Some(subscription)
6367
}
6468

65-
if (observable1Subscription.nonEmpty && observable2Subscription.nonEmpty) {
69+
if (leftSubscription.nonEmpty && rightSubscription.nonEmpty) {
6670
observer.onSubscribe(jointSubscription)
6771
}
6872
}
6973

7074
override def onComplete(): Unit = {
71-
if (!firstSub) {
72-
terminated = true
73-
observer.onComplete()
74-
}
75+
markCompleted(isLeftSub)
76+
processNext(observer)
7577
}
7678

7779
override def onNext(tResult: A): Unit = {
80+
if (isLeftSub) leftCounter.incrementAndGet() else rightCounter.incrementAndGet()
7881
counter += 1
7982
queue.add((counter, tResult))
80-
if (!firstSub) processNext(observer)
83+
processNext(observer)
8184
}
8285
}
8386
}
8487

85-
private def processNext(observer: Observer[_ >: (T, U)]): Unit = {
86-
(thisQueue.peek, thatQueue.peek) match {
87-
case ((k1: Long, _), (k2: Long, _)) if k1 == k2 => observer.onNext((thisQueue.poll()._2, thatQueue.poll()._2))
88+
private def markCompleted(isLeftSub: Boolean): Unit = synchronized {
89+
if (isLeftSub) {
90+
completedLeft = true
91+
} else {
92+
completedRight = true
93+
}
94+
}
95+
96+
private def completed(): Unit = synchronized {
97+
if (!terminated) {
98+
terminated = true
99+
leftSubscription.foreach(_.unsubscribe())
100+
rightSubscription.foreach(_.unsubscribe())
101+
observer.onComplete()
102+
}
103+
}
104+
105+
private def processNext(observer: Observer[_ >: (L, R)]): Unit = synchronized {
106+
(leftQueue.peek, rightQueue.peek) match {
107+
case ((k1: Long, _), (k2: Long, _)) if k1 == k2 =>
108+
observer.onNext((leftQueue.poll()._2, rightQueue.poll()._2))
109+
processNext(observer)
88110
case _ =>
89-
if (!terminated && !jointSubscription.isUnsubscribed) jointSubscription.request(1) // Uneven queues request more data
90-
// from downstream so to honor the original request for data.
111+
if (!terminated && !jointSubscription.isUnsubscribed) {
112+
if (completedLeft && rightCounter.get() >= leftCounter.get()) {
113+
completed()
114+
} else if (completedRight && leftCounter.get() >= rightCounter.get()) {
115+
completed()
116+
}
117+
}
91118
}
92119
}
93120

@@ -96,14 +123,14 @@ private[scala] case class ZipObservable[T, U](
96123
override def isUnsubscribed: Boolean = !subscribed
97124

98125
override def request(n: Long): Unit = {
99-
observable1Subscription.foreach(_.request(n))
100-
observable2Subscription.foreach(_.request(n))
126+
leftSubscription.foreach(_.request(n))
127+
rightSubscription.foreach(_.request(n))
101128
}
102129

103130
override def unsubscribe(): Unit = {
104131
subscribed = false
105-
observable1Subscription.foreach(_.unsubscribe())
106-
observable2Subscription.foreach(_.unsubscribe())
132+
leftSubscription.foreach(_.unsubscribe())
133+
rightSubscription.foreach(_.unsubscribe())
107134
}
108135
}
109136
}

driver-scala/src/test/scala/org/mongodb/scala/internal/ObservableImplementationSpec.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,18 @@ class ObservableImplementationSpec extends BaseSpec with TableDrivenPropertyChec
167167
observer.completed should equal(true)
168168
}
169169
}
170+
171+
forAll(zippedObservablesWithEmptyObservable) { (observable: Observable[(Int, Int)]) =>
172+
{
173+
val observer = TestObserver[(Int, Int)]()
174+
observable.subscribe(observer)
175+
176+
observer.subscription.foreach(_.request(100))
177+
178+
observer.results should equal(List())
179+
observer.completed should equal(true)
180+
}
181+
}
170182
}
171183

172184
it should "error if requested amount is less than 1" in {
@@ -258,6 +270,13 @@ class ObservableImplementationSpec extends BaseSpec with TableDrivenPropertyChec
258270
ZipObservable[Int, Int](TestObservable[Int](), TestObservable[Int](1 to 50))
259271
)
260272

273+
private def zippedObservablesWithEmptyObservable =
274+
Table[Observable[(Int, Int)]](
275+
"observable",
276+
ZipObservable[Int, Int](TestObservable[Int](1 to 50), TestObservable[Int](List())),
277+
ZipObservable[Int, Int](TestObservable[Int](List()), TestObservable[Int](1 to 50))
278+
)
279+
261280
private def overRequestingObservables =
262281
Table(
263282
("observable", "observer", "expected"),

0 commit comments

Comments
 (0)