Date:     Updated:

카테고리:

태그: , , , ,


Survival Analysis(생존분석)은 통계학의 한 분야로, 어떠한 현상이 발생하기까지 걸리는 시간을 예측하는 분야이다. 암 환자가 관찰시작부터 사망에 이르는 시간을 분석하는 것, 사망 여부를 예측하는 것이 생존분석의 예시라고 할 수 있다. 뿐만 아니라 고객의 이탈 같은 비즈니스 문제에서도 생존분석이 사용된다.

필자는 ai 경진대회에서 다룬 주제가 생존분석이어서 처음 이 분야를 접하게 됐다. 생존분석에 쓰이는 딥러닝 툴로는 pytorch 기반의 pycox가 있는데, 생소한 분야이다보니 관련 정보가 드물었다.

때문에 pycox github 문서들을 참고하여 이 내용들을 풀이해보았다.

havakv/pycox github 링크

모듈, 라이브러리 설치

우선 필요한 모듈 설치부터 한다.

!pip install pycox
!pip install sklearn-pandas
import torch # For building the networks 
import torchtuples as tt # Some useful functions

#pycox
from sklearn_pandas import DataFrameMapper
from sklearn.preprocessing import StandardScaler
from pycox.preprocessing.feature_transforms import OrderedCategoricalLong
#continuous time
from pycox.models import CoxTime
from pycox.models.cox_time import MLPVanillaCoxTime
from pycox.models import CoxCC
from pycox.models import CoxPH
from pycox.models import PCHazard
#discrete time
from pycox.models import LogisticHazard
from pycox.models import PMF
from pycox.models import DeepHit
from pycox.models import DeepHitSingle
from pycox.models import MTLR
from pycox.models import LogisticHazard, BCESurv
#metric
from pycox.evaluation import EvalSurv


데이터 전처리

우선 pycox 모듈에서 제공되는 기본 데이터셋인 support를 불러온다. 그리고 train set과 validation set으로 나눠준다. 따로 rain_test_split함수를 안 쓰고 이렇게 train과 val set을 나누었다.

df_train = support.read_df()
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)

support 데이터 셋은 총 14개의 변수가 있고, target은 2개(duration, event)이다. numerical변수 8개는 standardize하고, binary변수 3개는 그대로 둔다. categorical변수 3개는 OrderedCategoricalLong으로 인코딩해준다.

cols_standardize =  ['x0', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13']
cols_leave = ['x1', 'x4', 'x5']
cols_categorical =  ['x2', 'x3', 'x6']


standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]

OrderedCategoricalLong()은 임베딩을 생성해주는데, 이는 원핫 인코딩의 메모리 한계를 보완하기 위한 방법이라고 한다. 아마 원핫인코딩보다 차원을 축소하는 방식인 것 같다.

categorical = [(col, OrderedCategoricalLong()) for col in cols_categorical]


DataFrameMapper는 결측치 대체, 표준화 등의 작업을 한꺼번에, 수월하게 처리해주는 모듈이다. 두 개를 나눠서 만든 이유는 자료형이 달라서라고 하는데, 뒤에 나올 임베딩 작업에서 나눠서 처리할 거다.

DataFrameMapper가 어떤 변수를 어떻게 전처리할 것인지를 보여준다.

x_mapper_float = DataFrameMapper(standardize + leave) # float32
x_mapper_long = DataFrameMapper(categorical) # int64

매개변수 df를 받고, df를 아까 만들어놓은 x_mapper_floatx_mapper_long에 따라 전처리를 하라는 의미 tuplefy는 중첩된 tuple에서 작업하게 해주는 함수라고 한다.

x_fit_transform = lambda df: tt.tuplefy(x_mapper_float.fit_transform(df), x_mapper_long.fit_transform(df))
x_transform = lambda df: tt.tuplefy(x_mapper_float.transform(df), x_mapper_long.transform(df))

train에서 훈련한 transform을 val, test에 적용한다.

x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')


duration 구간화

이제 모델링을 해야 한다. DeepHit 모델은 discrete time에 적절하고 DeppSurv 모델은 continuous time에 적절하다. 주어진 supprot 데이터셋의 duration을 보자. 연속형이다. 즉, continuous time이긴 하지만 DeepHit을 적용해보도록 한다. 그래서 처음으로 해야 할 것이 continuous한 것을 discrete하게 바꿔주는것이다.

num_durations = 10 # 나눌 구간의 수를 정한다. 10개의 time 구간을 만든다는 의미.
scheme = 'equidistant' # 같은 시간 간격으로 나눈다는 의미같다. 예를 들어 20분 단위로 뚝뚝 끊기게 만든다거나... 

분위수 단위로 나누려면 quantiles을 쓴다. 위에 설정한 매개변수대로 target(label)을 구간화시키기 위한 준비

labtrans = DeepHitSingle.label_transform(num_durations, scheme)

target 값만 빼와서 이제 구간화를 시켜준다.

get_target = lambda df: (df['duration'].values, df['event'].values)
y_train = labtrans.fit_transform(*get_target(df_train))
y_val = labtrans.transform(*get_target(df_val))

train = (x_train, y_train)
val = (x_val, y_val)

test set은 구간화를 시킬 필요가 없다. 그냥 있는 값을 쓴다.

durations_test, events_test = get_target(df_test)
from pycox.utils import kaplan_meier
import matplotlib.pyplot as plt
plt.vlines(labtrans.cuts, 0, 1, colors='black', linestyles="--", label='Discretization Grid')
kaplan_meier(*get_target(df_train)).plot(label='Kaplan-Meier')
plt.ylabel('S(t)')
plt.legend()
_ = plt.xlabel('Time')

