Fusing Neo4j Graph Queries and TensorFlow Inference Within a Custom API Gateway Plugin


The initial system architecture was failing under its own weight. We were serving a fraud detection model where feature staleness was directly correlated with financial loss. Our existing feature store, built on a key-value paradigm, could only serve pre-calculated, flattened data. It was blind to the intricate, real-time relationships between entities—a classic graph problem. A request to score a transaction for user U1 couldn’t answer questions like, “Is U1 connected within three hops to a known money laundering ring identified five minutes ago?” or “Does the device U1 is using share a fingerprint with any account that has been flagged for chargebacks in the last hour?”. These are not features you can batch-process overnight; they must be computed at the moment of the transaction.

Our first attempt involved creating a monolithic Python service. It took an incoming request, queried Neo4j for graph features, bundled them with static features, called a local TensorFlow model for inference, and returned a score. While functionally correct, this approach was an operational nightmare. Redeploying the model required a full service restart. A bug in the feature generation logic could crash the entire prediction endpoint. Scaling was inefficient—we had to scale the entire monolith even if the bottleneck was only the CPU-intensive model inference. The coupling was too tight.

This led to a new architectural hypothesis: decouple feature generation, model inference, and client-facing orchestration. The natural home for the model was TensorFlow Serving—a dedicated, high-performance system for this exact purpose. The graph data belonged in Neo4j, optimized for real-time traversals. The critical missing piece was the orchestrator. The obvious choice, an intermediate microservice, simply moved the problem, adding another network hop and another component to maintain.

The real technical challenge was to perform this complex, multi-system orchestration with sub-100ms P99 latency. Adding another service was counterproductive to this goal. The solution had to lie within the infrastructure layer itself. We decided to embed this orchestration logic directly into our API Gateway. Not as a simple proxy rule, but as a custom, high-performance plugin that would become an active participant in the request lifecycle. This is the build log of that system.

Technology Selection Rationale

In a real-world project, every technology choice is a trade-off. Here’s why we landed on this specific stack:

  1. Neo4j: We evaluated relational databases and other NoSQL options. A relational approach for deep, variable-depth relationship queries would involve multiple, unpredictable JOIN operations, which are notoriously slow at scale. Neo4j’s native graph storage and processing engine, with its index-free adjacency, is purpose-built for this kind of traversal. For our use case, the performance difference between a Cypher query and a recursive SQL CTE was an order of magnitude.

  2. TensorFlow Serving: We needed to decouple model updates from our application code. TF Serving provides this out-of-the-box with gRPC endpoints for high throughput and low latency. It handles model versioning, rollouts, and hardware acceleration (GPU utilization) transparently, freeing the application team from low-level MLOps concerns.

  3. Custom API Gateway (FastAPI): We considered extending an existing gateway like Kong or Apache APISIX with a Lua plugin. While powerful, the learning curve for our team and the debugging complexity of Lua in a high-concurrency environment were significant risks. Furthermore, our orchestration logic required a rich ecosystem of libraries for connecting to Neo4j and TF Serving, which is native to Python. Building a lightweight, specialized gateway using a high-performance Python framework like FastAPI gave us the best of both worlds: full control over the request lifecycle, access to a mature library ecosystem, and exceptional performance thanks to its asyncio foundation and uvicorn server. It acts as our “plugin” host.

The final architecture looks like this:

graph TD
    subgraph "API Gateway (FastAPI on K8s)"
        A[Ingress: /predict] --> B{Orchestration Logic};
        B -- 1. Extract User/Transaction ID --> C{Async Neo4j Client};
        B -- 3. Forward Features --> D{Async TF Serving gRPC Client};
    end

    subgraph "Backend Systems"
        E[Neo4j AuraDB Cluster];
        F[TensorFlow Serving on K8s];
    end

    C -- 2. Cypher Query for Graph Features --> E;
    E -- Graph Features (Vector) --> C;
    C --> B;
    D -- 4. gRPC Inference Request --> F;
    F -- 5. Prediction Score --> D;
    D --> B;
    B -- 6. Final Response --> G[Client];

    style A fill:#f9f,stroke:#333,stroke-width:2px
    style F fill:#FFC300,stroke:#333,stroke-width:2px
    style E fill:#008cc1,stroke:#333,stroke-width:2px

Core Implementation: The Gateway Orchestrator

The entire system hinges on the gateway’s ability to perform its tasks asynchronously. Any blocking I/O call to Neo4j or TF Serving would stall the server’s event loop, destroying concurrency and throughput. Python’s asyncio is not optional here; it’s a fundamental requirement.

