Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Supports Stddev #348

Merged
merged 6 commits into from
May 7, 2024
Merged

feat: Supports Stddev #348

merged 6 commits into from
May 7, 2024

Conversation

huaxingao
Copy link
Contributor

Which issue does this PR close?

Closes #.

Rationale for this change

Supports STDDEV_SAMP and STDDEV_POP
The implementation mostly is the same as the DataFusion's implementation. The reason
we have our own implementation is that DataFusion has UInt64 for state_field count,
while Spark has Double for count. Also adding null_on_divide_by_zero
to be consistent with Spark's implementation.

What changes are included in this PR?

How are these changes tested?

@@ -31,7 +31,7 @@ struct<w_warehouse_sk:int,i_item_sk:int,d_moy:int,mean:double,cov:double,w_wareh
1 12259 1 326.5 1.219693210219279 1 12259 2 292.6666666666667 1.2808898286830026
1 12641 1 321.25 1.1286221893301993 1 12641 2 279.25 1.129134558577743
1 13043 1 260.5 1.355894484625015 1 13043 2 295.0 1.056210118409035
1 13157 1 260.5 1.5242630430075292 1 13157 2 413.5 1.0422561797285326
1 13157 1 260.5 1.524263043007529 1 13157 2 413.5 1.0422561797285326
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering what is causing the digit difference...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what caused the digit difference.
Actually for SortMergeJoin, I got 1.524263043007529, but for BroadCastJoin, I still got 1.5242630430075292. Is it OK if I change the the expected result based on join type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also cc @viirya

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, it would be good to compare floating point numbers based on an epsilon to make sure they are within some tolerance threshold. I assume we are currently just comparing text file output directly? Do we have a way to generate the output into a structured file type such as CSV or JSON?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference may be down to order of operations - depending on the order that batches that are being processed from different partitions, for example. I don't think we can expect it to be 100% deterministic in a distributed system.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we are currently just comparing text file output directly. We are using Spark's TPCDSQuerySuite. It doesn't seem to be a way to generate the output into a structured file type.

Comment on lines 170 to 179
match variance {
ScalarValue::Float64(e) => {
if e.is_none() {
Ok(ScalarValue::Float64(None))
} else {
Ok(ScalarValue::Float64(e.map(|f| f.sqrt())))
}
}
_ => internal_err!("Variance should be f64"),
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can leverage pattern matching to simplify this.

Suggested change
match variance {
ScalarValue::Float64(e) => {
if e.is_none() {
Ok(ScalarValue::Float64(None))
} else {
Ok(ScalarValue::Float64(e.map(|f| f.sqrt())))
}
}
_ => internal_err!("Variance should be f64"),
}
match variance {
ScalarValue::Float64(Some(e)) => Ok(ScalarValue::Float64(Some(e.sqrt()))),
ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)),
_ => internal_err!("Variance should be f64"),
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed. Thanks

Copy link
Member

@andygrove andygrove left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @huaxingao

// specific language governing permissions and limitations
// under the License.

//! Defines physical expressions that can evaluated at runtime during query execution
Copy link
Member

@viirya viirya May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems copied from somewhere and not related?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed. Thanks

Comment on lines 55 to 56
// the result of stddev just support FLOAT64 and Decimal data type.
assert!(matches!(data_type, DataType::Float64));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm? So we also need to add DecimalType here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's FLOAT64 only. Removed and Decimal data type

}
}

/// An accumulator to compute the average
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// An accumulator to compute the average
/// An accumulator to compute the standard deviation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed. Thanks

Comment on lines 76 to 81
// TODO: comment 39a and 39b for now because the expected result for stddev failed:
// expected: 1.5242630430075292, actual: 1.524263043007529.
// Will change the comparison logic to detect floating-point numbers and compare
// with epsilon
// "q39a",
// "q39b",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should create a ticket for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opened #392

@viirya
Copy link
Member

viirya commented May 6, 2024

Some minor comments.

@viirya viirya merged commit c40bc7c into apache:main May 7, 2024
28 checks passed
@viirya
Copy link
Member

viirya commented May 7, 2024

Merged. Thanks @huaxingao @kazuyukitanimura @andygrove

@huaxingao
Copy link
Contributor Author

Thanks, everyone!

@huaxingao huaxingao deleted the stddev branch May 7, 2024 01:11
himadripal pushed a commit to himadripal/datafusion-comet that referenced this pull request Sep 7, 2024
* feat: Supports Stddev

* fix fmt

* update q39a.sql.out

* address comments

* disable q93a and q93b for now

* address comments

---------

Co-authored-by: Huaxin Gao <huaxin.gao@apple.com>
himadripal pushed a commit to himadripal/datafusion-comet that referenced this pull request Sep 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants