Plotting data in NumPy and pandas is handled by an external Python module called matplotlib. Like NumPy and pandas it is a large library and has been around for a while (first released in 2003). Hence we won't cover all its functionality in this lesson.
To see the wide range of possibilities you have with matplotlib see matplotlib example gallery, Nicolas P. Rougier's tutorial and Ben Root's tutorial after the course.
Here we will cover the basic uses of it and how it integrates with NumPy and pandas. While working through these examples you may want to refer to the matplotlib documentation.
The most common interface to matplotlib is its pyplot
module which provides a way to affect the current state of matplotlib directly. As with both NumPy and pandas, there is a conventional way to import matplot lib, which is as follows:
%config InlineBackend.figure_format = 'svg'
import matplotlib.pyplot as plt
import numpy as np
First we import pandas in the same way as we did previously.
import pandas as pd
from pandas import Series, DataFrame
Some matplotlib functionality is provided directly through pandas (such as the plot()
method as we saw above) but for extra functionality you need to import the matplotlib interface itself as well (but we already did that above).
We first need to import some data to plot. Let's start with the data from the end of the pandas section (available from cetml1659on.dat) and import it into a DataFrame
. This is also the solution to the last exercise in the previous section.
csv_file = 'data/cetml1659on.dat'
df = pd.read_csv(csv_file, # file name
skiprows=6, # skip header
sep='\s+', # whitespace separated
na_values=['-99.9', '-99.99'] # NaNs
)
df.head()
Pandas integrates matplotlib directly into itself so any dataframe can be plotted easily simply by calling the plot()
method on one of the columns. This creates a plot object which you can then edit and alter, for example by setting the axis labels using the plt.ylabel()
function before displaying it with plt.show()
.
As above, matplotlib operates on a single global state and calling any function on plt
will alter that state. Calling df.plot()
sets the currently operating plot. plt.ylabel()
then alters that state and plt.show()
displays it.
df['JAN'].plot()
plt.title('January Climate Plot')
plt.xlabel('Year')
plt.ylabel('Temperature ($^\circ$C)')
plt.show()
Matplotlib can plot more than just line graphs, another plot type is a bar chart. We will construct a bar chart of the average temperature per decade.
We start by adding a new column to the data frame which represents the decade. We create it by taking the index (which is a list of years), converting each element to a string and then replacing the fourth character with a '0'
.
years = Series(df.index, index=df.index).apply(str)
decade = years.apply(lambda x: x[:3]+'0')
df['decade'] = decade
df.head()
You may be wondering how
lambda x: x[:3]+'0'
replaces the fourth character with a '0'
. This is through the use of Python lambda functions, which are documented here. In this case using a lambda function is a quick way of succinctly writing a function (with no name) in one line. The following is an equivalent, but longer way of doing the same thing:
def year2decade(year):
'''Takes a year as a string and returns the decade, also as a string
'''
decade = year[:3] + '0'
return decade
decade = years.apply(year2decade)
.
Once we have our decade column, we can use pandas groupby()
function to gather our data by decade and then aggregate it by taking the mean of each decade.
by_decade = df.groupby('decade')
agg = by_decade.aggregate(np.mean)
agg.head()
At this point, agg
is a standard pandas DataFrame
so we can plot it like any other, by putting .bar
after the plot
call:
agg.YEAR.plot.bar()
plt.title('Average temperature per decade')
plt.xlabel('Decade')
plt.ylabel('Temperature ($^\circ$C)')
plt.show()