|
| 1 | +// Copyright (C) 2020 THL A29 Limited, a Tencent company. |
| 2 | +// All rights reserved. |
| 3 | +// Licensed under the BSD 3-Clause License (the "License"); you may |
| 4 | +// not use this file except in compliance with the License. You may |
| 5 | +// obtain a copy of the License at |
| 6 | +// https://opensource.org/licenses/BSD-3-Clause |
| 7 | +// Unless required by applicable law or agreed to in writing, software |
| 8 | +// distributed under the License is distributed on an "AS IS" basis, |
| 9 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
| 10 | +// implied. See the License for the specific language governing |
| 11 | +// permissions and limitations under the License. |
| 12 | +// See the AUTHORS file for names of contributors. |
| 13 | +#include "blas.h" |
| 14 | +#define EIGEN_DONT_PARALLELIZE |
| 15 | +#include "unsupported/Eigen/CXX11/Tensor" |
| 16 | +extern "C" { |
| 17 | +void cblas_sgemm_batch(const CBLAS_ORDER Layout, |
| 18 | + const CBLAS_TRANSPOSE* transa_array, |
| 19 | + const CBLAS_TRANSPOSE* transb_array, |
| 20 | + const blasint* m_array, const blasint* n_array, |
| 21 | + const blasint* k_array, const float* alpha_array, |
| 22 | + const float** a_array, const blasint* lda_array, |
| 23 | + const float** b_array, const blasint* ldb_array, |
| 24 | + const float* beta_array, float** c_array, |
| 25 | + const blasint* ldc_array, const blasint group_count, |
| 26 | + const blasint* group_size) { |
| 27 | + int idx = 0; |
| 28 | + for (int i = 0; i < group_count; ++i) { |
| 29 | + auto alpha = alpha_array[i]; |
| 30 | + auto beta = beta_array[i]; |
| 31 | + for (int j = 0; j < group_size[i]; ++j) { |
| 32 | + cblas_sgemm(Layout, transa_array[i], transb_array[i], m_array[i], |
| 33 | + n_array[i], k_array[i], alpha, a_array[idx], lda_array[i], |
| 34 | + b_array[idx], ldb_array[i], beta, c_array[idx], ldc_array[i]); |
| 35 | + ++idx; |
| 36 | + } |
| 37 | + } |
| 38 | +} |
| 39 | + |
| 40 | +using Vec = Eigen::TensorMap<Eigen::Tensor<float, 1>>; |
| 41 | + |
| 42 | +void vsTanh(blasint N, const float* in, float* out) { |
| 43 | + Vec input(const_cast<float*>(in), N); |
| 44 | + Vec output(out, N); |
| 45 | + |
| 46 | + // let use eigen to calculate tanh. |
| 47 | + // Eigen can use `FAST_MATH`. |
| 48 | + output = input.tanh(); |
| 49 | +} |
| 50 | +} |
0 commit comments