220v
젝무의 개발새발
220v
전체 방문자
오늘
어제
  • 분류 전체보기 (255)
    • AI (35)
      • ML, DL 학습 (30)
      • 논문 리뷰 (4)
      • 실습 및 프로젝트 (1)
    • Algorithm (145)
      • LeetCode (13)
      • 프로그래머스 (35)
      • 백준 (96)
      • 알고리즘, 문법 정리 (1)
    • Mobile, Application (17)
      • Flutter (10)
      • iOS, MacOS (7)
    • BackEnd (7)
      • Flask (1)
      • Node.js (5)
      • Spring, JSP..etc (1)
    • Web - FrontEnd (18)
      • JavaScript, JQuery, HTML, C.. (12)
      • React (6)
    • DataBase (1)
      • MySQL (1)
      • Firebase Firestore (0)
      • Supabase (0)
    • Git (1)
    • 기타 툴 및 오류 해결 (3)
    • 강의 (5)
      • Database (3)
      • 암호학 (2)
      • 알고리즘 (0)
    • 후기와 회고 (2)
    • 블로그 꾸미기 (1)
    • 일상과 이것저것 (20)
      • 맛집 (12)
      • 세상사는일 (4)
      • 도서리뷰 (1)
      • 이런저런 생각들 (잡글) (3)

블로그 메뉴

  • 홈
  • 태그
  • 방명록

공지사항

인기 글

태그

  • two pointer
  • topological sort
  • 위상 정렬
  • 백준
  • 프로그래머스
  • Priority Queue
  • union-find
  • 티스토리챌린지
  • Greedy
  • 구현
  • Backtracking
  • disjoint set
  • simulation
  • dfs
  • top-down
  • Mathematics
  • Prefix Sum
  • 오블완
  • brute-Force
  • Minimum Spanning Tree
  • IMPLEMENT
  • REACT
  • Lis
  • implementation
  • bitmasking
  • Dynamic Programming
  • BFS
  • binary search
  • dp
  • 다익스트라

최근 댓글

최근 글

티스토리

hELLO · Designed By 정상우.
220v

젝무의 개발새발

AI/실습 및 프로젝트

[nltk] nltk tokenizer 사용 중 nltk LookupError 해결 (nltk.tokenize.word_tokenize)

2024. 11. 28. 00:59

문제 상황

from pycocoevalcap.cider.cider import Cider
import matplotlib.pyplot as plt
import numpy as np
import nltk
import re
import os
import torch
from nltk.tokenize import word_tokenize

# NLTK 데이터 다운로드
nltk.download('punkt')

# 캡션 전처리 함수
def preprocess_caption(caption):
    # 소문자 변환
    caption = caption.lower()
    # 특수문자 제거
    caption = re.sub(r'[^\w\s]', '', caption)
    # 토크나이즈
    tokens = word_tokenize(caption)
    # 문자열로 다시 결합
    return ' '.join(tokens)

# CIDEr 평가 함수
def evaluate_and_visualize_with_cider(model, test_loader, word_index, device, image_dir, captions_dict, num_display=5):
    """
    CIDEr 점수 계산 및 결과 시각화

    Args:
        model: 캡션 생성 모델
        test_loader: 테스트 데이터 로더
        word_index: 단어-인덱스 매핑
        device: PyTorch 디바이스 (CPU/GPU)
        image_dir: 이미지 디렉토리 경로
        captions_dict: 이미지 이름과 참조 캡션 매핑 딕셔너리
        num_display: 시각화할 이미지 수
    """
    model.eval()
    index_to_word = {idx: word for word, idx in word_index.items()}
    results = []
    test_examples = []
    
    with torch.no_grad():
        for batch_idx, (features, captions, image_names_batch) in enumerate(test_loader):
            features = features.to(device)
            
            for i in range(features.size(0)):  # 배치 크기만큼 반복
                # Get the image name
                image_name = image_names_batch[i]

                # 캡션 생성
                generated_caption = generate_caption(model, features[i], word_index)
                # 전처리된 캡션
                generated_caption_proc = preprocess_caption(generated_caption)

                # 참조 캡션 가져오기 및 전처리
                references = captions_dict.get(image_name, [])
                references_proc = [preprocess_caption(ref) for ref in references]

                # 평가 데이터 준비
                results.append({
                    "image_id": image_name,
                    "candidate": generated_caption_proc,
                    "references": references_proc
                })

                # 테스트 예시 저장 (num_display 개수만 저장)
                if len(test_examples) < num_display:
                    image_path = os.path.join(image_dir, image_name)
                    test_examples.append({
                        "image_path": image_path,
                        "generated": generated_caption,
                        "references": references
                    })
    
    # CIDEr 점수 계산
    print("Calculating CIDEr scores...")
    cider_scorer = Cider()
    
    # gts와 res 딕셔너리 생성
    gts = {}
    res = {}
    for res_item in results:
        image_id = res_item["image_id"]
        gts[image_id] = res_item["references"]  # 리스트 형태의 참조 캡션들
        res[image_id] = [res_item["candidate"]]  # 생성된 캡션을 리스트로 감싸서 전달

    # CIDEr 점수 계산
    cider_score, cider_scores = cider_scorer.compute_score(gts, res)
    
    avg_cider_score = cider_score  # 평균 CIDEr 점수
    print(f"\nCIDEr Metric Evaluation:")
    print(f"Average CIDEr Score: {avg_cider_score:.4f}")

    # CIDEr 점수 분포 시각화
    plt.figure(figsize=(10, 5))
    plt.hist(cider_scores, bins=50, alpha=0.7)
    plt.title("Distribution of CIDEr Scores on Test Set")
    plt.xlabel("CIDEr Score")
    plt.ylabel("Count")
    plt.show()

    # 예시 출력
    print("\nExample Generations:")
    for idx, example in enumerate(test_examples):
        print(f"\nExample {idx + 1}")
        print(f"Image Path: {example['image_path']}")
        print(f"Generated Caption: {example['generated']}")
        print(f"References: {example['references']}")

    return avg_cider_score, cider_scores

