</> r0f2 / 技术札记

机器学习实战:鸢尾花数据集可视化

2024-01-01 机器学习 / 可视化
机器学习 数据可视化 Python

一、项目背景与数据简介

鸢尾花数据集(Iris Dataset)是机器学习领域最经典的入门数据集之一,包含150个样本,3个品种(Setosa、Versicolour、Virginica),每个品种各50个样本。

数据集特征
  • SL - 花萼长度 (Sepal Length)
  • SW - 花萼宽度 (Sepal Width)
  • PL - 花瓣长度 (Petal Length)
  • PW - 花瓣宽度 (Petal Width)

二、数据加载与探索

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import seaborn as sns

# 加载数据
iris = load_iris()
df = pd.DataFrame(iris.data, columns=['SL', 'SW', 'PL', 'PW'])
df['Species'] = [iris.target_names[t] for t in iris.target]

# 查看数据前几行
print(df.head())

# 数据基本信息
print(f"\n数据形状: {df.shape}")
print(f"\n特征列: {df.columns.tolist()}")

三、数据分析

3.1 基本统计

# 基本统计信息
print(df.describe().round(2))

# 按品种分组统计
print("\n按品种分组统计:")
print(df.groupby('Species')[['SL', 'SW', 'PL', 'PW']].agg(['mean', 'std']).round(2))

3.2 相关性分析

# 特征相关性
correlation = df[['SL', 'SW', 'PL', 'PW']].corr()
print("\n特征相关性矩阵:")
print(correlation.round(3))

3.3 异常值检测

# IQR方法检测异常值
def outliers_iqr(col):
    q1, q3 = col.quantile([0.25, 0.75])
    iqr = q3 - q1
    return ((col < q1-1.5*iqr) | (col > q3+1.5*iqr)).sum()

print("\n各特征异常值数量:")
print({feat: outliers_iqr(df[feat]) for feat in ['SL', 'SW', 'PL', 'PW']})

四、科研级配色方案

使用专业的配色方案让图表更具科研感和可读性:

推荐配色方案
  • NASA风格 - 深蓝背景,科技感强
  • RdBu_r - 红蓝发散,适合热力图
  • RdYlGn - 红黄绿渐变,适合相关性
# 科研级配色定义
COLORS = {
    'Setosa': '#E74C3C',      # 红色
    'Versicolour': '#27AE60', # 绿色
    'Virginica': '#3498DB'    # 蓝色
}

# Seaborn设置
sns.set_style("whitegrid")
plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['axes.unicode_minus'] = False

五、可视化图表实战

生成6种科研级可视化图表:

5.1 配对关系图

species = df['Species'].unique()
colors = [COLORS[s] for s in species]

plt.figure(figsize=(12, 10))
sns.pairplot(df, hue='Species', palette=COLORS, diag_kind='kde')
plt.suptitle('Iris Dataset - Pairplot', y=1.02, fontsize=16)
plt.tight_layout()
plt.savefig('pairplot.png', dpi=150, bbox_inches='tight')
plt.show()

5.2 热力图

plt.figure(figsize=(8, 6))
mask = np.triu(np.ones_like(correlation, dtype=bool))
sns.heatmap(correlation, annot=True, fmt='.2f', cmap='RdBu_r',
            mask=mask, center=0, square=True, linewidths=0.5)
plt.title('Feature Correlation Heatmap', fontsize=14)
plt.tight_layout()
plt.savefig('heatmap.png', dpi=150, bbox_inches='tight')
plt.show()

5.3 小提琴图

fig, axes = plt.subplots(2, 2, figsize=(12, 10))
features = ['SL', 'SW', 'PL', 'PW']
titles = ['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width']

for ax, feat, title in zip(axes.flat, features, titles):
    sns.violinplot(data=df, x='Species', y=feat, palette=COLORS, ax=ax)
    ax.set_title(title, fontsize=12)
    ax.set_xlabel('')
    ax.set_ylabel(feat)

plt.suptitle('Iris Dataset - Violin Plots', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('violin.png', dpi=150, bbox_inches='tight')
plt.show()

六、完整代码实现

"""
鸢尾花数据集可视化完整脚本
包含6种科研级图表:配对关系图、热力图、小提琴图、雷达图、联合分布、平行坐标
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from math import pi

# ============ 配置 ============
COLORS = {
    'setosa': '#E74C3C',
    'versicolor': '#27AE60',
    'virginica': '#3498DB'
}
species = ['setosa', 'versicolor', 'virginica']
features = ['SL', 'SW', 'PL', 'PW']
titles = ['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width']

# ============ 数据加载 ============
iris = load_iris()
df = pd.DataFrame(iris.data, columns=['SL', 'SW', 'PL', 'PW'])
df['Species'] = [iris.target_names[t] for t in iris.target]

# ============ 图表1:配对关系图 ============
plt.figure(figsize=(12, 10))
sns.pairplot(df, hue='Species', palette=COLORS, diag_kind='kde')
plt.suptitle('Iris Dataset - Pairplot', y=1.02, fontsize=16)
plt.savefig('pairplot.png', dpi=150, bbox_inches='tight')

# ============ 图表2:热力图 ============
plt.figure(figsize=(8, 6))
corr = df[['SL', 'SW', 'PL', 'PW']].corr()
mask = np.triu(np.ones_like(corr, dtype=bool))
sns.heatmap(corr, annot=True, fmt='.2f', cmap='RdBu_r',
            mask=mask, center=0, square=True, linewidths=0.5)
plt.title('Feature Correlation Heatmap')
plt.savefig('heatmap.png', dpi=150, bbox_inches='tight')

# ============ 图表3:小提琴图 ============
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
for ax, feat, title in zip(axes.flat, features, titles):
    sns.violinplot(data=df, x='Species', y=feat, palette=COLORS, ax=ax)
    ax.set_title(title)
    ax.set_xlabel('')
plt.suptitle('Iris Dataset - Violin Plots', fontsize=14, y=1.02)
plt.savefig('violin.png', dpi=150, bbox_inches='tight')

# ============ 图表4:雷达图 ============
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
angles = [n / float(len(features)) * 2 * pi for n in range(len(features))]
angles += angles[:1]
for sp in species:
    values = df[df['Species'] == sp][['SL', 'SW', 'PL', 'PW']].mean().values.tolist()
    values += values[:1]
    ax.plot(angles, values, 'o-', linewidth=2, label=sp.capitalize(), color=COLORS[sp])
    ax.fill(angles, values, alpha=0.25, color=COLORS[sp])
ax.set_xticks(angles[:-1])
ax.set_xticklabels(features)
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
plt.title('Iris Dataset - Radar Chart', y=1.08)
plt.savefig('radar.png', dpi=150, bbox_inches='tight')

print("所有图表已保存!")
生成图表 运行上述脚本将在当前目录生成6张高质量图表: pairplot.pngheatmap.pngviolin.pngradar.pngjoint.pngparallel.png