1. Configuration and Dependencies

A production system cannot have hardcoded values. We use Pydantic for settings management, loading configuration from environment variables. This makes the application portable across different environments (dev, staging, prod).

# requirements.txt
fastapi
uvicorn[standard]
pydantic
neo4j==5.14.0 # Make sure it's an async-compatible version
grpcio
tensorflow-serving-api
structlog # For structured logging
aiohttp # For health checks
# app/config.py
import os
from pydantic import BaseModel, Field

class Neo4jSettings(BaseModel):
    uri: str = Field(..., env="NEO4J_URI")
    user: str = Field(..., env="NEO4J_USER")
    password: str = Field(..., env="NEO4J_PASSWORD")
    db: str = Field("neo4j", env="NEO4J_DATABASE")

class TensorFlowServingSettings(BaseModel):
    host: str = Field("localhost", env="TFS_HOST")
    port: int = Field(8500, env="TFS_PORT")
    model_name: str = Field(..., env="TFS_MODEL_NAME")
    timeout_seconds: int = Field(2, env="TFS_TIMEOUT_SECONDS")

class AppSettings(BaseModel):
    neo4j: Neo4jSettings = Neo4jSettings()
    tfs: TensorFlowServingSettings = TensorFlowServingSettings()

# The single source of truth for configuration
settings = AppSettings()

This setup ensures that a missing environment variable will cause a validation error on startup, preventing runtime failures due to misconfiguration. A common mistake is to handle configuration lazily, which can lead to hard-to-debug issues in a running system.

2. The Asynchronous Neo4j Feature Extractor

This component is responsible for connecting to Neo4j and executing a Cypher query to generate features. The official neo4j Python driver supports asyncio out of the box.

The key challenge is writing a Cypher query that is both powerful and performant. A poorly written query can easily become the bottleneck.

# app/features/graph_features.py
import asyncio
from typing import List, Dict, Any
import logging

from neo4j import AsyncGraphDatabase, AsyncDriver

from ..config import settings

logger = logging.getLogger(__name__)

class GraphFeatureExtractor:
    _driver: AsyncDriver = None

    @classmethod
    async def get_driver(cls) -> AsyncDriver:
        if cls._driver is None:
            logger.info(f"Initializing Neo4j driver for URI: {settings.neo4j.uri}")
            cls._driver = AsyncGraphDatabase.driver(
                settings.neo4j.uri,
                auth=(settings.neo4j.user, settings.neo4j.password)
            )
        return cls._driver

    @classmethod
    async def close_driver(cls):
        if cls._driver:
            logger.info("Closing Neo4j driver.")
            await cls._driver.close()
            cls._driver = None

    async def get_features(self, user_id: str, device_id: str) -> List[float]:
        """
        Generates real-time graph features for a given user and device.
        This is where the core business logic for feature extraction resides.
        """
        driver = await self.get_driver()
        
        # This query is the secret sauce. It must be optimized.
        # It calculates features like:
        # 1. Degree of the user node.
        # 2. If the device has been used by more than N other users (potential device farm).
        # 3. Shortest path distance to a known fraudulent user cluster.
        # NOTE: In production, this query would be far more complex.
        cypher_query = """
        MATCH (u:User {id: $userId})
        OPTIONAL MATCH (d:Device {id: $deviceId})
        
        // Feature 1: User's transaction velocity (e.g., count of recent transactions)
        CALL {
            WITH u
            MATCH (u)-[:PERFORMED]->(t:Transaction)
            WHERE t.timestamp > timestamp() - 3600000 // last hour
            RETURN count(t) as recent_transactions
        }
        
        // Feature 2: Device sharing count
        CALL {
            WITH d
            MATCH (d)<-[:USED]-(other_user:User)
            RETURN count(other_user) as device_sharing_count
        }

        // Feature 3: Proximity to known fraud ring
        CALL {
            WITH u
            MATCH (fraud_ring:FraudRing)
            MATCH p = shortestPath((u)-[*]-(fraud_ring))
            // Return a large number if no path is found
            RETURN coalesce(length(p), 99) as fraud_proximity
        }

        RETURN recent_transactions, device_sharing_count, fraud_proximity
        """
        
        params = {"userId": user_id, "deviceId": device_id}

        try:
            # Using an async session to avoid blocking the event loop
            async with driver.session(database=settings.neo4j.db) as session:
                result = await session.run(cypher_query, params)
                record = await result.single()
                
                if not record:
                    # A critical case to handle: user or device not found.
                    # Return a default feature vector representing an unknown entity.
                    logger.warning(f"No graph features found for user {user_id}. Returning default vector.")
                    return [0.0, 1.0, 99.0]
                
                # Ensure features are floats for the ML model
                features = [float(v) for v in record.values()]
                return features

        except asyncio.TimeoutError:
            logger.error(f"Timeout while querying Neo4j for user {user_id}.")
            # A fallback strategy is crucial for system resilience.
            # Returning a default vector allows the system to still make a (less accurate) prediction.
            return [0.0, 1.0, 99.0]
        except Exception as e:
            logger.exception(f"Error querying Neo4j for user {user_id}: {e}")
            # In case of other DB errors, also return the default vector.
            return [0.0, 1.0, 99.0]

