Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Normalize NaN and zeros for floating number comparison #953

Merged
merged 5 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/user-guide/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ Comet uses the Rust regexp crate for evaluating regular expressions, and this ha
regular expression engine. Comet will fall back to Spark for patterns that are known to produce different results, but
this can be overridden by setting `spark.comet.regexp.allowIncompatible=true`.

## Floating number comparison

Spark normalizes NaN and zero for floating point numbers for several cases. See `NormalizeFloatingNumbers` optimization rule in Spark.
However, one exception is comparison. Spark does not normalize NaN and zero when comparing values
because they are handled well in Spark (e.g., `SQLOrderingUtil.compareFloats`). But the comparison
functions of arrow-rs used by DataFusion do not normalize NaN and zero (e.g., [arrow::compute::kernels::cmp::eq](https://docs.rs/arrow/latest/arrow/compute/kernels/cmp/fn.eq.html#)).
So Comet will add additional normalization expression of NaN and zero for comparison.

## Cast

Cast operations in Comet fall into three levels of support:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.{EqualNullSafe, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.comet._
Expand All @@ -47,6 +48,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DoubleType, FloatType}

import org.apache.comet.CometConf._
import org.apache.comet.CometExplainInfo.getActualPlan
Expand Down Expand Up @@ -840,6 +842,52 @@ class CometSparkSessionExtensions
}
}

def normalizePlan(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case p: ProjectExec =>
val newProjectList = p.projectList.map(normalize(_).asInstanceOf[NamedExpression])
ProjectExec(newProjectList, p.child)
case f: FilterExec =>
val newCondition = normalize(f.condition)
FilterExec(newCondition, f.child)
}
}