모델을 돌리기 전에 임베딩(categorical변수 3개)의 차원을 결정해야 한다.

x_train[0]엔 numerical변수와 binary 변수가 담겨 있고, x_train[1]엔 categorical 변수가 담겨있다.

임베딩의 수와 차원을 결정할 것이기에 x_train[1]을 쓰기로 한다. categorical 변수 중 가장 차원이 큰 것을 찾아 1을 더해준다.

num_embeddings = x_train[1].max(0) + 1

#그리고 그것을 절반으로 나눈 값을 임베딩의 차원으로 정한다.
embedding_dims = num_embeddings // 2


모델링

이번에 사용할 DeppHit 모델은 2 다층 퍼셉트론이고, 각각 64개의 node를 가진다.

in_features = x_train[0].shape[1]
out_features = labtrans.out_features
num_nodes = [64, 64]
batch_norm = True # 각 배치별로 다양한 분산을 가진 데이터를 정규화해줌(batch normalization)
dropout = 0.2 # 0.2만큼 dropout하여 overfitting 방지

MLPVanilla를 사용하면 2개 층의 신경망을 자동으로 생성해준다. activation 함수로는 ReLU가 기본으로 들어가 있다

net = tt.practical.MixedInputMLP(in_features, num_embeddings, embedding_dims,
                                 num_nodes, out_features, batch_norm, dropout)

최적화는 AdamWR 방식을 채택했다. 보통은 Adam을 쓰긴 한다. decoupled_weight_decay는 overfitting을 억제하는 penalty항 역할을 한다.

optimizer = tt.optim.AdamWR(decoupled_weight_decay=0.01, cycle_eta_multiplier=0.8,
                            cycle_multiplier=2)

model = DeepHitSingle(net, optimizer, duration_index=labtrans.cuts)

lr(learning rate)을 설정해야하는데, lr_finder를 사용할 수 있다. lr_finder는 비록 최적의 lr을 알려주진 않지만, 대강 이쯤부터 lr을 찾아봐라 하는 추천값을 알려준다. 보통 추천해준 값보다 낮은 값을 입력하여 최적화를 시켜준다.

batch_size = 256
lrfind = model.lr_finder(x_train, y_train, batch_size, tolerance=50)
lrfind.get_best_lr()
model.optimizer.set_lr(0.05)


훈련 및 추론

#훈련
epochs = 100
callbacks = [tt.cb.EarlyStoppingCycle()]
verbose = True 

log = model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose,
                val_data=val)

아래 그래프는 test set에서 5명의 환자의 생존확률을 분석한 결과이다.

surv = model.interpolate(10).predict_surv_df(x_test)
surv.iloc[:, :5].plot(drawstyle='steps-post')
plt.ylabel('S(t | x)')
_ = plt.xlabel('Time')

학습한 결과로 val set에 대한 검증 점수를 출력한다.

ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
ev.concordance_td('antolini')


부록. scikit-survival

scikit-survival이라는 scikit-laern 기반 생존분석 모듈도 있다. 간단하게 알아보기로 하자.

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

from sklearn import set_config
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder

from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest

set_config(display="text")
X, y = load_gbsg2()

결측치를 확인

X.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 686 entries, 0 to 685
Data columns (total 8 columns):
 #   Column    Non-Null Count  Dtype   
---  ------    --------------  -----   
 0   age       686 non-null    float64 
 1   estrec    686 non-null    float64 
 2   horTh     686 non-null    category
 3   menostat  686 non-null    category
 4   pnodes    686 non-null    float64 
 5   progrec   686 non-null    float64 
 6   tgrade    686 non-null    category
 7   tsize     686 non-null    float64 
dtypes: category(3), float64(5)
memory usage: 29.2 KB

cardinality 확인. 각 변수가 가질 수 있는 값의 개수를 확인한다.

X.nunique()
age          54
estrec      244
horTh         2
menostat      2
pnodes       30
progrec     242
tgrade        3
tsize        58
dtype: int64

tgrade는 범주가 3개라서 onehotendoer를 쓰지 않고 OrdinalEncoder를 사용한다. 원핫인코딩은 범주가 많아지면, 새로 생성해야 하는 컬럼도 많아서 성능저하를 야기할 수도 있다. OrdinalEncoder는 원핫인코딩과 달리 각 특성별로 순서를 매겨주는 방식으로, 단점을 보완해준다.

grade_str = X.loc[:, "tgrade"].astype(object).values[:, np.newaxis]
grade_num = OrdinalEncoder(categories=[["I", "II", "III"]]).fit_transform(grade_str)

나머지 범주형 변수는 binary하므로 원핫인코딩을 하여도 컬럼이 과하게 많아지지 않는다.

X_no_grade = X.drop("tgrade", axis=1)
Xt = OneHotEncoder().fit_transform(X_no_grade)

# 앞서 OrdinalEncoder한 tgrade 추가
Xt.loc[:, "tgrade"] = grade_num
random_state = 20

X_train, X_test, y_train, y_test = train_test_split(
    Xt, y, test_size=0.25, random_state=random_state)

#training
rsf = RandomSurvivalForest(n_estimators=1000,
                           min_samples_split=10,
                           min_samples_leaf=15,
                           n_jobs=-1,
                           random_state=random_state)
rsf.fit(X_train, y_train)
RandomSurvivalForest(min_samples_leaf=15, min_samples_split=10,
                     n_estimators=1000, n_jobs=-1, random_state=20)
rsf.score(X_test, y_test)
0.6759696016771488