Как использовать код, который опирается на ThreadLocal с сопрограммами Kotlin

Некоторые платформы JVM используют ThreadLocal хранить контекст вызова приложения, такого как SLF4j MDC, менеджеры транзакций, менеджеры безопасности и другие.

Тем не менее, сопрограммы Kotlin рассылаются по разным потокам, так как это может быть сделано для работы?

(Вопрос вдохновлен проблемой GitHub)

2 ответа

Решение

Аналог сопрограммы ThreadLocal является CoroutineContext,

Взаимодействовать с ThreadLocal -используя библиотеки, вам нужно реализовать ContinuationInterceptor который поддерживает специфичные для фреймворка локальные потоки.

Вот пример. Давайте предположим, что мы используем некоторую структуру, которая опирается на определенный ThreadLocal хранить некоторые специфичные для приложения данные (MyData в этом примере):

val myThreadLocal = ThreadLocal<MyData>()

Чтобы использовать его с сопрограммами, вам нужно реализовать контекст, который сохраняет текущее значение MyData и помещает его в соответствующий ThreadLocal каждый раз, когда сопрограмма возобновляется в потоке. Код должен выглядеть так:

class MyContext(
    private var myData: MyData,
    private val dispatcher: ContinuationInterceptor
) : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {
    override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
        dispatcher.interceptContinuation(Wrapper(continuation))

    inner class Wrapper<T>(private val continuation: Continuation<T>): Continuation<T> {
        private inline fun wrap(block: () -> Unit) {
            try {
                myThreadLocal.set(myData)
                block()
            } finally {
                myData = myThreadLocal.get()
            }
        }

        override val context: CoroutineContext get() = continuation.context
        override fun resume(value: T) = wrap { continuation.resume(value) }
        override fun resumeWithException(exception: Throwable) = wrap { continuation.resumeWithException(exception) }
    }
}

Чтобы использовать его в своих сопрограммах, вы оборачиваете диспетчер, с которым хотите использовать MyContext и дать ему начальное значение ваших данных. Это значение будет помещено в локальный поток в потоке, где возобновляется сопрограмма.

launch(MyContext(MyData(), CommonPool)) {
    // do something...
}

Приведенная выше реализация также будет отслеживать любые изменения в локальном потоке, которые были сделаны, и сохранять их в этом контексте, так что множественный вызов может совместно использовать данные "локального потока" через контекст.

ОБНОВЛЕНИЕ: Начиная с kotlinx.corutines версия 0.25.0 есть прямая поддержка для представления Java ThreadLocal экземпляры как сопрограммные элементы контекста. Смотрите эту документацию для деталей. Существует также встроенная поддержка SLF4J MDC через kotlinx-coroutines-slf4j модуль интеграции.

Хотя этот вопрос довольно старый, но я хотел бы добавить к ответу Романа еще один возможный подход с . Может быть, это будет полезно для кого-то еще.

      // Snippet from the source code's comment
class TraceContextElement(private val traceData: TraceData?) : CopyableThreadContextElement<TraceData?> {
    companion object Key : CoroutineContext.Key<TraceContextElement>

    override val key: CoroutineContext.Key<TraceContextElement> = Key

    override fun updateThreadContext(context: CoroutineContext): TraceData? {
        val oldState = traceThreadLocal.get()
        traceThreadLocal.set(traceData)
        return oldState
    }

    override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) {
        traceThreadLocal.set(oldState)
    }

    override fun copyForChild(): TraceContextElement {
        // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes
        // ThreadLocal writes between resumption of the parent coroutine and the launch of the
        // child coroutine visible to the child.
        return TraceContextElement(traceThreadLocal.get()?.copy())
    }

    override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
        // Merge operation defines how to handle situations when both
        // the parent coroutine has an element in the context and
        // an element with the same key was also
        // explicitly passed to the child coroutine.
        // If merging does not require special behavior,
        // the copy of the element can be returned.
        return TraceContextElement(traceThreadLocal.get()?.copy())
    }
}

Обратите внимание, что этот метод позволяет вам распространять локальные данные потока, взятые из последней фазы возобновления родительской сопрограммы, в локальный контекст дочерней сопрограммы (какCopyablein подразумевает), потому что методcopyForChildбудет вызываться в потоке родительской сопрограммы, связанной с соответствующей фазой возобновления, когда была создана дочерняя сопрограмма.

Просто добавивTraceContextElementэлемент контекста в контекст корневой сопрограммы, он будет распространен на все дочерние сопрограммы как элемент контекста.

        runBlocking(Dispatchers.IO + TraceContextElement(someTraceDataInstance)){...}

В то время как с подходом может потребоваться дополнительная упаковка для построителей дочерних сопрограмм, если вы переопределите диспетчеры для дочерних сопрограмм.

      fun main() {
    runBlocking(WrappedDispatcher(Dispatchers.IO)) {
        delay(100)
        println("It is wrapped!")
        delay(100)
        println("It is also wrapped!")
        // NOTE: we don't wrap with the WrappedDispatcher class here
        // redefinition of the dispatcher leads to replacement of our custom ContinuationInterceptor
        // with logic taken from specified dispatcher (in the case below from Dispatchers.Default)
        withContext(Dispatchers.Default) {
            delay(100)
            println("It is nested coroutine, and it isn't wrapped!")
            delay(100)
            println("It is nested coroutine, and it isn't wrapped!")
        }
        delay(100)
        println("It is also wrapped!")
    }
}

