Propagating SAML Assertions from Ruby on Rails to a PyTorch Service for Secure Model Inference


The technical debt was clear: our monolithic Ruby on Rails application, which successfully handles enterprise Single Sign-On via SAML for multi-tenant customers, was becoming a bottleneck for deploying specialized, high-performance services. The latest requirement was a document analysis feature powered by a PyTorch model. The data science team delivered a clean, stateless Python service, but the immediate and contentious question was authentication and authorization. The monolith holds the user session, including the original SAML assertion data from the Identity Provider (IdP). How do we securely get that identity context to the new Python service without compromising its stateless nature or creating a brittle, chatty dependency on the monolith’s database?

Our first whiteboard session produced the most obvious, and ultimately incorrect, solution: an internal API endpoint on the Rails monolith, say /api/internal/v1/auth_check, that the PyTorch service could call with a session token on every incoming request. In a real-world project, this pattern is an anti-pattern. It couples the services at the runtime level, makes the PyTorch service’s latency dependent on the monolith’s, and turns the monolith into a single point of failure for authentication. We would effectively be building a distributed monolith.

The second proposal was more standard: have the Rails application mint a short-lived JWT after a user logs in via SAML. The Rails backend would then pass this JWT to the PyTorch service. This is a significant improvement. It decouples the services, as the PyTorch service can validate the JWT signature using a shared public key. The pitfall here is subtle but critical for enterprise-grade security. The JWT is an assertion made by our system (the monolith). The original, cryptographically signed assertion from the customer’s IdP is lost. The PyTorch service must place its full trust in the monolith’s JWT-minting logic. For a service handling sensitive customer data, we wanted a stronger guarantee—an architecture where the service could independently verify the original authentication event.

This led to our final design: propagating the entire, signed SAML Response XML from the Rails session to the PyTorch service on each request. The Rails application acts as a pass-through for the authentication artifact, and the PyTorch service becomes responsible for its own security by validating the assertion directly against the customer’s IdP certificate. This aligns with zero-trust principles, where services don’t inherently trust each other and verify identity independently. It’s more complex to implement but provides a far more robust and auditable security posture.

Here is the data flow we settled on:

sequenceDiagram
    participant User
    participant Browser
    participant RailsApp as Ruby on Rails Monolith
    participant IdP as Customer Identity Provider
    participant PyTorchService as PyTorch/Flask Service

    User->>Browser: Accesses secure feature
    Browser->>RailsApp: GET /protected-feature
    Note over RailsApp: No session, initiates SSO
    RailsApp->>Browser: Redirect to IdP (SAMLRequest)
    Browser->>IdP: Redirect with SAMLRequest
    IdP-->>User: Prompts for credentials
    User-->>IdP: Enters credentials
    IdP->>Browser: Redirect to Rails ACS (SAMLResponse)
    Browser->>RailsApp: POST /saml/acs (SAMLResponse)
    Note over RailsApp: Validates SAMLResponse, creates session, stores entire XML
    RailsApp->>Browser: Renders feature page with JS

    Browser->>RailsApp: (AJAX) POST /api/v1/analyze_document
    Note over RailsApp: Controller retrieves SAML XML from session
    RailsApp->>PyTorchService: POST /infer (with X-SAML-Response header)
    Note over PyTorchService: Middleware intercepts request
    PyTorchService->>PyTorchService: Validates SAML Signature & Conditions
    Note over PyTorchService: Extracts tenant_id, user_email from Attributes
    PyTorchService->>PyTorchService: Loads tenant-specific model and performs inference
    PyTorchService-->>RailsApp: Returns inference result (JSON)
    RailsApp-->>Browser: Returns inference result (JSON)

Part 1: Configuring Rails as the SAML Service Provider

The first step is ensuring the Rails monolith correctly handles the SAML SSO flow and, crucially, preserves the raw SAML Response XML in the user’s session. A common mistake is to parse the attributes, discard the original XML, and move on. For this pattern, the raw XML is the payload.

We use the ruby-saml gem. The configuration lives in an initializer.

# config/initializers/saml_idp.rb

# This configuration assumes a single IdP for simplicity.
# In a multi-tenant app, you would load these settings dynamically
# based on the request domain or a tenant identifier.

Rails.application.config.saml_settings = OneLogin::RubySaml::Settings.new

# Service Provider (SP) metadata. This is our Rails app.
Rails.application.config.saml_settings.assertion_consumer_service_url = "http://localhost:3000/saml/acs"
Rails.application.config.saml_settings.sp_entity_id                   = "urn:rails:sp:myapplication"
Rails.application.config.saml_settings.authn_context                   = "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport"
Rails.application.config.saml_settings.protocol_binding              = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"

