How to run a function on all Spark workers before processing data in PySpark?

15,926

Solution 1

If all you want is to distribute a file between worker machines the simplest approach is to use SparkFiles mechanism:

some_path = ...  # local file, a file in DFS, an HTTP, HTTPS or FTP URI.
sc.addFile(some_path)

and retrieve it on the workers using SparkFiles.get and standard IO tools:

from pyspark import SparkFiles

with open(SparkFiles.get(some_path)) as fw:
    ... # Do something

If you want to make sure that model is actually loaded the simplest approach is to load on module import. Assuming config can be used to retrieve model path:

  • model.py:

    from pyspark import SparkFiles
    
    config = ...
    class MyClassifier:
        clf = None
    
        @staticmethod
        def is_loaded():
            return MyClassifier.clf is not None
    
        @staticmethod
        def load_models(config):
            path = SparkFiles.get(config.get("model_file"))
            MyClassifier.clf = load_from_file(path)
    
    # Executed once per interpreter 
    MyClassifier.load_models(config)  
    
  • main.py:

    from pyspark import SparkContext
    
    config = ...
    
    sc = SparkContext("local", "foo")
    
    # Executed before StreamingContext starts
    sc.addFile(config.get("model_file"))
    sc.addPyFile("model.py")
    
    import model
    
    ssc = ...
    stream = ...
    stream.map(model.MyClassifier.do_something).pprint()
    
    ssc.start()
    ssc.awaitTermination()
    

Solution 2

This is a typical use case for Spark's broadcast variables. Let's say fetch_models returns the models rather than saving them locally, you would do something like:

bc_models = sc.broadcast(fetch_models())

spark_partitions = config.get(ConfigKeys.SPARK_PARTITIONS)
stream.union(*create_kafka_streams())\
    .repartition(spark_partitions)\
    .foreachRDD(lambda rdd: rdd.foreachPartition(lambda partition: spam.on_partition(config, partition, bc_models.value)))

This does assume that your models fit in memory, on the driver and the executors.

You may be worried that broadcasting the models from the single driver to all the executors is inefficient, but it uses 'efficient broadcast algorithms' that can outperform distributing through HDFS significantly according to this analysis

Share:
15,926

Related videos on Youtube

Oscar
Author by

Oscar

Updated on September 15, 2022

