|
@@ -0,0 +1,857 @@
|
|
|
+/*
|
|
|
+ * Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
+ * contributor license agreements. See the NOTICE file distributed with
|
|
|
+ * this work for additional information regarding copyright ownership.
|
|
|
+ * The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
+ * (the "License"); you may not use this file except in compliance with
|
|
|
+ * the License. You may obtain a copy of the License at
|
|
|
+ *
|
|
|
+ * http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+ *
|
|
|
+ * Unless required by applicable law or agreed to in writing, software
|
|
|
+ * distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+ * See the License for the specific language governing permissions and
|
|
|
+ * limitations under the License.
|
|
|
+ */
|
|
|
+
|
|
|
+package org.apache.spark.sql.execution.datasources.jdbc
|
|
|
+
|
|
|
+import java.sql.{Connection, Driver, DriverManager, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException}
|
|
|
+import java.util.Locale
|
|
|
+
|
|
|
+import scala.collection.JavaConverters._
|
|
|
+import scala.util.Try
|
|
|
+import scala.util.control.NonFatal
|
|
|
+
|
|
|
+import org.apache.spark.TaskContext
|
|
|
+import org.apache.spark.executor.InputMetrics
|
|
|
+import org.apache.spark.internal.Logging
|
|
|
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
|
|
|
+import org.apache.spark.sql.catalyst.InternalRow
|
|
|
+import org.apache.spark.sql.catalyst.analysis.Resolver
|
|
|
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
|
|
+import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
|
|
|
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
|
|
|
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData}
|
|
|
+import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
|
|
|
+import org.apache.spark.sql.types._
|
|
|
+import org.apache.spark.sql.util.SchemaUtils
|
|
|
+import org.apache.spark.unsafe.types.UTF8String
|
|
|
+import org.apache.spark.util.NextIterator
|
|
|
+
|
|
|
+/**
|
|
|
+ * Util functions for JDBC tables.
|
|
|
+ */
|
|
|
+object JdbcUtils extends Logging {
|
|
|
+ /**
|
|
|
+ * Returns a factory for creating connections to the given JDBC URL.
|
|
|
+ *
|
|
|
+ * @param options - JDBC options that contains url, table and other information.
|
|
|
+ */
|
|
|
+ def createConnectionFactory(options: JDBCOptions): () => Connection = {
|
|
|
+ val driverClass: String = options.driverClass
|
|
|
+ () => {
|
|
|
+ DriverRegistry.register(driverClass)
|
|
|
+ val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
|
|
|
+ case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
|
|
|
+ case d if d.getClass.getCanonicalName == driverClass => d
|
|
|
+ }.getOrElse {
|
|
|
+ throw new IllegalStateException(
|
|
|
+ s"Did not find registered driver with class $driverClass")
|
|
|
+ }
|
|
|
+ driver.connect(options.url, options.asConnectionProperties)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Returns true if the table already exists in the JDBC database.
|
|
|
+ */
|
|
|
+ def tableExists(conn: Connection, options: JDBCOptions): Boolean = {
|
|
|
+ val dialect = JdbcDialects.get(options.url)
|
|
|
+
|
|
|
+ // Somewhat hacky, but there isn't a good way to identify whether a table exists for all
|
|
|
+ // SQL database systems using JDBC meta data calls, considering "table" could also include
|
|
|
+ // the database name. Query used to find table exists can be overridden by the dialects.
|
|
|
+ Try {
|
|
|
+ val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table))
|
|
|
+ try {
|
|
|
+ statement.executeQuery()
|
|
|
+ } finally {
|
|
|
+ statement.close()
|
|
|
+ }
|
|
|
+ }.isSuccess
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Drops a table from the JDBC database.
|
|
|
+ */
|
|
|
+ def dropTable(conn: Connection, table: String): Unit = {
|
|
|
+ val statement = conn.createStatement
|
|
|
+ try {
|
|
|
+ statement.executeUpdate(s"DROP TABLE $table")
|
|
|
+ } finally {
|
|
|
+ statement.close()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Truncates a table from the JDBC database without side effects.
|
|
|
+ */
|
|
|
+ def truncateTable(conn: Connection, options: JDBCOptions): Unit = {
|
|
|
+ val dialect = JdbcDialects.get(options.url)
|
|
|
+ val statement = conn.createStatement
|
|
|
+ try {
|
|
|
+ statement.executeUpdate(dialect.getTruncateQuery(options.table))
|
|
|
+ } finally {
|
|
|
+ statement.close()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ def isCascadingTruncateTable(url: String): Option[Boolean] = {
|
|
|
+ JdbcDialects.get(url).isCascadingTruncateTable()
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
|
|
|
+ */
|
|
|
+ def getInsertStatement(
|
|
|
+ table: String,
|
|
|
+ rddSchema: StructType,
|
|
|
+ tableSchema: Option[StructType],
|
|
|
+ isCaseSensitive: Boolean,
|
|
|
+ dialect: JdbcDialect): String = {
|
|
|
+ val columns = if (tableSchema.isEmpty) {
|
|
|
+ rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
|
|
|
+ } else {
|
|
|
+ val columnNameEquality = if (isCaseSensitive) {
|
|
|
+ org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
|
|
|
+ } else {
|
|
|
+ org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
|
|
|
+ }
|
|
|
+ // The generated insert statement needs to follow rddSchema's column sequence and
|
|
|
+ // tableSchema's column names. When appending data into some case-sensitive DBMSs like
|
|
|
+ // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
|
|
|
+ // RDD column names for user convenience.
|
|
|
+ val tableColumnNames = tableSchema.get.fieldNames
|
|
|
+ rddSchema.fields.map { col =>
|
|
|
+ val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
|
|
|
+ throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
|
|
|
+ }
|
|
|
+ dialect.quoteIdentifier(normalizedName)
|
|
|
+ }.mkString(",")
|
|
|
+ }
|
|
|
+ val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
|
|
|
+ s"INSERT INTO $table ($columns) VALUES ($placeholders)"
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Retrieve standard jdbc types.
|
|
|
+ *
|
|
|
+ * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
|
|
|
+ * @return The default JdbcType for this DataType
|
|
|
+ */
|
|
|
+ def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
|
|
|
+ dt match {
|
|
|
+ case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
|
|
|
+ case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
|
|
|
+ case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
|
|
|
+ case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
|
|
|
+ case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
|
|
|
+ case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
|
|
|
+ case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
|
|
|
+ case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
|
|
|
+ case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
|
|
|
+ case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
|
|
|
+ case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
|
|
|
+ case t: DecimalType => Option(
|
|
|
+ JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
|
|
|
+ case _ => None
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
|
|
|
+ dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
|
|
|
+ throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Maps a JDBC type to a Catalyst type. This function is called only when
|
|
|
+ * the JdbcDialect class corresponding to your database driver returns null.
|
|
|
+ *
|
|
|
+ * @param sqlType - A field of java.sql.Types
|
|
|
+ * @return The Catalyst type corresponding to sqlType.
|
|
|
+ */
|
|
|
+ private def getCatalystType(
|
|
|
+ sqlType: Int,
|
|
|
+ precision: Int,
|
|
|
+ scale: Int,
|
|
|
+ signed: Boolean): DataType = {
|
|
|
+ val answer = sqlType match {
|
|
|
+ // scalastyle:off
|
|
|
+ case java.sql.Types.ARRAY => null
|
|
|
+ case java.sql.Types.BIGINT => if (signed) {
|
|
|
+ LongType
|
|
|
+ } else {
|
|
|
+ DecimalType(20, 0)
|
|
|
+ }
|
|
|
+ case java.sql.Types.BINARY => BinaryType
|
|
|
+ case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
|
|
|
+ case java.sql.Types.BLOB => BinaryType
|
|
|
+ case java.sql.Types.BOOLEAN => BooleanType
|
|
|
+ case java.sql.Types.CHAR => StringType
|
|
|
+ case java.sql.Types.CLOB => StringType
|
|
|
+ case java.sql.Types.DATALINK => null
|
|
|
+ case java.sql.Types.DATE => DateType
|
|
|
+ case java.sql.Types.DECIMAL
|
|
|
+ if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
|
|
|
+ case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
|
|
|
+ case java.sql.Types.DISTINCT => null
|
|
|
+ case java.sql.Types.DOUBLE => DoubleType
|
|
|
+ case java.sql.Types.FLOAT => FloatType
|
|
|
+ case java.sql.Types.INTEGER => if (signed) {
|
|
|
+ IntegerType
|
|
|
+ } else {
|
|
|
+ LongType
|
|
|
+ }
|
|
|
+ case java.sql.Types.JAVA_OBJECT => null
|
|
|
+ case java.sql.Types.LONGNVARCHAR => StringType
|
|
|
+ case java.sql.Types.LONGVARBINARY => BinaryType
|
|
|
+ case java.sql.Types.LONGVARCHAR => StringType
|
|
|
+ case java.sql.Types.NCHAR => StringType
|
|
|
+ case java.sql.Types.NCLOB => StringType
|
|
|
+ case java.sql.Types.NULL => null
|
|
|
+ case java.sql.Types.NUMERIC
|
|
|
+ if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
|
|
|
+ case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
|
|
|
+ case java.sql.Types.NVARCHAR => StringType
|
|
|
+ case java.sql.Types.OTHER => null
|
|
|
+ case java.sql.Types.REAL => DoubleType
|
|
|
+ case java.sql.Types.REF => StringType
|
|
|
+ case java.sql.Types.REF_CURSOR => null
|
|
|
+ case java.sql.Types.ROWID => LongType
|
|
|
+ case java.sql.Types.SMALLINT => IntegerType
|
|
|
+ case java.sql.Types.SQLXML => StringType
|
|
|
+ case java.sql.Types.STRUCT => StringType
|
|
|
+ case java.sql.Types.TIME => TimestampType
|
|
|
+ case java.sql.Types.TIME_WITH_TIMEZONE
|
|
|
+ => null
|
|
|
+ case java.sql.Types.TIMESTAMP => TimestampType
|
|
|
+ case java.sql.Types.TIMESTAMP_WITH_TIMEZONE
|
|
|
+ => null
|
|
|
+ case java.sql.Types.TINYINT => IntegerType
|
|
|
+ case java.sql.Types.VARBINARY => BinaryType
|
|
|
+ case java.sql.Types.VARCHAR => StringType
|
|
|
+ case _ =>
|
|
|
+ throw new SQLException("Unrecognized SQL type " + sqlType)
|
|
|
+ // scalastyle:on
|
|
|
+ }
|
|
|
+
|
|
|
+ if (answer == null) {
|
|
|
+ throw new SQLException("Unsupported type " + JDBCType.valueOf(sqlType).getName)
|
|
|
+ }
|
|
|
+ answer
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Returns the schema if the table already exists in the JDBC database.
|
|
|
+ */
|
|
|
+ def getSchemaOption(conn: Connection, options: JDBCOptions): Option[StructType] = {
|
|
|
+ val dialect = JdbcDialects.get(options.url)
|
|
|
+
|
|
|
+ try {
|
|
|
+ val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table))
|
|
|
+ try {
|
|
|
+ Some(getSchema(statement.executeQuery(), dialect))
|
|
|
+ } catch {
|
|
|
+ case _: SQLException => None
|
|
|
+ } finally {
|
|
|
+ statement.close()
|
|
|
+ }
|
|
|
+ } catch {
|
|
|
+ case _: SQLException => None
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Takes a [[ResultSet]] and returns its Catalyst schema.
|
|
|
+ *
|
|
|
+ * @param alwaysNullable If true, all the columns are nullable.
|
|
|
+ * @return A [[StructType]] giving the Catalyst schema.
|
|
|
+ * @throws SQLException if the schema contains an unsupported type.
|
|
|
+ */
|
|
|
+ def getSchema(
|
|
|
+ resultSet: ResultSet,
|
|
|
+ dialect: JdbcDialect,
|
|
|
+ alwaysNullable: Boolean = false): StructType = {
|
|
|
+ val rsmd = resultSet.getMetaData
|
|
|
+ val ncols = rsmd.getColumnCount
|
|
|
+ val fields = new Array[StructField](ncols)
|
|
|
+ var i = 0
|
|
|
+ while (i < ncols) {
|
|
|
+ val columnName = rsmd.getColumnLabel(i + 1)
|
|
|
+ val dataType = rsmd.getColumnType(i + 1)
|
|
|
+ val typeName = rsmd.getColumnTypeName(i + 1)
|
|
|
+ val fieldSize = rsmd.getPrecision(i + 1)
|
|
|
+ val fieldScale = rsmd.getScale(i + 1)
|
|
|
+ val isSigned = {
|
|
|
+ try {
|
|
|
+ rsmd.isSigned(i + 1)
|
|
|
+ } catch {
|
|
|
+ // Workaround for HIVE-14684:
|
|
|
+ case e: SQLException if
|
|
|
+ e.getMessage == "Method not supported" &&
|
|
|
+ rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ val nullable = if (alwaysNullable) {
|
|
|
+ true
|
|
|
+ } else {
|
|
|
+ rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
|
|
|
+ }
|
|
|
+ val metadata = new MetadataBuilder().putLong("scale", fieldScale)
|
|
|
+ val columnType =
|
|
|
+ dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
|
|
|
+ getCatalystType(dataType, fieldSize, fieldScale, isSigned))
|
|
|
+ fields(i) = StructField(columnName, columnType, nullable)
|
|
|
+ i = i + 1
|
|
|
+ }
|
|
|
+ new StructType(fields)
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Convert a [[ResultSet]] into an iterator of Catalyst Rows.
|
|
|
+ */
|
|
|
+ def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = {
|
|
|
+ val inputMetrics =
|
|
|
+ Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics)
|
|
|
+ val encoder = RowEncoder(schema).resolveAndBind()
|
|
|
+ val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics)
|
|
|
+ internalRows.map(encoder.fromRow)
|
|
|
+ }
|
|
|
+
|
|
|
+ private[spark] def resultSetToSparkInternalRows(
|
|
|
+ resultSet: ResultSet,
|
|
|
+ schema: StructType,
|
|
|
+ inputMetrics: InputMetrics): Iterator[InternalRow] = {
|
|
|
+ new NextIterator[InternalRow] {
|
|
|
+ private[this] val rs = resultSet
|
|
|
+ private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
|
|
|
+ private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))
|
|
|
+
|
|
|
+ override protected def close(): Unit = {
|
|
|
+ try {
|
|
|
+ rs.close()
|
|
|
+ } catch {
|
|
|
+ case e: Exception => logWarning("Exception closing resultset", e)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ override protected def getNext(): InternalRow = {
|
|
|
+ if (rs.next()) {
|
|
|
+ inputMetrics.incRecordsRead(1)
|
|
|
+ var i = 0
|
|
|
+ while (i < getters.length) {
|
|
|
+ getters(i).apply(rs, mutableRow, i)
|
|
|
+ if (rs.wasNull) mutableRow.setNullAt(i)
|
|
|
+ i = i + 1
|
|
|
+ }
|
|
|
+ mutableRow
|
|
|
+ } else {
|
|
|
+ finished = true
|
|
|
+ null.asInstanceOf[InternalRow]
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
|
|
|
+ // for `MutableRow`. The last argument `Int` means the index for the value to be set in
|
|
|
+ // the row and also used for the value in `ResultSet`.
|
|
|
+ private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Creates `JDBCValueGetter`s according to [[StructType]], which can set
|
|
|
+ * each value from `ResultSet` to each field of [[InternalRow]] correctly.
|
|
|
+ */
|
|
|
+ private def makeGetters(schema: StructType): Array[JDBCValueGetter] =
|
|
|
+ schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
|
|
|
+
|
|
|
+ private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
|
|
|
+ case BooleanType =>
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ row.setBoolean(pos, rs.getBoolean(pos + 1))
|
|
|
+
|
|
|
+ case DateType =>
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
|
|
|
+ val dateVal = rs.getDate(pos + 1)
|
|
|
+ if (dateVal != null) {
|
|
|
+ row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
|
|
|
+ } else {
|
|
|
+ row.update(pos, null)
|
|
|
+ }
|
|
|
+
|
|
|
+ // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
|
|
|
+ // object returned by ResultSet.getBigDecimal is not correctly matched to the table
|
|
|
+ // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
|
|
|
+ // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
|
|
|
+ // a BigDecimal object with scale as 0. But the dataframe schema has correct type as
|
|
|
+ // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
|
|
|
+ // retrieve it, you will get wrong result 199.99.
|
|
|
+ // So it is needed to set precision and scale for Decimal based on JDBC metadata.
|
|
|
+ case DecimalType.Fixed(p, s) =>
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ val decimal =
|
|
|
+ nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
|
|
|
+ row.update(pos, decimal)
|
|
|
+
|
|
|
+ case DoubleType =>
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ row.setDouble(pos, rs.getDouble(pos + 1))
|
|
|
+
|
|
|
+ case FloatType =>
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ row.setFloat(pos, rs.getFloat(pos + 1))
|
|
|
+
|
|
|
+ case IntegerType =>
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ row.setInt(pos, rs.getInt(pos + 1))
|
|
|
+
|
|
|
+ case LongType if metadata.contains("binarylong") =>
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ val bytes = rs.getBytes(pos + 1)
|
|
|
+ var ans = 0L
|
|
|
+ var j = 0
|
|
|
+ while (j < bytes.length) {
|
|
|
+ ans = 256 * ans + (255 & bytes(j))
|
|
|
+ j = j + 1
|
|
|
+ }
|
|
|
+ row.setLong(pos, ans)
|
|
|
+
|
|
|
+ case LongType =>
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ row.setLong(pos, rs.getLong(pos + 1))
|
|
|
+
|
|
|
+ case ShortType =>
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ row.setShort(pos, rs.getShort(pos + 1))
|
|
|
+
|
|
|
+ case StringType =>
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ // TODO(davies): use getBytes for better performance, if the encoding is UTF-8
|
|
|
+ row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
|
|
|
+
|
|
|
+ case TimestampType =>
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ val t = rs.getTimestamp(pos + 1)
|
|
|
+ if (t != null) {
|
|
|
+ row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
|
|
|
+ } else {
|
|
|
+ row.update(pos, null)
|
|
|
+ }
|
|
|
+
|
|
|
+ case BinaryType =>
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ row.update(pos, rs.getBytes(pos + 1))
|
|
|
+
|
|
|
+ case ArrayType(et, _) =>
|
|
|
+ val elementConversion = et match {
|
|
|
+ case TimestampType =>
|
|
|
+ (array: Object) =>
|
|
|
+ array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
|
|
|
+ nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
|
|
|
+ }
|
|
|
+
|
|
|
+ case StringType =>
|
|
|
+ (array: Object) =>
|
|
|
+ // some underling types are not String such as uuid, inet, cidr, etc.
|
|
|
+ array.asInstanceOf[Array[java.lang.Object]]
|
|
|
+ .map(obj => if (obj == null) null else UTF8String.fromString(obj.toString))
|
|
|
+
|
|
|
+ case DateType =>
|
|
|
+ (array: Object) =>
|
|
|
+ array.asInstanceOf[Array[java.sql.Date]].map { date =>
|
|
|
+ nullSafeConvert(date, DateTimeUtils.fromJavaDate)
|
|
|
+ }
|
|
|
+
|
|
|
+ case dt: DecimalType =>
|
|
|
+ (array: Object) =>
|
|
|
+ array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
|
|
|
+ nullSafeConvert[java.math.BigDecimal](
|
|
|
+ decimal, d => Decimal(d, dt.precision, dt.scale))
|
|
|
+ }
|
|
|
+
|
|
|
+ case LongType if metadata.contains("binarylong") =>
|
|
|
+ throw new IllegalArgumentException(s"Unsupported array element " +
|
|
|
+ s"type ${dt.simpleString} based on binary")
|
|
|
+
|
|
|
+ case ArrayType(_, _) =>
|
|
|
+ throw new IllegalArgumentException("Nested arrays unsupported")
|
|
|
+
|
|
|
+ case _ => (array: Object) => array.asInstanceOf[Array[Any]]
|
|
|
+ }
|
|
|
+
|
|
|
+ (rs: ResultSet, row: InternalRow, pos: Int) =>
|
|
|
+ val array = nullSafeConvert[java.sql.Array](
|
|
|
+ input = rs.getArray(pos + 1),
|
|
|
+ array => new GenericArrayData(elementConversion.apply(array.getArray)))
|
|
|
+ row.update(pos, array)
|
|
|
+
|
|
|
+ case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
|
|
|
+ }
|
|
|
+
|
|
|
+ private def nullSafeConvert[T](input: T, f: T => Any): Any = {
|
|
|
+ if (input == null) {
|
|
|
+ null
|
|
|
+ } else {
|
|
|
+ f(input)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
|
|
|
+ // `PreparedStatement`. The last argument `Int` means the index for the value to be set
|
|
|
+ // in the SQL statement and also used for the value in `Row`.
|
|
|
+ private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit
|
|
|
+
|
|
|
+ private def makeSetter(
|
|
|
+ conn: Connection,
|
|
|
+ dialect: JdbcDialect,
|
|
|
+ dataType: DataType): JDBCValueSetter = dataType match {
|
|
|
+ case IntegerType =>
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ stmt.setInt(pos + 1, row.getInt(pos))
|
|
|
+
|
|
|
+ case LongType =>
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ stmt.setLong(pos + 1, row.getLong(pos))
|
|
|
+
|
|
|
+ case DoubleType =>
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ stmt.setDouble(pos + 1, row.getDouble(pos))
|
|
|
+
|
|
|
+ case FloatType =>
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ stmt.setFloat(pos + 1, row.getFloat(pos))
|
|
|
+
|
|
|
+ case ShortType =>
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ stmt.setInt(pos + 1, row.getShort(pos))
|
|
|
+
|
|
|
+ case ByteType =>
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ stmt.setInt(pos + 1, row.getByte(pos))
|
|
|
+
|
|
|
+ case BooleanType =>
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ stmt.setBoolean(pos + 1, row.getBoolean(pos))
|
|
|
+
|
|
|
+ case StringType =>
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ stmt.setString(pos + 1, row.getString(pos))
|
|
|
+
|
|
|
+ case BinaryType =>
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
|
|
|
+
|
|
|
+ case TimestampType =>
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
|
|
|
+
|
|
|
+ case DateType =>
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
|
|
|
+
|
|
|
+ case t: DecimalType =>
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
|
|
|
+
|
|
|
+ case ArrayType(et, _) =>
|
|
|
+ // remove type length parameters from end of type name
|
|
|
+ val typeName = getJdbcType(et, dialect).databaseTypeDefinition
|
|
|
+ .toLowerCase(Locale.ROOT).split("\\(")(0)
|
|
|
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
|
|
|
+ val array = conn.createArrayOf(
|
|
|
+ typeName,
|
|
|
+ row.getSeq[AnyRef](pos).toArray)
|
|
|
+ stmt.setArray(pos + 1, array)
|
|
|
+
|
|
|
+ case _ =>
|
|
|
+ (_: PreparedStatement, _: Row, pos: Int) =>
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ s"Can't translate non-null value for field $pos")
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Saves a partition of a DataFrame to the JDBC database. This is done in
|
|
|
+ * a single database transaction (unless isolation level is "NONE")
|
|
|
+ * in order to avoid repeatedly inserting data as much as possible.
|
|
|
+ *
|
|
|
+ * It is still theoretically possible for rows in a DataFrame to be
|
|
|
+ * inserted into the database more than once if a stage somehow fails after
|
|
|
+ * the commit occurs but before the stage can return successfully.
|
|
|
+ *
|
|
|
+ * This is not a closure inside saveTable() because apparently cosmetic
|
|
|
+ * implementation changes elsewhere might easily render such a closure
|
|
|
+ * non-Serializable. Instead, we explicitly close over all variables that
|
|
|
+ * are used.
|
|
|
+ */
|
|
|
+ def savePartition(
|
|
|
+ getConnection: () => Connection,
|
|
|
+ table: String,
|
|
|
+ iterator: Iterator[Row],
|
|
|
+ rddSchema: StructType,
|
|
|
+ insertStmt: String,
|
|
|
+ batchSize: Int,
|
|
|
+ dialect: JdbcDialect,
|
|
|
+ isolationLevel: Int): Iterator[Byte] = {
|
|
|
+ val conn = getConnection()
|
|
|
+ var committed = false
|
|
|
+
|
|
|
+ var finalIsolationLevel = Connection.TRANSACTION_NONE
|
|
|
+ if (isolationLevel != Connection.TRANSACTION_NONE) {
|
|
|
+ try {
|
|
|
+ val metadata = conn.getMetaData
|
|
|
+ if (metadata.supportsTransactions()) {
|
|
|
+ // Update to at least use the default isolation, if any transaction level
|
|
|
+ // has been chosen and transactions are supported
|
|
|
+ val defaultIsolation = metadata.getDefaultTransactionIsolation
|
|
|
+ finalIsolationLevel = defaultIsolation
|
|
|
+ if (metadata.supportsTransactionIsolationLevel(isolationLevel)) {
|
|
|
+ // Finally update to actually requested level if possible
|
|
|
+ finalIsolationLevel = isolationLevel
|
|
|
+ } else {
|
|
|
+ logWarning(s"Requested isolation level $isolationLevel is not supported; " +
|
|
|
+ s"falling back to default isolation level $defaultIsolation")
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
|
|
|
+ }
|
|
|
+ } catch {
|
|
|
+ case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE
|
|
|
+
|
|
|
+ try {
|
|
|
+ if (supportsTransactions) {
|
|
|
+ conn.setAutoCommit(false) // Everything in the same db transaction.
|
|
|
+ conn.setTransactionIsolation(finalIsolationLevel)
|
|
|
+ }
|
|
|
+ val stmt = conn.prepareStatement(insertStmt.replace("INSERT", "UPSERT"))
|
|
|
+ val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
|
|
|
+ val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
|
|
|
+ val numFields = rddSchema.fields.length
|
|
|
+
|
|
|
+ try {
|
|
|
+ var rowCount = 0
|
|
|
+ while (iterator.hasNext) {
|
|
|
+ val row = iterator.next()
|
|
|
+ var i = 0
|
|
|
+ while (i < numFields) {
|
|
|
+ if (row.isNullAt(i)) {
|
|
|
+ stmt.setNull(i + 1, nullTypes(i))
|
|
|
+ } else {
|
|
|
+ setters(i).apply(stmt, row, i)
|
|
|
+ }
|
|
|
+ i = i + 1
|
|
|
+ }
|
|
|
+ stmt.addBatch()
|
|
|
+ rowCount += 1
|
|
|
+ if (rowCount % batchSize == 0) {
|
|
|
+ stmt.executeBatch()
|
|
|
+ rowCount = 0
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (rowCount > 0) {
|
|
|
+ stmt.executeBatch()
|
|
|
+ }
|
|
|
+ } finally {
|
|
|
+ stmt.close()
|
|
|
+ }
|
|
|
+ if (supportsTransactions) {
|
|
|
+ conn.commit()
|
|
|
+ }
|
|
|
+ committed = true
|
|
|
+ Iterator.empty
|
|
|
+ } catch {
|
|
|
+ case e: SQLException =>
|
|
|
+ val cause = e.getNextException
|
|
|
+ if (cause != null && e.getCause != cause) {
|
|
|
+ // If there is no cause already, set 'next exception' as cause. If cause is null,
|
|
|
+ // it *may* be because no cause was set yet
|
|
|
+ if (e.getCause == null) {
|
|
|
+ try {
|
|
|
+ e.initCause(cause)
|
|
|
+ } catch {
|
|
|
+ // Or it may be null because the cause *was* explicitly initialized, to *null*,
|
|
|
+ // in which case this fails. There is no other way to detect it.
|
|
|
+ // addSuppressed in this case as well.
|
|
|
+ case _: IllegalStateException => e.addSuppressed(cause)
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ e.addSuppressed(cause)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ throw e
|
|
|
+ } finally {
|
|
|
+ if (!committed) {
|
|
|
+ // The stage must fail. We got here through an exception path, so
|
|
|
+ // let the exception through unless rollback() or close() want to
|
|
|
+ // tell the user about another problem.
|
|
|
+ if (supportsTransactions) {
|
|
|
+ conn.rollback()
|
|
|
+ }
|
|
|
+ conn.close()
|
|
|
+ } else {
|
|
|
+ // The stage must succeed. We cannot propagate any exception close() might throw.
|
|
|
+ try {
|
|
|
+ conn.close()
|
|
|
+ } catch {
|
|
|
+ case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Compute the schema string for this RDD.
|
|
|
+ */
|
|
|
+ def schemaString(
|
|
|
+ df: DataFrame,
|
|
|
+ url: String,
|
|
|
+ createTableColumnTypes: Option[String] = None): String = {
|
|
|
+ val sb = new StringBuilder()
|
|
|
+ val dialect = JdbcDialects.get(url)
|
|
|
+ val userSpecifiedColTypesMap = createTableColumnTypes
|
|
|
+ .map(parseUserSpecifiedCreateTableColumnTypes(df, _))
|
|
|
+ .getOrElse(Map.empty[String, String])
|
|
|
+ df.schema.fields.foreach { field =>
|
|
|
+ val name = dialect.quoteIdentifier(field.name)
|
|
|
+ val typ = userSpecifiedColTypesMap
|
|
|
+ .getOrElse(field.name, getJdbcType(field.dataType, dialect).databaseTypeDefinition)
|
|
|
+ val nullable = if (field.nullable) "" else "NOT NULL"
|
|
|
+ sb.append(s", $name $typ $nullable")
|
|
|
+ }
|
|
|
+ if (sb.length < 2) "" else sb.substring(2)
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Parses the user specified createTableColumnTypes option value string specified in the same
|
|
|
+ * format as create table ddl column types, and returns Map of field name and the data type to
|
|
|
+ * use in-place of the default data type.
|
|
|
+ */
|
|
|
+ private def parseUserSpecifiedCreateTableColumnTypes(
|
|
|
+ df: DataFrame,
|
|
|
+ createTableColumnTypes: String): Map[String, String] = {
|
|
|
+ def typeName(f: StructField): String = {
|
|
|
+ // char/varchar gets translated to string type. Real data type specified by the user
|
|
|
+ // is available in the field metadata as HIVE_TYPE_STRING
|
|
|
+ if (f.metadata.contains(HIVE_TYPE_STRING)) {
|
|
|
+ f.metadata.getString(HIVE_TYPE_STRING)
|
|
|
+ } else {
|
|
|
+ f.dataType.catalogString
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes)
|
|
|
+ val nameEquality = df.sparkSession.sessionState.conf.resolver
|
|
|
+
|
|
|
+ // checks duplicate columns in the user specified column types.
|
|
|
+ SchemaUtils.checkColumnNameDuplication(
|
|
|
+ userSchema.map(_.name), "in the createTableColumnTypes option value", nameEquality)
|
|
|
+
|
|
|
+ // checks if user specified column names exist in the DataFrame schema
|
|
|
+ userSchema.fieldNames.foreach { col =>
|
|
|
+ df.schema.find(f => nameEquality(f.name, col)).getOrElse {
|
|
|
+ throw new AnalysisException(
|
|
|
+ s"createTableColumnTypes option column $col not found in schema " +
|
|
|
+ df.schema.catalogString)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap
|
|
|
+ val isCaseSensitive = df.sparkSession.sessionState.conf.caseSensitiveAnalysis
|
|
|
+ if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap)
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Parses the user specified customSchema option value to DataFrame schema, and
|
|
|
+ * returns a schema that is replaced by the custom schema's dataType if column name is matched.
|
|
|
+ */
|
|
|
+ def getCustomSchema(
|
|
|
+ tableSchema: StructType,
|
|
|
+ customSchema: String,
|
|
|
+ nameEquality: Resolver): StructType = {
|
|
|
+ if (null != customSchema && customSchema.nonEmpty) {
|
|
|
+ val userSchema = CatalystSqlParser.parseTableSchema(customSchema)
|
|
|
+
|
|
|
+ SchemaUtils.checkColumnNameDuplication(
|
|
|
+ userSchema.map(_.name), "in the customSchema option value", nameEquality)
|
|
|
+
|
|
|
+ // This is resolved by names, use the custom filed dataType to replace the default dataType.
|
|
|
+ val newSchema = tableSchema.map { col =>
|
|
|
+ userSchema.find(f => nameEquality(f.name, col.name)) match {
|
|
|
+ case Some(c) => col.copy(dataType = c.dataType)
|
|
|
+ case None => col
|
|
|
+ }
|
|
|
+ }
|
|
|
+ StructType(newSchema)
|
|
|
+ } else {
|
|
|
+ tableSchema
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Saves the RDD to the database in a single transaction.
|
|
|
+ */
|
|
|
+ def saveTable(
|
|
|
+ df: DataFrame,
|
|
|
+ tableSchema: Option[StructType],
|
|
|
+ isCaseSensitive: Boolean,
|
|
|
+ options: JDBCOptions): Unit = {
|
|
|
+ val url = options.url
|
|
|
+ val table = options.table
|
|
|
+ val dialect = JdbcDialects.get(url)
|
|
|
+ val rddSchema = df.schema
|
|
|
+ val getConnection: () => Connection = createConnectionFactory(options)
|
|
|
+ val batchSize = options.batchSize
|
|
|
+ val isolationLevel = options.isolationLevel
|
|
|
+
|
|
|
+ val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
|
|
|
+ val repartitionedDF = options.numPartitions match {
|
|
|
+ case Some(n) if n <= 0 => throw new IllegalArgumentException(
|
|
|
+ s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
|
|
|
+ "via JDBC. The minimum value is 1.")
|
|
|
+ case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
|
|
|
+ case _ => df
|
|
|
+ }
|
|
|
+ repartitionedDF.rdd.foreachPartition(iterator => savePartition(
|
|
|
+ getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
|
|
|
+ )
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Creates a table with a given schema.
|
|
|
+ */
|
|
|
+ def createTable(
|
|
|
+ conn: Connection,
|
|
|
+ df: DataFrame,
|
|
|
+ options: JDBCOptions): Unit = {
|
|
|
+ val strSchema = schemaString(
|
|
|
+ df, options.url, options.createTableColumnTypes)
|
|
|
+ val table = options.table
|
|
|
+ val createTableOptions = options.createTableOptions
|
|
|
+ // Create the table if the table does not exist.
|
|
|
+ // To allow certain options to append when create a new table, which can be
|
|
|
+ // table_options or partition_options.
|
|
|
+ // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
|
|
|
+ val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions"
|
|
|
+ val statement = conn.createStatement
|
|
|
+ try {
|
|
|
+ statement.executeUpdate(sql)
|
|
|
+ } finally {
|
|
|
+ statement.close()
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|