Skip to main content
Reference

TensorFlow & ML Frameworks

TensorFlow vs PyTorch vs JAX vs MLX — framework comparison, ONNX conversion workflow, and when to use each for different deployment targets.

1. Framework Landscape (April 2026)

The ML framework ecosystem in 2026 is dominated by a handful of mature platforms, each with distinct use cases and community adoption:

Major Frameworks

  • TensorFlow (Google): Production-grade, end-to-end platform with extensive deployment options (TF Lite for mobile, TF Serving for production, TF.js for web)
  • PyTorch (Meta): Dynamic computation graphs, research-friendly, increasingly adopted in production with better deployment tooling
  • JAX (Google): Functional programming, NumPy-like API, popular in research and scientific computing
  • ONNX (Open Standard): Language-agnostic model format enabling cross-framework portability
  • MLX (Apple): Unified memory optimization for Apple Silicon, growing adoption in M-series development
  • Hugging Face Transformers: High-level library built on PyTorch/TensorFlow, dominates NLP and multimodal model distribution
  • PyTorch: Fastest-growing adoption, especially in startups and research. Meta’s continued investment and ecosystem maturity (TorchServe, ONNX export) making production use increasingly viable.
  • TensorFlow: Entrenched in Google, large enterprises, and production systems. Still dominant for mobile/edge (TF Lite), but facing PyTorch’s momentum.
  • ONNX: Becoming the de facto standard for model portability. Growing support across inference runtimes.
  • MLX: Rapidly adopted by Apple Silicon users; significant performance gains over generic frameworks on M-series chips.
  • Hugging Face: Effectively the standard distribution channel for transformer models (LLMs, vision models). Makes framework choice transparent to end users.

2. TensorFlow

What It Is

TensorFlow is Google’s end-to-end machine learning platform. It handles the entire ML lifecycle: data loading, model definition, training, optimization, and deployment. TensorFlow 2.x moved to eager execution (running code immediately) and a Pythonic API, making it more accessible than the graph-focused TensorFlow 1.x.

Core Concepts

  • Eager execution: Code runs immediately, like regular Python (TensorFlow 2.x default)
  • Keras API: High-level interface for defining models (sequential, functional, or subclassing)
  • tf.data: Efficient pipeline for data loading and preprocessing
  • tf.function: JIT compilation of Python functions to graphs for performance

Ecosystem

  • TensorFlow Lite: Mobile and embedded inference (Android, iOS)
  • TensorFlow Serving: Production serving with versioning, A/B testing, and monitoring
  • TensorFlow.js: Browser-based inference
  • TensorFlow Extended (TFX): Production ML pipelines with data validation, model analysis, and serving

Strengths

  • Production-ready: Mature deployment ecosystem, battle-tested at scale
  • Mobile/edge optimized: TF Lite enables inference on phones and embedded devices
  • Comprehensive tools: End-to-end platform from data to production
  • Performance: XLA compilation for optimization, quantization, and pruning support
  • Community and learning resources: Extensive documentation, courses, examples

Weaknesses

  • Steeper learning curve: API still feels large and sometimes verbose
  • Debugging: Graph-based execution (when using tf.function) can be harder to debug than eager code
  • Research flexibility: Dynamic graphs (Pytorch’s strength) require more boilerplate in TensorFlow
  • Newer features sometimes lag: Recent innovations often appear in PyTorch first

Used By

  • Google (search, recommendations)
  • Major tech companies (Airbnb, Twitter, Airbus)
  • Financial institutions (for risk modeling, fraud detection)
  • Production ML systems requiring mobile or edge deployment

Keras Integration

Keras is the official high-level API within TensorFlow. You typically write models using Keras, then train and deploy with TensorFlow infrastructure:

import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=32)

3. PyTorch

What It Is

PyTorch is Meta’s machine learning framework based on dynamic computation graphs. Unlike static graph frameworks, PyTorch builds the computation graph as code executes, enabling Pythonic, imperative programming. This makes it feel like writing regular Python with array operations.

Core Concepts

  • Dynamic graphs: Graphs are built on-the-fly during forward pass
  • torch.autograd: Automatic differentiation for backpropagation
  • Tensor: PyTorch’s fundamental data structure (like NumPy arrays, but GPU-aware)
  • nn.Module: Base class for all models and layers
  • torch.optim: Optimizers (SGD, Adam, etc.)

Strengths

  • Research-friendly: Intuitive, Pythonic syntax makes experimentation fast
  • Flexibility: Easy to write custom training loops and complex model architectures
  • Debugging: Eager execution makes debugging straightforward (inspect tensors directly)
  • Community: Explosive growth; now dominant in research and increasingly in industry
  • Ecosystem: Growing production tools (TorchServe, Lightning, Ray), excellent third-party libraries

Weaknesses

  • Production deployment: Historically required more work; TorchServe is newer and less battle-tested than TF Serving
  • Mobile: Limited TorchMobile support compared to TF Lite
  • Learning curve for production: Easy to prototype, harder to scale safely
  • Ecosystem fragmentation: Less integrated than TensorFlow (data pipelines, serving, monitoring are separate tools)

Used By

  • Meta (research, recommendations, content moderation)
  • Startups and researchers (OpenAI, Hugging Face, most academic labs)
  • Industry adoption rapidly growing (Tesla, Uber, Stripe, etc.)

Production Deployment: TorchServe

TorchServe is PyTorch’s production serving framework (introduced around 2020, now stable):

