A Guide to Flattening List of Lists in Python

In this tutorial, we will walk through various methods of list flattening, ranging from the simple yet powerful Python built-ins to the more advanced techniques utilizing libraries like NumPy.

Whether you are working with two-dimensional lists or more complex multidimensional data, we’ve got you covered.

Each method will be illustrated with example code, detailed explanations, and specific outputs to ensure you grasp the concept clearly and completely.
By the end of this tutorial, you’ll have a toolkit of techniques at your disposal to flatten any list object that comes your way.



Using Nested for Loops

The concept involves creating a new list and using a nested for loop to iterate through the elements in the original list. Let’s illustrate this with an example:

main_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

# Create a new list to store the flattened elements
new_list = []
for sub_list in main_list:
    for item in sub_list:


[1, 2, 3, 4, 5, 6, 7, 8, 9]

In this example, we first created an empty new list. Then, we used two for loops: the outer one to iterate through each list in the main list, and the inner one to go through each item within those lists.

The append function was used to add each list item to the new list, thereby creating a flat list.
One drawback of this method is its space complexity.

When dealing with large lists, this might consume a significant amount of memory since you’re creating a second list. Nonetheless, for smaller data sets, this method works fine and is easy to understand.


Using sum() function

This function applies a function of two arguments cumulatively to the items of iterable, from left to right, so as to reduce the iterable to a single output.
Let’s consider an example:

main_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

# Create a flat list using sum function
flat_list = sum(main_list, [])


[1, 2, 3, 4, 5, 6, 7, 8, 9]

In the above code, we’re using the sum function with two arguments. The first argument is the list we want to flatten, and the second argument is an empty list.

The sum function, in this case, concatenates each list in main_list with the empty list, cumulatively. As a result, we obtain a single list that includes all elements from the original lists.
While this method is concise, keep in mind that it can be slower than using nested loops, especially with large lists.

This is because the sum function creates a new list and reassigns the reference to this new list each time it processes a list from the input list.


Recursive Functions for Flattening a List of Lists

We can create a recursive solution in Python to flatten a list of lists.

This method is handy when dealing with lists that have an unknown number of nested lists (irregular nested lists). Let’s see how we can achieve this:

def flatten_list(nested_list):
    # if the input is not a list, return a list containing the input item
    if not isinstance(nested_list, list):
        return [nested_list]

    flattened = []
    for item in nested_list:
        flattened += flatten_list(item)

    return flattened

main_list = [[1, 2, [3, 3.1]], [4, [5, [5.1, 5.2]], 6], 7, [8, 9]]


[1, 2, 3, 3.1, 4, 5, 5.1, 5.2, 6, 7, 8, 9]

In this code, we defined a recursive function flatten_list to flatten any nested list. The function checks if the input is not a list, in which case it wraps the input in a list and returns it.

However, if the input is a list, it iterates over each list item and makes a recursive call to flatten_list.

This process continues until all nested lists are flattened.
It’s worth mentioning that when I tried the recursive solution, it led to a RecursionError because the depth of recursion exceeds Python’s maximum recursion depth. That all happened because I was trying to flatten an extremely deep nested list.


Using List Comprehension

List comprehension is a compact way of creating lists and can also be used to flatten a list of lists in Python.

Let’s look at how to use list comprehension to flatten a list:

main_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
flat_list = [item for sublist in main_list for item in sublist]


[1, 2, 3, 4, 5, 6, 7, 8, 9]

The list comprehension in the code above is equivalent to our earlier nested for loop example.

The expression [item for sublist in main_list for item in sublist] does the same job of iterating over each list and then each item within those lists, all in one compact line.
This method is quick and clean; however, it might be less readable for those new to Python, but once you used it many times, it will be your favorite method.


Using itertools

The chain() function of the itertools module allows you to flatten a list of lists by iterating over each list item as if they were all in one list.Here’s how you can use it:

from itertools import chain
main_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
flat_list = list(chain(*main_list))


[1, 2, 3, 4, 5, 6, 7, 8, 9]

In this code snippet, the chain function is used with the unpacking operator (*).

The chain function takes several iterables as arguments and returns a single iterable that contains all the elements of the passed iterables.

The unpacking operator is used to pass the lists in main_list as separate arguments to chain.
One advantage of this method is that it doesn’t create intermediate lists during the operation, making it a memory-friendly solution, especially when dealing with large lists.


Using functools reduce()

The reduce() function from the functools module applies a function of two arguments cumulatively to the items of iterable from left to right, so as to reduce the iterable to a single output.
Here’s an example of how you can use reduce to flatten a list of lists:

from functools import reduce
main_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

# Flatten the list using functools.reduce
flat_list = reduce(lambda x, y: x + y, main_list)


[1, 2, 3, 4, 5, 6, 7, 8, 9]

In the code above, reduce is applied to main_list using a lambda function that combines two lists (x and y) using the + operator.

It starts with the first two lists in main_list, combines them, and then combines the result with the next list, continuing this way until it has combined all the lists, thereby creating a single list.


