Skip to the content.

MLIRの実装

導入

前回はMLIRの概念的な部分の説明に徹して、実装部分の話はほとんどしなかった。

今回はいよいよ実装部分に踏み込む。 とはいえ、あまり深追いはしない。 ここではDialectを作成、拡張する上で必要になる知識をまとめる。

結構雑に書いているので、この記事の基になっている公式のチュートリアルも適宜参照するとよいと思う。

実装方法

Dialect

まずはDialect自身を定義する。

MLIRはそれ自身が中間言語というわけではなく、中間言語を定義するためのフレームワークでしかない。 そのため、Dialectを実装しないことには始まらない。 (MLIRのプロジェクトの中で実装されているDialectも多数あるが、例えばFIRはMLIRのプロジェクトから外れたところにいる)

ベースとなるクラスはMLIRの中に定義されており、mlir/include/mlir/IR/Dialect.hに実装されている。 このmlir::Dialectを継承してDialectを定義する。

実はC++で直接定義する以外にもう一つ定義する方法がある。 それがTableGenを使った方法で、これをMLIRではODS(Operation Definition Specification)と呼んでいる。 ODSで定義する場合は、mlir/include/mlir/IR/DialectBase.tdにあるDialectを継承して定義する。 TableGenは独自の文法からC++コードを生成する仕組みだが、それはmlir-tblgen -gen-dialect-declsというコマンドで実行される。
また、ODSの別の利点として、ドキュメント化が楽というのがある。(-gen-dialect-doc) そのため、ODSでの定義が推奨されている。以降でもODSでの定義方法を中心に説明する。

メンバとして定義できるものをいくつかピックアップする。

メンバ 必要性 説明 (C++コードとの対応)
name 必須 Dialectの名前そのものズバリ getDialectNamespace()
cppNamespace 推奨 C++コードでDialect固有の要素を定義するときに使う名前空間
(省略した場合はnameと同じになる)
namespace {}
summary 推奨 Dialectの概要の説明  
description 推奨 Dialectの具体的な説明  
dependentDialects 適宜 定義するDialectに必要な既存のDialectのリスト  

それ以外のメンバについては、適宜DialectBase.tdを参照して欲しい。

Dialectは定義するだけではダメで、実際に使用するにはMLIRContextloadDialect()で読み込ませる必要がある。(どのタイミングで?)

Op (Operation)

Dialectが定義できたら中身を実装していく。 まずはIRの基本構成要素であるOperationを定義する。

DialectのOperationを定義するには、mlir/include/mlir/IR/OpDefinition.hにあるmlir::Opを継承してクラスを定義すればよい。 mlir::OpはCRTP(Curiously Reccuring Template Pattern)というもので、詳細はググってほしいがクラスを継承する際にテンプレート引数にサブクラス自身を渡す手法が使われている。
(ちなみにOperationというクラスもあるが、これはIRの構成要素としてのOperationであって似て非なるものである。詳細は公式チュートリアルを参照。)

ODSでも定義できるが、いきなりmlir/include/mlir/IR/OpBase.tdにあるOpを継承して定義するのはあまり推奨されていなさそう。 ODSではまずDialectのOperationを定義するOpのサブクラスを定義して、さらにそれを継承してOperationを定義するというのがベストプラクティスとされている。 生成コマンドはmlir-tblgen -gen-op-decls(クラス定義)とmlir-tblgen -gen-op-defs(関数定義)である。

定義できるものとして以下がある。

メンバ 必要性 説明 (C++コードとの対応)
mnemonic(テンプレート引数) 必須 Operationの名前 getOperationName()
traits(テンプレート引数) 任意 Operationの性質 対応するOpTrait
arguments 必須 Operationの引数(OperandとAttribute)
なければ省略可
mlir::OpTrait::ZeroOperandsなど
results 必須 Operationの返り値
なければ省略可
mlir::OpTrait::ZeroResultsなど
summary 推奨 Operationの概要の説明  
description 推奨 Operationの具体的な説明  
hasVerifier 推奨 生成されたOperationの妥当性を確認する関数verify()をユーザが定義するか verify()
builders 任意 コンストラクタの追加 build()
assemblyFormat
hasCustomAssemblyFormat
任意 アセンブリ形式での表現 print()およびparse()

