ECG Heartbeat Categorization Dataset

Author

SEOYEON CHOI

Published

August 18, 2025

ECG Heartbeat Categorization Dataset

Abstract - This dataset is composed of two collections of heartbeat signals derived from two famous datasets in heartbeat classification, the MIT-BIH Arrhythmia Dataset and The PTB Diagnostic ECG Database. The number of samples in both collections is large enough for training a deep neural network. - This dataset has been used in exploring heartbeat classification using deep neural network architectures, and observing some of the capabilities of transfer learning on it. The signals correspond to electrocardiogram (ECG) shapes of heartbeats for the normal case and the cases affected by different arrhythmias and myocardial infarction. These signals are preprocessed and segmented, with each segment corresponding to a heartbeat.

Content

The PTB Diagnostic ECG Database - Number of Samples: 14552 - Number of Categories: 2 - Sampling Frequency: 125Hz - Data Source: Physionet’s PTB Diagnostic Database - Remark: All the samples are cropped, downsampled and padded with zeroes if necessary to the fixed dimension of 188.

Data Files - This dataset consists of a series of CSV files. Each of these CSV files contain a matrix, with each row representing an example in that portion of the dataset. The final element of each row denotes the class to which that example belongs.

Acknowledgements - Mohammad Kachuee, Shayan Fazeli, and Majid Sarrafzadeh. “ECG Heartbeat Classification: A Deep Transferable Representation.” arXiv preprint arXiv:1805.00794 (2018).

Inspiration - Can you identify myocardial infarction?

Import

import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler

from sklearn.model_selection import train_test_split

import pywt

from scipy.stats import f_oneway

import scikit_posthocs as sp

from scipy.stats import kruskal

from scipy.stats import skew, kurtosis
from scipy.fft import rfft, rfftfreq

데이터 이해 & 전처리

데이터셋 구성 확인

df = pd.read_csv('../../../delete/mitbih_test.csv',header=None)
df.head(5)
0 1 2 3 4 5 6 7 8 9 ... 178 179 180 181 182 183 184 185 186 187
0 1.000000 0.758264 0.111570 0.000000 0.080579 0.078512 0.066116 0.049587 0.047521 0.035124 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1 0.908425 0.783883 0.531136 0.362637 0.366300 0.344322 0.333333 0.307692 0.296703 0.300366 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2 0.730088 0.212389 0.000000 0.119469 0.101770 0.101770 0.110619 0.123894 0.115044 0.132743 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3 1.000000 0.910417 0.681250 0.472917 0.229167 0.068750 0.000000 0.004167 0.014583 0.054167 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
4 0.570470 0.399329 0.238255 0.147651 0.000000 0.003356 0.040268 0.080537 0.070470 0.090604 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

5 rows × 188 columns

df.shape
(21892, 188)
sample_idx = np.random.choice(df.shape[0], 9, replace=False)

plt.figure(figsize=(12, 10))

for i, idx in enumerate(sample_idx, 1):
    ecg = df.iloc[idx, :-1].values if df.shape[1] > 188 else df.iloc[idx, :].values
    
    plt.subplot(3, 3, i)
    plt.plot(ecg, color="blue")
    plt.title(f"Sample {idx}")
    plt.xlabel("Time index (0-187)")
    plt.ylabel("Amplitude")

plt.tight_layout()
plt.show()

df[187].unique()
array([0., 1., 2., 3., 4.])
  • MIT-BIH (부정맥 → 5 classes)

전처리

  • Normalization (z-score scaling)
X = df.iloc[:, :-1].values
y = df.iloc[:, -1].values

탐색적 데이터 분석 (EDA)

평균 파형 시각화

label_col = df.columns[-1]
feature_cols = df.columns[:-1]
n_points = len(feature_cols)
fs = 125.0
t = np.arange(n_points) / fs
label_map = {0:'N', 1:'S', 2:'V', 3:'F', 4:'Q'}
# 라벨별 평균 파형 계산
means = []
labels_present = sorted(df[label_col].unique())
for y in labels_present:
    m = df.loc[df[label_col] == y, feature_cols].mean(axis=0).values
    means.append((y, m))

# 한 그래프에 라벨별 평균 파형 그리기
plt.figure(figsize=(12,6))
for y, m in means:
    plt.plot(t, m, label=f"{label_map.get(int(y), y)} ({int(y)})")
plt.title("Average ECG Waveform by Class")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

스펙트럼 분석 - FFT or Wavelet 변환을 통한 frequency 특징 비교.

FFT

# ----- 기본 설정 -----
fs = 125.0
label_col = df.columns[-1]
feature_cols = df.columns[:-1]
N = len(feature_cols)               # 188
t = np.arange(N) / fs
freqs = np.fft.rfftfreq(N, d=1/fs)  # 0 ~ Nyquist(62.5 Hz)
window = np.hanning(N)              # 한닝 윈도우(누설 감소)

# 라벨 매핑(가독성)
label_map = {0:'N', 1:'S', 2:'V', 3:'F', 4:'Q'}

# ----- 클래스별 평균 파워 스펙트럼 -----
plt.figure(figsize=(12,6))
for y in sorted(df[label_col].unique()):
    X = df.loc[df[label_col] == y, feature_cols].values    # (n_class_samples, N)
    # detrend(평균 제거) + windowing + FFT
    X_detrended = X - X.mean(axis=1, keepdims=True)
    Xw = X_detrended * window
    F = np.fft.rfft(Xw, axis=1)
    P = (np.abs(F)**2) / (window**2).sum()                 # 파워 정규화
    P_mean = P.mean(axis=0)

    plt.plot(freqs, P_mean, label=f"{label_map.get(int(y), y)} ({int(y)})")

plt.title("Average Power Spectrum by Class (FFT)")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

Wavelet

# ----- DWT 파라미터 -----
wavelet = 'db4'
level = 4 # 188 포인트면 4~5레벨 적당

# 한 샘플에서 서브밴드 에너지 벡터 추출
def wavelet_band_energy(x):
    # 평균 제거(기저선 영향 완화)
    x = x - x.mean()
    coeffs = pywt.wavedec(x, wavelet=wavelet, level=level)
    # coeffs = [cA_L, cD_L, cD_{L-1}, ..., cD_1]
    energies = [np.sum(c**2) for c in coeffs]
    total = np.sum(energies) + 1e-12
    # 상대 에너지(합=1)로 반환하면 샘플 간 진폭 차이 영향 완화
    return np.array(energies) / total