Using NumPy

NumPy has a function called flatten() which can be used to flatten a list of lists into a single list.
Let’s see how you can use numpy.flatten() to create a flat list:

import numpy as np
main_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

# Convert the list of lists into a numpy array
array = np.array(main_list)

# Use the flatten method to flatten the array
flat_array = array.flatten()


[1, 2, 3, 4, 5, 6, 7, 8, 9]

In the above code, we converted the main list into a NumPy array using np.array().

The flatten() method was then used on the array to create a one-dimensional array.

Lastly, we converted the NumPy array back into a list.
One thing to keep in mind is that this method only works if all elements in the list are of the same type, like all integers or all floats.


Using extend() function

The extend function in Python is designed to add elements from an iterable to the end of the current list.
Let’s see how you can use the extend method to flatten a list of lists:

main_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
flat_list = []

# Iterate over each list in the main list
for sublist in main_list:    


[1, 2, 3, 4, 5, 6, 7, 8, 9]

In the above code, we created an empty list flat_list. We then looped over each list in the main list and used the extend method to add the elements of each sublist to flat_list.

This method does not create new lists while iterating, which makes it a memory-efficient solution.


Performance Comparison

For benchmark, I’ll create a list of 1000 lists, each containing 500 random integers between 0 and 500.

nested_list = [[random.randint(0, 500) for _ in range(500)] for _ in range(1000)]

After that, I will implement each of these methods and measure their execution time using Python’s timeit module.

import timeit
import random
import numpy as np
import itertools
from functools import reduce

# Create the nested list
nested_list = [[random.randint(0, 500) for _ in range(500)] for _ in range(1000)]

# Define the flattening functions
def flatten_for_loop(nested_list):
    flat_list = []
    for sublist in nested_list:
        for item in sublist:
    return flat_list

def flatten_sum(nested_list):
    return sum(nested_list, [])

def flatten_recursive(nested_list):
    def flatten(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                yield from flatten(item)
                yield item
    return list(flatten(nested_list))

def flatten_list_comprehension(nested_list):
    return [item for sublist in nested_list for item in sublist]

def flatten_itertools(nested_list):
    return list(itertools.chain(*nested_list))

def flatten_reduce(nested_list):
    return reduce(lambda x, y: x+y, nested_list)

def flatten_numpy(nested_list):
    return np.array(nested_list).flatten().tolist()

def flatten_extend(nested_list):
    flat_list = []
    for sublist in nested_list:
    return flat_list

# Define a list of methods
methods = [flatten_for_loop, flatten_sum, flatten_recursive, flatten_list_comprehension,
           flatten_itertools, flatten_reduce, flatten_numpy, flatten_extend]

# Create a dictionary to store the results
performance = {}

# Measure the time for each method
for method in methods:
    start_time = timeit.default_timer()
    elapsed = timeit.default_timer() - start_time
    performance[method.__name__] = elapsed

Benchmark result:

Method Time (seconds)
Nested For Loops 0.03192999999737367
sum() Function 1.7831921000033617
Recursive Functions 0.057356700010132045
List Comprehension 0.01263219999964349
itertools chain() 0.00968809999176301
functools reduce 1.1577529000060167
NumPy flatten() 0.03636749999714084
extend() Method 0.008474100002786145

From this comparison, it is clear that the extend() method is the fastest, followed by the itertools chain() method.


Flatten for Good

I often encounter diverse tasks in my freelancing journey. One of these was analyzing hashtags from multiple Twitter posts.

The raw data pulled from Twitter’s API presented itself as a list of lists, where each sublist contained hashtags from a single tweet. Initially, the data looked something like this:

hashtags_in_tweets = [["#AI", "#MachineLearning"], ["#DeepLearning", "#AI"], ["#DataScience", "#AI", "#Python"]]

Each sublist was a separate tweet, and within each of those lists were the individual hashtags. The end goal was to analyze the popularity of different hashtags across all tweets.

Now, it’s rather tricky to analyze data when it’s nestled within multiple layers of lists. I needed a flattened list, a single sequence where each hashtag could be analyzed in the context of the whole data set.

I was able to use a single line of code to flatten the list of lists (I used list comprehension):

all_hashtags = [hashtag for tweet in hashtags_in_tweets for hashtag in tweet]
With this simple yet powerful line of code, I was able to transform my nested data into a flat list, like this:
['#AI', '#MachineLearning', '#DeepLearning', '#AI', '#DataScience', '#AI', '#Python']
Then I used Python’s built-in collections.Counter class, to count the occurrences of each hashtag:
from collections import Counter
hashtag_counts = Counter(all_hashtags)
What I got was a clear, concise summary of the popularity of each hashtag:
Counter({'#AI': 3, '#MachineLearning': 1, '#DeepLearning': 1, '#DataScience': 1, '#Python': 1})
Thus, the task of flattening the nested list structure and analyzing the Twitter data became a smooth and efficient process.
Leave a Reply

Your email address will not be published. Required fields are marked *