diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 918cb790bdc9a..29e5cecb31799 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, FileInputStream, FileNotFoundException, FileOutputStream} import java.net.URI -import java.nio.file.Paths +import java.nio.file.{Files, Paths} import java.util.Properties import java.util.concurrent.ConcurrentHashMap @@ -760,59 +760,64 @@ class ClientSuite extends SparkFunSuite test("YARN AM JavaOptions") { Seq("client", "cluster").foreach { deployMode => - withTempDir { stagingDir => - val sparkConf = new SparkConfWithEnv( - Map("SPARK_HOME" -> System.getProperty("spark.test.home"))) - .set(SUBMIT_DEPLOY_MODE, deployMode) - .set(SparkLauncher.DRIVER_DEFAULT_JAVA_OPTIONS, "-Dx=1 -Dy=2") - .set(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Dz=3") - .set(AM_DEFAULT_JAVA_OPTIONS, "-Da=1 -Db=2") - .set(AM_JAVA_OPTIONS, "-Dc=3") - - val client = createClient(sparkConf) - val appIdField = classOf[Client] - .getDeclaredField("org$apache$spark$deploy$yarn$Client$$appId") - appIdField.setAccessible(true) - // A dummy ApplicationId impl, only `toString` method will be called - // in Client.createContainerLaunchContext - appIdField.set(client, new ApplicationId { - override def getId: Int = 1 - override def setId(i: Int): Unit = {} - override def getClusterTimestamp: Long = 1770077136288L - override def setClusterTimestamp(l: Long): Unit = {} - override def build(): Unit = {} - override def toString: String = "application_1770077136288_0001" - }) - val stagingDirPathField = classOf[Client] - .getDeclaredField("org$apache$spark$deploy$yarn$Client$$stagingDirPath") - stagingDirPathField.setAccessible(true) - stagingDirPathField.set(client, new Path(stagingDir.getAbsolutePath)) - val _createContainerLaunchContext = - PrivateMethod[ContainerLaunchContext](Symbol("createContainerLaunchContext")) - val containerLaunchContext = client invokePrivate _createContainerLaunchContext() - - val commands = containerLaunchContext.getCommands.asScala - deployMode match { - case "client" => - // In client mode, spark.yarn.am.defaultJavaOptions and spark.yarn.am.extraJavaOptions - // should be set in AM container command JAVA_OPTIONS - commands should contain("'-Da=1'") - commands should contain("'-Db=2'") - commands should contain("'-Dc=3'") - commands should not contain "'-Dx=1'" - commands should not contain "'-Dy=2'" - commands should not contain "'-Dz=3'" - case "cluster" => - // In cluster mode, spark.driver.defaultJavaOptions and spark.driver.extraJavaOptions - // should be set in AM container command JAVA_OPTIONS - commands should not contain "'-Da=1'" - commands should not contain "'-Db=2'" - commands should not contain "'-Dc=3'" - commands should contain ("'-Dx=1'") - commands should contain ("'-Dy=2'") - commands should contain ("'-Dz=3'") - case m => - fail(s"Unexpected deploy mode: $m") + withTempDir { sparkHome => + // Create jars dir and RELEASE file to avoid IllegalStateException. + Files.createDirectory(Paths.get(sparkHome.getPath, "jars")) + Files.createFile(Paths.get(sparkHome.getPath, "RELEASE")) + + withTempDir { stagingDir => + val sparkConf = new SparkConfWithEnv(Map("SPARK_HOME" -> sparkHome.getAbsolutePath)) + .set(SUBMIT_DEPLOY_MODE, deployMode) + .set(SparkLauncher.DRIVER_DEFAULT_JAVA_OPTIONS, "-Dx=1 -Dy=2") + .set(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Dz=3") + .set(AM_DEFAULT_JAVA_OPTIONS, "-Da=1 -Db=2") + .set(AM_JAVA_OPTIONS, "-Dc=3") + + val client = createClient(sparkConf) + val appIdField = classOf[Client] + .getDeclaredField("org$apache$spark$deploy$yarn$Client$$appId") + appIdField.setAccessible(true) + // A dummy ApplicationId impl, only `toString` method will be called + // in Client.createContainerLaunchContext + appIdField.set(client, new ApplicationId { + override def getId: Int = 1 + override def setId(i: Int): Unit = {} + override def getClusterTimestamp: Long = 1770077136288L + override def setClusterTimestamp(l: Long): Unit = {} + override def build(): Unit = {} + override def toString: String = "application_1770077136288_0001" + }) + val stagingDirPathField = classOf[Client] + .getDeclaredField("org$apache$spark$deploy$yarn$Client$$stagingDirPath") + stagingDirPathField.setAccessible(true) + stagingDirPathField.set(client, new Path(stagingDir.getAbsolutePath)) + val _createContainerLaunchContext = + PrivateMethod[ContainerLaunchContext](Symbol("createContainerLaunchContext")) + val containerLaunchContext = client invokePrivate _createContainerLaunchContext() + + val commands = containerLaunchContext.getCommands.asScala + deployMode match { + case "client" => + // In client mode, spark.yarn.am.defaultJavaOptions and spark.yarn.am.extraJavaOptions + // should be set in AM container command JAVA_OPTIONS + commands should contain("'-Da=1'") + commands should contain("'-Db=2'") + commands should contain("'-Dc=3'") + commands should not contain "'-Dx=1'" + commands should not contain "'-Dy=2'" + commands should not contain "'-Dz=3'" + case "cluster" => + // In cluster mode, spark.driver.defaultJavaOptions and spark.driver.extraJavaOptions + // should be set in AM container command JAVA_OPTIONS + commands should not contain "'-Da=1'" + commands should not contain "'-Db=2'" + commands should not contain "'-Dc=3'" + commands should contain ("'-Dx=1'") + commands should contain ("'-Dy=2'") + commands should contain ("'-Dz=3'") + case m => + fail(s"Unexpected deploy mode: $m") + } } } }