Open3D (C++ API)  0.16.1
BlasWrapper.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// The MIT License (MIT)
5//
6// Copyright (c) 2018-2021 www.open3d.org
7//
8// Permission is hereby granted, free of charge, to any person obtaining a copy
9// of this software and associated documentation files (the "Software"), to deal
10// in the Software without restriction, including without limitation the rights
11// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12// copies of the Software, and to permit persons to whom the Software is
13// furnished to do so, subject to the following conditions:
14//
15// The above copyright notice and this permission notice shall be included in
16// all copies or substantial portions of the Software.
17//
18// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24// IN THE SOFTWARE.
25// ----------------------------------------------------------------------------
26
27#pragma once
28
32
33namespace open3d {
34namespace core {
35
36template <typename scalar_t>
37inline void gemm_cpu(CBLAS_LAYOUT layout,
38 CBLAS_TRANSPOSE trans_A,
39 CBLAS_TRANSPOSE trans_B,
43 scalar_t alpha,
44 const scalar_t *A_data,
46 const scalar_t *B_data,
48 scalar_t beta,
49 scalar_t *C_data,
51 utility::LogError("Unsupported data type.");
52}
53
54template <>
55inline void gemm_cpu<float>(CBLAS_LAYOUT layout,
56 CBLAS_TRANSPOSE trans_A,
57 CBLAS_TRANSPOSE trans_B,
61 float alpha,
62 const float *A_data,
64 const float *B_data,
66 float beta,
67 float *C_data,
69 cblas_sgemm(layout, trans_A, trans_B, m, n, k, alpha, A_data, lda, B_data,
70 ldb, beta, C_data, ldc);
71}
72
73template <>
74inline void gemm_cpu<double>(CBLAS_LAYOUT layout,
75 CBLAS_TRANSPOSE trans_A,
76 CBLAS_TRANSPOSE trans_B,
80 double alpha,
81 const double *A_data,
83 const double *B_data,
85 double beta,
86 double *C_data,
88 cblas_dgemm(layout, trans_A, trans_B, m, n, k, alpha, A_data, lda, B_data,
89 ldb, beta, C_data, ldc);
90}
91
92#ifdef BUILD_CUDA_MODULE
93template <typename scalar_t>
94inline cublasStatus_t gemm_cuda(cublasHandle_t handle,
95 cublasOperation_t transa,
96 cublasOperation_t transb,
97 int m,
98 int n,
99 int k,
100 const scalar_t *alpha,
101 const scalar_t *A_data,
102 int lda,
103 const scalar_t *B_data,
104 int ldb,
105 const scalar_t *beta,
106 scalar_t *C_data,
107 int ldc) {
108 utility::LogError("Unsupported data type.");
109 return CUBLAS_STATUS_NOT_SUPPORTED;
110}
111
112template <typename scalar_t>
113inline cublasStatus_t trsm_cuda(cublasHandle_t handle,
114 cublasSideMode_t side,
115 cublasFillMode_t uplo,
116 cublasOperation_t trans,
117 cublasDiagType_t diag,
118 int m,
119 int n,
120 const scalar_t *alpha,
121 const scalar_t *A,
122 int lda,
123 scalar_t *B,
124 int ldb) {
125 utility::LogError("Unsupported data type.");
126 return CUBLAS_STATUS_NOT_SUPPORTED;
127}
128
129template <>
130inline cublasStatus_t gemm_cuda<float>(cublasHandle_t handle,
131 cublasOperation_t transa,
132 cublasOperation_t transb,
133 int m,
134 int n,
135 int k,
136 const float *alpha,
137 const float *A_data,
138 int lda,
139 const float *B_data,
140 int ldb,
141 const float *beta,
142 float *C_data,
143 int ldc) {
144 return cublasSgemm(handle, transa,
145 transb, // A, B transpose flag
146 m, n, k, // dimensions
147 alpha, static_cast<const float *>(A_data), lda,
148 static_cast<const float *>(B_data),
149 ldb, // input and their leading dims
150 beta, static_cast<float *>(C_data), ldc);
151}
152
153template <>
154inline cublasStatus_t gemm_cuda<double>(cublasHandle_t handle,
155 cublasOperation_t transa,
156 cublasOperation_t transb,
157 int m,
158 int n,
159 int k,
160 const double *alpha,
161 const double *A_data,
162 int lda,
163 const double *B_data,
164 int ldb,
165 const double *beta,
166 double *C_data,
167 int ldc) {
168 return cublasDgemm(handle, transa,
169 transb, // A, B transpose flag
170 m, n, k, // dimensions
171 alpha, static_cast<const double *>(A_data), lda,
172 static_cast<const double *>(B_data),
173 ldb, // input and their leading dims
174 beta, static_cast<double *>(C_data), ldc);
175}
176
177template <>
178inline cublasStatus_t trsm_cuda<float>(cublasHandle_t handle,
179 cublasSideMode_t side,
180 cublasFillMode_t uplo,
181 cublasOperation_t trans,
182 cublasDiagType_t diag,
183 int m,
184 int n,
185 const float *alpha,
186 const float *A,
187 int lda,
188 float *B,
189 int ldb) {
190 return cublasStrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B,
191 ldb);
192}
193
194template <>
195inline cublasStatus_t trsm_cuda<double>(cublasHandle_t handle,
196 cublasSideMode_t side,
197 cublasFillMode_t uplo,
198 cublasOperation_t trans,
199 cublasDiagType_t diag,
200 int m,
201 int n,
202 const double *alpha,
203 const double *A,
204 int lda,
205 double *B,
206 int ldb) {
207 return cublasDtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B,
208 ldb);
209}
210#endif
211
212} // namespace core
213} // namespace open3d
#define OPEN3D_CPU_LINALG_INT
Definition: LinalgHeadersCPU.h:42
#define LogError(...)
Definition: Logging.h:67
void gemm_cpu< double >(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, OPEN3D_CPU_LINALG_INT m, OPEN3D_CPU_LINALG_INT n, OPEN3D_CPU_LINALG_INT k, double alpha, const double *A_data, OPEN3D_CPU_LINALG_INT lda, const double *B_data, OPEN3D_CPU_LINALG_INT ldb, double beta, double *C_data, OPEN3D_CPU_LINALG_INT ldc)
Definition: BlasWrapper.h:74
void gemm_cpu(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, OPEN3D_CPU_LINALG_INT m, OPEN3D_CPU_LINALG_INT n, OPEN3D_CPU_LINALG_INT k, scalar_t alpha, const scalar_t *A_data, OPEN3D_CPU_LINALG_INT lda, const scalar_t *B_data, OPEN3D_CPU_LINALG_INT ldb, scalar_t beta, scalar_t *C_data, OPEN3D_CPU_LINALG_INT ldc)
Definition: BlasWrapper.h:37
void gemm_cpu< float >(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, OPEN3D_CPU_LINALG_INT m, OPEN3D_CPU_LINALG_INT n, OPEN3D_CPU_LINALG_INT k, float alpha, const float *A_data, OPEN3D_CPU_LINALG_INT lda, const float *B_data, OPEN3D_CPU_LINALG_INT ldb, float beta, float *C_data, OPEN3D_CPU_LINALG_INT ldc)
Definition: BlasWrapper.h:55
Definition: PinholeCameraIntrinsic.cpp:35