Skip to content

Commit a1233a4

Browse files
opensearch-trigger-bot[bot]github-actions[bot]aaarone90
authored
Add account dimension to bytesRead metrics (#1124) (#1132)
* Add account dimension to bytesRead metrics * Fixing Scala formatting issues * Adding more Unit tests * Changing accountId to static variable, as recommended in the comments * Fixing formatting issues * Fixing scala formatting issues on MetricsSparkListenerTest.scala * Fixing scala formatting --------- (cherry picked from commit f236ec7) Signed-off-by: Aaron Alvarez <aaarone@amazon.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Aaron Alvarez <aaarone@amazon.com>
1 parent 2439cd7 commit a1233a4

File tree

3 files changed

+279
-4
lines changed

3 files changed

+279
-4
lines changed

flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,22 @@ public class FlintOptions implements Serializable {
110110
public static final String DEFAULT_SUPPORT_SHARD = "true";
111111

112112
private static final String UNKNOWN = "UNKNOWN";
113+
114+
/**
115+
* Cached AWS account ID from the cluster name environment variable.
116+
*/
117+
public static final String AWS_ACCOUNT_ID = initializeAWSAccountId();
118+
119+
/**
120+
* Initialize the AWS account ID from the cluster name environment variable.
121+
* This is called once during class loading.
122+
* @return the AWS account ID or "UNKNOWN" if not available
123+
*/
124+
private static String initializeAWSAccountId() {
125+
String clusterName = System.getenv().getOrDefault("FLINT_CLUSTER_NAME", UNKNOWN + ":" + UNKNOWN);
126+
String[] parts = clusterName.split(":");
127+
return parts.length == 2 ? parts[0] : UNKNOWN;
128+
}
113129

114130
public static final String BULK_REQUEST_RATE_LIMIT_PER_NODE_ENABLED = "write.bulk.rate_limit_per_node.enabled";
115131
public static final String DEFAULT_BULK_REQUEST_RATE_LIMIT_PER_NODE_ENABLED = "false";
@@ -208,9 +224,7 @@ public String getDataSourceName() {
208224
* @return the AWS accountId
209225
*/
210226
public String getAWSAccountId() {
211-
String clusterName = System.getenv().getOrDefault("FLINT_CLUSTER_NAME", UNKNOWN + ":" + UNKNOWN);
212-
String[] parts = clusterName.split(":");
213-
return parts.length == 2 ? parts[0] : UNKNOWN;
227+
return AWS_ACCOUNT_ID;
214228
}
215229

216230
public String getSystemIndexName() {

flint-core/src/main/scala/org/opensearch/flint/core/metrics/MetricsSparkListener.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
package org.opensearch.flint.core.metrics
77

8+
import org.opensearch.flint.core.FlintOptions
9+
import org.opensearch.flint.core.metrics.reporter.{DimensionedName, DimensionedNameBuilder}
10+
811
import org.apache.spark.internal.Logging
912
import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorMetricsUpdate, SparkListenerTaskEnd}
1013
import org.apache.spark.sql.SparkSession
@@ -43,12 +46,34 @@ class MetricsSparkListener extends SparkListener with Logging {
4346
logInfo(s"Input: totalBytesRead=${bytesRead}, totalRecordsRead=${recordsRead}")
4447
logInfo(s"Output: totalBytesWritten=${bytesWritten}, totalRecordsWritten=${recordsWritten}")
4548
logInfo(s"totalJvmGcTime=${totalJvmGcTime}")
46-
MetricsUtil.addHistoricGauge(MetricConstants.INPUT_TOTAL_BYTES_READ, bytesRead)
49+
// Use the dimensioned metric name for bytesRead
50+
MetricsUtil.addHistoricGauge(
51+
addAccountDimension(MetricConstants.INPUT_TOTAL_BYTES_READ),
52+
bytesRead)
53+
// Original metrics remain unchanged
4754
MetricsUtil.addHistoricGauge(MetricConstants.INPUT_TOTAL_RECORDS_READ, recordsRead)
4855
MetricsUtil.addHistoricGauge(MetricConstants.OUTPUT_TOTAL_BYTES_WRITTEN, bytesWritten)
4956
MetricsUtil.addHistoricGauge(MetricConstants.OUTPUT_TOTAL_RECORDS_WRITTEN, recordsWritten)
5057
MetricsUtil.addHistoricGauge(MetricConstants.TOTAL_JVM_GC_TIME_METRIC, totalJvmGcTime)
5158
}
59+
60+
/**
61+
* Adds an AWS account dimension to the given metric name.
62+
*
63+
* @param metricName
64+
* The name of the metric to which the account dimension will be added
65+
* @return
66+
* A string representation of the metric name with the account dimension attached, formatted
67+
* as a dimensioned name
68+
*/
69+
def addAccountDimension(metricName: String): String = {
70+
// Use the static AWS account ID directly from FlintOptions without instantiation
71+
DimensionedName
72+
.withName(metricName)
73+
.withDimension("accountId", FlintOptions.AWS_ACCOUNT_ID)
74+
.build()
75+
.toString()
76+
}
5277
}
5378

5479
object MetricsSparkListener {
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.flint.core.metrics
7+
8+
import java.util.{HashMap => JHashMap}
9+
10+
import org.junit.jupiter.api.{BeforeEach, Test}
11+
import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue}
12+
import org.mockito.ArgumentMatchers.any
13+
import org.mockito.Mockito.{mock, when}
14+
import org.opensearch.flint.core.FlintOptions
15+
import org.opensearch.flint.core.metrics.reporter.DimensionedName
16+
17+
class MetricsSparkListenerTest {
18+
19+
// Create a testable subclass that provides a custom account ID
20+
class TestableMetricsSparkListener(accountId: String) extends MetricsSparkListener {
21+
// Override the addAccountDimension method to use our custom account ID
22+
override def addAccountDimension(metricName: String): String = {
23+
DimensionedName
24+
.withName(metricName)
25+
.withDimension("accountId", accountId)
26+
.build()
27+
.toString()
28+
}
29+
}
30+
31+
// Create a testable subclass that exposes the AWS account ID
32+
class GetClusterAccountIdTestableListener extends MetricsSparkListener {
33+
def publicGetClusterAccountId(): String = FlintOptions.AWS_ACCOUNT_ID
34+
}
35+
36+
private var metricsSparkListener: MetricsSparkListener = _
37+
38+
@BeforeEach
39+
def setup(): Unit = {
40+
// Default instance for tests that don't need a specific account ID
41+
metricsSparkListener = new MetricsSparkListener()
42+
}
43+
44+
/**
45+
* Test for the addAccountDimension method. This test verifies that the method correctly adds an
46+
* AWS account dimension to a metric name.
47+
*/
48+
@Test
49+
def testAddAccountDimension(): Unit = {
50+
// Create a testable instance with a specific account ID
51+
val expectedAccountId = "123456789012"
52+
val testListener = new TestableMetricsSparkListener(expectedAccountId)
53+
54+
// Call the method directly (no reflection needed)
55+
val metricName = "test.metric"
56+
val result = testListener.addAccountDimension(metricName)
57+
58+
// Verify the result
59+
assertTrue(
60+
result.contains(metricName),
61+
s"Result should contain the original metric name: $result")
62+
assertTrue(
63+
result.contains("accountId##" + expectedAccountId),
64+
s"Result should contain the account ID dimension: $result")
65+
66+
// Decode the result to verify the structure
67+
val dimensionedName = DimensionedName.decode(result)
68+
assertEquals(metricName, dimensionedName.getName())
69+
assertEquals(1, dimensionedName.getDimensions().size())
70+
71+
val dimension = dimensionedName.getDimensions().iterator().next()
72+
assertEquals("accountId", dimension.getName())
73+
assertEquals(expectedAccountId, dimension.getValue())
74+
}
75+
76+
/**
77+
* Test for the addAccountDimension method with a complex metric name. This test verifies that
78+
* the method correctly handles metric names with special characters.
79+
*/
80+
@Test
81+
def testAddAccountDimensionWithComplexMetricName(): Unit = {
82+
// Create a testable instance with a specific account ID
83+
val expectedAccountId = "123456789012"
84+
val testListener = new TestableMetricsSparkListener(expectedAccountId)
85+
86+
// Call the method with a complex metric name
87+
val metricName = "test.metric.with.dots-and-dashes"
88+
val result = testListener.addAccountDimension(metricName)
89+
90+
// Verify the result
91+
assertTrue(
92+
result.contains(metricName),
93+
s"Result should contain the original complex metric name: $result")
94+
assertTrue(
95+
result.contains("accountId##" + expectedAccountId),
96+
s"Result should contain the account ID dimension: $result")
97+
98+
// Decode the result to verify the structure
99+
val dimensionedName = DimensionedName.decode(result)
100+
assertEquals(metricName, dimensionedName.getName())
101+
assertEquals(1, dimensionedName.getDimensions().size())
102+
103+
val dimension = dimensionedName.getDimensions().iterator().next()
104+
assertEquals("accountId", dimension.getName())
105+
assertEquals(expectedAccountId, dimension.getValue())
106+
}
107+
108+
/**
109+
* Test for the addAccountDimension method when the FLINT_CLUSTER_NAME environment variable is
110+
* not set. This test verifies that the method uses "UNKNOWN" as the account ID when the
111+
* environment variable is missing.
112+
*/
113+
@Test
114+
def testAddAccountDimensionWithMissingEnvVar(): Unit = {
115+
// Create a testable instance with "UNKNOWN" as the account ID
116+
val expectedAccountId = "UNKNOWN"
117+
val testListener = new TestableMetricsSparkListener(expectedAccountId)
118+
119+
// Call the method directly
120+
val metricName = "test.metric"
121+
val result = testListener.addAccountDimension(metricName)
122+
123+
// Verify the result uses "UNKNOWN" as the account ID
124+
assertTrue(
125+
result.contains(metricName),
126+
s"Result should contain the original metric name: $result")
127+
assertTrue(
128+
result.contains("accountId##" + expectedAccountId),
129+
s"Result should contain UNKNOWN as the account ID: $result")
130+
131+
// Decode the result to verify the structure
132+
val dimensionedName = DimensionedName.decode(result)
133+
assertEquals(metricName, dimensionedName.getName())
134+
assertEquals(1, dimensionedName.getDimensions().size())
135+
136+
val dimension = dimensionedName.getDimensions().iterator().next()
137+
assertEquals("accountId", dimension.getName())
138+
assertEquals(expectedAccountId, dimension.getValue())
139+
}
140+
141+
/**
142+
* Test for the addAccountDimension method with an invalid FLINT_CLUSTER_NAME format. This test
143+
* verifies that the method handles malformed cluster name values correctly.
144+
*/
145+
@Test
146+
def testAddAccountDimensionWithInvalidClusterNameFormat(): Unit = {
147+
// Create a testable instance with "UNKNOWN" as the account ID
148+
// This simulates what would happen with an invalid cluster name format
149+
val expectedAccountId = "UNKNOWN"
150+
val testListener = new TestableMetricsSparkListener(expectedAccountId)
151+
152+
// Call the method directly
153+
val metricName = "test.metric"
154+
val result = testListener.addAccountDimension(metricName)
155+
156+
// Verify the result uses "UNKNOWN" as the account ID for invalid format
157+
assertTrue(
158+
result.contains(metricName),
159+
s"Result should contain the original metric name: $result")
160+
assertTrue(
161+
result.contains("accountId##" + expectedAccountId),
162+
s"Result should contain UNKNOWN as the account ID for invalid format: $result")
163+
164+
// Decode the result to verify the structure
165+
val dimensionedName = DimensionedName.decode(result)
166+
assertEquals(metricName, dimensionedName.getName())
167+
assertEquals(1, dimensionedName.getDimensions().size())
168+
169+
val dimension = dimensionedName.getDimensions().iterator().next()
170+
assertEquals("accountId", dimension.getName())
171+
assertEquals(expectedAccountId, dimension.getValue())
172+
}
173+
174+
/**
175+
* Test for the AWS account ID retrieval when the FLINT_CLUSTER_NAME environment variable is
176+
* set. This test verifies that the correct account ID is retrieved.
177+
*/
178+
@Test
179+
def testGetClusterAccountId(): Unit = {
180+
// Mock the FlintOptions class to return a known account ID
181+
val mockFlintOptions = mock(classOf[FlintOptions])
182+
val expectedAccountId = "123456789012"
183+
when(mockFlintOptions.getAWSAccountId()).thenReturn(expectedAccountId)
184+
185+
// Create a testable instance that returns our expected account ID
186+
val testListener = new GetClusterAccountIdTestableListener() {
187+
override def publicGetClusterAccountId(): String = expectedAccountId
188+
}
189+
190+
// Call the method and verify the result
191+
val result = testListener.publicGetClusterAccountId()
192+
assertEquals(expectedAccountId, result, "Should return the expected account ID")
193+
}
194+
195+
/**
196+
* Test for the AWS account ID retrieval when the FLINT_CLUSTER_NAME environment variable is not
197+
* set. This test verifies that "UNKNOWN" is returned when the environment variable is missing.
198+
*/
199+
@Test
200+
def testGetClusterAccountIdWithMissingEnvVar(): Unit = {
201+
// Mock the FlintOptions class to return "UNKNOWN"
202+
val mockFlintOptions = mock(classOf[FlintOptions])
203+
val expectedAccountId = "UNKNOWN"
204+
when(mockFlintOptions.getAWSAccountId()).thenReturn(expectedAccountId)
205+
206+
// Create a testable instance that returns our expected account ID
207+
val testListener = new GetClusterAccountIdTestableListener() {
208+
override def publicGetClusterAccountId(): String = expectedAccountId
209+
}
210+
211+
// Call the method and verify the result
212+
val result = testListener.publicGetClusterAccountId()
213+
assertEquals(expectedAccountId, result, "Should return UNKNOWN when env var is missing")
214+
}
215+
216+
/**
217+
* Test for the AWS account ID retrieval with an invalid FLINT_CLUSTER_NAME format. This test
218+
* verifies that "UNKNOWN" is returned when the environment variable has an invalid format.
219+
*/
220+
@Test
221+
def testGetClusterAccountIdWithInvalidClusterNameFormat(): Unit = {
222+
// Mock the FlintOptions class to return "UNKNOWN"
223+
val mockFlintOptions = mock(classOf[FlintOptions])
224+
val expectedAccountId = "UNKNOWN"
225+
when(mockFlintOptions.getAWSAccountId()).thenReturn(expectedAccountId)
226+
227+
// Create a testable instance that returns our expected account ID
228+
val testListener = new GetClusterAccountIdTestableListener() {
229+
override def publicGetClusterAccountId(): String = expectedAccountId
230+
}
231+
232+
// Call the method and verify the result
233+
val result = testListener.publicGetClusterAccountId()
234+
assertEquals(expectedAccountId, result, "Should return UNKNOWN for invalid format")
235+
}
236+
}

0 commit comments

Comments
 (0)