Serverless WebSocket 架构下实时 PyTorch 特征推送至 Android 客户端的工程实践


业务需求的变化往往是推动技术架构演进最直接的动力。我们面临的挑战是:在 Android 客户端上实现一个动态的、由机器学习模型驱动的“智能推荐卡片”功能。这个卡片需要根据用户的实时行为,结合离线计算出的复杂特征,近乎实时地更新其展示内容。传统的客户端定时轮询方案,在延迟和服务器压力上都无法满足要求,技术栈升级势在必行。

初步构想是建立一个长连接服务,由服务端主动推送更新。WebSocket 是不二之_选择_。但问题随之而来:如何构建一个能承载百万级潜在连接、具备高可用性且成本可控的 WebSocket 服务?自建一套基于 EC2 或容器的 WebSocket 集群,需要处理负载均衡、节点状态同步、故障转移等一系列复杂问题,运维成本高昂。这让我们将目光投向了 Serverless 架构。

我们的目标是构建一个端到端的实时特征推送管道:数据源自 S3 数据湖,由 PyTorch 模型进行批处理和特征提取,将结果存入一个低延迟的存储中,最终通过一个 Serverless WebSocket 服务,将特征变更实时推送给在线的 Android 设备。

技术选型决策

  1. WebSocket 网关: AWS API Gateway WebSocket API。它完美解决了连接管理和水平扩展的难题。我们只需关心业务逻辑,连接的建立、心跳维持、断线重连都由 AWS 托管。它将连接事件($connect, $disconnect)和消息($default)直接路由到后端的 Lambda 函数。
  2. 后端计算: AWS Lambda。无服务器的核心,按需执行,自动扩缩容。对于处理 WebSocket 连接事件这种突发性、无状态的负载来说,成本效益极高。
  3. 连接状态管理: Amazon DynamoDB。Serverless 环境的本质是无状态的,但 WebSocket 连接本身是有状态的。我们必须在某个地方持久化 connectionId 与用户标识符之间的映射关系。DynamoDB 提供了低延迟、高可用的键值存储,是这个场景下的理想选择。
  4. 离线特征计算: S3 + AWS Glue + PyTorch。原始用户行为日志存储在 S3 数据湖中。我们使用 AWS Glue 定期触发一个批处理任务,该任务运行一个 PyTorch 模型,从 S3 读取数据,计算出最新的用户特征向量,并将结果输出。
  5. 热特征存储: Amazon DynamoDB。离线计算出的特征向量需要一个能被 Lambda 快速读取的地方。我们选择另一个 DynamoDB 表来存储这些“热”特征数据,以用户 ID 为主键。
  6. 客户端: Android (Kotlin) + OkHttp。OkHttp 提供了稳定可靠的 WebSocket 客户端实现。
  7. 数据协议: Protocol Buffers (Protobuf)。相比 JSON,Protobuf 提供了更高效的序列化/反序列化性能和更小的数据体积,这对于移动端网络环境至关重要。

架构概览

整个系统分为离线处理和在线推送两个主要部分。

graph TD
    subgraph "离线特征工程 (Batch Processing)"
        direction TB
        S3_Lake[S3 数据湖: 用户行为日志] --> Glue_Job{AWS Glue PyTorch Job};
        PyTorch_Model[PyTorch 模型] --> Glue_Job;
        Glue_Job --> Feature_DB[(DynamoDB: 热特征表)];
    end

    subgraph "在线实时推送 (Real-time Push)"
        direction LR
        Android_Client[Android 客户端] -- WebSocket Connect --> API_GW{API Gateway WebSocket};
        
        API_GW -- $connect / $disconnect --> Connect_Lambda[Lambda: 连接管理器];
        Connect_Lambda <--> Connection_DB[(DynamoDB: 连接状态表)];

        Feature_DB -- DynamoDB Streams --> Trigger_Lambda[Lambda: 变更捕获];
        Trigger_Lambda -- SQS --> Push_Queue[SQS 队列];
        Push_Queue --> Push_Lambda[Lambda: 特征推送器];
        
        Push_Lambda -- 读取ConnectionID --> Connection_DB;
        Push_Lambda -- PostToConnection --> API_GW;
        API_GW -- Push Feature Update --> Android_Client;
    end
    
    style S3_Lake fill:#f9f,stroke:#333,stroke-width:2px
    style Feature_DB fill:#ccf,stroke:#333,stroke-width:2px
    style Connection_DB fill:#ccf,stroke:#333,stroke-width:2px

