Gumbel AlphaZero 핵심 알고리즘 구현 심층 분석

Gumbel AlphaZero 핵심 알고리즘 구현 심층 분석

들어가며: 이론에서 코드로

지금까지 세 편의 포스트를 통해 Gumbel AlphaZero의 핵심 아이디어인 행동 선택과 정책 학습의 이론적 배경을 살펴보았습니다. 이번 마지막 포스트에서는 이 이론들이 실제 코드에서 어떻게 구현되어 있는지, Gumbel AlphaZero 구현 코드를 중심으로 심층 분석합니다.

이론과 코드를 함께 이해하면 알고리즘의 동작 방식을 훨씬 더 명확하게 파악할 수 있을 것입니다.

핵심 함수 search 분석

모든 탐색 과정의 시작점은 search 함수입니다. 이 함수는 Gumbel AlphaZero의 핵심 탐색 로직을 구현하며, 루트 노드와 내부 노드에서 서로 다른 행동 선택 전략을 사용합니다.


def search(
    params,
    rng_key: chex.PRNGKey,
    root: RootFnOutput,
    recurrent_fn: RecurrentFn,
    root_action_selection_fn: ActionSelectionFn,
    interior_action_selection_fn: ActionSelectionFn,
    num_simulations: int,
    max_depth: int = 1000,
    loop_fn: LoopFn = jax.lax.fori_loop,
    qtransform: QTransform = qtransforms.qtransform_by_parent_and_siblings,
    invalid_actions: Optional[chex.Array] = None
) -> PolicyOutput:

주요 처리 과정

  1. 초기 트리 구성: 루트 노드로부터 탐색 트리를 초기화합니다.
  2. 시뮬레이션 루프: num_simulations 횟수만큼 반복하며 트리를 확장합니다.
  3. 행동 선택: 각 노드에서 적절한 행동 선택 함수를 호출합니다.
  4. 백프로파게이션: 리프 노드에서 얻은 가치를 상위 노드로 전파합니다.

시뮬레이션 루프의 핵심 로직


def simulation_step(simulation_index, loop_state):
    """단일 시뮬레이션 스텝 실행"""
    tree, rng_key = loop_state
    
    # 1. 루트에서 리프까지 경로 선택
    path = select_path(tree, root_action_selection_fn, interior_action_selection_fn)
    
    # 2. 리프 노드 확장
    leaf_node = expand_leaf(tree, path, recurrent_fn)
    
    # 3. 가치 백프로파게이션
    tree = backup_values(tree, path, leaf_node.value)
    
    return tree, rng_key

이 과정을 통해 탐색 트리가 점진적으로 확장되며, 각 노드의 가치 추정이 개선됩니다.

루트 노드 행동 선택: gumbel_muzero_root_action_selection

루트 노드에서의 행동 선택은 gumbel_muzero_root_action_selection 함수에서 구현됩니다. 이 함수는 Gumbel-Top-k 트릭과 순차적 반감법을 결합한 정교한 알고리즘입니다.

함수 구현 분석


def gumbel_muzero_root_action_selection(
    rng_key: chex.PRNGKey,
    tree: tree_lib.Tree,
    node_index: chex.Numeric,
    depth: chex.Numeric,
    *,
    num_simulations: chex.Numeric,
    max_num_considered_actions: chex.Numeric,
    qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
) -> chex.Array:
    # 1. 기본 정보 수집
    visit_counts = tree.children_visits[node_index]
    prior_logits = tree.children_prior_logits[node_index]
    
    # 2. Q값 변환 (완성된 Q값 계산)
    completed_qvalues = qtransform(tree, node_index)
    
    # 3. 순차적 반감법 방문 계획 조회
    simulation_index = jnp.sum(visit_counts, axis=-1, dtype=jnp.int32)
    table = seq_halving.get_table_of_considered_visits(
        max_num_considered_actions, num_simulations)
    num_considered, considered_visit = table[simulation_index]
    
    # 4. Gumbel 노이즈 사용
    gumbel = tree.extra_data.root_gumbel
    
    # 5. 최종 스코어 계산 및 행동 선택
    to_argmax = seq_halving.score_considered(
        considered_visit, gumbel, prior_logits, completed_qvalues, visit_counts)
    
    return masked_argmax(to_argmax, tree.root_invalid_actions)

핵심 처리 단계

  1. Q값 변환: qtransform_completed_by_mix_value를 통해 완성된 Q값을 계산합니다.
  2. 방문 계획 조회: 현재 시뮬레이션 단계에서 고려해야 할 행동 수와 방문 횟수를 결정합니다.
  3. 스코어 계산: Gumbel 노이즈, 사전 확률, 완성된 Q값을 결합하여 최종 스코어를 계산합니다.

Sequential Halving 스코어 계산


