Skip to content

Commit

Permalink
Add ability to load into spark dataframe (#4)
Browse files Browse the repository at this point in the history
* add ability to load into spark dataframe

* relax spark constraint

* relax spark constraint even further

* use _ imports

* optional spark context
  • Loading branch information
tomcarter23 authored Dec 21, 2023
1 parent 957251b commit 55feb67
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dynamic = ["version"]

dependencies = [
"pandas >= 1.2",
"pyspark>=0.7.0",
]

[project.optional-dependencies]
Expand Down
16 changes: 16 additions & 0 deletions src/synthesized_datasets/_datasets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import sys as _sys
import typing as _typing
from enum import Enum as _Enum
import os as _os

import pandas as _pd
import pyspark.sql as _ps
from pyspark import SparkFiles as _SparkFiles


_ROOT_URL = "https://raw.githubusercontent.com/synthesized-io/datasets/master/"

Expand Down Expand Up @@ -51,6 +55,18 @@ def load(self) -> _pd.DataFrame:
df.attrs["name"] = self.name
return df

def load_spark(self, spark: _typing.Optional[_ps.SparkSession] = None) -> _ps.DataFrame:
"""Loads the dataset as a Spark DataFrame."""

if spark is None:
spark = _ps.SparkSession.builder.getOrCreate()

spark.sparkContext.addFile(self.url)
_, ext = _os.path.splitext(self.url)
df = spark.read.csv(_SparkFiles.get("".join([self.name, ext])), header=True, inferSchema=True)
df.name = self.name
return df

def __repr__(self):
return f"<Dataset: {self.url}>"

Expand Down

0 comments on commit 55feb67

Please sign in to comment.