# CIDEr 평가 실행
print("Starting CIDEr evaluation...")
avg_cider_score, cider_scores = evaluate_and_visualize_with_cider(
    model, test_loader, word_index, device, image_dir, captions_dict
)

print(f"Average CIDEr Score: {avg_cider_score:.4f}")
print(f"Total samples used for CIDEr calculation: {len(cider_scores)}")

위와 같이 nltk를 이용하여 CIDEr score를 계산하려 했는데,

이미 nltk.download('punkt') 로 다운로드를 했음에도, 아래와 같은 에러가 발생함.

 

---------------------------------------------------------------------------
LookupError                               Traceback (most recent call last)
Cell In[65], line 115
    113 # CIDEr 평가 실행
    114 print("Starting CIDEr evaluation...")
--> 115 avg_cider_score, cider_scores = evaluate_and_visualize_with_cider(
    116     model, test_loader, word_index, device, image_dir, captions_dict
    117 )
    119 print(f"Average CIDEr Score: {avg_cider_score:.4f}")
    120 print(f"Total samples used for CIDEr calculation: {len(cider_scores)}")

Cell In[65], line 54
     52 generated_caption = generate_caption(model, features[i], word_index)
     53 # 전처리된 캡션
---> 54 generated_caption_proc = preprocess_caption(generated_caption)
     56 # 참조 캡션 가져오기 및 전처리
     57 references = captions_dict.get(image_name, [])

Cell In[65], line 20
     18 caption = re.sub(r'[^\w\s]', '', caption)
     19 # 토크나이즈
---> 20 tokens = word_tokenize(caption)
     21 # 문자열로 다시 결합
     22 return ' '.join(tokens)

File ~/.conda/envs/DL/lib/python3.10/site-packages/nltk/tokenize/__init__.py:142, in word_tokenize(text, language, preserve_line)
    127 def word_tokenize(text, language="english", preserve_line=False):
    128     """
    129     Return a tokenized copy of *text*,
    130     using NLTK's recommended word tokenizer
   (...)
    140     :type preserve_line: bool
    141     """
--> 142     sentences = [text] if preserve_line else sent_tokenize(text, language)
    143     return [
    144         token for sent in sentences for token in _treebank_word_tokenizer.tokenize(sent)
    145     ]

File ~/.conda/envs/DL/lib/python3.10/site-packages/nltk/tokenize/__init__.py:119, in sent_tokenize(text, language)
    109 def sent_tokenize(text, language="english"):
    110     """
    111     Return a sentence-tokenized copy of *text*,
    112     using NLTK's recommended sentence tokenizer
   (...)
    117     :param language: the model name in the Punkt corpus
...
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

 

해결

우선 아래와 같은 코드로 tokenizer가 안되는 이유를 디버깅.

import nltk
from nltk.tokenize import word_tokenize

try:
    word_tokenize("Test sentence for debugging.")
    print("Tokenization successful!")
except LookupError as e:
    print("Error:", e)
    print("NLTK data path:", nltk.data.path)

 

Error: 
**********************************************************************
  Resource punkt_tab not found.
  Please use the NLTK Downloader to obtain the resource:

  >>> import nltk
  >>> nltk.download('punkt_tab')
  
  For more information see: https://www.nltk.org/data.html

  Attempted to load tokenizers/punkt_tab/english/

  Searched in:
    - '/home/gpu_04/nltk_data'
    - '/home/gpu_04/.conda/envs/DL/nltk_data'
    - '/home/gpu_04/.conda/envs/DL/share/nltk_data'
    - '/home/gpu_04/.conda/envs/DL/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
    - '/home/gpu_04/.conda/envs/DL/nltk_data'
**********************************************************************

출력대로 nltk.download('punkt_tab') 으로 설치해주니 해결.

만약 path 문제라면, path를 추가해주면 될 것 같다.

    220v
    220v
    DGU CSE 20 / Apple Developer Academy @ POSTECH 2nd Jr.Learner.

    티스토리툴바