gRPC之拦截器与元数据

  • A+
所属分类: gRPC

简介

  1. 用户可以通过访问或者修改Metadata来传递额外的信息(即HTTP/2的Header信息),比如认证信息、TraceId、RequestId等
    • Metadata是以key-value的形式存储数据的。其中key是字符串类型,value是字符串数组类型
    • Metadata 的生命周期则是一次 RPC 调用
  2. gRPC可以在四个地方增加拦截处理
    • 客户端调用前的拦截
    • 客户端收到的回复拦截
    • 服务端收到的请求拦截
    • 服务端回复前的拦截

客户端拦截器案例

  1. 客户端调用前拦截器

    @Slf4j
    public class ClientPreInterceptor implements ClientInterceptor {
         
        @Override
        public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
         
            final String methodName = method.getFullMethodName();
            return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
         
                @Override
                public void start(Listener<RespT> responseListener, Metadata headers) {
         
                    log.info("调用{}开始", methodName);
                    super.start(responseListener, headers);
                }
    
                @Override
                public void sendMessage(ReqT message) {
         
                    log.info("方法:{}发送消息:{}", methodName, message);
                    super.sendMessage(message);
                }
    
                @Override
                public void request(int numMessages) {
         
                    log.info("方法:{} 传递给侦听器的请求的消息数:{}", methodName, numMessages);
                    super.request(numMessages);
                }
    
                @Override
                public void halfClose() {
         
                    log.info("方法:{}客户端半关闭", methodName);
                    super.halfClose();
                }
            };
        }
    }
    
  2. 客户端调用后拦截器

    @Slf4j
    public class ClientPostInterceptor implements ClientInterceptor {
         
        @Override
        public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
         
            final String methodName = method.getFullMethodName();
            return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
         
                @Override
                public void start(Listener<RespT> responseListener, Metadata headers) {
         
                    ForwardingClientCallListener.SimpleForwardingClientCallListener<RespT> simpleForwardingClientCallListener = new ForwardingClientCallListener.SimpleForwardingClientCallListener<RespT>(responseListener) {
         
                        @Override
                        public void onMessage(RespT message) {
         
                            log.info("客户端已经接收到响应消息,methodName:{},message:{}", methodName, message);
                            super.onMessage(message);
                        }
    
                        @Override
                        public void onHeaders(Metadata headers) {
         
                            log.info("已经收到响应头,methodName:{}", methodName);
                            super.onHeaders(headers);
                        }
    
                        @Override
                        public void onClose(Status status, Metadata trailers) {
         
                            log.info("客户端关闭连接,methodName:{},code:{}", methodName, status.getCode());
                            super.onClose(status, trailers);
                        }
    
                        @Override
                        public void onReady() {
         
                            log.info("客户端onReady,methodName:{}", methodName);
                            super.onReady();
                        }
                    };
    
                    super.start(simpleForwardingClientCallListener, headers);
                }
            };
        }
    }
    
  3. 客户端传递Metadata

    //放到api包(服务端和客户端都依赖api包)
    public class Constant {
         
        public static final Context.Key<String> TRACE_ID_CTX_KEY = Context.key("traceId");
    
        public static final Metadata.Key<String> TRACE_ID_METADATA_KEY = Metadata.Key.of("traceId", io.grpc.Metadata.ASCII_STRING_MARSHALLER);
    }
    
    public class TraceIdClientInterceptor implements ClientInterceptor {
         
        @Override
        public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> methodDescriptor, CallOptions callOptions, Channel channel) {
         
            return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(channel.newCall(methodDescriptor, callOptions)) {
         
                @Override
                public void start(Listener<RespT> responseListener, Metadata headers) {
         
                    if (Constant.TRACE_ID_CTX_KEY.get() != null) {
         
                        headers.put(Constant.TRACE_ID_METADATA_KEY, Constant.TRACE_ID_CTX_KEY.get());
                    }
                    super.start(responseListener, headers);
                }
            };
        }
    }
    
  4. 客户端拦截器配置

    @Slf4j
    public class GrpcConsumerInterceptor {
         
        public static final String IP = "127.0.0.1";
        public static final int PORT = 8081;
    
        @Test
        public void testGlobalInterceptor() {
         
            ManagedChannel channel = ManagedChannelBuilder.forAddress(IP, PORT)
                    .usePlaintext()// 启用明文
                    .intercept(
                            new ClientPreInterceptor()
                            , new ClientPostInterceptor()
                            , new TraceIdClientInterceptor()
                    )//创建channel时注册全局拦截器
                    .build();
    
            // 同步调用
            HelloServiceGrpc.HelloServiceBlockingStub stub
                    = HelloServiceGrpc.newBlockingStub(channel);
    
            HelloResponse helloResponse = stub.hello(HelloRequest.newBuilder()
                    .setFirstName("Jannal")
                    .setLastName("Jan")
                    .build());
            log.info("Response received from server:{}", helloResponse);
            channel.shutdown();
        }
    
        @Test
        public void testMethodInterceptor() {
         
            ManagedChannel channel = ManagedChannelBuilder.forAddress(IP, PORT)
                    .usePlaintext()// 启用明文
                    .build();
    
            HelloServiceGrpc.HelloServiceBlockingStub stub
                    = HelloServiceGrpc.newBlockingStub(channel);
    
            //调用方法时指定拦截器
            HelloResponse helloResponse = stub.withInterceptors(new ClientPreInterceptor(),
                            new ClientPostInterceptor(),
                            new TraceIdClientInterceptor())
                    .hello(HelloRequest.newBuilder()
                            .setFirstName("Jannal")
                            .setLastName("Jan")
                            .build());
            log.info("Response received from server:{}", helloResponse);
            channel.shutdown();
        }
    
        @Test
        public void testTraceId() {
         
            ManagedChannel channel = ManagedChannelBuilder.forAddress(IP, PORT)
                    .usePlaintext()// 启用明文
                    .build();
    
            Context.current().withValue(Constant.TRACE_ID_CTX_KEY, UUID.randomUUID().toString()).run(() -> {
         
    
                HelloServiceGrpc.HelloServiceBlockingStub stub
                        = HelloServiceGrpc.newBlockingStub(channel);
                //调用方法时指定拦截器
                HelloResponse helloResponse = stub.withInterceptors(new ClientPreInterceptor(),
                                new ClientPostInterceptor(),
                                new TraceIdClientInterceptor())
                        .hello(HelloRequest.newBuilder()
                                .setFirstName("Jannal")
                                .setLastName("Jan")
                                .build());
                log.info("Response received from server:{}", helloResponse);
            });
    
    
            channel.shutdown();
        }
    
    }
    
  5. 输出结果

    ---------------------客户端调用前的拦截---------------------
    客户端传递的traceId:3f393abc-d5ae-4ad0-b24a-aaa4596e8f7f
    调用cn.jannal.grpc.facade.dto.HelloService/hello开始
    方法:cn.jannal.grpc.facade.dto.HelloService/hello 传递给侦听器的请求的消息数:2
    方法:cn.jannal.grpc.facade.dto.HelloService/hello发送消息:firstName: "Jannal"
    lastName: "Jan"
    ---------------------客户端调用后的拦截--------------------- 
    方法:cn.jannal.grpc.facade.dto.HelloService/hello客户端半关闭
    已经收到响应头,methodName:cn.jannal.grpc.facade.dto.HelloService/hello
    客户端已经接收到响应消息,methodName:cn.jannal.grpc.facade.dto.HelloService/hello,message:greeting: "Hello, Jannal Jan"
    客户端关闭连接,methodName:cn.jannal.grpc.facade.dto.HelloService/hello,code:OK
    