步骤化实现

1. 基础设施即代码 (AWS SAM)

在真实项目中,手动配置云资源是不可接受的。我们使用 AWS Serverless Application Model (SAM) 来定义整个后端架构。

template.yaml:

AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2.0
Description: Real-time PyTorch feature push service via Serverless WebSockets.

Globals:
  Function:
    Timeout: 10
    Runtime: python3.9
    MemorySize: 256
    Architectures:
      - x86_64

Resources:
  # DynamoDB Table for WebSocket Connections
  ConnectionsTable:
    Type: AWS::Serverless::SimpleTable
    Properties:
      TableName: WebSocketConnections
      PrimaryKey:
        Name: connectionId
        Type: String
      ProvisionedThroughput:
        ReadCapacityUnits: 5
        WriteCapacityUnits: 5

  # DynamoDB Table for hot features generated by PyTorch
  FeaturesTable:
    Type: AWS::Serverless::SimpleTable
    Properties:
      TableName: UserFeatures
      PrimaryKey:
        Name: userId
        Type: String
      StreamSpecification:
        StreamViewType: NEW_AND_OLD_IMAGES # Enable stream to capture changes

  # The WebSocket API Gateway
  WebSocketApi:
    Type: AWS::ApiGatewayV2::Api
    Properties:
      Name: FeaturePushWebSocketApi
      ProtocolType: WEBSOCKET
      RouteSelectionExpression: "$request.body.action"

  # Routes for the WebSocket API
  ConnectRoute:
    Type: AWS::ApiGatewayV2::Route
    Properties:
      ApiId: !Ref WebSocketApi
      RouteKey: $connect
      AuthorizationType: NONE
      OperationName: ConnectRoute
      Target: !Join
        - /
        - - "integrations"
          - !Ref ConnectIntegration

  DisconnectRoute:
    Type: AWS::ApiGatewayV2::Route
    Properties:
      ApiId: !Ref WebSocketApi
      RouteKey: $disconnect
      AuthorizationType: NONE
      OperationName: DisconnectRoute
      Target: !Join
        - /
        - - "integrations"
          - !Ref DisconnectIntegration
  
  DefaultRoute:
    Type: AWS::ApiGatewayV2::Route
    Properties:
      ApiId: !Ref WebSocketApi
      RouteKey: $default
      AuthorizationType: NONE
      OperationName: DefaultRoute
      Target: !Join
        - /
        - - "integrations"
          - !Ref DefaultIntegration

  # Lambda Integrations for Routes
  ConnectIntegration:
    Type: AWS::ApiGatewayV2::Integration
    Properties:
      ApiId: !Ref WebSocketApi
      IntegrationType: AWS_PROXY
      IntegrationUri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${ConnectFunction.Arn}/invocations"

  DisconnectIntegration:
    Type: AWS::ApiGatewayV2::Integration
    Properties:
      ApiId: !Ref WebSocketApi
      IntegrationType: AWS_PROXY
      IntegrationUri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${DisconnectFunction.Arn}/invocations"
      
  DefaultIntegration:
    Type: AWS::ApiGatewayV2::Integration
    Properties:
      ApiId: !Ref WebSocketApi
      IntegrationType: AWS_PROXY
      IntegrationUri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${DefaultFunction.Arn}/invocations"

  # Lambda Functions
  ConnectFunction:
    Type: AWS::Serverless::Function
    Properties:
      FunctionName: WebSocketConnectFunction
      CodeUri: src/connection_manager/
      Handler: app.connect_handler
      Policies:
        - DynamoDBCrudPolicy:
            TableName: !Ref ConnectionsTable
      Environment:
        Variables:
          CONNECTIONS_TABLE: !Ref ConnectionsTable

  DisconnectFunction:
    Type: AWS::Serverless::Function
    Properties:
      FunctionName: WebSocketDisconnectFunction
      CodeUri: src/connection_manager/
      Handler: app.disconnect_handler
      Policies:
        - DynamoDBCrudPolicy:
            TableName: !Ref ConnectionsTable
      Environment:
        Variables:
          CONNECTIONS_TABLE: !Ref ConnectionsTable
          
  DefaultFunction:
    Type: AWS::Serverless::Function
    Properties:
      FunctionName: WebSocketDefaultFunction
      CodeUri: src/connection_manager/
      Handler: app.default_handler

  PushFeaturesFunction:
    Type: AWS::Serverless::Function
    Properties:
      FunctionName: PushFeaturesFunction
      CodeUri: src/feature_pusher/
      Handler: app.handler
      Policies:
        - DynamoDBReadPolicy:
            TableName: !Ref ConnectionsTable
        - Statement:
          - Effect: Allow
            Action:
              - "execute-api:ManageConnections"
            Resource: !Sub "arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:${WebSocketApi}/*"
      Environment:
        Variables:
          CONNECTIONS_TABLE: !Ref ConnectionsTable
      Events:
        Stream:
          Type: DynamoDB
          Properties:
            Stream: !GetAtt FeaturesTable.StreamArn
            BatchSize: 100
            StartingPosition: LATEST

  # Deployment
  Deployment:
    Type: AWS::ApiGatewayV2::Deployment
    DependsOn:
      - ConnectRoute
      - DisconnectRoute
      - DefaultRoute
    Properties:
      ApiId: !Ref WebSocketApi

  Stage:
    Type: AWS::ApiGatewayV2::Stage
    Properties:
      ApiId: !Ref WebSocketApi
      DeploymentId: !Ref Deployment
      StageName: prod

