forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 78
Expand file tree
/
Copy pathBlas.cpp
More file actions
1029 lines (916 loc) · 38.9 KB
/
Blas.cpp
File metadata and controls
1029 lines (916 loc) · 38.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <cstdint>
#include <c10/util/typeid.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/native/ScaledBlasUtils.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/tunable/TunableGemm.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/native/GroupedMMUtils.h>
#include <ATen/native/cuda/cuBlasCommonArgs.h>
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#include <ATen/ceil_div.h>
#ifdef USE_MSLK
#include <mslk/gemm/gemm_torch.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
namespace at::native {
using at::blas::ScalingType;
using at::blas::SwizzleType;
c10::MaybeOwned<Tensor> prepare_batch_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, int64_t& ld_tensor, bool transpose_result, int64_t m, int64_t n) {
IntArrayRef tensor_strides = tensor.strides();
c10::MaybeOwned<Tensor> tensor_;
int fast_dim = transpose_result ? 2 : 1;
int leading_dim = transpose_result ? 1 : 2;
if (tensor_strides[fast_dim] == 1 &&
(tensor_strides[leading_dim] >= std::max<int64_t>(1, m))) {
transpose_tensor = false;
tensor_ = resolve_conj_if_indicated(tensor, true);
ld_tensor = tensor_->strides()[leading_dim];
} else if ((tensor_strides[leading_dim] == 1) &&
(tensor_strides[fast_dim] >= std::max<int64_t>(1, n))) {
transpose_tensor = true;
tensor_ = resolve_conj_if_indicated(tensor, false);
ld_tensor = tensor_->strides()[fast_dim];
} else {
transpose_tensor = !transpose_result;
// gemm call requires leading dimension and stride parameters to be non-zero
bool is_stride_non_zero = tensor.strides()[1] != 0 && tensor.strides()[2] != 0;
if (tensor.is_contiguous() && is_stride_non_zero) {
tensor_ = resolve_conj_if_indicated(tensor, transpose_result);
} else {
tensor_ = c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
}
ld_tensor = tensor_->strides()[1];
}
return tensor_;
}
namespace {
enum class Activation {
None,
RELU,
GELU,
};
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
switch (a) {
case Activation::None:
return cuda::blas::GEMMAndBiasActivationEpilogue::None;
case Activation::RELU:
return cuda::blas::GEMMAndBiasActivationEpilogue::RELU;
case Activation::GELU:
return cuda::blas::GEMMAndBiasActivationEpilogue::GELU;
default:
TORCH_CHECK(false);
return cuda::blas::GEMMAndBiasActivationEpilogue::None;
}
}
/*
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
* Additionally, for ROCM we test whether the architecture supports the Lt.
*/
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
/* On ROCM, we have the following order of precedence:
- When hipBLASLt is NOT supported on the architecture, return true.
- If and only if the environment is set, then return the value that it set to.
- If the environment variable is NOT set, treturn a value based on the preferred BLAS backend.
*/
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
#ifdef USE_ROCM
const auto& archs = at::detail::getCUDAHooks().getHipblasltSupportedArchs();
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
if (!is_hipblas_lt_arch_supported) {
return true;
}
// If environment variable is explicitly set, respect it
if (is_addmm_cuda_lt_disabled.has_value()) {
return is_addmm_cuda_lt_disabled == "1";
}
// The available BLAS backends on ROCm are: rocBLAS, hipBLASLt, and CK.
const auto preferred_backend = at::globalContext().blasPreferredBackend();
if (preferred_backend == at::BlasBackend::Cublaslt) {
return false;
} else {
return true;
}
#else
if (is_addmm_cuda_lt_disabled == "1") {
return true;
}
return false;
#endif
}
/*
* Check whether for the given input we want to enable the Lt interface
*/
static bool isInputCompliesAddmmCudaLt(
Tensor& result,
const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
Activation activation
) {
#ifdef USE_ROCM
// Implies 2D bias which we currently not send through Lt.
// TODO: this check is done pre col-major input preparation,
// so, this condition can be ralexed in cases when a col-major
// copy of result is needed.
if (self.is_same(result) || self.dim() == 2) {
return false;
}
#endif
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
const auto args = cublasCommonArgs(mat1, mat2, result);
if (args.transa == 't' && args.transb == 't') {
return false;
}
#endif
const auto mat1_sizes = mat1.sizes();
const auto mat2_sizes = mat2.sizes();
const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0
// NOTE: row-major result is important when bias is 1D.
// This is because Lt broadcasts 1D bias over the columns
// while the aten::addmm API broadcasts it over the rows,
// and this is in conjunction with the data preparation
// procedure that does not transpose arguments with
// col-major result. For col-major result we need
// to explicitly transpose the problem so that bias is
// correctly applied.
// TODO: enable col-major result if needed.
// TODO: no need to check result's layout when
// !result.is_same(self) and self.dim() == 2, because
// self needs to be copied into result and the bias ptr
// will be ignored.
&& result.dim() == 2 && result.is_contiguous()
&& (
( // Conditions for bias to be fusable -- implies direct Lt path without copies.
self.is_contiguous() &&
// NOTE: fine to have 1-len dims to the left from the right-most one
(self.dim() == 1 || self.squeeze().dim() == 1) &&
self.sizes().back() == mat2_sizes[1]
)
|| ( // 2D bias restrictions. self.is_contiguous() is implicit when result.is_same(self),
// and we need to copy self into result otherwise, so the self's layout becomes irrelevant.
// See also TODO from above.
activation != Activation::None && // Lt is faster when activation is fused
(self.dim() == 2 && at::is_expandable_to(self.sizes(), {mat1_sizes[0], mat2_sizes[1]}))
)
)
&& ( // some dtype restrictions
#ifndef USE_ROCM
scalar_type == at::ScalarType::Double ||
#endif
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16
)
&& ( // some shape/stride restrictions
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
)
);
// no compliance by default
return false;
}
template <typename scalar_t>
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::GemmAndBiasParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.alpha = alpha.to<at::opmath_type<scalar_t>>();
params.a = args.mata->const_data_ptr<scalar_t>();
params.lda = args.lda;
params.b = args.matb->const_data_ptr<scalar_t>();
params.ldb = args.ldb;
params.c = args.result->data_ptr<scalar_t>();
params.ldc = args.result_ld;
params.bias = bias;
params.activation = activation;
if (transa_ && transb_) {
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T> gemm{};
gemm(¶ms);
}
else if (transa_ && !transb_) {
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N> gemm{};
gemm(¶ms);
}
else if (!transa_ && transb_) {
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T> gemm{};
gemm(¶ms);
}
else if (!transa_ && !transb_) {
static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N> gemm{};
gemm(¶ms);
}
else {
TORCH_CHECK(false, "unreachable");
}
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmAndBiasCublasLt(
// args contains result which is modified
cublasCommonArgs& args,
const std::optional<Tensor>& self,
const Scalar& alpha,
Activation activation = Activation::None
) {
// We apply bias in the epilogue only when it is 1D,
// or when it can be squeezed to 1D.
// self_ptr == nullptr implies ignore bias epilogue
// and use standard gemm-like API.
const auto* self_ptr = self.has_value() ? self.value().const_data_ptr<scalar_t>() : static_cast<const scalar_t*>(nullptr);
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
// TODO: maybe also return some success state?
launchTunableGemmAndBias<scalar_t>(
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
);
return true;
}
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self_ptr,
args.result->data_ptr<res_scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmCublas(
// args contains result which is modified
cublasCommonArgs& args,
const Scalar& alpha,
const Scalar& beta
) {
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
beta.to<at::opmath_type<scalar_t>>(),
args.result->data_ptr<res_scalar_t>(),
args.result_ld
);
return true; // success!
}
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
// Shape checks {
// Make sure to keep addmm_cuda below in sync with this code; it
// preflights a check to try to avoid actually needing to call
// expand().
TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(
mat1.dtype() == mat2.dtype(),
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
)
if (result.is_same(self)) {
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
}
// } Shape checks
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
checkAllSameGPU(__func__, targs);
// Handle whether to use the Lt interface {
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
// if lt path fails, we recurse back into this function here and force the lt path to off
// we cannot update variable disable_addmm_cuda_lt from above since it is static and would be permanent
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
// NOTE: See https://github.com/pytorch/pytorch/issues/172231
const auto preferred_cublas_backend = at::globalContext().blasPreferredBackend();
disable_addmm_cuda_lt = !(
preferred_cublas_backend == BlasBackend::Cublaslt
|| preferred_cublas_backend == BlasBackend::Default // Lt is default
) || disable_addmm_cuda_lt;
#ifdef USE_ROCM
// Conditioned on the device index, which is not persistent
disable_addmm_cuda_lt = disable_addmm_cuda_lt || isGloballyDisabledAddmmCudaLt(self.device());
#endif
// Condition on the input
disable_addmm_cuda_lt = disable_addmm_cuda_lt || !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation);
at::ScalarType scalar_type = mat1.scalar_type();
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
#ifdef USE_ROCM
disable_addmm_cuda_lt = disable_addmm_cuda_lt || is_float_output_with_half_input;
#endif
bool use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt;
// for float output with half input cublasLT with bias produces wrong results
use_bias_ptr_lt &= !is_float_output_with_half_input;
// Handle result/self shapes
if (!result.is_same(self)) {
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
// We do not copy bias only when we need the bias ptr
if (beta.toComplexDouble() != 0.0 && !use_bias_ptr_lt) {
// NOTE: self should broadcast over result
at::native::copy_(result, *expand_size(self, result.sizes(), "addmm"));
}
}
// Short circuit on empty result
if (result.numel() == 0) {
return result;
}
// Short circuit if the reduction dim is empty
if (mat1.sizes()[1] == 0) {
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (beta.toComplexDouble() == 0.) {
return result.zero_();
}
// TODO: We could squeeze some perf by calling at::cuda::mul_out here instead, to bypass the dispatcher.
// That requires some fixing some internal build dependencies though.
return at::mul_out(
result,
self.expand(result.sizes()),
at::native::scalar_tensor(
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */
)
);
}
cublasCommonArgs args(mat1, mat2, result);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
// The Lt path
if (!disable_addmm_cuda_lt) {
bool lt_success = false;
if (is_float_output_with_half_input) {
#ifdef USE_ROCM
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
#else
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, use_bias_ptr_lt ? std::make_optional(self) : std::nullopt, alpha, activation);
}
);
#endif
} else {
// !is_float_output_with_half_input
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, use_bias_ptr_lt ? std::make_optional(self) : std::nullopt, alpha, activation);
}
);
} // end is_float_output_with_half_input
if (!lt_success) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
// end Lt path
} else {
// No Lt, we use a GEMM instead
if (is_float_output_with_half_input) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda",
[&] {
launchGemmCublas<scalar_t, float>(args, alpha, beta);
}
);
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda",
[&] {
launchGemmCublas<scalar_t>(args, alpha, beta);
}
);
}
// Apply epilogue
switch (activation) {
case Activation::RELU:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
at::relu_(const_cast<Tensor&>(*args.result));
break;
case Activation::GELU:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
break;
default: break;
}
} // end GEMM path
// Preprocessor gate here needs to match the inverse of the check
// gating activation_to_gemm_and_blas_arg above; here we are manually
// performing a post-GELU because we weren't able to use the GELU
// epilogue above.
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
}
#endif
if (!result.is_same(*args.result)) {
result.copy_(*args.result);
}
return result;
}
const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
// handle pathological cases that blas may not like
if (result.numel() == 0) {
return result;
} else if (batch1.size(2) == 0) {
if (beta.to<c10::complex<double>>() == 0.0) {
return result.zero_();
} else {
return result.mul_(beta);
}
}
bool transpose_result = false;
c10::MaybeOwned<Tensor> result_;
IntArrayRef result_strides = result.strides();
IntArrayRef result_sizes = result.sizes();
if ((result_strides[1] == 1) &&
((result_sizes[2] == 1) || (result_strides[2] >= std::max<int64_t>(1, result_sizes[1])))) {
result_ = resolve_conj_if_indicated(result, true);
} else if ((result_strides[2] == 1) &&
(result_sizes[1] == 1 || (result_strides[1] >= std::max<int64_t>(1, result_sizes[2])))) {
transpose_result = true;
result_ = resolve_conj_if_indicated(result, true);
} else {
result_ = c10::MaybeOwned<Tensor>::owned(result.transpose(1, 2).clone(at::MemoryFormat::Contiguous).transpose(1, 2));
}
int leading_dim = transpose_result ? 1 : 2;
int64_t m = result_sizes[transpose_result ? 2 : 1];
int64_t n = result_sizes[leading_dim];
int64_t k = (transpose_result ? batch2 : batch1).sizes()[leading_dim];
int64_t lda = 0, ldb = 0, ldc = 0;
bool transpose_batch1 = false, transpose_batch2 = false;
auto batch1_ = prepare_batch_matrix_for_cublas(transpose_result ? batch2 : batch1, transpose_batch1, lda, transpose_result, m, k);
auto batch2_ = prepare_batch_matrix_for_cublas(transpose_result ? batch1 : batch2, transpose_batch2, ldb, transpose_result, k, n);
ldc = result_->strides()[leading_dim];
int64_t num_batches = result_->sizes()[0];
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj());
bool is_float_output_with_half_input = (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
if (is_float_output_with_half_input) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(batch1.scalar_type(), "baddbmm_cuda", [&] {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* batch1_ptr = batch1_->const_data_ptr<scalar_t>();
const scalar_t* batch2_ptr = batch2_->const_data_ptr<scalar_t>();
const auto transa = transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n';
const auto transb = transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n';
float* result_ptr = result_->mutable_data_ptr<float>();
// If batch is 1 call gemm rather than bgemm
if (num_batches == 1) {
at::cuda::blas::gemm<scalar_t, float>(
transa, transb,
m, n, k,
alpha_val,
batch1_ptr, lda,
batch2_ptr, ldb,
beta_val,
result_ptr, ldc);
} else {
at::cuda::blas::bgemm<scalar_t, float>(
transa, transb,
m, n, k,
alpha_val,
batch1_ptr, lda, batch1_->strides()[0],
batch2_ptr, ldb, batch2_->strides()[0],
beta_val,
result_ptr, ldc, result_->strides()[0],
num_batches
);
}
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, batch1.scalar_type(), "baddbmm_cuda", [&] {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* batch1_ptr = batch1_->const_data_ptr<scalar_t>();
const scalar_t* batch2_ptr = batch2_->const_data_ptr<scalar_t>();
const auto transa = transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n';
const auto transb = transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n';
scalar_t* result_ptr = result_->mutable_data_ptr<scalar_t>();
// If batch is 1 call gemm rather than bgemm
if (num_batches == 1) {
at::cuda::blas::gemm<scalar_t>(
transa, transb,
m, n, k,
alpha_val,
batch1_ptr, lda,
batch2_ptr, ldb,
beta_val,
result_ptr, ldc);
} else {
at::cuda::blas::bgemm<scalar_t>(
transa, transb,
m, n, k,
alpha_val,
batch1_ptr, lda, batch1_->strides()[0],
batch2_ptr, ldb, batch2_->strides()[0],
beta_val,
result_ptr, ldc, result_->strides()[0],
num_batches
);
}
});
}
if (!result.is_same(*result_)) {
result.copy_(*result_);
}
return result;
}
} // anonymous namespace
TORCH_IMPL_FUNC(addmm_out_cuda)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
addmm_out_cuda_impl(const_cast<Tensor&>(result), self, mat1, mat2, beta, alpha);
}
TORCH_IMPL_FUNC(addmm_activation_out_cuda)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, bool use_gelu, const Tensor& result) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
addmm_out_cuda_impl(const_cast<Tensor&>(result), self, mat1, mat2, beta, alpha, use_gelu ? Activation::GELU : Activation::RELU);
}
TORCH_IMPL_FUNC(mm_out_cuda)(const Tensor& self, const Tensor& mat2, const Tensor& result) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
addmm_out_cuda_impl(const_cast<Tensor&>(result), result, self, mat2, 0, 1);
}
TORCH_IMPL_FUNC(baddbmm_out_cuda)(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
{
at::NoNamesGuard guard;
baddbmm_out_cuda_impl(result, self, batch1, batch2, beta, alpha);
}
}
TORCH_IMPL_FUNC(bmm_out_cuda)(const Tensor& batch1, const Tensor& batch2, const Tensor &result) {
Scalar beta(0.0);
Scalar alpha(1.0);
{
NoNamesGuard guard;
baddbmm_out_cuda_impl(result, result, batch1, batch2, beta, alpha);
}
}
namespace {
inline void dot_check(const Tensor& self, const Tensor& other) {
TORCH_CHECK(
self.dim() == 1 && other.dim() == 1,
"1D tensors expected, but got ",
self.dim(),
"D and ",
other.dim(),
"D tensors");
TORCH_CHECK(
self.scalar_type() == other.scalar_type(),
"dot : expected both vectors to have same dtype, but found ",
self.scalar_type(),
" and ",
other.scalar_type());
TORCH_CHECK(
self.numel() == other.numel(),
"inconsistent tensor size, expected tensor [",
self.numel(),
"] and src [",
other.numel(),
"] to have the same number of elements, but got ",
self.numel(),
" and ",
other.numel(),
" elements respectively");
TORCH_CHECK(
(self.numel() <= INT_MAX) && (self.stride(0) <= INT_MAX) &&
(other.stride(0) <= INT_MAX),
"dot only supports n, incx, incy with the bound [val] <= %d",
INT_MAX);
}
} // anonymous namespace
Tensor dot_cuda(const Tensor& self, const Tensor& other) {
if (self.is_complex()) {
if (self.is_conj()) {
if (other.is_conj()) {
return (dot_cuda(self.conj(), other.conj())).conj();
} else {
return vdot_cuda(self.conj(), other);
}
} else if (other.is_conj()) {
return vdot_cuda(other.conj(), self);
}
}
at::NoNamesGuard guard;
dot_check(self, other);
const int n = static_cast<int>(self.numel());
int incx = static_cast<int>(self.stride(0));
int incy = static_cast<int>(other.stride(0));
if (n == 1) {
incx = 1;
incy = 1;
}
if (self._is_zerotensor() || other._is_zerotensor()) {
return at::_efficientzerotensor({}, self.options());
}
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
self.scalar_type(), "dot",
[&] {
Tensor result = at::empty({}, self.options());
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::PointerModeGuard pointerModeGuard(handle, CUBLAS_POINTER_MODE_DEVICE);
at::cuda::blas::dot<scalar_t>(
handle,
n,
self.const_data_ptr<scalar_t>(),
incx,
other.const_data_ptr<scalar_t>(),
incy,
result.mutable_data_ptr<scalar_t>());
return result;
});
}
Tensor vdot_cuda(const Tensor& self, const Tensor& other) {
if (!self.is_complex()) {
return dot_cuda(self, other);
}
if (self.is_conj()) {
if (other.is_conj()) {
return vdot_cuda(other.conj(), self.conj());
} else {
return dot_cuda(self.conj(), other);
}
} else if (other.is_conj()) {
return (dot_cuda(self, other.conj())).conj();
}
at::NoNamesGuard guard;
dot_check(self, other);
if (self._is_zerotensor() || other._is_zerotensor()) {
return at::_efficientzerotensor({}, self.options());
}
const int n = static_cast<int>(self.numel());
int incx = static_cast<int>(self.stride(0));
int incy = static_cast<int>(other.stride(0));
if (n == 1) {
incx = 1;
incy = 1;
}
return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] {
Tensor result = at::empty({}, self.options());
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::PointerModeGuard pointerModeGuard(
handle, CUBLAS_POINTER_MODE_DEVICE);
at::cuda::blas::vdot<scalar_t>(
handle,
n,
self.const_data_ptr<scalar_t>(),
incx,
other.const_data_ptr<scalar_t>(),
incy,
result.mutable_data_ptr<scalar_t>());
return result;
});
}
TORCH_IMPL_FUNC(addmv_out_cuda)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) {
c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
auto betaval = beta_.toComplexDouble();
if (mat.numel() == 0) {
// shortcut for an empty matrix
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (betaval == 0.0) {
result.zero_();
} else {
at::mul_out(
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<Tensor&>(result),
self,
at::native::scalar_tensor(
beta_, self.scalar_type(), std::nullopt /* layout */, at::kCPU, std::nullopt /* pin_memory */));
}
} else {
if (!result.is_same(*self_) && betaval != 0.0) { //if beta is 0, result contents will be zeroed later
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
at::native::copy_(const_cast<Tensor&>(result), *self_);
}
if (result.numel() != 0) {
auto r_stride = result.stride(0);
auto vec_stride = vec.stride(0);
// Check for contiguity of `vec` and update `vec_stride` accordingly
const auto vec_contiguous = vec_stride == 0 ? vec.contiguous() : vec;
// A vector can be contiguous and have a stride of zero if it has it is of length 1
vec_stride = std::max<int64_t>(vec_contiguous.stride(0), 1LL);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, mat.scalar_type(), "addmv_impl_cuda", [&] {
auto beta = beta_.to<scalar_t>();
auto alpha = alpha_.to<scalar_t>();
if (mat.stride(0) == 1 && mat.stride(1) >= std::max<int64_t>(1, mat.size(0))) {
at::cuda::blas::gemv<scalar_t>('n',
mat.size(0), mat.size(1), alpha, mat.const_data_ptr<scalar_t>(), mat.stride(1), vec_contiguous.const_data_ptr<scalar_t>(),
vec_stride, beta, result.mutable_data_ptr<scalar_t>(), r_stride);
}
else if (mat.stride(1) == 1 && mat.stride(0) >= std::max<int64_t>(1, mat.size(1))) {
at::cuda::blas::gemv<scalar_t>('t',
mat.size(1), mat.size(0), alpha, mat.const_data_ptr<scalar_t>(), mat.stride(0),
vec_contiguous.const_data_ptr<scalar_t>(), vec_stride, beta, result.mutable_data_ptr<scalar_t>(), r_stride);
}
else {
Tensor cmat = mat.contiguous();
at::cuda::blas::gemv<scalar_t>('t',
mat.size(1), mat.size(0), alpha, cmat.const_data_ptr<scalar_t>(), cmat.stride(0),
vec_contiguous.const_data_ptr<scalar_t>(), vec_stride, beta, result.mutable_data_ptr<scalar_t>(), r_stride);
}
});
}
}
}
Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result) {
// NOTE: cuBLAS is currently broken for some combination of transposed inputs.
TORCH_CHECK(self.dim() == 2, "Expected self to be of dimension 2 but got ", self.dim());
TORCH_CHECK(mat2.dim() == 2, "Expected mat2 to be of dimension 2 but got ", mat2.dim());
TORCH_CHECK(self.size(0) > 16, "self.size(0) needs to be greater than 16, but got ", self.size(0));
TORCH_CHECK(self.size(1) > 0 && self.size(1) % 8 == 0, "self.size(1) needs to be greater than 0 and a multiple of 8, but got ", self.size(1));
TORCH_CHECK(self.size(1) == mat2.size(0), "self.size(1) needs to match mat2.size(0) but got ", self.size(1), " and ", mat2.size(0));
TORCH_CHECK(mat2.size(1) > 0 && mat2.size(1) % 8 == 0, "mat2.size(1) needs to be greater than 0 and a multiple of 8, but got ", mat2.size(1));
TORCH_CHECK(result.dtype() == at::kInt, "Expected result dtype to be of type kInt but got ", result.dtype());
TORCH_CHECK(result.size(0) == self.size(0), "Expected result.size(0) to be ", self.size(0), " but got ", result.size(0));
TORCH_CHECK(result.size(1) == mat2.size(1), "Expected result.size(1) to be ", mat2.size(1), " but got ", result.size(1));
TORCH_CHECK(result.dim() == 2, "Expected result to be of dimension 2 but got ", result.dim());
TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");
cublasCommonArgs args(self, mat2, result);
at::cuda::blas::int8_gemm(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
args.mata->data_ptr<int8_t>(),
args.lda,
args.matb->data_ptr<int8_t>(),
args.ldb,
args.result->data_ptr<int32_t>(),
args.result_ld);
if (!result.is_same(*args.result)) {
result.copy_(*args.result);
}
return result;
}
Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
Tensor result = at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
return _int_mm_out_cuda(self, mat2, result);
}
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
// ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
const auto batch1_sizes = batch1.sizes();
const auto batch2_sizes = batch2.sizes();
int64_t bs = batch1_sizes[0];
int64_t contraction_size = batch1_sizes[2];
int64_t res_rows = batch1_sizes[1];
int64_t res_cols = batch2_sizes[2];
std::vector<int64_t> output_size {bs, res_rows, res_cols};
TORCH_CHECK(batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size,
"Expected size for first two dimensions of batch2 tensor to be: [",
bs, ", ", contraction_size, "] but got: [", batch2_sizes[0], ", ", batch2_sizes[1], "].");
TORCH_CHECK(batch1.scalar_type() == batch2.scalar_type(), "batch1 and batch2 must have the same dtype");
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
if (self_baddbmm.has_value()) {
const auto& self = self_baddbmm.value();
TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
TORCH_CHECK(self.sizes() == output_size, "self must have the same shape as the output");
}
}
Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {
Tensor out = at::empty({batch1.size(0), batch1.size(1), batch2.size(2)}, batch1.options().dtype(out_dtype));
return _bmm_out_dtype_cuda(batch1, batch2, out_dtype, out);
}
Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, Tensor &out) {
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype);
Scalar beta(0.0);
Scalar alpha(1.0);
{
NoNamesGuard guard;
baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha);
}
return out;
}
Tensor _baddbmm_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) {
TORCH_CHECK(self.scalar_type() == out_dtype || self.scalar_type() == batch1.dtype(),
"self dtype must match either out_dtype or batch1 dtype");
Tensor out = at::empty({batch1.size(0), batch1.size(1), batch2.size(2)}, batch1.options().dtype(out_dtype));
return _baddbmm_out_dtype_cuda(self, batch1, batch2, out_dtype, beta, alpha, out);
}
Tensor& _baddbmm_out_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, out);
// We need to copy the tensor
out.copy_(self);
{
NoNamesGuard guard;
baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha);
}
return out;
}
Tensor _mm_dtype_cuda(const Tensor& self, const Tensor& mat2, const at::ScalarType out_dtype) {
Tensor result = at::empty({self.size(0), mat2.size(1)}, self.options().dtype(out_dtype));
return _mm_dtype_out_cuda(self, mat2, out_dtype, result);
}
Tensor& _mm_dtype_out_cuda(const Tensor& self, const Tensor& mat2, const at::ScalarType out_dtype, Tensor &out) {
TORCH_CHECK(self.dim() == 2, "self must be a matrix, got ", self.dim(), "-D tensor");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
TORCH_CHECK(
self.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
self.sizes()[0], "x", self.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "input dtypes must be the same");
TORCH_CHECK(out_dtype == self.scalar_type() ||
(out_dtype == at::ScalarType::Float && (self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
addmm_out_cuda_impl(out, out, self, mat2, 0, 1);
return out;
}
Tensor _addmm_dtype_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) {