Skip to content

Commit

Permalink
fix issue with data loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Chloe He committed Jun 27, 2024
1 parent c318efe commit 6bc9364
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@

import os
from datetime import datetime, timedelta, timezone
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
import pytest
from filelock import FileLock

import ibis
from ibis import util
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest
from ibis.backends.tests.data import json_types, topk, win

if TYPE_CHECKING:
from pathlib import Path


def set_pyspark_database(con, database):
con._session.catalog.setCurrentDatabase(database)
Expand Down Expand Up @@ -194,6 +198,31 @@ def _load_data(self, **_: Any) -> None:
t = t.sort(sort_col)
t.createOrReplaceTempView(name)

@classmethod
def load_data(
cls, data_dir: Path, tmpdir: Path, worker_id: str, **kw: Any
) -> BackendTest:
"""Load testdata from `data_dir`."""
# handling for multi-processes pytest

# get the temp directory shared by all workers
root_tmp_dir = tmpdir.getbasetemp() / "streaming"
if worker_id != "master":
root_tmp_dir = root_tmp_dir.parent

fn = root_tmp_dir / cls.name()
with FileLock(f"{fn}.lock"):
cls.skip_if_missing_deps()

inst = cls(data_dir=data_dir, tmpdir=tmpdir, worker_id=worker_id, **kw)

if inst.stateful:
inst.stateful_load(fn, **kw)
else:
inst.stateless_load(**kw)
inst.postload(tmpdir=tmpdir, worker_id=worker_id, **kw)
return inst

@staticmethod
def connect(*, tmpdir, worker_id, **kw):
from pyspark.sql import SparkSession
Expand Down Expand Up @@ -324,8 +353,8 @@ def con(data_dir, tmp_path_factory, worker_id):

@pytest.fixture(scope="session")
def con_streaming(data_dir, tmp_path_factory, worker_id):
pytest.set_trace()
backend_test = TestConfForStreaming.load_data(data_dir, tmp_path_factory, worker_id)
backend_test._load_data()
return backend_test.connection


Expand Down

0 comments on commit 6bc9364

Please sign in to comment.