Outputs:
  WebSocketURI:
    Description: "The WSS URI of the WebSocket API"
    Value: !Sub "wss://${WebSocketApi}.execute-api.${AWS::Region}.amazonaws.com/prod"

这个 SAM 模板定义了所有必要的资源,并将它们关联起来。这里的关键点在于 FeaturesTable 开启了 StreamSpecification,这使得对该表的任何数据变更都能触发一个事件流。PushFeaturesFunction 正是订阅了这个流,实现了数据驱动的推送。

2. 连接管理 Lambda

这部分代码负责处理 $connect$disconnect 事件,维护 ConnectionsTable

src/connection_manager/app.py:

import os
import json
import logging
import boto3
from botocore.exceptions import ClientError

# Configure logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Initialize DynamoDB client
dynamodb = boto3.resource('dynamodb')
table_name = os.environ.get('CONNECTIONS_TABLE')
table = dynamodb.Table(table_name)

def connect_handler(event, context):
    """
    Handles new WebSocket connections.
    Stores the connectionId and userId in DynamoDB.
    """
    connection_id = event.get('requestContext', {}).get('connectionId')
    # In a real application, you would get userId from an authorizer token
    query_params = event.get('queryStringParameters', {})
    user_id = query_params.get('userId')

    if not connection_id or not user_id:
        logger.error("Missing connectionId or userId")
        return {'statusCode': 400, 'body': 'connectionId and userId are required.'}

    logger.info(f"Connecting: {connection_id} for user: {user_id}")

    try:
        table.put_item(
            Item={
                'connectionId': connection_id,
                'userId': user_id
            }
        )
        logger.info(f"Successfully stored connection for userId {user_id}")
    except ClientError as e:
        logger.error(f"Failed to store connection: {e.response['Error']['Message']}")
        return {'statusCode': 500, 'body': 'Failed to connect.'}

    return {'statusCode': 200, 'body': 'Connected.'}