# Package model
torch-model-archiver --model-name my_model --version 1.0 --model-file model.py --serialized-file model.pt --handler my_handler.py

# Serve
torchserve --start --model-store model_store --models my_model.mar

Provides REST and gRPC endpoints, model versioning, and horizontal scaling.

Exporting to ONNX

PyTorch models can be exported to ONNX format for framework agnostic deployment:

import torch
import torch.onnx

model = MyModel()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=14)

3b. PyTorch Detailed Implementation Examples

PyTorch Fundamentals

PyTorch uses dynamic computation graphs, making it feel like writing regular Python:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define a model
class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Create model, optimizer, loss
model = SimpleNet(784, 128, 10)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Training loop
for epoch in range(10):
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
        # Forward pass
        logits = model(batch_x)
        loss = loss_fn(logits, batch_y)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# Save model
torch.save(model.state_dict(), "model.pt")

# Load model
model.load_state_dict(torch.load("model.pt"))

PyTorch vs TensorFlow Code Comparison

Example 1: Custom Training Loop

TensorFlow:

import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=32)

PyTorch:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(10):
    for batch_x, batch_y in train_loader:
        logits = model(batch_x)
        loss = loss_fn(logits, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Key difference: PyTorch is explicit (you write the loop), TensorFlow abstracts it (Keras API).

Example 2: Convolutional Network

TensorFlow:

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2,2)),
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2,2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

PyTorch:

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(64*5*5, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

Example 3: Transfer Learning

TensorFlow:

base_model = tf.keras.applications.MobileNetV2(
    input_shape=(224, 224, 3),
    include_top=False,
    weights='imagenet'
)

# Freeze base layers
base_model.trainable = False

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

PyTorch:

from torchvision import models

base_model = models.mobilenet_v2(pretrained=True)

# Freeze base layers
for param in base_model.parameters():
    param.requires_grad = False

# Replace head
num_features = base_model.classifier[-1].in_features
base_model.classifier = nn.Sequential(
    nn.Linear(num_features, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

4. ONNX Conversion Workflow

Why ONNX Matters

ONNX (Open Neural Network Exchange) lets you train in any framework and deploy everywhere. This is critical for production systems.

Complete Workflow: PyTorch → ONNX → Multiple Runtimes

import torch
import onnx
import onnxruntime as ort
import numpy as np
from torchvision import models

# Step 1: Load or train a PyTorch model
model = models.resnet50(pretrained=True)
model.eval()

# Step 2: Export to ONNX
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    model,
    dummy_input,
    "resnet50.onnx",
    input_names=['input'],
    output_names=['output'],
    opset_version=14,  # ONNX opset version
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

# Step 3: Verify ONNX model is valid
onnx_model = onnx.load("resnet50.onnx")
onnx.checker.check_model(onnx_model)
print("✓ ONNX model is valid")

# Step 4: Run inference with ONNX Runtime
session = ort.InferenceSession("resnet50.onnx")

# Input must be numpy array
test_input = np.random.randn(1, 3, 224, 224).astype(np.float32)

# Run inference
outputs = session.run(None, {"input": test_input})
print(f"Output shape: {outputs[0].shape}")
print(f"Top-5 class predictions: {np.argsort(outputs[0][0])[-5:]}")

# Step 5: Optimize ONNX for inference
from onnxruntime.transformers import optimizer
opt = optimizer.optimize_model("resnet50.onnx")
opt.save_model_to_file("resnet50_optimized.onnx")

# Step 6: Convert to platform-specific formats

# Convert to TensorRT (NVIDIA GPU)
import tensorrt as trt
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
parser.parse_from_file("resnet50.onnx")
engine = builder.build_engine(network)
# ... serialize engine ...

# Convert to CoreML (iOS/macOS)
import onnx_coreml
onnx_model = onnx.load("resnet50.onnx")
coreml_model = onnx_coreml.convert(onnx_model)
coreml_model.save("ResNet50.mlmodel")

# Convert to TF Lite (Mobile/Edge)
import onnx_tf
from onnx_tf.backend import prepare
tf_rep = prepare(onnx_model)
tf_rep.export_graph("resnet50_tflite")

TensorFlow → ONNX

import tensorflow as tf
import onnx
from onnx_tf.backend import prepare

# Save TF model
model = tf.keras.applications.ResNet50()
tf.saved_model.save(model, "resnet50_tf")

# Convert to ONNX using tf2onnx
import tf2onnx
import tensorflow as tf

concrete_func = tf.function(lambda x: model(x))
concrete_func = concrete_func.get_concrete_function(
    tf.TensorSpec([None, 224, 224, 3], tf.float32, name="input")
)

onnx_model, _ = tf2onnx.convert.from_concrete_function(concrete_func)
onnx.save(onnx_model, "resnet50.onnx")

5. Framework Comparison Table: TensorFlow vs PyTorch vs JAX vs MLX

Head-to-Head Comparison

AspectTensorFlowPyTorchJAXMLX
Learning CurveSteep (large API)Gentle (Python-like)Very steep (functional)Moderate (PyTorch-like)
Research AgilityGoodExcellent (dominant)Excellent (researchers love it)Growing (specialized)
Production MaturityMature, proven at scaleImproving (TorchServe good)Immature (research only)Emerging (Apple ecosystem)
DebuggingGraph execution trickyEager execution simpleFunctional paradigm complexStraightforward
Mobile/EdgeSuperior (TF Lite)Limited (TorchMobile)PoorExcellent (Apple native)
Deployment OptionsTF Serving, TF Lite, TF.js, TFXTorchServe, ONNXCustom (grad server)MLX native, ONNX export
Community SizeLarge, Google-backedLargest, Meta + ecosystemGrowing, research-focusedSmall, Apple ecosystem
PerformanceXLA compilation, goodtorch.compile, improvingAutomatic differentiation excellentExceptional on M-series
GPU SupportAll (NVIDIA, TPU, AMD via HIP)All (NVIDIA, AMD, Intel)GPU, TPU, CPUApple Neural Engine + GPU only
Training SpeedMediumFastFastVery fast on M-series
Inference SpeedFastMediumSlow (compilation overhead)Very fast on M-series
Memory EfficiencyGood (quantization tools)GoodExcellent (JAX tracing)Excellent (unified memory)
Model DistributionSavedModel, KerasTorchScript, ONNXJAX files, ONNXMLX format
For HarnessesBest if production + mobileBest if flexibility neededBest for researchBest if M-series hardware
Learning ResourcesExcellent (TF tutorials)Excellent (PyTorch docs)Good (research papers)Growing (official docs)

When to Use Each Framework

TensorFlow

Best for:

  • Production ML systems at large companies
  • Mobile/embedded deployment (TF Lite is unmatched)
  • End-to-end ML pipelines (TFX)
  • Teams that value stability over innovation

Typical use cases:

  • Google services (search, recommendations)
  • Mobile apps (TensorFlow Lite on Android/iOS)
  • Enterprise deployments (financial, healthcare)
  • Large-scale data processing

PyTorch

Best for:

  • Research and experimentation
  • Rapid prototyping
  • Flexible custom training loops
  • Startups and teams that value agility

Typical use cases:

  • Academic research
  • Startups building new architectures
  • Custom training logic
  • Integration with other libraries (Hugging Face, Lightning)

JAX

Best for:

  • Research requiring automatic differentiation of arbitrary Python code
  • Scientific computing
  • Numerical optimization
  • Teams comfortable with functional programming

Typical use cases:

  • Physics simulations
  • Advanced optimization research
  • Reinforcement learning (Brax)
  • Numerical analysis

MLX

Best for:

  • Apple Silicon optimization (M1, M2, M3, M4)
  • On-device inference on MacBooks/iPads
  • Teams already in Apple ecosystem
  • Maximum performance on M-series

Typical use cases:

  • Local LLM inference on MacBook
  • iOS/macOS ML apps
  • Apple developer teams
  • Edge inference on M-series chips

Decision Framework

Start here: What hardware will you run on?

Do you have Apple Silicon (M1/M2/M3/M4)?
├─ YES → MLX (best performance on M-series)
└─ NO → Continue

Do you need to deploy on mobile (iOS/Android)?
├─ YES → TensorFlow (TF Lite is essential)
└─ NO → Continue

Are you in research or academia?
├─ YES → PyTorch (90% of research uses it)
└─ NO → Continue

Do you need production-grade ML pipelines?
├─ YES → TensorFlow (TFX is mature)
└─ NO → Continue

Do you have custom training requirements?
├─ YES → PyTorch
└─ NO → TensorFlow (safer choice)

Are you comfortable with functional programming?
├─ YES → JAX (if numerical optimization is core)
└─ NO → PyTorch or TensorFlow

RESULT: Choose based on your constraints above.

Quick Cheat Sheet

ScenarioFrameworkWhy
Building chatbotPyTorch + TransformersPyTorch dominates NLP ecosystem
Mobile app (iOS/Android)TensorFlow LiteTF Lite has no equal
Apple laptop MLMLX2-5x faster than alternatives
GCP/TPU trainingTensorFlowGoogle’s own framework
AWS/general cloudPyTorchGrowing standard
Research paperPyTorch70%+ of papers use PyTorch
Enterprise ML pipelineTensorFlowTFX handles everything
Physics simulationJAXAutomatic differentiation for ODEs/PDEs
Uncertain deploymentONNX exportExport to ONNX, decide later
AspectTensorFlowPyTorch
Learning CurveSteeper, large API surfaceGentle, feels like Python
Research AgilityGood, but PyTorch fasterExcellent, dominant in research
Production MaturityMature, battle-tested at scaleImproving rapidly (TorchServe)
DebuggingGraph execution can be trickyEager execution, straightforward
Mobile/EdgeSuperior (TF Lite)Limited (TorchMobile)
Deployment OptionsTF Serving, TF Lite, TF.js, TFXTorchServe, ONNX, limited native mobile
Community SizeLarge, Google-backedRapidly growing, Meta-backed, larger research community
PerformanceXLA optimization, goodtorch.compile, improving
Model DistributionKeras saved model, SavedModelTorchScript, ONNX
For HarnessesBetter if production + mobile neededBetter if prototyping + flexible inference

Decision Framework

  • Choose TensorFlow if: Building production system with mobile deployment, need ecosystem maturity, optimizing for scale at large companies
  • Choose PyTorch if: Rapid prototyping, research-oriented work, flexible custom logic, industry startup/scale-up environment

5. ONNX (Open Neural Network Exchange)

What It Is

ONNX is a standardized, language-agnostic format for representing machine learning models. It decouples model definition from framework and enables interoperability across the ML ecosystem.

Why ONNX Matters

The core promise: Train in any framework, deploy everywhere.

You can:

  • Train a model in PyTorch
  • Export to ONNX
  • Run on ONNX Runtime (CPU, GPU, various hardware accelerators)
  • Convert to CoreML for iOS, TensorRT for NVIDIA GPUs, etc.
  • Use the same model across web (ONNX.js), mobile, and server

Format & Operators

ONNX models are graphs of operators. Each operator is well-defined (matrix multiply, ReLU, convolution, attention, etc.). This standardization enables:

  • Framework-independent optimization
  • Hardware-specific optimization (TensorRT compiles ONNX for NVIDIA)
  • Cross-platform deployment

Runtimes

  • ONNX Runtime (Microsoft): CPU and GPU inference, multiple backends (CUDA, TensorRT, CPU, etc.)
  • TensorRT: NVIDIA’s GPU-optimized inference engine (ONNX → compiled GPU code)
  • CoreML: Apple’s native format (ONNX can be converted to CoreML)
  • ONNX.js: Browser-based inference
  • TensorFlow Lite: Can export ONNX models to TF Lite format

Conversion Paths

PyTorch model → torch.onnx.export() → model.onnx → ONNX Runtime (CPU/GPU)
                                   → coremltools → CoreML → iOS/macOS
                                   → tf2onnx → TensorRT → NVIDIA GPU

TensorFlow model → tf2onnx → model.onnx → [same as above]

Practical Example: PyTorch → ONNX → Inference

import torch
import onnx
import onnxruntime as ort
import numpy as np

# Train or load model
model = MyModel()

# Export to ONNX
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=14, input_names=['input'], output_names=['output'])

# Verify ONNX model
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)

# Run inference
session = ort.InferenceSession("model.onnx")
input_data = {"input": np.random.randn(1, 3, 224, 224).astype(np.float32)}
outputs = session.run(None, input_data)

Use Cases

  • Maximum portability: One model format, multiple deployment targets
  • Hardware optimization: ONNX can be compiled to leverage specific hardware (GPUs, TPUs, NPUs)
  • Mixed-framework teams: PyTorch model, TensorFlow infrastructure, both use ONNX as bridge
  • Uncertain deployment: Export to ONNX, decide deployment platform later

6. MLX (Apple’s Framework)

What It Is

MLX is Apple’s machine learning framework, optimized for Apple Silicon (M1, M2, M3, M4 families). It leverages unified memory architecture to achieve efficiency gains impossible on traditional architectures.

Key Advantages on Apple Silicon

  • Unified memory: GPU and CPU share memory, eliminating expensive data copies
  • Performance: 2-5x faster than generic frameworks on equivalent M-series hardware
  • Energy efficiency: Better battery life on MacBook and iPad
  • Pythonic API: Similar to PyTorch, easy for researchers to adopt

API Style

MLX mimics NumPy and PyTorch:

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

class MyModel(nn.Module):
    def __call__(self, x):
        x = nn.Linear(128)(x)
        x = nn.relu(x)
        return nn.Linear(10)(x)

model = MyModel()
optimizer = optim.Adam(learning_rate=1e-3)

Growing Adoption

  • Hugging Face Transformers (starting to include MLX support)
  • OpenAI and researchers increasingly targeting M-series
  • Better LLM inference on MacBook than on equivalent Linux boxes
  • Community libraries building around MLX (LM Studio, etc.)

Integration Strategy

  • For M-series development: Consider MLX for native performance
  • For cross-platform: Export PyTorch model to ONNX, run on ONNX Runtime (less efficient on M-series but portable)
  • Hybrid: Use MLX for local development, PyTorch/ONNX for server deployment

7. Hugging Face Ecosystem

The Hub: 2M+ Models, One Interface

Hugging Face’s Transformers library has become the distribution standard for:

  • Large language models (LLaMA, Mistral, GPT-2, etc.)
  • Vision models (ViT, DINO, CLIP, etc.)
  • Multimodal models (BLIP-2, LLaVA, etc.)

Transformers Library

Built on PyTorch and TensorFlow, Transformers abstracts framework details:

from transformers import pipeline

# One line to load any model and run inference
classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
result = classifier("I love this!")

The same code works for:

  • Different model architectures
  • Different frameworks (PyTorch/TensorFlow backend handled automatically)
  • Different sizes (distilled models vs. large models)

Model Hub

  • Hosted at huggingface.co
  • 2M+ models (LLMs, vision, audio, multimodal)
  • Model cards with documentation, usage examples, licensing
  • Integration with popular frameworks and runtimes
  • Community contributions and model updates

Fine-Tuning: The Trainer API

Simplified training loop:

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=16,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)

