この記事について
Python 3.x で自然言語処理ライブラリgensimでword2vecを使います。学習済みのモデルデータを読み込んで、各単語のベクトルをえられるようにしたいとき、低メモリ(RAM)環境では、うまく行かないことがあります。この記事では、低メモリ環境で、学習済みword2vecモデルから各単語ベクトルを抽出する方法を紹介します。
問題の根源
gensimのword2vecモデルはmost_similar()
メソッドのように、「すべてのベクトルをメモリ上に展開していないと使えないメソッド」を含んでいます。そのため、モデルをロードする際に、モデルのサイズ以上の空きメモリがないと、読み込みに失敗します。自分の手元には日本語版Wikipediaから作ったモデルがありますが、モデルと付随するファイル含めて合計で1.3GBあります。一般的なパソコンでは問題ないですが、VPSなどの極めてメモリが少ない環境では問題になります。
解決法1 KeyedVectorsのみにする
gensimのword2vecモデルは単語のベクトル情報以外にも様々な情報を持っています。これらを切り落として、単純なベクトルのみにします。
from gensim.models import Word2Vec model = Word2Vec.load("jawiki.model") # gensim形式のモデルをロードします model.wv.save("jawiki.kv") # KeyedVectorsのみを保存します。
当方では、1.3GBから732MBに減量することができました。これを読み込むには以下のようにします。
from gensim.models import KeyedVectors kvs = KeyedVectors.load("jawiki.kv") for l in kvs.most_similar(positive=["女性", "皇帝"]): print(l)
most_similar()
やsimilarity()
などがそのまま使えます。kvs["安倍晋三"]
のようにしてここのベクトルが得られます。
解決法2 正規化して元データを消す
正規化とは「すべての単語ベクトルを単位ベクトルにする」ことです。単語ベクトルを扱う場合は、ベクトルの長さを無視して考えたほうが効果的であることが知られています。出現頻度や単語の文字長にベクトルの長さは影響されるため、単語同士の類似度を考える上では余計です。また、similarity()
で利用されるコサイン類似度は比較対象のベクトルの角度差のみで計算されます。事前に正規化することで、後の計算処理を減らすことができます。
事前正規化を行っている場合、元のオリジナルベクトルを消しておくことで、メモリー消費量を減らせます。半分にできるはずです。事前正規化はinit_sims()
メソッドで行いますが、init_sims(replace=True)
とすると、オリジナルのベクトルを忘れます。これを行うと、再学習はできなくなります。
kvs.init_sims(replace=True) kvs.save("jawiki_normalized.kv")
解決法3 読み込む単語数を絞り込む
word2vecのオリジナル形式からモデルを読み込む際は読み込む単語数を絞り込むことができます。まずは、gensim形式のモデルをword2vec形式に変換します。
kvs.save_word2vec_format("jawiki_normalized.kv.bin", binary=True)
バイナリ形式で保存することで、若干のファイル軽量化効果があります。50MBぐらい小さくなりました。また、単一のファイルになります。読み込む場合はこうします。
kvs = KeyedVectors.load_word2vec_format( "jawiki_normalized.kv.bin", binary=True, limit=10000) print(len(kvs.vocab)) # -> 10000
limit=*
を書き換えることで任意の数の単語のみを読み込みます。どの単語が選ばれるかなどはよく調べてません。
最終解決法
今までの方法は結局の所、モデルをすべてメモリにロードする必要がありました。最終解決法ではgensimを使うのをやめて、必要な単語のベクトルだけをメモリにロードできるようにします。most_similar()
などは使えなくなりますが、similarity()
などは自分で実装することで、使えるようになります。
まずは、gensimからモデルをテキスト形式で吐かせます。
kvs = KeyedVectors.load("jawiki_normalized.kv") kvs.save_word2vec_format("jawiki_normalized.kv.txt", binary=False)
headしてみるとこんなデータができます
879000 200 の -0.18367043 -0.029534407 -0.04568029 0.059862584 -0.07997447 -0.046858925 0.14023784 0.07425413 -0.008315135 -0.12953435 -0.026671728 0.0070677847 0.16410053 -0.018117158 0.010087145 0.01466953 0.041376486 -0.1008434 -0.06180911 -0.058874626 0.008909045 0.044234663 0.014539371 -0.0028310632 0.018776815 0.049506992 0.073248215 -3.157805e-05 0.10673941 0.03583274 -0.10652217 -0.059396442 -0.0261147 -0.03284311 0.081868224 -0.012962754 -0.034510043 -0.10235525 -0.0769274 0.082873024 -0.012097447 0.0727865 0.014861626 0.03009095 -0.030071974 0.027952265 -0.025505234 -0.05018914 0.0417648 -0.02850465 0.07273282 -0.09483565 -0.04345289 0.15165983 -0.031632055 0.17059058 -0.044084977 0.0052842456 -0.028103009 -0.04878716 -0.024238652 0.054972656 -0.03674599 -0.040995907 -0.0050267056 -0.055824015 -0.062449194 0.13537866 0.07244164 -0.0714956 -0.019936046 -0.16671564 -0.100499384 0.007035132 -0.11507068 0.04457269 -0.13684298 -0.0803739 -0.017586963 0.08030843 0.033362087 -0.108916104 0.039713085 0.04878837 -0.04515106 -0.018243527 0.12831812 -0.017795842 0.04866434 -0.060816325 -0.050964024 -0.01270505 -0.033732776 -0.08854351 0.06676082 -0.05967011 0.07151932 0.061369702 -0.044702165 0.05061344 0.14359577 -0.051136304 0.0615069 -0.022738786 0.051260248 -0.07742282 0.02800144 0.03599446 0.06877219 -0.01719892 0.06675788 0.06307474 -0.036053922 -0.22192973 -0.048423205 -0.04822691 0.059284963 0.052187417 -0.07442425 -0.07336827 -0.118471846 -0.08167546 -0.013464366 0.006391276 0.0012051899 -0.114997625 -0.055676162 -0.053286873 0.047492016 -0.053404514 -0.042560868 0.039681405 0.11049425 0.07539424 -0.08105653 0.06514334 0.116133265 0.0396685 0.050337143 -0.055580735 -0.12020915 0.010553302 0.100789204 0.057277914 -0.01211393 -0.0009179714 0.08036089 0.0055075907 -0.17496698 0.13960186 -0.016151525 -0.062465787 0.038296852 0.017309219 0.012427443 -0.02194943 -0.09682389 -0.06353024 -0.060949646 0.020672444 0.13299905 0.028146993 0.07839865 0.022320326 -0.047637332 0.0005690668 0.07873213 -0.02366066 0.053062834 -0.011705997 0.056978896 -0.13023455 -0.048329093 -0.012450362 0.08268035 0.027270395 -0.09271215 0.014355482 0.11091846 -0.016312167 -0.06322232 0.024058044 -0.15635346 -0.060165983 -0.06801029 0.053233985 -0.091366336 -0.04244707 0.03506038 0.010074944 -0.039310183 0.003523075 0.026648186 -0.09815444 0.0450355 0.0151189305 -0.09069526 -0.015519027 0.06754098 -0.021642119 (略)
先頭行に単語数 ベクトル次元数
があり、以降の行に単語 ベクトル...
となっています。必要な単語だけこのファイルからベクトルを読み込むようにすれば、低メモリ環境でも何十万というヴォキャブラリを活かすことができます。
SQLiteのDBにしてしまう。
個人的なおすすめはこのファイルをさらにデータベースにしてしまうことです。CSVなどより検索が早くなります。クソみたいなコードを貼ります。
import sqlite3 import logging WORD2VEC_TEXT = "jawiki_normalized.kv.txt" DB_PATH = "jawiki_normalized.kv.db" def return_column_names(n): if n >= 0: r = "keyword TEXT PRIMARY KEY," for i in range(n): r += f" vec{i} integer," else: r = "?," for i in range(-1*n): r += "?," return r[0:-1] if __name__=="__main__": fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" logging.basicConfig(level=logging.INFO, format=fmt) con = sqlite3.connect(DB_PATH) c = con.cursor() with open(WORD2VEC_TEXT, "r") as f: fl = f.readline() logging.info("Start: " + fl) count = int(fl.split(" ")[0]) size = int(fl.split(" ")[1]) c.execute(f"create table kv( {return_column_names(size)} );") for i in range(2, count+2): line = f.readline() row = line.split(" ") if len(row) != (size+1): logging.warn(f"Index error: line: {i} text: {line}") continue c.execute(f"insert into kv values ({return_column_names(-1*size)});", row) if i % 10**4 == 0: logging.info(f"{i} / {count} done. {i/count*100} percent.") con.commit() logging.info("DONE!!!") con.commit() con.close()
DBができたら、インデックスを作っておきましょう。
$ sqlite3 jawiki_normalized.kv.db SQLite version 3.22.0 2018-01-22 18:45:57 Enter ".help" for usage hints. sqlite> create index keyword_index on kv(keyword); sqlite> .exit
DBのCursorを受け取って、ベクトルを返す関数は以下のようになります。
def get_vec(c_kv, surface): c_kv.execute("select * from kv where keyword=?;", (surface,)) r = c_kv.fetchall()[0] return np.array(r[1:])
c_kv
がDBのCursorで、surface
が単語、単語が見つからない場合はIndexError
になります。
コサイン類似度の算出
別にユークリッド距離でも構わないと思いますが、ユークリッド距離は「短いほど近く」コサイン類似度は「大きいほど近く」なります。正規化を行っている場合は、ユークリッド距離は2
の時が一番遠くなります。コサイン類似度では-1
が一番遠く、1
が一番近くなります。コサイン類似度はsimilarity()
などで使われています。gensimを使わないので自分で実装する必要があります。
コサイン類似度ってなんやねん
コサイン類似度は2つのベクトルaとbの間の角度をθとしたときのcosθのこと
(やたら難しい概念かと思ったら高校1年生でもわかりそうなものでびっくりしちゃった。)具体的にどう求めるかというと、ベクトルの内角の公式より、
そんだけ。ちなみに、aとbのユークリッド距離をdとおくと、
となり、ここで、aとbは正規化されていて単位ベクトルなので、
という関係がコサイン類似度とユークリッド距離で成り立つ。
実装
ユークリッド距離を求めてから、コサイン類似度を求めたほうが早い気がするのでそうする。
より、
def get_similarity(vec1, vec2): # vec1, vec2は事前正規化済みとして # ユークリッド距離 diff = (vec1 - vec2) ** 2 d_2 = diff.sum() # ユークリッド距離の二乗 # コサイン類似度に cos_sim = (2 - d_2) / 2 return cos_sim
とすると求められる。簡単じゃん。
その他
最終的解決法をつかうとモデルをメモリにロードする必要がなくなるので、ロード時間を大幅に減らせます。most_similar()
などを使わなければ、限りなくいい方法だと思います。よんでくれてありがとね。