Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ case class JobOperator(
val statementExecutionManager =
instantiateStatementExecutionManager(commandContext, resultIndex, osClient)

if (!isWarmpoolEnabled) {
statementExecutionManager.updateStatement(statement)
}

val readWriteBytesSparkListener = new MetricsSparkListener()
sparkSession.sparkContext.addSparkListener(readWriteBytesSparkListener)

Expand Down Expand Up @@ -293,7 +297,7 @@ case class JobOperator(
}
}

private def instantiateStatementExecutionManager(
protected def instantiateStatementExecutionManager(
commandContext: CommandContext,
resultIndex: String,
osClient: OSClient): StatementExecutionManager = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicInteger

import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{doNothing, times, verify, when}
import org.opensearch.flint.common.model.FlintStatement
import org.scalatest.BeforeAndAfterEach
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar

import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.streaming.StreamingQueryManager

class JobOperatorTest
extends SparkFunSuite
with MockitoSugar
with Matchers
with BeforeAndAfterEach {

private val jobId = "testJobId"
private val resultIndex = "resultIndex"
private val jobType = "interactive"
private val applicationId = "testApplicationId"
private val dataSource = "testDataSource"

override def beforeEach(): Unit = {
super.beforeEach()
}

test("verify if statementExecutionManager is calling update during non-warmpool jobs ") {
try {
val mockFlintStatement = mock[FlintStatement]
when(mockFlintStatement.queryId).thenReturn("test-query-id")
when(mockFlintStatement.query).thenReturn("SELECT 1")
val streamingRunningCount = new AtomicInteger(1)
val statementRunningCount = new AtomicInteger(1)
val mockSparkSession = mock[SparkSession]
val mockDataFrame = mock[DataFrame]
val mockSparkConf = mock[RuntimeConfig]

val mockStatementExecutionManager = mock[StatementExecutionManager]
val mockOSClient = mock[OSClient]
val mockSparkContext = mock[SparkContext]
val mockStreamingQueryManager = mock[StreamingQueryManager]

when(mockSparkSession.sql("SELECT 1")).thenReturn(mockDataFrame)
when(mockSparkSession.conf).thenReturn(mockSparkConf)
when(mockSparkConf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false"))
.thenReturn("false")
when(mockSparkConf.get(FlintSparkConf.REQUEST_INDEX.key, ""))
.thenReturn(resultIndex)
when(mockSparkSession.sparkContext).thenReturn(mockSparkContext)
doNothing().when(mockSparkContext).addSparkListener(any())
when(mockSparkSession.streams).thenReturn(mockStreamingQueryManager)
doNothing().when(mockStreamingQueryManager).addListener(any())
doNothing()
.when(mockStatementExecutionManager)
.updateStatement(any[FlintStatement])
when(mockStatementExecutionManager.prepareStatementExecution())
.thenReturn(Right(()))
when(mockStatementExecutionManager.executeStatement(any[FlintStatement]))
.thenReturn(mockDataFrame)

when(mockOSClient.doesIndexExist(resultIndex)).thenReturn(true)

val jobOperator = new JobOperator(
applicationId,
jobId,
mockSparkSession,
mockFlintStatement,
dataSource,
resultIndex,
jobType,
streamingRunningCount,
statementRunningCount) {
override protected def instantiateStatementExecutionManager(
commandContext: CommandContext,
resultIndex: String,
osClient: OSClient): StatementExecutionManager = {
mockStatementExecutionManager
}

override def writeDataFrameToOpensearch(
resultData: DataFrame,
resultIndex: String,
osClient: OSClient): Unit = {}
Comment on lines +90 to +93
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it necessary to override this instead of mock the results?

  (jobOperator.writeDataFrameToOpensearch _)
    .expects(*, *, *)
    .returning(())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't mock jobOperator, because we are testing that, and hence that would not be possible. writeDataFrameToOpenSearch internal calls osClient.doesIndexExist(resultIndex). I tried adding mockOSClient and mocking this step, but it's not mocking this for me.


override def instantiateQueryResultWriter(
spark: SparkSession,
commandContext: CommandContext): QueryResultWriter = {
val mockQueryResultWriter = mock[QueryResultWriter]
when(
mockQueryResultWriter
.processDataFrame(any[DataFrame], any[FlintStatement], any[Long]))
.thenAnswer(invocation => invocation.getArgument[DataFrame](0))
doNothing()
.when(mockQueryResultWriter)
.writeDataFrame(any[DataFrame], any[FlintStatement])
mockQueryResultWriter
}

override def start(): Unit = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why override start()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was no other way to mock the query result writer and mock data frame, as they are initialised inside jobOperator only, and hence I had to do this.

try {
if (!isWarmpoolEnabled) {
mockStatementExecutionManager.updateStatement(mockFlintStatement)
}

mockStatementExecutionManager.prepareStatementExecution() match {
case Right(_) =>
val data = mockStatementExecutionManager.executeStatement(mockFlintStatement)
val queryResultWriter =
instantiateQueryResultWriter(mockSparkSession, null)
queryResultWriter.writeDataFrame(data, mockFlintStatement)
case Left(err) =>
}

mockFlintStatement.complete()
mockStatementExecutionManager.updateStatement(mockFlintStatement)
} catch {
case e: Exception =>
mockFlintStatement.fail()
mockStatementExecutionManager.updateStatement(mockFlintStatement)
}
}
}

jobOperator.start()
verify(mockStatementExecutionManager, times(2)).updateStatement(mockFlintStatement);
} catch {
case e: Exception =>
print("Exception : ", e.printStackTrace())
}
}
}
Loading