# Identity Provider (IdP) metadata. This is provided by the customer.
# In a real-world project, this comes from a database, not hardcoded.
Rails.application.config.saml_settings.idp_entity_id                  = "https://idp.test.com/metadata"
Rails.application.config.saml_settings.idp_sso_target_url             = "https://idp.test.com/sso"
Rails.application.config.saml_settings.idp_slo_target_url             = "https://idp.test.com/slo"

# The IdP's public certificate for validating their signature.
# Store this securely.
idp_cert_path = Rails.root.join('config', 'saml', 'idp_cert.pem')
if File.exist?(idp_cert_path)
  Rails.application.config.saml_settings.idp_cert = File.read(idp_cert_path)
else
  Rails.logger.warn("SAML IdP certificate not found at #{idp_cert_path}")
end

# We need to sign our requests.
sp_key_path = Rails.root.join('config', 'saml', 'sp_key.pem')
sp_cert_path = Rails.root.join('config', 'saml', 'sp_cert.pem')
if File.exist?(sp_key_path) && File.exist?(sp_cert_path)
  Rails.application.config.saml_settings.certificate = File.read(sp_cert_path)
  Rails.application.config.saml_settings.private_key = File.read(sp_key_path)
  Rails.application.config.saml_settings.security[:authn_requests_signed] = true
  Rails.application.config.saml_settings.security[:logout_requests_signed] = true
  Rails.application.config.saml_settings.security[:want_assertions_signed] = true
  Rails.application.config.saml_settings.security[:signature_method] = XMLSecurity::Document::RSA_SHA256
else
  Rails.logger.warn("SAML SP key/cert pair not found. AuthnRequests will not be signed.")
end

The key modification is in the Assertion Consumer Service (ACS) controller, which receives the POST from the IdP.

# app/controllers/saml_controller.rb

class SamlController < ApplicationController
  # The IdP posts the SAMLResponse to this endpoint.
  skip_before_action :verify_authenticity_token, only: [:acs]

  def acs
    saml_response = OneLogin::RubySaml::Response.new(
      params[:SAMLResponse],
      settings: Rails.application.config.saml_settings
    )

    unless saml_response.is_valid?
      logger.error "Invalid SAML Response: #{saml_response.errors}"
      # In production, redirect to a more user-friendly error page.
      return render plain: "SAML Response was invalid.", status: :unauthorized
    end

    # Find or create the user based on the NameID or another attribute.
    user = User.find_or_create_by(email: saml_response.nameid)
    
    # Standard session creation
    session[:user_id] = user.id
    session[:user_email] = user.email
    session[:saml_attributes] = saml_response.attributes.to_h

    # --- CRITICAL STEP ---
    # Store the raw, original, Base64-encoded SAML Response.
    # This is what we will propagate to downstream services.
    session[:raw_saml_response] = params[:SAMLResponse]
    
    logger.info "SAML login successful for #{user.email}. Session created."
    
    # Redirect to the originally requested URL or a dashboard.
    redirect_to root_path
  end

  def metadata
    meta = OneLogin::RubySaml::Metadata.new
    render xml: meta.generate(Rails.application.config.saml_settings), content_type: 'application/samlmetadata+xml'
  end
end

Part 2: The Rails Proxy API Endpoint

Next, we create the internal API endpoint that the frontend will call. This endpoint acts as a secure proxy, attaching the authentication artifact before forwarding the request.

# app/controllers/api/v1/inference_controller.rb

module Api
  module V1
    class InferenceController < ApplicationController
      # This controller requires an authenticated user session.
      before_action :require_login
      
      # Define the target service URL. In production, this comes from config/secrets.
      PYTORCH_SERVICE_URL = 'http://localhost:5000/infer'.freeze

      def analyze
        # 1. Check for the required SAML response in the session.
        raw_saml_response = session[:raw_saml_response]
        unless raw_saml_response
          logger.warn "User #{current_user.id} attempted to access inference API without a SAML session."
          return render json: { error: 'Authentication context is missing' }, status: :unauthorized
        end

        # 2. Prepare the request payload for the PyTorch service.
        # We simply forward the original request body from the client.
        request_payload = request.body.read
        
        # 3. Make the proxied request.
        begin
          response = RestClient.post(
            PYTORCH_SERVICE_URL,
            request_payload,
            {
              # The crucial header containing the authentication artifact.
              'X-SAML-Response': raw_saml_response,
              'Content-Type': 'application/json',
              'Accept': 'application/json'
            }
          )
          
          # Forward the response from the PyTorch service back to the client.
          render json: JSON.parse(response.body), status: response.code
        
        rescue RestClient::ExceptionWithResponse => e
          logger.error "Error proxying to PyTorch service: #{e.response}"
          render json: { error: 'Inference service failed' }, status: :internal_server_error
        rescue RestClient::Exception => e
          logger.error "Network error connecting to PyTorch service: #{e.message}"
          render json: { error: 'Could not connect to inference service' }, status: :service_unavailable
        end
      end

      private

      def require_login
        unless current_user
          render json: { error: 'Not authenticated' }, status: :unauthorized
        end
      end
    end
  end
