Procházet zdrojové kódy

odpsOps方式写入phoenix

许家凯 před 4 roky
rodič
revize
863d0e998a

+ 1 - 0
.gitignore

@@ -23,3 +23,4 @@ metastore_db/
 derby.log
 log4j.properties
 dependency-reduced-pom.xml
+spark-warehouse

+ 72 - 10
pom.xml

@@ -17,6 +17,8 @@
         <maven.compiler.source>1.8</maven.compiler.source>
         <maven.compiler.target>1.8</maven.compiler.target>
         <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+        <phoenix.version>5.0.0-HBase-2.0</phoenix.version>
+        <ali-phoenix.version>5.2.1-HBase-2.x</ali-phoenix.version>
     </properties>
 
 
@@ -104,6 +106,12 @@
             <groupId>com.aliyun.odps</groupId>
             <artifactId>cupid-sdk</artifactId>
             <version>${cupid.sdk.version}</version>
+            <exclusions>
+                <exclusion>
+                    <artifactId>odps-sdk-core</artifactId>
+                    <groupId>com.aliyun.odps</groupId>
+                </exclusion>
+            </exclusions>
         </dependency>
         <dependency>
             <groupId>com.aliyun.odps</groupId>
@@ -172,6 +180,10 @@
                     <groupId>org.apache.hbase</groupId>
                     <artifactId>hbase-client</artifactId>
                 </exclusion>
+                <exclusion>
+                    <groupId>org.apache.calcite.avatica</groupId>
+                    <artifactId>avatica</artifactId>
+                </exclusion>
             </exclusions>
         </dependency>
         <dependency>
@@ -191,7 +203,67 @@
             <artifactId>elasticsearch-spark-20_2.11</artifactId>
             <version>6.0.0</version>
         </dependency>
+        <!-- <dependency>
+             <groupId>com.aliyun.phoenix</groupId>
+             <artifactId>ali-phoenix-shaded-thin-client</artifactId>
+             <version>${ali-phoenix.version}</version>
+         </dependency>-->
+
+
+        <!-- https://mvnrepository.com/artifact/org.apache.phoenix/phoenix-spark -->
+        <dependency>
+            <groupId>org.apache.phoenix</groupId>
+            <artifactId>phoenix-spark</artifactId>
+            <version>${phoenix.version}</version>
+            <exclusions>
+                <exclusion>
+                    <groupId>org.apache.spark</groupId>
+                    <artifactId>spark-core_2.11</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.spark</groupId>
+                    <artifactId>spark-launcher_2.11</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.spark</groupId>
+                    <artifactId>spark-kvstore_2.11</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.spark</groupId>
+                    <artifactId>spark-network-common_2.11</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.spark</groupId>
+                    <artifactId>spark-network-shuffle_2.11</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.spark</groupId>
+                    <artifactId>spark-unsafe_2.11</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.spark</groupId>
+                    <artifactId>spark-tags_2.11</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.spark</groupId>
+                    <artifactId>spark-sql_2.11</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.spark</groupId>
+                    <artifactId>spark-sketch_2.11</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.spark</groupId>
+                    <artifactId>spark-catalyst_2.11</artifactId>
+                </exclusion>
+            </exclusions>
+        </dependency>
 
+        <dependency>
+            <artifactId>protobuf-java</artifactId>
+            <groupId>com.google.protobuf</groupId>
+            <version>3.3.0</version>
+        </dependency>
     </dependencies>
 
     <build>
@@ -215,16 +287,6 @@
                                 </excludes>
                                 <includes>
                                     <include>*:*</include>
-                                    <!--<include>com.aliyun.odps:*</include>
-                                    <include>org.mongodb.*:*</include>
-                                    <include>org.apache.hbase:*</include>
-                                    <include>org.elasticsearch:*</include>-->
-                                    <!--                                    <include>cn.hutool:*</include>-->
-                                    <!--                                    <include>com.aliyun.odps:*</include>-->
-                                    <!--                                    <include>org.mongodb.*:*</include>-->
-                                    <!--                                    <include>org.apache.hbase:*</include>-->
-                                    <!--                                    <include>com.aliyun.hbase:*</include>-->
-                                    <!--                                    <include>com.alibaba.hbase:*</include>-->
                                 </includes>
                             </artifactSet>
                             <filters>

+ 40 - 0
src/main/scala/com/winhc/bigdata/spark/test/TestOps2Phoenix.scala

@@ -0,0 +1,40 @@
+package com.winhc.bigdata.spark.test
+
+import com.aliyun.odps.TableSchema
+import com.aliyun.odps.data.Record
+import com.winhc.bigdata.spark.utils.{PhoenixUtil, SparkUtils}
+import org.apache.spark.odps.OdpsOps
+
+import scala.collection.mutable
+
+/**
+ * @Author: XuJiakai
+ * @Date: 2020/6/3 17:17
+ * @Description:
+ */
+object TestOps2Phoenix {
+  def main(args: Array[String]): Unit = {
+    val map = mutable.Map[String, String](
+      "spark.sql.catalogImplementation" -> "in-memory"
+    )
+
+    val spark = SparkUtils.InitEnv("test ops to phoenix", map)
+    val odpsOps = new OdpsOps(spark.sparkContext)
+    import spark.implicits._
+    import spark._
+
+    val rdd_2 = odpsOps.readTable(
+      "winhc_test_dev",
+      "const_company_category_code",
+      (r: Record, schema: TableSchema) => (r.getBigint(0), r.getString(1), r.getString(2), r.getString(3), r.getString(4))
+    )
+    rdd_2.foreach(println(_))
+    sql(PhoenixUtil.getPhoenixTempView("tmp_table", "CONST_COMPANY_CATEGORY_CODE"))
+    rdd_2.toDF("\"id\"", "\"category_code\"", "\"category_str\"", "\"category_str_middle\"", "\"category_str_big\"")
+      .write.mode("append")
+      .insertInto("tmp_table")
+
+    spark.stop()
+  }
+
+}

