Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion python/pyspark/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
from py4j.protocol import Py4JJavaError

from pyspark import SparkConf, SparkContext
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, eventually
from pyspark.testing.utils import (
ReusedPySparkTestCase,
PySparkTestCase,
QuietTest,
eventually,
)


class WorkerTests(ReusedPySparkTestCase):
Expand Down Expand Up @@ -272,6 +277,65 @@ def test_worker_crash(self):
rdd.map(lambda x: os.getpid()).collect()


class SimpleWorkerTests(WorkerTests):
"""Run worker tests through the non-daemon (simple-worker) path.

Windows always uses this path; Linux/macOS use it when
spark.python.use.daemon=false.
"""

@classmethod
def conf(cls):
_conf = super(SimpleWorkerTests, cls).conf()
_conf.set("spark.python.use.daemon", "false")
return _conf

def test_create_dataframe(self):
"""DataFrame creation through the simple-worker path."""
from pyspark.sql import SparkSession

spark = SparkSession(self.sc)
df = spark.createDataFrame([("Alice", 30), ("Bob", 25)], ["name", "age"])
self.assertEqual(df.count(), 2)

def test_udf(self):
"""UDF execution through the simple-worker path."""
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

spark = SparkSession(self.sc)
str_udf = udf(lambda x: f"val_{x}", StringType())
rows = spark.range(5).withColumn("x", str_udf("id")).collect()
self.assertEqual(len(rows), 5)

def test_datasource_read(self):
"""Python Data Source read through the simple-worker path."""
from pyspark.sql import SparkSession
from pyspark.sql.datasource import DataSource, DataSourceReader

class TestReader(DataSourceReader):
def read(self, partition):
yield (0, "a")
yield (1, "b")

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test_simple_worker"

def schema(self):
return "id INT, value STRING"

def reader(self, schema):
return TestReader()

spark = SparkSession(self.sc)
spark.dataSource.register(TestDataSource)
df = spark.read.format("test_simple_worker").load()
self.assertEqual(df.count(), 2)


if __name__ == "__main__":
from pyspark.testing import main

Expand Down