How to run a function on all Spark workers before processing data in PySpark?
Solution 1
If all you want is to distribute a file between worker machines the simplest approach is to use SparkFiles
mechanism:
some_path = ... # local file, a file in DFS, an HTTP, HTTPS or FTP URI.
sc.addFile(some_path)
and retrieve it on the workers using SparkFiles.get
and standard IO tools:
from pyspark import SparkFiles
with open(SparkFiles.get(some_path)) as fw:
... # Do something
If you want to make sure that model is actually loaded the simplest approach is to load on module import. Assuming config
can be used to retrieve model path:
-
model.py
:from pyspark import SparkFiles config = ... class MyClassifier: clf = None @staticmethod def is_loaded(): return MyClassifier.clf is not None @staticmethod def load_models(config): path = SparkFiles.get(config.get("model_file")) MyClassifier.clf = load_from_file(path) # Executed once per interpreter MyClassifier.load_models(config)
-
main.py
:from pyspark import SparkContext config = ... sc = SparkContext("local", "foo") # Executed before StreamingContext starts sc.addFile(config.get("model_file")) sc.addPyFile("model.py") import model ssc = ... stream = ... stream.map(model.MyClassifier.do_something).pprint() ssc.start() ssc.awaitTermination()
Solution 2
This is a typical use case for Spark's broadcast variables. Let's say fetch_models
returns the models rather than saving them locally, you would do something like:
bc_models = sc.broadcast(fetch_models())
spark_partitions = config.get(ConfigKeys.SPARK_PARTITIONS)
stream.union(*create_kafka_streams())\
.repartition(spark_partitions)\
.foreachRDD(lambda rdd: rdd.foreachPartition(lambda partition: spam.on_partition(config, partition, bc_models.value)))
This does assume that your models fit in memory, on the driver and the executors.
You may be worried that broadcasting the models from the single driver to all the executors is inefficient, but it uses 'efficient broadcast algorithms' that can outperform distributing through HDFS significantly according to this analysis
Related videos on Youtube
Oscar
Updated on September 15, 2022Comments
-
Oscar about 1 year
I'm running a Spark Streaming task in a cluster using YARN. Each node in the cluster runs multiple spark workers. Before the streaming starts I want to execute a "setup" function on all workers on all nodes in the cluster.
The streaming task classifies incoming messages as spam or not spam, but before it can do that it needs to download the latest pre-trained models from HDFS to local disk, like this pseudo code example:
def fetch_models(): if hadoop.version > local.version: hadoop.download()
I've seen the following examples here on SO:
sc.parallelize().map(fetch_models)
But in Spark 1.6
parallelize()
requires some data to be used, like this shitty work-around I'm doing now:sc.parallelize(range(1, 1000)).map(fetch_models)
Just to be fairly sure that the function is run on ALL workers I set the range to 1000. I also don't exactly know how many workers are in the cluster when running.
I've read the programming documentation and googled relentlessly but I can't seem to find any way to actually just distribute anything to all workers without any data.
After this initialization phase is done, the streaming task is as usual, operating on incoming data from Kafka.
The way I'm using the models is by running a function similar to this:
spark_partitions = config.get(ConfigKeys.SPARK_PARTITIONS) stream.union(*create_kafka_streams())\ .repartition(spark_partitions)\ .foreachRDD(lambda rdd: rdd.foreachPartition(lambda partition: spam.on_partition(config, partition)))
Theoretically I could check whether or not the models are up to date in the
on_partition
function, though it would be really wasteful to do this on each batch. I'd like to do it before Spark starts retrieving batches from Kafka, since the downloading from HDFS can take a couple of minutes...UPDATE:
To be clear: it's not an issue on how to distribute the files or how to load them, it's about how to run an arbitrary method on all workers without operating on any data.
To clarify what actually loading models means currently:
def on_partition(config, partition): if not MyClassifier.is_loaded(): MyClassifier.load_models(config) handle_partition(config, partition)
While MyClassifier is something like this:
class MyClassifier: clf = None @staticmethod def is_loaded(): return MyClassifier.clf is not None @staticmethod def load_models(config): MyClassifier.clf = load_from_file(config)
Static methods since PySpark doesn't seem to be able to serialize classes with non-static methods (the state of the class is irrelevant with relation to another worker). Here we only have to call load_models() once, and on all future batches MyClassifier.clf will be set. This is something that should really not be done for each batch, it's a one time thing. Same with downloading the files from HDFS using fetch_models().
-
Oscar over 7 yearsThanks for the input. One question though, since this is a streaming task and not a batch job, wouldn't spark distribute these broadcasted models on every batch? If the batch interval is low (say 5 seconds), that would be quite the bottleneck.
-
Oscar over 7 yearsOr is the broadcast models created using the SparkContext and not the StreamingSparkContext? So the models are broadcasted once before streaming starts and then just reused on each batch?
-
sgvd over 7 yearsExactly,
broadcast
is a method onSparkContext
, and the broadcast is done just once before you start your stream. Note though what the documentation says about deserializing the data for each task: each batch creates (a) new tasks(s), implying that deserialization would happen for each batch. How bad this is depends on your case, but at least surely beats reading a file from disk at every batch. -
Oscar over 7 yearsI didn't know about SparkFiles, I wouldn't have to deal with HDFS directly, thanks. Though the main issue is still there: how to "retrieve it on the workers". I.e. how do I execute this on all workers before I start streaming? I don't want to have to do this on every batch, just once, since the models could be huge. Maybe easier to just do this as a separate step in the deployment phase instead of trying to use spark for it.
-
zero323 over 7 yearsFiles are already there.
SparkFiles.get(some_path)
only returns a local path where file resides. There are some subtleties here. As far as I know Streaming don't reuse Python workers between batches. So it has to be read on each batch. From the other hand using files gives some interesting options. If you can memory map your files then you can significantly reduce memory requirements. -
zero323 over 7 years@sgvd Correct me if I am wrong but last time I checked Python broadcasts have been passed to workers using disk. Am I missing something here?
-
Oscar over 7 years@zero323 If you have HDFS it will distribute using that I think, so not directly "using disk".
-
Oscar over 7 years@sgvd What do you mean "beats reading a file from disk at every batch"? Reading the models from local disk? It's only done once, before the batching starts. The loaded models are stored as class variables, so on every future batch to process, no loading has to be done since they're already loaded. Currently I have to check on every batch "if they have not been loaded yet, block stream and load from local disk, otherwise continue"; instead I want to be able to assume it's been loaded for each batch, since it's supposed to be done before the streaming even starts.
-
sgvd over 7 years@zero323: they are initially indeed passed using disk, but unless I am missing something, the
value
property caches the value on first read after creation of theBroadcast
object on the worker, which happens either at Python worker startup (which shouldn't happen at each task with Python worker reuse enabled by default) or when the variable is broadcast when the worker is already running. After that theBroadcast
is retrieved from the registry at every use, with the cached value data after the first use. If you think this is wrong I'd be happy to open a new question to discuss this. -
zero323 over 7 yearsScratch that. Discarding interpreter in streaming looks like Python version specific bug :/
-
retrocookie over 7 yearsIt frustrates me that import solves this issue. It feels especially opaque that import statements are the only way to run setup work on executors. Is there any reason for this? Can we somehow use the underlying mechanism that module imports use directly?
-
zero323 over 7 years@retrocookie You can use Borg pattern if makes you feel better but at the end of the day it is not a huge difference. Any why this way? Because you should never depend on the state of the executor, especially in Python.
-
Ansari about 6 yearsThis is a great answer. I thought I'd hit a dead end when I couldn't broadcast my models but this is a really neat workaround.
-
hangkongwang over 4 years@zero323 I have used your way, But in the worker's log I still see load_models be called every second, since my stream interval is 1 sec. I add a print in load_models.
-
hangkongwang over 4 years@zero323 by the way I use it in spark streaming.