+ 9 - 21
src/main/scala/com/winhc/bigdata/spark/test/TestSpark2AliPhoenix.scala

@@ -1,8 +1,5 @@
 package com.winhc.bigdata.spark.test
 
-import java.util.Properties
-
-import com.winhc.bigdata.spark.test.newPhoenixTest.{DB_PHOENIX_DRIVER, DB_PHOENIX_FETCHSIZE, DB_PHOENIX_PASS, DB_PHOENIX_URL, DB_PHOENIX_USER, SQL_QUERY}
 import com.winhc.bigdata.spark.utils.SparkUtils
 
 import scala.collection.mutable
@@ -13,25 +10,19 @@ import scala.collection.mutable
  * @Description:
  */
 object TestSpark2AliPhoenix {
-  private val DB_PHOENIX_DRIVER = "org.apache.phoenix.queryserver.client.Driver"
-  private val DB_PHOENIX_URL = "jdbc:phoenix:thin:url=http://hb-uf6as8i6h85k02092-001.hbase.rds.aliyuncs.com:8765;serialization=PROTOBUF"
-  private val DB_PHOENIX_USER = ""
-  private val DB_PHOENIX_PASS = ""
-  private val DB_PHOENIX_FETCHSIZE = "10000"
-
   def main(args: Array[String]): Unit = {
     val map = mutable.Map[String, String](
-      "spark.hadoop.odps.spark.local.partition.amt" -> "100"
+      "spark.sql.catalogImplementation" -> args(0)
     )
     val sparkSession = SparkUtils.InitEnv("scala spark on Phoenix5.x test", map)
-
     val phoenixTableName = "COMPANY_BID_LIST"
     val sparkTableName = "test_spark"
 
     val driver = "org.apache.phoenix.queryserver.client.Driver"
     val url = SparkUtils.PhoenixUrl
 
-/*
+    //    sparkSession.sql("select * from const_company_category_code").show()
+
     sparkSession.sql(s"drop table if exists $sparkTableName")
     val createCmd = "CREATE TABLE " +
       sparkTableName +
@@ -46,15 +37,12 @@ object TestSpark2AliPhoenix {
     sparkSession.sql(createCmd)
     val querySql = "select * from " + sparkTableName + " limit 100"
     sparkSession.sql(querySql).show
-*/
-// JDBC连接属性
-    val SQL_QUERY = " ( SELECT ID,NCID,CID,TITLE FROM COMPANY_BID_LIST limit 10 )  events  "
-    val connProp = new Properties
-    connProp.put("driver", DB_PHOENIX_DRIVER)
-    connProp.put("user", DB_PHOENIX_USER)
-    connProp.put("password", DB_PHOENIX_PASS)
-    connProp.put("fetchsize", DB_PHOENIX_FETCHSIZE)
-    sparkSession.read.jdbc(DB_PHOENIX_URL, SQL_QUERY, connProp).show
 
+//        sparkSession.createDataFrame(Seq(("rowkey", "2", null))).toDF("k", "s", "time")
+//          .write
+//          .mode("append")
+//          .insertInto(sparkTableName)
+
+    sparkSession.stop()
   }
 }

+ 5 - 3
src/main/scala/com/winhc/bigdata/spark/test/TestSpark2Phoenix.scala

@@ -1,6 +1,8 @@
 package com.winhc.bigdata.spark.test
 
-//import org.apache.phoenix.spark._
+import com.winhc.bigdata.spark.utils.SparkUtils
+
+import org.apache.phoenix.spark._
 
 /**
  * @Author: XuJiakai
@@ -9,7 +11,7 @@ package com.winhc.bigdata.spark.test
  */
 object TestSpark2Phoenix {
   def main(args: Array[String]): Unit = {
- /*   val spark = SparkUtils.InitEnv("testSpark2Phoenix")
+    val spark = SparkUtils.InitEnv("testSpark2Phoenix")
 
     val df1 = spark.sqlContext.phoenixTableAsDataFrame("\"company_abnormal_info\"", Seq("rowkey", "ncid"), zkUrl = Some(SparkUtils.PhoenixOptions("\"company_abnormal_info\"")("zkUrl")))
     df1.show()
@@ -18,6 +20,6 @@ object TestSpark2Phoenix {
 
     df.saveToPhoenix(SparkUtils.PhoenixOptions("\"company_abnormal_info\""))
 
-    spark.stop()*/
+    spark.stop()
   }
 }

+ 64 - 0
src/main/scala/com/winhc/bigdata/spark/test/TestSpark2PhoenixJDBC.scala

