Skip to content
Draft
27 changes: 25 additions & 2 deletions projects/miopen/src/conv/solver_finders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <miopen/perf_field.hpp>
#include <miopen/conv/problem_description.hpp>
#include <miopen/solution.hpp>
#include <miopen/solver/gemm_common.hpp>
#include <miopen/utility/modified_z.hpp>

MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_GEMM)
Expand Down Expand Up @@ -150,10 +151,11 @@ class GemmSolverFinder : public SolversFinderMixin<ProblemDescription, ConvFindP
}

bool IsEnabled(const ExecutionContext& /*ctx*/,
const ProblemDescription& /*problem*/,
const ProblemDescription& problem,
const ConvFindParameters& parameters) const override
{
return !parameters.use_winograd_only && !env::disabled(MIOPEN_DEBUG_CONV_GEMM);
return !parameters.use_winograd_only && !env::disabled(MIOPEN_DEBUG_CONV_GEMM) &&
!IsAlgorithmDisabled(miopenConvolutionAlgoGEMM, problem);
}

std::vector<solver::ConvSolution> FindImpl(const ExecutionContext& ctx,
Expand Down Expand Up @@ -436,6 +438,7 @@ FindCoreResult FindCore(const AnyInvokeParams& invoke_ctx,

namespace conv {

// Overload without problem parameter - checks only environment variable (global disable)
bool IsAlgorithmDisabled(miopenConvAlgorithm_t algo)
{
switch(algo)
Expand All @@ -457,6 +460,26 @@ bool IsAlgorithmDisabled(miopenConvAlgorithm_t algo)
} // clang-format on
}

// Overload with problem parameter - checks both environment variable and problem-specific
// constraints
bool IsAlgorithmDisabled(miopenConvAlgorithm_t algo, const ProblemDescription& problem)
{
// First check if algorithm is globally disabled
if(IsAlgorithmDisabled(algo))
return true;

// Then check problem-specific constraints
switch(algo)
{ // clang-format off
#if MIOPEN_USE_GEMM
case miopenConvolutionAlgoGEMM:
return solver::conv::gemm::IsGEMMProblemTooLarge(problem);
#endif
default: // if not globally disabled and no problem-specific constraints: do not disable
return false;
} // clang-format on
}

bool IsEnoughWorkspace(std::string_view where,
const miopen::solver::Id& solver_id,
const std::size_t required_size,
Expand Down
4 changes: 4 additions & 0 deletions projects/miopen/src/include/miopen/conv/solver_finders.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,11 @@ FindCoreResult FindCore(const AnyInvokeParams& invoke_ctx,
bool force_attach_binary = false);

namespace conv {
// Check if algorithm is globally disabled via environment variable
bool MIOPEN_INTERNALS_EXPORT IsAlgorithmDisabled(miopenConvAlgorithm_t algo);
// Check if algorithm is disabled (globally or for specific problem)
bool MIOPEN_INTERNALS_EXPORT IsAlgorithmDisabled(miopenConvAlgorithm_t algo,
const ProblemDescription& problem);
bool MIOPEN_INTERNALS_EXPORT IsEnoughWorkspace(std::string_view where,
const miopen::solver::Id& solver_id,
std::size_t required_size,
Expand Down
18 changes: 18 additions & 0 deletions projects/miopen/src/include/miopen/solver/gemm_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ bool IsAnyBufferFp16(const TensorDescriptor& xDesc,

double SlowdownFactor(int n_oper, double oper_factor, double multiple_oper_factor);

// Workaround for ALMIOPEN-1044: Temporary workspace size limit for GEMM solvers
// This can be removed once the underlying issue is resolved
#if MIOPEN_USE_GEMM
/// Determine if GEMM workspace size exceeds threshold.
///
/// This function estimates the workspace size required by GEMM solvers and compares it
/// against the configured limit. The calculation matches the workspace computation used
/// in GemmFwdRest solver, which represents the general (non-1x1) case.
///
/// The workspace size is calculated as:
/// C × filter_spatial × output_spatial × type_size × groups
/// For Int8, the size is doubled.
///
/// @param problem The convolution problem description.
/// @return true if estimated workspace exceeds limit and GEMM should be disabled, false otherwise.
bool IsGEMMProblemTooLarge(const miopen::conv::ProblemDescription& problem);
#endif

} // namespace gemm
} // namespace conv
} // namespace solver
Expand Down
70 changes: 70 additions & 0 deletions projects/miopen/src/solver/conv/gemm_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,25 @@
*******************************************************************************/

#include <miopen/env.hpp>
#include <miopen/logger.hpp>
#include <miopen/solver/gemm_common.hpp>
#include <boost/range/adaptors.hpp>
#include <numeric>

// Workaround for ALMIOPEN-1044: Temporary workspace size limit for GEMM solvers
// This can be removed once the underlying issue is resolved
#define MIOPEN_WORKAROUND_ALMIOPEN_1044 1

#if MIOPEN_WORKAROUND_ALMIOPEN_1044
// temporary workaround, essentially reverting #2393.
// PR 2393 removed this limit with this comment attached:
// "Workaround for MLOpen issue 1430. Vega20 fails to access GPU memory
// larger than the return value of GetMaxMemoryAllocSize() of Vega10.
// Due to historical reasons, this W/A is applied to all targets.
// We are going to keep it as is until the new GEMM backend
// is used instead of rocBLAS. See also issue #2809."
MIOPEN_DECLARE_ENV_VAR_UINT64(MIOPEN_DEBUG_CONV_GEMM_MAX_WORKSPACE_SIZE, 7287183769)
#endif

namespace miopen {
namespace solver {
Expand Down Expand Up @@ -63,6 +81,58 @@ double SlowdownFactor(const int n_oper, const double oper_factor, const double m
return 1.0;
}

#if MIOPEN_USE_GEMM
bool IsGEMMProblemTooLarge(const miopen::conv::ProblemDescription& problem)
{
#if MIOPEN_WORKAROUND_ALMIOPEN_1044
const std::size_t max_size = env::value(MIOPEN_DEBUG_CONV_GEMM_MAX_WORKSPACE_SIZE);
// 0 means no limit
if(max_size == 0)
return false;

const auto& conv = problem.GetConv();
const auto& wDesc = problem.GetWeights();
const auto& yDesc = problem.GetOut();

const auto spatial_dim = conv.GetSpatialDimension();
const auto wei_spatial = boost::adaptors::slice(wDesc.GetLengths(), 2, 2 + spatial_dim);
const auto out_spatial = boost::adaptors::slice(yDesc.GetLengths(), 2, 2 + spatial_dim);
const auto wei_c = wDesc.GetLengths()[1];

// Calculate workspace size using the same formula as GemmFwdRest::GetWorkspaceSize()
// workspace = C × filter_spatial × output_spatial × type_size × groups
const auto workspace_size = wei_c *
std::accumulate(wei_spatial.begin(),
wei_spatial.end(),
std::size_t(1),
std::multiplies<std::size_t>()) *
std::accumulate(out_spatial.begin(),
out_spatial.end(),
std::size_t(1),
std::multiplies<std::size_t>()) *
GetTypeSize(wDesc.GetType()) * conv.group_count;

// For Int8, workspace is doubled (for transpose operations)
const auto ws_sz = (wDesc.GetType() == miopenInt8 ? 2 * workspace_size : workspace_size);

// Workspace is within limit
if(ws_sz <= max_size)
{
return false;
}

MIOPEN_LOG_I2("GEMMSolverFinder disabled for workspace size "
<< ws_sz << " bytes > " << max_size
<< " bytes (MIOPEN_DEBUG_CONV_GEMM_MAX_WORKSPACE_SIZE)");
return true;
#else
// Workaround disabled - no size limit
(void)problem; // Suppress unused parameter warning
return false;
#endif
}
#endif

} // namespace gemm
} // namespace conv
} // namespace solver
Expand Down
Loading