Skip to content

Conversation

thorkous
Copy link
Contributor

Description

Added update statement in JobOperator. For warmpool jobs, we are persisting query status in WarmpoolJob.scala line 70 but we are not persisting query status for non-warmpool jobs.

Related Issues

List any issues this PR will resolve, e.g. Resolves [...].

Check List

  • [NA ] Updated documentation (docs/ppl-lang/README.md)
  • [ NA] Implemented unit tests
  • [ NA] Implemented tests for combination with other commands
  • [ NA] New added source code should include a copyright header
  • [ x] Commits are signed per the DCO using --signoff
  • [ NA] Add backport 0.x label if it is a stable change which won't break existing feature

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.

@@ -81,6 +81,7 @@ case class JobOperator(

val statementExecutionManager =
instantiateStatementExecutionManager(commandContext, resultIndex, osClient)
statementExecutionManager.updateStatement(statement)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we inclose this in isWarmpoolEnabled flag

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks for pointing this.

@thorkous thorkous force-pushed the main branch 5 times, most recently from 1ca7533 to 8905681 Compare June 12, 2025 07:27
Copy link
Collaborator

@noCharger noCharger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add UT and IT to verify this change

@thorkous
Copy link
Contributor Author

thorkous commented Jun 25, 2025

Please add UT and IT to verify this change

I tried adding IT for non-warmpool jobs, but I am unable to test if statementExecutionManager.updateStatement(statement) is called for non-warmpool jobs since it is an intermittent state. I am getting a success state for the non-warm pool job, but it doesn't test the line that I added.

It's not possible to write a JobOperator unit test without making code changes in JobOperator, as there are multiple methods which are initialised inside JobOperator, and they can't be mocked. Do let me know if you have any approach which I can take in writing UTs. PowerMockito can do this, but we haven't imported PowerMockito in the OpenSearch-Spark code.

@thorkous thorkous requested a review from noCharger June 26, 2025 04:06
@noCharger
Copy link
Collaborator

Please add UT and IT to verify this change

I tried adding IT for non-warmpool jobs, but I am unable to test if statementExecutionManager.updateStatement(statement) is called for non-warmpool jobs since it is an intermittent state. I am getting a success state for the non-warm pool job, but it doesn't test the line that I added.

It's not possible to write a JobOperator unit test without making code changes in JobOperator, as there are multiple methods which are initialised inside JobOperator, and they can't be mocked. Do let me know if you have any approach which I can take in writing UTs. PowerMockito can do this, but we haven't imported PowerMockito in the OpenSearch-Spark code.

For UT, you can create JobOperator instance and override instantiateStatementExecutionManager to return a mock to verify that updateStatement was called.

For IT, you can verify the statement change via getting it from the index itself

@thorkous
Copy link
Contributor Author

thorkous commented Jul 9, 2025

Please add UT and IT to verify this change

I tried adding IT for non-warmpool jobs, but I am unable to test if statementExecutionManager.updateStatement(statement) is called for non-warmpool jobs since it is an intermittent state. I am getting a success state for the non-warm pool job, but it doesn't test the line that I added.
It's not possible to write a JobOperator unit test without making code changes in JobOperator, as there are multiple methods which are initialised inside JobOperator, and they can't be mocked. Do let me know if you have any approach which I can take in writing UTs. PowerMockito can do this, but we haven't imported PowerMockito in the OpenSearch-Spark code.

For UT, you can create JobOperator instance and override instantiateStatementExecutionManager to return a mock to verify that updateStatement was called.

For IT, you can verify the statement change via getting it from the index itself

/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicInteger

import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
import org.opensearch.flint.common.model.FlintStatement
import org.opensearch.flint.core.metrics.MetricsSparkListener
import org.scalatest.BeforeAndAfterEach
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar

import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager}
import org.apache.spark.sql.types._