trainer.train()

Inference Options

  • Local inference: Load model and run (PyTorch/TensorFlow)
  • Hugging Face Inference API: Serverless inference endpoint
  • Hugging Face Spaces: Deploy interactive demos
  • ONNX export: Convert to ONNX for portability

8. Framework Integration in Harnesses

What Is a Harness?

A harness is a coordination layer that orchestrates AI reasoning. It typically includes:

  • LLM backend: Model for generating responses (Claude, GPT-4, or local)
  • Orchestration logic: Loop, state management, tool calling
  • Tools/Functions: Domain-specific functions the LLM can invoke
  • Memory/Context: Maintaining conversation history and state

Role of ML Frameworks

The framework choice depends on which model backend you use:

  1. API-based (Claude, OpenAI): No framework needed; you call HTTP endpoints
  2. Local open-source model: Framework required to load and run the model
  3. Hybrid: Fast local model for screening, Claude API for complex reasoning

Example: Harness with Claude API (No Framework)

import anthropic

client = anthropic.Anthropic()

def harness_loop(user_input):
    response = client.messages.create(
        model="claude-sonnet-4",
        max_tokens=1024,
        messages=[{"role": "user", "content": user_input}]
    )
    return response.content[0].text

Framework needed: None. Anthropic handles model execution.

