Skip to content

vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8

DYNAMIC_QUANT module-attribute

DYNAMIC_QUANT = False

STATIC_QUANT module-attribute

STATIC_QUANT = True

__all__ module-attribute

__all__ = ['CompressedTensorsW8A8Fp8']

activation_quant_key_mapping module-attribute

activation_quant_key_mapping = {
    STATIC_QUANT: kFp8StaticTensorSym,
    DYNAMIC_QUANT: kFp8DynamicTokenSym,
}

logger module-attribute

logger = init_logger(__name__)

strategy_to_parameter_type module-attribute

strategy_to_parameter_type = {
    BLOCK: BlockQuantScaleParameter,
    CHANNEL: ChannelQuantScaleParameter,
    TENSOR: PerTensorScaleParameter,
}

weight_quant_key_mapping module-attribute

weight_quant_key_mapping = {
    CHANNEL: kFp8StaticTokenSym,
    TENSOR: kFp8StaticTensorSym,
}

CompressedTensorsW8A8Fp8

Bases: CompressedTensorsScheme

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
    def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
        self.weight_quant = weight_quant
        self.strategy = weight_quant.strategy
        self.out_dtype = torch.get_default_dtype()
        self.is_static_input_scheme = is_static_input_scheme
        self.weight_block_size = self.weight_quant.block_structure

        if self.weight_block_size is not None:
            self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
            self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
            assert not self.is_static_input_scheme
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            activation_quant_key = activation_quant_key_mapping[is_static_input_scheme]
            weight_quant_key = weight_quant_key_mapping[self.strategy]
            self.fp8_linear = init_fp8_linear_kernel(
                activation_quant_key=activation_quant_key,
                weight_quant_key=weight_quant_key,
                out_dtype=self.out_dtype,
                module_name=self.__class__.__name__,
            )

    @classmethod
    def get_min_capability(cls) -> int:
        # lovelace and up
        return 89

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        weight_loader: Callable,
        **kwargs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        layer.logical_widths = output_partition_sizes
        layer.weight_block_size = None
        layer.orig_dtype = params_dtype

        if self.strategy == QuantizationStrategy.BLOCK:
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
            # Validate block quantization shapes
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )

        # WEIGHT
        weight = create_fp8_weight_parameter(
            output_size_per_partition, input_size_per_partition, weight_loader
        )
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        weight_scale = create_fp8_scale_parameter(
            strategy_to_parameter_type[self.strategy],
            output_partition_sizes,
            input_size_per_partition,
            layer.weight_block_size,
            weight_loader,
        )
        layer.register_parameter("weight_scale", weight_scale)

        # INPUT SCALE
        if self.is_static_input_scheme:
            input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
            layer.register_parameter("input_scale", input_scale)

    def process_weights_after_loading(self, layer) -> None:
        if self.strategy == QuantizationStrategy.TENSOR:
            weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                layer.weight,
                layer.weight_scale,
                layer.logical_widths,
                getattr(layer, "input_scale", None),
            )
            weight = weight.t()
        elif self.strategy == QuantizationStrategy.CHANNEL:
            weight, weight_scale, input_scale = process_fp8_weight_channel_strategy(
                layer.weight, layer.weight_scale, getattr(layer, "input_scale", None)
            )
            weight = weight.t()

        elif self.strategy == QuantizationStrategy.BLOCK:
            assert self.is_static_input_scheme is False
            weight, weight_scale = process_fp8_weight_block_strategy(
                layer.weight, layer.weight_scale
            )
            input_scale = None

        else:
            raise ValueError(
                f"Unknown quantization strategy {self.strategy}: "
                f"should be one of {list(QuantizationStrategy)}"
            )

        # required by torch.compile to be torch.nn.Parameter
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
        if input_scale is not None:
            layer.input_scale = Parameter(input_scale.data, requires_grad=False)

        # INPUT SCALE
        if self.is_static_input_scheme and hasattr(layer, "input_scale"):
            layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
        else:
            layer.input_scale = None
        if self.strategy == QuantizationStrategy.BLOCK:
            maybe_post_process_fp8_weight_block(layer)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if self.weight_block_size is not None:
            return self.w8a8_block_fp8_linear.apply(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
                bias=bias,
            )

        return self.fp8_linear.apply_weights(layer, x, bias)

act_q_group_shape instance-attribute

act_q_group_shape = GroupShape(1, weight_block_size[0])

cutlass_block_fp8_supported instance-attribute

cutlass_block_fp8_supported = cutlass_block_fp8_supported()

fp8_linear instance-attribute

fp8_linear = init_fp8_linear_kernel(
    activation_quant_key=activation_quant_key,
    weight_quant_key=weight_quant_key,
    out_dtype=out_dtype,
    module_name=__name__,
)