その他、regionsuccessorsで分岐処理が作れる(よく分かっていない)ほか、OpBase.tdにいくつかあるので適宜参照。

Operationも定義するだけではダメで、使用するためには先ほど定義したDialectのクラスのinitialize()内でaddOperations()を呼ぶ必要がある。
このとき、GET_OP_LISTというマクロを定義した上で-gen-op-defsで生成されたファイルをインクルードすると簡潔に書ける。(参考)

Type

Interface

一番使用頻度が高いのはOperationのInterfaceだと思うのでそれを中心に説明する。

OperationのInterfaceはmlir/include/mlir/IR/OpDefinition.hにあるOpInterfaceを継承して定義する。 そして定義したInterfaceにinterface methodを実装する。 (Op側にinterface methodを宣言する必要があるが、Interfaceのクラスを継承させればよい?後述するODSでのやり方しか情報がないため不明。 多分tblgenで自動生成されるファイルの中身を見ればわかるとは思うが)

ODSで定義する場合は、mlir/include/mlir/IR/Interfaces.tdOpInterfaceを継承して定義する。 interface methodはmethodというメンバの中に列挙していく。 各interface methodはInterfaceMethodというテンプレートを使って宣言を書く。 最後にOperationからinterface methodを呼べるように、TraitsにDeclareOpInterfaceMethodsをつけてInterfaceを指定すると、Opクラスのメンバとして必要な関数の宣言が自動で挿入される。(ただしdefaultImplementationを指定したmethodについては明示的に指定しないとオーバーライドできない) あとはそれらに対して定義を書けばよい。

DialectInterfaceを一から定義する方法は不明だが、定義したDialectInterfaceを使用するためにはDialectのクラスのinitialize()内でaddInterfaces()を呼ぶ必要がある。

Pass

DialectにはOperationが定義されれば十分かといわれるとそうではない。 Dialectによる中間表現は、表現できることだけでなく、LLVM dialect、そしてLLVM IRに変換されていくことが求められる。 多くのDialectではIRというディレクトリにOperationの定義、TransformsというディレクトリにOperationの変換規則(Pass)の定義がされている。 ここではそのPassを扱う。

ちなみにMLIRではLoweringパスもOptimizationパスも等しく”変換パス”という扱いになっている。 これは前回も述べた通りPartial Loweringが可能であるため、そもそもLoweringのフェーズというものが存在しないからである。 (ただし実装の観点では、インターフェースこそ共通だが中身の実装は明らかに違っている)

まずはパス自身を定義する。 パス自身はPassWrapperというクラスを使って定義する。 PassWrapperもCRTPであり、一番目の引数に定義するパス自身、二番目の引数には継承するクラス(OperationPass<mlir::ModuleOp>など)を指定する。 そしてメンバ関数であるrunOnOperation()などに具体的な処理内容を記述していく。

パスに関しても、C++で直接定義する以外にODSで定義する方法がある。 mlir/include/mlir/Pass/PassBase.tdにあるPassを継承することで定義できる。 ただ、結局ODSはガワしか作ってくれないので中身は自分で実装していく必要がある。

次にrunOnOperation()の中に変換処理を実装するにあたって役に立つ機能をいくつか紹介する。

まず1つ目はgetOperation()である。
これはパスの変換対象となるOperationを返してくれる関数である。 この関数はパスの中であればどこでも呼び出すことができる。 (当たり前と思われるかもしれないが、LLVMだと変換対象のInstructionを取得しようと思ったときに、そこから見えているクラスから辿っていかないと取得できないことがあり、面倒くさい)

2つ目はmlir::RewritePatternである。
実際にはmlir::OpRewritePatternmlir::OpInterfaceRewritePatternmlir::ConversionPatternなどを継承して使う。 その名から分かるように、パターンマッチングでOperationを書き換えていく仕組みである。 変換処理の本体はmatchAndRewrite()であり、パターンマッチングをした後mlir::PatternRewriterreplaceOp()によって実際に変換する。
この時のパターンマッチをC++で書く方法もあるが、簡単なパターンマッチであればTableGenを使って簡潔に書くことができる。 MLIRではこれをDRR(Declarative Rewrite Rule)と呼んでいる。

