Complete Matplotlib Guide

Master data visualization from basics to advanced techniques. Learn matplotlib through comprehensive tutorials, real-world examples, and hands-on practice.

Introduction to Matplotlib

Matplotlib is the cornerstone of data visualization in Python. It provides a comprehensive set of tools for creating static, animated, and interactive visualizations. Think of it as your digital canvas for turning data into insights.

Why Matplotlib?

  • • Complete control over every visual element
  • • Publication-quality plots ready for scientific journals
  • • Extensive customization options
  • • Foundation for other libraries like Seaborn
  • • Industry standard for Python visualization

Interactive Demo: Your First Plot

Let's start with a simple line plot. Click the "Run Code" button below to see your first matplotlib visualization!

import matplotlib.pyplot as plt
import numpy as np

# Create simple data
x = np.linspace(0, 10, 100)
y = np.sin(x)

# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'b-', linewidth=2, label='sin(x)')

# Add labels and title
plt.xlabel('X Values', fontsize=12)
plt.ylabel('Y Values', fontsize=12)
plt.title('Your First Matplotlib Plot!', fontsize=16, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.legend()

# Show the plot
plt.show()

Click "Run Code" to see your first plot!

Essential Plot Types

Matplotlib offers a wide variety of plot types, each suited for different data visualization needs. Understanding when and how to use each type is crucial for effective data communication.

Interactive Examples

Multiple Line Plot with Styling

import matplotlib.pyplot as plt
import numpy as np

# Create data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = np.tan(x)

# Create plot
plt.figure(figsize=(12, 8))
plt.plot(x, y1, label='sin(x)', color='#1f77b4', linewidth=2.5)
plt.plot(x, y2, label='cos(x)', color='#ff7f0e', linewidth=2.5, linestyle='--')
plt.plot(x, y3, label='tan(x)', color='#2ca02c', linewidth=2, alpha=0.7)

# Styling
plt.title('Trigonometric Functions Comparison', fontsize=16, fontweight='bold', pad=20)
plt.xlabel('X Values', fontsize=12, fontweight='bold')
plt.ylabel('Y Values', fontsize=12, fontweight='bold')
plt.legend(fontsize=11, frameon=True, fancybox=True, shadow=True)
plt.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
plt.ylim(-5, 5)

# Add annotations
plt.annotate('sin(π/2) = 1', xy=(np.pi/2, 1), xytext=(4, 3),
             arrowprops=dict(arrowstyle='->', color='blue', alpha=0.7),
             fontsize=10, ha='center')

plt.tight_layout()
plt.show()

Click "Run" to see the line plot

Real-World Applications

Matplotlib is used across various industries and research fields. Understanding real-world applications helps you choose the right visualization for your specific needs.

Financial Analysis

Stock price trends, portfolio performance, risk analysis, and market research visualization.

• Stock price charts with volume
• Portfolio allocation pie charts
• Risk-return scatter plots
• Time series forecasting

Scientific Research

Publication-ready plots for research papers, experimental data visualization, and statistical analysis.

• Experimental data plots
• Error bar visualization
• Statistical distribution plots
• Correlation matrices

Business Intelligence

Sales performance, customer analytics, market trends, and operational metrics visualization.

• Sales trend analysis
• Customer segmentation
• Performance dashboards
• Market share analysis

Healthcare Analytics

Patient data visualization, clinical trial results, epidemiological studies, and medical research.

• Patient vital signs monitoring
• Clinical trial results
• Disease spread modeling
• Treatment effectiveness

Real-World Example: Stock Price Analysis

Let's create a comprehensive stock price visualization that includes price trends, volume analysis, and technical indicators - exactly what financial analysts use daily.

import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime, timedelta

# Generate sample stock data
dates = [datetime(2024, 1, 1) + timedelta(days=i) for i in range(100)]
np.random.seed(42)
price_changes = np.random.randn(100) * 2 + 0.1
prices = 100 + np.cumsum(price_changes)
volumes = np.random.randint(1000000, 5000000, 100)

# Create subplots
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(15, 12), 
                                     gridspec_kw={'height_ratios': [3, 1, 1]})

