Skip to content

Commit

Permalink
Add APIs to get the strongly typed columns from a DataFrame (dotnet#2878
Browse files Browse the repository at this point in the history
)

* CP

* sq

* sq

* Improve docs
  • Loading branch information
Prashanth Govindarajan authored Mar 20, 2020
1 parent 8d7fb66 commit 7ef10ba
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 0 deletions.
273 changes: 273 additions & 0 deletions src/Microsoft.Data.Analysis/DataFrameColumnCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -175,5 +175,278 @@ public DataFrameColumn this[string columnName]
}
}

/// <summary>
/// Gets the <see cref="PrimitiveDataFrameColumn{T}"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="PrimitiveDataFrameColumn{T}"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public PrimitiveDataFrameColumn<T> GetPrimitiveColumn<T>(string name)
where T : unmanaged
{
DataFrameColumn column = this[name];
if (column is PrimitiveDataFrameColumn<T> ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(T)), nameof(T));
}

/// <summary>
/// Gets the <see cref="ArrowStringDataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="ArrowStringDataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public ArrowStringDataFrameColumn GetArrowStringColumn(string name)
{
DataFrameColumn column = this[name];
if (column is ArrowStringDataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(string)));
}

/// <summary>
/// Gets the <see cref="StringDataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="StringDataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public StringDataFrameColumn GetStringColumn(string name)
{
DataFrameColumn column = this[name];
if (column is StringDataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(string)));
}

/// <summary>
/// Gets the <see cref="BooleanDataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="BooleanDataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public BooleanDataFrameColumn GetBooleanColumn(string name)
{
DataFrameColumn column = this[name];
if (column is BooleanDataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(Boolean)));
}

/// <summary>
/// Gets the <see cref="ByteDataFrameColumn"/> with the specified <paramref name="name"/> and attempts to return it as an <see cref="ByteDataFrameColumn"/>. If <see cref="DataFrameColumn.DataType"/> is not of type <see cref="Byte"/>, an exception is thrown.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="ByteDataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public ByteDataFrameColumn GetByteColumn(string name)
{
DataFrameColumn column = this[name];
if (column is ByteDataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(Byte)));
}

/// <summary>
/// Gets the <see cref="CharDataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="CharDataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public CharDataFrameColumn GetCharColumn(string name)
{
DataFrameColumn column = this[name];
if (column is CharDataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(Char)));
}

/// <summary>
/// Gets the <see cref="DoubleDataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="DoubleDataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public DoubleDataFrameColumn GetDoubleColumn(string name)
{
DataFrameColumn column = this[name];
if (column is DoubleDataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(Double)));
}

/// <summary>
/// Gets the <see cref="DecimalDataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="DecimalDataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public DecimalDataFrameColumn GetDecimalColumn(string name)
{
DataFrameColumn column = this[name];
if (column is DecimalDataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(Decimal)));
}

/// <summary>
/// Gets the <see cref="SingleDataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="SingleDataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public SingleDataFrameColumn GetSingleColumn(string name)
{
DataFrameColumn column = this[name];
if (column is SingleDataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(Single)));
}

/// <summary>
/// Gets the <see cref="Int32DataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="Int32DataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public Int32DataFrameColumn GetInt32Column(string name)
{
DataFrameColumn column = this[name];
if (column is Int32DataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(Int32)));
}

/// <summary>
/// Gets the <see cref="Int64DataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="Int64DataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public Int64DataFrameColumn GetInt64Column(string name)
{
DataFrameColumn column = this[name];
if (column is Int64DataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(Int64)));
}

/// <summary>
/// Gets the <see cref="SByteDataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="SByteDataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public SByteDataFrameColumn GetSByteColumn(string name)
{
DataFrameColumn column = this[name];
if (column is SByteDataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(SByte)));
}

/// <summary>
/// Gets the <see cref="Int16DataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="Int16DataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public Int16DataFrameColumn GetInt16Column(string name)
{
DataFrameColumn column = this[name];
if (column is Int16DataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(Int16)));
}

/// <summary>
/// Gets the <see cref="UInt32DataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="UInt32DataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public UInt32DataFrameColumn GetUInt32Column(string name)
{
DataFrameColumn column = this[name];
if (column is UInt32DataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(string)));
}

/// <summary>
/// Gets the <see cref="UInt64DataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="UInt64DataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public UInt64DataFrameColumn GetUInt64Column(string name)
{
DataFrameColumn column = this[name];
if (column is UInt64DataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(UInt64)));
}

/// <summary>
/// Gets the <see cref="UInt16DataFrameColumn"/> with the specified <paramref name="name"/>.
/// </summary>
/// <param name="name">The name of the column</param>
/// <returns><see cref="UInt16DataFrameColumn"/>.</returns>
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
public UInt16DataFrameColumn GetUInt16Column(string name)
{
DataFrameColumn column = this[name];
if (column is UInt16DataFrameColumn ret)
{
return ret;
}

throw new ArgumentException(string.Format(Strings.BadColumnCast, column.DataType, typeof(UInt16)));
}

}
}
9 changes: 9 additions & 0 deletions src/Microsoft.Data.Analysis/strings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/Microsoft.Data.Analysis/strings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@
<resheader name="writer">
<value>System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089</value>
</resheader>
<data name="BadColumnCast" xml:space="preserve">
<value>Cannot cast column holding {0} values to type {1}</value>
</data>
<data name="CannotResizeDown" xml:space="preserve">
<value>Cannot resize down</value>
</data>
Expand Down
70 changes: 70 additions & 0 deletions tests/Microsoft.Data.Analysis.Tests/DataFrameTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2179,5 +2179,75 @@ public void TestBinaryOperationsOnExplodedNumericColumns()
Assert.True(reverseInPlace.ElementwiseEquals(ints).All());
Assert.False(reverseInPlace.ElementwiseEquals(reverse).All());
}

[Fact]
public void GetColumnTests()
{
DataFrame dataFrame = MakeDataFrameWithAllColumnTypes(10);
PrimitiveDataFrameColumn<int> primitiveInts = dataFrame.Columns.GetPrimitiveColumn<int>("Int");
Assert.NotNull(primitiveInts);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetPrimitiveColumn<float>("Int"));

StringDataFrameColumn strings = dataFrame.Columns.GetStringColumn("String");
Assert.NotNull(strings);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetStringColumn("ArrowString"));

ArrowStringDataFrameColumn arrowStrings = dataFrame.Columns.GetArrowStringColumn("ArrowString");
Assert.NotNull(arrowStrings);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetArrowStringColumn("String"));

ByteDataFrameColumn bytes = dataFrame.Columns.GetByteColumn("Byte");
Assert.NotNull(bytes);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetSingleColumn("Byte"));

Int32DataFrameColumn ints = dataFrame.Columns.GetInt32Column("Int");
Assert.NotNull(ints);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetSingleColumn("Int"));

BooleanDataFrameColumn bools = dataFrame.Columns.GetBooleanColumn("Bool");
Assert.NotNull(bools);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetSingleColumn("Bool"));

CharDataFrameColumn chars = dataFrame.Columns.GetCharColumn("Char");
Assert.NotNull(chars);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetSingleColumn("Char"));

DecimalDataFrameColumn decimals = dataFrame.Columns.GetDecimalColumn("Decimal");
Assert.NotNull(decimals);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetSingleColumn("Decimal"));

DoubleDataFrameColumn doubles = dataFrame.Columns.GetDoubleColumn("Double");
Assert.NotNull(doubles);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetSingleColumn("Double"));

SingleDataFrameColumn singles = dataFrame.Columns.GetSingleColumn("Float");
Assert.NotNull(singles);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetDoubleColumn("Float"));

Int64DataFrameColumn longs = dataFrame.Columns.GetInt64Column("Long");
Assert.NotNull(longs);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetSingleColumn("Long"));

SByteDataFrameColumn sbytes = dataFrame.Columns.GetSByteColumn("Sbyte");
Assert.NotNull(sbytes);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetSingleColumn("Sbyte"));

Int16DataFrameColumn shorts = dataFrame.Columns.GetInt16Column("Short");
Assert.NotNull(shorts);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetSingleColumn("Short"));

UInt32DataFrameColumn uints = dataFrame.Columns.GetUInt32Column("Uint");
Assert.NotNull(uints);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetSingleColumn("Uint"));

UInt64DataFrameColumn ulongs = dataFrame.Columns.GetUInt64Column("Ulong");
Assert.NotNull(ulongs);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetSingleColumn("Ulong"));

UInt16DataFrameColumn ushorts = dataFrame.Columns.GetUInt16Column("Ushort");
Assert.NotNull(ushorts);
Assert.Throws<ArgumentException>(() => dataFrame.Columns.GetSingleColumn("Ushort"));

}
}
}

0 comments on commit 7ef10ba

Please sign in to comment.