diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 3f09f752b..6761e4412 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -82,6 +82,10 @@ case class JobOperator( val statementExecutionManager = instantiateStatementExecutionManager(commandContext, resultIndex, osClient) + if (!isWarmpoolEnabled) { + statementExecutionManager.updateStatement(statement) + } + val readWriteBytesSparkListener = new MetricsSparkListener() sparkSession.sparkContext.addSparkListener(readWriteBytesSparkListener) @@ -293,7 +297,7 @@ case class JobOperator( } } - private def instantiateStatementExecutionManager( + protected def instantiateStatementExecutionManager( commandContext: CommandContext, resultIndex: String, osClient: OSClient): StatementExecutionManager = { diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/JobOperatorTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/JobOperatorTest.scala new file mode 100644 index 000000000..b236fd4fd --- /dev/null +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/JobOperatorTest.scala @@ -0,0 +1,141 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.concurrent.atomic.AtomicInteger + +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{doNothing, times, verify, when} +import org.opensearch.flint.common.model.FlintStatement +import org.scalatest.BeforeAndAfterEach +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.streaming.StreamingQueryManager + +class JobOperatorTest + extends SparkFunSuite + with MockitoSugar + with Matchers + with BeforeAndAfterEach { + + private val jobId = "testJobId" + private val resultIndex = "resultIndex" + private val jobType = "interactive" + private val applicationId = "testApplicationId" + private val dataSource = "testDataSource" + + override def beforeEach(): Unit = { + super.beforeEach() + } + + test("verify if statementExecutionManager is calling update during non-warmpool jobs ") { + try { + val mockFlintStatement = mock[FlintStatement] + when(mockFlintStatement.queryId).thenReturn("test-query-id") + when(mockFlintStatement.query).thenReturn("SELECT 1") + val streamingRunningCount = new AtomicInteger(1) + val statementRunningCount = new AtomicInteger(1) + val mockSparkSession = mock[SparkSession] + val mockDataFrame = mock[DataFrame] + val mockSparkConf = mock[RuntimeConfig] + + val mockStatementExecutionManager = mock[StatementExecutionManager] + val mockOSClient = mock[OSClient] + val mockSparkContext = mock[SparkContext] + val mockStreamingQueryManager = mock[StreamingQueryManager] + + when(mockSparkSession.sql("SELECT 1")).thenReturn(mockDataFrame) + when(mockSparkSession.conf).thenReturn(mockSparkConf) + when(mockSparkConf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false")) + .thenReturn("false") + when(mockSparkConf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(resultIndex) + when(mockSparkSession.sparkContext).thenReturn(mockSparkContext) + doNothing().when(mockSparkContext).addSparkListener(any()) + when(mockSparkSession.streams).thenReturn(mockStreamingQueryManager) + doNothing().when(mockStreamingQueryManager).addListener(any()) + doNothing() + .when(mockStatementExecutionManager) + .updateStatement(any[FlintStatement]) + when(mockStatementExecutionManager.prepareStatementExecution()) + .thenReturn(Right(())) + when(mockStatementExecutionManager.executeStatement(any[FlintStatement])) + .thenReturn(mockDataFrame) + + when(mockOSClient.doesIndexExist(resultIndex)).thenReturn(true) + + val jobOperator = new JobOperator( + applicationId, + jobId, + mockSparkSession, + mockFlintStatement, + dataSource, + resultIndex, + jobType, + streamingRunningCount, + statementRunningCount) { + override protected def instantiateStatementExecutionManager( + commandContext: CommandContext, + resultIndex: String, + osClient: OSClient): StatementExecutionManager = { + mockStatementExecutionManager + } + + override def writeDataFrameToOpensearch( + resultData: DataFrame, + resultIndex: String, + osClient: OSClient): Unit = {} + + override def instantiateQueryResultWriter( + spark: SparkSession, + commandContext: CommandContext): QueryResultWriter = { + val mockQueryResultWriter = mock[QueryResultWriter] + when( + mockQueryResultWriter + .processDataFrame(any[DataFrame], any[FlintStatement], any[Long])) + .thenAnswer(invocation => invocation.getArgument[DataFrame](0)) + doNothing() + .when(mockQueryResultWriter) + .writeDataFrame(any[DataFrame], any[FlintStatement]) + mockQueryResultWriter + } + + override def start(): Unit = { + try { + if (!isWarmpoolEnabled) { + mockStatementExecutionManager.updateStatement(mockFlintStatement) + } + + mockStatementExecutionManager.prepareStatementExecution() match { + case Right(_) => + val data = mockStatementExecutionManager.executeStatement(mockFlintStatement) + val queryResultWriter = + instantiateQueryResultWriter(mockSparkSession, null) + queryResultWriter.writeDataFrame(data, mockFlintStatement) + case Left(err) => + } + + mockFlintStatement.complete() + mockStatementExecutionManager.updateStatement(mockFlintStatement) + } catch { + case e: Exception => + mockFlintStatement.fail() + mockStatementExecutionManager.updateStatement(mockFlintStatement) + } + } + } + + jobOperator.start() + verify(mockStatementExecutionManager, times(2)).updateStatement(mockFlintStatement); + } catch { + case e: Exception => + print("Exception : ", e.printStackTrace()) + } + } +}