# 전체 샘플 처리
feat_list = []
labels = []
for _, row in df.iterrows():
    x = row[feature_cols].values.astype(float)
    e = wavelet_band_energy(x)
    feat_list.append(e)
    labels.append(int(row[label_col]))

E = np.vstack(feat_list)   # shape: (n_samples, 1 + level)
y = np.array(labels)

# 클래스별 평균 상대 에너지
classes = sorted(np.unique(y))
E_mean = {c: E[y==c].mean(axis=0) for c in classes}

# 플롯: 각 클래스의 subband 에너지 비교 (막대그래프)
bands = [f"cA_{level}"] + [f"cD_{i}" for i in range(level, 0, -1)]  # 상위 스케일→하위 스케일
x_pos = np.arange(len(bands))

plt.figure(figsize=(12,6))
bar_width = 0.14
for i, c in enumerate(classes):
    plt.bar(x_pos + i*bar_width, E_mean[c], width=bar_width, label=f"{c}")

plt.xticks(x_pos + bar_width*(len(classes)-1)/2, bands)
plt.ylabel("Relative Energy")
plt.title(f"Wavelet Subband Energy by Class (wavelet={wavelet}, level={level})")
plt.legend(title="Class (0=N,1=S,2=V,3=F,4=Q)")
plt.tight_layout()
plt.show()

통계적 비교 - Classes: [‘N’: 0, ‘S’: 1, ‘V’: 2, ‘F’: 3, ‘Q’: 4] 간에 amplitude, variance, QRS duration 등 특징 차이 통계 검정 (t-test, Mann-Whitney).

df['amp_mean'] = df.iloc[:, :-1].mean(axis=1)
df['amp_max']  = df.iloc[:, :-1].max(axis=1)
df['amp_min']  = df.iloc[:, :-1].min(axis=1)
df['amp_range'] = df['amp_max'] - df['amp_min']
df['variance'] = df.iloc[:, :-1].var(axis=1)
def qrs_duration(signal, threshold_ratio=0.5):
    # R peak amplitude 추정
    peak = np.max(signal)
    threshold = peak * threshold_ratio
    above = np.where(signal >= threshold)[0]
    if len(above) == 0:
        return 0
    return above[-1] - above[0]   # index 차이 (포인트 수)

df['qrs_dur'] = df.iloc[:, :-1].apply(lambda row: qrs_duration(row.values), axis=1)
label_map = {0:'N', 1:'S', 2:'V', 3:'F', 4:'Q'}

desc = df.groupby(df.iloc[:, -1])[['amp_mean','variance','qrs_dur']].describe()
groups = [df.loc[df.iloc[:, -1] == c, 'amp_mean'] for c in sorted(df.iloc[:, -1].unique())]
anova_amp = f_oneway(*groups)

groups = [df.loc[df.iloc[:, -1] == c, 'variance'] for c in sorted(df.iloc[:, -1].unique())]
anova_var = f_oneway(*groups)

groups = [df.loc[df.iloc[:, -1] == c, 'qrs_dur'] for c in sorted(df.iloc[:, -1].unique())]
anova_qrs = f_oneway(*groups)

print("ANOVA amp_mean:", anova_amp)
print("ANOVA variance:", anova_var)
print("ANOVA qrs_dur:", anova_qrs)
ANOVA amp_mean: F_onewayResult(statistic=97.23061517562662, pvalue=0.0)
ANOVA variance: F_onewayResult(statistic=673.9205351617329, pvalue=0.0)
ANOVA qrs_dur: F_onewayResult(statistic=inf, pvalue=0.0)
/home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages/scipy/stats/_stats_py.py:3903: ConstantInputWarning: Each of the input arrays is constant;the F statistic is not defined or infinite
  warnings.warn(stats.ConstantInputWarning(msg))
groups = [df.loc[df.iloc[:, -1] == c, 'amp_mean'] for c in sorted(df.iloc[:, -1].unique())]
kw_amp = kruskal(*groups)

groups = [df.loc[df.iloc[:, -1] == c, 'variance'] for c in sorted(df.iloc[:, -1].unique())]
kw_var = kruskal(*groups)

groups = [df.loc[df.iloc[:, -1] == c, 'qrs_dur'] for c in sorted(df.iloc[:, -1].unique())]
kw_qrs = kruskal(*groups)

print("Kruskal amp_mean:", kw_amp)
print("Kruskal variance:", kw_var)
print("Kruskal qrs_dur:", kw_qrs)
Kruskal amp_mean: KruskalResult(statistic=4731.5315295296, pvalue=0.0)
Kruskal variance: KruskalResult(statistic=8376.756741998353, pvalue=0.0)
Kruskal qrs_dur: KruskalResult(statistic=21891.0, pvalue=0.0)
posthoc_amp = sp.posthoc_dunn(df, val_col='amp_mean', group_col=df.columns[-1], p_adjust='bonferroni')
print(posthoc_amp)
               4    5    6    7    9    10   12   14   15   16   ...  182  \
4     1.000000e+00  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
5     1.000000e+00  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
6     1.000000e+00  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
7     1.000000e+00  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
9     1.000000e+00  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
..             ...  ...  ...  ...  ...  ...  ...  ...  ...  ...  ...  ...   
187   1.353411e-35  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
188   3.901249e-05  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
189   1.318389e-12  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
190   2.638355e-03  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
191  4.113517e-283  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   

     183       184           185            186            187           188  \
4    1.0  1.000000  1.000000e+00   7.231955e-26   1.353411e-35  3.901249e-05   
5    1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
6    1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
7    1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
9    1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
..   ...       ...           ...            ...            ...           ...   
187  1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
188  1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
189  1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
190  1.0  1.000000  1.000000e+00   3.018640e-04   6.368383e-07  1.000000e+00   
191  1.0  0.000027  4.575056e-07  8.248003e-126  8.271528e-161  2.793909e-34   

              189           190            191  