с интерфейсом переопределения оболочки

      class WrappedDispatcher(
    private val dispatcher: ContinuationInterceptor
) : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {

    override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
        dispatcher.interceptContinuation(ContinuationWrapper(continuation))

    private class ContinuationWrapper<T>(val base: Continuation<T>) : Continuation<T> by base {

        override fun resumeWith(result: Result<T>) {
            println("------WRAPPED START-----")
            base.resumeWith(result)
            println("------WRAPPED END-------")
        }
    }
}

выход:

      ------WRAPPED START-----
------WRAPPED END-------
------WRAPPED START-----
It is wrapped!
------WRAPPED END-------
------WRAPPED START-----
It is also wrapped!
------WRAPPED END-------
It is nested coroutine, and it isn't wrapped!
It is nested coroutine, and it isn't wrapped!
------WRAPPED START-----
------WRAPPED END-------
------WRAPPED START-----
It is also wrapped!
------WRAPPED END-------

как видите, для дочерней (вложенной) сопрограммы наша обертка не применялась, так как мы переназначили предоставление другого диспетчера в качестве параметра. Это может привести к проблеме, поскольку вы можете по ошибке забыть обернуть диспетчер дочерней сопрограммы.


В качестве примечания: если вы решите выбрать этот подход с , рассмотрите возможность добавления такого расширения

      fun ContinuationInterceptor.withMyProjectWrappers() = WrappedDispatcher(this)

обернув ваш диспетчер всеми необходимыми оболочками, которые у вас есть в проекте, очевидно, его можно легко расширить, взяв определенные компоненты (обертки) из контейнера IoC, такого как Spring.


А также в качестве лишнего примераCopyableThreadContextElementгде локальные изменения потока сохраняются на всех этапах возобновления.

Executors.newFixedThreadPool(..).asCoroutineDispatcher()используется для лучшей иллюстрации того, что между фазами возобновления могут работать разные потоки.

      val counterThreadLocal: ThreadLocal<Int> = ThreadLocal.withInitial{ 1 }

fun showCounter(){
    println("-------------------------------------------------")
    println("Thread: ${Thread.currentThread().name}\n Counter value: ${counterThreadLocal.get()}")
}

fun main() {
    runBlocking(Executors.newFixedThreadPool(10).asCoroutineDispatcher() + CounterPropagator(1)) {
        showCounter()
        delay(100)
        showCounter()
        counterThreadLocal.set(2)
        delay(100)
        showCounter()
        counterThreadLocal.set(3)
        val nested = async(Executors.newFixedThreadPool(10).asCoroutineDispatcher()) {
            println("-----------NESTED START---------")
            showCounter()
            delay(100)
            counterThreadLocal.set(4)
            showCounter()
            println("------------NESTED END-----------")
        }
        nested.await()
        showCounter()
        println("---------------END------------")
    }
}

class CounterPropagator(private var counterFromParenCoroutine: Int) : CopyableThreadContextElement<Int> {
    companion object Key : CoroutineContext.Key<CounterPropagator>

    override val key: CoroutineContext.Key<CounterPropagator> = Key

    override fun updateThreadContext(context: CoroutineContext): Int {
        // initialize thread local on the resumption
        counterThreadLocal.set(counterFromParenCoroutine)
        return 0
    }

    override fun restoreThreadContext(context: CoroutineContext, oldState: Int) {
        // propagate thread local changes between resumption phases in the same coroutine
        counterFromParenCoroutine = counterThreadLocal.get()
    }

    override fun copyForChild(): CounterPropagator {
        // propagate thread local changes to children
        return CounterPropagator(counterThreadLocal.get())
    }

    override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
        return CounterPropagator(counterThreadLocal.get())
    }
}

выход:

      -------------------------------------------------
Thread: pool-1-thread-1
 Counter value: 1
-------------------------------------------------
Thread: pool-1-thread-2
 Counter value: 1
-------------------------------------------------
Thread: pool-1-thread-3
 Counter value: 2
-----------NESTED START---------
-------------------------------------------------
Thread: pool-2-thread-1
 Counter value: 3
-------------------------------------------------
Thread: pool-2-thread-2
 Counter value: 4
------------NESTED END-----------
-------------------------------------------------
Thread: pool-1-thread-4
 Counter value: 3
---------------END------------

Вы можете добиться аналогичного поведения с помощьюContinuationInterceptor(но не забудьте повторно обернуть диспетчеры дочерних (вложенных) сопрограмм в конструктор сопрограмм, как было упомянуто выше)

      val counterThreadLocal: ThreadLocal<Int> = ThreadLocal()

class WrappedDispatcher(
    private val dispatcher: ContinuationInterceptor,
    private var savedCounter: Int = counterThreadLocal.get() ?: 0
) : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {
    override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
        dispatcher.interceptContinuation(ContinuationWrapper(continuation))

    private inner class ContinuationWrapper<T>(val base: Continuation<T>) : Continuation<T> by base {

        override fun resumeWith(result: Result<T>) {
            counterThreadLocal.set(savedCounter)
            try {
                base.resumeWith(result)
            } finally {
                savedCounter = counterThreadLocal.get()
            }
        }
    }
}
Другие вопросы по тегам