Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.deploy.yarn

import java.io.{File, FileInputStream, FileNotFoundException, FileOutputStream}
import java.net.URI
import java.nio.file.Paths
import java.nio.file.{Files, Paths}
import java.util.Properties
import java.util.concurrent.ConcurrentHashMap

Expand Down Expand Up @@ -760,59 +760,64 @@ class ClientSuite extends SparkFunSuite

test("YARN AM JavaOptions") {
Seq("client", "cluster").foreach { deployMode =>
withTempDir { stagingDir =>
val sparkConf = new SparkConfWithEnv(
Map("SPARK_HOME" -> System.getProperty("spark.test.home")))
.set(SUBMIT_DEPLOY_MODE, deployMode)
.set(SparkLauncher.DRIVER_DEFAULT_JAVA_OPTIONS, "-Dx=1 -Dy=2")
.set(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Dz=3")
.set(AM_DEFAULT_JAVA_OPTIONS, "-Da=1 -Db=2")
.set(AM_JAVA_OPTIONS, "-Dc=3")

val client = createClient(sparkConf)
val appIdField = classOf[Client]
.getDeclaredField("org$apache$spark$deploy$yarn$Client$$appId")
appIdField.setAccessible(true)
// A dummy ApplicationId impl, only `toString` method will be called
// in Client.createContainerLaunchContext
appIdField.set(client, new ApplicationId {
override def getId: Int = 1
override def setId(i: Int): Unit = {}
override def getClusterTimestamp: Long = 1770077136288L
override def setClusterTimestamp(l: Long): Unit = {}
override def build(): Unit = {}
override def toString: String = "application_1770077136288_0001"
})
val stagingDirPathField = classOf[Client]
.getDeclaredField("org$apache$spark$deploy$yarn$Client$$stagingDirPath")
stagingDirPathField.setAccessible(true)
stagingDirPathField.set(client, new Path(stagingDir.getAbsolutePath))
val _createContainerLaunchContext =
PrivateMethod[ContainerLaunchContext](Symbol("createContainerLaunchContext"))
val containerLaunchContext = client invokePrivate _createContainerLaunchContext()

val commands = containerLaunchContext.getCommands.asScala
deployMode match {
case "client" =>
// In client mode, spark.yarn.am.defaultJavaOptions and spark.yarn.am.extraJavaOptions
// should be set in AM container command JAVA_OPTIONS
commands should contain("'-Da=1'")
commands should contain("'-Db=2'")
commands should contain("'-Dc=3'")
commands should not contain "'-Dx=1'"
commands should not contain "'-Dy=2'"
commands should not contain "'-Dz=3'"
case "cluster" =>
// In cluster mode, spark.driver.defaultJavaOptions and spark.driver.extraJavaOptions
// should be set in AM container command JAVA_OPTIONS
commands should not contain "'-Da=1'"
commands should not contain "'-Db=2'"
commands should not contain "'-Dc=3'"
commands should contain ("'-Dx=1'")
commands should contain ("'-Dy=2'")
commands should contain ("'-Dz=3'")
case m =>
fail(s"Unexpected deploy mode: $m")
withTempDir { sparkHome =>
// Create jars dir and RELEASE file to avoid IllegalStateException.
Files.createDirectory(Paths.get(sparkHome.getPath, "jars"))
Files.createFile(Paths.get(sparkHome.getPath, "RELEASE"))

withTempDir { stagingDir =>
val sparkConf = new SparkConfWithEnv(Map("SPARK_HOME" -> sparkHome.getAbsolutePath))
.set(SUBMIT_DEPLOY_MODE, deployMode)
.set(SparkLauncher.DRIVER_DEFAULT_JAVA_OPTIONS, "-Dx=1 -Dy=2")
.set(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Dz=3")
.set(AM_DEFAULT_JAVA_OPTIONS, "-Da=1 -Db=2")
.set(AM_JAVA_OPTIONS, "-Dc=3")

val client = createClient(sparkConf)
val appIdField = classOf[Client]
.getDeclaredField("org$apache$spark$deploy$yarn$Client$$appId")
appIdField.setAccessible(true)
// A dummy ApplicationId impl, only `toString` method will be called
// in Client.createContainerLaunchContext
appIdField.set(client, new ApplicationId {
override def getId: Int = 1
override def setId(i: Int): Unit = {}
override def getClusterTimestamp: Long = 1770077136288L
override def setClusterTimestamp(l: Long): Unit = {}
override def build(): Unit = {}
override def toString: String = "application_1770077136288_0001"
})
val stagingDirPathField = classOf[Client]
.getDeclaredField("org$apache$spark$deploy$yarn$Client$$stagingDirPath")
stagingDirPathField.setAccessible(true)
stagingDirPathField.set(client, new Path(stagingDir.getAbsolutePath))
val _createContainerLaunchContext =
PrivateMethod[ContainerLaunchContext](Symbol("createContainerLaunchContext"))
val containerLaunchContext = client invokePrivate _createContainerLaunchContext()

val commands = containerLaunchContext.getCommands.asScala
deployMode match {
case "client" =>
// In client mode, spark.yarn.am.defaultJavaOptions and spark.yarn.am.extraJavaOptions
// should be set in AM container command JAVA_OPTIONS
commands should contain("'-Da=1'")
commands should contain("'-Db=2'")
commands should contain("'-Dc=3'")
commands should not contain "'-Dx=1'"
commands should not contain "'-Dy=2'"
commands should not contain "'-Dz=3'"
case "cluster" =>
// In cluster mode, spark.driver.defaultJavaOptions and spark.driver.extraJavaOptions
// should be set in AM container command JAVA_OPTIONS
commands should not contain "'-Da=1'"
commands should not contain "'-Db=2'"
commands should not contain "'-Dc=3'"
commands should contain ("'-Dx=1'")
commands should contain ("'-Dy=2'")
commands should contain ("'-Dz=3'")
case m =>
fail(s"Unexpected deploy mode: $m")
}
}
}
}
Expand Down