class JobOperatorTest
    extends SparkFunSuite
    with MockitoSugar
    with Matchers
    with BeforeAndAfterEach {

  private var mockStatementExecutionManager: StatementExecutionManager = _
  private var mockOSClient: OSClient = _
  private var mockSessionManager: SessionManager = _
  private var flintStatement: FlintStatement = _
  private var mockSparkSession: SparkSession = _
  private var mockSparkContext: SparkContext = _
  private var metricSparkListner: MetricsSparkListener = _
  private var mockStreamingQueryManager: StreamingQueryManager = _

  private val applicationId = "test-app-id"
  private val jobId = "test-job-id"
  private val dataSource = "test-datasource"
  private val resultIndex = "test-result-index"
  private val jobType = "batch"
  private val streamingRunningCount = new AtomicInteger(0)
  private val statementRunningCount = new AtomicInteger(0)
  private var mockConf: RuntimeConfig = _

  override def beforeEach(): Unit = {
    super.beforeEach()

    // Create Spark session
    mockSparkSession = mock[SparkSession]
    mockConf = mock[RuntimeConfig]
    mockSparkContext = mock[SparkContext]
    metricSparkListner = mock[MetricsSparkListener]
    mockStreamingQueryManager = mock[StreamingQueryManager]

    // Create mocks
    mockStatementExecutionManager = mock[StatementExecutionManager]
    mockOSClient = mock[OSClient]
    mockSessionManager = mock[SessionManager]

    // Create test FlintStatement
    flintStatement = new FlintStatement(
      "RUNNING",
      "SELECT 1",
      "test-statement-id",
      "test-query-id",
      "sql",
      System.currentTimeMillis()
    )

    // Create JobOperator instance

  }

  test(
    "start should call instantiateStatementExecutionManager and execute statement successfully"
  ) {

    try {
      when(mockSparkSession.conf).thenReturn(mockConf)
      when(mockConf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false"))
        .thenReturn("true")
      when(mockConf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""))
        .thenReturn("")
      // Add missing REQUEST_INDEX configuration that SessionManagerImpl requires
      when(mockConf.get(FlintSparkConf.REQUEST_INDEX.key, ""))
        .thenReturn("test-session-index")
      when(mockSparkSession.sparkContext).thenReturn(mockSparkContext)
      doNothing().when(mockSparkContext).addSparkListener(any())
      when(mockSparkSession.streams).thenReturn(mockStreamingQueryManager)
      when(mockStreamingQueryManager.active).thenReturn(
        Array.empty[StreamingQuery]
      )
      val spark = SparkSession
        .builder()
        .appName("EmptyDataFrame")
        .master("local[*]")
        .getOrCreate()

      // Create empty DataFrame with schema
      val schema = StructType(
        Array(
          StructField("id", IntegerType, true),
          StructField("name", StringType, true),
          StructField("age", IntegerType, true)
        )
      )

      val data = Seq(
        Row(
          true,
          "Alice",
          100L,
          10,
          10.5,
          3.14f,
          java.sql.Timestamp.valueOf("2024-01-01 10:00:00"),
          java.sql.Date.valueOf("2024-01-01"),
          Row("sub1", 1)
        )
      )
      val df =
        spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
      when(mockStatementExecutionManager.executeStatement(flintStatement))
        .thenReturn(df)

      val jobOperator = JobOperator(
        applicationId,
        jobId,
        mockSparkSession,
        flintStatement,
        dataSource,
        resultIndex,
        jobType,
        streamingRunningCount,
        statementRunningCount
      )
      jobOperator.start()
    } catch {
      case e: Exception => print("job failed", e.printStackTrace())
    }

    verify(mockStatementExecutionManager, times(2)).updateStatement(
      flintStatement
    )
  }
}

This is the code change i made and I am getting NPE

25/07/09 18:36:48 ERROR SingleStatementExecutionManager: Failed to verify existing mapping: Failed to get OpenSearch index mapping for test-result-index
java.lang.IllegalStateException: Failed to get OpenSearch index mapping for test-result-index

Do let me know how I can resolve this.
Thanks

@noCharger
Copy link
Collaborator

java.lang.IllegalStateException: Failed to get OpenSearch index mapping for test-result-index

You need to mock the opensearch client call to return a fake doc.

@thorkous
Copy link
Contributor Author

java.lang.IllegalStateException: Failed to get OpenSearch index mapping for test-result-index

You need to mock the opensearch client call to return a fake doc.

/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicInteger

import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
import org.opensearch.action.get.{GetRequest, GetResponse}
import org.opensearch.flint.common.model.FlintStatement
import org.opensearch.flint.core.metrics.MetricsSparkListener
import org.scalatest.BeforeAndAfterEach
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar

import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager}
import org.apache.spark.sql.types._

