Introduction to matplotlib

Overview:

  • Teaching: 5 min
  • Exercises: 10 min

Questions

  • How do I plot from Pandas matplotlib?
  • Why should I use a progammatic approach to processing data and producing plots?

Objectives

  • See how to quickly plot from a data frame.
  • Use the plot, bar and scatter functions to produce different plots.

Plotting data in NumPy and pandas is handled by an external Python module called matplotlib. Like NumPy and pandas it is a large library and has been around for a while (first released in 2003). Hence we won't cover all its functionality in this lesson.

To see the wide range of possibilities you have with matplotlib see matplotlib example gallery, Nicolas P. Rougier's tutorial and Ben Root's tutorial after the course.

Here we will cover the basic uses of it and how it integrates with NumPy and pandas. While working through these examples you may want to refer to the matplotlib documentation.

Plotting using NumPy and matplotlib

The most common interface to matplotlib is its pyplot module which provides a way to affect the current state of matplotlib directly. As with both NumPy and pandas, there is a conventional way to import matplot lib, which is as follows:

In [1]:
%config InlineBackend.figure_format = 'svg'
import matplotlib.pyplot as plt
import numpy as np

Plotting using pandas and matplotlib

First we import pandas in the same way as we did previously.

In [2]:
import pandas as pd
from pandas import Series, DataFrame

Some matplotlib functionality is provided directly through pandas (such as the plot() method as we saw above) but for extra functionality you need to import the matplotlib interface itself as well (but we already did that above).

We first need to import some data to plot. Let's start with the data from the end of the pandas section (available from cetml1659on.dat) and import it into a DataFrame. This is also the solution to the last exercise in the previous section.

In [3]:
csv_file = 'data/cetml1659on.dat'
df = pd.read_csv(csv_file, # file name
                 skiprows=6,  # skip header
                 sep='\s+',  # whitespace separated
                 na_values=['-99.9', '-99.99']  # NaNs
                )
df.head()
Out[3]:
JAN FEB MAR APR MAY JUN JUL AUG SEP OCT NOV DEC YEAR
1659 3.0 4.0 6.0 7.0 11.0 13.0 16.0 16.0 13.0 10.0 5.0 2.0 8.87
1660 0.0 4.0 6.0 9.0 11.0 14.0 15.0 16.0 13.0 10.0 6.0 5.0 9.10
1661 5.0 5.0 6.0 8.0 11.0 14.0 15.0 15.0 13.0 11.0 8.0 6.0 9.78
1662 5.0 6.0 6.0 8.0 11.0 15.0 15.0 15.0 13.0 11.0 6.0 3.0 9.52
1663 1.0 1.0 5.0 7.0 10.0 14.0 15.0 15.0 13.0 10.0 7.0 5.0 8.63

Pandas integrates matplotlib directly into itself so any dataframe can be plotted easily simply by calling the plot() method on one of the columns. This creates a plot object which you can then edit and alter, for example by setting the axis labels using the plt.ylabel() function before displaying it with plt.show().

As above, matplotlib operates on a single global state and calling any function on plt will alter that state. Calling df.plot() sets the currently operating plot. plt.ylabel() then alters that state and plt.show() displays it.

In [4]:
df['JAN'].plot()

plt.title('January Climate Plot')
plt.xlabel('Year')
plt.ylabel('Temperature ($^\circ$C)')

plt.show()

Summer climate

  • Try reproducing the plot above but for the month of June.
  • Try putting in two plot() calls with different months (both January and June for example) before calling show().
  • Add a legend to distinguish the two lines.

Solution

Bar charts

Matplotlib can plot more than just line graphs, another plot type is a bar chart. We will construct a bar chart of the average temperature per decade.

We start by adding a new column to the data frame which represents the decade. We create it by taking the index (which is a list of years), converting each element to a string and then replacing the fourth character with a '0'.

In [5]:
years = Series(df.index, index=df.index).apply(str)
decade = years.apply(lambda x: x[:3]+'0')

df['decade'] = decade
df.head()
Out[5]:
JAN FEB MAR APR MAY JUN JUL AUG SEP OCT NOV DEC YEAR decade
1659 3.0 4.0 6.0 7.0 11.0 13.0 16.0 16.0 13.0 10.0 5.0 2.0 8.87 1650
1660 0.0 4.0 6.0 9.0 11.0 14.0 15.0 16.0 13.0 10.0 6.0 5.0 9.10 1660
1661 5.0 5.0 6.0 8.0 11.0 14.0 15.0 15.0 13.0 11.0 8.0 6.0 9.78 1660
1662 5.0 6.0 6.0 8.0 11.0 15.0 15.0 15.0 13.0 11.0 6.0 3.0 9.52 1660
1663 1.0 1.0 5.0 7.0 10.0 14.0 15.0 15.0 13.0 10.0 7.0 5.0 8.63 1660

Info: Lambda functions

You may be wondering how

lambda x: x[:3]+'0'

replaces the fourth character with a '0'. This is through the use of Python lambda functions, which are documented here. In this case using a lambda function is a quick way of succinctly writing a function (with no name) in one line. The following is an equivalent, but longer way of doing the same thing:

def year2decade(year):
    '''Takes a year as a string and returns the decade, also as a string
    '''
    decade = year[:3] + '0'
    return decade

decade = years.apply(year2decade)

.

Once we have our decade column, we can use pandas groupby() function to gather our data by decade and then aggregate it by taking the mean of each decade.

In [6]:
by_decade = df.groupby('decade')
agg = by_decade.aggregate(np.mean)

agg.head()
Out[6]:
JAN FEB MAR APR MAY JUN JUL AUG SEP OCT NOV DEC YEAR
decade
1650 3.00 4.00 6.00 7.00 11.00 13.00 16.00 16.00 13.00 10.00 5.00 2.00 8.870
1660 2.60 4.00 5.10 7.70 10.60 14.50 16.00 15.70 13.30 10.00 6.30 3.80 9.157
1670 3.25 2.35 4.50 7.25 11.05 14.40 15.80 15.25 12.40 8.95 5.20 2.45 8.607
1680 2.50 2.80 4.80 7.40 11.45 14.00 15.45 14.90 12.70 9.55 5.45 4.05 8.785
1690 1.89 2.49 3.99 6.79 9.60 13.44 15.27 14.65 11.93 8.64 5.26 3.31 8.134

At this point, agg is a standard pandas DataFrame so we can plot it like any other, by putting .bar after the plot call:

In [7]:
agg.YEAR.plot.bar()

plt.title('Average temperature per decade')
plt.xlabel('Decade')
plt.ylabel('Temperature ($^\circ$C)')

plt.show()

Other graphs

  1. Modify the above code to plot a bar chart of the average temperature per century.

  2. Plot a histogram of the average annual temperature. Make sure that the x-axis is labelled correctly. Hint: Look in the documentation for the right command to run.

  3. Plot a scatter plot of each year's February temperature plotted against that year's January temperature. Is there an obvious correlation?

Solution

Key Points:

  • We can plot a function quickly using plot.plot(x, y).
  • We can (and should) add a title and axes labels to our plots.
  • The bar() function creates bar charts.
  • The hist() function creates histograms.
  • The scatter() function creates scatter diagrams.