createSession method

  1. @override
Future<InferenceModelSession> createSession({
  1. double temperature = 0.8,
  2. int randomSeed = 1,
  3. int topK = 1,
  4. double? topP,
  5. String? loraPath,
  6. bool? enableVisionModality,
})
override

Creates a new InferenceModelSession for generation.

temperature, randomSeed, topK, topP — parameters for sampling. loraPath — optional path to LoRA model. enableVisionModality — enable vision modality for multimodal models.

Implementation

@override
Future<InferenceModelSession> createSession({
  double temperature = 0.8,
  int randomSeed = 1,
  int topK = 1,
  double? topP,
  String? loraPath,
  bool? enableVisionModality, // Enabling vision modality support
}) async {
  // TODO: Implement vision modality for web
  if (enableVisionModality == true) {
    if (kDebugMode) {
      debugPrint('Warning: Vision modality is not yet implemented for web platform');
    }
  }

  if (_initCompleter case Completer<InferenceModelSession> completer) {
    return completer.future;
  }
  final completer = _initCompleter = Completer<InferenceModelSession>();
  try {
    // Use Modern API to get model path (same as mobile)
    final activeModel = modelManager.activeInferenceModel;
    if (activeModel == null) {
      throw Exception('No active inference model set');
    }

    final modelFilePaths = await modelManager.getModelFilePaths(activeModel);
    if (modelFilePaths == null || modelFilePaths.isEmpty) {
      throw Exception('Model file paths not found');
    }

    // Get model path from Modern API
    final modelPath = modelFilePaths[PreferencesKeys.installedModelFileName];
    if (modelPath == null) {
      throw Exception('Model path not found in file paths');
    }

    final fileset = await FilesetResolver.forGenAiTasks(
            'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai@latest/wasm'.toJS)
        .toDart;

    // Get LoRA path if available
    final loraPathToUse = loraPath ?? modelFilePaths[PreferencesKeys.installedLoraFileName];
    final hasLoraParams = loraPathToUse != null && loraRanks != null;

    // Check if using OPFS streaming mode
    final registry = ServiceRegistry.instance;
    final useStreaming = registry.useStreamingStorage;

    // Create base options based on storage mode
    final LlmInferenceBaseOptions baseOptions;
    if (useStreaming && modelPath.startsWith('opfs://')) {
      // OPFS streaming mode: Get ReadableStreamDefaultReader
      final filename = modelPath.substring('opfs://'.length);
      debugPrint('[WebInferenceModel] Loading from OPFS: $filename');

      final downloadService = registry.downloadService;
      if (downloadService is! WebDownloadService) {
        throw Exception('Expected WebDownloadService for web platform');
      }

      final opfsService = downloadService.opfsService;
      if (opfsService == null) {
        throw Exception('OPFS service not available');
      }

      final streamReader = await opfsService.getStreamReader(filename);
      // ignore: dead_code
      debugPrint('[WebInferenceModel] Got OPFS stream reader');

      baseOptions = LlmInferenceBaseOptions(modelAssetBuffer: streamReader);
    } else {
      // Cache API / None mode: Use Blob URL
      debugPrint('[WebInferenceModel] Loading from Blob URL: $modelPath');
      baseOptions = LlmInferenceBaseOptions(modelAssetPath: modelPath);
    }

    final config = LlmInferenceOptions(
        baseOptions: baseOptions,
        maxTokens: maxTokens,
        randomSeed: randomSeed,
        topK: topK,
        temperature: temperature,
        topP: topP,
        supportedLoraRanks: !hasLoraParams ? null : Int32List.fromList(loraRanks!).toJS,
        loraPath: !hasLoraParams ? null : loraPathToUse,
        maxNumImages: supportImage ? (maxNumImages ?? 1) : null);

    final llmInference = await LlmInference.createFromOptions(fileset, config).toDart;

    session = WebModelSession(
      modelType: modelType,
      fileType: fileType,
      llmInference: llmInference,
      supportImage: supportImage, // Enabling image support
      onClose: onClose,
    );

    completer.complete(session);
    return completer.future;
  } catch (e, st) {
    completer.completeError(e, st);
    rethrow;
  }
}