class JobOperatorTest
    extends SparkFunSuite
    with MockitoSugar
    with Matchers
    with BeforeAndAfterEach {

  private var mockStatementExecutionManager: StatementExecutionManager = _
  private var mockOSClient: OSClient = _
  private var mockSessionManager: SessionManager = _
  private var flintStatement: FlintStatement = _
  private var mockSparkSession: SparkSession = _
  private var mockSparkContext: SparkContext = _
  private var metricSparkListner: MetricsSparkListener = _
  private var mockStreamingQueryManager: StreamingQueryManager = _

  private val applicationId = "test-app-id"
  private val jobId = "test-job-id"
  private val dataSource = "test-datasource"
  private val resultIndex = "test-result-index"
  private val jobType = "batch"
  private val streamingRunningCount = new AtomicInteger(0)
  private val statementRunningCount = new AtomicInteger(0)
  private var mockConf: RuntimeConfig = _

  override def beforeEach(): Unit = {
    super.beforeEach()

    // Create Spark session
    mockSparkSession = mock[SparkSession]
    mockConf = mock[RuntimeConfig]
    mockSparkContext = mock[SparkContext]
    metricSparkListner = mock[MetricsSparkListener]
    mockStreamingQueryManager = mock[StreamingQueryManager]

    // Create mocks
    mockStatementExecutionManager = mock[StatementExecutionManager]
    mockOSClient = mock[OSClient]
    mockSessionManager = mock[SessionManager]

    // Create test FlintStatement
    flintStatement = new FlintStatement(
      "RUNNING",
      "SELECT 1",
      "test-statement-id",
      "test-query-id",
      "sql",
      System.currentTimeMillis()
    )

    // Create JobOperator instance

  }

  test(
    "start should call instantiateStatementExecutionManager and execute statement successfully"
  ) {

    try {
      when(mockSparkSession.conf).thenReturn(mockConf)
      when(mockConf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false"))
        .thenReturn("true")
      when(mockConf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""))
        .thenReturn("")
      // Add missing REQUEST_INDEX configuration that SessionManagerImpl requires
      when(mockConf.get(FlintSparkConf.REQUEST_INDEX.key, ""))
        .thenReturn("test-session-index")
      when(mockSparkSession.sparkContext).thenReturn(mockSparkContext)
      doNothing().when(mockSparkContext).addSparkListener(any())
      when(mockSparkSession.streams).thenReturn(mockStreamingQueryManager)
      when(mockStreamingQueryManager.active).thenReturn(
        Array.empty[StreamingQuery]
      )

      when(mockStatementExecutionManager.prepareStatementExecution())
        .thenReturn(null)
      val spark = SparkSession
        .builder()
        .appName("EmptyDataFrame")
        .master("local[*]")
        .getOrCreate()

      // Create empty DataFrame with schema
      val schema = StructType(
        Array(
          StructField("id", IntegerType, true),
          StructField("name", StringType, true),
          StructField("age", IntegerType, true)
        )
      )

      val data = Seq(
        Row(
          true,
          "Alice",
          100L,
          10,
          10.5,
          3.14f,
          java.sql.Timestamp.valueOf("2024-01-01 10:00:00"),
          java.sql.Date.valueOf("2024-01-01"),
          Row("sub1", 1)
        )
      )
      val df =
        spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
      when(mockStatementExecutionManager.executeStatement(flintStatement))
        .thenReturn(df)
      var getResponse = mock[GetResponse]
      when(mockOSClient.getIndexMetadata(resultIndex)).thenReturn("check")
      when(mockOSClient.getDoc(any(), any())).thenReturn(getResponse)

      val jobOperator = JobOperator(
        applicationId,
        jobId,
        mockSparkSession,
        flintStatement,
        dataSource,
        resultIndex,
        jobType,
        streamingRunningCount,
        statementRunningCount
      )
      jobOperator.start()
    } catch {
      case e: Exception => print("job failed", e.printStackTrace())
    }

    verify(mockStatementExecutionManager, times(2)).updateStatement(
      flintStatement
    )
  }
}

Even after mocking OSClient Method, I am getting same error.