服务端拦截器案例

  1. 服务端收到的请求拦截

    /** * 服务端收到请求拦截器 */
    @Slf4j
    public class ServerPreInterceptor implements ServerInterceptor {
         
        @Override
        public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
         
            final String methodName = call.getMethodDescriptor().getFullMethodName();
            return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall(call, headers)) {
         
                @Override
                public void onMessage(ReqT message) {
         
                    log.info("服务端接收消息,methodName:{},message:{}", methodName, message);
                    super.onMessage(message);
                }
    
                @Override
                public void onHalfClose() {
         
                    log.info("服务端半关闭,methodName:{}", methodName);
                    super.onHalfClose();
                }
    
                @Override
                public void onCancel() {
         
                    log.info("服务端调用被取消,methodName:{}", methodName);
                    super.onCancel();
                }
    
                @Override
                public void onComplete() {
         
                    log.info("服务端调用完成,methodName:{}", methodName);
                    super.onComplete();
                }
    
                @Override
                public void onReady() {
         
                    log.info("服务端onReady,methodName:{}", methodName);
                    super.onReady();
                }
            };
        }
    }
    
  2. 服务端回复前的拦截

    @Slf4j
    public class ServerPostInterceptor implements ServerInterceptor {
         
    
        @Override
        public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
         
            final String methodName = call.getMethodDescriptor().getFullMethodName();
            ServerCall<ReqT, RespT> newCall = new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call) {
         
                @Override
                public void sendMessage(RespT message) {
         
                    log.info("服务端发送消息,methodName:{},message:{}", methodName, message);
                    super.sendMessage(message);
                }
    
                @Override
                public void sendHeaders(Metadata headers) {
         
                    log.info("服务端发送响应头,methodName:{}", methodName);
                    super.sendHeaders(headers);
                }
    
                @Override
                public void close(Status status, Metadata trailers) {
         
                    log.info("服务端关闭连接,methodName:{},code:{}", methodName, status.getCode());
                    super.close(status, trailers);
                }
            };
            return next.startCall(newCall, headers);
        }
    }
    
  3. 服务端获取Metadata

    public class TraceIdServerInterceptor implements ServerInterceptor {
         
    
        @Override
        public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> serverCall, Metadata metadata, ServerCallHandler<ReqT, RespT> serverCallHandler) {
         
            //从metadata获取traceId,在放入context中
            String traceId = metadata.get(Constant.TRACE_ID_METADATA_KEY);
            Context ctx = Context.current().withValue(Constant.TRACE_ID_CTX_KEY, traceId);
            return Contexts.interceptCall(ctx, serverCall, metadata, serverCallHandler);
        }
    }
    
  4. 服务端拦截器配置

    @Slf4j
    public class GrpcProvider {
         
        public static void main(String[] args) throws IOException {
         
            int port = 8081;
            ServerBuilder serverBuilder = ServerBuilder
                    .forPort(port)
                    .intercept(new ServerPreInterceptor())
                    .intercept(new ServerPostInterceptor())
                    .intercept(new TraceIdServerInterceptor())
                    .addService(new HelloServiceImpl());
            // 指定Service的拦截器
            //.addService(ServerInterceptors.intercept(new HelloServiceImpl(), new ServerPreInterceptor()))
    
            Server server = serverBuilder.build();
            Runtime.getRuntime().addShutdownHook(new Thread(() -> {
         
                if (server != null) {
         
                    server.shutdown();
                }
                log.info("Server Shutdown!");
            }));
            serverBuilder.intercept(TransmitStatusRuntimeExceptionInterceptor.instance());
            server.start();
            log.info("Server start port {} !", port);
            startDaemonAwaitThread(server);
    
    
        }
    
        private static void startDaemonAwaitThread(Server server) {
         
            Thread awaitThread = new Thread(() -> {
         
                try {
         
                    server.awaitTermination();
                } catch (InterruptedException ignore) {
         
    
                }
            });
            awaitThread.setDaemon(false);
            awaitThread.start();
        }
    }
    
  5. 输出结果

    ---------------------服务端收到的请求拦截---------------------
    客户端传递的TraceId:479a91b0-854b-482e-ab27-221d6d1a0e5e
    服务端onReady,methodName:cn.jannal.grpc.facade.dto.HelloService/hello
    服务端接收消息,methodName:cn.jannal.grpc.facade.dto.HelloService/hello,message:firstName: "Jannal"
    lastName: "Jan"
    服务端半关闭,methodName:cn.jannal.grpc.facade.dto.HelloService/hello  
    ---------------------服务端响应前的请求拦截---------------------
    服务端发送响应头,methodName:cn.jannal.grpc.facade.dto.HelloService/hello
    服务端发送消息,methodName:cn.jannal.grpc.facade.dto.HelloService/hello,message:greeting: "Hello, Jannal Jan"
    服务端关闭连接,methodName:cn.jannal.grpc.facade.dto.HelloService/hello,code:OK
    服务端调用完成,methodName:cn.jannal.grpc.facade.dto.HelloService/hello
    
    