4    1.318389e-12  2.638355e-03  4.113517e-283  
5    1.000000e+00  1.000000e+00   1.000000e+00  
6    1.000000e+00  1.000000e+00   1.000000e+00  
7    1.000000e+00  1.000000e+00   1.000000e+00  
9    1.000000e+00  1.000000e+00   1.000000e+00  
..            ...           ...            ...  
187  1.000000e+00  6.368383e-07  8.271528e-161  
188  1.000000e+00  1.000000e+00   2.793909e-34  
189  1.000000e+00  3.952042e-01   1.699180e-68  
190  3.952042e-01  1.000000e+00   2.442212e-88  
191  1.699180e-68  2.442212e-88   1.000000e+00  

[181 rows x 181 columns]
# 예: amp_mean 특징을 클래스별 비교
posthoc_amp = sp.posthoc_dunn(
    df, 
    val_col='amp_mean', 
    group_col=df.columns[-1],  # 마지막 열이 라벨
    p_adjust='bonferroni'
)

print("Dunn post-hoc test (amp_mean):")
print(posthoc_amp)
Dunn post-hoc test (amp_mean):
               4    5    6    7    9    10   12   14   15   16   ...  182  \
4     1.000000e+00  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
5     1.000000e+00  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
6     1.000000e+00  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
7     1.000000e+00  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
9     1.000000e+00  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
..             ...  ...  ...  ...  ...  ...  ...  ...  ...  ...  ...  ...   
187   1.353411e-35  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
188   3.901249e-05  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
189   1.318389e-12  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
190   2.638355e-03  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   
191  4.113517e-283  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  ...  1.0   

     183       184           185            186            187           188  \
4    1.0  1.000000  1.000000e+00   7.231955e-26   1.353411e-35  3.901249e-05   
5    1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
6    1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
7    1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
9    1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
..   ...       ...           ...            ...            ...           ...   
187  1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
188  1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
189  1.0  1.000000  1.000000e+00   1.000000e+00   1.000000e+00  1.000000e+00   
190  1.0  1.000000  1.000000e+00   3.018640e-04   6.368383e-07  1.000000e+00   
191  1.0  0.000027  4.575056e-07  8.248003e-126  8.271528e-161  2.793909e-34   

              189           190            191  
4    1.318389e-12  2.638355e-03  4.113517e-283  
5    1.000000e+00  1.000000e+00   1.000000e+00  
6    1.000000e+00  1.000000e+00   1.000000e+00  
7    1.000000e+00  1.000000e+00   1.000000e+00  
9    1.000000e+00  1.000000e+00   1.000000e+00  
..            ...           ...            ...  
187  1.000000e+00  6.368383e-07  8.271528e-161  
188  1.000000e+00  1.000000e+00   2.793909e-34  
189  1.000000e+00  3.952042e-01   1.699180e-68  
190  3.952042e-01  1.000000e+00   2.442212e-88  
191  1.699180e-68  2.442212e-88   1.000000e+00  

[181 rows x 181 columns]
# Boxplot
plt.figure(figsize=(8,6))
sns.boxplot(x=df.iloc[:, -1], y=df['amp_mean'])
plt.title("Amplitude Mean by Class")
plt.xlabel("Class (0=N,1=S,2=V,3=F,4=Q)")
plt.ylabel("amp_mean")
plt.show()

# Posthoc heatmap
plt.figure(figsize=(6,5))
sns.heatmap(posthoc_amp, annot=True, cmap="coolwarm", cbar=True, fmt=".3f")
plt.title("Post-hoc Dunn Test (p-values)")
plt.show()

특징 기반 머신러닝 접근

특징 추출 - 시간영역: mean, std, skewness, kurtosis, peak-to-peak - 주파수영역: dominant frequency, power spectrum - Wavelet: ST segment 관련 coefficient

def extract_time_features(x):
    return {
        'mean': float(np.mean(x)),
        'std': float(np.std(x)),
        'skew': float(skew(x)),
        'kurtosis': float(kurtosis(x)),
        'ptp': float(np.ptp(x)),  # peak-to-peak
    }

def extract_freq_features(x, fs=125):
    N = len(x)
    freqs = rfftfreq(N, d=1/fs)
    spectrum = np.abs(rfft(x))**2
    dom_freq = float(freqs[np.argmax(spectrum)])
    power = float(np.sum(spectrum))
    return {'dom_freq': dom_freq, 'power': power}

def extract_wavelet_features(x, wavelet='db4', level=4):
    x = x - x.mean()
    coeffs = pywt.wavedec(x, wavelet=wavelet, level=level)
    energies = [np.sum(c**2) for c in coeffs]
    total = np.sum(energies) + 1e-12
    rel = [float(e/total) for e in energies]
    # cA_L, cD_L, ..., cD_1
    names = [f'wavelet_energy_cA_{level}'] + [f'wavelet_energy_cD_{i}' for i in range(level, 0, -1)]
    return dict(zip(names, rel))

def extract_features(x, fs=125):
    feats = {}
    feats.update(extract_time_features(x))
    feats.update(extract_freq_features(x, fs=fs))
    feats.update(extract_wavelet_features(x))
    return feats
# ---- 라벨/피처 컬럼 안전하게 정의 ----
label_col = df.columns[-1]              # 마지막 열을 라벨로 사용
feature_cols = df.columns[:-1]          # 나머지는 ECG 시계열

# ---- 넘파이로 뽑아서 루프 (빠르고 안전) ----
X = df.loc[:, feature_cols].to_numpy(dtype=float)
y = df.loc[:, label_col].to_numpy()

feature_list = []
labels = []

for xi, yi in zip(X, y):
    feats = extract_features(xi, fs=125)
    feature_list.append(feats)
    labels.append(int(yi))

feature_df = pd.DataFrame(feature_list)
feature_df['label'] = labels

print(feature_df.head())
print(feature_df.shape)
       mean       std      skew   kurtosis  ptp  dom_freq        power  \
0  0.088813  0.165259  4.148272  18.524207  1.0       0.0   802.459636   
1  0.201932  0.247202  1.129980   0.747387  1.0       0.0  2657.005592   
2  0.119641  0.163150  2.984116  12.369159  1.0       0.0  1028.928658   
3  0.168079  0.222119  1.874741   3.945423  1.0       0.0  1971.182895   
4  0.207004  0.219523  1.185874   1.866668  1.0       0.0  2493.665686   

   wavelet_energy_cA_4  wavelet_energy_cD_4  wavelet_energy_cD_3  \