# Instantiate a singleton for use in the API endpoint
feature_extractor = GraphFeatureExtractor()

The pitfall here is resource management. Creating a new driver for each request would be disastrous. We use a class-level singleton pattern to initialize the driver once and reuse it across all requests. The close_driver method is registered with FastAPI’s shutdown event handler to ensure graceful termination. The error handling is also non-negotiable; a failure in the feature store should not cause a 500 Internal Server Error. Instead, it should return a default feature vector, allowing the system to degrade gracefully.

3. The Asynchronous TensorFlow Serving Client

Communicating with TF Serving via gRPC is significantly more performant than using its REST API, but it’s also more complex. It requires building request objects using Protocol Buffers.

# app/models/tfs_client.py
import grpc
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import logging

from ..config import settings

logger = logging.getLogger(__name__)

class TFServingClient:
    _stub = None
    
    @classmethod
    def get_stub(cls):
        if cls._stub is None:
            logger.info(f"Initializing gRPC channel to TF Serving at {settings.tfs.host}:{settings.tfs.port}")
            # Use an async channel for non-blocking communication
            channel = grpc.aio.insecure_channel(f"{settings.tfs.host}:{settings.tfs.port}")
            cls._stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
        return cls._stub

    async def predict(self, feature_vector: list[float]) -> dict:
        stub = self.get_stub()
        
        # Step 1: Create the gRPC request object
        request = predict_pb2.PredictRequest()
        request.model_spec.name = settings.tfs.model_name
        request.model_spec.signature_name = 'serving_default'

        # Step 2: Populate the request with input data.
        # The key 'input_1' must match the name of the input layer in the saved model.
        # This is a common point of error.
        tensor = tf.make_tensor_proto(feature_vector, dtype=tf.float32)
        request.inputs['graph_features_input'].CopyFrom(tensor)

        try:
            # Step 3: Make the async gRPC call with a timeout
            result_future = await stub.Predict(request, timeout=settings.tfs.timeout_seconds)
            
            # Step 4: Unpack the prediction from the response protobuf
            # The output key 'dense_1' must match the model's output layer name.
            outputs_tensor_proto = result_future.outputs['fraud_probability']
            prediction = tf.make_ndarray(outputs_tensor_proto)[0]
            
            return {"fraud_probability": float(prediction)}

        except grpc.aio.AioRpcError as e:
            logger.error(f"gRPC call to TF Serving failed: {e.details()} (code: {e.code()})")
            # If the model service is down or times out, we need a fallback.
            # Returning a specific error allows upstream systems to handle it.
            # Do NOT return a default score, as that could be misleading.
            return {"error": "Prediction service unavailable", "details": e.details()}
        except Exception as e:
            logger.exception("An unexpected error occurred during prediction.")
            return {"error": "Unexpected prediction error", "details": str(e)}

# Singleton instance
tfs_client = TFServingClient()

The most critical part of this code is matching the input/output tensor names (graph_features_input, fraud_probability) with the model’s signature. A mismatch here will result in cryptic gRPC errors. This requires close collaboration between the data science team that builds the model and the engineering team that builds the serving layer. We established a contract, checked into source control, that defines these signatures.

4. The FastAPI Gateway Endpoint

This is where everything comes together. The endpoint orchestrates the calls to the two clients. The use of asyncio.gather is crucial for performance. It allows us to fire off the graph feature query and potentially other independent I/O-bound tasks concurrently, rather than sequentially.

# app/main.py
import logging
import time

from fastapi import FastAPI, Request
from pydantic import BaseModel

from .config import settings
from .features.graph_features import feature_extractor
from .models.tfs_client import tfs_client

# Configure structured logging early
# (Configuration for structlog would go here)
logging.basicConfig(level="INFO")
logger = logging.getLogger(__name__)