// Spark will normalize NaN and zero for floating point numbers for several cases.
// See `NormalizeFloatingNumbers` optimization rule in Spark.
// However, one exception is for comparison operators. Spark does not normalize NaN and zero
// because they are handled well in Spark (e.g., `SQLOrderingUtil.compareFloats`). But the
// comparison functions in arrow-rs do not normalize NaN and zero. So we need to normalize NaN
// and zero for comparison operators in Comet.
def normalize(expr: Expression): Expression = {
expr.transformUp {
case EqualTo(left, right) =>
EqualTo(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
case EqualNullSafe(left, right) =>
EqualNullSafe(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
case GreaterThan(left, right) =>
GreaterThan(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
case GreaterThanOrEqual(left, right) =>
GreaterThanOrEqual(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
case LessThan(left, right) =>
LessThan(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
case LessThanOrEqual(left, right) =>
LessThanOrEqual(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
}
}

def normalizeNaNAndZero(expr: Expression): Expression = {
expr match {
case _: KnownFloatingPointNormalized => expr
case _ =>
expr.dataType match {
case _: FloatType | _: DoubleType =>
KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))
case _ => expr
}
}
}

override def apply(plan: SparkPlan): SparkPlan = {
// DataFusion doesn't have ANSI mode. For now we just disable CometExec if ANSI mode is
// enabled.
Expand All @@ -865,7 +913,7 @@ class CometSparkSessionExtensions
plan
}
} else {
var newPlan = transform(plan)
var newPlan = transform(normalizePlan(plan))

// if the plan cannot be run fully natively then explain why (when appropriate
// config is enabled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Functions [2]: [sum(CASE WHEN (d_date#12 < 2000-03-11) THEN inv_quantity_on_hand

(22) CometFilter
Input [4]: [w_warehouse_name#7, i_item_id#9, inv_before#15, inv_after#16]
Condition : (CASE WHEN (inv_before#15 > 0) THEN ((cast(inv_after#16 as double) / cast(inv_before#15 as double)) >= 0.666667) END AND CASE WHEN (inv_before#15 > 0) THEN ((cast(inv_after#16 as double) / cast(inv_before#15 as double)) <= 1.5) END)
Condition : (CASE WHEN (inv_before#15 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(inv_after#16 as double) / cast(inv_before#15 as double)))) >= knownfloatingpointnormalized(normalizenanandzero(0.666667))) END AND CASE WHEN (inv_before#15 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(inv_after#16 as double) / cast(inv_before#15 as double)))) <= knownfloatingpointnormalized(normalizenanandzero(1.5))) END)

(23) CometTakeOrderedAndProject
Input [4]: [w_warehouse_name#7, i_item_id#9, inv_before#15, inv_after#16]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ ReadSchema: struct<hd_demo_sk:int,hd_buy_potential:string,hd_dep_count:int,hd_ve

(16) CometFilter
Input [4]: [hd_demo_sk#12, hd_buy_potential#13, hd_dep_count#14, hd_vehicle_count#15]
Condition : ((((isnotnull(hd_vehicle_count#15) AND ((hd_buy_potential#13 = >10000 ) OR (hd_buy_potential#13 = unknown ))) AND (hd_vehicle_count#15 > 0)) AND CASE WHEN (hd_vehicle_count#15 > 0) THEN ((cast(hd_dep_count#14 as double) / cast(hd_vehicle_count#15 as double)) > 1.2) END) AND isnotnull(hd_demo_sk#12))
Condition : ((((isnotnull(hd_vehicle_count#15) AND ((hd_buy_potential#13 = >10000 ) OR (hd_buy_potential#13 = unknown ))) AND (hd_vehicle_count#15 > 0)) AND CASE WHEN (hd_vehicle_count#15 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(hd_dep_count#14 as double) / cast(hd_vehicle_count#15 as double)))) > knownfloatingpointnormalized(normalizenanandzero(1.2))) END) AND isnotnull(hd_demo_sk#12))

(17) CometProject
Input [4]: [hd_demo_sk#12, hd_buy_potential#13, hd_dep_count#14, hd_vehicle_count#15]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ Functions [2]: [stddev_samp(cast(inv_quantity_on_hand#3 as double)), avg(inv_qua

(22) CometFilter
Input [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stdev#17, mean#18]
Condition : CASE WHEN (mean#18 = 0.0) THEN false ELSE ((stdev#17 / mean#18) > 1.0) END
Condition : CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#17 / mean#18))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END

(23) CometProject
Input [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stdev#17, mean#18]
Arguments: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, cov#19], [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, CASE WHEN (mean#18 = 0.0) THEN null ELSE (stdev#17 / mean#18) END AS cov#19]
Arguments: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, cov#19], [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN null ELSE (stdev#17 / mean#18) END AS cov#19]

(24) Scan parquet spark_catalog.default.inventory
Output [4]: [inv_item_sk#20, inv_warehouse_sk#21, inv_quantity_on_hand#22, inv_date_sk#23]
Expand Down Expand Up @@ -238,11 +238,11 @@ Functions [2]: [stddev_samp(cast(inv_quantity_on_hand#22 as double)), avg(inv_qu

(41) CometFilter
Input [5]: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, stdev#17, mean#18]
Condition : CASE WHEN (mean#18 = 0.0) THEN false ELSE ((stdev#17 / mean#18) > 1.0) END
Condition : CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#17 / mean#18))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END

(42) CometProject
Input [5]: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, stdev#17, mean#18]
Arguments: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#36, cov#37], [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#18 AS mean#36, CASE WHEN (mean#18 = 0.0) THEN null ELSE (stdev#17 / mean#18) END AS cov#37]
Arguments: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#36, cov#37], [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#18 AS mean#36, CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN null ELSE (stdev#17 / mean#18) END AS cov#37]

(43) CometBroadcastExchange
Input [5]: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#36, cov#37]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ Functions [2]: [stddev_samp(cast(inv_quantity_on_hand#3 as double)), avg(inv_qua

(22) CometFilter
Input [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stdev#17, mean#18]
Condition : (CASE WHEN (mean#18 = 0.0) THEN false ELSE ((stdev#17 / mean#18) > 1.0) END AND CASE WHEN (mean#18 = 0.0) THEN false ELSE ((stdev#17 / mean#18) > 1.5) END)
Condition : (CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#17 / mean#18))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END AND CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#17 / mean#18))) > knownfloatingpointnormalized(normalizenanandzero(1.5))) END)

(23) CometProject
Input [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stdev#17, mean#18]
Arguments: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, cov#19], [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, CASE WHEN (mean#18 = 0.0) THEN null ELSE (stdev#17 / mean#18) END AS cov#19]
Arguments: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, cov#19], [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#18, CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN null ELSE (stdev#17 / mean#18) END AS cov#19]

(24) Scan parquet spark_catalog.default.inventory
Output [4]: [inv_item_sk#20, inv_warehouse_sk#21, inv_quantity_on_hand#22, inv_date_sk#23]
Expand Down Expand Up @@ -238,11 +238,11 @@ Functions [2]: [stddev_samp(cast(inv_quantity_on_hand#22 as double)), avg(inv_qu

(41) CometFilter
Input [5]: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, stdev#17, mean#18]
Condition : CASE WHEN (mean#18 = 0.0) THEN false ELSE ((stdev#17 / mean#18) > 1.0) END
Condition : CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#17 / mean#18))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END

(42) CometProject
Input [5]: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, stdev#17, mean#18]
Arguments: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#36, cov#37], [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#18 AS mean#36, CASE WHEN (mean#18 = 0.0) THEN null ELSE (stdev#17 / mean#18) END AS cov#37]
Arguments: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#36, cov#37], [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#18 AS mean#36, CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#18)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN null ELSE (stdev#17 / mean#18) END AS cov#37]

(43) CometBroadcastExchange
Input [5]: [w_warehouse_sk#26, i_item_sk#25, d_moy#30, mean#36, cov#37]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ ReadSchema: struct<hd_demo_sk:int,hd_buy_potential:string,hd_dep_count:int,hd_ve

(16) CometFilter
Input [4]: [hd_demo_sk#12, hd_buy_potential#13, hd_dep_count#14, hd_vehicle_count#15]
Condition : ((((isnotnull(hd_vehicle_count#15) AND ((hd_buy_potential#13 = >10000 ) OR (hd_buy_potential#13 = unknown ))) AND (hd_vehicle_count#15 > 0)) AND CASE WHEN (hd_vehicle_count#15 > 0) THEN ((cast(hd_dep_count#14 as double) / cast(hd_vehicle_count#15 as double)) > 1.0) END) AND isnotnull(hd_demo_sk#12))
Condition : ((((isnotnull(hd_vehicle_count#15) AND ((hd_buy_potential#13 = >10000 ) OR (hd_buy_potential#13 = unknown ))) AND (hd_vehicle_count#15 > 0)) AND CASE WHEN (hd_vehicle_count#15 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(hd_dep_count#14 as double) / cast(hd_vehicle_count#15 as double)))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END) AND isnotnull(hd_demo_sk#12))

(17) CometProject
Input [4]: [hd_demo_sk#12, hd_buy_potential#13, hd_dep_count#14, hd_vehicle_count#15]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Results [4]: [w_warehouse_name#7, i_item_id#9, sum(CASE WHEN (d_date#12 < 2000-0

(23) Filter [codegen id : 2]
Input [4]: [w_warehouse_name#7, i_item_id#9, inv_before#19, inv_after#20]
Condition : (CASE WHEN (inv_before#19 > 0) THEN ((cast(inv_after#20 as double) / cast(inv_before#19 as double)) >= 0.666667) END AND CASE WHEN (inv_before#19 > 0) THEN ((cast(inv_after#20 as double) / cast(inv_before#19 as double)) <= 1.5) END)
Condition : (CASE WHEN (inv_before#19 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(inv_after#20 as double) / cast(inv_before#19 as double)))) >= knownfloatingpointnormalized(normalizenanandzero(0.666667))) END AND CASE WHEN (inv_before#19 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(inv_after#20 as double) / cast(inv_before#19 as double)))) <= knownfloatingpointnormalized(normalizenanandzero(1.5))) END)

(24) TakeOrderedAndProject
Input [4]: [w_warehouse_name#7, i_item_id#9, inv_before#19, inv_after#20]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ ReadSchema: struct<hd_demo_sk:int,hd_buy_potential:string,hd_dep_count:int,hd_ve

(16) CometFilter
Input [4]: [hd_demo_sk#12, hd_buy_potential#13, hd_dep_count#14, hd_vehicle_count#15]
Condition : ((((isnotnull(hd_vehicle_count#15) AND ((hd_buy_potential#13 = >10000 ) OR (hd_buy_potential#13 = unknown ))) AND (hd_vehicle_count#15 > 0)) AND CASE WHEN (hd_vehicle_count#15 > 0) THEN ((cast(hd_dep_count#14 as double) / cast(hd_vehicle_count#15 as double)) > 1.2) END) AND isnotnull(hd_demo_sk#12))
Condition : ((((isnotnull(hd_vehicle_count#15) AND ((hd_buy_potential#13 = >10000 ) OR (hd_buy_potential#13 = unknown ))) AND (hd_vehicle_count#15 > 0)) AND CASE WHEN (hd_vehicle_count#15 > 0) THEN (knownfloatingpointnormalized(normalizenanandzero((cast(hd_dep_count#14 as double) / cast(hd_vehicle_count#15 as double)))) > knownfloatingpointnormalized(normalizenanandzero(1.2))) END) AND isnotnull(hd_demo_sk#12))

(17) CometProject
Input [4]: [hd_demo_sk#12, hd_buy_potential#13, hd_dep_count#14, hd_vehicle_count#15]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ Results [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stddev_samp(cast(inv_quan

(23) Filter [codegen id : 4]
Input [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stdev#24, mean#25]
Condition : CASE WHEN (mean#25 = 0.0) THEN false ELSE ((stdev#24 / mean#25) > 1.0) END
Condition : CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#25)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#24 / mean#25))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END

(24) Project [codegen id : 4]
Output [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#25, CASE WHEN (mean#25 = 0.0) THEN null ELSE (stdev#24 / mean#25) END AS cov#26]
Output [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, mean#25, CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#25)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN null ELSE (stdev#24 / mean#25) END AS cov#26]
Input [5]: [w_warehouse_sk#7, i_item_sk#6, d_moy#11, stdev#24, mean#25]

(25) Scan parquet spark_catalog.default.inventory
Expand Down Expand Up @@ -255,10 +255,10 @@ Results [5]: [w_warehouse_sk#33, i_item_sk#32, d_moy#37, stddev_samp(cast(inv_qu

(43) Filter [codegen id : 3]
Input [5]: [w_warehouse_sk#33, i_item_sk#32, d_moy#37, stdev#48, mean#49]
Condition : CASE WHEN (mean#49 = 0.0) THEN false ELSE ((stdev#48 / mean#49) > 1.0) END
Condition : CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#49)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN false ELSE (knownfloatingpointnormalized(normalizenanandzero((stdev#48 / mean#49))) > knownfloatingpointnormalized(normalizenanandzero(1.0))) END

(44) Project [codegen id : 3]
Output [5]: [w_warehouse_sk#33, i_item_sk#32, d_moy#37, mean#49, CASE WHEN (mean#49 = 0.0) THEN null ELSE (stdev#48 / mean#49) END AS cov#50]
Output [5]: [w_warehouse_sk#33, i_item_sk#32, d_moy#37, mean#49, CASE WHEN (knownfloatingpointnormalized(normalizenanandzero(mean#49)) = knownfloatingpointnormalized(normalizenanandzero(0.0))) THEN null ELSE (stdev#48 / mean#49) END AS cov#50]
Input [5]: [w_warehouse_sk#33, i_item_sk#32, d_moy#37, stdev#48, mean#49]

(45) BroadcastExchange
Expand Down
Loading
Loading