Skip to content

Commit

Permalink
Merge pull request #159 from rokostik/spreadsheet-group-by
Browse files Browse the repository at this point in the history
Implement spredsheet group-by
  • Loading branch information
refaktor committed Mar 17, 2024
2 parents 3f2d5f5 + d800d7a commit 4d87aaf
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 9 deletions.
2 changes: 1 addition & 1 deletion env/spreadsheet.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func (s Spreadsheet) GetRowValue(column string, rrow SpreadsheetRow) (any, error
}
}
if index < 0 {
return "", nil
return "", fmt.Errorf("column %s not found", column)
}
return rrow.Values[index], nil
}
Expand Down
151 changes: 143 additions & 8 deletions evaldo/builtins_spreadsheet.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,52 @@ var Builtins_spreadsheet = map[string]*env.Builtin{
}
},
},
"group-by": {
Argsn: 3,
Doc: "Groups a spreadsheet by the given column and (optional) aggregations.",
Fn: func(ps *env.ProgramState, arg0 env.Object, arg1 env.Object, arg2 env.Object, arg3 env.Object, arg4 env.Object) (res env.Object) {
switch spr := arg0.(type) {
case env.Spreadsheet:
switch aggBlock := arg2.(type) {
case env.Block:
if len(aggBlock.Series.S)%2 != 0 {
return MakeBuiltinError(ps, "Aggregation block must contain pairs of column name and function for each aggregation.", "group-by")
}
aggregations := make(map[string][]string)
for i := 0; i < len(aggBlock.Series.S); i += 2 {
col := aggBlock.Series.S[i]
fun, ok := aggBlock.Series.S[i+1].(env.Word)
if !ok {
return MakeBuiltinError(ps, "Aggregation function must be a word", "group-by")
}
colStr := ""
switch col := col.(type) {
case env.Tagword:
colStr = ps.Idx.GetWord(col.Index)
case env.String:
colStr = col.Value
default:
return MakeBuiltinError(ps, "Aggregation column must be a word or string", "group-by")
}
funStr := ps.Idx.GetWord(fun.Index)
aggregations[colStr] = append(aggregations[colStr], funStr)
}
switch col := arg1.(type) {
case env.Word:
return GroupBy(ps, spr, ps.Idx.GetWord(col.Index), aggregations)
case env.String:
return GroupBy(ps, spr, col.Value, aggregations)
default:
return MakeArgError(ps, 2, []env.Type{env.WordType, env.StringType}, "group-by")
}
default:
return MakeArgError(ps, 3, []env.Type{env.BlockType}, "group-by")
}
default:
return MakeArgError(ps, 1, []env.Type{env.SpreadsheetType}, "group-by")
}
},
},
}