身份认证

  1. gRPC内置了以下身份验证机制:

    • SSL / TLS:gRPC具有SSL / TLS集成,并促进使用SSL / TLS对服务器进行身份验证,并加密客户端和服务器之间交换的所有数据。可选机制可供客户端提供相互身份验证的证书。
    • 使用Google进行基于令牌的身份验证:gRPC提供了一种通用机制,用于将基于元数据的凭据附加到请求和响应。gRPC提供了一个基于Credentials对象统一概念的简单身份验证API,可以在创建整个gRPC Channel或单个调用时使用。
  2. 客户端继承CallCredentials实现身份认证令牌

    //放在通用的api接口包中
    public class Constant {
         
        public static final Metadata.Key<String> AUTHORIZATION_METADATA_KEY = Metadata.Key.of("Authorization", io.grpc.Metadata.ASCII_STRING_MARSHALLER);
        public static final Context.Key<String> TOKEN_CONTEXT_KEY = Context.key("token");
    
        private Constant() {
         
        }
    }
    
    public class RequestToken extends CallCredentials {
         
        private String value;
    
        public RequestToken(String value) {
         
            this.value = value;
        }
    
        @Override
        public void applyRequestMetadata(RequestInfo requestInfo, Executor executor, MetadataApplier metadataApplier) {
         
            executor.execute(() -> {
         
                try {
         
                    Metadata headers = new Metadata();
                    headers.put(Constant.AUTHORIZATION_METADATA_KEY, value);
                    metadataApplier.apply(headers);
                } catch (Throwable e) {
         
                    metadataApplier.fail(Status.UNAUTHENTICATED.withCause(e));
                }
            });
        }
    
        @Override
        public void thisUsesUnstableApi() {
         
            // noop
        }
    }
    
    
  3. 客户端请求携带凭证

    /** * 携带凭证 */
    @Test
    public void testCredentials() {
         
        ManagedChannel channel = ManagedChannelBuilder.forAddress(IP, PORT)
                .usePlaintext()// 启用明文
                .build();
        HelloServiceGrpc.HelloServiceBlockingStub stub
                = HelloServiceGrpc.newBlockingStub(channel);
        //调用方法时指定拦截器
        HelloResponse helloResponse = stub
                .withCallCredentials(new RequestToken("123456"))
                .hello(HelloRequest.newBuilder()
                        .setFirstName("jannal")
                        .setLastName("Jan")
                        .build());
        log.info("Response received from server:{}", helloResponse);
        channel.shutdown();
    }
    
  4. 服务端编写权限拦截器

    @Slf4j
    public class AuthorizationServerInterceptor implements ServerInterceptor {
         
    
    
        @Override
        public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> serverCall, Metadata metadata, ServerCallHandler<ReqT, RespT> serverCallHandler) {
         
            String value = metadata.get(Constant.AUTHORIZATION_METADATA_KEY);
    
            log.info("客户端携带的token:{}", value);
    
            Status status;
            if (value == null || "".equals(value)) {
         
                status = Status.UNAUTHENTICATED.withDescription("Authorization token is missing");
            } else {
         
                try {
         
                    Context ctx = Context.current().withValue(Constant.TOKEN_CONTEXT_KEY, value);
                    return Contexts.interceptCall(ctx, serverCall, metadata, serverCallHandler);
                } catch (Exception e) {
         
                    status = Status.UNAUTHENTICATED.withDescription(e.getMessage()).withCause(e);
                }
            }
    
            serverCall.close(status, metadata);
            return new ServerCall.Listener<ReqT>() {
         
                // noop
            };
        }
    }
    
w3cjava