본문 바로가기

Data & AI

TAFAS (AAAI 2025)

반응형

안녕하세요 

 

오늘은 Battling the Non-stationarity in Time Series Forecasting via Test-time Adaptation이라는 논문을 읽고

 

서비스의 트레이딩 모델에 어떻게 적용할 수 있을지 연구해보았습니다.

 

해당 논문은 

 

시계열 예측이라는 Task의 out-of-distribution(OOD) 일반화 문제를 

 

Invariant Learning을 적용하여 해결하는 내용을 담고 있습니다. 

 

우리가 주가를 예측할 때, 참고하는 데이터들 외에 다른 외적인 요인들에 의해 영향을 받을 수 있기 때문에

 

관측되지 않은 Unobserved core variables로 논문에서 정의하고 있습니다. 

 

또한 시간 시계열 데이터셋에 적합한 environment labels이 부족하다는 점을 문제로 삼았습니다. 

 

그리고 TTA(Test-time Adaptation)라는 기법을 이용해

 

사전 학습된 모델을 테스트 입력 데이터에 동적으로 적응시키는데

 

논문에서는 TTA를 TSF에 적용하여 

 

분포 이동으로 인한 성능 저하를 완화하는데 활용했습니다. 

TTA를 TSF에 적용한 TAFAS 프레임워크 구현을 하여

 

비정상 시계열을 위한 테스트 시간 적응 예측 모델을 만들어내는 것을 제안합니다. 

 

주요 구성요소는 주기성 인식 적응 스케줄링 PAAS 파트와

 

게이트 교정 모듈 GCM으로 구성됩니다. 

 

 

Out-of-Distribution (OOD)

 

머신러닝 모델이 훈련된 데이터의 분포와

 

다른 분포를 가진 입력 데이터를 가리키는 개념입니다.

 

이는 모델이 학습된 환경과 다른 조건에서 발생하는 데이터를 의미하며,

 

모델의 신뢰성과 일반화 능력을 저해할 수 있는 주요 도전 과제입니다.

 

특히 가격의 예측은 학습데이터에서 주어진 글로벌한 매크로 패턴등은 포착할 수 있겠지만

 

갑작스럽게 발생되는 이슈들에는 다른 분포를 가지게 됩니다. 

 

역사적으로 기록된 주가의 데이터셋 내에서도 시간의 흐름에 따라

 

분포의 환경이 변화하게 되고 

 

해당 모델에서는 분포의 변화보다는 환경의 변화를 우선 학습하는 것에 중점을 두고 있습니다. 

 

 

논문의 시작에서도 Pre-trained forecaster가 distribution-shifted 된 시계열입력에 저조한 성과를 내는 걸 이야기합니다. 

 

특히 외적 요인에 민감하게 움직이는 금융데이터에선 모델의 일반화 능력이 부족하고 적절한 적응 매커니즘의 필요성이 제기됩니다. 

 

b 부분을 보면 시계열의 sequential nature이 전체 ground truth를 획득하기 전에 

 

partially-observed ground truth를 활용하여 예측기를 proactively adapt 할 수 있는 기회를 제공하고 있습니다. 

 

논문에서 이야기하는 핵심으로, test-time adaptation이라는 아이디어인데

 

분포 이동에 대응하여 모델의 예측 신뢰성을 강화하는 방안을 제시하고자 합니다

 

 

TSF-TTA

 

TSF-TTA는 테스트 시간 적응(Test-time Adaptation, TTA)을

 

시계열 예측(Time Series Forecasting, TSF)에 특화하여 적용한 프레임워크입니다.

 

이 개념은 사전 학습된 예측 모델이 테스트 데이터의 지속적인 분포 변화에 취약한 문제를 해결하기 위해 도입되었습니다.

 

기존 TTA는 컴퓨터 비전 분야에서 주로 분류 작업에 초점을 맞추었으나,

 

TSF-TTA는 시계열 데이터의 고유 특성, 즉 순차적 의존성과 지연된 ground truth 접근 가능성을 고려합니다.

 

구체적으로, TSF-TTA는 테스트 입력에 모델을 동적으로 적응시켜 일반화 성능을 유지하며,

 

모델의 가중치를 업데이트하지 않고 입력 보정 방식을 채택합니다.

 

논문에서 TSF-TTA는 비정상성으로 인한 분포 이동을 완화하는 핵심 전략으로 제시되며,

 

부분 관측 ground truth를 활용하여 실세계 배포 환경에서의 안정성을 강화합니다.

 

이는 장기 예측 시나리오에서 특히 효과적이며, 다양한 아키텍처에 모델-agnostic 하게 적용 가능합니다.

 

 

TAFAS

 

그것의 실제 구현체를 TAFAS라고 부르고 있습니다. 

 

비정상 시계열을 위한 테스트 시간 적응 예측 시스템으로 

 

이 프레임워크는 사전 학습된 소스 예측기를 freezing 시키고 

 

테스트 데이터의 변화하는 분포에 적응하도록 설계되었습니다. 

 