0             0.581605             0.070714             0.143888   
1             0.781645             0.074139             0.056645   
2             0.494606             0.057453             0.107020   
3             0.792486             0.068899             0.062562   
4             0.613792             0.090211             0.101761   

   wavelet_energy_cD_2  wavelet_energy_cD_1  label  
0             0.068162             0.135632    191  
1             0.009751             0.077819    191  
2             0.097060             0.243862    191  
3             0.008464             0.067589    191  
4             0.025589             0.168648    191  
(21892, 13)

ML 모델링 - Logistic Regression, Random Forest, XGBoost - 목표: baseline accuracy/ROC-AUC 확인

딥러닝 접근 (Raw Signal 사용)

1D CNN

  • Local ECG pattern (QRS, ST-segment)을 자동으로 학습.
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf
from tensorflow.keras import layers, models
# --- 라벨/피처 분리 ---
label_col = df.columns[187]
feature_cols = df.columns[:186]

X = df[feature_cols].to_numpy(dtype=np.float32)   # (N, L)
y = df[label_col].to_numpy(dtype=np.int64)        # (N, )

# # --- 샘플별 Z-score 표준화(파형 모양만 남기기) ---
eps = 1e-8
mu = X.mean(axis=1, keepdims=True)
sd = X.std(axis=1, keepdims=True)
X = (X - mu) / (sd + eps)

# # --- 80:20 split → train에서 20%를 val (최종 64/16/20) ---
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.20, random_state=42, stratify=y)
X_tr, X_va, y_tr, y_va = train_test_split(X_tr, y_tr, test_size=0.20, random_state=42, stratify=y_tr)
classes = np.unique(y)
class_weights = compute_class_weight('balanced', classes=classes, y=y_tr)
class_weight = {int(c): float(w) for c, w in zip(classes, class_weights)}
print("class_weight:", class_weight)
class_weight: {0: 0.24165588615782665, 1: 7.870786516853933, 2: 3.025917926565875, 3: 26.942307692307693, 4: 2.7230320699708455}
# 2) 모델: 1D CNN (QRS, ST 등 로컬 패턴 자동 학습)
L = X_tr.shape[1]  # 시퀀스 길이
n_classes = len(np.unique(y))

def build_1d_cnn(input_length=L, n_classes=n_classes):
    inp = layers.Input(shape=(input_length, 1))
    x = layers.Conv1D(32, kernel_size=7, padding='same')(inp)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.MaxPool1D(2)(x)
    x = layers.Dropout(0.2)(x)

    x = layers.Conv1D(64, kernel_size=5, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.MaxPool1D(2)(x)
    x = layers.Dropout(0.3)(x)

    x = layers.Conv1D(128, kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.GlobalAveragePooling1D()(x)

    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.4)(x)
    out = layers.Dense(n_classes, activation='softmax')(x)

    model = models.Model(inp, out)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy']
    )
    return model

model = build_1d_cnn()
model.summary()
Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_3 (InputLayer)        [(None, 186, 1)]          0         
                                                                 
 conv1d_6 (Conv1D)           (None, 186, 32)           256       
                                                                 
 batch_normalization_6 (Batc  (None, 186, 32)          128       
 hNormalization)                                                 
                                                                 
 re_lu_6 (ReLU)              (None, 186, 32)           0         
                                                                 
 max_pooling1d_4 (MaxPooling  (None, 93, 32)           0         
 1D)                                                             
                                                                 
 dropout_6 (Dropout)         (None, 93, 32)            0         
                                                                 
 conv1d_7 (Conv1D)           (None, 93, 64)            10304     
                                                                 
 batch_normalization_7 (Batc  (None, 93, 64)           256       
 hNormalization)                                                 
                                                                 
 re_lu_7 (ReLU)              (None, 93, 64)            0         
                                                                 
 max_pooling1d_5 (MaxPooling  (None, 46, 64)           0         
 1D)                                                             
                                                                 
 dropout_7 (Dropout)         (None, 46, 64)            0         
                                                                 
 conv1d_8 (Conv1D)           (None, 46, 128)           24704     
                                                                 
 batch_normalization_8 (Batc  (None, 46, 128)          512       
 hNormalization)                                                 
                                                                 
 re_lu_8 (ReLU)              (None, 46, 128)           0         
                                                                 
 global_average_pooling1d_2   (None, 128)              0         
 (GlobalAveragePooling1D)                                        
                                                                 
 dense_4 (Dense)             (None, 128)               16512     
                                                                 
 dropout_8 (Dropout)         (None, 128)               0         
                                                                 
 dense_5 (Dense)             (None, 5)                 645       
                                                                 
=================================================================
Total params: 53,317
Trainable params: 52,869
Non-trainable params: 448
_________________________________________________________________
  • 성능 높이기
    • 한 층 더 추가한다.
    • 23->64->128… 채널 수를 늘린다.