Example: Harness with Local PyTorch Model

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load open-source model
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")

def harness_loop(user_input):
    inputs = tokenizer.encode(user_input, return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

Framework needed: PyTorch + Transformers.

Example: Harness with TensorFlow Model

import tensorflow as tf

# Load model
model = tf.keras.models.load_model("my_classifier.h5")

def harness_loop(user_input):
    # Preprocessing
    processed = preprocess(user_input)
    
    # Inference
    predictions = model.predict(processed)
    
    return format_output(predictions)

Framework needed: TensorFlow.


9. Running Models in Harnesses: Architecture Patterns

Use: Claude API, OpenAI API, or Hugging Face Inference API

import anthropic

class HarnessWithAPI:
    def __init__(self):
        self.client = anthropic.Anthropic()
    
    def reason(self, query):
        response = self.client.messages.create(
            model="claude-sonnet-4",
            max_tokens=2048,
            messages=[{"role": "user", "content": query}]
        )
        return response.content[0].text

Pros: Simplicity, no infrastructure, always up-to-date, handled by provider Cons: Latency, API costs, internet dependency, no offline capability Best for: Most applications, when you want Claude’s intelligence


Pattern 2: Local Model (Research, Privacy, Offline)

Use: Open-source LLM running locally via PyTorch/ONNX

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class HarnessWithLocalModel:
    def __init__(self, model_name="mistralai/Mistral-7B"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, 
            torch_dtype=torch.float16,
            device_map="auto"  # Auto GPU placement
        )
    
    def reason(self, query):
        inputs = self.tokenizer.encode(query, return_tensors="pt").to(self.model.device)
        with torch.no_grad():
            outputs = self.model.generate(
                inputs,
                max_new_tokens=1024,
                temperature=0.7,
                top_p=0.95
            )
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

Pros: Offline, privacy, control, no API costs Cons: GPU requirements, slower than proprietary models, setup complexity, memory usage Best for: Privacy-critical, offline-first, research, edge deployment


Pattern 3: Hybrid (Best of Both Worlds)

Use: Fast local classifier for routing + Claude API for complex reasoning

import torch
import anthropic
from transformers import pipeline

class HybridHarness:
    def __init__(self):
        # Fast local classifier for initial routing
        self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=0)
        
        # Claude for complex reasoning
        self.claude = anthropic.Anthropic()
    
    def reason(self, query):
        # Quick classification
        candidates = ["simple_qa", "complex_reasoning", "code_generation"]
        classification = self.classifier(query, candidates)
        
        # Route based on complexity
        if classification["scores"][0] > 0.8 and classification["labels"][0] == "simple_qa":
            # Use simple local model for fast answers
            return self._simple_answer(query)
        else:
            # Use Claude for complex reasoning
            return self._claude_reasoning(query)
    
    def _simple_answer(self, query):
        # Could use a smaller, faster local model
        pass
    
    def _claude_reasoning(self, query):
        response = self.claude.messages.create(
            model="claude-sonnet-4",
            max_tokens=2048,
            messages=[{"role": "user", "content": query}]
        )
        return response.content[0].text

Pros: Speed, quality, controlled costs, offline-capable for simple cases Cons: Complexity, maintenance, careful routing logic needed Best for: Production systems where cost and latency matter


10. TensorFlow Serving

What It Is

TensorFlow Serving is Google’s production-grade framework for serving models. It’s optimized for:

  • High throughput
  • Low latency
  • Model versioning (serve multiple versions simultaneously)
  • Canary deployments and A/B testing

Architecture

Client → REST/gRPC endpoint → Model Server → Model Registry → TensorFlow/ONNX models

                        Hardware acceleration
                        (GPU, TPU, CPU)

Key Features

  • REST and gRPC APIs: Different trade-offs (REST is simpler, gRPC is faster)
  • Model versioning: Serve v1, v2, v3 simultaneously; route traffic between versions
  • Batching: Automatic request batching for efficiency
  • Monitoring and logging: Built-in observability
  • Horizontal scaling: Stateless design enables easy scaling

Example: Serving a TensorFlow Model

# Save model in SavedModel format
model.save("model/1")

# Start TensorFlow Serving
docker run -t --rm -p 8501:8501 \
  -v "$(pwd)/model:/models/my_model" \
  -e MODEL_NAME=my_model \
  tensorflow/serving

# Query via REST
curl -X POST http://localhost:8501/v1/models/my_model:predict \
  -d '{"instances": [[[1,2,3,4]]]}'

Use Cases

  • Production systems requiring sub-100ms latency
  • High-volume inference (millions of requests/day)
  • Model versioning and canary testing
  • GPU-optimized inference at scale

11. TorchServe

What It Is

TorchServe is PyTorch’s production serving framework. Modeled partly on TensorFlow Serving, it provides REST/gRPC endpoints, model versioning, and batching.

Status (2026): Stable and production-ready, though with a smaller production footprint than TF Serving.

Key Features

  • REST and gRPC endpoints: Similar to TF Serving
  • Model versioning: Serve and route between model versions
  • Batching: Automatic request batching
  • Model zoo: Pre-packaged models for common tasks
  • Metrics and logging: Prometheus integration

Packaging and Serving

Models are packaged in .mar (Model Archive) format:

# Package model
torch-model-archiver \
  --model-name my_model \
  --version 1.0 \
  --model-file model.py \
  --serialized-file model.pt \
  --handler my_handler.py

# Start TorchServe
torchserve --start --model-store model_store --models my_model.mar

# Query
curl http://localhost:8080/predictions/my_model -T input.json

Handler Example

Custom inference logic:

from ts.torch_handler.base_handler import BaseHandler

class MyHandler(BaseHandler):
    def preprocess(self, data):
        # Data preprocessing
        return processed_data
    
    def inference(self, data):
        # Model inference
        return self.model(data)
    
    def postprocess(self, data):
        # Result formatting
        return formatted_result

Use Cases

  • Production PyTorch models
  • When you prefer PyTorch’s ecosystem
  • Research teams transitioning to production

12. LLM-Specific Frameworks

For LLM serving specifically, specialized frameworks abstract away framework details and provide optimizations:

vLLM

High-performance LLM serving engine:

  • PagedAttention optimization (efficient memory use)
  • Continuous batching
  • Supports many open-source LLMs
  • REST API
python -m vllm.entrypoints.openai_compatible_server \
  --model mistralai/Mistral-7B-Instruct-v0.1 \
  --gpu-memory-utilization 0.9

Ollama

Simplest approach to running LLMs locally:

ollama run mistral
ollama run llama2
ollama run phi

Provides REST API, works on CPU and GPU, easy for beginners.

LM Studio

GUI-based local LLM runner:

  • Download models from Hugging Face
  • Run inference with slider controls
  • REST API for programmatic use
  • No terminal required

Text Generation WebUI

Feature-rich open-source interface:

  • Multiple backends (transformers, vLLM, GPTQ, etc.)
  • Chat interface, character roles
  • Extensions and plugins
  • API endpoint

Integration in Harnesses

These frameworks abstract the model backend, so your harness code doesn’t change:

# Works with any framework
response = requests.post(
    "http://localhost:8000/v1/chat/completions",
    json={
        "model": "mistral",
        "messages": [{"role": "user", "content": "What is 2+2?"}],
        "temperature": 0.7
    }
)

Whether using vLLM, Ollama, or LM Studio, the interface is consistent.


13. Building a Harness with TensorFlow or PyTorch

Scenario: Custom Text Classification Model

Goal: Build a harness that classifies user intent and routes to appropriate handler.

Implementation with PyTorch

import torch
import torch.nn as nn
from transformers import AutoTokenizer
import anthropic