3つ目はmlir::ConversionTargetである。
大抵の場合、変換パスでは変換対象となるものとならないものがあるわけで、そのあたりの区別はちゃんとする必要がある。 特にDialect間の変換ではこの辺りを確認するコードを一から書いているのでは大変だし漏れがあるかもしれない。 そこでこの仕組みを使う。 (Conversionという単語自体はDialect間の変換に限らずDialect内での変換も含むはずだが、ConversionTargetはDialect間の変換(Lowering)に使われることがほとんどのようだ)
このConversionTargetに、addLegalDialect, addDynamicallyLegalDialect, addIllegalDialect, addLegalOp, addDynamicallyLegalOp, addIllegalOpといった関数で情報を追加していく。 例えばaddLegalDialectaddIllegalDialectを組み合わせることでaddIllegalDialectのOperationをすべてaddLegalDialectに変換するといったことを表現できる。 また、addIllegalDialectのあるOperationをaddLegalOpに追加することで、例外的にそのOperationが変換されなくても許されるようになる。
あとはRewritePatternSetConversionPatternを追加し、mlir::applyPartialConversion()またはmlir::applyFullConversion()を呼ぶことで変換が実行される。 (ちなみによく使われる変換パターンは既にまとめられていて、例えばmlir::arith::populateArithToLLVMConversionPatterns()といった関数を呼べばArith DialectからLLVM DialectへのRewritePatternを一括取得できる。)

パスもやはり定義するだけではダメで、mlir::PassManageraddPass()でパスを追加した上でrun()を呼んで初めてパスが実行される。
このとき、addPass()に渡すパス(へのunique_ptr)を生成する必要が出てくるわけで、そのための関数(create*Pass())も必要になる。 中でやっていることは大したことはなく、ODSで勝手に作ってくれる。 (ODSでconstructorを明に定義しておけばユーザ好みにカスタマイズできる)

IR

Dialectを定義する方法を前節で述べた。ただこれだけの情報ではパスの中身は実装できないと思う。 ここからはMLIRの構造の実装を見ていく。

構造

LLVM IRの場合は、Module->Function->BasicBlock->Instructionと階層がはっきり分かれている。 対してMLIRの場合は、前回説明した通りOperationを中心として、Operation自身が階層構造を持つようになっている。 ModuleもFunctionもInstructionも、MLIRにおいては等しくOperationである。

2020-02-26 - CGO 2020 Talkより図を拝借して、まずは各階層へのアクセスの仕方を確認する。

MLIRの構造


Operation->親Block: getBlock
Block->親Region: getParent
Operation->親Region: getParentRegion
Region->親Region: getParentRegion
Operation->親Operation: getParentOp
Block->親Operation: getParentOp
Region->親Operaion: getParentOp

Region->Block: getBlocks
Region->BlockArgument: getArguments
Region->Operation: getOps

Block->BlockArgument: getArguments
Block->Operation: getOperations

Operation->OpOperand: getOperands
Operation->Attribute: getAttr

Value->定義元Operation: getDefiningOp
Value->親Block: getParentBlock
Value->親Region: getParentRegion

Operation->その結果を使っているOpOperand: getUses
Value->それを使っている(対応する)OpOperand: getUses
Operation->その結果を使っているOperation: getUsers

OpOperand->親Operation: getOwner
OpOperand->親Value: get


Operation*mlir::dyn_castを使って前項で定義したような任意のOpのサブクラスに変換できる。 また、OperationがあるInterfaceを持っているかどうかは、該当Interfaceのクラスにdyn_castできるかで判別できる。

Walker

例えばModuleからIR全体を探索して各Operationに対して同じ処理を行いたいとなったときに、便利な機能がある。 それがWalkerである。 正確に言えばOperation::walk()である。

この関数は引数に関数を取ることができ、そのOperationを起点に順次Operationを探索していき関数を適用していく。 デフォルトでは後行順の深さ優先探索(葉から辿る)をするが、テンプレート引数で設定すれば先行順で(根から)辿る。

https://mlir.llvm.org/docs/Tutorials/UnderstandingTheIRStructure/#walkers