def score_considered(considered_visit, gumbel, prior_logits, completed_qvalues, visit_counts):
    """순차적 반감법에 따른 행동 스코어 계산"""
    # 현재 고려 중인 행동들만 선택
    considered_mask = jnp.arange(len(gumbel)) < considered_visit
    
    # 최종 스코어 = Gumbel + prior_logits + completed_qvalues
    scores = gumbel + prior_logits + completed_qvalues
    
    # 고려 대상이 아닌 행동은 매우 낮은 스코어 부여
    scores = jnp.where(considered_mask, scores, -jnp.inf)
    
    return scores

내부 노드 행동 선택: gumbel_muzero_interior_action_selection

내부 노드에서의 행동 선택은 gumbel_muzero_interior_action_selection 함수가 담당합니다. 루트 노드와 달리 내부 노드에서는 탐험보다는 활용(Exploitation)에 초점을 맞춘 결정론적 선택을 수행합니다.

함수 구현 분석


def gumbel_muzero_interior_action_selection(
    rng_key: chex.PRNGKey,
    tree: tree_lib.Tree,
    node_index: chex.Numeric,
    depth: chex.Numeric,
    *,
    qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
) -> chex.Array:
    # 1. 기본 정보 수집
    visit_counts = tree.children_visits[node_index]
    prior_logits = tree.children_prior_logits[node_index]
    
    # 2. 완성된 Q값 계산
    completed_qvalues = qtransform(tree, node_index)
    
    # 3. 개선된 정책 구성: prior + Q값
    improved_policy_logits = prior_logits + completed_qvalues
    probs = jax.nn.softmax(improved_policy_logits)
    
    # 4. 방문 빈도 기반 스코어 계산
    to_argmax = _prepare_argmax_input(
        probs=probs, 
        visit_counts=visit_counts
    )
    
    # 5. 최댓값 선택 (결정론적)
    return jnp.argmax(to_argmax, axis=-1).astype(jnp.int32)

개선된 정책의 핵심 아이디어

내부 노드에서는 정규화된 정책 최적화(Regularized Policy Optimization) 원리를 따릅니다:

$\pi_{\text{improved}}(a) \propto \exp(\log \pi_{\text{prior}}(a) + Q_{\text{completed}}(a))$

이는 다음과 같이 해석할 수 있습니다:

  • 사전 정책: 정책 네트워크의 초기 예측
  • 완성된 Q값: 탐색을 통해 얻은 가치 정보
  • 결합: 두 정보를 곱셈적으로 결합하여 더 정확한 정책 생성

_prepare_argmax_input 함수


def _prepare_argmax_input(probs, visit_counts):
    """방문 횟수를 고려한 argmax 입력 준비"""
    # 목표 방문 횟수 계산
    total_visits = jnp.sum(visit_counts)
    target_visits = probs * total_visits
    
    # 현재 방문 횟수와 목표 방문 횟수의 차이
    visit_deficit = target_visits - visit_counts
    
    # 방문이 부족한 행동에 높은 우선순위 부여
    return visit_deficit

이 함수는 현재 방문 분포가 목표 정책 분포에 근사하도록 유도하는 핵심 메커니즘입니다.

순차적 반감법 방문 계획 구현

순차적 반감법의 핵심은 '언제, 어떤 후보들을, 몇 번 방문할지'를 미리 정하는 방문 계획에 있습니다. 이 계획은 get_sequence_of_considered_visits 함수에서 생성됩니다.

방문 계획 생성 함수


