かもブログ

かもかも(@kam0_2)の雑記。地理やITに関するにわか稚拙センテンスの掃き溜め。

低メモリ(RAM)環境で、gensimのword2vecモデルを使うテクニック(備忘録)

この記事について

 Python 3.x で自然言語処理ライブラリgensimでword2vecを使います。学習済みのモデルデータを読み込んで、各単語のベクトルをえられるようにしたいとき、低メモリ(RAM)環境では、うまく行かないことがあります。この記事では、低メモリ環境で、学習済みword2vecモデルから各単語ベクトルを抽出する方法を紹介します。

問題の根源

 gensimのword2vecモデルはmost_similar()メソッドのように、「すべてのベクトルをメモリ上に展開していないと使えないメソッド」を含んでいます。そのため、モデルをロードする際に、モデルのサイズ以上の空きメモリがないと、読み込みに失敗します。自分の手元には日本語版Wikipediaから作ったモデルがありますが、モデルと付随するファイル含めて合計で1.3GBあります。一般的なパソコンでは問題ないですが、VPSなどの極めてメモリが少ない環境では問題になります。

解決法1 KeyedVectorsのみにする

 gensimのword2vecモデルは単語のベクトル情報以外にも様々な情報を持っています。これらを切り落として、単純なベクトルのみにします。

f:id:two_headed_duck:20200405150704p:plain
https://radimrehurek.com/gensim/models/keyedvectors.html より引用
再学習はできなくなりますが、オブジェクトを軽量化できます。また、オリジナルのword2vec形式での保存/読み込みが可能になります。これは後に重要になります。

from gensim.models import Word2Vec
model = Word2Vec.load("jawiki.model") # gensim形式のモデルをロードします
model.wv.save("jawiki.kv")  # KeyedVectorsのみを保存します。

当方では、1.3GBから732MBに減量することができました。これを読み込むには以下のようにします。

from gensim.models import KeyedVectors
kvs = KeyedVectors.load("jawiki.kv")
for l in kvs.most_similar(positive=["女性", "皇帝"]):
    print(l)

most_similar()similarity()などがそのまま使えます。kvs["安倍晋三"]のようにしてここのベクトルが得られます。

解決法2 正規化して元データを消す

 正規化とは「すべての単語ベクトルを単位ベクトルにする」ことです。単語ベクトルを扱う場合は、ベクトルの長さを無視して考えたほうが効果的であることが知られています。出現頻度や単語の文字長にベクトルの長さは影響されるため、単語同士の類似度を考える上では余計です。また、similarity()で利用されるコサイン類似度は比較対象のベクトルの角度差のみで計算されます。事前に正規化することで、後の計算処理を減らすことができます。
 事前正規化を行っている場合、元のオリジナルベクトルを消しておくことで、メモリー消費量を減らせます。半分にできるはずです。事前正規化はinit_sims()メソッドで行いますが、init_sims(replace=True)とすると、オリジナルのベクトルを忘れます。これを行うと、再学習はできなくなります。

kvs.init_sims(replace=True)
kvs.save("jawiki_normalized.kv")

解決法3 読み込む単語数を絞り込む

 word2vecのオリジナル形式からモデルを読み込む際は読み込む単語数を絞り込むことができます。まずは、gensim形式のモデルをword2vec形式に変換します。

kvs.save_word2vec_format("jawiki_normalized.kv.bin", binary=True)

バイナリ形式で保存することで、若干のファイル軽量化効果があります。50MBぐらい小さくなりました。また、単一のファイルになります。読み込む場合はこうします。

kvs = KeyedVectors.load_word2vec_format(
        "jawiki_normalized.kv.bin",
        binary=True,
        limit=10000)
print(len(kvs.vocab))  # -> 10000

limit=*を書き換えることで任意の数の単語のみを読み込みます。どの単語が選ばれるかなどはよく調べてません。