# 3) 학습
cb = [
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True)
]
hist = model.fit(
    X_tr, y_tr,
    validation_data=(X_va, y_va),
    epochs=50,
    batch_size=256,
    class_weight=class_weight,
    callbacks=cb,
    verbose=1
)
Epoch 1/50
55/55 [==============================] - 4s 41ms/step - loss: 1.3983 - accuracy: 0.2974 - val_loss: 1.6888 - val_accuracy: 0.0234 - lr: 0.0010
Epoch 2/50
55/55 [==============================] - 2s 36ms/step - loss: 1.0734 - accuracy: 0.4141 - val_loss: 2.2370 - val_accuracy: 0.0305 - lr: 0.0010
Epoch 3/50
55/55 [==============================] - 2s 36ms/step - loss: 0.9189 - accuracy: 0.4235 - val_loss: 2.7122 - val_accuracy: 0.0311 - lr: 0.0010
Epoch 4/50
55/55 [==============================] - 2s 37ms/step - loss: 0.8343 - accuracy: 0.4803 - val_loss: 2.9398 - val_accuracy: 0.0400 - lr: 0.0010
Epoch 5/50
55/55 [==============================] - 2s 37ms/step - loss: 0.7863 - accuracy: 0.4909 - val_loss: 2.7439 - val_accuracy: 0.0403 - lr: 5.0000e-04
Epoch 6/50
55/55 [==============================] - 11s 200ms/step - loss: 0.7505 - accuracy: 0.5074 - val_loss: 2.6304 - val_accuracy: 0.0514 - lr: 5.0000e-04
Epoch 7/50
55/55 [==============================] - 12s 220ms/step - loss: 0.7204 - accuracy: 0.5488 - val_loss: 2.3196 - val_accuracy: 0.1048 - lr: 5.0000e-04
Epoch 8/50
55/55 [==============================] - 2s 34ms/step - loss: 0.7089 - accuracy: 0.5453 - val_loss: 1.8006 - val_accuracy: 0.2643 - lr: 2.5000e-04
Epoch 9/50
55/55 [==============================] - 2s 36ms/step - loss: 0.6870 - accuracy: 0.5555 - val_loss: 1.5701 - val_accuracy: 0.3394 - lr: 2.5000e-04
Epoch 10/50
55/55 [==============================] - 2s 36ms/step - loss: 0.6709 - accuracy: 0.5479 - val_loss: 1.3493 - val_accuracy: 0.4099 - lr: 2.5000e-04
Epoch 11/50
55/55 [==============================] - 2s 35ms/step - loss: 0.6614 - accuracy: 0.5671 - val_loss: 1.4117 - val_accuracy: 0.3954 - lr: 2.5000e-04
Epoch 12/50
55/55 [==============================] - 2s 36ms/step - loss: 0.6540 - accuracy: 0.5715 - val_loss: 1.3136 - val_accuracy: 0.4116 - lr: 2.5000e-04
Epoch 13/50
55/55 [==============================] - 2s 37ms/step - loss: 0.6511 - accuracy: 0.5828 - val_loss: 1.1824 - val_accuracy: 0.4382 - lr: 2.5000e-04
Epoch 14/50
55/55 [==============================] - 14s 258ms/step - loss: 0.6341 - accuracy: 0.6059 - val_loss: 1.1752 - val_accuracy: 0.4599 - lr: 2.5000e-04
Epoch 15/50
55/55 [==============================] - 8s 138ms/step - loss: 0.6238 - accuracy: 0.5824 - val_loss: 1.1429 - val_accuracy: 0.4950 - lr: 2.5000e-04
Epoch 16/50
55/55 [==============================] - 2s 36ms/step - loss: 0.6284 - accuracy: 0.5942 - val_loss: 1.0967 - val_accuracy: 0.4747 - lr: 2.5000e-04
Epoch 17/50
55/55 [==============================] - 2s 35ms/step - loss: 0.6149 - accuracy: 0.6060 - val_loss: 1.1215 - val_accuracy: 0.4967 - lr: 2.5000e-04
Epoch 18/50
55/55 [==============================] - 2s 35ms/step - loss: 0.6041 - accuracy: 0.6247 - val_loss: 1.2169 - val_accuracy: 0.4459 - lr: 2.5000e-04
Epoch 19/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5890 - accuracy: 0.6167 - val_loss: 1.1233 - val_accuracy: 0.4425 - lr: 2.5000e-04
Epoch 20/50
55/55 [==============================] - 2s 37ms/step - loss: 0.5924 - accuracy: 0.6163 - val_loss: 1.1555 - val_accuracy: 0.4371 - lr: 1.2500e-04
Epoch 21/50
55/55 [==============================] - 5s 88ms/step - loss: 0.5795 - accuracy: 0.6340 - val_loss: 1.1013 - val_accuracy: 0.4804 - lr: 1.2500e-04
Epoch 22/50
55/55 [==============================] - 15s 269ms/step - loss: 0.5854 - accuracy: 0.6367 - val_loss: 1.1438 - val_accuracy: 0.4459 - lr: 1.2500e-04
Epoch 23/50
55/55 [==============================] - 4s 75ms/step - loss: 0.5819 - accuracy: 0.6288 - val_loss: 1.1051 - val_accuracy: 0.4859 - lr: 6.2500e-05
Epoch 24/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5769 - accuracy: 0.6416 - val_loss: 1.0771 - val_accuracy: 0.4990 - lr: 6.2500e-05
Epoch 25/50
55/55 [==============================] - 2s 35ms/step - loss: 0.5672 - accuracy: 0.6254 - val_loss: 1.0909 - val_accuracy: 0.4927 - lr: 6.2500e-05
Epoch 26/50
55/55 [==============================] - 2s 35ms/step - loss: 0.5637 - accuracy: 0.6376 - val_loss: 1.0779 - val_accuracy: 0.4921 - lr: 6.2500e-05
Epoch 27/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5760 - accuracy: 0.6438 - val_loss: 1.0539 - val_accuracy: 0.5090 - lr: 6.2500e-05
Epoch 28/50
55/55 [==============================] - 2s 37ms/step - loss: 0.5686 - accuracy: 0.6291 - val_loss: 1.0418 - val_accuracy: 0.5136 - lr: 6.2500e-05
Epoch 29/50
55/55 [==============================] - 9s 164ms/step - loss: 0.5743 - accuracy: 0.6503 - val_loss: 1.0573 - val_accuracy: 0.5104 - lr: 6.2500e-05
Epoch 30/50
55/55 [==============================] - 14s 251ms/step - loss: 0.5613 - accuracy: 0.6415 - val_loss: 1.0678 - val_accuracy: 0.5090 - lr: 6.2500e-05
Epoch 31/50
55/55 [==============================] - 2s 35ms/step - loss: 0.5497 - accuracy: 0.6428 - val_loss: 1.0438 - val_accuracy: 0.5201 - lr: 6.2500e-05
Epoch 32/50
55/55 [==============================] - 2s 35ms/step - loss: 0.5624 - accuracy: 0.6422 - val_loss: 1.0511 - val_accuracy: 0.5161 - lr: 3.1250e-05
Epoch 33/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5558 - accuracy: 0.6507 - val_loss: 1.0276 - val_accuracy: 0.5233 - lr: 3.1250e-05
Epoch 34/50
55/55 [==============================] - 2s 35ms/step - loss: 0.5521 - accuracy: 0.6499 - val_loss: 1.0458 - val_accuracy: 0.5144 - lr: 3.1250e-05
Epoch 35/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5578 - accuracy: 0.6471 - val_loss: 1.0451 - val_accuracy: 0.5101 - lr: 3.1250e-05
Epoch 36/50
55/55 [==============================] - 2s 39ms/step - loss: 0.5499 - accuracy: 0.6488 - val_loss: 1.0554 - val_accuracy: 0.5059 - lr: 3.1250e-05
Epoch 37/50
55/55 [==============================] - 14s 258ms/step - loss: 0.5488 - accuracy: 0.6470 - val_loss: 1.0474 - val_accuracy: 0.5107 - lr: 1.5625e-05
Epoch 38/50
55/55 [==============================] - 8s 154ms/step - loss: 0.5546 - accuracy: 0.6442 - val_loss: 1.0524 - val_accuracy: 0.5081 - lr: 1.5625e-05
Epoch 39/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5499 - accuracy: 0.6490 - val_loss: 1.0415 - val_accuracy: 0.5158 - lr: 1.5625e-05
Epoch 40/50
55/55 [==============================] - 2s 35ms/step - loss: 0.5419 - accuracy: 0.6495 - val_loss: 1.0420 - val_accuracy: 0.5161 - lr: 7.8125e-06
Epoch 41/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5347 - accuracy: 0.6526 - val_loss: 1.0405 - val_accuracy: 0.5178 - lr: 7.8125e-06
# 4) 평가: Accuracy / ROC-AUC(One-vs-Rest)
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix
from sklearn.preprocessing import label_binarize