def disconnect_handler(event, context):
    """
    Handles WebSocket disconnections.
    Removes the connectionId from DynamoDB.
    """
    connection_id = event.get('requestContext', {}).get('connectionId')
    if not connection_id:
        logger.error("Missing connectionId on disconnect")
        return {'statusCode': 400, 'body': 'connectionId is required.'}

    logger.info(f"Disconnecting: {connection_id}")

    try:
        table.delete_item(
            Key={
                'connectionId': connection_id
            }
        )
        logger.info(f"Successfully removed connection {connection_id}")
    except ClientError as e:
        logger.error(f"Failed to remove connection: {e.response['Error']['Message']}")
        # Don't fail the entire request, just log it.
        pass

    return {'statusCode': 200, 'body': 'Disconnected.'}

def default_handler(event, context):
    """
    Default handler for any message that doesn't match a route.
    Could be used for heartbeat/ping-pong implementation.
    """
    logger.info(f"Received default message: {event.get('body')}")
    # Simply acknowledge receipt
    return {'statusCode': 200, 'body': 'Message received.'}

关键点:

  • 用户身份关联: 在 connect_handler 中,我们将 connectionIduserId 绑定。在生产环境中,userId 应通过一个自定义的 Lambda Authorizer 从 JWT 或其他认证令牌中安全地获取,而不是从查询参数中。
  • 错误处理: 即使是简单的数据库操作,也必须包裹在 try...except 块中,并记录详细的错误日志。

3. 特征推送 Lambda

这是系统的核心。当 UserFeatures 表有数据更新时(由离线 PyTorch 任务写入),DynamoDB Streams 会触发这个 Lambda。

src/feature_pusher/app.py:

import os
import json
import logging
import boto3
from botocore.exceptions import ClientError

# Configure logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Initialize DynamoDB and API Gateway Management API clients
connections_table_name = os.environ.get('CONNECTIONS_TABLE')
dynamodb = boto3.resource('dynamodb')
connections_table = dynamodb.Table(connections_table_name)

# This needs to be configured with the endpoint URL from the event context
# because the default endpoint is not correct for WebSocket APIs.
apigateway_management_api = None

def get_api_gateway_client(event):
    """
    Lazy initializer for the APIGatewayManagementAPI client.
    """
    global apigateway_management_api
    if apigateway_management_api is None:
        domain_name = event['requestContext']['domainName']
        stage = event['requestContext']['stage']
        endpoint_url = f"https://{domain_name}/{stage}"
        apigateway_management_api = boto3.client(
            'apigatewaymanagementapi',
            endpoint_url=endpoint_url
        )
    return apigateway_management_api

def handler(event, context):
    """
    Triggered by DynamoDB stream from the FeaturesTable.
    Pushes updated features to connected clients.
    """
    logger.info(f"Received {len(event['Records'])} records from DynamoDB stream.")

    for record in event['Records']:
        # Only process MODIFY or INSERT events
        if record['eventName'] not in ['INSERT', 'MODIFY']:
            continue
            
        try:
            # The new feature data is in NewImage
            new_image = record['dynamodb']['NewImage']
            user_id = new_image['userId']['S']
            
            # Here you would typically serialize your features using Protobuf
            # For simplicity, we'll send the raw JSON representation.
            feature_payload = json.dumps({
                'userId': user_id,
                'features': new_image['featureVector']['S']
            })

            # Find all active connections for this user
            # A GSI on userId would be required for performance at scale.
            response = connections_table.scan(
                FilterExpression=boto3.dynamodb.conditions.Attr('userId').eq(user_id)
            )
            
            connections = response.get('Items', [])
            if not connections:
                logger.warning(f"No active connections found for user {user_id}. Skipping push.")
                continue

            api_client = get_api_gateway_client_for_push(context)
            
            stale_connections = []
            for connection in connections:
                connection_id = connection['connectionId']
                try:
                    logger.info(f"Pushing features to connection {connection_id} for user {user_id}")
                    api_client.post_to_connection(
                        ConnectionId=connection_id,
                        Data=feature_payload.encode('utf-8')
                    )
                except ClientError as e:
                    # GoneException indicates the connection is no longer active
                    if e.response['Error']['Code'] == 'GoneException':
                        logger.warning(f"Connection {connection_id} is stale. Marking for deletion.")
                        stale_connections.append(connection_id)
                    else:
                        logger.error(f"Error posting to connection {connection_id}: {e}")

            # Cleanup stale connections
            if stale_connections:
                with connections_table.batch_writer() as batch:
                    for conn_id in stale_connections:
                        batch.delete_item(Key={'connectionId': conn_id})
                logger.info(f"Cleaned up {len(stale_connections)} stale connections.")

        except Exception as e:
            logger.error(f"Error processing stream record: {record}. Error: {e}")
    
    return {'statusCode': 200, 'body': 'Push processed.'}