class IntentClassifier(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()
        self.encoder = AutoModel.from_pretrained("bert-base-uncased")
        self.classifier = nn.Linear(768, num_classes)
    
    def forward(self, input_ids, attention_mask):
        encoded = self.encoder(input_ids, attention_mask).last_hidden_state
        cls_token = encoded[:, 0]  # [CLS] token
        return self.classifier(cls_token)

class ClassificationHarness:
    def __init__(self, model_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = IntentClassifier().to(self.device)
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()
        
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.claude = anthropic.Anthropic()
        
        self.intents = ["greeting", "question", "request", "feedback", "error"]
    
    def process(self, user_input):
        # Classify intent
        with torch.no_grad():
            inputs = self.tokenizer(
                user_input,
                return_tensors="pt",
                max_length=128,
                padding=True,
                truncation=True
            ).to(self.device)
            
            logits = self.model(inputs["input_ids"], inputs["attention_mask"])
            intent_idx = logits.argmax(dim=1).item()
            intent = self.intents[intent_idx]
        
        # Route based on intent
        if intent == "greeting":
            return self._handle_greeting(user_input)
        elif intent == "error":
            return self._handle_error(user_input)
        else:
            return self._handle_with_claude(user_input, intent)
    
    def _handle_greeting(self, text):
        return "Hello! How can I help you today?"
    
    def _handle_error(self, text):
        return "I detected an error report. Let me escalate this."
    
    def _handle_with_claude(self, text, intent):
        prompt = f"User intent: {intent}\nUser message: {text}\nRespond helpfully:"
        response = self.claude.messages.create(
            model="claude-sonnet-4",
            max_tokens=1024,
            messages=[{"role": "user", "content": prompt}]
        )
        return response.content[0].text

# Usage
harness = ClassificationHarness("model.pt")
result = harness.process("I'm having trouble logging in")

Implementation with TensorFlow

import tensorflow as tf
import anthropic
from transformers import AutoTokenizer, TFAutoModel

class IntentClassifierTF(tf.keras.Model):
    def __init__(self, num_classes=5):
        super().__init__()
        self.encoder = TFAutoModel.from_pretrained("bert-base-uncased")
        self.dropout = tf.keras.layers.Dropout(0.1)
        self.classifier = tf.keras.layers.Dense(num_classes, activation='softmax')
    
    def call(self, inputs, training=False):
        encoded = self.encoder(inputs)[1]  # [CLS] token (pooled output)
        x = self.dropout(encoded, training=training)
        return self.classifier(x)

class ClassificationHarnessTF:
    def __init__(self, model_path):
        self.model = tf.keras.models.load_model(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.claude = anthropic.Anthropic()
        self.intents = ["greeting", "question", "request", "feedback", "error"]
    
    def process(self, user_input):
        # Classify intent
        inputs = self.tokenizer(
            user_input,
            return_tensors="tf",
            max_length=128,
            padding=True,
            truncation=True
        )
        
        logits = self.model(inputs)
        intent_idx = tf.argmax(logits, axis=1).numpy()[0]
        intent = self.intents[intent_idx]
        
        # Route based on intent
        if intent == "greeting":
            return "Hello! How can I help you today?"
        elif intent == "error":
            return "I detected an error report. Let me escalate this."
        else:
            prompt = f"User intent: {intent}\nUser message: {user_input}\nRespond helpfully:"
            response = self.claude.messages.create(
                model="claude-sonnet-4",
                max_tokens=1024,
                messages=[{"role": "user", "content": prompt}]
            )
            return response.content[0].text

# Usage
harness = ClassificationHarnessTF("model.h5")
result = harness.process("I'm having trouble logging in")

Key Observations

Both implementations:

  • Load a pretrained model (BERT)
  • Classify intent
  • Route to appropriate handler (local logic or Claude)
  • Use the same orchestration pattern

Differences:

  • PyTorch: Explicit with torch.no_grad(), .argmax().item() for scalar extraction
  • TensorFlow: Implicit inference mode, .argmax() returns tensor, .numpy() for conversion
  • API differs, but patterns are identical

14. Performance & Optimization

TensorFlow Optimization

XLA Compilation

model = tf.keras.Model(...)
# Compile with XLA
model.compile(optimizer='adam', loss='mse', jit_compile=True)

Quantization (reduce model size and latency)

import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model(model)
quantize_model.compile(optimizer='adam', loss='mse')

SavedModel format (optimized for serving)

model.save('model', save_format='tf')  # Saved as SavedModel
# TensorFlow Serving can optimize this further

PyTorch Optimization

torch.compile (new in PyTorch 2.0, automatic graph optimization)

model = MyModel()
compiled_model = torch.compile(model)  # JIT compile to optimized kernel
y = compiled_model(x)  # Faster inference

Quantization

import torch.quantization as quantization

model_fp32 = MyModel()
model_int8 = quantization.quantize_dynamic(
    model_fp32,
    {torch.nn.Linear},
    dtype=torch.qint8
)

TorchScript (export optimized model)

scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'model.pt')

Cross-Framework Optimization via ONNX

# Export to ONNX (framework-agnostic optimization)
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=14)

# Optimize with ONNX Runtime
from onnxruntime.transformers import optimizer
from onnxruntime.transformers.onnx_model_bert import BertOptimizationOptions

opt_options = BertOptimizationOptions('bert')
optimizer.optimize_model('model.onnx', model_type='bert', opt_options=opt_options, output_model_path='model_optimized.onnx')

Common Optimization Techniques

  1. Quantization: 4-8x reduction in model size, 2-4x faster inference
  2. Pruning: Remove less important connections, smaller models
  3. Distillation: Train small student model to mimic large teacher model
  4. Batching: Amortize overhead across multiple requests
  5. Hardware acceleration: GPU/TPU/NPU vs. CPU
  6. Model architecture: Faster architectures (MobileNet, DistilBERT, TinyLLM)

Benchmarking

import time

# Measure latency
start = time.time()
for _ in range(100):
    output = model(input_data)
latency = (time.time() - start) / 100

# Measure memory
import psutil
process = psutil.Process()
memory = process.memory_info().rss / 1024 / 1024  # MB

15. Deployment Paths

Path 1: Framework → Container → Cloud

Use: TensorFlow, PyTorch models in production cloud infrastructure

FROM python:3.11-slim

WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt

COPY model.pt .
COPY serve.py .

EXPOSE 8000
CMD ["python", "serve.py"]

Deployment:

docker build -t my-model:latest .
docker push gcr.io/my-project/my-model:latest

# Deploy to Cloud Run, Kubernetes, etc.

Best for: Flexible, portable, use any framework


Path 2: Framework → ONNX → CoreML → iOS

Use: Model on iPhone/iPad

import torch
import onnx
import onnx_coreml

# Train or load PyTorch model
model = MyModel()

# Export to ONNX
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=14)

# Convert to CoreML
onnx_model = onnx.load("model.onnx")
coreml_model = onnx_coreml.convert(onnx_model)
coreml_model.save("MyModel.mlmodel")

iOS integration:

import CoreML

let model = try MyModel(configuration: MLModelConfiguration())
let input = MyModelInput(image: cgImage)
let output = try model.prediction(input: input)

Best for: On-device inference, privacy, offline capability


Path 3: Framework → Hugging Face → Hugging Face Inference API

Use: Serverless inference, no infrastructure management

from transformers import pipeline

# Load locally
classifier = pipeline("text-classification", model="distilbert-base-uncased")

# Or use Hugging Face Inference API
from huggingface_hub import InferenceClient

client = InferenceClient(api_key="hf_xxxx")
response = client.text_classification("I love this movie!")

Benefits: Auto-scaling, managed by Hugging Face, pay-per-use


Path 4: Framework → Quantized Model → Edge Device

Use: Extremely resource-constrained devices (Raspberry Pi, microcontroller)

# Quantize model
import torch.quantization as quantization
model_int8 = quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

# Reduce size further
import torch.utils.mobile_optimizer as mobile_optimizer
optimized_model = mobile_optimizer.optimize_for_mobile(model_int8)

# Save and deploy to edge
torch.jit.save(optimized_model, "model_edge.pt")

Embedded deployment: TensorFlow Lite, ONNX Runtime, or even TinyML frameworks


Path 5: LLM → vLLM/Ollama → Local/Docker → Harness

Use: Running open-source LLMs locally or in containers

# Option A: Ollama (simplest)
ollama run mistral:7b