def get_sequence_of_considered_visits(max_num_considered_actions, num_simulations):
    """Sequential Halving의 방문 스케줄 생성"""
    if max_num_considered_actions <= 1:
        return tuple(range(num_simulations))
    
    log2max = int(math.ceil(math.log2(max_num_considered_actions)))
    sequence = []
    visits = [0] * max_num_considered_actions
    num_considered = max_num_considered_actions
    
    while len(sequence) < num_simulations:
        # 현재 페이즈에서 각 후보당 추가 방문 횟수
        num_extra_visits = max(1, int(num_simulations / (log2max * num_considered)))
        
        # 현재 고려 중인 후보들에게 방문 배분
        for _ in range(num_extra_visits):
            for i in range(num_considered):
                if len(sequence) < num_simulations:
                    sequence.append(i)
                    visits[i] += 1
        
        # 다음 페이즈를 위해 후보 수 반감
        num_considered = max(1, num_considered // 2)
    
    return tuple(sequence)

방문 테이블 생성

실제 구현에서는 효율성을 위해 방문 계획을 미리 테이블로 생성합니다:


def get_table_of_considered_visits(max_num_considered_actions, num_simulations):
    """방문 계획을 테이블 형태로 미리 계산"""
    table = []
    
    for simulation_index in range(num_simulations + 1):
        # 현재 시뮬레이션에서 고려할 후보 수와 방문 횟수 결정
        phase = determine_phase(simulation_index, max_num_considered_actions, num_simulations)
        num_considered = max_num_considered_actions >> phase
        considered_visit = simulation_index % num_considered
        
        table.append((num_considered, considered_visit))
    
    return jnp.array(table)

실제 동작 예시

8개 후보(max_num_considered_actions=8)를 32번 시뮬레이션(num_simulations=32)하는 경우:

페이즈 고려 후보 수 각 후보당 방문 총 방문 횟수
1 8 4 32
2 4 4 16
3 2 8 16

이러한 계획을 통해 초반에는 넓게 탐색하고, 점진적으로 유망한 후보에 집중하는 효율적인 자원 배분이 가능합니다.

통합 아키텍처 및 성능 최적화

Gumbel AlphaZero의 모든 구성 요소들이 어떻게 통합되어 동작하는지 살펴보겠습니다.


def gumbel_muzero_policy(
    params,
    rng_key: chex.PRNGKey,
    root: RootFnOutput,
    recurrent_fn: RecurrentFn,
    num_simulations: int,
    max_depth: int = 1000,
    max_num_considered_actions: int = 16,
    qtransform: QTransform = qtransforms.qtransform_completed_by_mix_value,
) -> PolicyOutput:
    """Gumbel AlphaZero의 통합 정책 함수"""
    
    # 1. Gumbel 노이즈 생성 및 저장
    rng_key, gumbel_rng = jax.random.split(rng_key)
    root_gumbel = jax.random.gumbel(gumbel_rng, shape=root.prior_logits.shape)
    
    # 2. 행동 선택 함수 정의
    root_action_selection_fn = functools.partial(
        gumbel_muzero_root_action_selection,
        num_simulations=num_simulations,
        max_num_considered_actions=max_num_considered_actions,
        qtransform=qtransform
    )
    
    interior_action_selection_fn = functools.partial(
        gumbel_muzero_interior_action_selection,
        qtransform=qtransform
    )
    
    # 3. 탐색 실행
    search_tree = search(
        params=params,
        rng_key=rng_key,
        root=root,
        recurrent_fn=recurrent_fn,
        root_action_selection_fn=root_action_selection_fn,
        interior_action_selection_fn=interior_action_selection_fn,
        num_simulations=num_simulations,
        max_depth=max_depth,
        qtransform=qtransform
    )
    
    # 4. 정책 출력 생성
    return create_policy_output(search_tree, qtransform)

JAX 기반 성능 최적화

Gumbel AlphaZero는 JAX의 강력한 기능들을 활용하여 높은 성능을 달성합니다:

  1. JIT 컴파일: @jax.jit 데코레이터를 통한 즉시 컴파일
  2. 벡터화: jax.vmap을 통한 배치 처리
  3. 병렬화: jax.pmap을 통한 다중 디바이스 활용
  4. 메모리 효율성: JAX의 함수형 프로그래밍 패러다임

핵심 성능 특징

  • 메모리 효율성: 트리 구조의 효율적인 표현
  • 계산 최적화: 불필요한 연산 최소화
  • 수치 안정성: 언더플로우/오버플로우 방지
  • 확장성: 대규모 행동 공간 지원

시리즈를 마치며

총 네 편의 포스트를 통해 Gumbel AlphaZero의 이론적 배경부터 실제 구현까지 상세히 살펴보았습니다.

주요 포인트 정리

  1. 1편: 기본 개념과 Mctx 라이브러리 소개
  2. 2편: Gumbel-Top-k와 순차적 반감법의 행동 선택 알고리즘
  3. 3편: 완성된 Q값을 활용한 정책 학습
  4. 4편: 실제 구현의 세부사항과 통합 아키텍처

Gumbel AlphaZero는 기존 AlphaZero의 한계를 명확히 인식하고, 이론적으로 보장된 정책 개선과 효율적인 탐색을 동시에 달성한 혁신적인 알고리즘입니다. 특히 적은 시뮬레이션만으로도 높은 성능을 보장한다는 점에서 실용적 가치가 매우 높습니다.

이 시리즈가 Gumbel AlphaZero를 이해하고 자신의 프로젝트에 적용해보고자 하는 분들께 도움이 되었기를 바랍니다.

Previous Post

Gumbel AlphaZero 핵심 알고리즘 2: 정책 학습

Gumbel AlphaZero의 핵심 알고리즘, 정책 학습에 대해 이야기합니다.

Gumbel AlphaZero 핵심 알고리즘 2: 정책 학습

Next Post

새 블로그로 이사했습니다.

블로그를 이사하게 된 이유와 이 블로그의 방향성에 대해 간단히 이야기합니다.

새 블로그로 이사했습니다.

Recommended Reading

scroll to top