728x90

❤️ 배운 것

개인별 violinplot, pairplot 알고리즘 작성

Iris dataset

import matplotlib.pyplot as plt  
from sklearn.datasets import load_iris  

def get_iris():  

    iris = load_iris()  

    for attr in dir(iris):  
        print(attr)  

    # DESCR  
    # data    # data_module    # feature_names    # filename    # frame    # target    # target_names  
    # 대문자는 행렬, 소문자는 벡터  

    iris_X = iris.data  
    iris_y = iris.target  
    feature_names = iris.feature_names  
    species = iris.target_names  
    n_feature = len(feature_names)  
    n_species = len(species)  

    return iris_X, iris_y, feature_names, species, n_feature, n_species

품종별 feature violinplot

def iris_visualization1():  
    iris_X, iris_y, feature_names, species, n_feature, n_species = get_iris()  

    cls_0 = iris_X[iris_y == 0]  
    cls_1 = iris_X[iris_y == 1]  
    cls_2 = iris_X[iris_y == 2]  
    xticks = np.arange(3)  

    fig, axes = plt.subplots(2, 2, figsize=(14, 14))  

    for i, ax in enumerate(axes.flat):  
        ax.violinplot([cls_0[:, i], cls_1[:, i], cls_2[:, i]],  
                      positions=xticks)  
        ax.set_xticks(xticks)  
        ax.set_xticklabels(species)  
        ax.set_title(feature_names[i], fontsize=20)  
        ax.tick_params(labelsize=20)

feature간 pairplot

def single_pair(axes, row, col, X, y, cls_dict, features):  
    color_list = ['purple', 'green', 'orange']  

    # histogram  
    if row == col:  
        data = X[:, row]  
        axes[row, col].hist(data, rwidth=0.9)  

    # scatter plot  
    else:  
        for key, val in cls_dict.items():  
            axes[row, col].scatter(val[:, col], val[:, row],  
                                   edgecolor=f'tab:{color_list[key]}',  
                                   color=color_list[key], alpha=0.5)  

    # labels  
    if col == 0:  
        axes[row, col].set(ylabel=features[row])  
        axes[row, col].set_ylabel(features[row], fontsize=20)  

    if row == len(features)-1:  
        axes[row, col].set(xlabel=features[col])  
        axes[row, col].set_xlabel(features[col], fontsize=20)  


def iris_pairplot():  
    iris_X, iris_y, feature_names, species, n_feature, n_species = get_iris()  

    fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(16, 16))  

    cls_dict = dict()  
    for cls in np.unique(iris_y):  
        cls_dict[cls] = iris_X[iris_y == cls]  

    for i in range(n_feature):  
        for j in range(n_feature):  
            single_pair(axes, i, j, iris_X, iris_y, cls_dict, feature_names)

Pairplot Refactoring (Full Code)

import matplotlib.pyplot as plt  
from sklearn.datasets import load_iris  
import numpy as np

def get_iris():  

    iris = load_iris()  

    for attr in dir(iris):  
        print(attr)  

    # DESCR  
    # data    # data_module    # feature_names    # filename    # frame    # target    # target_names  
    # 대문자는 행렬, 소문자는 벡터  
    iris_X = iris.data  
    iris_y = iris.target  
    feature_names = iris.feature_names  
    species = iris.target_names  
    n_feature = len(feature_names)  
    n_species = len(species)  

    return iris_X, iris_y, feature_names, species, n_feature, n_species

def single_pair2(axes, row, col, X, y, features):  
    # histogram  
    if row == col:  
        data = X[:, row]  
        axes[row, col].hist(data, rwidth=0.9)  

    # scatter plot  
    else:  
        axes[row, col].scatter(X[:, col], X[:, row], c=y, alpha=0.5)  

    # labels  
    if col == 0:  
        axes[row, col].set_ylabel(features[row], fontsize=20)  
    if row == len(features)-1:  
        axes[row, col].set_xlabel(features[col], fontsize=20)  


def iris_pairplot_clean():  
    # code refactoring  
    iris_X, iris_y, feature_names, species, n_feature, n_species = get_iris()  

    fig, axes = plt.subplots(nrows=n_feature, ncols=n_feature, figsize=(16, 16))

    for i in range(n_feature):  
        for j in range(n_feature):  
            single_pair2(axes, i, j, iris_X, iris_y, feature_names)


if __name__ == '__main__':  
    iris_pairplot_clean()  
    plt.show()

코드 기전

  • get_iris(): data 처리 코드
  • single_pair2(): nxn 페어플랏의 싱글 플랏 차트 1개를 그리는 코드
    1) 대각선(row가 col과 같은 경우)은 histogram
    2) 그 외에는 scatter plot으로 2개의 feature간 교차 산점도 차트 그리기
    3) 라벨링은 좌측 및 하단만 표기
  • iris_pairplot_clean(): 메인 루틴 코드
    1) 데이터 할당
    2) nxn matplotlib 차트 설정 (n: feature 개수)
    3) for loop으로 차트 번호 인덱싱하여 [a, b] 해당하는 인덱스의 개별 차트 그리기

리팩토링 시 고려한 점

1) Code Split
for 문 안에 돌아가는 반복 코드는 별도의 method로 split하고,

method 안에 있는 코드들이 의미적으로 일치하는 위치에 있는지 확인하여 재조정

 

2) Scatter plot c paremeter로 변경
기존에 scatter plot에서 c라는 parameter가 class 구분을 해주는지 모르고 데이터셋을 분리하여

3개의 scatter plot을 중첩했는데

분리해야할 이유가 없다면

-> target을 c에 설정하면 1개의 scatter plot만 그려도 class 별로 색이 구분되기 때문에 통합 데이터로 plotting 할 수 있음

 

3) Label 중복 코드 제거
subplot의 라벨 설정시 set_ylabel 함수 1개만으로 설정 가능

 

4) 하드코딩 제거
nrows=4, ncols=4 같이 숫자를 하드코딩하면 값이 틀어질 때 수정이 안되기때문에

기존의 get_iris 에서 받는 n_feature(피처 개수)를 row,col의 인자로 설정하여 동적으로 변경되도록 수정
fig, axes = plt.subplots(nrows=n_feature, ncols=n_feature, figsize=(16, 16))

리팩토링 후 차트

💛 배운점/느낀점

- pandas 대신 numpy만 사용하여 구현해보았음. ndarray도 pandas의 iloc 처럼 인덱싱, 슬라이싱이 가능함

- pytorch의 tensor가 numpy와 유사하기 때문에, torch를 잘 사용하기 위해서 numpy를 더 잘 컨트롤할 수 있게 연습 필요

- 알고리즘을 작성할 때 먼저 머리속으로 설계하고 구현해보니까 어느정도 작성 후 테스트용으로 라인디버깅을 하게 되고 생산성이 올랐다는 느낌이 들었음

반응형