end

Part 3: The PyTorch Service with SAML Validation

Now for the Python side. We’ll use Flask as a lightweight web server to wrap our PyTorch model. The security logic will be implemented in a decorator using the python3-saml library. This is where the zero-trust principle is enforced.

First, the project structure and setup:

pytorch_service/
├── app.py
├── auth.py
├── saml/
│   ├── settings.json
│   ├── idp_cert.pem
│   └── sp/
│       ├── sp.crt
│       └── sp.key
└── models/
    └── tenant_123_model.pth

The saml/settings.json is critical. Even though we are not a full SP handling redirects, the library needs this configuration to know about the IdP and how to validate the assertion.

// saml/settings.json
{
    "strict": true,
    "debug": false,
    "sp": {
        "entityId": "urn:rails:sp:myapplication",
        "assertionConsumerService": {
            "url": "http://localhost:3000/saml/acs",
            "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
        },
        "x509cert": "",
        "privateKey": ""
    },
    "idp": {
        "entityId": "https://idp.test.com/metadata",
        "singleSignOnService": {
            "url": "https://idp.test.com/sso",
            "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
        },
        "x509cert": "-----BEGIN CERTIFICATE-----\nMIIC...YOUR...IDP...CERT...HERE...END CERTIFICATE-----\n"
    },
    "security": {
        "wantAssertionsSigned": true,
        "signatureAlgorithm": "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256",
        "digestAlgorithm": "http://www.w3.org/2001/04/xmlenc#sha256"
    }
}

The core of the security logic resides in auth.py. We’ll build a decorator that handles the entire validation process.

# auth.py

import os
from functools import wraps
from flask import request, g, jsonify, current_app
from onelogin.saml2.auth import OneLogin_Saml2_Auth
from onelogin.saml2.utils import OneLogin_Saml2_Utils

def prepare_saml_request(req):
    """
    Prepares the SAML Auth object from the Flask request context.
    This is required by the python3-saml library.
    """
    # This is a bit of a hack. The library is designed for full browser flows.
    # We construct the necessary environment for it to work in an API context.
    url_data = OneLogin_Saml2_Utils.get_self_url_host(req.environ)
    return {
        'https': 'on' if req.scheme == 'https' else 'off',
        'http_host': req.host,
        'script_name': req.path,
        'get_data': {},
        'post_data': {
            # We inject the SAMLResponse from our custom header here.
            'SAMLResponse': req.headers.get('X-SAML-Response')
        }
    }

def saml_login_required(f):
    """
    A decorator to protect Flask endpoints. It validates the SAML assertion
    passed in the 'X-SAML-Response' header.
    """
    @wraps(f)
    def decorated_function(*args, **kwargs):
        saml_response_b64 = request.headers.get('X-SAML-Response')
        if not saml_response_b64:
            current_app.logger.warning("Auth decorator: Missing X-SAML-Response header.")
            return jsonify({'error': 'Missing SAML authentication token'}), 401

        # Prepare the request context for the SAML library.
        req = prepare_saml_request(request)
        
        # Load settings from JSON file. In production, consider caching this.
        # This path assumes the 'saml' folder is at the same level as app.py
        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'saml')
        auth = OneLogin_Saml2_Auth(req, old_settings=None, custom_base_path=path)

        # Process the response (this performs the validation).
        auth.process_response()

        # Check for errors.
        if auth.get_errors():
            error_reason = auth.get_last_error_reason()
            current_app.logger.error(f"SAML validation failed: {error_reason}")
            return jsonify({'error': 'Invalid SAML assertion', 'details': error_reason}), 401
        
        # If valid, extract attributes and store them in the request context (g).
        if not auth.is_authenticated():
            current_app.logger.warning("SAML assertion processed but not authenticated.")
            return jsonify({'error': 'Authentication failed'}), 401

        # The 'g' object is a per-request global context in Flask.
        g.saml_attributes = auth.get_attributes()
        g.saml_nameid = auth.get_nameid()
        current_app.logger.info(f"Successfully authenticated user {g.saml_nameid} via SAML assertion.")

        return f(*args, **kwargs)

    return decorated_function

Finally, the main Flask application app.py ties everything together. It defines the protected /infer endpoint and uses the attributes extracted by the decorator.

# app.py

import logging
from flask import Flask, request, jsonify, g
from auth import saml_login_required

# In a real app, the PyTorch model loading and inference logic
# would be more sophisticated.
# import torch
# model_cache = {}