주요 구성 요소는 주기성 인식 적응 스케줄링 PAAS와 게이트 교정 모듈 GCM입니다. 

 

작동원리는 테스트 입력을 교정하여 소스 모델이 처리하기 쉬운 분포로 변환하는 데 있습니다. 

 

벤치마크데이터셋에선 우수한 성능을 보였지만 추후에 서비스 모델들에도 적용을 할 필요가 있습니다. 

 

import torch
from torch.nn import Module, Linear, Sigmoid

class TAFAS(Module):
    def __init__(self, base_model, hidden_dim, config):
        super().__init__()
        self.base_model = base_model
        self.calibration_gate = Linear(hidden_dim, 1)  # 게이트 교정 모듈
        self.sigmoid = Sigmoid()
        self.config = config  # PAAS 설정 포함

    def paas_schedule(self, data, period):
        # 주기성에 따라 적응 강도 조정
        adjusted_strength = data.mean(dim=1) * period_factor(self.config.period)
        return adjusted_strength * data

 

 

  • TAFAS 클래스는 기본 모델과 게이트를 초기화하며, PAAS 스케줄링 함수를 정의합니다.
  • 주기성을 고려하여 데이터 강도를 조정하며, 비 IID 특성을 처리합니다.

 

 

POGT(Partially-Observed Ground Truth)

 

시계열 예측 과정에서 전체 값들이 지연되어 관측되기 전에 

 

부분적으로 접근 가능한 실제 값을 의미합니다. 

 

 

 

위 그림에서 설명되듯, 시계열 데이터의 순차적 특성을 활용하여 

 

예측 후 초기 시점의 ground truth를 미리 이용할 수 있습니다. 

 

시장에서 예를 들자면 상승장에 학습된 모델을 투입했더니

 

하락장이 시작되는 걸 포착하는 기간이 POGT라고 할 수 있겠습니다. 

 

이는 기존 TTA의 가정을 위배하여 엔트로피 기반 손실 대신 MSE, MAE 같은 회귀 손실을 적용 가능하게 됩니다. 

 

POGT는 적응 지연을 최소화하고 분포 변화에 대한 즉각적인 대응을 촉진하여 TAFAS에서 PAAS와 결합되어

 

의미 있는 주기적 패턴을 반영합니다. 

 

 

 

GCM (Gated Calibration Module)

 

GCM은 TAFAS 프레임워크의 핵심 모듈로

 

테스트 시간 입력을 사전 학습된 모델이 효과적으로 처리할 수 있는 분포로 교정하는 역할을 수행합니다. 

 

게이트 메커니즘을 통해 교정 정도를 동적으로 제어하고 

 

글로벌 분포 이동을 고려하여 과적합을 방지합니다. 

 

GCM은 입력 데이터를 보정한 후 결과를 소스 예측기와 결합하여 최종 예측을 생성합니다 

 

이런 model-agnostic 방식으로 다양한 아키텍처에 적용 가능합니다. 

 

 

PAAS(Periodicity-Aware Adaptation Scheduling)

 

시계열 데이터의 주기적 패턴을 인식하여 

 

적응 스케줄을 동적으로 조정합니다. 

 

이는 부분 관측 ground truth의 충분한 길이를 확보하여

 

의미 있는 주기성을 반영하며 적응 강도를 최적화합니다. 

 

PAAS는 non-iid 특성을 처리하기 위해 설계되었으며 

 

POGT를 활용하여 선제적 적응을 촉진합니다. 논문에서 PAAS는 데이터의 주기성을 분석하여 스케줄링을 수행하고

 

과도한 적응으로 인한 노이즈를 방지합니다. 이 모듈이 TAFAS의 전체 프레임워크에서 GCM과 결합되어 

 

장기 예측 시 분포 이동을 효과적으로 대응합니다. 

 

 

def forward(self, input, test_data):
    pred = self.base_model(input)  # 기본 예측
    pgt = test_data[:self.config.partial_len]  # POGT 활용
    calibrated = self.calibration_gate(pred)  # GCM 교정
    gate = self.sigmoid(calibrated)  # 게이트 메커니즘 적용
    adapted_pred = gate * pred + (1 - gate) * pgt.mean()  # 적응된 예측 계산
    scheduled_data = self.paas_schedule(adapted_pred, test_data.period)
    return scheduled_data

 

  • forward 메서드는 입력을 처리하고, POGT를 활용하여 GCM으로 교정합니다.
  • 게이트 메커니즘으로 글로벌 이동을 제어하며, PAAS를 통해 주기적 적응을 적용합니다. 이는 분포 이동 완화의 핵심 로직입니다.

 

 

 

TAFAS는 기존 모델에 융합하여 적용하는 프레임워크입니다. 

 

그래서 다른 모델소개와 달리 다른 모델들에 추가적인 pytorch forward 로직 처리가 필요해서 

 

이번 포스팅에 별도의 예측기 성능에 대한 지표는 준비하지 못했는데

 

조만간 연구를 진행해서 포스팅하도록 하겠습니다. 

 

 

반응형