is_static_input_scheme instance-attribute

is_static_input_scheme = is_static_input_scheme

out_dtype instance-attribute

out_dtype = get_default_dtype()

strategy instance-attribute

strategy = strategy

use_aiter_and_is_supported instance-attribute

use_aiter_and_is_supported = is_linear_fp8_enabled()

w8a8_block_fp8_linear instance-attribute

w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
    weight_group_shape=GroupShape(*(weight_block_size)),
    act_quant_group_shape=act_q_group_shape,
    cutlass_block_fp8_supported=cutlass_block_fp8_supported,
    use_aiter_and_is_supported=use_aiter_and_is_supported,
)

weight_block_size instance-attribute

weight_block_size = block_structure

weight_quant instance-attribute

weight_quant = weight_quant

__init__

__init__(
    weight_quant: QuantizationArgs,
    is_static_input_scheme: bool,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
    self.weight_quant = weight_quant
    self.strategy = weight_quant.strategy
    self.out_dtype = torch.get_default_dtype()
    self.is_static_input_scheme = is_static_input_scheme
    self.weight_block_size = self.weight_quant.block_structure

    if self.weight_block_size is not None:
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
        assert not self.is_static_input_scheme
        self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(*self.weight_block_size),
            act_quant_group_shape=self.act_q_group_shape,
            cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
            use_aiter_and_is_supported=self.use_aiter_and_is_supported,
        )
    else:
        activation_quant_key = activation_quant_key_mapping[is_static_input_scheme]
        weight_quant_key = weight_quant_key_mapping[self.strategy]
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=activation_quant_key,
            weight_quant_key=weight_quant_key,
            out_dtype=self.out_dtype,
            module_name=self.__class__.__name__,
        )

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Tensor | None = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
def apply_weights(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    if self.weight_block_size is not None:
        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )

    return self.fp8_linear.apply_weights(layer, x, bias)

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    weight_loader: Callable,
    **kwargs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    weight_loader: Callable,
    **kwargs,
):
    output_size_per_partition = sum(output_partition_sizes)
    layer.logical_widths = output_partition_sizes
    layer.weight_block_size = None
    layer.orig_dtype = params_dtype

    if self.strategy == QuantizationStrategy.BLOCK:
        assert self.weight_block_size is not None
        layer.weight_block_size = self.weight_block_size
        # Validate block quantization shapes
        validate_fp8_block_shape(
            layer,
            input_size,
            output_size,
            input_size_per_partition,
            output_partition_sizes,
            self.weight_block_size,
        )

    # WEIGHT
    weight = create_fp8_weight_parameter(
        output_size_per_partition, input_size_per_partition, weight_loader
    )
    layer.register_parameter("weight", weight)

    # WEIGHT SCALE
    weight_scale = create_fp8_scale_parameter(
        strategy_to_parameter_type[self.strategy],
        output_partition_sizes,
        input_size_per_partition,
        layer.weight_block_size,
        weight_loader,
    )
    layer.register_parameter("weight_scale", weight_scale)

    # INPUT SCALE
    if self.is_static_input_scheme:
        input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
        layer.register_parameter("input_scale", input_scale)

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
@classmethod
def get_min_capability(cls) -> int:
    # lovelace and up
    return 89

process_weights_after_loading

process_weights_after_loading(layer) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
def process_weights_after_loading(self, layer) -> None:
    if self.strategy == QuantizationStrategy.TENSOR:
        weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
            layer.weight,
            layer.weight_scale,
            layer.logical_widths,
            getattr(layer, "input_scale", None),
        )
        weight = weight.t()
    elif self.strategy == QuantizationStrategy.CHANNEL:
        weight, weight_scale, input_scale = process_fp8_weight_channel_strategy(
            layer.weight, layer.weight_scale, getattr(layer, "input_scale", None)
        )
        weight = weight.t()

    elif self.strategy == QuantizationStrategy.BLOCK:
        assert self.is_static_input_scheme is False
        weight, weight_scale = process_fp8_weight_block_strategy(
            layer.weight, layer.weight_scale
        )
        input_scale = None

    else:
        raise ValueError(
            f"Unknown quantization strategy {self.strategy}: "
            f"should be one of {list(QuantizationStrategy)}"
        )

    # required by torch.compile to be torch.nn.Parameter
    layer.weight = Parameter(weight.data, requires_grad=False)
    layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
    if input_scale is not None:
        layer.input_scale = Parameter(input_scale.data, requires_grad=False)

    # INPUT SCALE
    if self.is_static_input_scheme and hasattr(layer, "input_scale"):
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
    else:
        layer.input_scale = None
    if self.strategy == QuantizationStrategy.BLOCK:
        maybe_post_process_fp8_weight_block(layer)