# 1. Price chart with moving averages
ax1.plot(dates, prices, label='Stock Price', color='#1f77b4', linewidth=2)

# Add moving averages
ma_20 = [np.mean(prices[max(0, i-19):i+1]) for i in range(len(prices))]
ma_50 = [np.mean(prices[max(0, i-49):i+1]) for i in range(len(prices))]

ax1.plot(dates, ma_20, label='MA 20', color='orange', alpha=0.8, linewidth=1.5)
ax1.plot(dates, ma_50, label='MA 50', color='red', alpha=0.8, linewidth=1.5)

# Highlight buy/sell signals
buy_signals = [i for i in range(20, len(prices)) if ma_20[i] > ma_50[i] and ma_20[i-1] <= ma_50[i-1]]
sell_signals = [i for i in range(20, len(prices)) if ma_20[i] < ma_50[i] and ma_20[i-1] >= ma_50[i-1]]

for signal in buy_signals:
    ax1.scatter(dates[signal], prices[signal], color='green', s=100, marker='^', zorder=5)
for signal in sell_signals:
    ax1.scatter(dates[signal], prices[signal], color='red', s=100, marker='v', zorder=5)

ax1.set_title('Stock Price Analysis with Technical Indicators', fontsize=16, fontweight='bold')
ax1.set_ylabel('Price ($)', fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Volume chart
ax2.bar(dates, volumes, color='gray', alpha=0.6, width=0.8)
ax2.set_ylabel('Volume', fontweight='bold')
ax2.grid(True, alpha=0.3)

# 3. Price change histogram
price_changes_pct = np.diff(prices) / prices[:-1] * 100
ax3.hist(price_changes_pct, bins=20, alpha=0.7, color='lightblue', edgecolor='black')
ax3.axvline(np.mean(price_changes_pct), color='red', linestyle='--', linewidth=2, 
           label=f'Mean: {np.mean(price_changes_pct):.2f}%')
ax3.set_xlabel('Daily Price Change (%)', fontweight='bold')
ax3.set_ylabel('Frequency', fontweight='bold')
ax3.legend()
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Click "Run" to see the stock analysis

Best Practices for Real-World Applications

  • • Always label your axes clearly
  • • Use appropriate chart types for your data
  • • Consider your audience when choosing colors
  • • Include legends for multiple data series
  • • Test your visualizations on different devices
  • • Use consistent styling across related charts
  • • Include source references when applicable
  • • Optimize for both digital and print media

Subplots & Layouts

Subplots allow you to create multiple plots within a single figure. This is essential for comparing different datasets or showing multiple aspects of your data in one view.

Subplot Methods

plt.subplot(rows, cols, index)

Simple grid layout - specify position by index

plt.subplots()

Create figure and axes together - more flexible

GridSpec

Complex, irregular layouts with custom sizing

Layout Best Practices

Use consistent scales for meaningful comparison
Maintain aspect ratios when possible
Use tight_layout() to prevent overlap
Consider color blindness in design
Group related plots together

Advanced Subplot Layout

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec

# Create complex layout with GridSpec
fig = plt.figure(figsize=(16, 12))
gs = GridSpec(4, 4, figure=fig, hspace=0.3, wspace=0.3)

# Main plot (spans 2x2)
ax_main = fig.add_subplot(gs[0:2, 0:2])

# Side plots
ax_right = fig.add_subplot(gs[0:2, 2])
ax_bottom = fig.add_subplot(gs[0:2, 3])
ax_corner = fig.add_subplot(gs[2, 0:2])
ax_bottom2 = fig.add_subplot(gs[2, 2:])
ax_final = fig.add_subplot(gs[3, :])

# Generate data for main plot
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = np.exp(-x/3) * np.sin(2*x)

# Main plot with multiple functions
ax_main.plot(x, y1, label='sin(x)', color='#1f77b4', linewidth=2)
ax_main.plot(x, y2, label='cos(x)', color='#ff7f0e', linewidth=2)
ax_main.plot(x, y3, label='exp(-x/3)*sin(2x)', color='#2ca02c', linewidth=2)
ax_main.set_title('Main Plot: Function Comparison', fontweight='bold', fontsize=14)
ax_main.legend()
ax_main.grid(True, alpha=0.3)

# Right plot - histogram of sin(x)
ax_right.hist(y1, bins=20, color='lightblue', alpha=0.7, orientation='horizontal')
ax_right.set_title('Distribution', fontweight='bold', fontsize=12)
ax_right.set_ylabel('sin(x) values')

# Bottom plot - bar chart
categories = ['A', 'B', 'C', 'D', 'E']
values = [20, 35, 30, 25, 40]
bars = ax_bottom.bar(categories, values, color='lightcoral', alpha=0.7)
ax_bottom.set_title('Category Analysis', fontweight='bold', fontsize=12)

# Add value labels
for bar in bars:
    height = bar.get_height()
    ax_bottom.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                   f'{height}', ha='center', va='bottom', fontweight='bold')