最終解決法

 今までの方法は結局の所、モデルをすべてメモリにロードする必要がありました。最終解決法ではgensimを使うのをやめて、必要な単語のベクトルだけをメモリにロードできるようにします。most_similar()などは使えなくなりますが、similarity()などは自分で実装することで、使えるようになります。
 まずは、gensimからモデルをテキスト形式で吐かせます。

kvs = KeyedVectors.load("jawiki_normalized.kv")
kvs.save_word2vec_format("jawiki_normalized.kv.txt", binary=False)

headしてみるとこんなデータができます

879000 200
の -0.18367043 -0.029534407 -0.04568029 0.059862584 -0.07997447 -0.046858925 0.14023784 0.07425413 -0.008315135 -0.12953435 -0.026671728 0.0070677847 0.16410053 -0.018117158 0.010087145 0.01466953 0.041376486 -0.1008434 -0.06180911 -0.058874626 0.008909045 0.044234663 0.014539371 -0.0028310632 0.018776815 0.049506992 0.073248215 -3.157805e-05 0.10673941 0.03583274 -0.10652217 -0.059396442 -0.0261147 -0.03284311 0.081868224 -0.012962754 -0.034510043 -0.10235525 -0.0769274 0.082873024 -0.012097447 0.0727865 0.014861626 0.03009095 -0.030071974 0.027952265 -0.025505234 -0.05018914 0.0417648 -0.02850465 0.07273282 -0.09483565 -0.04345289 0.15165983 -0.031632055 0.17059058 -0.044084977 0.0052842456 -0.028103009 -0.04878716 -0.024238652 0.054972656 -0.03674599 -0.040995907 -0.0050267056 -0.055824015 -0.062449194 0.13537866 0.07244164 -0.0714956 -0.019936046 -0.16671564 -0.100499384 0.007035132 -0.11507068 0.04457269 -0.13684298 -0.0803739 -0.017586963 0.08030843 0.033362087 -0.108916104 0.039713085 0.04878837 -0.04515106 -0.018243527 0.12831812 -0.017795842 0.04866434 -0.060816325 -0.050964024 -0.01270505 -0.033732776 -0.08854351 0.06676082 -0.05967011 0.07151932 0.061369702 -0.044702165 0.05061344 0.14359577 -0.051136304 0.0615069 -0.022738786 0.051260248 -0.07742282 0.02800144 0.03599446 0.06877219 -0.01719892 0.06675788 0.06307474 -0.036053922 -0.22192973 -0.048423205 -0.04822691 0.059284963 0.052187417 -0.07442425 -0.07336827 -0.118471846 -0.08167546 -0.013464366 0.006391276 0.0012051899 -0.114997625 -0.055676162 -0.053286873 0.047492016 -0.053404514 -0.042560868 0.039681405 0.11049425 0.07539424 -0.08105653 0.06514334 0.116133265 0.0396685 0.050337143 -0.055580735 -0.12020915 0.010553302 0.100789204 0.057277914 -0.01211393 -0.0009179714 0.08036089 0.0055075907 -0.17496698 0.13960186 -0.016151525 -0.062465787 0.038296852 0.017309219 0.012427443 -0.02194943 -0.09682389 -0.06353024 -0.060949646 0.020672444 0.13299905 0.028146993 0.07839865 0.022320326 -0.047637332 0.0005690668 0.07873213 -0.02366066 0.053062834 -0.011705997 0.056978896 -0.13023455 -0.048329093 -0.012450362 0.08268035 0.027270395 -0.09271215 0.014355482 0.11091846 -0.016312167 -0.06322232 0.024058044 -0.15635346 -0.060165983 -0.06801029 0.053233985 -0.091366336 -0.04244707 0.03506038 0.010074944 -0.039310183 0.003523075 0.026648186 -0.09815444 0.0450355 0.0151189305 -0.09069526 -0.015519027 0.06754098 -0.021642119
(略)

先頭行に単語数 ベクトル次元数があり、以降の行に単語 ベクトル...となっています。必要な単語だけこのファイルからベクトルを読み込むようにすれば、低メモリ環境でも何十万というヴォキャブラリを活かすことができます。