# We need a different way to get the endpoint when not called by API GW directly
def get_api_gateway_client_for_push(context):
    # This is a bit of a hack. In production, store the API ID and region
    # in environment variables during deployment.
    # For this example, let's assume they are set.
    api_id = os.environ['WEBSOCKET_API_ID']
    region = os.environ['AWS_REGION']
    stage = os.environ.get('API_STAGE', 'prod') # Assuming a stage name
    endpoint_url = f"https://{api_id}.execute-api.{region}.amazonaws.com/{stage}"

    return boto3.client(
        'apigatewaymanagementapi',
        endpoint_url=endpoint_url
    )

架构陷阱与优化:

  • 性能瓶颈: connections_table.scan() 是一个非常低效的操作。在生产环境中,必须为 ConnectionsTable 创建一个基于 userId 的全局二级索引(GSI),这样就可以用高效的 query() 操作来代替 scan()
  • 僵尸连接处理: 当客户端异常断开(例如网络切换、应用崩溃),$disconnect 事件可能不会被触发。代码中的 GoneException 捕获和处理逻辑是至关重要的。它实现了对僵尸连接的被动清理。
  • 扇出 (Fan-out): 如果一个用户特征更新需要推送给成千上万个设备(例如一个群组通知),在单个 Lambda 中循环推送会非常慢,甚至导致 Lambda 超时。更稳健的架构是将 PushFeaturesFunction 的逻辑拆分:它从流中读取记录,然后将需要推送的 userId 放入一个 SQS 队列。再由另一个 Lambda(或多个并发实例)消费 SQS 消息,执行实际的推送操作,实现大规模扇出。

4. Android 客户端 (Kotlin)

客户端代码需要初始化 OkHttp WebSocket 客户端,监听消息,并处理连接生命周期。

import okhttp3.*
import okio.ByteString
import java.util.concurrent.TimeUnit
import android.util.Log

class FeaturePushWebSocketListener : WebSocketListener() {

    companion object {
        private const val TAG = "FeaturePushWS"
        private const val NORMAL_CLOSURE_STATUS = 1000
    }

    override fun onOpen(webSocket: WebSocket, response: Response) {
        Log.i(TAG, "WebSocket connection opened.")
        // You could send an initial message or just wait for server pushes
        // webSocket.send("{\"action\":\"ping\"}")
    }

    override fun onMessage(webSocket: WebSocket, text: String) {
        // This would be used if you expect text frames. We expect binary.
        Log.i(TAG, "Received text message: $text")
    }