# Option B: vLLM (optimized)
python -m vllm.entrypoints.openai_compatible_server \
  --model mistralai/Mistral-7B-Instruct-v0.1 \
  --gpu-memory-utilization 0.9

# Option C: Docker
docker run --gpus all -p 8000:8000 \
  vllm/vllm-openai:latest \
  --model mistralai/Mistral-7B-Instruct-v0.1

Harness integration:

import requests

response = requests.post("http://localhost:8000/v1/chat/completions", json={
    "model": "mistral",
    "messages": [{"role": "user", "content": "What is 2+2?"}]
})

Choosing a Framework and Deployment Path

Decision Tree

  1. Do you need an LLM?

    • Yes → Use Claude API (recommended) or open-source via vLLM/Ollama
    • No → Continue to 2
  2. Is your model a transformer (NLP/vision)?

    • Yes → Use Hugging Face Transformers library
    • No → Continue to 3
  3. What’s your deployment target?

    • Production cloud: TensorFlow Serving or TorchServe
    • iOS/Mobile: TensorFlow Lite or CoreML (via ONNX)
    • Web browser: ONNX.js or TensorFlow.js
    • Edge device: ONNX Runtime or TensorFlow Lite
    • Uncertain: Export to ONNX, stay flexible
  4. What’s your team’s background?

    • Research-focused: PyTorch
    • Production-focused: TensorFlow
    • Multi-framework team: ONNX
  5. Do you need offline capability?

    • Yes → Local model (PyTorch/TensorFlow) or edge deployment
    • No → API-based (Claude, Hugging Face Inference API)

Conclusion

The ML framework landscape in 2026 offers mature tools for any deployment scenario:

  • Rapid prototyping: PyTorch + Hugging Face
  • Production systems: TensorFlow Serving or TorchServe
  • Maximum portability: ONNX
  • Apple Silicon optimization: MLX
  • Serverless LLM inference: Hugging Face Inference API or OpenAI API
  • Local LLM running: vLLM, Ollama, LM Studio

For harnesses, the choice is between API-based (simplicity, reliability) and local models (control, privacy, offline). Most production harnesses are hybrid: use local classifiers for fast routing, Claude API for complex reasoning.

The key is understanding your constraints (latency, privacy, cost, expertise) and choosing accordingly. ONNX provides an escape hatch—export early, stay flexible.


Validation Checklist

How do you know you got this right?

Performance Checks

  • Inference latency benchmarked on target deployment hardware (CPU, GPU, or Apple Silicon) with realistic batch sizes
  • ONNX export validated: output from ONNX Runtime matches original framework output within tolerance (cosine similarity >0.99)
  • Model serving throughput measured under concurrent load (requests/second at target latency SLA)

Implementation Checks

  • Framework selected using the decision tree (Section 5): matches your deployment target, team skills, and hardware
  • Model exports cleanly to ONNX format with onnx.checker.check_model() passing without errors
  • Conversion path tested end-to-end: train in framework -> export ONNX -> run in target runtime (ONNX Runtime, CoreML, TensorRT)
  • Serving infrastructure chosen: TF Serving for TensorFlow models, TorchServe for PyTorch, vLLM/Ollama for LLMs
  • Quantization applied and validated: int8 model produces acceptable quality on 50+ test inputs
  • API vs local model decision made and justified: API for simplicity and quality, local for privacy and offline
  • Hybrid routing logic implemented if using both local and cloud models (classifier routes simple queries locally, complex to API)

Integration Checks

  • Harness inference loop works with chosen framework: model loads, generates output, and returns results through the harness API
  • Model versioning strategy defined: can deploy new model version without downtime (A/B testing or canary rollout)
  • Framework dependencies documented and pinned: specific versions of PyTorch/TensorFlow/ONNX Runtime recorded in requirements

Common Failure Modes

  • ONNX export fails on dynamic shapes: Model uses variable-length inputs that ONNX can’t trace. Fix: set dynamic_axes parameter in torch.onnx.export(), or use fixed input shapes with padding.
  • Framework mismatch: PyTorch model won’t load in TensorFlow ecosystem. Fix: export to ONNX first, then convert to target format (TF Lite, CoreML, TensorRT).
  • Serving latency spikes under load: Model inference is fast but serving adds overhead. Fix: enable request batching in TF Serving/TorchServe, optimize batch size for throughput vs latency trade-off.
  • torch.compile breaks on custom operations: Non-standard layers fail JIT compilation. Fix: isolate custom ops, use torch.compile only on standard submodules, or fall back to eager mode for problematic layers.

Sign-Off Criteria

  • Can reproduce model training and inference on a clean machine using documented framework versions
  • Deployment path validated end-to-end: from trained model to serving endpoint with measurable latency and throughput
  • Fallback strategy defined: what happens if primary serving framework fails (restart, fall back to simpler model, route to API)
  • Framework lock-in assessed: ONNX export tested to confirm you can switch frameworks if needed
  • Cost per inference calculated: cloud API cost vs local hardware amortization for your expected query volume

See Also

  • Doc 03 (Hugging Face Ecosystem) — Most models are PyTorch-native; HF provides export utilities and deployment options
  • Doc 01 (Foundation Models) — Framework choice affects which models you can use and how easily you can deploy them
  • Doc 12 (Deployment Patterns) — Frameworks like TensorFlow Serving and TorchServe are deployed in the patterns shown in this doc
  • Doc 23 (Apple Intelligence & CoreML) — Mobile deployment requires conversion; frameworks must support ONNX or CoreML export