Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Jupyter compile-time DF type not recognized #401

Merged
merged 7 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1786,12 +1786,10 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
*
* `df.`[select][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.select]` { `[colsOf][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.colsOf]`<`[String][String]`>().`[cols][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]`() }`
*
* @see [all]
*
*
* @param [predicate] A [ColumnFilter function][org.jetbrains.kotlinx.dataframe.ColumnFilter] that takes a [ColumnReference][org.jetbrains.kotlinx.dataframe.columns.ColumnReference] and returns a [Boolean].
* @return A [ColumnSet][org.jetbrains.kotlinx.dataframe.columns.ColumnSet] containing the columns that match the given [predicate].
*/
* @see [all] */
@Suppress("UNCHECKED_CAST")
public fun <C> ColumnSet<C>.cols(
predicate: ColumnFilter<C> = { true },
Expand Down Expand Up @@ -1829,12 +1827,10 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
*
* `df.`[select][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.select]` { `[colsOf][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.colsOf]`<`[String][String]`>().`[cols][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]`() }`
*
* @see [all]
*
*
* @param [predicate] A [ColumnFilter function][org.jetbrains.kotlinx.dataframe.ColumnFilter] that takes a [ColumnReference][org.jetbrains.kotlinx.dataframe.columns.ColumnReference] and returns a [Boolean].
* @return A [ColumnSet][org.jetbrains.kotlinx.dataframe.columns.ColumnSet] containing the columns that match the given [predicate].
*/
* @see [all] */
public operator fun <C> ColumnSet<C>.get(
predicate: ColumnFilter<C> = { true },
): TransformableColumnSet<C> = cols(predicate)
Expand Down Expand Up @@ -1928,12 +1924,10 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
*
* `df.`[select][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.select]` { myColumnGroup`[`[`][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]`{ ... }`[`]`][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]` }`
*
* @see [all]
*
*
* @param [predicate] A [ColumnFilter function][org.jetbrains.kotlinx.dataframe.ColumnFilter] that takes a [ColumnReference][org.jetbrains.kotlinx.dataframe.columns.ColumnReference] and returns a [Boolean].
* @return A [ColumnSet][org.jetbrains.kotlinx.dataframe.columns.ColumnSet] containing the columns that match the given [predicate].
*/
* @see [all] */
public fun SingleColumn<*>.cols(
predicate: ColumnFilter<*> = { true },
): TransformableColumnSet<*> = colsInternal(predicate)
Expand Down Expand Up @@ -1979,12 +1973,11 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
*
* `df.`[select][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.select]` { myColumnGroup`[`[`][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]`{ ... }`[`]`][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]` }`
*
* @see [all]
*
*
*
* @param [predicate] A [ColumnFilter function][org.jetbrains.kotlinx.dataframe.ColumnFilter] that takes a [ColumnReference][org.jetbrains.kotlinx.dataframe.columns.ColumnReference] and returns a [Boolean].
* @return A [ColumnSet][org.jetbrains.kotlinx.dataframe.columns.ColumnSet] containing the columns that match the given [predicate].
* @see [all]
*
*/
public operator fun SingleColumn<*>.get(
predicate: ColumnFilter<*> = { true },
Expand Down Expand Up @@ -2172,12 +2165,10 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
*
* `df.`[select][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.select]` { Type::columnGroup.`[cols][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]`() }`
*
* @see [all]
*
*
* @param [predicate] A [ColumnFilter function][org.jetbrains.kotlinx.dataframe.ColumnFilter] that takes a [ColumnReference][org.jetbrains.kotlinx.dataframe.columns.ColumnReference] and returns a [Boolean].
* @return A [ColumnSet][org.jetbrains.kotlinx.dataframe.columns.ColumnSet] containing the columns that match the given [predicate].
*/
* @see [all] */
public fun KProperty<*>.cols(
predicate: ColumnFilter<*> = { true },
): TransformableColumnSet<*> = colGroup(this).cols(predicate)
Expand Down Expand Up @@ -2212,12 +2203,10 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
*
* `df.`[select][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.select]` { Type::columnGroup.`[cols][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]`() }`
*
* @see [all]
*
*
* @param [predicate] A [ColumnFilter function][org.jetbrains.kotlinx.dataframe.ColumnFilter] that takes a [ColumnReference][org.jetbrains.kotlinx.dataframe.columns.ColumnReference] and returns a [Boolean].
* @return A [ColumnSet][org.jetbrains.kotlinx.dataframe.columns.ColumnSet] containing the columns that match the given [predicate].
*/
* @see [all] */
public operator fun KProperty<*>.get(
predicate: ColumnFilter<*> = { true },
): TransformableColumnSet<Any?> = cols(predicate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,33 @@ import org.jetbrains.dataframe.impl.codeGen.ReplCodeGenerator
import org.jetbrains.kotlinx.dataframe.AnyCol
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.AnyRow
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.api.Convert
import org.jetbrains.kotlinx.dataframe.api.FormattedFrame
import org.jetbrains.kotlinx.dataframe.api.Gather
import org.jetbrains.kotlinx.dataframe.api.GroupBy
import org.jetbrains.kotlinx.dataframe.api.Merge
import org.jetbrains.kotlinx.dataframe.api.Pivot
import org.jetbrains.kotlinx.dataframe.api.PivotGroupBy
import org.jetbrains.kotlinx.dataframe.api.ReducedGroupBy
import org.jetbrains.kotlinx.dataframe.api.ReducedPivot
import org.jetbrains.kotlinx.dataframe.api.ReducedPivotGroupBy
import org.jetbrains.kotlinx.dataframe.api.Split
import org.jetbrains.kotlinx.dataframe.api.SplitWithTransform
import org.jetbrains.kotlinx.dataframe.api.Update
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.asDataFrame
import org.jetbrains.kotlinx.dataframe.api.columnsCount
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
import org.jetbrains.kotlinx.dataframe.api.frames
import org.jetbrains.kotlinx.dataframe.api.into
import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
import org.jetbrains.kotlinx.dataframe.api.name
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.api.values
import org.jetbrains.kotlinx.dataframe.codeGen.CodeWithConverter
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
Expand All @@ -31,6 +55,7 @@ import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration
import org.jetbrains.kotlinx.jupyter.api.libraries.resources
import kotlin.reflect.KClass
import kotlin.reflect.KProperty
import kotlin.reflect.KType
import kotlin.reflect.full.isSubtypeOf

/** Users will get an error if their Kotlin Jupyter kernel is older than this version. */
Expand All @@ -45,6 +70,101 @@ internal class Integration(

val version = options["v"]

private fun KotlinKernelHost.execute(codeWithConverter: CodeWithConverter, argument: String): VariableName? {
val code = codeWithConverter.with(argument)
return if (code.isNotBlank()) {
val result = execute(code)
if (codeWithConverter.hasConverter) {
result.name
} else null
} else null
}

private fun KotlinKernelHost.execute(
codeWithConverter: CodeWithConverter,
property: KProperty<*>,
type: KType,
): VariableName? {
val variableName = "(${property.name}${if (property.returnType.isMarkedNullable) "!!" else ""} as $type)"
return execute(codeWithConverter, variableName)
}

private fun KotlinKernelHost.updateImportDataSchemaVariable(
importDataSchema: ImportDataSchema,
property: KProperty<*>,
): VariableName? {
val formats = supportedFormats.filterIsInstance<SupportedCodeGenerationFormat>()
val name = property.name + "DataSchema"
return when (
val codeGenResult = CodeGenerator.urlCodeGenReader(importDataSchema.url, name, formats, true)
) {
is CodeGenerationReadResult.Success -> {
val readDfMethod = codeGenResult.getReadDfMethod(importDataSchema.url.toExternalForm())
val code = readDfMethod.additionalImports.joinToString("\n") +
"\n" +
codeGenResult.code

execute(code)
execute("""DISPLAY("Data schema successfully imported as ${property.name}: $name")""")

name
}

is CodeGenerationReadResult.Error -> {
execute("""DISPLAY("Failed to read data schema from ${importDataSchema.url}: ${codeGenResult.reason}")""")
null
}
}
}

private fun KotlinKernelHost.updateAnyFrameVariable(
df: AnyFrame,
property: KProperty<*>,
codeGen: ReplCodeGenerator,
): VariableName? = execute(
codeWithConverter = codeGen.process(df, property),
property = property,
type = DataFrame::class.createStarProjectedType(false),

)

private fun KotlinKernelHost.updateAnyRowVariable(
row: AnyRow,
property: KProperty<*>,
codeGen: ReplCodeGenerator,
): VariableName? = execute(
codeWithConverter = codeGen.process(row, property),
property = property,
type = DataRow::class.createStarProjectedType(false),
)

private fun KotlinKernelHost.updateColumnGroupVariable(
col: ColumnGroup<*>,
property: KProperty<*>,
codeGen: ReplCodeGenerator,
): VariableName? = execute(
codeWithConverter = codeGen.process(col.asDataFrame(), property),
property = property,
type = ColumnGroup::class.createStarProjectedType(false),
)

private fun KotlinKernelHost.updateAnyColVariable(
col: AnyCol,
property: KProperty<*>,
codeGen: ReplCodeGenerator,
): VariableName? = if (col.isColumnGroup()) {
val codeWithConverter = codeGen.process(col.asColumnGroup().asDataFrame(), property).let { c ->
CodeWithConverter(c.declarations) { c.converter("$it.asColumnGroup()") }
}
execute(
codeWithConverter = codeWithConverter,
property = property,
type = DataColumn::class.createStarProjectedType(false),
)
} else {
null
}

override fun Builder.onLoaded() {
if (version != null) {
dependencies(
Expand Down Expand Up @@ -152,65 +272,17 @@ internal class Integration(
import("org.jetbrains.kotlinx.dataframe.dataTypes.*")
import("org.jetbrains.kotlinx.dataframe.impl.codeGen.urlCodeGenReader")

fun KotlinKernelHost.execute(codeWithConverter: CodeWithConverter, argument: String): VariableName? {
val code = codeWithConverter.with(argument)
return if (code.isNotBlank()) {
val result = execute(code)
if (codeWithConverter.hasConverter) {
result.name
} else null
} else null
}

fun KotlinKernelHost.execute(codeWithConverter: CodeWithConverter, property: KProperty<*>): VariableName? {
val variableName = property.name + if (property.returnType.isMarkedNullable) "!!" else ""
return execute(codeWithConverter, variableName)
}

updateVariable<ImportDataSchema> { importDataSchema, property ->
val formats = supportedFormats.filterIsInstance<SupportedCodeGenerationFormat>()
val name = property.name + "DataSchema"
when (val codeGenResult = CodeGenerator.urlCodeGenReader(importDataSchema.url, name, formats, true)) {
is CodeGenerationReadResult.Success -> {
val readDfMethod = codeGenResult.getReadDfMethod(importDataSchema.url.toExternalForm())
val code = readDfMethod.additionalImports.joinToString("\n") +
"\n" +
codeGenResult.code

execute(code)
execute("""DISPLAY("Data schema successfully imported as ${property.name}: $name")""")

name
}

is CodeGenerationReadResult.Error -> {
execute("""DISPLAY("Failed to read data schema from ${importDataSchema.url}: ${codeGenResult.reason}")""")
null
}
updateVariable<Any> { instance, property ->
when (instance) {
is AnyCol -> updateAnyColVariable(instance, property, codeGen)
is ColumnGroup<*> -> updateColumnGroupVariable(instance, property, codeGen)
is AnyRow -> updateAnyRowVariable(instance, property, codeGen)
is AnyFrame -> updateAnyFrameVariable(instance, property, codeGen)
is ImportDataSchema -> updateImportDataSchemaVariable(instance, property)
else -> null
}
}

updateVariable<AnyFrame> { df, property ->
execute(codeGen.process(df, property), property)
}

updateVariable<AnyRow> { row, property ->
execute(codeGen.process(row, property), property)
}

updateVariable<ColumnGroup<*>> { col, property ->
execute(codeGen.process(col.asDataFrame(), property), property)
}

updateVariable<AnyCol> { col, property ->
if (col.isColumnGroup()) {
val codeWithConverter = codeGen.process(col.asColumnGroup().asDataFrame(), property).let { c ->
CodeWithConverter(c.declarations) { c.converter("$it.asColumnGroup()") }
}
execute(codeWithConverter, property)
} else null
}

fun KotlinKernelHost.addDataSchemas(classes: List<KClass<*>>) {
val code = classes.joinToString("\n") {
codeGen.process(it)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.jetbrains.kotlinx.dataframe.jupyter

import org.intellij.lang.annotations.Language
import org.jetbrains.kotlinx.jupyter.api.Code
import org.junit.Test

Expand All @@ -11,9 +12,20 @@ class CodeGenerationTests : DataFrameJupyterTest() {
}
}

@Test
fun `Type erased dataframe`() {
@Language("kts")
val a = """
fun create(): Any? = dataFrameOf("a")(1)
val df = create()
df.a
""".checkCompilation()
}

@Test
fun `nullable dataframe`() {
"""
@Language("kts")
val a = """
fun create(): AnyFrame? = dataFrameOf("a")(1)
val df = create()
df.a
Expand All @@ -22,7 +34,8 @@ class CodeGenerationTests : DataFrameJupyterTest() {

@Test
fun `nullable columnGroup`() {
"""
@Language("kts")
val a = """
fun create(): AnyCol? = dataFrameOf("a")(1).asColumnGroup().asDataColumn()
val col = create()
col.a
Expand All @@ -31,7 +44,8 @@ class CodeGenerationTests : DataFrameJupyterTest() {

@Test
fun `nullable dataRow`() {
"""
@Language("kts")
val a = """
fun create(): AnyRow? = dataFrameOf("a")(1).single()
val row = create()
row.a
Expand Down
Loading