25/07/14 23:29:29 ERROR SingleStatementExecutionManager: Failed to verify existing mapping: Failed to get OpenSearch index mapping for test-result-index
java.lang.IllegalStateException: Failed to get OpenSearch index mapping for test-result-index
        at org.apache.spark.sql.OSClient.$anonfun$getIndexMetadata$1(OSClient.scala:52)
        at org.apache.spark.sql.OSClient.using(OSClient.scala:108)
        at org.apache.spark.sql.OSClient.getIndexMetadata(OSClient.scala:43)
        at org.apache.spark.sql.FlintJobExecutor.checkAndCreateIndex(FlintJobExecutor.scala:377)
        at org.apache.spark.sql.FlintJobExecutor.checkAndCreateIndex$(FlintJobExecutor.scala:375)
        at org.apache.spark.sql.SingleStatementExecutionManager.checkAndCreateIndex(SingleStatementExecutionManagerImpl.scala:19)
        at org.apache.spark.sql.SingleStatementExecutionManager.prepareStatementExecution(SingleStatementExecutionManagerImpl.scala:28)
        at org.apache.spark.sql.JobOperator.$anonfun$start$1(JobOperator.scala:94)
        at scala.concurrent.Future$.$anonfun$apply$1(Future.scala:659)
        at scala.util.Success.$anonfun$map$1(Try.scala:255)
        at scala.util.Success.map(Try.scala:213)
        at scala.concurrent.Future.$anonfun$map$1(Future.scala:292)
        at scala.concurrent.impl.Promise.liftedTree1$1(Promise.scala:33)
        at scala.concurrent.impl.Promise.$anonfun$transform$1(Promise.scala:33)
        at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:64)
        at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
        at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
        at java.base/java.lang.Thread.run(Thread.java:829)

@thorkous
Copy link
Contributor Author

Hi noCharger,

Is there any reason that we didn't have any unit tests for JobOperator before this? Since I am creating the JobOperator unit test, I am facing multiple issues in mocking osClient. Since you have written the majority of JobOperator, did you face any such issues? If not, then why didn't we have any unit tests for job operator before this?

@noCharger
Copy link
Collaborator

Hi noCharger,

Is there any reason that we didn't have any unit tests for JobOperator before this? Since I am creating the JobOperator unit test, I am facing multiple issues in mocking osClient. Since you have written the majority of JobOperator, did you face any such issues? If not, then why didn't we have any unit tests for job operator before this?

You can refer to https://github.com/opensearch-project/opensearch-spark/blob/9aad67dc9e1f899f04509d3fe2dc709652ad4b92/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala and https://github.com/opensearch-project/opensearch-spark/blob/9aad67dc9e1f899f04509d3fe2dc709652ad4b92/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala

@thorkous
Copy link
Contributor Author

Hi noCharger,
Is there any reason that we didn't have any unit tests for JobOperator before this? Since I am creating the JobOperator unit test, I am facing multiple issues in mocking osClient. Since you have written the majority of JobOperator, did you face any such issues? If not, then why didn't we have any unit tests for job operator before this?

You can refer to https://github.com/opensearch-project/opensearch-spark/blob/9aad67dc9e1f899f04509d3fe2dc709652ad4b92/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala and https://github.com/opensearch-project/opensearch-spark/blob/9aad67dc9e1f899f04509d3fe2dc709652ad4b92/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala

Hi @noCharger ,

In the above example, sessionManager is passed a variable, and hence we are able to use osClient, which is passed in sessionManager, but in our case sessionManager is not passed, and hence we are unable to mock osClient.

Reference :

  1. https://stackoverflow.com/questions/55329877/mocking-new-instance-creation-inside-testing-class-using-mockito
  2. https://groups.google.com/g/mockito/c/7wYX4_2m6NU

@noCharger
Copy link
Collaborator

ut in our case sessionManager is not passed, and hence we are unable to mock osClient.

According to the stack trace, the invoke is from SingleStatementExecutionManager, which accepts osClient as an input. You can pass the mock.

@thorkous
Copy link
Contributor Author

ut in our case sessionManager is not passed, and hence we are unable to mock osClient.

According to the stack trace, the invoke is from SingleStatementExecutionManager, which accepts osClient as an input. You can pass the mock.

/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

package org.apache.spark.sql

import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY

import java.util.concurrent.atomic.AtomicInteger
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
import org.opensearch.action.get.{GetRequest, GetResponse}
import org.opensearch.flint.common.model.FlintStatement
import org.opensearch.flint.core.metrics.MetricsSparkListener
import org.scalatest.BeforeAndAfterEach
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager}
import org.apache.spark.sql.types._

import scala.concurrent.duration.{Duration, MINUTES}

class JobOperatorTest
    extends SparkFunSuite
    with MockitoSugar
    with Matchers
    with BeforeAndAfterEach {

  private var mockStatementExecutionManager: StatementExecutionManager = _
  private var mockOSClient: OSClient = _
  private var mockSessionManager: SessionManager = _
  private var flintStatement: FlintStatement = _
  private var mockSparkSession: SparkSession = _
  private var mockSparkContext: SparkContext = _
  private var metricSparkListner: MetricsSparkListener = _
  private var mockStreamingQueryManager: StreamingQueryManager = _
  private val INTERACTIVE_JOB_TYPE = "interactive"

  private val applicationId = "test-app-id"
  private val jobId = "test-job-id"
  private val dataSource = "test-datasource"
  private val resultIndex = "test-result-index"
  private val jobType = "batch"
  private val streamingRunningCount = new AtomicInteger(0)
  private val statementRunningCount = new AtomicInteger(0)
  private var mockConf: RuntimeConfig = _

  override def beforeEach(): Unit = {
    super.beforeEach()

    // Create Spark session
    mockSparkSession = mock[SparkSession]
    mockConf = mock[RuntimeConfig]
    mockSparkContext = mock[SparkContext]
    metricSparkListner = mock[MetricsSparkListener]
    mockStreamingQueryManager = mock[StreamingQueryManager]

    // Create mocks
    mockStatementExecutionManager = mock[StatementExecutionManager]
    mockOSClient = mock[OSClient]
    mockSessionManager = mock[SessionManager]

    // Create test FlintStatement
    flintStatement = new FlintStatement(
      "RUNNING",
      "SELECT 1",
      "test-statement-id",
      "test-query-id",
      "sql",
      System.currentTimeMillis()
    )

    // Create JobOperator instance

  }

  test(
    "start should call instantiateStatementExecutionManager and execute statement successfully"
  ) {

    try {

      when(mockSparkSession.conf).thenReturn(mockConf)
      when(mockConf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false"))
        .thenReturn("true")
      when(mockConf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""))
        .thenReturn("")
      // Add missing REQUEST_INDEX configuration that SessionManagerImpl requires
      val sessionId = "someSessionId"
      val commandContext = CommandContext(
        applicationId,
        jobId,
        mockSparkSession,
        dataSource,
        INTERACTIVE_JOB_TYPE,
        sessionId,
        mockSessionManager,
        Duration(10, MINUTES),
        60,
        60,
        DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY
      )

      val singleStatementExecutionManager = new SingleStatementExecutionManager(
        commandContext,
        resultIndex,
        mockOSClient
      )

      when(mockConf.get(FlintSparkConf.REQUEST_INDEX.key, ""))
        .thenReturn("test-session-index")
      when(mockSparkSession.sparkContext).thenReturn(mockSparkContext)
      doNothing().when(mockSparkContext).addSparkListener(any())
      when(mockSparkSession.streams).thenReturn(mockStreamingQueryManager)
      when(mockStreamingQueryManager.active).thenReturn(
        Array.empty[StreamingQuery]
      )

      when(mockStatementExecutionManager.prepareStatementExecution())
        .thenReturn(null)
      when(singleStatementExecutionManager.prepareStatementExecution())
        .thenReturn(null)
      val spark = SparkSession
        .builder()
        .appName("EmptyDataFrame")
        .master("local[*]")
        .getOrCreate()

      // Create empty DataFrame with schema
      val schema = StructType(
        Array(
          StructField("id", IntegerType, true),
          StructField("name", StringType, true),
          StructField("age", IntegerType, true)
        )
      )

      val data = Seq(
        Row(
          true,
          "Alice",
          100L,
          10,
          10.5,
          3.14f,
          java.sql.Timestamp.valueOf("2024-01-01 10:00:00"),
          java.sql.Date.valueOf("2024-01-01"),
          Row("sub1", 1)
        )
      )
      val df =
        spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
      when(mockStatementExecutionManager.executeStatement(flintStatement))
        .thenReturn(df)
      var getResponse = mock[GetResponse]
      when(mockOSClient.getIndexMetadata(resultIndex)).thenReturn("check")
      when(mockOSClient.getDoc(any(), any())).thenReturn(getResponse)

      val jobOperator = JobOperator(
        applicationId,
        jobId,
        mockSparkSession,
        flintStatement,
        dataSource,
        resultIndex,
        jobType,
        streamingRunningCount,
        statementRunningCount
      )
      jobOperator.start()
    } catch {
      case e: Exception => print("job failed", e.printStackTrace())
    }

    verify(mockStatementExecutionManager, times(2)).updateStatement(
      flintStatement
    )
  }
}

