August 28, 2009 Plotting data with matplotlib
In continuing my slow migration away from “office-like” tools for working with my data, I’ve been taking a look lately at matplotlib. I’ve banged together a rough script to do some simple data plotting with a bit of flexibility:
#! /usr/bin/env python # https://biostumblematic.wordpress.com # An interface to matplotlib # Import modules import csv, sys import matplotlib.pyplot as plt import numpy as np # Introduce the program print '-'*60 print 'Your data should be in CSV format, with Y-values' print 'in odd columns and X-values in even columns.' print 'If your file contains a header row, these will be' print 'automatically detected' print '-'*60 # Open the data datafile = sys.argv[1] f = open(datafile, 'r') # Check to see if the file starts with headers or data: dialect = csv.Sniffer().has_header(f.read(1024)) f.seek(0) reader = csv.reader(f) # Assign the data to series via a dict if dialect is True: reader.next() # Move down a line to skip headers else: pass series_dict = {} for row in reader: i = 0 for column in row: i += 1 if series_dict.has_key(i): try: series_dict[i].append(float(column)) except ValueError: pass else: series_dict[i] = [float(column)] # Plot each data series num_cols = len(series_dict) i = 1 while i < num_cols: plt.plot(series_dict[i], series_dict[i+1], 'o') i += 2 # Get axis labels xaxis_label = raw_input('X-axis label > ') yaxis_label = raw_input('Y-axis label > ') # Show the plot plt.ylabel(yaxis_label) plt.xlabel(xaxis_label) plt.show() # Enter loop for customizing appearance # Stop f.close()
As-is this will read in a CSV file of any number of columns and plot them as Y values/X values (alternating).
Some things that feel nasty:
- Having to use the dictionaries to get the column data assembled. I feel like the CSV reader module should have a “transpose” function
- The section near the end where I’m generating the different plots by iterating over the number of columns.
Some things that would be nice to implement, but I haven’t figured out yet:
- More differentiation of the appearance for each series’ plot
- Automatic generation of a legend using headers for the X-values from the initial file (or else requested from the user at run-time if not in the file)
- 5 comments
- Posted under Python
Permalink # fitzgeraldsteele said
maybe csv.DictReader is what you’re looking for? http://docs.python.org/library/csv.html#csv.DictReader
Also, you could save a lot of code between lines 26-52 if you look into generator functions. Once you grok generators (and list comprehensions), you’ll wonder how you ever lived without them. Those should address the nastiness you’re feeling.
This is where I learned about them:
http://www.dabeaz.com/generators/ (The presentation slides are very good)
Here’s an example of how I used them to parse a csv file:
Permalink # jwinget said
I’ve tried dealing with DictReader on several occasions. Either the documentation is lacking or I’m missing something fundamental, because I could never wrangle it into doing the things I was interested in.
Thanks for the tip on generator functions, I will definitely look into those!
Permalink # achemistfollowscode said
Thank you for posting this up! It was really helpful. I am still not sure of a few things, and I hard-coded labels for the plot so it isn’t dynamic in that respect, but very useful for me to monitor energy changes over time on running jobs. Thanks!
Permalink # jmayer said
Much easier if you use the pandas library for reading in the csv data and getting into a dataframe and getting it into the shape you want (easy transpose method, for example), then plot the series from the dataframe…
Permalink # Max said
I added a few lines that parse x and y labels if they’re in the csv file and if not it prompts you for them.
#!/usr/bin/python
import csv
import sys
import numpy as np
import matplotlib.pyplot as plt
f = open(sys.argv[1], ‘r’)
dialect = csv.Sniffer().has_header(f.read(1024))
f.seek(0)
reader = csv.reader(f)
axes_dict = {}
axes_dict[‘labels’] = []
series_dict = {}
r = 0
for row in reader:
i = 0
for column in row:
i += 1
r += 1
if dialect is True and r < 3:
axes_dict['labels'].append(column)
continue
if series_dict.has_key(i):
try:
series_dict[i].append(float(column))
except ValueError:
pass
else:
series_dict[i] = [float(column)]
num_cols = len(series_dict)
i = 1
while i ‘))
axes_dict[‘labels’].append(raw_input(‘Y-axis label> ‘))
xaxis_label = axes_dict[‘labels’][0]
yaxis_label = axes_dict[‘labels’][1]
plt.ylabel(yaxis_label)
plt.xlabel(xaxis_label)
plt.show()
f.close()