SQLiteのDBにしてしまう。

 個人的なおすすめはこのファイルをさらにデータベースにしてしまうことです。CSVなどより検索が早くなります。クソみたいなコードを貼ります。

import sqlite3
import logging

WORD2VEC_TEXT = "jawiki_normalized.kv.txt"
DB_PATH = "jawiki_normalized.kv.db"

def return_column_names(n):
    if n >= 0:
        r = "keyword TEXT PRIMARY KEY,"
        for i in range(n):
            r += f" vec{i} integer,"
    else:
        r = "?,"
        for i in range(-1*n):
            r += "?,"
    return r[0:-1]

if __name__=="__main__":

    fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
    logging.basicConfig(level=logging.INFO, format=fmt)

    con = sqlite3.connect(DB_PATH)
    c = con.cursor()

    with open(WORD2VEC_TEXT, "r") as f:
        fl = f.readline()
        logging.info("Start: " + fl)
        count = int(fl.split(" ")[0])
        size = int(fl.split(" ")[1])
        c.execute(f"create table kv( {return_column_names(size)} );")
        for i in range(2, count+2):
            line = f.readline()
            row = line.split(" ")
            if len(row) != (size+1):
                logging.warn(f"Index error: line: {i} text: {line}")
                continue
            c.execute(f"insert into kv values ({return_column_names(-1*size)});", row)
            if i % 10**4 == 0:
                logging.info(f"{i} / {count} done. {i/count*100} percent.")
                con.commit()
    logging.info("DONE!!!")
    con.commit()
    con.close()

DBができたら、インデックスを作っておきましょう。

$ sqlite3 jawiki_normalized.kv.db
SQLite version 3.22.0 2018-01-22 18:45:57
Enter ".help" for usage hints.
sqlite> create index keyword_index on kv(keyword);
sqlite> .exit

DBのCursorを受け取って、ベクトルを返す関数は以下のようになります。

def get_vec(c_kv, surface):
    c_kv.execute("select * from kv where keyword=?;", (surface,))
    r = c_kv.fetchall()[0]
    return np.array(r[1:])

c_kvがDBのCursorで、surfaceが単語、単語が見つからない場合はIndexErrorになります。

コサイン類似度の算出

 別にユークリッド距離でも構わないと思いますが、ユークリッド距離は「短いほど近く」コサイン類似度は「大きいほど近く」なります。正規化を行っている場合は、ユークリッド距離は2の時が一番遠くなります。コサイン類似度では-1が一番遠く、1が一番近くなります。コサイン類似度はsimilarity()などで使われています。gensimを使わないので自分で実装する必要があります。

コサイン類似度ってなんやねん

コサイン類似度は2つのベクトルabの間の角度をθとしたときのcosθのこと
(やたら難しい概念かと思ったら高校1年生でもわかりそうなものでびっくりしちゃった。)具体的にどう求めるかというと、ベクトルの内角の公式より、

f:id:two_headed_duck:20200405164656p:plain

そんだけ。ちなみに、abユークリッド距離をdとおくと、
f:id:two_headed_duck:20200405171204p:plain
となり、ここで、abは正規化されていて単位ベクトルなので、
f:id:two_headed_duck:20200405171937p:plain
という関係がコサイン類似度とユークリッド距離で成り立つ。

実装

ユークリッド距離を求めてから、コサイン類似度を求めたほうが早い気がするのでそうする。
f:id:two_headed_duck:20200405172324p:plain より、

def get_similarity(vec1, vec2):
    # vec1, vec2は事前正規化済みとして
    # ユークリッド距離
    diff = (vec1 - vec2) ** 2
    d_2 = diff.sum()    # ユークリッド距離の二乗
    # コサイン類似度に
    cos_sim = (2 - d_2) / 2
    return cos_sim

とすると求められる。簡単じゃん。

その他

 最終的解決法をつかうとモデルをメモリにロードする必要がなくなるので、ロード時間を大幅に減らせます。most_similar()などを使わなければ、限りなくいい方法だと思います。よんでくれてありがとね。