app = Flask(__name__)
logging.basicConfig(level=logging.INFO)

@app.route('/health', methods=['GET'])
def health_check():
    return jsonify({"status": "ok"}), 200

@app.route('/infer', methods=['POST'])
@saml_login_required
def perform_inference():
    """
    This endpoint is protected by our SAML decorator.
    If the code reaches here, the SAML assertion was valid.
    """
    # We can now safely use the identity information from the request context.
    # A common pitfall is to not validate the presence of required attributes.
    if 'tenant_id' not in g.saml_attributes or not g.saml_attributes['tenant_id']:
        app.logger.error(f"User {g.saml_nameid} authenticated but is missing a 'tenant_id' attribute.")
        return jsonify({'error': 'Missing tenant identifier in SAML assertion'}), 403
    
    tenant_id = g.saml_attributes['tenant_id'][0] # Attributes are lists
    user_email = g.saml_nameid
    
    app.logger.info(f"Inference request for tenant '{tenant_id}' by user '{user_email}'.")

    # Use the tenant_id to load the correct model or apply tenant-specific logic.
    # This is a critical step for data isolation in a multi-tenant environment.
    # model_path = f"models/tenant_{tenant_id}_model.pth"
    # if tenant_id not in model_cache:
    #     try:
    #         model_cache[tenant_id] = torch.load(model_path)
    #     except FileNotFoundError:
    #         app.logger.error(f"Model file not found for tenant: {tenant_id}")
    #         return jsonify({'error': 'Inference model not available for this tenant'}), 500
    
    # model = model_cache[tenant_id]
    
    # Get the input data from the request.
    input_data = request.get_json()
    if not input_data:
        return jsonify({'error': 'Invalid JSON payload'}), 400

    # --- Dummy Inference Logic ---
    # In a real app, you would pass input_data to your model.
    # result = model(input_data)
    analysis_result = {
        "user": user_email,
        "tenant": tenant_id,
        "document_id": input_data.get("doc_id"),
        "confidence": 0.95,
        "entities": ["Person: John Doe", "Organization: Acme Corp"]
    }
    # ----------------------------

    return jsonify(analysis_result), 200


if __name__ == '__main__':
    # For production, use a proper WSGI server like Gunicorn.
    app.run(host='0.0.0.0', port=5000)

Part 4: The Testing Approach

A system like this requires robust testing, especially on the security boundaries.

  1. Rails Controller Test (spec/requests/api/v1/inference_controller_spec.rb):

    • Use RSpec to test the proxy controller.
    • Stub the RestClient.post call to avoid actual network requests.
    • Test the “happy path”: a logged-in user with a raw_saml_response in their session. Verify that the X-SAML-Response header is correctly passed to the stubbed service.
    • Test the failure path: a user without a session, or a user whose session is missing the SAML data. Ensure it returns a 401 or 403 and does not attempt to call the downstream service.
  2. Python Auth Decorator Test (tests/test_auth.py):

    • This is the most critical set of tests.
    • Obtain a valid, signed SAML Response XML from a test IdP. Save this as a fixture.
    • Use Pytest and the Flask test client.
    • Test 1: Send a request with a valid X-SAML-Response header. Assert that the endpoint returns 200 OK and that the logic inside the endpoint (which can be mocked) was called.
    • Test 2: Send a request with no header. Assert a 401 Unauthorized response.
    • Test 3: Manipulate the fixture XML to invalidate the signature. Assert a 401 response.
    • Test 4: Change the NotOnOrAfter timestamp in the fixture to be in the past. Assert a 401 response for an expired assertion.
    • Test 5: Change the Audience element in the fixture to something other than our SP entity ID. Assert a 401 response.

This architecture, while more complex upfront, created a robustly decoupled system. The PyTorch service is truly stateless, its security is self-contained and based on a widely adopted cryptographic standard, and the Rails monolith’s role is simplified to that of a trusted proxy.

The main limitation of this approach is the overhead of transmitting and parsing the SAML XML on every API call. For very high-frequency, low-latency services, this could become a bottleneck. A potential optimization path would be a hybrid model: the PyTorch service could validate the full SAML assertion on the first request and then mint its own, very short-lived (e.g., 60 seconds) scoped JWT for the client to use on subsequent calls. This preserves the strong initial authentication while improving performance for a burst of related requests. Furthermore, managing the IdP certificates on the PyTorch service instances requires a mature secret management infrastructure, like HashiCorp Vault, to avoid configuration drift and insecurely stored credentials. The presented solution does not address SAML Single Log-Out (SLO); if an IdP session is terminated, the PyTorch service will continue to accept the assertion until it expires, a trade-off that is often acceptable for internal service-to-service communication.


  TOC