Even after entering the SingleStatementExecutionManager, I am getting the same error.

The SingleStatementExecutionManager is created inside the JobOperator, and the osClient it uses is also instantiated within the JobOperator. This means that even if I mock osClient, the mock won’t be used, since mocking only works when the object is either passed into the function we want to test or injected into the class externally.

@noCharger
Copy link
Collaborator

noCharger commented Jul 29, 2025

ut in our case sessionManager is not passed, and hence we are unable to mock osClient.

According to the stack trace, the invoke is from SingleStatementExecutionManager, which accepts osClient as an input. You can pass the mock.

/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

package org.apache.spark.sql

import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY

import java.util.concurrent.atomic.AtomicInteger
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
import org.opensearch.action.get.{GetRequest, GetResponse}
import org.opensearch.flint.common.model.FlintStatement
import org.opensearch.flint.core.metrics.MetricsSparkListener
import org.scalatest.BeforeAndAfterEach
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager}
import org.apache.spark.sql.types._

import scala.concurrent.duration.{Duration, MINUTES}

class JobOperatorTest
    extends SparkFunSuite
    with MockitoSugar
    with Matchers
    with BeforeAndAfterEach {

  private var mockStatementExecutionManager: StatementExecutionManager = _
  private var mockOSClient: OSClient = _
  private var mockSessionManager: SessionManager = _
  private var flintStatement: FlintStatement = _
  private var mockSparkSession: SparkSession = _
  private var mockSparkContext: SparkContext = _
  private var metricSparkListner: MetricsSparkListener = _
  private var mockStreamingQueryManager: StreamingQueryManager = _
  private val INTERACTIVE_JOB_TYPE = "interactive"

  private val applicationId = "test-app-id"
  private val jobId = "test-job-id"
  private val dataSource = "test-datasource"
  private val resultIndex = "test-result-index"
  private val jobType = "batch"
  private val streamingRunningCount = new AtomicInteger(0)
  private val statementRunningCount = new AtomicInteger(0)
  private var mockConf: RuntimeConfig = _

  override def beforeEach(): Unit = {
    super.beforeEach()

    // Create Spark session
    mockSparkSession = mock[SparkSession]
    mockConf = mock[RuntimeConfig]
    mockSparkContext = mock[SparkContext]
    metricSparkListner = mock[MetricsSparkListener]
    mockStreamingQueryManager = mock[StreamingQueryManager]

    // Create mocks
    mockStatementExecutionManager = mock[StatementExecutionManager]
    mockOSClient = mock[OSClient]
    mockSessionManager = mock[SessionManager]

    // Create test FlintStatement
    flintStatement = new FlintStatement(
      "RUNNING",
      "SELECT 1",
      "test-statement-id",
      "test-query-id",
      "sql",
      System.currentTimeMillis()
    )

    // Create JobOperator instance

  }

  test(
    "start should call instantiateStatementExecutionManager and execute statement successfully"
  ) {

    try {

      when(mockSparkSession.conf).thenReturn(mockConf)
      when(mockConf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false"))
        .thenReturn("true")
      when(mockConf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""))
        .thenReturn("")
      // Add missing REQUEST_INDEX configuration that SessionManagerImpl requires
      val sessionId = "someSessionId"
      val commandContext = CommandContext(
        applicationId,
        jobId,
        mockSparkSession,
        dataSource,
        INTERACTIVE_JOB_TYPE,
        sessionId,
        mockSessionManager,
        Duration(10, MINUTES),
        60,
        60,
        DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY
      )

      val singleStatementExecutionManager = new SingleStatementExecutionManager(
        commandContext,
        resultIndex,
        mockOSClient
      )

      when(mockConf.get(FlintSparkConf.REQUEST_INDEX.key, ""))
        .thenReturn("test-session-index")
      when(mockSparkSession.sparkContext).thenReturn(mockSparkContext)
      doNothing().when(mockSparkContext).addSparkListener(any())
      when(mockSparkSession.streams).thenReturn(mockStreamingQueryManager)
      when(mockStreamingQueryManager.active).thenReturn(
        Array.empty[StreamingQuery]
      )

      when(mockStatementExecutionManager.prepareStatementExecution())
        .thenReturn(null)
      when(singleStatementExecutionManager.prepareStatementExecution())
        .thenReturn(null)
      val spark = SparkSession
        .builder()
        .appName("EmptyDataFrame")
        .master("local[*]")
        .getOrCreate()

      // Create empty DataFrame with schema
      val schema = StructType(
        Array(
          StructField("id", IntegerType, true),
          StructField("name", StringType, true),
          StructField("age", IntegerType, true)
        )
      )

      val data = Seq(
        Row(
          true,
          "Alice",
          100L,
          10,
          10.5,
          3.14f,
          java.sql.Timestamp.valueOf("2024-01-01 10:00:00"),
          java.sql.Date.valueOf("2024-01-01"),
          Row("sub1", 1)
        )
      )
      val df =
        spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
      when(mockStatementExecutionManager.executeStatement(flintStatement))
        .thenReturn(df)
      var getResponse = mock[GetResponse]
      when(mockOSClient.getIndexMetadata(resultIndex)).thenReturn("check")
      when(mockOSClient.getDoc(any(), any())).thenReturn(getResponse)

      val jobOperator = JobOperator(
        applicationId,
        jobId,
        mockSparkSession,
        flintStatement,
        dataSource,
        resultIndex,
        jobType,
        streamingRunningCount,
        statementRunningCount
      )
      jobOperator.start()
    } catch {
      case e: Exception => print("job failed", e.printStackTrace())
    }

    verify(mockStatementExecutionManager, times(2)).updateStatement(
      flintStatement
    )
  }
}

