Network-accelerated Distributed Machine Learning for Multi-Tenant Settings

https://dl.acm.org/doi/pdf/10.1145/3419111.3421296?casa_token=eX-ZTtI2CTwAAAAA:7TPapL35POUlmFtA46isDCmIbtMy6gtXADpQzZPUcNdywSFA1_qjYQLmHfOLP7uXAVBjrCL1ZFhE

Presentation

  • Background:

    • ML is becoming ubiquitous in software industry

    • ML models are trained with large volumes of data typically over shared infrastructure

    • DML

      • More compute cycles

      • Process lots of distributed data

      • But require frequent synchronization (i.e. SGD: after each 'iteration')

    • Architecture

      • Parameter server: Google DistBelief, Microsoft Adam

        • Async v.s Sync

      • P2P MPI

        • Synchronous

  • Key factors affecting performance of distributed ML training

    • Async v.s Sync SGD

    • A-SGD:

      • Version of the model being updated v.s. used to update the gradients

      • Communication intensive

        • Compute update (~100ms)

        • Fetch model (~600ms)

        • R1: Reduce network load at the server

      • Stragglers affect convergence

        • Halving bandwidth on 10% --> 35% in iteration through convergence

        • R2: Bound delays in the presence of stragglers

      • Fault tolerance has huge overhead

        • Chain replication (every worker update forwards to replica)

          • Outgoing nic:

            • Carries both models to the workers

            • And model updates to the replica

        • Directly forward to replica

          • Asynchrony + stateful leads to inconsistency

        • R3: Non-divergent replication without server overhead

    • S-SGD

      • With PS architecture

        • Server NIC overload is a bottleneck for large models

        • Stragglers increase time per-iteration

      • MPI architecture

        • Ring reduce algorithms are typically use: but assume homogeneous network infrastructure

        • Compute & network stragglers increase time per-iteration

      • R4: bandwidth-aware aggregation

Existing works

  • R1

    • Infrequent updates to the server

      • Workers aggregate locally, then transmit to server

      • But: cannot aggregate updates across multiple workers

  • R2

    • Dropping delayed updates

      • updates: transmitted to the servers anyway, does not reduce the server load

  • R3

    • No prior work

  • R4

    • Hierarchical AllReduce, but assume static bandwidth setting

MLFabric

  • Contributions

    • Prove bounded delay helps convergence

    • Network aware ordering and aggregation of updates

      • Helps bound delay

      • Reduce network load at parameter server

    • Model replication strategy

      • Guarantee bounded consistency b/w server and replica

  • Network aware ordering and aggregation of updates

    • Re-ordering updates

      • Buffering updates

        • Bounded delay (R2) at the cost of stale read

      • Time-shared schedules

        • Update scheduled for later can be aggregated off-server

        • SJF + Time-sharing satisfies R1 and R2

  • Architecture

  • Ordering and Aggregation algorithm

    • Scheduler batches transfer requests to server

      • Information: model version, size of the update, norm of gradients

    • Orders them iteratively in SJF fashion to server

      • Completion time is determined in network aware fashion

      • State update in each iteration

      • Update that cannot meet deadline are dropped

    • Queued updates are grouped, aggregated and then sent to server

      • Minimize total completion time

      • Aggregator is chosen randomly

Paper

  • DML: compute and network contention, resulting in stragglers

    • Current system takes too simplistic a view of the network (having either fixed bandwidth between all the workers or as a blackbox with unknown inter-worker bandwidth)

  • ML-fabric: contention-aware DML system

    • order transfers to improve convergence

    • opportunistically aggregates them at idle DML workers to improve resource efficiency

    • replicates them to support new notions of fault tolerance

    • systematically account for compute stragglers and network contention

    • implemented as a communication layer between DML applications and AllReduce libraries

  • Key idea

    • Control update delays

      • transfer available updates at non-overlapping times, reserving bandwidth per update, and explicitly ordering them

      • worker's delay = the difference between the server model version a worker pulled and computed on, versus the server model version at the time the worker's compute gradient is applied

    • Dynamically aggregating or dropping updates

    • Replicating updates for fault tolerance

Last updated