    override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
        Log.i(TAG, "Received binary message of size: ${bytes.size}")
        try {
            // Here you would use your Protobuf generated class to parse the bytes
            // val userFeature = UserFeatureProto.UserFeature.parseFrom(bytes.toByteArray())
            // Log.i(TAG, "Parsed feature for userId: ${userFeature.userId}")
            // updateUI(userFeature)
            
            // For demonstration, let's just log the string content
            val messageContent = bytes.utf8()
            Log.i(TAG, "Message content: $messageContent")

        } catch (e: Exception) {
            Log.e(TAG, "Failed to parse protobuf message", e)
        }
    }

    override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
        webSocket.close(NORMAL_CLOSURE_STATUS, null)
        Log.i(TAG, "WebSocket closing: $code / $reason")
    }

    override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
        Log.e(TAG, "WebSocket connection failure", t)
        // Implement retry logic here with backoff strategy
        // e.g., schedule a reconnect attempt after 5, 10, 20 seconds.
    }
}

class WebSocketManager(private val userId: String) {

    private var webSocket: WebSocket? = null
    private val client: OkHttpClient = OkHttpClient.Builder()
        .readTimeout(0, TimeUnit.MILLISECONDS) // Important for long-lived connections
        .pingInterval(30, TimeUnit.SECONDS) // Keep the connection alive
        .build()

    fun connect(webSocketUrl: String) {
        if (webSocket != null) {
            Log.w("WebSocketManager", "Already connected or connecting.")
            return
        }
        
        // Append userId to the connection URL for the connect handler
        val authenticatedUrl = "$webSocketUrl?userId=$userId"

        val request = Request.Builder()
            .url(authenticatedUrl)
            .build()
            
        webSocket = client.newWebSocket(request, FeaturePushWebSocketListener())
    }

    fun disconnect() {
        webSocket?.close(1000, "User initiated disconnect.")
        webSocket = null
    }
}

关键点:

  • 长轮询超时: readTimeout(0, ...) 是必须的,否则 OkHttp 默认的超时会断开长连接。
  • 心跳维持: pingInterval(...) 让客户端定期发送 ping 帧,可以防止网络中间设备(如 NAT、防火墙)因连接空闲而关闭 TCP 连接。
  • 连接认证: 将 userId 作为查询参数传递是演示目的。生产环境应该使用更安全的方式,比如在 Sec-WebSocket-Protocol 头中传递一个临时的认证票据。
  • 重连机制: onFailure 是实现健壮性的关键。必须在这里实现一个带指数退避的重连策略,以应对网络抖动或服务端短暂不可用的情况。

局限性与未来迭代路径

这套架构虽然解决了核心问题,但在真实生产环境中,它并非银弹,存在一些需要权衡的边界和可以优化的方向。

首先,延迟。整个流程依赖于离线批处理,这意味着从用户产生行为到接收到更新的特征,存在一个固有的时间窗口(取决于批处理的频率)。对于需要亚秒级响应的场景,这套架构是不适用的。未来的迭代方向是将离线批处理改造为基于 Flink 或 Kinesis Data Analytics 的实时流处理管道,直接消费实时行为流,动态更新特征并触发推送。

其次,Lambda 的冷启动问题。对于 WebSocket 的 $connect 请求,如果 Lambda 实例是冷启动,可能会给用户带来几百毫秒甚至数秒的额外连接延迟。虽然对于后台推送不那么敏感,但对于交互式应用来说体验不佳。可以启用 Lambda 的 Provisioned Concurrency 功能来预热实例,但这会带来额外的固定成本,违背了 Serverless 按需付费的初衷,需要进行成本效益分析。

最后,热特征存储的选型。DynamoDB 对于大多数场景已经足够,但如果特征更新极为频繁且对读取延迟要求苛刻(例如,需要支持每秒数万次的特征读取),那么基于内存的存储如 ElastiCache for Redis 可能是更优的选择。这同样是一个成本与性能的权衡:Redis 性能更高,但成本也更昂贵且需要处理集群管理。

这个方案的价值在于,它为移动端实时智能体验提供了一个高可扩展、运维成本低的架构范式,并将数据工程、机器学习和移动端开发这几个看似独立的领域有机地串联了起来。


  目录