Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: window function range offset should be long instead of int #733

Merged
merged 10 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 71 additions & 13 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1692,16 +1692,46 @@ impl PhysicalPlanner {
.and_then(|inner| inner.lower_frame_bound_struct.as_ref())
{
Some(l) => match l {
LowerFrameBoundStruct::UnboundedPreceding(_) => {
WindowFrameBound::Preceding(ScalarValue::UInt64(None))
}
LowerFrameBoundStruct::UnboundedPreceding(_) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Preceding(ScalarValue::UInt64(None))
}
WindowFrameUnits::Range => {
WindowFrameBound::Preceding(ScalarValue::Int64(None))
}
WindowFrameUnits::Groups => {
return Err(ExecutionError::GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
LowerFrameBoundStruct::Preceding(offset) => {
let offset_value = offset.offset.unsigned_abs() as u64;
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(offset_value)))
let offset_value = offset.offset.abs();
match units {
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(
Some(offset_value as u64),
)),
WindowFrameUnits::Range => {
WindowFrameBound::Preceding(ScalarValue::Int64(Some(offset_value)))
}
WindowFrameUnits::Groups => {
return Err(ExecutionError::GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
}
}
LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::Int64(None)),
WindowFrameUnits::Groups => {
return Err(ExecutionError::GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
};

let upper_bound: WindowFrameBound = match spark_window_frame
Expand All @@ -1710,15 +1740,43 @@ impl PhysicalPlanner {
.and_then(|inner| inner.upper_frame_bound_struct.as_ref())
{
Some(u) => match u {
UpperFrameBoundStruct::UnboundedFollowing(_) => {
WindowFrameBound::Following(ScalarValue::UInt64(None))
}
UpperFrameBoundStruct::Following(offset) => {
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
}
UpperFrameBoundStruct::UnboundedFollowing(_) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Following(ScalarValue::UInt64(None))
}
WindowFrameUnits::Range => {
WindowFrameBound::Following(ScalarValue::Int64(None))
}
WindowFrameUnits::Groups => {
return Err(ExecutionError::GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
UpperFrameBoundStruct::Following(offset) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
}
WindowFrameUnits::Range => {
WindowFrameBound::Following(ScalarValue::Int64(Some(offset.offset)))
}
WindowFrameUnits::Groups => {
return Err(ExecutionError::GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => WindowFrameBound::Following(ScalarValue::UInt64(None)),
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::Int64(None)),
WindowFrameUnits::Groups => {
return Err(ExecutionError::GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
};

let window_frame = WindowFrame::new_bounds(units, lower_bound, upper_bound);
Expand Down
4 changes: 2 additions & 2 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ message UpperWindowFrameBound {
}

message Preceding {
int32 offset = 1;
int64 offset = 1;
}

message Following {
int32 offset = 1;
int64 offset = 1;
}

message UnboundedPreceding {}
Expand Down
57 changes: 52 additions & 5 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
(None, exprToProto(windowExpr.windowFunction, output))
}

if (aggExpr.isEmpty && builtinFunc.isEmpty) {
return None
}

val f = windowExpr.windowSpec.frameSpecification

val (frameType, lowerBound, upperBound) = f match {
case SpecifiedWindowFrame(frameType, lBound, uBound) =>
val frameProto = frameType match {
case RowFrame => OperatorOuterClass.WindowFrameType.Rows
case RangeFrame =>
withInfo(windowExpr, "Range frame is not supported")
return None
case RangeFrame => OperatorOuterClass.WindowFrameType.Range
}

val lBoundProto = lBound match {
Expand All @@ -278,12 +280,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
.build()
case e =>
val offset = e.eval() match {
case i: Integer => i.toLong
case l: Long => l
case _ => return None
}
OperatorOuterClass.LowerWindowFrameBound
.newBuilder()
.setPreceding(
OperatorOuterClass.Preceding
.newBuilder()
.setOffset(e.eval().asInstanceOf[Int])
.setOffset(offset)
.build())
.build()
}
Expand All @@ -300,12 +307,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
.build()
case e =>
val offset = e.eval() match {
case i: Integer => i.toLong
case l: Long => l
case _ => return None
}

OperatorOuterClass.UpperWindowFrameBound
.newBuilder()
.setFollowing(
OperatorOuterClass.Following
.newBuilder()
.setOffset(e.eval().asInstanceOf[Int])
.setOffset(offset)
.build())
.build()
}
Expand Down Expand Up @@ -2774,6 +2787,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
return None
}

if (partitionSpec.nonEmpty && orderSpec.nonEmpty &&
!validatePartitionAndSortSpecsForWindowFunc(partitionSpec, orderSpec, op)) {
return None
}

val windowExprProto = winExprs.map(windowExprToProto(_, output, op.conf))
val partitionExprs = partitionSpec.map(exprToProto(_, child.output))

Expand Down Expand Up @@ -3277,4 +3295,33 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
true
}
}

private def validatePartitionAndSortSpecsForWindowFunc(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
op: SparkPlan): Boolean = {
if (partitionSpec.length != orderSpec.length) {
withInfo(op, "Partitioning and sorting specifications do not match")
return false
}

val partitionColumnNames = partitionSpec.collect { case a: AttributeReference =>
a.name
}

val orderColumnNames = orderSpec.collect { case s: SortOrder =>
s.child match {
case a: AttributeReference => a.name
}
}

if (partitionColumnNames.zip(orderColumnNames).exists { case (partCol, orderCol) =>
partCol != orderCol
}) {
withInfo(op, "Partitioning and sorting specifications must be the same.")
return false
}

true
}
}
39 changes: 26 additions & 13 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ class CometExecSuite extends CometTestBase {
}
}

test(
"fall back to Spark when the partition spec and order spec are not the same for window function") {
withTempView("test") {
sql("""
|CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES
| (1, true), (1, false),
|(2, true), (3, false), (4, true) AS test(k, v)
|""".stripMargin)

val df = sql("""
SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg
|""".stripMargin)
checkSparkAnswer(df)
}
}

test("Native window operator should be CometUnaryExec") {
withTempView("testData") {
sql("""
Expand All @@ -164,11 +180,11 @@ class CometExecSuite extends CometTestBase {
|(3, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), null)
|AS testData(val, val_long, val_double, val_date, val_timestamp, cate)
|""".stripMargin)
val df = sql("""
val df1 = sql("""
|SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW)
|FROM testData ORDER BY cate, val
|""".stripMargin)
checkSparkAnswer(df)
checkSparkAnswer(df1)
}
}

Expand All @@ -193,23 +209,21 @@ class CometExecSuite extends CometTestBase {
}
}

test("Window range frame should fall back to Spark") {
test("Window range frame with long boundary should not fail") {
val df =
Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2"))
.toDF("key", "value")

checkAnswer(
checkSparkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))),
Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)))
checkAnswer(
Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))))
checkSparkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))),
Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)))
Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))))
}

test("Unsupported window expression should fall back to Spark") {
Expand Down Expand Up @@ -1769,10 +1783,9 @@ class CometExecSuite extends CometTestBase {
aggregateFunctions.foreach { function =>
val queries = Seq(
s"SELECT $function OVER() FROM t1",
// TODO: Range frame is not supported yet.
// s"SELECT $function OVER(order by _2) FROM t1",
// s"SELECT $function OVER(order by _2 desc) FROM t1",
// s"SELECT $function OVER(partition by _2 order by _2) FROM t1",
s"SELECT $function OVER(order by _2) FROM t1",
s"SELECT $function OVER(order by _2 desc) FROM t1",
s"SELECT $function OVER(partition by _2 order by _2) FROM t1",
s"SELECT $function OVER(rows between 1 preceding and 1 following) FROM t1",
s"SELECT $function OVER(order by _2 rows between 1 preceding and current row) FROM t1",
s"SELECT $function OVER(order by _2 rows between current row and 1 following) FROM t1")
Expand Down
Loading