Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions deepspeed/runtime/fp16/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
from typing import Optional
from enum import Enum
from deepspeed.runtime.config_utils import DeepSpeedConfigObject
from deepspeed.runtime.loss_scale_validation import (
validate_loss_scale_value,
validate_positive_finite,
validate_positive_int,
)
from deepspeed import comm as dist
from deepspeed.utils import logger

Expand Down Expand Up @@ -67,6 +72,20 @@ class LossScaleProfileDefaults:
}


def _validated_dynamic_loss_args(dynamic_loss_args):
return {
INITIAL_LOSS_SCALE: validate_positive_finite(dynamic_loss_args[INITIAL_LOSS_SCALE],
name=f"dynamic_loss_args['{INITIAL_LOSS_SCALE}']"),
SCALE_WINDOW: validate_positive_int(dynamic_loss_args[SCALE_WINDOW],
name=f"dynamic_loss_args['{SCALE_WINDOW}']"),
DELAYED_SHIFT: validate_positive_int(dynamic_loss_args[DELAYED_SHIFT],
name=f"dynamic_loss_args['{DELAYED_SHIFT}']"),
CONSECUTIVE_HYSTERESIS: dynamic_loss_args[CONSECUTIVE_HYSTERESIS],
MIN_LOSS_SCALE: validate_positive_finite(dynamic_loss_args[MIN_LOSS_SCALE],
name=f"dynamic_loss_args['{MIN_LOSS_SCALE}']"),
}


