본문으로 건너뛰기

© 2026 Molayo

Dev.to헤드라인2026. 05. 30. 16:48

NumPy에서 JAX로: 가속화된 AI를 경험하며 느낀 첫 번째 '아하!' 모먼트

요약

NumPy에서 JAX로 전환하며 경험한 패러다임 변화와 핵심 메커니즘을 다룹니다. 불변성 원칙, 하드웨어 인식 및 샤딩, JIT 컴파일을 통한 연산 최적화 방법을 설명합니다.

핵심 포인트

  • JAX 배열은 불변성을 유지하여 부작용을 방지함
  • 인덱스 기반 업데이트 구문으로 배열 수정 가능
  • 하드웨어 가속기(GPU/TPU)를 자동으로 인식하고 활용
  • jax.jit을 통한 연산 시퀀스 최적화 및 속도 향상

나의 '100일간의 AI 에이전트(100 Days of AI Agents)' 챌린지를 위한 오픈 소스 솔루션을 구축한다는 것은, 표준 NumPy 및 PyTorch보다 더 잘 확장되는 프레임워크를 살펴봐야 함을 의미했습니다. 이는 필연적으로 저를 JAX로 이끌었습니다.

JAX로 전환하려면 약간의 패러다임 전환(paradigm shift)이 필요합니다. 표준 Python 데이터 과학 스택에 익숙하다면, JAX는 배열 연산(array operations), 메모리, 그리고 하드웨어 실행(hardware execution)에 대해 생각하는 방식을 재설계하도록 강제합니다.

저는 오늘 핵심 메커니즘을 파고드는 시간을 보냈으며, 제가 얻은 가장 중요한 3가지 교훈과 이해를 도왔던 정확한 코드 스니펫을 공유하고자 합니다.

1. 불변성(Immutability)은 버그가 아니라 기능이다

이것이 저의 첫 번째 큰 장애물이었습니다. 표준 NumPy에서는 배열의 요소를 변경하고 싶다면, 단순히 제자리(in-place)에서 재할당하면 됩니다.

Python

**import numpy as np
x = np.arange(10)
x[0] = 10
print(x) # 출력: [10 1 2 3 4 5 6 7 8 9]

JAX에서 정확히 똑같은 시도를 하면, 다음과 같은 오류를 내뱉으며 비명을 지릅니다: TypeError: JAX arrays are immutable.

JAX 배열(jax.Array)은 생성된 후 변경할 수 없습니다. 이는 JAX의 함수형 프로그래밍(functional programming) 특성과 자동 미분(automatic differentiation)을 가능하게 하는 핵심 설계 원칙입니다. 배열을 업데이트하려면, JAX는 업데이트된 복사본을 반환하는 인덱스 기반 업데이트 구문(indexed update syntax)을 제공합니다.

Python

**import jax.numpy as jnp
x = jnp.arange(10)
y = x.at[0].set(10)

print(y) # 출력: [10 1 2 3 4 5 6 7 8 9]
print(x) # 출력: [0 1 2 3 4 5 6 7 8 9]**

주의할 점: 복사본을 생성하기 때문에 메모리 오버헤드(memory overhead)가 발생하지만, 분산 컴퓨팅(distributed computing)을 악몽으로 만드는 부작용(side-effects)을 완전히 제거해 줍니다.

2. 네이티브 하드웨어 인식 및 샤딩(Sharding)

JAX 배열은 본질적으로 자신이 어디에 위치해 있는지 알고 있습니다. 데이터가 CPU, GPU 또는 TPU 중 어디에 있는지 알아내기 위해 번거로운 과정을 거칠 필요가 없습니다.

기본적으로 JAX는 사용 가능한 가장 빠른 가속기(accelerator)로 연산을 밀어넣습니다. 저의 MSI Raider에서 이를 로컬로 실행하면서, .devices()를 사용하여 배열이 정확히 어디에 저장되어 있는지 쉽게 확인할 수 있었습니다.

Python

x.devices()

출력: {CpuDevice(id=0)}

더 중요한 점은, JAX 배열이 병렬 실행을 위해 여러 장치(devices)에 걸쳐 샤딩(sharding)될 수 있다는 것입니다. .sharding 속성을 통해 이를 확인할 수 있습니다:

Python

**x.sharding

출력: SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)

이는 현대적인 하드웨어 스케일링(hardware scaling)을 위해 처음부터 설계된 것처럼 느껴집니다.

3. JIT 컴파일의 마법

기본적으로 JAX는 연산을 하나씩 순차적으로 실행합니다(표준 Python과 동일합니다). 하지만 함수를 Just-In-Time (jax.jit) 컴파일로 감싸면, JAX는 전체 연산 시퀀스를 최적화하여 한 번에 실행합니다.

이를 테스트하기 위해 간단한 정규화(normalization) 함수를 작성했습니다:

Python

from jax import jit
import jax.numpy as jnp
import numpy as np

def norm(X):
X = X - X.mean(0)
return X / X.std(0)

norm_compiled = jit(norm)

더미 데이터 생성

np.random.seed(22)
X = jnp.array(np.random.rand(100000, 10))

%timeit를 사용하여 두 함수를 벤치마킹했습니다 (JAX의 비동기 디스패치(asynchronous dispatch)를 고려하여 .block_until_ready()를 추가했습니다). 결과는 즉각적이었습니다:

**표준 실행: 루프당 1.52 ms ± 16.3 μs

JIT 실행: 루프당 1.16 ms ± 26.2 μs**

컴파일러가 실행의 정확한 청사진(blueprint)을 사전에 알고 있기 때문에 속도가 현저히 빨라집니다. 유일한 제한 사항은 무엇일까요? 모든 JAX 코드가 JIT 컴파일될 수 있는 것은 아닙니다. 배열의 형태(shape)가 정적(static)이어야 하며 컴파일 시점에 알려져 있어야 합니다.

다음 단계는?
이것은 단지 표면을 긁어본 것에 불과합니다. 저의 다음 심층 탐구에서는 함수형 무작위성 (jax.random), 자동 미분 (jax.grad), 그리고 자동 벡터화 (jax.vmap)를 다룰 예정입니다.

최근에 JAX로 넘어오신 분이 계신가요? 가장 큰 학습 곡선(learning curve)은 무엇이었나요? 아래에 댓글을 남겨주세요!

GitHub logo
ShinigamiFlanker0208 / JAX

*** 태그: #machinelearning #python #jax #ai #opensource

AI 자동 생성 콘텐츠

본 콘텐츠는 Dev.to AI tag의 원문을 AI가 자동으로 요약·번역·분석한 것입니다. 원 저작권은 원저작자에게 있으며, 정확한 내용은 반드시 원문을 확인해 주세요.

원문 바로가기
1

댓글

0