PySpark - how local file reads & writes can help performance

A quick guide to problem solving memory heap and garbage collection issues in your dataframe

Picture this...

Scene/

You are writing a long, winding series of Spark transformations on a dataset. It’s all going well, but you notice after about ten joins, and hundreds of lines of code, that your dataset goes from taking a few minutes to a few hours, eventually leading to garbage collection, heap space, or other memory issues. You start reading and trying different Spark calls, adjusting how the data is shuffled, re-writes, reallocating local memory, all to no avail.

/Scene

If you’re working in PySpark (or Spark in general), Spark should be doing a lot of optimization behind the scenes. However Spark may get confused if you have a lot of joins on different datasets or other expensive computations.

If Spark is unable to optimize your work, you might run into garbage collection or heap space issues. If you’ve already attempted to make calls to repartition, coalesce, persist, and cache, and none have worked, it may be time to consider having Spark write the dataframe to a local file and reading it back. Writing your dataframe to a file can help Spark clear the backlog of memory consumption caused by Spark being lazily-evaluated.

However, as a warning, if you write out an intermediate dataframe to a file, you can’t keep reusing the same path. The issue arises from trying to read and write to the same path you’re overwriting as the data cannot be streamed into the same directory you’re trying to overwrite. This is because Spark reads only a small portion of your data into memory, leaving the bulk of it on disk until it’s needed. If you read a dataset and then try to write it out to the same path, it creates a conflict in Spark’s usual order of operations.

Backstory on how I first encountered this problem

I am a software engineer at Capital One occasionally working with Python and Spark. I recently had an issue with Spark transformations that appeared logically correct, but resulted in memory and performance issues. There were a lot of transformations, but they were all as simple as I thought they could be. If it wasn’t a heap space issue, it was garbage collection; either in the transformation or tests. Running commands one by one in a Jupyter Notebook it was still difficult to find where performance suffered. Expanding the memory my notebook used, or the memory that Pycharm used, when running tests didn’t make any difference.

The function below is small and only a few lines of code, but can solve a lot of confusion. I managed to solve it using a combination of StackOverflow suggestions, as well as help from other more experienced developers here at Capital One. As always, there are some places this function won’t work, but we’ll talk more on that after looking at the code.

Give me the code!

Consider the following Python code. I’ve tried to make it PEP8 formatted for convenience, so if you’re new to the language you can promise your team that you know and follow PEP8 standards (and definitely didn’t come from a Java team that didn’t follow them):

    def clear_computation_graph(data_frame, spark_session):
   """Returns a ‘cleared’ dataframe after saving it for PySpark to work from 
    
    This will 'clear' the computation graph for you
    since occasionally spark will poorly optimize commands.
   This is useful to avoid having too many nested operations 
    especially more complex ones that tank performance.
   Keyword arguments:
   data_frame -- the dataframe you want to clear the graph for
   spark_session -- your current spark session
   """
   with tempfile.TemporaryDirectory() as path:
       data_frame.write.parquet(path, mode="overwrite")
       data_frame = spark_session.read.parquet(path)
       data_frame.cache()
       data_frame.count()
       return data_frame
  

Explanation

Why are we doing all these things?

What this does is create a temporary directory that will only exist for this function. It will delete itself and its contents after the return. It then writes your dataframe to a parquet file, and reads it back out immediately. It will then cache the dataframe to local memory, perform an action, and return the dataframe.

Writing to a temporary directory that deletes itself avoids creating a memory leak. Additionally, we don’t really care what directory we write to, so we get an added bonus of not having to know anything about our current directory structure, or the extra work of passing a location to save things.

Writing to a parquet file and reading back out immediately “clears” the computation graph to help Spark start from a fresh slate up to that point.

Cache is a lazily-evaluated operation, meaning Spark won’t run that command until an “action” is called. Actions cause the Spark graph to compute up to that point. Count is an action, to ensure Spark will actually run all the commands up to this point and cache the dataframe in memory. We don’t care about what action is here. Cache and Count are so the dataframe remains available in memory, since the directory on disk will clean itself up and you’ll lose the contents otherwise.

Where might this solution not work?

There are some cases where this solution won’t work:

  • If you have Spark optimizing things properly already (Spark cannot optimize User Defined Functions like this one).
  • If you cannot afford the memory to write to a local file (Alternatives to this in a bit).
  • If your dataframe is so large that you cannot afford to have Spark perform an action, or cache your dataframe in memory.
  • If your Spark transformation is simple and has performance issues better solved with tweaks, such as removing looping.

Caching should only be reserved for very small data sets that might be re-read over and over again like small CSV sheets. If you’re working with a large Parquet data set, that will interfere with Spark’s native optimization and will make your queries less efficient and potentially cause memory exhaustion problems.

An alternative is to just write to a random or temp directory, and ignore the cache/count. In some cases disk space can be very cheap to expand compared to memory. So that would look like this

    def saveandload(df, path):	
   """	
   Save a Spark dataframe to disk and immediately read back.	

   This can be used as part of a checkpointing scheme	
   as well as breaking Spark's computation graph.	
   """	
   df.write.parquet(path, mode="overwrite")	
   return spark.read.parquet(path)

my_df = saveandload(my_df, "/tmp/abcdef")
  

Rebuttal!

But wait, why does this work exactly? These operations are pretty expensive.

In theory, this function would be inefficient compared to just caching and Spark would work in such a way that you wouldn’t need to do this. User-defined functions are difficult for Spark, since it cannot optimize what’s going on inside them. Saving to memory and calling an action are already expensive operations on their own. But, in some cases, trying to cache and persist the dataframe doesn’t actually break up all of the calculations. If you have a lot of complex joins or other logic, Spark might need the hard break and to start with what it thinks of as a fresh dataframe.

Make this solution your own

Lots of this can be switched around - if you can’t write your dataframe to local, you can write to an S3 bucket. You don’t have to save your dataframe as a parquet file, or even use overwrite. You should be able to use any Spark Action instead of count. Cache can be switched for persist with whatever storage level you want.

So what would it look like if I call this method?

Probably something like this:

You have a spark session:

    from pyspark.sql import SparkSession

val spark_session = SparkSession
  .builder()
  .appName("Spark SQL basic example")
  .config("spark.some.config.option", "some-value")
  .getOrCreate()
  
    val complex_dataframe = spark.read.csv("/src/resources/file.csv")
…
#Some transformations on complex_dataframe
...
  

And once you start seeing memory issues when running, you can just make a call to clear the dataframe:

    complex_dataframe = clear_computation_graph(complex_dataframe, spark_session)
  

That’s all you need! Then you can continue to run transformations on your dataframe.

Big thanks to my colleagues at Capital One

Capital One has incredible developers and engineers. I’ve had the good fortune to work with some of them, and came to this solution with their attention, experience, and guidance. I specifically want to acknowledge Robin Neufeld and Cleland Loszewski for their help with this. Thank you!


David Slimak, Software Engineer

David Slimak is a Software Engineer working with PySpark and Airflow to support Data Analysts. He graduated in 2019 with a Bachelor's degree in Computer Science from Michigan State University and has been working at Capital One since then. David enjoys seeing new things and interesting ideas in board games, video games, and tech. (https://www.linkedin.com/in/david-slimak-a4724a113)

Related Content