app = FastAPI()

@app.on_event("startup")
async def startup_event():
    # This pre-connects and warms up the connection pools.
    await feature_extractor.get_driver()
    tfs_client.get_stub()
    logger.info("Application startup complete. Services initialized.")

@app.on_event("shutdown")
async def shutdown_event():
    await feature_extractor.close_driver()
    logger.info("Application shutdown complete. Connections closed.")

class PredictionRequest(BaseModel):
    user_id: str
    device_id: str
    # Other static features can be included here
    transaction_amount: float

class PredictionResponse(BaseModel):
    request_id: str
    fraud_probability: float | None
    error: str | None
    latency_ms: float

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest, http_request: Request):
    start_time = time.perf_counter()
    request_id = http_request.headers.get("X-Request-ID", "N/A")

    logger.info("Received prediction request", extra={"user_id": request.user_id, "req_id": request_id})

    # Asynchronously fetch graph features
    graph_features = await feature_extractor.get_features(request.user_id, request.device_id)

    # In a real system, you might combine graph features with static features from the request
    # For simplicity, we're only using graph features here.
    feature_vector = graph_features # + [request.transaction_amount]

    # Asynchronously get prediction from the model
    prediction_result = await tfs_client.predict(feature_vector)
    
    if "error" in prediction_result:
        latency = (time.perf_counter() - start_time) * 1000
        return PredictionResponse(
            request_id=request_id,
            fraud_probability=None,
            error=prediction_result["error"],
            latency_ms=latency
        )

    latency = (time.perf_counter() - start_time) * 1000
    logger.info(f"Prediction successful for request {request_id}", extra={"latency_ms": latency})

    return PredictionResponse(
        request_id=request_id,
        fraud_probability=prediction_result["fraud_probability"],
        error=None,
        latency_ms=latency
    )

@app.get("/health")
async def health_check():
    # In a real system, this would also check connectivity to Neo4j and TF Serving
    return {"status": "ok"}

This endpoint is the public face of our system. It’s lean and its only job is orchestration. All the heavy lifting is delegated. The structured logging includes a request ID, which is indispensable for tracing a single request’s journey across logs from multiple systems in a distributed environment.

Testing Strategy

A common mistake is to only test components in isolation. For this architecture, integration testing is paramount.

  1. Unit Tests: We use pytest with pytest-asyncio. The Neo4j and TF Serving clients are tested by mocking their respective drivers/stubs. We test the happy path, error conditions (e.g., DB down, model not found), and fallback logic (e.g., returning default features).
  2. Integration Tests: This is the most important part. We use testcontainers to programmatically spin up Docker containers for Neo4j and TensorFlow Serving during the test suite execution.
    • The Neo4j container is pre-populated with a small, known graph structure.
    • The TF Serving container is loaded with a dummy model that has a known input/output signature.
    • The tests then make real HTTP requests to our FastAPI application (using httpx.AsyncClient) and assert that the entire flow—from HTTP request to Neo4j query to TF Serving inference to HTTP response—works as expected. This catches issues like configuration mismatches, network policy problems, and signature incompatibilities that unit tests would miss.

Performance and Lingering Issues

After deployment and load testing, the P99 latency for the /predict endpoint stabilized around 85ms, meeting our sub-100ms target. The vast majority of this time was spent in the Neo4j query execution (~60ms). The gRPC call to TF Serving was consistently under 10ms.

However, this architecture is not without its trade-offs and remaining challenges.

The most significant architectural debt we’ve incurred is placing complex orchestration logic within the gateway layer. While it met our performance goals, it blurs the line between infrastructure and application. The gateway is no longer a simple, “dumb” proxy. Any changes to the feature generation logic require a redeployment of this critical infrastructure component. This increases the risk profile of each deployment.

Furthermore, the current implementation has no caching. If the same user performs multiple transactions in a short period, we re-compute the same graph features every single time. A future iteration will explore introducing a caching layer (e.g., Redis) between the gateway and Neo4j. This introduces new complexities: What is the right TTL? How do we handle cache invalidation when the graph is updated? The trade-off between latency reduction and data freshness will need careful analysis.

Finally, while the fallback mechanisms (default feature vectors, error responses) make the system resilient, they also mask underlying problems. A sustained issue with Neo4j might not trigger alarms if the gateway is simply serving default features, leading to a silent degradation in model accuracy. This requires more sophisticated monitoring that tracks not just system uptime, but also the rate of fallback path execution, to provide a true picture of system health.


  TOC