# 예측
proba_te = model.predict(X_te)
pred_te  = proba_te.argmax(axis=1)

# Accuracy
acc = accuracy_score(y_te, pred_te)
print("Test Accuracy:", acc)

# ROC-AUC (다중클래스 → OvR)
y_te_bin = label_binarize(y_te, classes=np.arange(n_classes))
auc = roc_auc_score(y_te_bin, proba_te, average='macro', multi_class='ovr')
print("Test ROC-AUC (macro OvR):", auc)

# 추가 리포트/매트릭스
print(classification_report(y_te, pred_te, digits=3))
print(confusion_matrix(y_te, pred_te))
137/137 [==============================] - 1s 4ms/step
Test Accuracy: 0.5181548298698333
Test ROC-AUC (macro OvR): 0.9432353917756396
              precision    recall  f1-score   support

           0      0.987     0.445     0.614      3624
           1      0.078     0.829     0.143       111
           2      0.446     0.807     0.574       290
           3      0.041     0.938     0.079        32
           4      0.949     0.929     0.939       322

    accuracy                          0.518      4379
   macro avg      0.500     0.789     0.470      4379
weighted avg      0.918     0.518     0.619      4379

[[1614 1080  276  640   14]
 [   9   92    4    6    0]
 [   2    6  234   46    2]
 [   0    1    1   30    0]
 [  11    0   10    2  299]]

LSTM/GRU

  • ECG 신호의 순차적 dependency 반영.

LSTM

# import tensorflow as tf
# from tensorflow.keras import layers, models

L = X_tr.shape[1]              # 188
n_classes = len(np.unique(y))  # 5

def build_bilstm(input_len=L, n_classes=n_classes):
    inp = layers.Input(shape=(input_len, 1))
    x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(inp)
    x = layers.Dropout(0.3)(x)
    x = layers.Bidirectional(layers.LSTM(64))(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.4)(x)
    out = layers.Dense(n_classes, activation='softmax')(x)
    model = models.Model(inp, out)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy']
    )
    return model

lstm_model = build_bilstm()
lstm_model.summary()
Model: "model_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_4 (InputLayer)        [(None, 186, 1)]          0         
                                                                 
 bidirectional (Bidirectiona  (None, 186, 128)         33792     
 l)                                                              
                                                                 
 dropout_9 (Dropout)         (None, 186, 128)          0         
                                                                 
 bidirectional_1 (Bidirectio  (None, 128)              98816     
 nal)                                                            
                                                                 
 dropout_10 (Dropout)        (None, 128)               0         
                                                                 
 dense_6 (Dense)             (None, 128)               16512     
                                                                 
 dropout_11 (Dropout)        (None, 128)               0         
                                                                 
 dense_7 (Dense)             (None, 5)                 645       
                                                                 
=================================================================
Total params: 149,765
Trainable params: 149,765
Non-trainable params: 0
_________________________________________________________________
  • GRU
def build_bigru(input_len=L, n_classes=n_classes):
    inp = layers.Input(shape=(input_len, 1))
    x = layers.Bidirectional(layers.GRU(64, return_sequences=True))(inp)
    x = layers.Dropout(0.3)(x)
    x = layers.Bidirectional(layers.GRU(64))(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.4)(x)
    out = layers.Dense(n_classes, activation='softmax')(x)
    model = models.Model(inp, out)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy']
    )
    return model

gru_model = build_bigru()
gru_model.summary()
Model: "model_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 186, 1)]          0         
                                                                 
 bidirectional_2 (Bidirectio  (None, 186, 128)         25728     
 nal)                                                            
                                                                 
 dropout_12 (Dropout)        (None, 186, 128)          0         
                                                                 
 bidirectional_3 (Bidirectio  (None, 128)              74496     
 nal)                                                            
                                                                 
 dropout_13 (Dropout)        (None, 128)               0         
                                                                 
 dense_8 (Dense)             (None, 128)               16512     
                                                                 
 dropout_14 (Dropout)        (None, 128)               0         
                                                                 
 dense_9 (Dense)             (None, 5)                 645       
                                                                 
=================================================================
Total params: 117,381
Trainable params: 117,381
Non-trainable params: 0
_________________________________________________________________
cbs = [
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True)
]

hist_lstm = lstm_model.fit(
    X_tr, y_tr,
    validation_data=(X_va, y_va),
    epochs=50, batch_size=256,
    class_weight=class_weight,
    callbacks=cbs, verbose=1
)