@@ -0,0 +1,64 @@
+package com.winhc.bigdata.spark.test
+
+import java.util
+import java.util.Properties
+
+import com.winhc.bigdata.spark.utils.SparkUtils
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType, UserDefinedType, VarcharType}
+
+import scala.collection.mutable
+
+/**
+ * @Author: XuJiakai
+ * @Date: 2020/6/3 15:53
+ * @Description:
+ */
+object TestSpark2PhoenixJDBC {
+  private val DB_PHOENIX_DRIVER = "org.apache.phoenix.queryserver.client.Driver"
+  private val DB_PHOENIX_USER = ""
+  private val DB_PHOENIX_PASS = ""
+  private val DB_PHOENIX_FETCHSIZE = "10000"
+
+
+  def main(args: Array[String]): Unit = {
+    val map = mutable.Map[String, String](
+      "spark.hadoop.odps.spark.local.partition.amt" -> "100"
+    )
+    val sparkSession = SparkUtils.InitEnv("scala spark on Phoenix5.x test", map)
+    val DB_PHOENIX_URL = SparkUtils.PhoenixUrl
+
+    // JDBC连接属性
+    val SQL_QUERY = " ( SELECT ID,NCID,CID,TITLE FROM COMPANY_BID_LIST limit 10 )  events  "
+    val connProp = new Properties
+    connProp.put("driver", DB_PHOENIX_DRIVER)
+    connProp.put("user", DB_PHOENIX_USER)
+    connProp.put("password", DB_PHOENIX_PASS)
+    connProp.put("fetchsize", DB_PHOENIX_FETCHSIZE)
+    val pDf = sparkSession.read.jdbc(DB_PHOENIX_URL, SQL_QUERY, connProp)
+    val sc = pDf.schema
+    println(sc)
+    pDf.printSchema()
+    pDf.show()
+    import sparkSession.implicits._
+    import sparkSession._
+
+    var dt:DataType = VarcharType(255)
+//    dt = StringType
+    val schema = StructType(Array(
+      StructField("k", dt, nullable = false),
+      StructField("s", dt, nullable = true),
+      StructField("time", dt, nullable = true)
+    )
+    )
+    val dataList = new util.ArrayList[Row]()
+    dataList.add(Row("1", "2", "null"))
+    val df = createDataFrame(dataList, schema)
+
+    df.write
+      .mode("append")
+      .jdbc(DB_PHOENIX_URL, "TEST_P", connProp)
+
+    sparkSession.stop()
+  }
+}

+ 10 - 0
src/main/scala/com/winhc/bigdata/spark/utils/BaseUtil.scala

@@ -0,0 +1,10 @@
+package com.winhc.bigdata.spark.utils
+
+/**
+ * @Author: XuJiakai
+ * @Date: 2020/6/3 18:49
+ * @Description:
+ */
+object BaseUtil {
+  def isWindows: Boolean = System.getProperty("os.name").contains("Windows")
+}

+ 33 - 0
src/main/scala/com/winhc/bigdata/spark/utils/PhoenixUtil.scala

@@ -0,0 +1,33 @@
+package com.winhc.bigdata.spark.utils
+
+import com.winhc.bigdata.spark.utils.BaseUtil.isWindows
+
+/**
+ * @Author: XuJiakai
+ * @Date: 2020/6/3 18:09
+ * @Description:
+ */
+object PhoenixUtil {
+  def getPhoenixUrl: String = {
+    var queryServerAddress: String = null
+    if (isWindows) {
+      queryServerAddress = "http://hb-proxy-pub-uf6as8i6h85k02092-001.hbase.rds.aliyuncs.com:8765"
+    } else {
+      queryServerAddress = "http://hb-uf6as8i6h85k02092-001.hbase.rds.aliyuncs.com:8765"
+    }
+    val url = "jdbc:phoenix:thin:url=" + queryServerAddress + ";serialization=PROTOBUF"
+    url
+  }
+
+  def getPhoenixTempView(tempViewTableName: String, phoenixTableName: String, fetchsize: Int = 100): String =
+    s"""
+       |CREATE TABLE $tempViewTableName USING org.apache.spark.sql.jdbc
+       |OPTIONS (
+       |  'driver' 'org.apache.phoenix.queryserver.client.Driver',
+       |  'url' '${getPhoenixUrl}',
+       |  'dbtable' '$phoenixTableName',
+       |  'fetchsize' '$fetchsize'
+       |)
+       |""".stripMargin
+
+}

+ 7 - 14
src/main/scala/com/winhc/bigdata/spark/utils/SparkUtils.scala

@@ -1,7 +1,8 @@
 package com.winhc.bigdata.spark.utils
 
-import org.apache.hadoop.hbase.{HBaseConfiguration, HConstants}
+import com.winhc.bigdata.spark.utils.BaseUtil._
 import org.apache.hadoop.hbase.mapred.TableOutputFormat
+import org.apache.hadoop.hbase.{HBaseConfiguration, HConstants}
 import org.apache.hadoop.mapred.JobConf
 import org.apache.spark.sql.SparkSession
 