func GenerateColumn(ps *env.ProgramState, s env.Spreadsheet, name env.Word, extractCols env.Block, code env.Block) env.Object {
Expand Down Expand Up @@ -759,7 +805,7 @@ func GenerateColumnRegexReplace(ps *env.ProgramState, s *env.Spreadsheet, name e
// get value from current row
val, err := s.GetRowValue(ps.Idx.GetWord(fromColName.Index), row)
if err != nil {
return MakeError(ps, "Couldn't retrieve value at row "+strconv.Itoa(ix))
return MakeError(ps, fmt.Sprintf("Couldn't retrieve value at row %d (%s)", ix, err))
}

var newVal any
Expand Down Expand Up @@ -1041,10 +1087,10 @@ func LeftJoin(ps *env.ProgramState, s1 env.Spreadsheet, s2 env.Spreadsheet, col1
}
}
nspr := env.NewSpreadsheet(combinedCols)
for _, row1 := range s1.GetRows() {
for i, row1 := range s1.GetRows() {
val1, err := s1.GetRowValue(col1, row1)
if err != nil {
return MakeError(ps, "Couldn't retrieve value at row")
return MakeError(ps, fmt.Sprintf("Couldn't retrieve value at row %d (%s)", i, err))
}
newRow := make([]any, len(combinedCols))

Expand All @@ -1057,13 +1103,13 @@ func LeftJoin(ps *env.ProgramState, s1 env.Spreadsheet, s2 env.Spreadsheet, col1
s2RowId = rowIds[0]
}
} else {
for i, row2 := range s2.GetRows() {
for j, row2 := range s2.GetRows() {
val2, err := s2.GetRowValue(col2, row2)
if err != nil {
return MakeError(ps, "Couldn't retrieve value at row")
return MakeError(ps, fmt.Sprintf("Couldn't retrieve value at row %d (%s)", j, err))
}
if val1.(env.Object).Equal(val2.(env.Object)) {
s2RowId = i
s2RowId = j
break
}
}
Expand All @@ -1077,11 +1123,100 @@ func LeftJoin(ps *env.ProgramState, s1 env.Spreadsheet, s2 env.Spreadsheet, col1
newRow[i+len(s1.Cols)] = v
}
} else {
for i := range s2.Cols {
newRow[i+len(s1.Cols)] = env.Void{}
for k := range s2.Cols {
newRow[k+len(s1.Cols)] = env.Void{}
}
}
nspr.AddRow(*env.NewSpreadsheetRow(newRow, nspr))
}
return *nspr
}

func GroupBy(ps *env.ProgramState, s env.Spreadsheet, col string, aggregations map[string][]string) env.Object {
if !slices.Contains(s.Cols, col) {
return MakeBuiltinError(ps, "Column not found.", "group-by")
}

aggregatesByGroup := make(map[string]map[string]float64)
countByGroup := make(map[string]int)
for i, row := range s.Rows {
groupingVal, err := s.GetRowValue(col, row)
if err != nil {
return MakeError(ps, fmt.Sprintf("Couldn't retrieve value at row %d (%s)", i, err))
}
groupValStr, ok := groupingVal.(env.String)
if !ok {
return MakeBuiltinError(ps, "Grouping column value must be a string", "group-by")
}

if _, ok := aggregatesByGroup[groupValStr.Value]; !ok {
aggregatesByGroup[groupValStr.Value] = make(map[string]float64)
}
groupAggregates := aggregatesByGroup[groupValStr.Value]

for aggCol, funs := range aggregations {
for _, fun := range funs {
colAgg := aggCol + "_" + fun
if fun == "count" {
if aggCol != col {
return MakeBuiltinError(ps, "Count aggregation can only be applied on the grouping column", "group-by")
}
groupAggregates[colAgg]++
continue
}
valObj, err := s.GetRowValue(aggCol, row)
if err != nil {
return MakeError(ps, fmt.Sprintf("Couldn't retrieve value at row %d (%s)", i, err))
}
var val float64
switch valObj := env.ToRyeValue(valObj).(type) {
case env.Integer:
val = float64(valObj.Value)
case env.Decimal:
val = valObj.Value
default:
return MakeBuiltinError(ps, "Aggregation column value must be a number", "group-by")
}
switch fun {
case "sum":
groupAggregates[colAgg] += val
case "avg":
groupAggregates[colAgg] += val
countByGroup[groupValStr.Value]++
case "min":
if min, ok := groupAggregates[colAgg]; !ok || val < min {
groupAggregates[colAgg] = val
}
case "max":
if max, ok := groupAggregates[colAgg]; !ok || val > max {
groupAggregates[colAgg] = val
}
default:
return MakeBuiltinError(ps, fmt.Sprintf("Unknown aggregation function: %s", fun), "group-by")
}
}
}
}
newCols := []string{col}
for aggCol, funs := range aggregations {
for _, fun := range funs {
newCols = append(newCols, aggCol+"_"+fun)
}
}
newS := env.NewSpreadsheet(newCols)
for groupVal, groupAggregates := range aggregatesByGroup {
newRow := make([]any, len(newCols))
newRow[0] = *env.NewString(groupVal)
for i, col := range newCols[1:] {
if strings.HasSuffix(col, "_count") {
newRow[i+1] = *env.NewInteger(int64(groupAggregates[col]))
} else if strings.HasSuffix(col, "_avg") {
newRow[i+1] = *env.NewDecimal(groupAggregates[col] / float64(countByGroup[groupVal]))
} else {
newRow[i+1] = *env.NewDecimal(groupAggregates[col])
}
}
newS.AddRow(*env.NewSpreadsheetRow(newRow, newS))
}
return *newS
}
17 changes: 17 additions & 0 deletions tests/structures.rye
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,23 @@ section "Spreadsheet related functions"
names .inner-join houses 'id 'id
} spreadsheet { "id" "name" "id_2" "house" } { 1 "Paul" 1 "Atreides" 3 "Vladimir" 3 "Harkonnen" }
}

group "group by"
mold\nowrap ?group-by
{ { block } }
{
equal { spreadsheet { "name" "val" } { "a" 1 "b" 2 } |group-by 'name { } |sort-col! 'name
} spreadsheet { "name" } { "a" "b" }

equal { spreadsheet { "name" "val" } { "a" 1 "b" 6 "a" 5 "b" 10 "a" 7 }
|group-by 'name { 'name count 'val sum 'val min 'val max 'val avg }
|sort-col! 'name
} spreadsheet { "name" "name_count" "val_sum" "val_min" "val_max" "val_avg" }
{
"a" 3 13.0 1.0 7.0 4.333333333333333
"b" 2 16.0 6.0 10.0 8.0
}
}
}


Expand Down

0 comments on commit 4d87aaf

Please sign in to comment.