ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [리뷰] JAX/Flax로 딥러닝 레벨업
    서평 2024. 10. 2. 20:48

    딥러닝 업계는 매우 빠르게 변한다. 이는 심지어 모델 뿐만이 아니라 사용되는 프레임워크에도 통하는 이야기이다. 중학교 때 처음 scikit-learn을 배워 글자 인식 인공지능을 만들어보고, 케라스가 나왔을 때에는 이렇게 편한 프레임워크를 두고 설마 다른 프레임워크가 나오겠냐는 생각을 했었다. 그러나, 오늘날은 PyTorch의 세상이다. 그리고, Jax가 다시 한번 이 세상을 뒤집으려 하고 있다.

     

    지난 학기에 친구의 손에 이끌려 인공지능 개론이라는 수업을 수강했다. 딥러닝 모델 중 하나를 선택해 이를 개선해보는 프로젝트형 과목이었는데, 우리는 DreamerV3이라는 WM 프로젝트를 Jax를 이용하여 구현하는 프로젝트를 진행했다. 이러한 경험 때문에 서평단을 신청해보게 되었다.

     

    Jax를 한마디로 요약하자면 autograd가 붙어 있는 numpy이다. 배열을 GPU/TPU에 올릴 수 있으며, LAX를 사용해 JIT까지 해준다. 그렇기 때문에 모델을 구현할 때 다른 프레임워크보다 모델 그 자체에 집중할 수 있다는 느낌이 든다. 또, 나는 개인적으로 함수형 프로그래밍 스타일을 매우 좋아하는데, Jax의 코딩 스타일이 매우 마음에 들었다.

     

    하지만, Jax가 만능은 아니다. Jax의 JIT는 강력하지만 역설적이게도 문제점이기도 하다. Jax는 함수형 스타일을 지키기 위해 모든 연산이 복사 연산이며, 이로 인한 성능 문제는 JIT로 해결한다. 다시 말해, JIT가 잘 작동하지 못하는 일부 경우에는 끔찍한 성능을 보여준다. 또 하나의 문제는 Jax의 이코시스템이다. JAX를 공부하는 과정에서 느낀 것은, JAX 이코시스템이 아직 정리가 덜 되어 있다는 점이었다. 예를 들어 Flax, Equinox, Ninjax 같은 여러 라이브러리들이 존재하지만, 이들이 마치 퍼즐 조각처럼 따로따로 흩어져 있는 느낌이었다. 하나로 모아서 체계적으로 배우고 싶은데, 그게 생각만큼 쉽지 않았다.

     

    그래도 나는 JAX가 앞으로 더 발전할 여지가 충분하다고 생각한다. 이 책은 그런 JAX의 장단점을 잘 균형 잡아 다루고 있어서, 처음 JAX를 접하는 사람들에게 큰 도움이 될 것이다. 무엇보다도 파편화된 이코시스템 속에서 어떻게 이 도구들을 조화롭게 활용할 수 있을지 명확한 가이드를 제공한다는 점에서, 나처럼 JAX를 제대로 배우고 싶은 사람들에게 추천하고 싶다.

     

    JAX를 통해 머신러닝의 새로운 가능성을 탐험하는 여정을 함께 떠날 준비가 되었다면, 이 책이 훌륭한 동반자가 되어줄 것이다.

     

    [본 글은 제이펍출판사에서 도서를 제공받아 작성한 서평입니다.]

    댓글

2022 서리.