hist_gru = gru_model.fit(
    X_tr, y_tr,
    validation_data=(X_va, y_va),
    epochs=50, batch_size=256,
    class_weight=class_weight,
    callbacks=cbs, verbose=1
)
Epoch 1/50
55/55 [==============================] - 53s 860ms/step - loss: 1.4447 - accuracy: 0.2344 - val_loss: 1.3851 - val_accuracy: 0.2281 - lr: 0.0010
Epoch 2/50
55/55 [==============================] - 55s 985ms/step - loss: 1.0811 - accuracy: 0.3211 - val_loss: 1.1173 - val_accuracy: 0.4059 - lr: 0.0010
Epoch 3/50
55/55 [==============================] - 62s 1s/step - loss: 0.8547 - accuracy: 0.4879 - val_loss: 0.7796 - val_accuracy: 0.7025 - lr: 0.0010
Epoch 4/50
55/55 [==============================] - 50s 921ms/step - loss: 0.7357 - accuracy: 0.6732 - val_loss: 0.7299 - val_accuracy: 0.8007 - lr: 0.0010
Epoch 5/50
55/55 [==============================] - 55s 969ms/step - loss: 0.7025 - accuracy: 0.6655 - val_loss: 0.8919 - val_accuracy: 0.6691 - lr: 0.0010
Epoch 6/50
55/55 [==============================] - 59s 1s/step - loss: 0.6482 - accuracy: 0.7323 - val_loss: 0.8052 - val_accuracy: 0.7439 - lr: 0.0010
Epoch 7/50
55/55 [==============================] - 55s 1s/step - loss: 0.6159 - accuracy: 0.7287 - val_loss: 0.7832 - val_accuracy: 0.7973 - lr: 0.0010
Epoch 8/50
55/55 [==============================] - 46s 823ms/step - loss: 0.5925 - accuracy: 0.7509 - val_loss: 0.6506 - val_accuracy: 0.8356 - lr: 5.0000e-04
Epoch 9/50
55/55 [==============================] - 49s 899ms/step - loss: 0.5229 - accuracy: 0.7952 - val_loss: 0.6809 - val_accuracy: 0.8139 - lr: 5.0000e-04
Epoch 12/50
55/55 [==============================] - 50s 896ms/step - loss: 0.4991 - accuracy: 0.7887 - val_loss: 0.6454 - val_accuracy: 0.8233 - lr: 2.5000e-04
Epoch 13/50
55/55 [==============================] - 59s 1s/step - loss: 0.4935 - accuracy: 0.7840 - val_loss: 0.6209 - val_accuracy: 0.8441 - lr: 2.5000e-04
Epoch 14/50
55/55 [==============================] - 58s 1s/step - loss: 0.4958 - accuracy: 0.8033 - val_loss: 0.6542 - val_accuracy: 0.8367 - lr: 2.5000e-04
Epoch 15/50
55/55 [==============================] - 46s 836ms/step - loss: 0.4884 - accuracy: 0.8053 - val_loss: 0.6998 - val_accuracy: 0.8139 - lr: 2.5000e-04
Epoch 16/50
55/55 [==============================] - 56s 1s/step - loss: 0.4653 - accuracy: 0.8044 - val_loss: 0.6575 - val_accuracy: 0.8167 - lr: 2.5000e-04
Epoch 17/50
55/55 [==============================] - 60s 1s/step - loss: 0.4719 - accuracy: 0.8073 - val_loss: 0.6300 - val_accuracy: 0.8390 - lr: 1.2500e-04
Epoch 18/50
55/55 [==============================] - 52s 957ms/step - loss: 0.4520 - accuracy: 0.8132 - val_loss: 0.5995 - val_accuracy: 0.8396 - lr: 1.2500e-04
Epoch 19/50
55/55 [==============================] - 49s 873ms/step - loss: 0.4573 - accuracy: 0.8223 - val_loss: 0.6010 - val_accuracy: 0.8410 - lr: 1.2500e-04
Epoch 20/50
55/55 [==============================] - 61s 1s/step - loss: 0.4486 - accuracy: 0.8178 - val_loss: 0.5740 - val_accuracy: 0.8567 - lr: 1.2500e-04
Epoch 21/50
55/55 [==============================] - 58s 1s/step - loss: 0.4543 - accuracy: 0.8139 - val_loss: 0.5989 - val_accuracy: 0.8393 - lr: 1.2500e-04
Epoch 22/50
55/55 [==============================] - 45s 821ms/step - loss: 0.4522 - accuracy: 0.8105 - val_loss: 0.5704 - val_accuracy: 0.8567 - lr: 1.2500e-04
Epoch 23/50
55/55 [==============================] - 57s 1s/step - loss: 0.4456 - accuracy: 0.8170 - val_loss: 0.5845 - val_accuracy: 0.8464 - lr: 1.2500e-04
Epoch 24/50
55/55 [==============================] - 60s 1s/step - loss: 0.4436 - accuracy: 0.8166 - val_loss: 0.5719 - val_accuracy: 0.8464 - lr: 1.2500e-04
Epoch 25/50
55/55 [==============================] - 48s 873ms/step - loss: 0.4467 - accuracy: 0.8261 - val_loss: 0.6169 - val_accuracy: 0.8333 - lr: 1.2500e-04
Epoch 26/50
55/55 [==============================] - 58s 1s/step - loss: 0.4420 - accuracy: 0.8108 - val_loss: 0.5781 - val_accuracy: 0.8527 - lr: 6.2500e-05
Epoch 27/50
55/55 [==============================] - 57s 1s/step - loss: 0.4332 - accuracy: 0.8208 - val_loss: 0.5613 - val_accuracy: 0.8536 - lr: 6.2500e-05
Epoch 28/50
55/55 [==============================] - 47s 854ms/step - loss: 0.4198 - accuracy: 0.8161 - val_loss: 0.5954 - val_accuracy: 0.8461 - lr: 6.2500e-05
Epoch 29/50
55/55 [==============================] - 53s 953ms/step - loss: 0.4212 - accuracy: 0.8185 - val_loss: 0.5794 - val_accuracy: 0.8461 - lr: 6.2500e-05
Epoch 30/50
55/55 [==============================] - 59s 1s/step - loss: 0.4263 - accuracy: 0.8286 - val_loss: 0.5395 - val_accuracy: 0.8578 - lr: 6.2500e-05
Epoch 31/50
55/55 [==============================] - 44s 812ms/step - loss: 0.4313 - accuracy: 0.8131 - val_loss: 0.5670 - val_accuracy: 0.8470 - lr: 6.2500e-05
Epoch 32/50
55/55 [==============================] - 54s 996ms/step - loss: 0.4285 - accuracy: 0.8227 - val_loss: 0.5752 - val_accuracy: 0.8484 - lr: 6.2500e-05
Epoch 33/50
55/55 [==============================] - 48s 853ms/step - loss: 0.4234 - accuracy: 0.8249 - val_loss: 0.5688 - val_accuracy: 0.8501 - lr: 6.2500e-05
Epoch 34/50
55/55 [==============================] - 49s 891ms/step - loss: 0.4210 - accuracy: 0.8283 - val_loss: 0.5749 - val_accuracy: 0.8470 - lr: 3.1250e-05
Epoch 35/50
55/55 [==============================] - 56s 1s/step - loss: 0.4159 - accuracy: 0.8269 - val_loss: 0.5576 - val_accuracy: 0.8530 - lr: 3.1250e-05
Epoch 36/50
55/55 [==============================] - 45s 826ms/step - loss: 0.4144 - accuracy: 0.8286 - val_loss: 0.5677 - val_accuracy: 0.8490 - lr: 3.1250e-05
Epoch 37/50
55/55 [==============================] - 59s 1s/step - loss: 0.4163 - accuracy: 0.8278 - val_loss: 0.5573 - val_accuracy: 0.8507 - lr: 1.5625e-05
Epoch 38/50
55/55 [==============================] - 56s 1s/step - loss: 0.4125 - accuracy: 0.8261 - val_loss: 0.5558 - val_accuracy: 0.8533 - lr: 1.5625e-05
Epoch 1/50
55/55 [==============================] - 38s 601ms/step - loss: 1.4778 - accuracy: 0.2935 - val_loss: 1.3737 - val_accuracy: 0.3097 - lr: 0.0010
Epoch 2/50
55/55 [==============================] - 33s 608ms/step - loss: 1.2010 - accuracy: 0.3377 - val_loss: 1.3370 - val_accuracy: 0.2412 - lr: 0.0010
Epoch 3/50
55/55 [==============================] - 33s 607ms/step - loss: 1.0183 - accuracy: 0.3458 - val_loss: 1.2278 - val_accuracy: 0.4171 - lr: 0.0010
Epoch 4/50
55/55 [==============================] - 31s 572ms/step - loss: 0.9265 - accuracy: 0.3909 - val_loss: 1.1295 - val_accuracy: 0.4256 - lr: 0.0010
Epoch 5/50
55/55 [==============================] - 34s 620ms/step - loss: 0.8314 - accuracy: 0.4499 - val_loss: 0.9094 - val_accuracy: 0.5678 - lr: 0.0010
Epoch 6/50
55/55 [==============================] - 33s 605ms/step - loss: 0.7969 - accuracy: 0.5272 - val_loss: 1.1571 - val_accuracy: 0.3825 - lr: 0.0010
Epoch 7/50
55/55 [==============================] - 33s 597ms/step - loss: 0.7091 - accuracy: 0.5333 - val_loss: 1.0658 - val_accuracy: 0.5133 - lr: 0.0010
Epoch 8/50
55/55 [==============================] - 33s 595ms/step - loss: 0.6783 - accuracy: 0.5704 - val_loss: 0.9522 - val_accuracy: 0.6472 - lr: 0.0010
Epoch 9/50
55/55 [==============================] - 31s 575ms/step - loss: 0.6493 - accuracy: 0.6109 - val_loss: 0.7449 - val_accuracy: 0.7268 - lr: 5.0000e-04
Epoch 10/50
55/55 [==============================] - 33s 597ms/step - loss: 0.6403 - accuracy: 0.6452 - val_loss: 0.8811 - val_accuracy: 0.6594 - lr: 5.0000e-04
Epoch 11/50
55/55 [==============================] - 34s 624ms/step - loss: 0.6106 - accuracy: 0.6439 - val_loss: 0.7216 - val_accuracy: 0.7608 - lr: 5.0000e-04
Epoch 12/50
15/55 [=======>......................] - ETA: 25s - loss: 0.5640 - accuracy: 0.6893
# from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix
# from sklearn.preprocessing import label_binarize

def evaluate(model, X_te, y_te, n_classes):
    proba = model.predict(X_te)
    pred = proba.argmax(axis=1)
    acc = accuracy_score(y_te, pred)
    y_bin = label_binarize(y_te, classes=np.arange(n_classes))
    auc = roc_auc_score(y_bin, proba, average='macro', multi_class='ovr')
    print("Accuracy:", acc)
    print("ROC-AUC (macro OvR):", auc)
    print(classification_report(y_te, pred, digits=3))
    print(confusion_matrix(y_te, pred))
    return acc, auc

print("=== LSTM ===")
acc_lstm, auc_lstm = evaluate(lstm_model, X_te, y_te, n_classes)

print("\n=== GRU ===")
acc_gru, auc_gru = evaluate(gru_model, X_te, y_te, n_classes)
=== LSTM ===
137/137 [==============================] - 9s 62ms/step
Accuracy: 0.8597853391185202
ROC-AUC (macro OvR): 0.9577867423130026
              precision    recall  f1-score   support

           0      0.980     0.858     0.915      3624
           1      0.280     0.748     0.408       111
           2      0.743     0.817     0.778       290
           3      0.154     0.906     0.264        32
           4      0.763     0.950     0.846       322

    accuracy                          0.860      4379
   macro avg      0.584     0.856     0.642      4379
weighted avg      0.924     0.860     0.883      4379

[[3110  199   78  148   89]
 [  24   83    2    1    1]
 [  27   12  237    9    5]
 [   1    1    1   29    0]
 [  13    1    1    1  306]]

=== GRU ===
137/137 [==============================] - 7s 54ms/step
Accuracy: 0.798812514272665
ROC-AUC (macro OvR): 0.9510862539659783
              precision    recall  f1-score   support

           0      0.975     0.787     0.871      3624
           1      0.194     0.721     0.305       111
           2      0.581     0.841     0.687       290
           3      0.101     0.875     0.182        32
           4      0.849     0.910     0.879       322

    accuracy                          0.799      4379
   macro avg      0.540     0.827     0.585      4379
weighted avg      0.914     0.799     0.840      4379

[[2853  327  170  229   45]
 [  26   80    2    3    0]
 [  18    6  244   15    7]
 [   2    0    2   28    0]
 [  26    0    2    1  293]]