# Corner plot - scatter
scatter_x = np.random.randn(50)
scatter_y = np.random.randn(50)
colors = np.random.rand(50)
ax_corner.scatter(scatter_x, scatter_y, c=colors, alpha=0.6, s=60)
ax_corner.set_title('Random Scatter', fontweight='bold', fontsize=12)

# Bottom 2 - pie chart
pie_data = [30, 25, 20, 15, 10]
pie_labels = ['Q1', 'Q2', 'Q3', 'Q4', 'Other']
ax_bottom2.pie(pie_data, labels=pie_labels, autopct='%1.1f%%', startangle=90)
ax_bottom2.set_title('Quarterly Distribution', fontweight='bold', fontsize=12)

# Final plot - time series
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
sales_data = np.random.randint(100, 500, 12)
ax_final.plot(months, sales_data, marker='o', linewidth=2, markersize=8, color='purple')
ax_final.fill_between(months, sales_data, alpha=0.3, color='purple')
ax_final.set_title('Monthly Sales Trend', fontweight='bold', fontsize=12)
ax_final.set_ylabel('Sales ($K)')
ax_final.grid(True, alpha=0.3)

plt.suptitle('Advanced Subplot Layout with GridSpec', fontsize=16, fontweight='bold')
plt.show()

Click "Run" to see the subplot layout

Advanced Topics

3D Plotting

Create stunning 3D visualizations for complex data analysis

• Surface plots
• Wireframe plots
• 3D scatter plots
• Contour plots in 3D

Animations

Bring your data to life with smooth animations

• Function animations
• Artist animations
• Interactive controls
• Export to video formats

Specialized Plots

Advanced visualization techniques for specific use cases

• Heatmaps and color maps
• Contour and filled contour plots
• Quiver and stream plots
• Polar coordinate plots

Export & Integration

Save and share your visualizations in various formats

• Multiple formats (PNG, PDF, SVG)
• High DPI for publication
• LaTeX integration
• Web embedding capabilities

3D Surface Plot

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

# Create 3D plot
fig = plt.figure(figsize=(14, 10))

# 1. 3D Surface plot
ax1 = fig.add_subplot(2, 2, 1, projection='3d')
x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2))

surf = ax1.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8, 
                       linewidth=0, antialiased=True)
ax1.set_title('3D Surface: sin(√(x²+y²))', fontweight='bold')
fig.colorbar(surf, ax=ax1, shrink=0.5, aspect=10)

# 2. 3D Scatter plot
ax2 = fig.add_subplot(2, 2, 2, projection='3d')
n = 100
xs = np.random.randn(n)
ys = np.random.randn(n)
zs = xs**2 + ys**2 + np.random.randn(n)
colors = zs
cmap = plt.cm.viridis

scatter = ax2.scatter(xs, ys, zs, c=colors, s=60, alpha=0.6, cmap=cmap)
ax2.set_title('3D Scatter: Random Points', fontweight='bold')
fig.colorbar(scatter, ax=ax2, shrink=0.5, aspect=10)