Even after entering the SingleStatementExecutionManager, I am getting the same error.

The SingleStatementExecutionManager is created inside the JobOperator, and the osClient it uses is also instantiated within the JobOperator. This means that even if I mock osClient, the mock won’t be used, since mocking only works when the object is either passed into the function we want to test or injected into the class externally.

This is because SingleStatementExecutionManager created is never used in your test case. If you check my previous comment, you need to override instantiateStatementExecutionManager #1217 (comment)

Signed-off-by: Koustubh <thorkous@amazon.com>
@thorkous
Copy link
Contributor Author

thorkous commented Aug 1, 2025

Updated the test.

mockQueryResultWriter
}

override def start(): Unit = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why override start()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was no other way to mock the query result writer and mock data frame, as they are initialised inside jobOperator only, and hence I had to do this.

Comment on lines +90 to +93
override def writeDataFrameToOpensearch(
resultData: DataFrame,
resultIndex: String,
osClient: OSClient): Unit = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it necessary to override this instead of mock the results?

  (jobOperator.writeDataFrameToOpensearch _)
    .expects(*, *, *)
    .returning(())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't mock jobOperator, because we are testing that, and hence that would not be possible. writeDataFrameToOpenSearch internal calls osClient.doesIndexExist(resultIndex). I tried adding mockOSClient and mocking this step, but it's not mocking this for me.

@noCharger noCharger added backport 0.x Backport to 0.x branch (stable branch) backport 0.7 labels Aug 4, 2025
@noCharger noCharger merged commit 57dccb8 into opensearch-project:main Aug 4, 2025
5 of 7 checks passed
opensearch-trigger-bot bot pushed a commit that referenced this pull request Aug 4, 2025
Signed-off-by: Koustubh <thorkous@amazon.com>
Co-authored-by: Koustubh <thorkous@amazon.com>
(cherry picked from commit 57dccb8)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
opensearch-trigger-bot bot pushed a commit that referenced this pull request Aug 4, 2025
Signed-off-by: Koustubh <thorkous@amazon.com>
Co-authored-by: Koustubh <thorkous@amazon.com>
(cherry picked from commit 57dccb8)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backport 0.x Backport to 0.x branch (stable branch) backport 0.7
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants