import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import pywt
from scipy.stats import f_oneway
import scikit_posthocs as sp
from scipy.stats import kruskal
from scipy.stats import skew, kurtosis
from scipy.fft import rfft, rfftfreqECG Heartbeat Categorization Dataset
Abstract - This dataset is composed of two collections of heartbeat signals derived from two famous datasets in heartbeat classification, the MIT-BIH Arrhythmia Dataset and The PTB Diagnostic ECG Database. The number of samples in both collections is large enough for training a deep neural network. - This dataset has been used in exploring heartbeat classification using deep neural network architectures, and observing some of the capabilities of transfer learning on it. The signals correspond to electrocardiogram (ECG) shapes of heartbeats for the normal case and the cases affected by different arrhythmias and myocardial infarction. These signals are preprocessed and segmented, with each segment corresponding to a heartbeat.
Content
- Arrhythmia Dataset
- Number of Samples: 109446
- Number of Categories: 5
- Sampling Frequency: 125Hz
- Data Source: Physionet’s MIT-BIH Arrhythmia Dataset
- Classes: [‘N’: 0, ‘S’: 1, ‘V’: 2, ‘F’: 3, ‘Q’: 4]
The PTB Diagnostic ECG Database - Number of Samples: 14552 - Number of Categories: 2 - Sampling Frequency: 125Hz - Data Source: Physionet’s PTB Diagnostic Database - Remark: All the samples are cropped, downsampled and padded with zeroes if necessary to the fixed dimension of 188.
Data Files - This dataset consists of a series of CSV files. Each of these CSV files contain a matrix, with each row representing an example in that portion of the dataset. The final element of each row denotes the class to which that example belongs.
Acknowledgements - Mohammad Kachuee, Shayan Fazeli, and Majid Sarrafzadeh. “ECG Heartbeat Classification: A Deep Transferable Representation.” arXiv preprint arXiv:1805.00794 (2018).
Inspiration - Can you identify myocardial infarction?
Import
데이터 이해 & 전처리
데이터셋 구성 확인
df = pd.read_csv('../../../delete/mitbih_test.csv',header=None)df.head(5)| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.000000 | 0.758264 | 0.111570 | 0.000000 | 0.080579 | 0.078512 | 0.066116 | 0.049587 | 0.047521 | 0.035124 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 1 | 0.908425 | 0.783883 | 0.531136 | 0.362637 | 0.366300 | 0.344322 | 0.333333 | 0.307692 | 0.296703 | 0.300366 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 2 | 0.730088 | 0.212389 | 0.000000 | 0.119469 | 0.101770 | 0.101770 | 0.110619 | 0.123894 | 0.115044 | 0.132743 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 3 | 1.000000 | 0.910417 | 0.681250 | 0.472917 | 0.229167 | 0.068750 | 0.000000 | 0.004167 | 0.014583 | 0.054167 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 4 | 0.570470 | 0.399329 | 0.238255 | 0.147651 | 0.000000 | 0.003356 | 0.040268 | 0.080537 | 0.070470 | 0.090604 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
5 rows × 188 columns
df.shape(21892, 188)
sample_idx = np.random.choice(df.shape[0], 9, replace=False)
plt.figure(figsize=(12, 10))
for i, idx in enumerate(sample_idx, 1):
ecg = df.iloc[idx, :-1].values if df.shape[1] > 188 else df.iloc[idx, :].values
plt.subplot(3, 3, i)
plt.plot(ecg, color="blue")
plt.title(f"Sample {idx}")
plt.xlabel("Time index (0-187)")
plt.ylabel("Amplitude")
plt.tight_layout()
plt.show()
df[187].unique()array([0., 1., 2., 3., 4.])
- MIT-BIH (부정맥 → 5 classes)
전처리
- Normalization (z-score scaling)
X = df.iloc[:, :-1].values
y = df.iloc[:, -1].values탐색적 데이터 분석 (EDA)
평균 파형 시각화
label_col = df.columns[-1]
feature_cols = df.columns[:-1]
n_points = len(feature_cols)fs = 125.0
t = np.arange(n_points) / fslabel_map = {0:'N', 1:'S', 2:'V', 3:'F', 4:'Q'}# 라벨별 평균 파형 계산
means = []
labels_present = sorted(df[label_col].unique())
for y in labels_present:
m = df.loc[df[label_col] == y, feature_cols].mean(axis=0).values
means.append((y, m))
# 한 그래프에 라벨별 평균 파형 그리기
plt.figure(figsize=(12,6))
for y, m in means:
plt.plot(t, m, label=f"{label_map.get(int(y), y)} ({int(y)})")
plt.title("Average ECG Waveform by Class")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
스펙트럼 분석 - FFT or Wavelet 변환을 통한 frequency 특징 비교.
FFT
# ----- 기본 설정 -----
fs = 125.0
label_col = df.columns[-1]
feature_cols = df.columns[:-1]
N = len(feature_cols) # 188
t = np.arange(N) / fs
freqs = np.fft.rfftfreq(N, d=1/fs) # 0 ~ Nyquist(62.5 Hz)
window = np.hanning(N) # 한닝 윈도우(누설 감소)
# 라벨 매핑(가독성)
label_map = {0:'N', 1:'S', 2:'V', 3:'F', 4:'Q'}
# ----- 클래스별 평균 파워 스펙트럼 -----
plt.figure(figsize=(12,6))
for y in sorted(df[label_col].unique()):
X = df.loc[df[label_col] == y, feature_cols].values # (n_class_samples, N)
# detrend(평균 제거) + windowing + FFT
X_detrended = X - X.mean(axis=1, keepdims=True)
Xw = X_detrended * window
F = np.fft.rfft(Xw, axis=1)
P = (np.abs(F)**2) / (window**2).sum() # 파워 정규화
P_mean = P.mean(axis=0)
plt.plot(freqs, P_mean, label=f"{label_map.get(int(y), y)} ({int(y)})")
plt.title("Average Power Spectrum by Class (FFT)")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()
Wavelet
# ----- DWT 파라미터 -----
wavelet = 'db4'
level = 4 # 188 포인트면 4~5레벨 적당
# 한 샘플에서 서브밴드 에너지 벡터 추출
def wavelet_band_energy(x):
# 평균 제거(기저선 영향 완화)
x = x - x.mean()
coeffs = pywt.wavedec(x, wavelet=wavelet, level=level)
# coeffs = [cA_L, cD_L, cD_{L-1}, ..., cD_1]
energies = [np.sum(c**2) for c in coeffs]
total = np.sum(energies) + 1e-12
# 상대 에너지(합=1)로 반환하면 샘플 간 진폭 차이 영향 완화
return np.array(energies) / total
# 전체 샘플 처리
feat_list = []
labels = []
for _, row in df.iterrows():
x = row[feature_cols].values.astype(float)
e = wavelet_band_energy(x)
feat_list.append(e)
labels.append(int(row[label_col]))
E = np.vstack(feat_list) # shape: (n_samples, 1 + level)
y = np.array(labels)
# 클래스별 평균 상대 에너지
classes = sorted(np.unique(y))
E_mean = {c: E[y==c].mean(axis=0) for c in classes}
# 플롯: 각 클래스의 subband 에너지 비교 (막대그래프)
bands = [f"cA_{level}"] + [f"cD_{i}" for i in range(level, 0, -1)] # 상위 스케일→하위 스케일
x_pos = np.arange(len(bands))
plt.figure(figsize=(12,6))
bar_width = 0.14
for i, c in enumerate(classes):
plt.bar(x_pos + i*bar_width, E_mean[c], width=bar_width, label=f"{c}")
plt.xticks(x_pos + bar_width*(len(classes)-1)/2, bands)
plt.ylabel("Relative Energy")
plt.title(f"Wavelet Subband Energy by Class (wavelet={wavelet}, level={level})")
plt.legend(title="Class (0=N,1=S,2=V,3=F,4=Q)")
plt.tight_layout()
plt.show()
통계적 비교 - Classes: [‘N’: 0, ‘S’: 1, ‘V’: 2, ‘F’: 3, ‘Q’: 4] 간에 amplitude, variance, QRS duration 등 특징 차이 통계 검정 (t-test, Mann-Whitney).
df['amp_mean'] = df.iloc[:, :-1].mean(axis=1)
df['amp_max'] = df.iloc[:, :-1].max(axis=1)
df['amp_min'] = df.iloc[:, :-1].min(axis=1)
df['amp_range'] = df['amp_max'] - df['amp_min']df['variance'] = df.iloc[:, :-1].var(axis=1)def qrs_duration(signal, threshold_ratio=0.5):
# R peak amplitude 추정
peak = np.max(signal)
threshold = peak * threshold_ratio
above = np.where(signal >= threshold)[0]
if len(above) == 0:
return 0
return above[-1] - above[0] # index 차이 (포인트 수)
df['qrs_dur'] = df.iloc[:, :-1].apply(lambda row: qrs_duration(row.values), axis=1)label_map = {0:'N', 1:'S', 2:'V', 3:'F', 4:'Q'}
desc = df.groupby(df.iloc[:, -1])[['amp_mean','variance','qrs_dur']].describe()groups = [df.loc[df.iloc[:, -1] == c, 'amp_mean'] for c in sorted(df.iloc[:, -1].unique())]
anova_amp = f_oneway(*groups)
groups = [df.loc[df.iloc[:, -1] == c, 'variance'] for c in sorted(df.iloc[:, -1].unique())]
anova_var = f_oneway(*groups)
groups = [df.loc[df.iloc[:, -1] == c, 'qrs_dur'] for c in sorted(df.iloc[:, -1].unique())]
anova_qrs = f_oneway(*groups)
print("ANOVA amp_mean:", anova_amp)
print("ANOVA variance:", anova_var)
print("ANOVA qrs_dur:", anova_qrs)ANOVA amp_mean: F_onewayResult(statistic=97.23061517562662, pvalue=0.0)
ANOVA variance: F_onewayResult(statistic=673.9205351617329, pvalue=0.0)
ANOVA qrs_dur: F_onewayResult(statistic=inf, pvalue=0.0)
/home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages/scipy/stats/_stats_py.py:3903: ConstantInputWarning: Each of the input arrays is constant;the F statistic is not defined or infinite
warnings.warn(stats.ConstantInputWarning(msg))
groups = [df.loc[df.iloc[:, -1] == c, 'amp_mean'] for c in sorted(df.iloc[:, -1].unique())]
kw_amp = kruskal(*groups)
groups = [df.loc[df.iloc[:, -1] == c, 'variance'] for c in sorted(df.iloc[:, -1].unique())]
kw_var = kruskal(*groups)
groups = [df.loc[df.iloc[:, -1] == c, 'qrs_dur'] for c in sorted(df.iloc[:, -1].unique())]
kw_qrs = kruskal(*groups)
print("Kruskal amp_mean:", kw_amp)
print("Kruskal variance:", kw_var)
print("Kruskal qrs_dur:", kw_qrs)Kruskal amp_mean: KruskalResult(statistic=4731.5315295296, pvalue=0.0)
Kruskal variance: KruskalResult(statistic=8376.756741998353, pvalue=0.0)
Kruskal qrs_dur: KruskalResult(statistic=21891.0, pvalue=0.0)
posthoc_amp = sp.posthoc_dunn(df, val_col='amp_mean', group_col=df.columns[-1], p_adjust='bonferroni')
print(posthoc_amp) 4 5 6 7 9 10 12 14 15 16 ... 182 \
4 1.000000e+00 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
5 1.000000e+00 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
6 1.000000e+00 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
7 1.000000e+00 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
9 1.000000e+00 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
.. ... ... ... ... ... ... ... ... ... ... ... ...
187 1.353411e-35 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
188 3.901249e-05 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
189 1.318389e-12 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
190 2.638355e-03 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
191 4.113517e-283 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
183 184 185 186 187 188 \
4 1.0 1.000000 1.000000e+00 7.231955e-26 1.353411e-35 3.901249e-05
5 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
6 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
7 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
9 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
.. ... ... ... ... ... ...
187 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
188 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
189 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
190 1.0 1.000000 1.000000e+00 3.018640e-04 6.368383e-07 1.000000e+00
191 1.0 0.000027 4.575056e-07 8.248003e-126 8.271528e-161 2.793909e-34
189 190 191
4 1.318389e-12 2.638355e-03 4.113517e-283
5 1.000000e+00 1.000000e+00 1.000000e+00
6 1.000000e+00 1.000000e+00 1.000000e+00
7 1.000000e+00 1.000000e+00 1.000000e+00
9 1.000000e+00 1.000000e+00 1.000000e+00
.. ... ... ...
187 1.000000e+00 6.368383e-07 8.271528e-161
188 1.000000e+00 1.000000e+00 2.793909e-34
189 1.000000e+00 3.952042e-01 1.699180e-68
190 3.952042e-01 1.000000e+00 2.442212e-88
191 1.699180e-68 2.442212e-88 1.000000e+00
[181 rows x 181 columns]
# 예: amp_mean 특징을 클래스별 비교
posthoc_amp = sp.posthoc_dunn(
df,
val_col='amp_mean',
group_col=df.columns[-1], # 마지막 열이 라벨
p_adjust='bonferroni'
)
print("Dunn post-hoc test (amp_mean):")
print(posthoc_amp)Dunn post-hoc test (amp_mean):
4 5 6 7 9 10 12 14 15 16 ... 182 \
4 1.000000e+00 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
5 1.000000e+00 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
6 1.000000e+00 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
7 1.000000e+00 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
9 1.000000e+00 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
.. ... ... ... ... ... ... ... ... ... ... ... ...
187 1.353411e-35 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
188 3.901249e-05 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
189 1.318389e-12 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
190 2.638355e-03 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
191 4.113517e-283 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0
183 184 185 186 187 188 \
4 1.0 1.000000 1.000000e+00 7.231955e-26 1.353411e-35 3.901249e-05
5 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
6 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
7 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
9 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
.. ... ... ... ... ... ...
187 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
188 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
189 1.0 1.000000 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00
190 1.0 1.000000 1.000000e+00 3.018640e-04 6.368383e-07 1.000000e+00
191 1.0 0.000027 4.575056e-07 8.248003e-126 8.271528e-161 2.793909e-34
189 190 191
4 1.318389e-12 2.638355e-03 4.113517e-283
5 1.000000e+00 1.000000e+00 1.000000e+00
6 1.000000e+00 1.000000e+00 1.000000e+00
7 1.000000e+00 1.000000e+00 1.000000e+00
9 1.000000e+00 1.000000e+00 1.000000e+00
.. ... ... ...
187 1.000000e+00 6.368383e-07 8.271528e-161
188 1.000000e+00 1.000000e+00 2.793909e-34
189 1.000000e+00 3.952042e-01 1.699180e-68
190 3.952042e-01 1.000000e+00 2.442212e-88
191 1.699180e-68 2.442212e-88 1.000000e+00
[181 rows x 181 columns]
# Boxplot
plt.figure(figsize=(8,6))
sns.boxplot(x=df.iloc[:, -1], y=df['amp_mean'])
plt.title("Amplitude Mean by Class")
plt.xlabel("Class (0=N,1=S,2=V,3=F,4=Q)")
plt.ylabel("amp_mean")
plt.show()
# Posthoc heatmap
plt.figure(figsize=(6,5))
sns.heatmap(posthoc_amp, annot=True, cmap="coolwarm", cbar=True, fmt=".3f")
plt.title("Post-hoc Dunn Test (p-values)")
plt.show()
특징 기반 머신러닝 접근
특징 추출 - 시간영역: mean, std, skewness, kurtosis, peak-to-peak - 주파수영역: dominant frequency, power spectrum - Wavelet: ST segment 관련 coefficient
def extract_time_features(x):
return {
'mean': float(np.mean(x)),
'std': float(np.std(x)),
'skew': float(skew(x)),
'kurtosis': float(kurtosis(x)),
'ptp': float(np.ptp(x)), # peak-to-peak
}
def extract_freq_features(x, fs=125):
N = len(x)
freqs = rfftfreq(N, d=1/fs)
spectrum = np.abs(rfft(x))**2
dom_freq = float(freqs[np.argmax(spectrum)])
power = float(np.sum(spectrum))
return {'dom_freq': dom_freq, 'power': power}
def extract_wavelet_features(x, wavelet='db4', level=4):
x = x - x.mean()
coeffs = pywt.wavedec(x, wavelet=wavelet, level=level)
energies = [np.sum(c**2) for c in coeffs]
total = np.sum(energies) + 1e-12
rel = [float(e/total) for e in energies]
# cA_L, cD_L, ..., cD_1
names = [f'wavelet_energy_cA_{level}'] + [f'wavelet_energy_cD_{i}' for i in range(level, 0, -1)]
return dict(zip(names, rel))
def extract_features(x, fs=125):
feats = {}
feats.update(extract_time_features(x))
feats.update(extract_freq_features(x, fs=fs))
feats.update(extract_wavelet_features(x))
return feats# ---- 라벨/피처 컬럼 안전하게 정의 ----
label_col = df.columns[-1] # 마지막 열을 라벨로 사용
feature_cols = df.columns[:-1] # 나머지는 ECG 시계열
# ---- 넘파이로 뽑아서 루프 (빠르고 안전) ----
X = df.loc[:, feature_cols].to_numpy(dtype=float)
y = df.loc[:, label_col].to_numpy()
feature_list = []
labels = []
for xi, yi in zip(X, y):
feats = extract_features(xi, fs=125)
feature_list.append(feats)
labels.append(int(yi))
feature_df = pd.DataFrame(feature_list)
feature_df['label'] = labels
print(feature_df.head())
print(feature_df.shape) mean std skew kurtosis ptp dom_freq power \
0 0.088813 0.165259 4.148272 18.524207 1.0 0.0 802.459636
1 0.201932 0.247202 1.129980 0.747387 1.0 0.0 2657.005592
2 0.119641 0.163150 2.984116 12.369159 1.0 0.0 1028.928658
3 0.168079 0.222119 1.874741 3.945423 1.0 0.0 1971.182895
4 0.207004 0.219523 1.185874 1.866668 1.0 0.0 2493.665686
wavelet_energy_cA_4 wavelet_energy_cD_4 wavelet_energy_cD_3 \
0 0.581605 0.070714 0.143888
1 0.781645 0.074139 0.056645
2 0.494606 0.057453 0.107020
3 0.792486 0.068899 0.062562
4 0.613792 0.090211 0.101761
wavelet_energy_cD_2 wavelet_energy_cD_1 label
0 0.068162 0.135632 191
1 0.009751 0.077819 191
2 0.097060 0.243862 191
3 0.008464 0.067589 191
4 0.025589 0.168648 191
(21892, 13)
ML 모델링 - Logistic Regression, Random Forest, XGBoost - 목표: baseline accuracy/ROC-AUC 확인
딥러닝 접근 (Raw Signal 사용)
1D CNN
- Local ECG pattern (QRS, ST-segment)을 자동으로 학습.
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf
from tensorflow.keras import layers, models# --- 라벨/피처 분리 ---
label_col = df.columns[187]
feature_cols = df.columns[:186]
X = df[feature_cols].to_numpy(dtype=np.float32) # (N, L)
y = df[label_col].to_numpy(dtype=np.int64) # (N, )
# # --- 샘플별 Z-score 표준화(파형 모양만 남기기) ---
eps = 1e-8
mu = X.mean(axis=1, keepdims=True)
sd = X.std(axis=1, keepdims=True)
X = (X - mu) / (sd + eps)
# # --- 80:20 split → train에서 20%를 val (최종 64/16/20) ---
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.20, random_state=42, stratify=y)
X_tr, X_va, y_tr, y_va = train_test_split(X_tr, y_tr, test_size=0.20, random_state=42, stratify=y_tr)classes = np.unique(y)
class_weights = compute_class_weight('balanced', classes=classes, y=y_tr)
class_weight = {int(c): float(w) for c, w in zip(classes, class_weights)}
print("class_weight:", class_weight)class_weight: {0: 0.24165588615782665, 1: 7.870786516853933, 2: 3.025917926565875, 3: 26.942307692307693, 4: 2.7230320699708455}
# 2) 모델: 1D CNN (QRS, ST 등 로컬 패턴 자동 학습)
L = X_tr.shape[1] # 시퀀스 길이
n_classes = len(np.unique(y))
def build_1d_cnn(input_length=L, n_classes=n_classes):
inp = layers.Input(shape=(input_length, 1))
x = layers.Conv1D(32, kernel_size=7, padding='same')(inp)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.MaxPool1D(2)(x)
x = layers.Dropout(0.2)(x)
x = layers.Conv1D(64, kernel_size=5, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.MaxPool1D(2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv1D(128, kernel_size=3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.4)(x)
out = layers.Dense(n_classes, activation='softmax')(x)
model = models.Model(inp, out)
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-3),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy']
)
return model
model = build_1d_cnn()
model.summary()Model: "model_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_3 (InputLayer) [(None, 186, 1)] 0
conv1d_6 (Conv1D) (None, 186, 32) 256
batch_normalization_6 (Batc (None, 186, 32) 128
hNormalization)
re_lu_6 (ReLU) (None, 186, 32) 0
max_pooling1d_4 (MaxPooling (None, 93, 32) 0
1D)
dropout_6 (Dropout) (None, 93, 32) 0
conv1d_7 (Conv1D) (None, 93, 64) 10304
batch_normalization_7 (Batc (None, 93, 64) 256
hNormalization)
re_lu_7 (ReLU) (None, 93, 64) 0
max_pooling1d_5 (MaxPooling (None, 46, 64) 0
1D)
dropout_7 (Dropout) (None, 46, 64) 0
conv1d_8 (Conv1D) (None, 46, 128) 24704
batch_normalization_8 (Batc (None, 46, 128) 512
hNormalization)
re_lu_8 (ReLU) (None, 46, 128) 0
global_average_pooling1d_2 (None, 128) 0
(GlobalAveragePooling1D)
dense_4 (Dense) (None, 128) 16512
dropout_8 (Dropout) (None, 128) 0
dense_5 (Dense) (None, 5) 645
=================================================================
Total params: 53,317
Trainable params: 52,869
Non-trainable params: 448
_________________________________________________________________
- 성능 높이기
- 한 층 더 추가한다.
- 23->64->128… 채널 수를 늘린다.
# 3) 학습
cb = [
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),
tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True)
]
hist = model.fit(
X_tr, y_tr,
validation_data=(X_va, y_va),
epochs=50,
batch_size=256,
class_weight=class_weight,
callbacks=cb,
verbose=1
)Epoch 1/50
55/55 [==============================] - 4s 41ms/step - loss: 1.3983 - accuracy: 0.2974 - val_loss: 1.6888 - val_accuracy: 0.0234 - lr: 0.0010
Epoch 2/50
55/55 [==============================] - 2s 36ms/step - loss: 1.0734 - accuracy: 0.4141 - val_loss: 2.2370 - val_accuracy: 0.0305 - lr: 0.0010
Epoch 3/50
55/55 [==============================] - 2s 36ms/step - loss: 0.9189 - accuracy: 0.4235 - val_loss: 2.7122 - val_accuracy: 0.0311 - lr: 0.0010
Epoch 4/50
55/55 [==============================] - 2s 37ms/step - loss: 0.8343 - accuracy: 0.4803 - val_loss: 2.9398 - val_accuracy: 0.0400 - lr: 0.0010
Epoch 5/50
55/55 [==============================] - 2s 37ms/step - loss: 0.7863 - accuracy: 0.4909 - val_loss: 2.7439 - val_accuracy: 0.0403 - lr: 5.0000e-04
Epoch 6/50
55/55 [==============================] - 11s 200ms/step - loss: 0.7505 - accuracy: 0.5074 - val_loss: 2.6304 - val_accuracy: 0.0514 - lr: 5.0000e-04
Epoch 7/50
55/55 [==============================] - 12s 220ms/step - loss: 0.7204 - accuracy: 0.5488 - val_loss: 2.3196 - val_accuracy: 0.1048 - lr: 5.0000e-04
Epoch 8/50
55/55 [==============================] - 2s 34ms/step - loss: 0.7089 - accuracy: 0.5453 - val_loss: 1.8006 - val_accuracy: 0.2643 - lr: 2.5000e-04
Epoch 9/50
55/55 [==============================] - 2s 36ms/step - loss: 0.6870 - accuracy: 0.5555 - val_loss: 1.5701 - val_accuracy: 0.3394 - lr: 2.5000e-04
Epoch 10/50
55/55 [==============================] - 2s 36ms/step - loss: 0.6709 - accuracy: 0.5479 - val_loss: 1.3493 - val_accuracy: 0.4099 - lr: 2.5000e-04
Epoch 11/50
55/55 [==============================] - 2s 35ms/step - loss: 0.6614 - accuracy: 0.5671 - val_loss: 1.4117 - val_accuracy: 0.3954 - lr: 2.5000e-04
Epoch 12/50
55/55 [==============================] - 2s 36ms/step - loss: 0.6540 - accuracy: 0.5715 - val_loss: 1.3136 - val_accuracy: 0.4116 - lr: 2.5000e-04
Epoch 13/50
55/55 [==============================] - 2s 37ms/step - loss: 0.6511 - accuracy: 0.5828 - val_loss: 1.1824 - val_accuracy: 0.4382 - lr: 2.5000e-04
Epoch 14/50
55/55 [==============================] - 14s 258ms/step - loss: 0.6341 - accuracy: 0.6059 - val_loss: 1.1752 - val_accuracy: 0.4599 - lr: 2.5000e-04
Epoch 15/50
55/55 [==============================] - 8s 138ms/step - loss: 0.6238 - accuracy: 0.5824 - val_loss: 1.1429 - val_accuracy: 0.4950 - lr: 2.5000e-04
Epoch 16/50
55/55 [==============================] - 2s 36ms/step - loss: 0.6284 - accuracy: 0.5942 - val_loss: 1.0967 - val_accuracy: 0.4747 - lr: 2.5000e-04
Epoch 17/50
55/55 [==============================] - 2s 35ms/step - loss: 0.6149 - accuracy: 0.6060 - val_loss: 1.1215 - val_accuracy: 0.4967 - lr: 2.5000e-04
Epoch 18/50
55/55 [==============================] - 2s 35ms/step - loss: 0.6041 - accuracy: 0.6247 - val_loss: 1.2169 - val_accuracy: 0.4459 - lr: 2.5000e-04
Epoch 19/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5890 - accuracy: 0.6167 - val_loss: 1.1233 - val_accuracy: 0.4425 - lr: 2.5000e-04
Epoch 20/50
55/55 [==============================] - 2s 37ms/step - loss: 0.5924 - accuracy: 0.6163 - val_loss: 1.1555 - val_accuracy: 0.4371 - lr: 1.2500e-04
Epoch 21/50
55/55 [==============================] - 5s 88ms/step - loss: 0.5795 - accuracy: 0.6340 - val_loss: 1.1013 - val_accuracy: 0.4804 - lr: 1.2500e-04
Epoch 22/50
55/55 [==============================] - 15s 269ms/step - loss: 0.5854 - accuracy: 0.6367 - val_loss: 1.1438 - val_accuracy: 0.4459 - lr: 1.2500e-04
Epoch 23/50
55/55 [==============================] - 4s 75ms/step - loss: 0.5819 - accuracy: 0.6288 - val_loss: 1.1051 - val_accuracy: 0.4859 - lr: 6.2500e-05
Epoch 24/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5769 - accuracy: 0.6416 - val_loss: 1.0771 - val_accuracy: 0.4990 - lr: 6.2500e-05
Epoch 25/50
55/55 [==============================] - 2s 35ms/step - loss: 0.5672 - accuracy: 0.6254 - val_loss: 1.0909 - val_accuracy: 0.4927 - lr: 6.2500e-05
Epoch 26/50
55/55 [==============================] - 2s 35ms/step - loss: 0.5637 - accuracy: 0.6376 - val_loss: 1.0779 - val_accuracy: 0.4921 - lr: 6.2500e-05
Epoch 27/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5760 - accuracy: 0.6438 - val_loss: 1.0539 - val_accuracy: 0.5090 - lr: 6.2500e-05
Epoch 28/50
55/55 [==============================] - 2s 37ms/step - loss: 0.5686 - accuracy: 0.6291 - val_loss: 1.0418 - val_accuracy: 0.5136 - lr: 6.2500e-05
Epoch 29/50
55/55 [==============================] - 9s 164ms/step - loss: 0.5743 - accuracy: 0.6503 - val_loss: 1.0573 - val_accuracy: 0.5104 - lr: 6.2500e-05
Epoch 30/50
55/55 [==============================] - 14s 251ms/step - loss: 0.5613 - accuracy: 0.6415 - val_loss: 1.0678 - val_accuracy: 0.5090 - lr: 6.2500e-05
Epoch 31/50
55/55 [==============================] - 2s 35ms/step - loss: 0.5497 - accuracy: 0.6428 - val_loss: 1.0438 - val_accuracy: 0.5201 - lr: 6.2500e-05
Epoch 32/50
55/55 [==============================] - 2s 35ms/step - loss: 0.5624 - accuracy: 0.6422 - val_loss: 1.0511 - val_accuracy: 0.5161 - lr: 3.1250e-05
Epoch 33/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5558 - accuracy: 0.6507 - val_loss: 1.0276 - val_accuracy: 0.5233 - lr: 3.1250e-05
Epoch 34/50
55/55 [==============================] - 2s 35ms/step - loss: 0.5521 - accuracy: 0.6499 - val_loss: 1.0458 - val_accuracy: 0.5144 - lr: 3.1250e-05
Epoch 35/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5578 - accuracy: 0.6471 - val_loss: 1.0451 - val_accuracy: 0.5101 - lr: 3.1250e-05
Epoch 36/50
55/55 [==============================] - 2s 39ms/step - loss: 0.5499 - accuracy: 0.6488 - val_loss: 1.0554 - val_accuracy: 0.5059 - lr: 3.1250e-05
Epoch 37/50
55/55 [==============================] - 14s 258ms/step - loss: 0.5488 - accuracy: 0.6470 - val_loss: 1.0474 - val_accuracy: 0.5107 - lr: 1.5625e-05
Epoch 38/50
55/55 [==============================] - 8s 154ms/step - loss: 0.5546 - accuracy: 0.6442 - val_loss: 1.0524 - val_accuracy: 0.5081 - lr: 1.5625e-05
Epoch 39/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5499 - accuracy: 0.6490 - val_loss: 1.0415 - val_accuracy: 0.5158 - lr: 1.5625e-05
Epoch 40/50
55/55 [==============================] - 2s 35ms/step - loss: 0.5419 - accuracy: 0.6495 - val_loss: 1.0420 - val_accuracy: 0.5161 - lr: 7.8125e-06
Epoch 41/50
55/55 [==============================] - 2s 36ms/step - loss: 0.5347 - accuracy: 0.6526 - val_loss: 1.0405 - val_accuracy: 0.5178 - lr: 7.8125e-06
# 4) 평가: Accuracy / ROC-AUC(One-vs-Rest)
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix
from sklearn.preprocessing import label_binarize
# 예측
proba_te = model.predict(X_te)
pred_te = proba_te.argmax(axis=1)
# Accuracy
acc = accuracy_score(y_te, pred_te)
print("Test Accuracy:", acc)
# ROC-AUC (다중클래스 → OvR)
y_te_bin = label_binarize(y_te, classes=np.arange(n_classes))
auc = roc_auc_score(y_te_bin, proba_te, average='macro', multi_class='ovr')
print("Test ROC-AUC (macro OvR):", auc)
# 추가 리포트/매트릭스
print(classification_report(y_te, pred_te, digits=3))
print(confusion_matrix(y_te, pred_te))137/137 [==============================] - 1s 4ms/step
Test Accuracy: 0.5181548298698333
Test ROC-AUC (macro OvR): 0.9432353917756396
precision recall f1-score support
0 0.987 0.445 0.614 3624
1 0.078 0.829 0.143 111
2 0.446 0.807 0.574 290
3 0.041 0.938 0.079 32
4 0.949 0.929 0.939 322
accuracy 0.518 4379
macro avg 0.500 0.789 0.470 4379
weighted avg 0.918 0.518 0.619 4379
[[1614 1080 276 640 14]
[ 9 92 4 6 0]
[ 2 6 234 46 2]
[ 0 1 1 30 0]
[ 11 0 10 2 299]]
LSTM/GRU
- ECG 신호의 순차적 dependency 반영.
LSTM
# import tensorflow as tf
# from tensorflow.keras import layers, models
L = X_tr.shape[1] # 188
n_classes = len(np.unique(y)) # 5
def build_bilstm(input_len=L, n_classes=n_classes):
inp = layers.Input(shape=(input_len, 1))
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(inp)
x = layers.Dropout(0.3)(x)
x = layers.Bidirectional(layers.LSTM(64))(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.4)(x)
out = layers.Dense(n_classes, activation='softmax')(x)
model = models.Model(inp, out)
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-3),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy']
)
return model
lstm_model = build_bilstm()
lstm_model.summary()Model: "model_3"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_4 (InputLayer) [(None, 186, 1)] 0
bidirectional (Bidirectiona (None, 186, 128) 33792
l)
dropout_9 (Dropout) (None, 186, 128) 0
bidirectional_1 (Bidirectio (None, 128) 98816
nal)
dropout_10 (Dropout) (None, 128) 0
dense_6 (Dense) (None, 128) 16512
dropout_11 (Dropout) (None, 128) 0
dense_7 (Dense) (None, 5) 645
=================================================================
Total params: 149,765
Trainable params: 149,765
Non-trainable params: 0
_________________________________________________________________
- GRU
def build_bigru(input_len=L, n_classes=n_classes):
inp = layers.Input(shape=(input_len, 1))
x = layers.Bidirectional(layers.GRU(64, return_sequences=True))(inp)
x = layers.Dropout(0.3)(x)
x = layers.Bidirectional(layers.GRU(64))(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.4)(x)
out = layers.Dense(n_classes, activation='softmax')(x)
model = models.Model(inp, out)
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-3),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy']
)
return model
gru_model = build_bigru()
gru_model.summary()Model: "model_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_5 (InputLayer) [(None, 186, 1)] 0
bidirectional_2 (Bidirectio (None, 186, 128) 25728
nal)
dropout_12 (Dropout) (None, 186, 128) 0
bidirectional_3 (Bidirectio (None, 128) 74496
nal)
dropout_13 (Dropout) (None, 128) 0
dense_8 (Dense) (None, 128) 16512
dropout_14 (Dropout) (None, 128) 0
dense_9 (Dense) (None, 5) 645
=================================================================
Total params: 117,381
Trainable params: 117,381
Non-trainable params: 0
_________________________________________________________________
cbs = [
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),
tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True)
]
hist_lstm = lstm_model.fit(
X_tr, y_tr,
validation_data=(X_va, y_va),
epochs=50, batch_size=256,
class_weight=class_weight,
callbacks=cbs, verbose=1
)
hist_gru = gru_model.fit(
X_tr, y_tr,
validation_data=(X_va, y_va),
epochs=50, batch_size=256,
class_weight=class_weight,
callbacks=cbs, verbose=1
)Epoch 1/50
55/55 [==============================] - 53s 860ms/step - loss: 1.4447 - accuracy: 0.2344 - val_loss: 1.3851 - val_accuracy: 0.2281 - lr: 0.0010
Epoch 2/50
55/55 [==============================] - 55s 985ms/step - loss: 1.0811 - accuracy: 0.3211 - val_loss: 1.1173 - val_accuracy: 0.4059 - lr: 0.0010
Epoch 3/50
55/55 [==============================] - 62s 1s/step - loss: 0.8547 - accuracy: 0.4879 - val_loss: 0.7796 - val_accuracy: 0.7025 - lr: 0.0010
Epoch 4/50
55/55 [==============================] - 50s 921ms/step - loss: 0.7357 - accuracy: 0.6732 - val_loss: 0.7299 - val_accuracy: 0.8007 - lr: 0.0010
Epoch 5/50
55/55 [==============================] - 55s 969ms/step - loss: 0.7025 - accuracy: 0.6655 - val_loss: 0.8919 - val_accuracy: 0.6691 - lr: 0.0010
Epoch 6/50
55/55 [==============================] - 59s 1s/step - loss: 0.6482 - accuracy: 0.7323 - val_loss: 0.8052 - val_accuracy: 0.7439 - lr: 0.0010
Epoch 7/50
55/55 [==============================] - 55s 1s/step - loss: 0.6159 - accuracy: 0.7287 - val_loss: 0.7832 - val_accuracy: 0.7973 - lr: 0.0010
Epoch 8/50
55/55 [==============================] - 46s 823ms/step - loss: 0.5925 - accuracy: 0.7509 - val_loss: 0.6506 - val_accuracy: 0.8356 - lr: 5.0000e-04
Epoch 9/50
55/55 [==============================] - 49s 899ms/step - loss: 0.5229 - accuracy: 0.7952 - val_loss: 0.6809 - val_accuracy: 0.8139 - lr: 5.0000e-04
Epoch 12/50
55/55 [==============================] - 50s 896ms/step - loss: 0.4991 - accuracy: 0.7887 - val_loss: 0.6454 - val_accuracy: 0.8233 - lr: 2.5000e-04
Epoch 13/50
55/55 [==============================] - 59s 1s/step - loss: 0.4935 - accuracy: 0.7840 - val_loss: 0.6209 - val_accuracy: 0.8441 - lr: 2.5000e-04
Epoch 14/50
55/55 [==============================] - 58s 1s/step - loss: 0.4958 - accuracy: 0.8033 - val_loss: 0.6542 - val_accuracy: 0.8367 - lr: 2.5000e-04
Epoch 15/50
55/55 [==============================] - 46s 836ms/step - loss: 0.4884 - accuracy: 0.8053 - val_loss: 0.6998 - val_accuracy: 0.8139 - lr: 2.5000e-04
Epoch 16/50
55/55 [==============================] - 56s 1s/step - loss: 0.4653 - accuracy: 0.8044 - val_loss: 0.6575 - val_accuracy: 0.8167 - lr: 2.5000e-04
Epoch 17/50
55/55 [==============================] - 60s 1s/step - loss: 0.4719 - accuracy: 0.8073 - val_loss: 0.6300 - val_accuracy: 0.8390 - lr: 1.2500e-04
Epoch 18/50
55/55 [==============================] - 52s 957ms/step - loss: 0.4520 - accuracy: 0.8132 - val_loss: 0.5995 - val_accuracy: 0.8396 - lr: 1.2500e-04
Epoch 19/50
55/55 [==============================] - 49s 873ms/step - loss: 0.4573 - accuracy: 0.8223 - val_loss: 0.6010 - val_accuracy: 0.8410 - lr: 1.2500e-04
Epoch 20/50
55/55 [==============================] - 61s 1s/step - loss: 0.4486 - accuracy: 0.8178 - val_loss: 0.5740 - val_accuracy: 0.8567 - lr: 1.2500e-04
Epoch 21/50
55/55 [==============================] - 58s 1s/step - loss: 0.4543 - accuracy: 0.8139 - val_loss: 0.5989 - val_accuracy: 0.8393 - lr: 1.2500e-04
Epoch 22/50
55/55 [==============================] - 45s 821ms/step - loss: 0.4522 - accuracy: 0.8105 - val_loss: 0.5704 - val_accuracy: 0.8567 - lr: 1.2500e-04
Epoch 23/50
55/55 [==============================] - 57s 1s/step - loss: 0.4456 - accuracy: 0.8170 - val_loss: 0.5845 - val_accuracy: 0.8464 - lr: 1.2500e-04
Epoch 24/50
55/55 [==============================] - 60s 1s/step - loss: 0.4436 - accuracy: 0.8166 - val_loss: 0.5719 - val_accuracy: 0.8464 - lr: 1.2500e-04
Epoch 25/50
55/55 [==============================] - 48s 873ms/step - loss: 0.4467 - accuracy: 0.8261 - val_loss: 0.6169 - val_accuracy: 0.8333 - lr: 1.2500e-04
Epoch 26/50
55/55 [==============================] - 58s 1s/step - loss: 0.4420 - accuracy: 0.8108 - val_loss: 0.5781 - val_accuracy: 0.8527 - lr: 6.2500e-05
Epoch 27/50
55/55 [==============================] - 57s 1s/step - loss: 0.4332 - accuracy: 0.8208 - val_loss: 0.5613 - val_accuracy: 0.8536 - lr: 6.2500e-05
Epoch 28/50
55/55 [==============================] - 47s 854ms/step - loss: 0.4198 - accuracy: 0.8161 - val_loss: 0.5954 - val_accuracy: 0.8461 - lr: 6.2500e-05
Epoch 29/50
55/55 [==============================] - 53s 953ms/step - loss: 0.4212 - accuracy: 0.8185 - val_loss: 0.5794 - val_accuracy: 0.8461 - lr: 6.2500e-05
Epoch 30/50
55/55 [==============================] - 59s 1s/step - loss: 0.4263 - accuracy: 0.8286 - val_loss: 0.5395 - val_accuracy: 0.8578 - lr: 6.2500e-05
Epoch 31/50
55/55 [==============================] - 44s 812ms/step - loss: 0.4313 - accuracy: 0.8131 - val_loss: 0.5670 - val_accuracy: 0.8470 - lr: 6.2500e-05
Epoch 32/50
55/55 [==============================] - 54s 996ms/step - loss: 0.4285 - accuracy: 0.8227 - val_loss: 0.5752 - val_accuracy: 0.8484 - lr: 6.2500e-05
Epoch 33/50
55/55 [==============================] - 48s 853ms/step - loss: 0.4234 - accuracy: 0.8249 - val_loss: 0.5688 - val_accuracy: 0.8501 - lr: 6.2500e-05
Epoch 34/50
55/55 [==============================] - 49s 891ms/step - loss: 0.4210 - accuracy: 0.8283 - val_loss: 0.5749 - val_accuracy: 0.8470 - lr: 3.1250e-05
Epoch 35/50
55/55 [==============================] - 56s 1s/step - loss: 0.4159 - accuracy: 0.8269 - val_loss: 0.5576 - val_accuracy: 0.8530 - lr: 3.1250e-05
Epoch 36/50
55/55 [==============================] - 45s 826ms/step - loss: 0.4144 - accuracy: 0.8286 - val_loss: 0.5677 - val_accuracy: 0.8490 - lr: 3.1250e-05
Epoch 37/50
55/55 [==============================] - 59s 1s/step - loss: 0.4163 - accuracy: 0.8278 - val_loss: 0.5573 - val_accuracy: 0.8507 - lr: 1.5625e-05
Epoch 38/50
55/55 [==============================] - 56s 1s/step - loss: 0.4125 - accuracy: 0.8261 - val_loss: 0.5558 - val_accuracy: 0.8533 - lr: 1.5625e-05
Epoch 1/50
55/55 [==============================] - 38s 601ms/step - loss: 1.4778 - accuracy: 0.2935 - val_loss: 1.3737 - val_accuracy: 0.3097 - lr: 0.0010
Epoch 2/50
55/55 [==============================] - 33s 608ms/step - loss: 1.2010 - accuracy: 0.3377 - val_loss: 1.3370 - val_accuracy: 0.2412 - lr: 0.0010
Epoch 3/50
55/55 [==============================] - 33s 607ms/step - loss: 1.0183 - accuracy: 0.3458 - val_loss: 1.2278 - val_accuracy: 0.4171 - lr: 0.0010
Epoch 4/50
55/55 [==============================] - 31s 572ms/step - loss: 0.9265 - accuracy: 0.3909 - val_loss: 1.1295 - val_accuracy: 0.4256 - lr: 0.0010
Epoch 5/50
55/55 [==============================] - 34s 620ms/step - loss: 0.8314 - accuracy: 0.4499 - val_loss: 0.9094 - val_accuracy: 0.5678 - lr: 0.0010
Epoch 6/50
55/55 [==============================] - 33s 605ms/step - loss: 0.7969 - accuracy: 0.5272 - val_loss: 1.1571 - val_accuracy: 0.3825 - lr: 0.0010
Epoch 7/50
55/55 [==============================] - 33s 597ms/step - loss: 0.7091 - accuracy: 0.5333 - val_loss: 1.0658 - val_accuracy: 0.5133 - lr: 0.0010
Epoch 8/50
55/55 [==============================] - 33s 595ms/step - loss: 0.6783 - accuracy: 0.5704 - val_loss: 0.9522 - val_accuracy: 0.6472 - lr: 0.0010
Epoch 9/50
55/55 [==============================] - 31s 575ms/step - loss: 0.6493 - accuracy: 0.6109 - val_loss: 0.7449 - val_accuracy: 0.7268 - lr: 5.0000e-04
Epoch 10/50
55/55 [==============================] - 33s 597ms/step - loss: 0.6403 - accuracy: 0.6452 - val_loss: 0.8811 - val_accuracy: 0.6594 - lr: 5.0000e-04
Epoch 11/50
55/55 [==============================] - 34s 624ms/step - loss: 0.6106 - accuracy: 0.6439 - val_loss: 0.7216 - val_accuracy: 0.7608 - lr: 5.0000e-04
Epoch 12/50
15/55 [=======>......................] - ETA: 25s - loss: 0.5640 - accuracy: 0.6893
# from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix
# from sklearn.preprocessing import label_binarize
def evaluate(model, X_te, y_te, n_classes):
proba = model.predict(X_te)
pred = proba.argmax(axis=1)
acc = accuracy_score(y_te, pred)
y_bin = label_binarize(y_te, classes=np.arange(n_classes))
auc = roc_auc_score(y_bin, proba, average='macro', multi_class='ovr')
print("Accuracy:", acc)
print("ROC-AUC (macro OvR):", auc)
print(classification_report(y_te, pred, digits=3))
print(confusion_matrix(y_te, pred))
return acc, auc
print("=== LSTM ===")
acc_lstm, auc_lstm = evaluate(lstm_model, X_te, y_te, n_classes)
print("\n=== GRU ===")
acc_gru, auc_gru = evaluate(gru_model, X_te, y_te, n_classes)=== LSTM ===
137/137 [==============================] - 9s 62ms/step
Accuracy: 0.8597853391185202
ROC-AUC (macro OvR): 0.9577867423130026
precision recall f1-score support
0 0.980 0.858 0.915 3624
1 0.280 0.748 0.408 111
2 0.743 0.817 0.778 290
3 0.154 0.906 0.264 32
4 0.763 0.950 0.846 322
accuracy 0.860 4379
macro avg 0.584 0.856 0.642 4379
weighted avg 0.924 0.860 0.883 4379
[[3110 199 78 148 89]
[ 24 83 2 1 1]
[ 27 12 237 9 5]
[ 1 1 1 29 0]
[ 13 1 1 1 306]]
=== GRU ===
137/137 [==============================] - 7s 54ms/step
Accuracy: 0.798812514272665
ROC-AUC (macro OvR): 0.9510862539659783
precision recall f1-score support
0 0.975 0.787 0.871 3624
1 0.194 0.721 0.305 111
2 0.581 0.841 0.687 290
3 0.101 0.875 0.182 32
4 0.849 0.910 0.879 322
accuracy 0.799 4379
macro avg 0.540 0.827 0.585 4379
weighted avg 0.914 0.799 0.840 4379
[[2853 327 170 229 45]
[ 26 80 2 3 0]
[ 18 6 244 15 7]
[ 2 0 2 28 0]
[ 26 0 2 1 293]]