# 3. Wireframe plot
ax3 = fig.add_subplot(2, 2, 3, projection='3d')
X2, Y2 = np.meshgrid(np.linspace(-3, 3, 20), np.linspace(-3, 3, 20))
Z2 = np.exp(-(X2**2 + Y2**2))
ax3.plot_wireframe(X2, Y2, Z2, color='blue', alpha=0.7)
ax3.set_title('Wireframe: Gaussian', fontweight='bold')

# 4. Contour plot in 3D
ax4 = fig.add_subplot(2, 2, 4, projection='3d')
X3, Y3 = np.meshgrid(np.linspace(-2, 2, 30), np.linspace(-2, 2, 30))
Z3 = X3 * np.exp(-X3**2 - Y3**2)
ax4.contour3D(X3, Y3, Z3, 50, cmap='binary')
ax4.set_title('3D Contour: x*e^(-x²-y²)', fontweight='bold')

plt.tight_layout()
plt.show()

Click "Run" to see the 3D plots

Animation Performance Tips

  • • Use blitting for better performance
  • • Limit the number of animated elements
  • • Optimize data processing in each frame
  • • Consider frame rate vs. smoothness trade-offs
  • • Use appropriate file formats for export
  • • Test on target hardware for performance

Performance Optimization

When working with large datasets or creating complex visualizations, performance becomes crucial. These optimization techniques will help you create faster, more efficient plots.

Speed Optimization

  • • Use vectorized operations instead of loops
  • • Downsample large datasets when possible
  • • Choose appropriate plot types for data size
  • • Use blitting for animations
  • • Optimize marker styles and sizes

Memory Management

  • • Clear figure references when done
  • • Use context managers for file operations
  • • Batch process large datasets
  • • Monitor memory usage with large arrays
  • • Use appropriate data types

Common Pitfalls

  • • Plotting millions of data points unnecessarily
  • • Using scatter plots for very large datasets
  • • Creating animations without optimization
  • • Forgetting to close figures in loops
  • • Using inappropriate chart types

Performance Benchmarks

  • • Line plots: ~10,000 points optimal
  • • Scatter plots: ~100,000 points max
  • • Histograms: ~1M points with proper binning
  • • Animations: 30-60 FPS target

Performance Comparison

import matplotlib.pyplot as plt
import numpy as np
import time

# Performance comparison: different approaches
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))

# Generate large dataset
n = 10000
x = np.linspace(0, 10, n)
y = np.sin(x) + 0.1 * np.random.randn(n)

# 1. Regular plot (baseline)
start_time = time.time()
ax1.plot(x, y, 'b-', linewidth=1)
ax1.set_title('Regular Plot\n(Baseline)', fontweight='bold')
regular_time = time.time() - start_time

# 2. Downsampled plot
step = 10
start_time = time.time()
ax2.plot(x[::step], y[::step], 'r-', linewidth=1, alpha=0.7)
ax2.set_title(f'Downsampled Plot\n(every {step}th point)', fontweight='bold')
downsample_time = time.time() - start_time

# 3. Scatter plot (for comparison)
start_time = time.time()
ax3.scatter(x[::50], y[::50], alpha=0.6, s=1)
ax3.set_title('Scatter Plot\n(sparse)', fontweight='bold')
scatter_time = time.time() - start_time

# 4. Optimized plot with markers
start_time = time.time()
ax4.plot(x[::20], y[::20], 'g-', linewidth=2, marker='o', 
         markersize=3, markevery=5, alpha=0.8)
ax4.set_title('Optimized Plot\n(markers every 5th)', fontweight='bold')
optimized_time = time.time() - start_time

# Add performance metrics
performance_text = f"""
Performance Comparison (for {n:,} points):
• Regular Plot: {regular_time:.3f}s
• Downsampled: {downsample_time:.3f}s
• Scatter Plot: {scatter_time:.3f}s
• Optimized: {optimized_time:.3f}s

Recommendations:
• Use downsampling for large datasets
• Consider scatter plots for sparse data
• Optimize marker placement
• Profile your specific use case
"""

plt.figtext(0.02, 0.02, performance_text, fontsize=10, 
           bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.8))

plt.suptitle('Matplotlib Performance Comparison', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

Click "Run" to see performance comparison