Skip to content

Commit 2040335

Browse files
committed
[SPARK-54621][SQL] Merge Into Update Set * preserve nested fields if coerceNestedTypes is enabled
1 parent d9d7f1a commit 2040335

File tree

2 files changed

+189
-16
lines changed

2 files changed

+189
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala

Lines changed: 185 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import scala.collection.mutable
2121

2222
import org.apache.spark.sql.catalyst.SQLConfHelper
2323
import org.apache.spark.sql.catalyst.analysis.TableOutputResolver.DefaultValueFillMode.{NONE, RECURSE}
24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal}
24+
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal}
25+
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
2526
import org.apache.spark.sql.catalyst.plans.logical.Assignment
2627
import org.apache.spark.sql.catalyst.types.DataTypeUtils
2728
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
@@ -55,6 +56,7 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
5556
* (preserving existing fields).
5657
* @param coerceNestedTypes whether to coerce nested types to match the target type
5758
* for complex types
59+
* @param missingSourcePaths paths that exist in target but not in source
5860
* @return aligned update assignments that match table attributes
5961
*/
6062
def alignUpdateAssignments(
@@ -72,7 +74,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
7274
assignments,
7375
addError = err => errors += err,
7476
colPath = Seq(attr.name),
75-
coerceNestedTypes)
77+
coerceNestedTypes,
78+
fromStar)
7679
}
7780

7881
if (errors.nonEmpty) {
@@ -156,7 +159,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
156159
assignments: Seq[Assignment],
157160
addError: String => Unit,
158161
colPath: Seq[String],
159-
coerceNestedTypes: Boolean = false): Expression = {
162+
coerceNestedTypes: Boolean = false,
163+
updateStar: Boolean = false): Expression = {
160164

161165
val (exactAssignments, otherAssignments) = assignments.partition { assignment =>
162166
assignment.key.semanticEquals(colExpr)
@@ -178,11 +182,31 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
178182
} else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) {
179183
TableOutputResolver.checkNullability(colExpr, col, conf, colPath)
180184
} else if (exactAssignments.nonEmpty) {
181-
val value = exactAssignments.head.value
182-
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
183-
val resolvedValue = TableOutputResolver.resolveUpdate("", value, col, conf, addError,
184-
colPath, coerceMode)
185-
resolvedValue
185+
if (updateStar) {
186+
val value = exactAssignments.head.value
187+
col.dataType match {
188+
case structType: StructType =>
189+
// Expand assignments to leaf fields
190+
val structAssignment =
191+
applyNestedFieldAssignments(col, colExpr, value, addError, colPath,
192+
coerceNestedTypes)
193+
194+
// Wrap with null check for missing source fields
195+
fixNullExpansion(col, value, structType, structAssignment,
196+
colPath, addError)
197+
case _ =>
198+
// For non-struct types, resolve directly
199+
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
200+
TableOutputResolver.resolveUpdate("", value, col, conf, addError, colPath,
201+
coerceMode)
202+
}
203+
} else {
204+
val value = exactAssignments.head.value
205+
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
206+
val resolvedValue = TableOutputResolver.resolveUpdate("", value, col, conf, addError,
207+
colPath, coerceMode)
208+
resolvedValue
209+
}
186210
} else {
187211
applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath, coerceNestedTypes)
188212
}
@@ -194,7 +218,7 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
194218
assignments: Seq[Assignment],
195219
addError: String => Unit,
196220
colPath: Seq[String],
197-
coerceNestedTypes: Boolean): Expression = {
221+
coerceNestedTyptes: Boolean): Expression = {
198222

199223
col.dataType match {
200224
case structType: StructType =>
@@ -204,14 +228,71 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
204228
}
205229
val updatedFieldExprs = fieldAttrs.zip(fieldExprs).map { case (fieldAttr, fieldExpr) =>
206230
applyAssignments(fieldAttr, fieldExpr, assignments, addError, colPath :+ fieldAttr.name,
207-
coerceNestedTypes)
231+
coerceNestedTyptes)
232+
}
233+
toNamedStruct(structType, updatedFieldExprs)
234+
235+
case otherType =>
236+
addError(
237+
"Updating nested fields is only supported for StructType but " +
238+
s"'${colPath.quoted}' is of type $otherType")
239+
colExpr
240+
}
241+
}
242+
243+
private def applyNestedFieldAssignments(
244+
col: Attribute,
245+
colExpr: Expression,
246+
value: Expression,
247+
addError: String => Unit,
248+
colPath: Seq[String],
249+
coerceNestedTyptes: Boolean): Expression = {
250+
251+
col.dataType match {
252+
case structType: StructType =>
253+
val fieldAttrs = DataTypeUtils.toAttributes(structType)
254+
255+
val updatedFieldExprs = fieldAttrs.zipWithIndex.map { case (fieldAttr, ordinal) =>
256+
val fieldPath = colPath :+ fieldAttr.name
257+
val targetFieldExpr = GetStructField(colExpr, ordinal, Some(fieldAttr.name))
258+
259+
// Try to find a corresponding field in the source value by name
260+
val sourceFieldValue: Expression = value.dataType match {
261+
case valueStructType: StructType =>
262+
valueStructType.fields.find(f => conf.resolver(f.name, fieldAttr.name)) match {
263+
case Some(matchingField) =>
264+
// Found matching field in source, extract it
265+
val fieldIndex = valueStructType.fieldIndex(matchingField.name)
266+
GetStructField(value, fieldIndex, Some(matchingField.name))
267+
case None =>
268+
// Field doesn't exist in source, use target's current value with null check
269+
TableOutputResolver.checkNullability(targetFieldExpr, fieldAttr, conf, fieldPath)
270+
}
271+
case _ =>
272+
// Value is not a struct, cannot extract field
273+
addError(s"Cannot assign non-struct value to struct field '${fieldPath.quoted}'")
274+
Literal(null, fieldAttr.dataType)
275+
}
276+
277+
// Recurse or resolve based on field type
278+
fieldAttr.dataType match {
279+
case nestedStructType: StructType =>
280+
// Field is a struct, recurse
281+
applyNestedFieldAssignments(fieldAttr, targetFieldExpr, sourceFieldValue,
282+
addError, fieldPath, coerceNestedTyptes)
283+
case _ =>
284+
// Field is not a struct, resolve with TableOutputResolver
285+
val coerceMode = if (coerceNestedTyptes) RECURSE else NONE
286+
TableOutputResolver.resolveUpdate("", sourceFieldValue, fieldAttr, conf, addError,
287+
fieldPath, coerceMode)
288+
}
208289
}
209290
toNamedStruct(structType, updatedFieldExprs)
210291

211292
case otherType =>
212293
addError(
213294
"Updating nested fields is only supported for StructType but " +
214-
s"'${colPath.quoted}' is of type $otherType")
295+
s"'${colPath.quoted}' is of type $otherType")
215296
colExpr
216297
}
217298
}
@@ -223,6 +304,99 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
223304
CreateNamedStruct(namedStructExprs)
224305
}
225306

