Last active
February 4, 2026 10:34
-
-
Save trueroad/96bbe0d936bd68454d10bf20595eb424 to your computer and use it in GitHub Desktop.
Calc distance matrix from SMFs (Standard MIDI Files).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Calc distance matrix from SMFs (Standard MIDI Files). | |
| https://gist.github.com/trueroad/96bbe0d936bd68454d10bf20595eb424 | |
| Copyright (C) 2026 Masamichi Hosoda. | |
| All rights reserved. | |
| Redistribution and use in source and binary forms, with or without | |
| modification, are permitted provided that the following conditions | |
| are met: | |
| * Redistributions of source code must retain the above copyright notice, | |
| this list of conditions and the following disclaimer. | |
| * Redistributions in binary form must reproduce the above copyright notice, | |
| this list of conditions and the following disclaimer in the documentation | |
| and/or other materials provided with the distribution. | |
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |
| ARE DISCLAIMED. | |
| IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
| FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
| DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS | |
| OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) | |
| HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |
| LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |
| OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF | |
| SUCH DAMAGE. | |
| """ | |
| import os | |
| from pathlib import Path | |
| import sys | |
| from typing import Optional, Union | |
| import numpy as np | |
| import numpy.typing as npt | |
| import pandas as pd | |
| # https://gist.github.com/trueroad/97477dab8beca099afeb4af5199634e2 | |
| import smf_diff | |
| # https://gist.github.com/trueroad/a704a7ab54a851c0ddfbb359b2e7f87a | |
| import smf_distance | |
| class DistanceMatrix: | |
| """ | |
| Distance Matrix Class. | |
| 距離行列計算クラス | |
| """ | |
| def __init__(self) -> None: | |
| """__init__.""" | |
| # 差分比較パラメータ | |
| # (フィルタは論文と同様のパラメータで双方フィルタされるようにする) | |
| self.max_misalignment: Optional[float] = 0.05 | |
| self.filter_velocity: Optional[int] = 10 | |
| self.filter_duration: Optional[float] = 0.05 | |
| self.filter_noteno_margin: Optional[int] = None | |
| self.b_filter_both: bool = True | |
| self.b_octave_reduction: bool = False | |
| self.b_strict_diff: bool = False | |
| # SMF(.midファイル)のリスト | |
| self.path_list: list[Path] | |
| # 距離行列 | |
| self.matrix: npt.NDArray[np.float64] | |
| # 距離行列のラベル | |
| self.labels: list[str] | |
| # 距離行列のデータフレーム | |
| self.df: pd.DataFrame | |
| def load(self, dir: Union[str, os.PathLike[str]]) -> None: | |
| """ | |
| Load SMFs (.mid files) in the directory. | |
| ディレクトリにあるSMF(.midファイル)をロードする | |
| Args: | |
| dir (Union[str, os.PathLike[str]]): ディレクトリ | |
| """ | |
| if type(dir) is not Path: | |
| # Pathではなかった場合はPathに変換する | |
| dir = Path(dir) | |
| # ディレクトリからSMF(.midファイル)のパスを列挙する | |
| self.path_list = sorted(list(dir.glob('*.mid'))) | |
| def calc(self) -> None: | |
| """ | |
| Calc matrix. | |
| 距離行列を計算する | |
| """ | |
| # 距離行列を初期化 | |
| self.matrix = np.zeros((len(self.path_list), len(self.path_list)), | |
| dtype=np.float64) | |
| # 距離行列のラベルを初期化 | |
| self.labels = [] | |
| # 行ループ | |
| for i in range(len(self.path_list)): | |
| # SMF(.midファイル)のパスからラベル(stemだけ)を生成 | |
| label: str = str(self.path_list[i].stem) | |
| # ラベルのリストに追加 | |
| self.labels.append(label) | |
| # 列ループ | |
| for j in range(len(self.path_list)): | |
| # 今回使用している距離(非類似度)計算は、 | |
| # 非可換(非対称)なので、前後入れ替えたものも含め | |
| # すべての組み合わせを計算する。 | |
| # 同じファイル同士の距離は0になるハズなので | |
| # 計算しなくても良いが、とりあえずそのまま計算する。 | |
| # ただし、フィルタはデフォルト・論文設定だと | |
| # 評価対象側だけフィルタされるので、 | |
| # なるべく対称となるよう双方フィルタされる設定にする。 | |
| # 一部、距離がnanになってしまう組み合わせがある。 | |
| # 片方向だけnanになる場合もある。 | |
| # 差分比較クラスのインスタンスを生成 | |
| sd: smf_diff.smf_difference = smf_diff.smf_difference( | |
| verbose=0, | |
| max_misalignment=self.max_misalignment, | |
| filter_velocity=self.filter_velocity, | |
| filter_duration=self.filter_duration, | |
| filter_noteno_margin=self.filter_noteno_margin, | |
| b_filter_both=self.b_filter_both, | |
| b_octave_reduction=self.b_octave_reduction, | |
| b_strict_diff=self.b_strict_diff) | |
| # 1つ目のSMF(.midファイル)をロード | |
| sd.load_model(self.path_list[i]) | |
| # 2つ目のSMF(.midファイル)をロード | |
| sd.load_foreval(self.path_list[j]) | |
| # 差分をとる | |
| sd.diff() | |
| # タイミング系の集計 | |
| sd.calc_note_timing() | |
| # 距離計算クラスのインスタンスを生成 | |
| dist: smf_distance.Distance = smf_distance.Distance() | |
| # 距離計算して距離行列へ格納 | |
| self.matrix[i][j] = dist.calc(sd) | |
| # 距離行列をラベル含めてデータフレーム化する | |
| self.df = pd.DataFrame(self.matrix, | |
| index=self.labels, | |
| columns=self.labels) | |
| def save(self, filename: Union[str, os.PathLike[str]]) -> None: | |
| """ | |
| Save matrix. | |
| 距離行列のデータフレームをCSVファイルとして保存する | |
| Args: | |
| filename (Union[str, os.PathLike[str]]): 保存するCSVファイル名 | |
| """ | |
| # 距離がnanの項目は空欄として出力される。 | |
| # pandasのread_csvで空欄はnanになるので元に戻る(round trip)。 | |
| # CSVをExcelで読み込む際の文字化けを防ぐため | |
| # encodingをUTF-8 BOM付きとして出力する。 | |
| # pandasのread_csvはBOMがついていたら無視するため問題ない。 | |
| # 浮動小数点数をround tripする | |
| # (CSVを読み込んだ際に正確に再現する)ため、 | |
| # float64の精度を記録できる17桁で出力する。 | |
| self.df.to_csv(filename, | |
| encoding='utf-8-sig', | |
| float_format='%.17g') | |
| def main() -> None: | |
| """Do main.""" | |
| if len(sys.argv) != 3: | |
| print('Usage: ./calc_distance_matrix.py ' | |
| '[(int)SMF_DIR (out)DISTANCE_MATRIX.csv]') | |
| sys.exit(1) | |
| dm: DistanceMatrix = DistanceMatrix() | |
| dm.load(sys.argv[1]) | |
| dm.calc() | |
| dm.save(sys.argv[2]) | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment