この記事はGo2 Advent Calendar 2019の記事。

自分はprotoc-gen-gohttpというプラグインを作っているが、先日grpc.UnaryServerInterceptorに対応した。

その際、どうやればいいかを調査していたら、ChainUnaryServerという関数を見つけた。 中身を読んでみると、特定のinterfaceを受け取ってそれを実行する機能を持った関数に対し、interfaceを実行するときに複数の関数でインターセプトする方法の知見を得られたのため記事にしようと思う。

コード例 Link to heading

まずはインターセプトしたいコードの例を書く。

type WrapStruct struct {
	pingServer pb.PingAPIServer
}

func (w *WrapStruct) Call() {
	ctx := context.Background()
	arg := &pb.PingRequest{
		Msg: "PING",
	}

	ret, err := w.pingServer.Ping(ctx, arg)

	fmt.Println(ret, err)
}

上記例は以下のProtocol Buffersの定義をgRPCオプション付きでコンパイルすると生成されるコードを使用している。

syntax = "proto3";

package protobuf1;

option go_package = "pb";

service PingAPI {
  rpc Ping(PingRequest) returns (PingResponse);
}

message PingRequest {
  string msg = 1;
}

message PingResponse {
  string msg = 1;
}

ただ自身を実装している構造体がもつ PingAPIServer interfaceの Ping メソッドに PING の文字列を渡して実行し、その結果を出力している。

以下のように、 PingAPIServer interfaceを実装した構造体を渡すことで実行できる。

func main() {
	pingServer := &PingAPIServer{}
	wrapStruct := &WrapStruct{
		pingServer: pingServer,
	}

	wrapStruct.Call()
}

type PingAPIServer struct{}

func (p *PingAPIServer) Ping(ctx context.Context, arg *pb.PingRequest) (*pb.PingResponse, error) {
	fmt.Println("CALLED PING")
	return &pb.PingResponse{
		Msg: fmt.Sprintf("PONG: %s", arg.GetMsg()),
	}, nil
}

実行結果は以下のようになる。

CALLED PING
msg:"PONG: PING"  <nil>

さて、今回インターセプトしたい部分は最初のコードの以下の部分。

	ret, err := w.pingServer.Ping(ctx, arg)

このコードの前後に、 WrapStruct.Call に関数を渡すことで処理をインターセプトできるようにしたい。

インターセプトできるようにしたコードを以下に貼ってみる。

type RPC func(context.Context, *pb.PingRequest) (*pb.PingResponse, error)
type Interceptor func(context.Context, *pb.PingRequest, RPC) (*pb.PingResponse, error)

type WrapStruct struct {
	pingServer pb.PingAPIServer
}

func (w *WrapStruct) Call(interceptors ...Interceptor) {
	ctx := context.Background()
	arg := &pb.PingRequest{
		Msg: "PING",
	}

	n := len(interceptors)
	chained := func(ctx context.Context, arg *pb.PingRequest, rpc RPC) (*pb.PingResponse, error) {
		chainer := func(currentInter Interceptor, currentRPC RPC) RPC {
			return func(currentCtx context.Context, currentReq *pb.PingRequest) (*pb.PingResponse, error) {
				return currentInter(currentCtx, currentReq, currentRPC)
			}
		}

		chainedRPC := rpc
		for i := n - 1; i >= 0; i-- {
			chainedRPC = chainer(interceptors[i], chainedRPC)
		}
		return chainedRPC(ctx, arg)
	}

	ret, err := chained(ctx, arg, w.pingServer.Ping)

	fmt.Println(ret, err)
}

以下のようにして実行することで内部の Ping をインターセプトできる。

func main() {
	pingServer := &PingAPIServer{}
	wrapStruct := &WrapStruct{
		pingServer: pingServer,
	}

	wrapStruct.Call()
	fmt.Println("-------------------")
	wrapStruct.Call(
		printInterceptor,
	)
	fmt.Println("-------------------")
	wrapStruct.Call(
		printInterceptor,
		printInterceptor2,
	)
}

func printInterceptor(ctx context.Context, arg *pb.PingRequest, rpc RPC) (*pb.PingResponse, error) {
	fmt.Println("BEFORE")
	ret, err := rpc(ctx, arg)
	fmt.Println("AFTER")
	return ret, err
}

func printInterceptor2(ctx context.Context, arg *pb.PingRequest, rpc RPC) (*pb.PingResponse, error) {
	fmt.Println("BEFORE2")
	ret, err := rpc(ctx, arg)
	fmt.Println("AFTER2")
	return ret, err
}

上記コードの実行結果は以下。

CALLED PING
msg:"PONG: PING"  <nil>
-------------------
BEFORE
CALLED PING
AFTER
msg:"PONG: PING"  <nil>
-------------------
BEFORE
BEFORE2
CALLED PING
AFTER2
AFTER
msg:"PONG: PING"  <nil>

Pingメソッドが内部で出力している CALLED PING の前後に任意の文字列を出力できている。

解説 Link to heading

コードを見ればわかるが、一応どういう変更を入れたのか簡単に解説をしようと思う。

まずは下記のintraface。

type RPC func(context.Context, *pb.PingRequest) (*pb.PingResponse, error)
type Interceptor func(context.Context, *pb.PingRequest, RPC) (*pb.PingResponse, error)

RPC は内部でインターセプトされる側のメソッドを表現したもの。

今回は例のため型をベタに書いているが、汎用性が必要な場合は interface{} 型にしたり、引数や返り値にinterfaceを使うほうがよい。 今回の例も proto.Message 型で抽象化できるが、わかりやすさのために構造体を直に書いている。

Interceptor は文字通りインターセプトをするための関数の定義。

引数に、「RPCの引数」と「RPCそのもの」を受け取り、「RPCの戻り値」を返す定義になっている。

次に Call メソッドの引数の解説。

func (w *WrapStruct) Call(interceptors ...Interceptor) {

ここは単純に、可変長引数でInterceptorを受け取っている。

必須なら可変長引数にしないほうがいいが、ここでは渡さなくても動く例を示せるよう可変長引数にしている。

さて、一番重要な部分。

	n := len(interceptors)
	chained := func(ctx context.Context, arg *pb.PingRequest, rpc RPC) (*pb.PingResponse, error) {
		chainer := func(currentInter Interceptor, currentRPC RPC) RPC {
			return func(currentCtx context.Context, currentReq *pb.PingRequest) (*pb.PingResponse, error) {
				return currentInter(currentCtx, currentReq, currentRPC)
			}
		}

		chainedRPC := rpc
		for i := n - 1; i >= 0; i-- {
			chainedRPC = chainer(interceptors[i], chainedRPC)
		}
		return chainedRPC(ctx, arg)
	}

	ret, err := chained(ctx, arg, w.pingServer.Ping)

1つずつ分解してみる。

chained 関数は Interceptor と同じ引数戻り値をしている。

重要なのは chained 関数の内部。

内部で定義されている chainerInterceptorRPC を受け取り、 RPC を返す関数になっている。

返している RPC は何をしているかというと、 chainer が受け取った Interceptor を内部で実行してその戻り値を返している。

つまり chainer は、「受け取った RPC を、同時に受け取った Interceptor で実行する RPC へ変換する」機能を持った関数となる。

それ以降は簡単。

chained の引数で受け取った RPC を起点にして、forループで chainer を実行し、受け取った InterceptorRPC に変換していく。 逆順に回している理由は、 最初に渡した Interceptor を最後に変換することで、渡した順に Interceptor を実行するため。

最後に、変換が終わった RPCchained が受け取った RPC の引数を渡して実行することで、 chained の処理は終わり。

こうすることで、 Call メソッドに1つの Interceptor が渡された場合、その渡された Interceptor には Ping メソッドが渡された状態で RPC に変換して RPC を実行することでインターセプトを可能にできる。

まとめ Link to heading

このテクニックを知っていると「特定のinterfaceを受け取ってそのinterfaceで実行する」機能を持ったものを内部でインターセプトして任意の処理を挟むことができるようになる。

あまり業務のコードで使えるようなテクニックではないが、上記のような構造を持つライブラリやコードジェネレータを作るときに知っておくと使う側にある程度の自由度を与えることができるようになる。

おまけ Link to heading

今回の例では出てこなかったテクニックとして、Interceptorに追加の情報を渡すということができる。

上記の例では Interceptor は引数に、「RPC の引数」と「RPC そのもの」を受け取るようになっているが、これは最低限の定義でしかないため、追加で別のものを受け取るように定義することで実現できる。

実際にgrpc.UnaryServerInterceptorは、ctxとreqとhandlerの他に、 UnaryServerInfo という構造体も渡されるように作られている。

こうすることにより、インターセプトする側がされる側の情報を取得して処理を挟むことができるようになる。