307+
private def getMissingSourcePaths(targetType: StructType,
308+
sourceType: DataType,
309+
colPath: Seq[String],
310+
addError: String => Unit): Seq[Seq[String]] = {
311+
val nestedTargetPaths = DataTypeUtils.extractLeafFieldPaths(targetType, Seq.empty)
312+
val nestedSourcePaths = sourceType match {
313+
case sourceStructType: StructType =>
314+
DataTypeUtils.extractLeafFieldPaths(sourceStructType, Seq.empty)
315+
case _ =>
316+
addError(s"Value for struct type: " +
317+
s"${colPath.quoted} must be a struct but was ${sourceType.simpleString}")
318+
Seq()
319+
}
320+
nestedSourcePaths.diff(nestedTargetPaths)
321+
}
322+
323+
/**
324+
* Creates a null check for a field at the given path within a struct expression.
325+
* Navigates through the struct hierarchy following the path and returns an IsNull check
326+
* for the final field.
327+
*
328+
* @param rootExpr the root expression to navigate from
329+
* @param path the field path to navigate (sequence of field names)
330+
* @return an IsNull expression checking if the field at the path is null
331+
*/
332+
private def createNullCheckForFieldPath(
333+
rootExpr: Expression,
334+
path: Seq[String]): Expression = {
335+
var currentExpr: Expression = rootExpr
336+
path.foreach { fieldName =>
337+
currentExpr.dataType match {
338+
case st: StructType =>
339+
st.fields.find(f => conf.resolver(f.name, fieldName)) match {
340+
case Some(field) =>
341+
val fieldIndex = st.fieldIndex(field.name)
342+
currentExpr = GetStructField(currentExpr, fieldIndex, Some(field.name))
343+
case None => // No-op, should error later in TableOutputResolver
344+
}
345+
case _ => // Not a struct- no-op, should error later in TableOutputResolver
346+
}
347+
}
348+
IsNull(currentExpr)
349+
}
350+
351+
/**
352+
* As UPDATE SET * can assign struct fields individually (preserving existing fields),
353+
* this will lead to null expansion, ie, a struct is created where all fields are null.
354+
* Wraps a struct assignment with null checks for the source and missing source fields.
355+
* Return null if all are null.
356+
*
357+
* @param col the target column attribute
358+
* @param value the source value expression
359+
* @param structType the target struct type
360+
* @param structAssignment the struct assignment result to wrap
361+
* @param colPath the column path for error reporting
362+
* @param addError error reporting function
363+
* @return the wrapped expression with null checks
364+
*/
365+
private def fixNullExpansion(
366+
col: Attribute,
367+
value: Expression,
368+
structType: StructType,
369+
structAssignment: Expression,
370+
colPath: Seq[String],
371+
addError: String => Unit): Expression = {
372+
// As StoreAssignmentPolicy.LEGACY is not allowed in DSv2, always add null check for
373+
// non-nullable column
374+
if (!col.nullable) {
375+
AssertNotNull(value)
376+
} else {
377+
// Check if source struct is null
378+
val valueIsNull = IsNull(value)
379+
380+
// Check if missing source paths (paths in target but not in source) are not null
381+
// These will be null for the case of UPDATE SET * and
382+
val missingSourcePaths = getMissingSourcePaths(structType, value.dataType, colPath, addError)
383+
val condition = if (missingSourcePaths.nonEmpty) {
384+
// Check if all target attributes at missing source paths are null
385+
val missingFieldNullChecks = missingSourcePaths.map { path =>
386+
createNullCheckForFieldPath(col, path)
387+
}
388+
// Combine all null checks with AND
389+
val allMissingFieldsNull = missingFieldNullChecks.reduce[Expression]((a, b) => And(a, b))
390+
And(valueIsNull, allMissingFieldsNull)
391+
} else {
392+
valueIsNull
393+
}
394+
395+
// Return: If (condition) THEN NULL ELSE structAssignment
396+
If(condition, Literal(null, structAssignment.dataType), structAssignment)
397+
}
398+
}
399+
226400
/**
227401
* Checks whether assignments are aligned and compatible with table columns.
228402
*

sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3240,9 +3240,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
32403240
checkAnswer(
32413241
sql(s"SELECT * FROM $tableNameAsString"),
32423242
Seq(
3243-
Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"),
3243+
Row(1, Row(10, Row(Seq(1, 2), Map("c" -> "d"), false)), "sales"),
32443244
Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering")))
3245-
32463245
} else {
32473246
val exception = intercept[org.apache.spark.sql.AnalysisException] {
32483247
sql(mergeStmt)
@@ -5258,8 +5257,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
52585257
checkAnswer(
52595258
sql(s"SELECT * FROM $tableNameAsString"),
52605259
Seq(
5261-
Row(1, Row(10, Row(20, null)), "sales"),
5262-
Row(2, Row(20, Row(30, null)), "engineering")))
5260+
Row(1, Row(10, Row(20, true)), "sales"),
5261+
Row(2, Row(20, Row(30, false)), "engineering")))
52635262
} else {
52645263
val exception = intercept[org.apache.spark.sql.AnalysisException] {
52655264
sql(mergeStmt)
@@ -5918,7 +5917,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
59185917
checkAnswer(
59195918
sql(s"SELECT * FROM $tableNameAsString"),
59205919
Seq(
5921-
Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"),
5920+
Row(1, Row(10, Row(Seq(1, 2), Map("c" -> "d"), false)), "sales"),
59225921
Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering")))
59235922
} else {
59245923
val exception = intercept[org.apache.spark.sql.AnalysisException] {

0 commit comments

Comments
 (0)