Skip to content

Commit

Permalink
refactor(snowflake): remove some unnecessary checking of input in the…
Browse files Browse the repository at this point in the history
… snowflake compiler
  • Loading branch information
cpcloud committed Nov 2, 2024
1 parent ef6634c commit 4fceebf
Showing 1 changed file with 10 additions and 22 deletions.
32 changes: 10 additions & 22 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class SnowflakeCompiler(SQLGlotCompiler):
)

SIMPLE_OPS = {
# overrides booland_agg/boolor_agg because neither of those can be used
# in a cumulative window frame, while min and max can
ops.All: "min",
ops.Any: "max",
ops.ArrayDistinct: "array_distinct",
Expand All @@ -130,10 +132,17 @@ class SnowflakeCompiler(SQLGlotCompiler):
ops.Hash: "hash",
ops.Median: "median",
ops.Mode: "mode",
ops.RandomUUID: "uuid_string",
ops.StringToDate: "to_date",
ops.StringToTimestamp: "to_timestamp_tz",
ops.TimeFromHMS: "time_from_parts",
ops.TimestampFromYMDHMS: "timestamp_from_parts",
ops.ToJSONMap: "as_object",
ops.ToJSONArray: "as_array",
ops.UnwrapJSONString: "as_varchar",
ops.UnwrapJSONInt64: "as_integer",
ops.UnwrapJSONFloat64: "as_double",
ops.UnwrapJSONBoolean: "as_boolean",
}

def __init__(self):
Expand Down Expand Up @@ -306,24 +315,6 @@ def visit_Cast(self, op, *, arg, to):
return self.if_(self.f.is_array(arg), arg, NULL)
return super().visit_Cast(op, arg=arg, to=to)

def visit_ToJSONMap(self, op, *, arg):
return self.if_(self.f.is_object(arg), arg, NULL)

def visit_ToJSONArray(self, op, *, arg):
return self.if_(self.f.is_array(arg), arg, NULL)

def visit_UnwrapJSONString(self, op, *, arg):
return self.if_(self.f.is_varchar(arg), self.f.as_varchar(arg), NULL)

def visit_UnwrapJSONInt64(self, op, *, arg):
return self.if_(self.f.is_integer(arg), self.f.as_integer(arg), NULL)

def visit_UnwrapJSONFloat64(self, op, *, arg):
return self.if_(self.f.is_double(arg), self.f.as_double(arg), NULL)

def visit_UnwrapJSONBoolean(self, op, *, arg):
return self.if_(self.f.is_boolean(arg), self.f.as_boolean(arg), NULL)

def visit_IsNan(self, op, *, arg):
return arg.eq(self.NAN)

Expand Down Expand Up @@ -383,14 +374,11 @@ def visit_MapLength(self, op, *, arg):
def visit_Log(self, op, *, arg, base):
return self.f.log(base, arg)

def visit_RandomScalar(self, op, **kwargs):
def visit_RandomScalar(self, op, **_):
return self.f.uniform(
self.f.to_double(0.0), self.f.to_double(1.0), self.f.random()
)

def visit_RandomUUID(self, op, **kwargs):
return self.f.uuid_string()

def visit_ApproxMedian(self, op, *, arg, where):
return self.agg.approx_percentile(arg, 0.5, where=where)

Expand Down

0 comments on commit 4fceebf

Please sign in to comment.