The initial system architecture was failing under its own weight. We were serving a fraud detection model where feature staleness was directly correlated with financial loss. Our existing feature store, built on a key-value paradigm, could only serve pre-calculated, flattened data. It was blind to the intricate, real-time relationships between entities—a classic graph problem. A request to score a transaction for user U1
couldn’t answer questions like, “Is U1
connected within three hops to a known money laundering ring identified five minutes ago?” or “Does the device U1
is using share a fingerprint with any account that has been flagged for chargebacks in the last hour?”. These are not features you can batch-process overnight; they must be computed at the moment of the transaction.
Our first attempt involved creating a monolithic Python service. It took an incoming request, queried Neo4j for graph features, bundled them with static features, called a local TensorFlow model for inference, and returned a score. While functionally correct, this approach was an operational nightmare. Redeploying the model required a full service restart. A bug in the feature generation logic could crash the entire prediction endpoint. Scaling was inefficient—we had to scale the entire monolith even if the bottleneck was only the CPU-intensive model inference. The coupling was too tight.
This led to a new architectural hypothesis: decouple feature generation, model inference, and client-facing orchestration. The natural home for the model was TensorFlow Serving—a dedicated, high-performance system for this exact purpose. The graph data belonged in Neo4j, optimized for real-time traversals. The critical missing piece was the orchestrator. The obvious choice, an intermediate microservice, simply moved the problem, adding another network hop and another component to maintain.
The real technical challenge was to perform this complex, multi-system orchestration with sub-100ms P99 latency. Adding another service was counterproductive to this goal. The solution had to lie within the infrastructure layer itself. We decided to embed this orchestration logic directly into our API Gateway. Not as a simple proxy rule, but as a custom, high-performance plugin that would become an active participant in the request lifecycle. This is the build log of that system.
Technology Selection Rationale
In a real-world project, every technology choice is a trade-off. Here’s why we landed on this specific stack:
Neo4j: We evaluated relational databases and other NoSQL options. A relational approach for deep, variable-depth relationship queries would involve multiple, unpredictable
JOIN
operations, which are notoriously slow at scale. Neo4j’s native graph storage and processing engine, with its index-free adjacency, is purpose-built for this kind of traversal. For our use case, the performance difference between a Cypher query and a recursive SQL CTE was an order of magnitude.TensorFlow Serving: We needed to decouple model updates from our application code. TF Serving provides this out-of-the-box with gRPC endpoints for high throughput and low latency. It handles model versioning, rollouts, and hardware acceleration (GPU utilization) transparently, freeing the application team from low-level MLOps concerns.
Custom API Gateway (FastAPI): We considered extending an existing gateway like Kong or Apache APISIX with a Lua plugin. While powerful, the learning curve for our team and the debugging complexity of Lua in a high-concurrency environment were significant risks. Furthermore, our orchestration logic required a rich ecosystem of libraries for connecting to Neo4j and TF Serving, which is native to Python. Building a lightweight, specialized gateway using a high-performance Python framework like FastAPI gave us the best of both worlds: full control over the request lifecycle, access to a mature library ecosystem, and exceptional performance thanks to its
asyncio
foundation anduvicorn
server. It acts as our “plugin” host.
The final architecture looks like this:
graph TD subgraph "API Gateway (FastAPI on K8s)" A[Ingress: /predict] --> B{Orchestration Logic}; B -- 1. Extract User/Transaction ID --> C{Async Neo4j Client}; B -- 3. Forward Features --> D{Async TF Serving gRPC Client}; end subgraph "Backend Systems" E[Neo4j AuraDB Cluster]; F[TensorFlow Serving on K8s]; end C -- 2. Cypher Query for Graph Features --> E; E -- Graph Features (Vector) --> C; C --> B; D -- 4. gRPC Inference Request --> F; F -- 5. Prediction Score --> D; D --> B; B -- 6. Final Response --> G[Client]; style A fill:#f9f,stroke:#333,stroke-width:2px style F fill:#FFC300,stroke:#333,stroke-width:2px style E fill:#008cc1,stroke:#333,stroke-width:2px
Core Implementation: The Gateway Orchestrator
The entire system hinges on the gateway’s ability to perform its tasks asynchronously. Any blocking I/O call to Neo4j or TF Serving would stall the server’s event loop, destroying concurrency and throughput. Python’s asyncio
is not optional here; it’s a fundamental requirement.
1. Configuration and Dependencies
A production system cannot have hardcoded values. We use Pydantic for settings management, loading configuration from environment variables. This makes the application portable across different environments (dev, staging, prod).
# requirements.txt
fastapi
uvicorn[standard]
pydantic
neo4j==5.14.0 # Make sure it's an async-compatible version
grpcio
tensorflow-serving-api
structlog # For structured logging
aiohttp # For health checks
# app/config.py
import os
from pydantic import BaseModel, Field
class Neo4jSettings(BaseModel):
uri: str = Field(..., env="NEO4J_URI")
user: str = Field(..., env="NEO4J_USER")
password: str = Field(..., env="NEO4J_PASSWORD")
db: str = Field("neo4j", env="NEO4J_DATABASE")
class TensorFlowServingSettings(BaseModel):
host: str = Field("localhost", env="TFS_HOST")
port: int = Field(8500, env="TFS_PORT")
model_name: str = Field(..., env="TFS_MODEL_NAME")
timeout_seconds: int = Field(2, env="TFS_TIMEOUT_SECONDS")
class AppSettings(BaseModel):
neo4j: Neo4jSettings = Neo4jSettings()
tfs: TensorFlowServingSettings = TensorFlowServingSettings()
# The single source of truth for configuration
settings = AppSettings()
This setup ensures that a missing environment variable will cause a validation error on startup, preventing runtime failures due to misconfiguration. A common mistake is to handle configuration lazily, which can lead to hard-to-debug issues in a running system.
2. The Asynchronous Neo4j Feature Extractor
This component is responsible for connecting to Neo4j and executing a Cypher query to generate features. The official neo4j
Python driver supports asyncio
out of the box.
The key challenge is writing a Cypher query that is both powerful and performant. A poorly written query can easily become the bottleneck.
# app/features/graph_features.py
import asyncio
from typing import List, Dict, Any
import logging
from neo4j import AsyncGraphDatabase, AsyncDriver
from ..config import settings
logger = logging.getLogger(__name__)
class GraphFeatureExtractor:
_driver: AsyncDriver = None
@classmethod
async def get_driver(cls) -> AsyncDriver:
if cls._driver is None:
logger.info(f"Initializing Neo4j driver for URI: {settings.neo4j.uri}")
cls._driver = AsyncGraphDatabase.driver(
settings.neo4j.uri,
auth=(settings.neo4j.user, settings.neo4j.password)
)
return cls._driver
@classmethod
async def close_driver(cls):
if cls._driver:
logger.info("Closing Neo4j driver.")
await cls._driver.close()
cls._driver = None
async def get_features(self, user_id: str, device_id: str) -> List[float]:
"""
Generates real-time graph features for a given user and device.
This is where the core business logic for feature extraction resides.
"""
driver = await self.get_driver()
# This query is the secret sauce. It must be optimized.
# It calculates features like:
# 1. Degree of the user node.
# 2. If the device has been used by more than N other users (potential device farm).
# 3. Shortest path distance to a known fraudulent user cluster.
# NOTE: In production, this query would be far more complex.
cypher_query = """
MATCH (u:User {id: $userId})
OPTIONAL MATCH (d:Device {id: $deviceId})
// Feature 1: User's transaction velocity (e.g., count of recent transactions)
CALL {
WITH u
MATCH (u)-[:PERFORMED]->(t:Transaction)
WHERE t.timestamp > timestamp() - 3600000 // last hour
RETURN count(t) as recent_transactions
}
// Feature 2: Device sharing count
CALL {
WITH d
MATCH (d)<-[:USED]-(other_user:User)
RETURN count(other_user) as device_sharing_count
}
// Feature 3: Proximity to known fraud ring
CALL {
WITH u
MATCH (fraud_ring:FraudRing)
MATCH p = shortestPath((u)-[*]-(fraud_ring))
// Return a large number if no path is found
RETURN coalesce(length(p), 99) as fraud_proximity
}
RETURN recent_transactions, device_sharing_count, fraud_proximity
"""
params = {"userId": user_id, "deviceId": device_id}
try:
# Using an async session to avoid blocking the event loop
async with driver.session(database=settings.neo4j.db) as session:
result = await session.run(cypher_query, params)
record = await result.single()
if not record:
# A critical case to handle: user or device not found.
# Return a default feature vector representing an unknown entity.
logger.warning(f"No graph features found for user {user_id}. Returning default vector.")
return [0.0, 1.0, 99.0]
# Ensure features are floats for the ML model
features = [float(v) for v in record.values()]
return features
except asyncio.TimeoutError:
logger.error(f"Timeout while querying Neo4j for user {user_id}.")
# A fallback strategy is crucial for system resilience.
# Returning a default vector allows the system to still make a (less accurate) prediction.
return [0.0, 1.0, 99.0]
except Exception as e:
logger.exception(f"Error querying Neo4j for user {user_id}: {e}")
# In case of other DB errors, also return the default vector.
return [0.0, 1.0, 99.0]
# Instantiate a singleton for use in the API endpoint
feature_extractor = GraphFeatureExtractor()
The pitfall here is resource management. Creating a new driver for each request would be disastrous. We use a class-level singleton pattern to initialize the driver once and reuse it across all requests. The close_driver
method is registered with FastAPI’s shutdown event handler to ensure graceful termination. The error handling is also non-negotiable; a failure in the feature store should not cause a 500 Internal Server Error
. Instead, it should return a default feature vector, allowing the system to degrade gracefully.
3. The Asynchronous TensorFlow Serving Client
Communicating with TF Serving via gRPC is significantly more performant than using its REST API, but it’s also more complex. It requires building request objects using Protocol Buffers.
# app/models/tfs_client.py
import grpc
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import logging
from ..config import settings
logger = logging.getLogger(__name__)
class TFServingClient:
_stub = None
@classmethod
def get_stub(cls):
if cls._stub is None:
logger.info(f"Initializing gRPC channel to TF Serving at {settings.tfs.host}:{settings.tfs.port}")
# Use an async channel for non-blocking communication
channel = grpc.aio.insecure_channel(f"{settings.tfs.host}:{settings.tfs.port}")
cls._stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
return cls._stub
async def predict(self, feature_vector: list[float]) -> dict:
stub = self.get_stub()
# Step 1: Create the gRPC request object
request = predict_pb2.PredictRequest()
request.model_spec.name = settings.tfs.model_name
request.model_spec.signature_name = 'serving_default'
# Step 2: Populate the request with input data.
# The key 'input_1' must match the name of the input layer in the saved model.
# This is a common point of error.
tensor = tf.make_tensor_proto(feature_vector, dtype=tf.float32)
request.inputs['graph_features_input'].CopyFrom(tensor)
try:
# Step 3: Make the async gRPC call with a timeout
result_future = await stub.Predict(request, timeout=settings.tfs.timeout_seconds)
# Step 4: Unpack the prediction from the response protobuf
# The output key 'dense_1' must match the model's output layer name.
outputs_tensor_proto = result_future.outputs['fraud_probability']
prediction = tf.make_ndarray(outputs_tensor_proto)[0]
return {"fraud_probability": float(prediction)}
except grpc.aio.AioRpcError as e:
logger.error(f"gRPC call to TF Serving failed: {e.details()} (code: {e.code()})")
# If the model service is down or times out, we need a fallback.
# Returning a specific error allows upstream systems to handle it.
# Do NOT return a default score, as that could be misleading.
return {"error": "Prediction service unavailable", "details": e.details()}
except Exception as e:
logger.exception("An unexpected error occurred during prediction.")
return {"error": "Unexpected prediction error", "details": str(e)}
# Singleton instance
tfs_client = TFServingClient()
The most critical part of this code is matching the input/output tensor names (graph_features_input
, fraud_probability
) with the model’s signature. A mismatch here will result in cryptic gRPC errors. This requires close collaboration between the data science team that builds the model and the engineering team that builds the serving layer. We established a contract, checked into source control, that defines these signatures.
4. The FastAPI Gateway Endpoint
This is where everything comes together. The endpoint orchestrates the calls to the two clients. The use of asyncio.gather
is crucial for performance. It allows us to fire off the graph feature query and potentially other independent I/O-bound tasks concurrently, rather than sequentially.
# app/main.py
import logging
import time
from fastapi import FastAPI, Request
from pydantic import BaseModel
from .config import settings
from .features.graph_features import feature_extractor
from .models.tfs_client import tfs_client
# Configure structured logging early
# (Configuration for structlog would go here)
logging.basicConfig(level="INFO")
logger = logging.getLogger(__name__)
app = FastAPI()
@app.on_event("startup")
async def startup_event():
# This pre-connects and warms up the connection pools.
await feature_extractor.get_driver()
tfs_client.get_stub()
logger.info("Application startup complete. Services initialized.")
@app.on_event("shutdown")
async def shutdown_event():
await feature_extractor.close_driver()
logger.info("Application shutdown complete. Connections closed.")
class PredictionRequest(BaseModel):
user_id: str
device_id: str
# Other static features can be included here
transaction_amount: float
class PredictionResponse(BaseModel):
request_id: str
fraud_probability: float | None
error: str | None
latency_ms: float
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest, http_request: Request):
start_time = time.perf_counter()
request_id = http_request.headers.get("X-Request-ID", "N/A")
logger.info("Received prediction request", extra={"user_id": request.user_id, "req_id": request_id})
# Asynchronously fetch graph features
graph_features = await feature_extractor.get_features(request.user_id, request.device_id)
# In a real system, you might combine graph features with static features from the request
# For simplicity, we're only using graph features here.
feature_vector = graph_features # + [request.transaction_amount]
# Asynchronously get prediction from the model
prediction_result = await tfs_client.predict(feature_vector)
if "error" in prediction_result:
latency = (time.perf_counter() - start_time) * 1000
return PredictionResponse(
request_id=request_id,
fraud_probability=None,
error=prediction_result["error"],
latency_ms=latency
)
latency = (time.perf_counter() - start_time) * 1000
logger.info(f"Prediction successful for request {request_id}", extra={"latency_ms": latency})
return PredictionResponse(
request_id=request_id,
fraud_probability=prediction_result["fraud_probability"],
error=None,
latency_ms=latency
)
@app.get("/health")
async def health_check():
# In a real system, this would also check connectivity to Neo4j and TF Serving
return {"status": "ok"}
This endpoint is the public face of our system. It’s lean and its only job is orchestration. All the heavy lifting is delegated. The structured logging includes a request ID, which is indispensable for tracing a single request’s journey across logs from multiple systems in a distributed environment.
Testing Strategy
A common mistake is to only test components in isolation. For this architecture, integration testing is paramount.
- Unit Tests: We use
pytest
withpytest-asyncio
. The Neo4j and TF Serving clients are tested by mocking their respective drivers/stubs. We test the happy path, error conditions (e.g., DB down, model not found), and fallback logic (e.g., returning default features). - Integration Tests: This is the most important part. We use
testcontainers
to programmatically spin up Docker containers for Neo4j and TensorFlow Serving during the test suite execution.- The Neo4j container is pre-populated with a small, known graph structure.
- The TF Serving container is loaded with a dummy model that has a known input/output signature.
- The tests then make real HTTP requests to our FastAPI application (using
httpx.AsyncClient
) and assert that the entire flow—from HTTP request to Neo4j query to TF Serving inference to HTTP response—works as expected. This catches issues like configuration mismatches, network policy problems, and signature incompatibilities that unit tests would miss.
Performance and Lingering Issues
After deployment and load testing, the P99 latency for the /predict
endpoint stabilized around 85ms, meeting our sub-100ms target. The vast majority of this time was spent in the Neo4j query execution (~60ms). The gRPC call to TF Serving was consistently under 10ms.
However, this architecture is not without its trade-offs and remaining challenges.
The most significant architectural debt we’ve incurred is placing complex orchestration logic within the gateway layer. While it met our performance goals, it blurs the line between infrastructure and application. The gateway is no longer a simple, “dumb” proxy. Any changes to the feature generation logic require a redeployment of this critical infrastructure component. This increases the risk profile of each deployment.
Furthermore, the current implementation has no caching. If the same user performs multiple transactions in a short period, we re-compute the same graph features every single time. A future iteration will explore introducing a caching layer (e.g., Redis) between the gateway and Neo4j. This introduces new complexities: What is the right TTL? How do we handle cache invalidation when the graph is updated? The trade-off between latency reduction and data freshness will need careful analysis.
Finally, while the fallback mechanisms (default feature vectors, error responses) make the system resilient, they also mask underlying problems. A sustained issue with Neo4j might not trigger alarms if the gateway is simply serving default features, leading to a silent degradation in model accuracy. This requires more sophisticated monitoring that tracks not just system uptime, but also the rate of fallback path execution, to provide a true picture of system health.