Comments

  • Oscar
    Oscar about 1 year

    I'm running a Spark Streaming task in a cluster using YARN. Each node in the cluster runs multiple spark workers. Before the streaming starts I want to execute a "setup" function on all workers on all nodes in the cluster.

    The streaming task classifies incoming messages as spam or not spam, but before it can do that it needs to download the latest pre-trained models from HDFS to local disk, like this pseudo code example:

    def fetch_models():
        if hadoop.version > local.version:
            hadoop.download()
    

    I've seen the following examples here on SO:

    sc.parallelize().map(fetch_models)
    

    But in Spark 1.6 parallelize() requires some data to be used, like this shitty work-around I'm doing now:

    sc.parallelize(range(1, 1000)).map(fetch_models)
    

    Just to be fairly sure that the function is run on ALL workers I set the range to 1000. I also don't exactly know how many workers are in the cluster when running.

    I've read the programming documentation and googled relentlessly but I can't seem to find any way to actually just distribute anything to all workers without any data.

    After this initialization phase is done, the streaming task is as usual, operating on incoming data from Kafka.

    The way I'm using the models is by running a function similar to this:

    spark_partitions = config.get(ConfigKeys.SPARK_PARTITIONS)
    stream.union(*create_kafka_streams())\
        .repartition(spark_partitions)\
        .foreachRDD(lambda rdd: rdd.foreachPartition(lambda partition: spam.on_partition(config, partition)))
    

    Theoretically I could check whether or not the models are up to date in the on_partition function, though it would be really wasteful to do this on each batch. I'd like to do it before Spark starts retrieving batches from Kafka, since the downloading from HDFS can take a couple of minutes...

    UPDATE:

    To be clear: it's not an issue on how to distribute the files or how to load them, it's about how to run an arbitrary method on all workers without operating on any data.

    To clarify what actually loading models means currently:

    def on_partition(config, partition):
        if not MyClassifier.is_loaded():
            MyClassifier.load_models(config)
    
        handle_partition(config, partition)
    

    While MyClassifier is something like this:

    class MyClassifier:
        clf = None
    
        @staticmethod
        def is_loaded():
            return MyClassifier.clf is not None
    
        @staticmethod
        def load_models(config):
            MyClassifier.clf = load_from_file(config)
    

    Static methods since PySpark doesn't seem to be able to serialize classes with non-static methods (the state of the class is irrelevant with relation to another worker). Here we only have to call load_models() once, and on all future batches MyClassifier.clf will be set. This is something that should really not be done for each batch, it's a one time thing. Same with downloading the files from HDFS using fetch_models().

  • Oscar
    Oscar over 7 years
    Thanks for the input. One question though, since this is a streaming task and not a batch job, wouldn't spark distribute these broadcasted models on every batch? If the batch interval is low (say 5 seconds), that would be quite the bottleneck.
  • Oscar
    Oscar over 7 years
    Or is the broadcast models created using the SparkContext and not the StreamingSparkContext? So the models are broadcasted once before streaming starts and then just reused on each batch?
  • sgvd
    sgvd over 7 years
    Exactly, broadcast is a method on SparkContext, and the broadcast is done just once before you start your stream. Note though what the documentation says about deserializing the data for each task: each batch creates (a) new tasks(s), implying that deserialization would happen for each batch. How bad this is depends on your case, but at least surely beats reading a file from disk at every batch.
  • Oscar
    Oscar over 7 years
    I didn't know about SparkFiles, I wouldn't have to deal with HDFS directly, thanks. Though the main issue is still there: how to "retrieve it on the workers". I.e. how do I execute this on all workers before I start streaming? I don't want to have to do this on every batch, just once, since the models could be huge. Maybe easier to just do this as a separate step in the deployment phase instead of trying to use spark for it.
  • zero323
    zero323 over 7 years
    Files are already there. SparkFiles.get(some_path) only returns a local path where file resides. There are some subtleties here. As far as I know Streaming don't reuse Python workers between batches. So it has to be read on each batch. From the other hand using files gives some interesting options. If you can memory map your files then you can significantly reduce memory requirements.
  • zero323
    zero323 over 7 years
    @sgvd Correct me if I am wrong but last time I checked Python broadcasts have been passed to workers using disk. Am I missing something here?
  • Oscar
    Oscar over 7 years
    @zero323 If you have HDFS it will distribute using that I think, so not directly "using disk".
  • Oscar
    Oscar over 7 years
    @sgvd What do you mean "beats reading a file from disk at every batch"? Reading the models from local disk? It's only done once, before the batching starts. The loaded models are stored as class variables, so on every future batch to process, no loading has to be done since they're already loaded. Currently I have to check on every batch "if they have not been loaded yet, block stream and load from local disk, otherwise continue"; instead I want to be able to assume it's been loaded for each batch, since it's supposed to be done before the streaming even starts.
  • sgvd
    sgvd over 7 years
    @zero323: they are initially indeed passed using disk, but unless I am missing something, the value property caches the value on first read after creation of the Broadcast object on the worker, which happens either at Python worker startup (which shouldn't happen at each task with Python worker reuse enabled by default) or when the variable is broadcast when the worker is already running. After that the Broadcast is retrieved from the registry at every use, with the cached value data after the first use. If you think this is wrong I'd be happy to open a new question to discuss this.
  • zero323
    zero323 over 7 years
    Scratch that. Discarding interpreter in streaming looks like Python version specific bug :/
  • retrocookie
    retrocookie over 7 years
    It frustrates me that import solves this issue. It feels especially opaque that import statements are the only way to run setup work on executors. Is there any reason for this? Can we somehow use the underlying mechanism that module imports use directly?
  • zero323
    zero323 over 7 years
    @retrocookie You can use Borg pattern if makes you feel better but at the end of the day it is not a huge difference. Any why this way? Because you should never depend on the state of the executor, especially in Python.
  • Ansari
    Ansari about 6 years
    This is a great answer. I thought I'd hit a dead end when I couldn't broadcast my models but this is a really neat workaround.
  • hangkongwang
    hangkongwang over 4 years
    @zero323 I have used your way, But in the worker's log I still see load_models be called every second, since my stream interval is 1 sec. I add a print in load_models.
  • hangkongwang
    hangkongwang over 4 years
    @zero323 by the way I use it in spark streaming.