@@ -9,19 +10,10 @@ import scala.collection.mutable
 
 object SparkUtils {
 
-  def PhoenixUrl: String = {
-    var queryServerAddress: String = null
-    if (System.getProperty("os.name").contains("Windows")) {
-      queryServerAddress = "http://hb-proxy-pub-uf6as8i6h85k02092-001.hbase.rds.aliyuncs.com:8765"
-    } else {
-      queryServerAddress = "http://hb-uf6as8i6h85k02092-001.hbase.rds.aliyuncs.com:8765"
-    }
-    val url = "jdbc:phoenix:thin:url=" + queryServerAddress + ";serialization=PROTOBUF"
-    url
-  }
+
 
   def PhoenixOptions(tableName: String): Map[String, String] = {
-    if (System.getProperty("os.name").contains("Windows")) {
+    if (isWindows) {
       import com.alibaba.dcm.DnsCacheManipulator
       DnsCacheManipulator.setDnsCache("hb-uf6as8i6h85k02092-001.hbase.rds.aliyuncs.com", "47.101.251.157")
       Map("table" -> tableName, "zkUrl" -> "hb-proxy-pub-uf6as8i6h85k02092-001.hbase.rds.aliyuncs.com:2181")
@@ -33,7 +25,7 @@ object SparkUtils {
   def HBaseOutputJobConf(outputTable: String): JobConf = {
     val config = HBaseConfiguration.create()
     var zkAddress: String = null
-    if (System.getProperty("os.name").contains("Windows")) {
+    if (isWindows) {
       zkAddress = "hb-proxy-pub-uf6as8i6h85k02092-001.hbase.rds.aliyuncs.com"
       import com.alibaba.dcm.DnsCacheManipulator
       DnsCacheManipulator.setDnsCache("hb-uf6as8i6h85k02092-001.hbase.rds.aliyuncs.com", "47.101.251.157")
@@ -69,7 +61,7 @@ object SparkUtils {
       .config("spark.hadoop.odps.runtime.end.point", "http://service.cn.maxcompute.aliyun-inc.com/api")
       .config("spark.hadoop.odps.cupid.vectorization.enable", false)
 
-    if (System.getProperty("os.name").contains("Windows")) {
+    if (isWindows) {
       spark.master("local[*]")
     }
     if (config != null) {
@@ -79,4 +71,5 @@ object SparkUtils {
     }
     spark.getOrCreate()
   }
+
 }

+ 857 - 0
src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala

@@ -0,0 +1,857 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.jdbc
+
+import java.sql.{Connection, Driver, DriverManager, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException}
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+import scala.util.Try
+import scala.util.control.NonFatal
+
+import org.apache.spark.TaskContext
+import org.apache.spark.executor.InputMetrics
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.SchemaUtils
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.NextIterator
+
+/**
+ * Util functions for JDBC tables.
+ */
+object JdbcUtils extends Logging {
+  /**
+   * Returns a factory for creating connections to the given JDBC URL.
+   *
+   * @param options - JDBC options that contains url, table and other information.
+   */
+  def createConnectionFactory(options: JDBCOptions): () => Connection = {
+    val driverClass: String = options.driverClass
+    () => {
+      DriverRegistry.register(driverClass)
+      val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
+        case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
+        case d if d.getClass.getCanonicalName == driverClass => d
+      }.getOrElse {
+        throw new IllegalStateException(
+          s"Did not find registered driver with class $driverClass")
+      }
+      driver.connect(options.url, options.asConnectionProperties)
+    }
+  }
+
+  /**
+   * Returns true if the table already exists in the JDBC database.
+   */
+  def tableExists(conn: Connection, options: JDBCOptions): Boolean = {
+    val dialect = JdbcDialects.get(options.url)
+
+    // Somewhat hacky, but there isn't a good way to identify whether a table exists for all
+    // SQL database systems using JDBC meta data calls, considering "table" could also include
+    // the database name. Query used to find table exists can be overridden by the dialects.
+    Try {
+      val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table))
+      try {
+        statement.executeQuery()
+      } finally {
+        statement.close()
+      }
+    }.isSuccess
+  }
+
+  /**
+   * Drops a table from the JDBC database.
+   */
+  def dropTable(conn: Connection, table: String): Unit = {
+    val statement = conn.createStatement
+    try {
+      statement.executeUpdate(s"DROP TABLE $table")
+    } finally {
+      statement.close()
+    }
+  }
+
+  /**
+   * Truncates a table from the JDBC database without side effects.
+   */
+  def truncateTable(conn: Connection, options: JDBCOptions): Unit = {
+    val dialect = JdbcDialects.get(options.url)
+    val statement = conn.createStatement
+    try {
+      statement.executeUpdate(dialect.getTruncateQuery(options.table))
+    } finally {
+      statement.close()
+    }
+  }
+
+  def isCascadingTruncateTable(url: String): Option[Boolean] = {
+    JdbcDialects.get(url).isCascadingTruncateTable()
+  }
+
+  /**
+   * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
+   */
+  def getInsertStatement(
+                          table: String,
+                          rddSchema: StructType,
+                          tableSchema: Option[StructType],
+                          isCaseSensitive: Boolean,
+                          dialect: JdbcDialect): String = {
+    val columns = if (tableSchema.isEmpty) {
+      rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
+    } else {
+      val columnNameEquality = if (isCaseSensitive) {
+        org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
+      } else {
+        org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
+      }
+      // The generated insert statement needs to follow rddSchema's column sequence and
+      // tableSchema's column names. When appending data into some case-sensitive DBMSs like
+      // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
+      // RDD column names for user convenience.
+      val tableColumnNames = tableSchema.get.fieldNames
+      rddSchema.fields.map { col =>
+        val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
+          throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
+        }
+        dialect.quoteIdentifier(normalizedName)
+      }.mkString(",")
+    }
+    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
+    s"INSERT INTO $table ($columns) VALUES ($placeholders)"
+  }
+
+  /**
+   * Retrieve standard jdbc types.
+   *
+   * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
+   * @return The default JdbcType for this DataType
+   */
+  def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
+    dt match {
+      case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
+      case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
+      case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
+      case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
+      case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
+      case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
+      case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
+      case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
+      case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
+      case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
+      case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
+      case t: DecimalType => Option(
+        JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
+      case _ => None
+    }
+  }
+
+  private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
+    dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
+      throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
+  }
+
+  /**
+   * Maps a JDBC type to a Catalyst type.  This function is called only when
+   * the JdbcDialect class corresponding to your database driver returns null.
+   *
+   * @param sqlType - A field of java.sql.Types
+   * @return The Catalyst type corresponding to sqlType.
+   */
+  private def getCatalystType(
+                               sqlType: Int,
+                               precision: Int,
+                               scale: Int,
+                               signed: Boolean): DataType = {
+    val answer = sqlType match {
+      // scalastyle:off
+      case java.sql.Types.ARRAY => null
+      case java.sql.Types.BIGINT => if (signed) {
+        LongType
+      } else {
+        DecimalType(20, 0)
+      }
+      case java.sql.Types.BINARY => BinaryType
+      case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
+      case java.sql.Types.BLOB => BinaryType
+      case java.sql.Types.BOOLEAN => BooleanType
+      case java.sql.Types.CHAR => StringType
+      case java.sql.Types.CLOB => StringType
+      case java.sql.Types.DATALINK => null
+      case java.sql.Types.DATE => DateType
+      case java.sql.Types.DECIMAL
+        if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
+      case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
+      case java.sql.Types.DISTINCT => null
+      case java.sql.Types.DOUBLE => DoubleType
+      case java.sql.Types.FLOAT => FloatType
+      case java.sql.Types.INTEGER => if (signed) {
+        IntegerType
+      } else {
+        LongType
+      }
+      case java.sql.Types.JAVA_OBJECT => null
+      case java.sql.Types.LONGNVARCHAR => StringType
+      case java.sql.Types.LONGVARBINARY => BinaryType
+      case java.sql.Types.LONGVARCHAR => StringType
+      case java.sql.Types.NCHAR => StringType
+      case java.sql.Types.NCLOB => StringType
+      case java.sql.Types.NULL => null
+      case java.sql.Types.NUMERIC
+        if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
+      case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
+      case java.sql.Types.NVARCHAR => StringType
+      case java.sql.Types.OTHER => null
+      case java.sql.Types.REAL => DoubleType
+      case java.sql.Types.REF => StringType
+      case java.sql.Types.REF_CURSOR => null
+      case java.sql.Types.ROWID => LongType
+      case java.sql.Types.SMALLINT => IntegerType
+      case java.sql.Types.SQLXML => StringType
+      case java.sql.Types.STRUCT => StringType
+      case java.sql.Types.TIME => TimestampType
+      case java.sql.Types.TIME_WITH_TIMEZONE
+      => null
+      case java.sql.Types.TIMESTAMP => TimestampType
+      case java.sql.Types.TIMESTAMP_WITH_TIMEZONE
+      => null
+      case java.sql.Types.TINYINT => IntegerType
+      case java.sql.Types.VARBINARY => BinaryType
+      case java.sql.Types.VARCHAR => StringType
+      case _ =>
+        throw new SQLException("Unrecognized SQL type " + sqlType)
+      // scalastyle:on
+    }
+
+    if (answer == null) {
+      throw new SQLException("Unsupported type " + JDBCType.valueOf(sqlType).getName)
+    }
+    answer
+  }
+
+  /**
+   * Returns the schema if the table already exists in the JDBC database.
+   */
+  def getSchemaOption(conn: Connection, options: JDBCOptions): Option[StructType] = {
+    val dialect = JdbcDialects.get(options.url)
+
+    try {
+      val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table))
+      try {
+        Some(getSchema(statement.executeQuery(), dialect))
+      } catch {
+        case _: SQLException => None
+      } finally {
+        statement.close()
+      }
+    } catch {
+      case _: SQLException => None
+    }
+  }
+
+  /**
+   * Takes a [[ResultSet]] and returns its Catalyst schema.
+   *
+   * @param alwaysNullable If true, all the columns are nullable.
+   * @return A [[StructType]] giving the Catalyst schema.
+   * @throws SQLException if the schema contains an unsupported type.
+   */
+  def getSchema(
+                 resultSet: ResultSet,
+                 dialect: JdbcDialect,
+                 alwaysNullable: Boolean = false): StructType = {
+    val rsmd = resultSet.getMetaData
+    val ncols = rsmd.getColumnCount
+    val fields = new Array[StructField](ncols)
+    var i = 0
+    while (i < ncols) {
+      val columnName = rsmd.getColumnLabel(i + 1)
+      val dataType = rsmd.getColumnType(i + 1)
+      val typeName = rsmd.getColumnTypeName(i + 1)
+      val fieldSize = rsmd.getPrecision(i + 1)
+      val fieldScale = rsmd.getScale(i + 1)
+      val isSigned = {
+        try {
+          rsmd.isSigned(i + 1)
+        } catch {
+          // Workaround for HIVE-14684:
+          case e: SQLException if
+          e.getMessage == "Method not supported" &&
+            rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true
+        }
+      }
+      val nullable = if (alwaysNullable) {
+        true
+      } else {
+        rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
+      }
+      val metadata = new MetadataBuilder().putLong("scale", fieldScale)
+      val columnType =
+        dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
+          getCatalystType(dataType, fieldSize, fieldScale, isSigned))
+      fields(i) = StructField(columnName, columnType, nullable)
+      i = i + 1
+    }
+    new StructType(fields)
+  }
+
+  /**
+   * Convert a [[ResultSet]] into an iterator of Catalyst Rows.
+   */
+  def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = {
+    val inputMetrics =
+      Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics)
+    val encoder = RowEncoder(schema).resolveAndBind()
+    val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics)
+    internalRows.map(encoder.fromRow)
+  }
+
+  private[spark] def resultSetToSparkInternalRows(
+                                                   resultSet: ResultSet,
+                                                   schema: StructType,
+                                                   inputMetrics: InputMetrics): Iterator[InternalRow] = {
+    new NextIterator[InternalRow] {
+      private[this] val rs = resultSet
+      private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
+      private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))
+
+      override protected def close(): Unit = {
+        try {
+          rs.close()
+        } catch {
+          case e: Exception => logWarning("Exception closing resultset", e)
+        }
+      }
+
+      override protected def getNext(): InternalRow = {
+        if (rs.next()) {
+          inputMetrics.incRecordsRead(1)
+          var i = 0
+          while (i < getters.length) {
+            getters(i).apply(rs, mutableRow, i)
+            if (rs.wasNull) mutableRow.setNullAt(i)
+            i = i + 1
+          }
+          mutableRow
+        } else {
+          finished = true
+          null.asInstanceOf[InternalRow]
+        }
+      }
+    }
+  }
+
+  // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
+  // for `MutableRow`. The last argument `Int` means the index for the value to be set in
+  // the row and also used for the value in `ResultSet`.
+  private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit
+
+  /**
+   * Creates `JDBCValueGetter`s according to [[StructType]], which can set
+   * each value from `ResultSet` to each field of [[InternalRow]] correctly.
+   */
+  private def makeGetters(schema: StructType): Array[JDBCValueGetter] =
+    schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
+
+  private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
+    case BooleanType =>
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        row.setBoolean(pos, rs.getBoolean(pos + 1))
+
+    case DateType =>
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
+        val dateVal = rs.getDate(pos + 1)
+        if (dateVal != null) {
+          row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
+        } else {
+          row.update(pos, null)
+        }
+
+    // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
+    // object returned by ResultSet.getBigDecimal is not correctly matched to the table
+    // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
+    // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
+    // a BigDecimal object with scale as 0. But the dataframe schema has correct type as
+    // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
+    // retrieve it, you will get wrong result 199.99.
+    // So it is needed to set precision and scale for Decimal based on JDBC metadata.
+    case DecimalType.Fixed(p, s) =>
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        val decimal =
+          nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
+        row.update(pos, decimal)
+
+    case DoubleType =>
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        row.setDouble(pos, rs.getDouble(pos + 1))
+
+    case FloatType =>
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        row.setFloat(pos, rs.getFloat(pos + 1))
+
+    case IntegerType =>
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        row.setInt(pos, rs.getInt(pos + 1))
+
+    case LongType if metadata.contains("binarylong") =>
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        val bytes = rs.getBytes(pos + 1)
+        var ans = 0L
+        var j = 0
+        while (j < bytes.length) {
+          ans = 256 * ans + (255 & bytes(j))
+          j = j + 1
+        }
+        row.setLong(pos, ans)
+
+    case LongType =>
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        row.setLong(pos, rs.getLong(pos + 1))
+
+    case ShortType =>
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        row.setShort(pos, rs.getShort(pos + 1))
+
+    case StringType =>
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        // TODO(davies): use getBytes for better performance, if the encoding is UTF-8
+        row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
+
+    case TimestampType =>
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        val t = rs.getTimestamp(pos + 1)
+        if (t != null) {
+          row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
+        } else {
+          row.update(pos, null)
+        }
+
+    case BinaryType =>
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        row.update(pos, rs.getBytes(pos + 1))
+
+    case ArrayType(et, _) =>
+      val elementConversion = et match {
+        case TimestampType =>
+          (array: Object) =>
+            array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
+              nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
+            }
+
+        case StringType =>
+          (array: Object) =>
+            // some underling types are not String such as uuid, inet, cidr, etc.
+            array.asInstanceOf[Array[java.lang.Object]]
+              .map(obj => if (obj == null) null else UTF8String.fromString(obj.toString))
+
+        case DateType =>
+          (array: Object) =>
+            array.asInstanceOf[Array[java.sql.Date]].map { date =>
+              nullSafeConvert(date, DateTimeUtils.fromJavaDate)
+            }
+
+        case dt: DecimalType =>
+          (array: Object) =>
+            array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
+              nullSafeConvert[java.math.BigDecimal](
+                decimal, d => Decimal(d, dt.precision, dt.scale))
+            }
+
+        case LongType if metadata.contains("binarylong") =>
+          throw new IllegalArgumentException(s"Unsupported array element " +
+            s"type ${dt.simpleString} based on binary")
+
+        case ArrayType(_, _) =>
+          throw new IllegalArgumentException("Nested arrays unsupported")
+
+        case _ => (array: Object) => array.asInstanceOf[Array[Any]]
+      }
+
+      (rs: ResultSet, row: InternalRow, pos: Int) =>
+        val array = nullSafeConvert[java.sql.Array](
+          input = rs.getArray(pos + 1),
+          array => new GenericArrayData(elementConversion.apply(array.getArray)))
+        row.update(pos, array)
+
+    case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
+  }
+
+  private def nullSafeConvert[T](input: T, f: T => Any): Any = {
+    if (input == null) {
+      null
+    } else {
+      f(input)
+    }
+  }
+
+  // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
+  // `PreparedStatement`. The last argument `Int` means the index for the value to be set
+  // in the SQL statement and also used for the value in `Row`.
+  private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit
+
+  private def makeSetter(
+                          conn: Connection,
+                          dialect: JdbcDialect,
+                          dataType: DataType): JDBCValueSetter = dataType match {
+    case IntegerType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setInt(pos + 1, row.getInt(pos))
+
+    case LongType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setLong(pos + 1, row.getLong(pos))
+
+    case DoubleType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setDouble(pos + 1, row.getDouble(pos))
+
+    case FloatType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setFloat(pos + 1, row.getFloat(pos))
+
+    case ShortType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setInt(pos + 1, row.getShort(pos))
+
+    case ByteType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setInt(pos + 1, row.getByte(pos))
+
+    case BooleanType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setBoolean(pos + 1, row.getBoolean(pos))
+
+    case StringType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setString(pos + 1, row.getString(pos))
+
+    case BinaryType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
+
+    case TimestampType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
+
+    case DateType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
+
+    case t: DecimalType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
+
+    case ArrayType(et, _) =>
+      // remove type length parameters from end of type name
+      val typeName = getJdbcType(et, dialect).databaseTypeDefinition
+        .toLowerCase(Locale.ROOT).split("\\(")(0)
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        val array = conn.createArrayOf(
+          typeName,
+          row.getSeq[AnyRef](pos).toArray)
+        stmt.setArray(pos + 1, array)
+
+    case _ =>
+      (_: PreparedStatement, _: Row, pos: Int) =>
+        throw new IllegalArgumentException(
+          s"Can't translate non-null value for field $pos")
+  }
+
+  /**
+   * Saves a partition of a DataFrame to the JDBC database.  This is done in
+   * a single database transaction (unless isolation level is "NONE")
+   * in order to avoid repeatedly inserting data as much as possible.
+   *
+   * It is still theoretically possible for rows in a DataFrame to be
+   * inserted into the database more than once if a stage somehow fails after
+   * the commit occurs but before the stage can return successfully.
+   *
+   * This is not a closure inside saveTable() because apparently cosmetic
+   * implementation changes elsewhere might easily render such a closure
+   * non-Serializable.  Instead, we explicitly close over all variables that
+   * are used.
+   */
+  def savePartition(
+                     getConnection: () => Connection,
+                     table: String,
+                     iterator: Iterator[Row],
+                     rddSchema: StructType,
+                     insertStmt: String,
+                     batchSize: Int,
+                     dialect: JdbcDialect,
+                     isolationLevel: Int): Iterator[Byte] = {
+    val conn = getConnection()
+    var committed = false
+
+    var finalIsolationLevel = Connection.TRANSACTION_NONE
+    if (isolationLevel != Connection.TRANSACTION_NONE) {
+      try {
+        val metadata = conn.getMetaData
+        if (metadata.supportsTransactions()) {
+          // Update to at least use the default isolation, if any transaction level
+          // has been chosen and transactions are supported
+          val defaultIsolation = metadata.getDefaultTransactionIsolation
+          finalIsolationLevel = defaultIsolation
+          if (metadata.supportsTransactionIsolationLevel(isolationLevel)) {
+            // Finally update to actually requested level if possible
+            finalIsolationLevel = isolationLevel
+          } else {
+            logWarning(s"Requested isolation level $isolationLevel is not supported; " +
+              s"falling back to default isolation level $defaultIsolation")
+          }
+        } else {
+          logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
+        }
+      } catch {
+        case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
+      }
+    }
+    val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE
+
+    try {
+      if (supportsTransactions) {
+        conn.setAutoCommit(false) // Everything in the same db transaction.
+        conn.setTransactionIsolation(finalIsolationLevel)
+      }
+      val stmt = conn.prepareStatement(insertStmt.replace("INSERT", "UPSERT"))
+      val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
+      val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
+      val numFields = rddSchema.fields.length
+
+      try {
+        var rowCount = 0
+        while (iterator.hasNext) {
+          val row = iterator.next()
+          var i = 0
+          while (i < numFields) {
+            if (row.isNullAt(i)) {
+              stmt.setNull(i + 1, nullTypes(i))
+            } else {
+              setters(i).apply(stmt, row, i)
+            }
+            i = i + 1
+          }
+          stmt.addBatch()
+          rowCount += 1
+          if (rowCount % batchSize == 0) {
+            stmt.executeBatch()
+            rowCount = 0
+          }
+        }
+        if (rowCount > 0) {
+          stmt.executeBatch()
+        }
+      } finally {
+        stmt.close()
+      }
+      if (supportsTransactions) {
+        conn.commit()
+      }
+      committed = true
+      Iterator.empty
+    } catch {
+      case e: SQLException =>
+        val cause = e.getNextException
+        if (cause != null && e.getCause != cause) {
+          // If there is no cause already, set 'next exception' as cause. If cause is null,
+          // it *may* be because no cause was set yet
+          if (e.getCause == null) {
+            try {
+              e.initCause(cause)
+            } catch {
+              // Or it may be null because the cause *was* explicitly initialized, to *null*,
+              // in which case this fails. There is no other way to detect it.
+              // addSuppressed in this case as well.
+              case _: IllegalStateException => e.addSuppressed(cause)
+            }
+          } else {
+            e.addSuppressed(cause)
+          }
+        }
+        throw e
+    } finally {
+      if (!committed) {
+        // The stage must fail.  We got here through an exception path, so
+        // let the exception through unless rollback() or close() want to
+        // tell the user about another problem.
+        if (supportsTransactions) {
+          conn.rollback()
+        }
+        conn.close()
+      } else {
+        // The stage must succeed.  We cannot propagate any exception close() might throw.
+        try {
+          conn.close()
+        } catch {
+          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
+        }
+      }
+    }
+  }
+
+  /**
+   * Compute the schema string for this RDD.
+   */
+  def schemaString(
+                    df: DataFrame,
+                    url: String,
+                    createTableColumnTypes: Option[String] = None): String = {
+    val sb = new StringBuilder()
+    val dialect = JdbcDialects.get(url)
+    val userSpecifiedColTypesMap = createTableColumnTypes
+      .map(parseUserSpecifiedCreateTableColumnTypes(df, _))
+      .getOrElse(Map.empty[String, String])
+    df.schema.fields.foreach { field =>
+      val name = dialect.quoteIdentifier(field.name)
+      val typ = userSpecifiedColTypesMap
+        .getOrElse(field.name, getJdbcType(field.dataType, dialect).databaseTypeDefinition)
+      val nullable = if (field.nullable) "" else "NOT NULL"
+      sb.append(s", $name $typ $nullable")
+    }
+    if (sb.length < 2) "" else sb.substring(2)
+  }
+
+  /**
+   * Parses the user specified createTableColumnTypes option value string specified in the same
+   * format as create table ddl column types, and returns Map of field name and the data type to
+   * use in-place of the default data type.
+   */
+  private def parseUserSpecifiedCreateTableColumnTypes(
+                                                        df: DataFrame,
+                                                        createTableColumnTypes: String): Map[String, String] = {
+    def typeName(f: StructField): String = {
+      // char/varchar gets translated to string type. Real data type specified by the user
+      // is available in the field metadata as HIVE_TYPE_STRING
+      if (f.metadata.contains(HIVE_TYPE_STRING)) {
+        f.metadata.getString(HIVE_TYPE_STRING)
+      } else {
+        f.dataType.catalogString
+      }
+    }
+
+    val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes)
+    val nameEquality = df.sparkSession.sessionState.conf.resolver
+
+    // checks duplicate columns in the user specified column types.
+    SchemaUtils.checkColumnNameDuplication(
+      userSchema.map(_.name), "in the createTableColumnTypes option value", nameEquality)
+
+    // checks if user specified column names exist in the DataFrame schema
+    userSchema.fieldNames.foreach { col =>
+      df.schema.find(f => nameEquality(f.name, col)).getOrElse {
+        throw new AnalysisException(
+          s"createTableColumnTypes option column $col not found in schema " +
+            df.schema.catalogString)
+      }
+    }
+
+    val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap
+    val isCaseSensitive = df.sparkSession.sessionState.conf.caseSensitiveAnalysis
+    if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap)
+  }
+
+  /**
+   * Parses the user specified customSchema option value to DataFrame schema, and
+   * returns a schema that is replaced by the custom schema's dataType if column name is matched.
+   */
+  def getCustomSchema(
+                       tableSchema: StructType,
+                       customSchema: String,
+                       nameEquality: Resolver): StructType = {
+    if (null != customSchema && customSchema.nonEmpty) {
+      val userSchema = CatalystSqlParser.parseTableSchema(customSchema)
+
+      SchemaUtils.checkColumnNameDuplication(
+        userSchema.map(_.name), "in the customSchema option value", nameEquality)
+
+      // This is resolved by names, use the custom filed dataType to replace the default dataType.
+      val newSchema = tableSchema.map { col =>
+        userSchema.find(f => nameEquality(f.name, col.name)) match {
+          case Some(c) => col.copy(dataType = c.dataType)
+          case None => col
+        }
+      }
+      StructType(newSchema)
+    } else {
+      tableSchema
+    }
+  }
+
+  /**
+   * Saves the RDD to the database in a single transaction.
+   */
+  def saveTable(
+                 df: DataFrame,
+                 tableSchema: Option[StructType],
+                 isCaseSensitive: Boolean,
+                 options: JDBCOptions): Unit = {
+    val url = options.url
+    val table = options.table
+    val dialect = JdbcDialects.get(url)
+    val rddSchema = df.schema
+    val getConnection: () => Connection = createConnectionFactory(options)
+    val batchSize = options.batchSize
+    val isolationLevel = options.isolationLevel
+
+    val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
+    val repartitionedDF = options.numPartitions match {
+      case Some(n) if n <= 0 => throw new IllegalArgumentException(
+        s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
+          "via JDBC. The minimum value is 1.")
+      case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
+      case _ => df
+    }
+    repartitionedDF.rdd.foreachPartition(iterator => savePartition(
+      getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
+    )
+  }
+
+  /**
+   * Creates a table with a given schema.
+   */
+  def createTable(
+                   conn: Connection,
+                   df: DataFrame,
+                   options: JDBCOptions): Unit = {
+    val strSchema = schemaString(
+      df, options.url, options.createTableColumnTypes)
+    val table = options.table
+    val createTableOptions = options.createTableOptions
+    // Create the table if the table does not exist.
+    // To allow certain options to append when create a new table, which can be
+    // table_options or partition_options.
+    // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
+    val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions"
+    val statement = conn.createStatement
+    try {
+      statement.executeUpdate(sql)
+    } finally {
+      statement.close()
+    }
+  }
+}