@dataclass
class LossScaleConfig:
use_grad_scaling: bool
Expand Down Expand Up @@ -100,7 +119,7 @@ def __init__(self,
if not use_grad_scaling:
return

self.cur_scale = static_loss_scale
self.cur_scale = validate_loss_scale_value(static_loss_scale, name="fp16.loss_scale")
if not dynamic_loss_scale:
return

Expand All @@ -111,14 +130,15 @@ def __init__(self,
self.last_overflow_iter = -1
self.scale_factor = defaults.scale_factor
if dynamic_loss_args is None:
self.cur_scale = initial_dynamic_scale
self.cur_scale = validate_positive_finite(initial_dynamic_scale, name="dynamic_loss_args['init_scale']")
self.scale_window = defaults.default_scale_window
self.min_loss_scale = defaults.default_min_loss_scale
return

self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
self.scale_window = dynamic_loss_args[SCALE_WINDOW]
self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
validated_dynamic_loss_args = _validated_dynamic_loss_args(dynamic_loss_args)
self.cur_scale = validated_dynamic_loss_args[INITIAL_LOSS_SCALE]
self.scale_window = validated_dynamic_loss_args[SCALE_WINDOW]
self.min_loss_scale = validated_dynamic_loss_args[MIN_LOSS_SCALE]


# item() is a recent addition, so this helps with backward compatibility.
Expand Down Expand Up @@ -305,9 +325,10 @@ def update_scale(self, overflow):
def CreateLossScaler(dtype, static_loss_scale, dynamic_scaling, dynamic_loss_args):
if dtype == torch.half and dynamic_scaling:
assert dynamic_loss_args is not None, "Dynamic loss scaling parameters must be defined."
return DynamicLossScaler(dtype=dtype, **dynamic_loss_args)
return DynamicLossScaler(dtype=dtype, **_validated_dynamic_loss_args(dynamic_loss_args))

loss_scale_value = static_loss_scale if dtype == torch.half else 1.0
loss_scale_value = validate_loss_scale_value(loss_scale_value, name="fp16.loss_scale")
return LossScaler(scale=loss_scale_value)


Expand Down
47 changes: 47 additions & 0 deletions deepspeed/runtime/loss_scale_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import math
from numbers import Integral, Real


def _to_finite_float(value, *, name: str) -> float:
if isinstance(value, bool) or not isinstance(value, Real):
raise ValueError(f"{name} must be a real number, got {type(value).__name__}")

numeric_value = float(value)
if not math.isfinite(numeric_value):
raise ValueError(f"{name} must be finite, got {numeric_value}")
return numeric_value


def validate_loss_scale_value(value, *, name: str = "loss_scale", allow_dynamic_zero: bool = False) -> float:
"""
Validate static loss scale values.

A value of 0 is accepted only when it represents dynamic loss scaling mode.
"""
numeric_value = _to_finite_float(value, name=name)
if allow_dynamic_zero and numeric_value == 0.0:
return numeric_value
if numeric_value <= 0.0:
raise ValueError(f"{name} must be greater than 0, got {numeric_value}")
return numeric_value


def validate_positive_finite(value, *, name: str) -> float:
numeric_value = _to_finite_float(value, name=name)
if numeric_value <= 0.0:
raise ValueError(f"{name} must be greater than 0, got {numeric_value}")
return numeric_value


def validate_positive_int(value, *, name: str) -> int:
if isinstance(value, bool) or not isinstance(value, Integral):
raise ValueError(f"{name} must be an integer, got {type(value).__name__}")
int_value = int(value)
if int_value <= 0:
raise ValueError(f"{name} must be greater than 0, got {int_value}")
return int_value
24 changes: 23 additions & 1 deletion deepspeed/runtime/precision_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# DeepSpeed Team

from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from pydantic import field_validator
from .loss_scale_validation import validate_loss_scale_value, validate_positive_finite, validate_positive_int
from .fp16.loss_scaler import (
INITIAL_LOSS_SCALE,
SCALE_WINDOW,
Expand Down Expand Up @@ -133,12 +135,32 @@ class DeepSpeedFP16Config(DeepSpeedConfigModel):
Refill hysteresis if iteration does not overflow/underflow.
"""

min_loss_scale: int = 1
min_loss_scale: float = 1
"""
Minimum dynamic loss scale value.
"""

fp16_master_weights_and_grads: bool = False
@field_validator("loss_scale")
@classmethod
def validate_loss_scale(cls, value):
return validate_loss_scale_value(value, name="fp16.loss_scale", allow_dynamic_zero=True)

@field_validator("loss_scale_window")
@classmethod
def validate_loss_scale_window(cls, value):
return validate_positive_int(value, name="fp16.loss_scale_window")

@field_validator("hysteresis")
@classmethod
def validate_hysteresis(cls, value):
return validate_positive_int(value, name="fp16.hysteresis")

@field_validator("min_loss_scale")
@classmethod
def validate_min_loss_scale(cls, value):
return validate_positive_finite(value, name="fp16.min_loss_scale")

"""
Maintain master weights in optimizer state as fp16 instead of fp32 (valid with DeepSpeedCPUAdam only).
"""
Expand Down
10 changes: 6 additions & 4 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from deepspeed.utils import logger
from deepspeed.utils.torch import register_grad_hook, required_torch_version
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.loss_scale_validation import validate_loss_scale_value
from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward
Expand Down Expand Up @@ -2212,10 +2213,11 @@ def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id):
self._partitioned_params_swap_out(sub_group_id)

def override_loss_scale(self, loss_scale):
if loss_scale != self.external_loss_scale:
logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
validated_loss_scale = validate_loss_scale_value(loss_scale, name="loss_scale")
if validated_loss_scale != self.external_loss_scale:
logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {validated_loss_scale}')
self.custom_loss_scaler = True
self.external_loss_scale = loss_scale
self.external_loss_scale = validated_loss_scale

@instrument_w_nvtx
def step(self, closure=None):
Expand Down Expand Up @@ -2701,7 +2703,7 @@ def _get_loss_scale(self):
return self.loss_scaler.cur_scale

def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
self.loss_scaler.cur_scale = validate_loss_scale_value(value, name="loss_scale")

loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
Expand Down
10 changes: 6 additions & 4 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from deepspeed.runtime.zero.offload_states import offload_optimizer_states, reload_optimizer_states
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.loss_scale_validation import validate_loss_scale_value
from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes
from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter,
align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace,
Expand Down Expand Up @@ -2031,10 +2032,11 @@ def get_lr(self):
return self.optimizer.param_groups[0]["lr"]

def override_loss_scale(self, loss_scale):
if loss_scale != self.external_loss_scale:
logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
validated_loss_scale = validate_loss_scale_value(loss_scale, name="loss_scale")
if validated_loss_scale != self.external_loss_scale:
logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {validated_loss_scale}')
self.custom_loss_scaler = True
self.external_loss_scale = loss_scale
self.external_loss_scale = validated_loss_scale

def scaled_global_norm(self, norm_type=2):
assert norm_type == 2, "only L2 norm supported"
Expand Down Expand Up @@ -2346,7 +2348,7 @@ def _get_loss_scale(self):
return self.loss_scaler.cur_scale

def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
self.loss_scaler.cur_scale = validate_loss_scale_value(value, name="loss_scale")

loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/runtime/half_precision/test_loss_scale_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from types import SimpleNamespace
import pytest
import torch

from deepspeed.runtime.fp16.loss_scaler import (
CONSECUTIVE_HYSTERESIS,
DELAYED_SHIFT,
INITIAL_LOSS_SCALE,
MIN_LOSS_SCALE,
SCALE_WINDOW,
CreateLossScaler,
LossScaleConfig,
)
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3


def test_loss_scale_config_rejects_non_finite_static_loss_scale():
with pytest.raises(ValueError, match="fp16.loss_scale must be finite"):
LossScaleConfig(low_precision_dtype=torch.float16,
dynamic_loss_scale=False,
static_loss_scale=float("inf"),
dynamic_loss_args=None)


def test_create_loss_scaler_rejects_non_finite_dynamic_init_scale():
dynamic_loss_args = {
INITIAL_LOSS_SCALE: float("inf"),
SCALE_WINDOW: 1000,
DELAYED_SHIFT: 2,
CONSECUTIVE_HYSTERESIS: False,
MIN_LOSS_SCALE: 1.0,
}
with pytest.raises(ValueError, match="dynamic_loss_args\\['init_scale'\\] must be finite"):
CreateLossScaler(torch.float16, static_loss_scale=0, dynamic_scaling=True, dynamic_loss_args=dynamic_loss_args)


def test_stage1_override_loss_scale_validates_values():
optimizer = SimpleNamespace(external_loss_scale=None, custom_loss_scaler=False)
with pytest.raises(ValueError, match="loss_scale must be finite"):
DeepSpeedZeroOptimizer.override_loss_scale(optimizer, float("inf"))

DeepSpeedZeroOptimizer.override_loss_scale(optimizer, 256.0)
assert optimizer.custom_loss_scaler is True
assert optimizer.external_loss_scale == 256.0


def test_stage3_set_loss_scale_validates_values():
optimizer = SimpleNamespace(loss_scaler=SimpleNamespace(cur_scale=1.0))
with pytest.raises(ValueError, match="loss_scale must be greater than 0"):
DeepSpeedZeroOptimizer_Stage3._set_loss_scale(optimizer, 0)

DeepSpeedZeroOptimizer_Stage3._set_loss_scale(optimizer, 128.0)
assert optimizer.loss_scaler.cur_scale == 128.0
34 changes: 34 additions & 0 deletions tests/unit/runtime/test_ds_config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from deepspeed.runtime import config as ds_config
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.precision_config import DeepSpeedFP16Config


class SimpleConf(DeepSpeedConfigModel):
Expand Down Expand Up @@ -84,3 +85,36 @@ def test_config_base_literalfail(config_dict):
def test_config_base_deprecatedfail():
with pytest.raises(AssertionError):
config = SimpleConf(**{"param_2": ["DS"], "param_2_old": "DS"})


@pytest.mark.parametrize(
"fp16_overrides",
[
{"loss_scale": float("inf")},
{"loss_scale": float("-inf")},
{"loss_scale": float("nan")},
{"loss_scale": -1.0},
{"loss_scale_window": 0},
{"hysteresis": 0},
{"min_loss_scale": 0},
],
)
def test_fp16_config_rejects_invalid_loss_scale_inputs(fp16_overrides):
with pytest.raises(ValidationError):
DeepSpeedFP16Config(enabled=True, **fp16_overrides)


def test_fp16_config_accepts_dynamic_loss_scale_sentinel():
config = DeepSpeedFP16Config(enabled=True, loss_scale=0)
assert config.loss_scale == 0.0


def test_deepspeed_config_rejects_non_finite_fp16_loss_scale():
with pytest.raises(ValidationError):
ds_config.DeepSpeedConfig({
"train_batch_size": 1,
"fp16": {
"enabled": True,
"loss_scale": float("inf"),
},
})